old python needs __main__ to call a module
[pylucene.git] / test / test_FilteredQuery.py
1 # ====================================================================
2 #   Licensed under the Apache License, Version 2.0 (the "License");
3 #   you may not use this file except in compliance with the License.
4 #   You may obtain a copy of the License at
5 #
6 #       http://www.apache.org/licenses/LICENSE-2.0
7 #
8 #   Unless required by applicable law or agreed to in writing, software
9 #   distributed under the License is distributed on an "AS IS" BASIS,
10 #   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 #   See the License for the specific language governing permissions and
12 #   limitations under the License.
13 # ====================================================================
14
15 from unittest import TestCase, main
16 from lucene import *
17
18
19 class FilteredQueryTestCase(TestCase):
20     """
21     Unit tests ported from Java Lucene
22     """
23
24     def setUp(self):
25
26         self.directory = RAMDirectory()
27         writer = IndexWriter(self.directory, WhitespaceAnalyzer(), True,
28                              IndexWriter.MaxFieldLength.LIMITED)
29
30         doc = Document()
31         doc.add(Field("field", "one two three four five",
32                       Field.Store.YES, Field.Index.ANALYZED))
33         doc.add(Field("sorter", "b",
34                       Field.Store.YES, Field.Index.ANALYZED))
35                       
36         writer.addDocument(doc)
37
38         doc = Document()
39         doc.add(Field("field", "one two three four",
40                       Field.Store.YES, Field.Index.ANALYZED))
41         doc.add(Field("sorter", "d",
42                       Field.Store.YES, Field.Index.ANALYZED))
43
44         writer.addDocument(doc)
45
46         doc = Document()
47         doc.add(Field("field", "one two three y",
48                       Field.Store.YES, Field.Index.ANALYZED))
49         doc.add(Field("sorter", "a",
50                       Field.Store.YES, Field.Index.ANALYZED))
51
52         writer.addDocument(doc)
53
54         doc = Document()
55         doc.add(Field("field", "one two x",
56                       Field.Store.YES, Field.Index.ANALYZED))
57         doc.add(Field("sorter", "c",
58                       Field.Store.YES, Field.Index.ANALYZED))
59                       
60         writer.addDocument(doc)
61
62         writer.optimize()
63         writer.close()
64
65         self.searcher = IndexSearcher(self.directory, True)
66         self.query = TermQuery(Term("field", "three"))
67
68         class filter(PythonFilter):
69             def getDocIdSet(self, reader):
70                 bitset = BitSet(5)
71                 bitset.set(1)
72                 bitset.set(3)
73                 return DocIdBitSet(bitset)
74
75         self.filter = filter()
76
77     def tearDown(self):
78
79         self.searcher.close()
80         self.directory.close()
81
82     def testFilteredQuery(self):
83
84         filteredquery = FilteredQuery(self.query, self.filter)
85         topDocs = self.searcher.search(filteredquery, 50)
86         self.assertEqual(1, topDocs.totalHits)
87         self.assertEqual(1, topDocs.scoreDocs[0].doc)
88
89         topDocs = self.searcher.search(filteredquery, None, 50,
90                                        Sort(SortField("sorter",
91                                                       SortField.STRING)))
92         self.assertEqual(1, topDocs.totalHits)
93         self.assertEqual(1, topDocs.scoreDocs[0].doc)
94
95         filteredquery = FilteredQuery(TermQuery(Term("field", "one")),
96                                       self.filter)
97         topDocs = self.searcher.search(filteredquery, 50)
98         self.assertEqual(2, topDocs.totalHits)
99
100         filteredquery = FilteredQuery(TermQuery(Term("field", "x")),
101                                       self.filter)
102         topDocs = self.searcher.search(filteredquery, 50)
103         self.assertEqual(1, topDocs.totalHits)
104         self.assertEqual(3, topDocs.scoreDocs[0].doc)
105
106         filteredquery = FilteredQuery(TermQuery(Term("field", "y")),
107                                       self.filter)
108         topDocs = self.searcher.search(filteredquery, 50)
109         self.assertEqual(0, topDocs.totalHits)
110
111     def testRangeQuery(self):
112         """
113         This tests FilteredQuery's rewrite correctness
114         """
115
116         rq = TermRangeQuery("sorter", "b", "d", True, True)
117         filteredquery = FilteredQuery(rq, self.filter)
118         scoreDocs = self.searcher.search(filteredquery, 1000).scoreDocs
119         self.assertEqual(2, len(scoreDocs))
120
121
122 if __name__ == "__main__":
123     import sys, lucene
124     lucene.initVM()
125     if '-loop' in sys.argv:
126         sys.argv.remove('-loop')
127         while True:
128             try:
129                 main()
130             except:
131                 pass
132     else:
133          main()