old python needs __main__ to call a module
[pylucene.git] / test / test_PositionIncrement.py
1 # ====================================================================
2 #   Licensed under the Apache License, Version 2.0 (the "License");
3 #   you may not use this file except in compliance with the License.
4 #   You may obtain a copy of the License at
5 #
6 #       http://www.apache.org/licenses/LICENSE-2.0
7 #
8 #   Unless required by applicable law or agreed to in writing, software
9 #   distributed under the License is distributed on an "AS IS" BASIS,
10 #   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 #   See the License for the specific language governing permissions and
12 #   limitations under the License.
13 # ====================================================================
14
15 from unittest import TestCase, main
16 from lucene import *
17
18
19 class PositionIncrementTestCase(TestCase):
20     """
21     Unit tests ported from Java Lucene
22     """
23
24     def testSetPosition(self):
25
26         class _analyzer(PythonAnalyzer):
27             def tokenStream(_self, fieldName, reader):
28                 class _tokenStream(PythonTokenStream):
29                     def __init__(self_):
30                         super(_tokenStream, self_).__init__()
31                         self_.TOKENS = ["1", "2", "3", "4", "5"]
32                         self_.INCREMENTS = [1, 2, 1, 0, 1]
33                         self_.i = 0
34                         self_.posIncrAtt = self_.addAttribute(PositionIncrementAttribute.class_)
35                         self_.termAtt = self_.addAttribute(TermAttribute.class_)
36                         self_.offsetAtt = self_.addAttribute(OffsetAttribute.class_)
37                     def incrementToken(self_):
38                         if self_.i == len(self_.TOKENS):
39                             return False
40                         self_.termAtt.setTermBuffer(self_.TOKENS[self_.i])
41                         self_.offsetAtt.setOffset(self_.i, self_.i)
42                         self_.posIncrAtt.setPositionIncrement(self_.INCREMENTS[self_.i])
43                         self_.i += 1
44                         return True
45                     def end(self_):
46                         pass
47                     def reset(self_):
48                         pass
49                     def close(self_):
50                         pass
51                 return _tokenStream()
52
53         analyzer = _analyzer()
54
55         store = RAMDirectory()
56         writer = IndexWriter(store, analyzer, True, 
57                              IndexWriter.MaxFieldLength.LIMITED)
58         d = Document()
59         d.add(Field("field", "bogus",
60                     Field.Store.YES, Field.Index.ANALYZED))
61         writer.addDocument(d)
62         writer.optimize()
63         writer.close()
64
65         searcher = IndexSearcher(store, True)
66
67         pos = searcher.getIndexReader().termPositions(Term("field", "1"))
68         pos.next()
69         # first token should be at position 0
70         self.assertEqual(0, pos.nextPosition())
71     
72         pos = searcher.getIndexReader().termPositions(Term("field", "2"))
73         pos.next()
74         # second token should be at position 2
75         self.assertEqual(2, pos.nextPosition())
76     
77         q = PhraseQuery()
78         q.add(Term("field", "1"))
79         q.add(Term("field", "2"))
80         hits = searcher.search(q, None, 1000).scoreDocs
81         self.assertEqual(0, len(hits))
82
83         # same as previous, just specify positions explicitely.
84         q = PhraseQuery() 
85         q.add(Term("field", "1"), 0)
86         q.add(Term("field", "2"), 1)
87         hits = searcher.search(q, None, 1000).scoreDocs
88         self.assertEqual(0, len(hits))
89
90         # specifying correct positions should find the phrase.
91         q = PhraseQuery()
92         q.add(Term("field", "1"), 0)
93         q.add(Term("field", "2"), 2)
94         hits = searcher.search(q, None, 1000).scoreDocs
95         self.assertEqual(1, len(hits))
96
97         q = PhraseQuery()
98         q.add(Term("field", "2"))
99         q.add(Term("field", "3"))
100         hits = searcher.search(q, None, 1000).scoreDocs
101         self.assertEqual(1, len(hits))
102
103         q = PhraseQuery()
104         q.add(Term("field", "3"))
105         q.add(Term("field", "4"))
106         hits = searcher.search(q, None, 1000).scoreDocs
107         self.assertEqual(0, len(hits))
108
109         # phrase query would find it when correct positions are specified. 
110         q = PhraseQuery()
111         q.add(Term("field", "3"), 0)
112         q.add(Term("field", "4"), 0)
113         hits = searcher.search(q, None, 1000).scoreDocs
114         self.assertEqual(1, len(hits))
115
116         # phrase query should fail for non existing searched term 
117         # even if there exist another searched terms in the same searched
118         # position.
119         q = PhraseQuery()
120         q.add(Term("field", "3"), 0)
121         q.add(Term("field", "9"), 0)
122         hits = searcher.search(q, None, 1000).scoreDocs
123         self.assertEqual(0, len(hits))
124
125         # multi-phrase query should succed for non existing searched term
126         # because there exist another searched terms in the same searched
127         # position.
128
129         mq = MultiPhraseQuery()
130         mq.add([Term("field", "3"), Term("field", "9")], 0)
131         hits = searcher.search(mq, None, 1000).scoreDocs
132         self.assertEqual(1, len(hits))
133
134         q = PhraseQuery()
135         q.add(Term("field", "2"))
136         q.add(Term("field", "4"))
137         hits = searcher.search(q, None, 1000).scoreDocs
138         self.assertEqual(1, len(hits))
139
140         q = PhraseQuery()
141         q.add(Term("field", "3"))
142         q.add(Term("field", "5"))
143         hits = searcher.search(q, None, 1000).scoreDocs
144         self.assertEqual(1, len(hits))
145
146         q = PhraseQuery()
147         q.add(Term("field", "4"))
148         q.add(Term("field", "5"))
149         hits = searcher.search(q, None, 1000).scoreDocs
150         self.assertEqual(1, len(hits))
151
152         q = PhraseQuery()
153         q.add(Term("field", "2"))
154         q.add(Term("field", "5"))
155         hits = searcher.search(q, None, 1000).scoreDocs
156         self.assertEqual(0, len(hits))
157
158         # should not find "1 2" because there is a gap of 1 in the index
159         qp = QueryParser(Version.LUCENE_CURRENT, "field",
160                          StopWhitespaceAnalyzer(False))
161         q = PhraseQuery.cast_(qp.parse("\"1 2\""))
162         hits = searcher.search(q, None, 1000).scoreDocs
163         self.assertEqual(0, len(hits))
164
165         # omitted stop word cannot help because stop filter swallows the
166         # increments.
167         q = PhraseQuery.cast_(qp.parse("\"1 stop 2\""))
168         hits = searcher.search(q, None, 1000).scoreDocs
169         self.assertEqual(0, len(hits))
170
171         # query parser alone won't help, because stop filter swallows the
172         # increments.
173         qp.setEnablePositionIncrements(True)
174         q = PhraseQuery.cast_(qp.parse("\"1 stop 2\""))
175         hits = searcher.search(q, None, 1000).scoreDocs
176         self.assertEqual(0, len(hits))
177
178         # stop filter alone won't help, because query parser swallows the
179         # increments.
180         qp.setEnablePositionIncrements(False)
181         q = PhraseQuery.cast_(qp.parse("\"1 stop 2\""))
182         hits = searcher.search(q, None, 1000).scoreDocs
183         self.assertEqual(0, len(hits))
184       
185         # when both qp qnd stopFilter propagate increments, we should find
186         # the doc.
187         qp = QueryParser(Version.LUCENE_CURRENT, "field",
188                          StopWhitespaceAnalyzer(True))
189         qp.setEnablePositionIncrements(True)
190         q = PhraseQuery.cast_(qp.parse("\"1 stop 2\""))
191         hits = searcher.search(q, None, 1000).scoreDocs
192         self.assertEqual(1, len(hits))
193
194     def testPayloadsPos0(self):
195
196         dir = RAMDirectory()
197         writer = IndexWriter(dir, TestPayloadAnalyzer(), True,
198                              IndexWriter.MaxFieldLength.LIMITED)
199
200         doc = Document()
201         doc.add(Field("content",
202                       StringReader("a a b c d e a f g h i j a b k k")))
203         writer.addDocument(doc)
204
205         r = writer.getReader()
206
207         tp = r.termPositions(Term("content", "a"))
208         count = 0
209         self.assert_(tp.next())
210         # "a" occurs 4 times
211         self.assertEqual(4, tp.freq())
212
213         expected = 0
214         self.assertEqual(expected, tp.nextPosition())
215         self.assertEqual(1, tp.nextPosition())
216         self.assertEqual(3, tp.nextPosition())
217         self.assertEqual(6, tp.nextPosition())
218
219         # only one doc has "a"
220         self.assert_(not tp.next())
221
222         searcher = IndexSearcher(r)
223     
224         stq1 = SpanTermQuery(Term("content", "a"))
225         stq2 = SpanTermQuery(Term("content", "k"))
226         sqs = [stq1, stq2]
227         snq = SpanNearQuery(sqs, 30, False)
228
229         count = 0
230         sawZero = False
231
232         pspans = snq.getSpans(searcher.getIndexReader())
233         while pspans.next():
234             payloads = pspans.getPayload()
235             sawZero |= pspans.start() == 0
236
237             it = payloads.iterator()
238             while it.hasNext():
239                 count += 1
240                 it.next()
241
242         self.assertEqual(5, count)
243         self.assert_(sawZero)
244
245         spans = snq.getSpans(searcher.getIndexReader())
246         count = 0
247         sawZero = False
248         while spans.next():
249             count += 1
250             sawZero |= spans.start() == 0
251
252         self.assertEqual(4, count)
253         self.assert_(sawZero)
254                 
255         sawZero = False
256         psu = PayloadSpanUtil(searcher.getIndexReader())
257         pls = psu.getPayloadsForQuery(snq)
258         count = pls.size()
259         it = pls.iterator()
260         while it.hasNext():
261             bytes = JArray('byte').cast_(it.next())
262             s = bytes.string_
263             sawZero |= s == "pos: 0"
264
265         self.assertEqual(5, count)
266         self.assert_(sawZero)
267         writer.close()
268         searcher.getIndexReader().close()
269         dir.close()
270
271
272 class StopWhitespaceAnalyzer(PythonAnalyzer):
273
274     def __init__(self, enablePositionIncrements):
275         super(StopWhitespaceAnalyzer, self).__init__()
276
277         self.enablePositionIncrements = enablePositionIncrements
278         self.a = WhitespaceAnalyzer()
279
280     def tokenStream(self, fieldName, reader):
281
282         ts = self.a.tokenStream(fieldName, reader)
283         set = HashSet()
284         set.add("stop")
285
286         return StopFilter(self.enablePositionIncrements, ts, set)
287
288
289 class TestPayloadAnalyzer(PythonAnalyzer):
290
291     def tokenStream(self, fieldName, reader):
292
293         result = LowerCaseTokenizer(reader)
294         return PayloadFilter(result, fieldName)
295
296
297 class PayloadFilter(PythonTokenFilter):
298
299     def __init__(self, input, fieldName):
300         super(PayloadFilter, self).__init__(input)
301         self.input = input
302
303         self.fieldName = fieldName
304         self.pos = 0
305         self.i = 0
306         self.posIncrAttr = input.addAttribute(PositionIncrementAttribute.class_)
307         self.payloadAttr = input.addAttribute(PayloadAttribute.class_)
308         self.termAttr = input.addAttribute(TermAttribute.class_)
309
310     def incrementToken(self):
311
312         if self.input.incrementToken():
313             bytes = JArray('byte')("pos: %d" %(self.pos))
314             self.payloadAttr.setPayload(Payload(bytes))
315
316             if self.i % 2 == 1:
317                 posIncr = 1
318             else:
319                 posIncr = 0
320
321             self.posIncrAttr.setPositionIncrement(posIncr)
322             self.pos += posIncr
323             self.i += 1
324             return True
325
326         return False
327
328
329 if __name__ == "__main__":
330     import sys, lucene
331     lucene.initVM()
332     if '-loop' in sys.argv:
333         sys.argv.remove('-loop')
334         while True:
335             try:
336                 main()
337             except:
338                 pass
339     else:
340          main()