1 # ====================================================================
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
6 # http://www.apache.org/licenses/LICENSE-2.0
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 # ====================================================================
15 from unittest import TestCase, main
19 class PositionIncrementTestCase(TestCase):
21 Unit tests ported from Java Lucene
24 def testSetPosition(self):
26 class _analyzer(PythonAnalyzer):
27 def tokenStream(_self, fieldName, reader):
28 class _tokenStream(PythonTokenStream):
30 super(_tokenStream, self_).__init__()
31 self_.TOKENS = ["1", "2", "3", "4", "5"]
32 self_.INCREMENTS = [1, 2, 1, 0, 1]
34 self_.posIncrAtt = self_.addAttribute(PositionIncrementAttribute.class_)
35 self_.termAtt = self_.addAttribute(TermAttribute.class_)
36 self_.offsetAtt = self_.addAttribute(OffsetAttribute.class_)
37 def incrementToken(self_):
38 if self_.i == len(self_.TOKENS):
40 self_.termAtt.setTermBuffer(self_.TOKENS[self_.i])
41 self_.offsetAtt.setOffset(self_.i, self_.i)
42 self_.posIncrAtt.setPositionIncrement(self_.INCREMENTS[self_.i])
53 analyzer = _analyzer()
55 store = RAMDirectory()
56 writer = IndexWriter(store, analyzer, True,
57 IndexWriter.MaxFieldLength.LIMITED)
59 d.add(Field("field", "bogus",
60 Field.Store.YES, Field.Index.ANALYZED))
65 searcher = IndexSearcher(store, True)
67 pos = searcher.getIndexReader().termPositions(Term("field", "1"))
69 # first token should be at position 0
70 self.assertEqual(0, pos.nextPosition())
72 pos = searcher.getIndexReader().termPositions(Term("field", "2"))
74 # second token should be at position 2
75 self.assertEqual(2, pos.nextPosition())
78 q.add(Term("field", "1"))
79 q.add(Term("field", "2"))
80 hits = searcher.search(q, None, 1000).scoreDocs
81 self.assertEqual(0, len(hits))
83 # same as previous, just specify positions explicitely.
85 q.add(Term("field", "1"), 0)
86 q.add(Term("field", "2"), 1)
87 hits = searcher.search(q, None, 1000).scoreDocs
88 self.assertEqual(0, len(hits))
90 # specifying correct positions should find the phrase.
92 q.add(Term("field", "1"), 0)
93 q.add(Term("field", "2"), 2)
94 hits = searcher.search(q, None, 1000).scoreDocs
95 self.assertEqual(1, len(hits))
98 q.add(Term("field", "2"))
99 q.add(Term("field", "3"))
100 hits = searcher.search(q, None, 1000).scoreDocs
101 self.assertEqual(1, len(hits))
104 q.add(Term("field", "3"))
105 q.add(Term("field", "4"))
106 hits = searcher.search(q, None, 1000).scoreDocs
107 self.assertEqual(0, len(hits))
109 # phrase query would find it when correct positions are specified.
111 q.add(Term("field", "3"), 0)
112 q.add(Term("field", "4"), 0)
113 hits = searcher.search(q, None, 1000).scoreDocs
114 self.assertEqual(1, len(hits))
116 # phrase query should fail for non existing searched term
117 # even if there exist another searched terms in the same searched
120 q.add(Term("field", "3"), 0)
121 q.add(Term("field", "9"), 0)
122 hits = searcher.search(q, None, 1000).scoreDocs
123 self.assertEqual(0, len(hits))
125 # multi-phrase query should succed for non existing searched term
126 # because there exist another searched terms in the same searched
129 mq = MultiPhraseQuery()
130 mq.add([Term("field", "3"), Term("field", "9")], 0)
131 hits = searcher.search(mq, None, 1000).scoreDocs
132 self.assertEqual(1, len(hits))
135 q.add(Term("field", "2"))
136 q.add(Term("field", "4"))
137 hits = searcher.search(q, None, 1000).scoreDocs
138 self.assertEqual(1, len(hits))
141 q.add(Term("field", "3"))
142 q.add(Term("field", "5"))
143 hits = searcher.search(q, None, 1000).scoreDocs
144 self.assertEqual(1, len(hits))
147 q.add(Term("field", "4"))
148 q.add(Term("field", "5"))
149 hits = searcher.search(q, None, 1000).scoreDocs
150 self.assertEqual(1, len(hits))
153 q.add(Term("field", "2"))
154 q.add(Term("field", "5"))
155 hits = searcher.search(q, None, 1000).scoreDocs
156 self.assertEqual(0, len(hits))
158 # should not find "1 2" because there is a gap of 1 in the index
159 qp = QueryParser(Version.LUCENE_CURRENT, "field",
160 StopWhitespaceAnalyzer(False))
161 q = PhraseQuery.cast_(qp.parse("\"1 2\""))
162 hits = searcher.search(q, None, 1000).scoreDocs
163 self.assertEqual(0, len(hits))
165 # omitted stop word cannot help because stop filter swallows the
167 q = PhraseQuery.cast_(qp.parse("\"1 stop 2\""))
168 hits = searcher.search(q, None, 1000).scoreDocs
169 self.assertEqual(0, len(hits))
171 # query parser alone won't help, because stop filter swallows the
173 qp.setEnablePositionIncrements(True)
174 q = PhraseQuery.cast_(qp.parse("\"1 stop 2\""))
175 hits = searcher.search(q, None, 1000).scoreDocs
176 self.assertEqual(0, len(hits))
178 # stop filter alone won't help, because query parser swallows the
180 qp.setEnablePositionIncrements(False)
181 q = PhraseQuery.cast_(qp.parse("\"1 stop 2\""))
182 hits = searcher.search(q, None, 1000).scoreDocs
183 self.assertEqual(0, len(hits))
185 # when both qp qnd stopFilter propagate increments, we should find
187 qp = QueryParser(Version.LUCENE_CURRENT, "field",
188 StopWhitespaceAnalyzer(True))
189 qp.setEnablePositionIncrements(True)
190 q = PhraseQuery.cast_(qp.parse("\"1 stop 2\""))
191 hits = searcher.search(q, None, 1000).scoreDocs
192 self.assertEqual(1, len(hits))
194 def testPayloadsPos0(self):
197 writer = IndexWriter(dir, TestPayloadAnalyzer(), True,
198 IndexWriter.MaxFieldLength.LIMITED)
201 doc.add(Field("content",
202 StringReader("a a b c d e a f g h i j a b k k")))
203 writer.addDocument(doc)
205 r = writer.getReader()
207 tp = r.termPositions(Term("content", "a"))
209 self.assert_(tp.next())
211 self.assertEqual(4, tp.freq())
214 self.assertEqual(expected, tp.nextPosition())
215 self.assertEqual(1, tp.nextPosition())
216 self.assertEqual(3, tp.nextPosition())
217 self.assertEqual(6, tp.nextPosition())
219 # only one doc has "a"
220 self.assert_(not tp.next())
222 searcher = IndexSearcher(r)
224 stq1 = SpanTermQuery(Term("content", "a"))
225 stq2 = SpanTermQuery(Term("content", "k"))
227 snq = SpanNearQuery(sqs, 30, False)
232 pspans = snq.getSpans(searcher.getIndexReader())
234 payloads = pspans.getPayload()
235 sawZero |= pspans.start() == 0
237 it = payloads.iterator()
242 self.assertEqual(5, count)
243 self.assert_(sawZero)
245 spans = snq.getSpans(searcher.getIndexReader())
250 sawZero |= spans.start() == 0
252 self.assertEqual(4, count)
253 self.assert_(sawZero)
256 psu = PayloadSpanUtil(searcher.getIndexReader())
257 pls = psu.getPayloadsForQuery(snq)
261 bytes = JArray('byte').cast_(it.next())
263 sawZero |= s == "pos: 0"
265 self.assertEqual(5, count)
266 self.assert_(sawZero)
268 searcher.getIndexReader().close()
272 class StopWhitespaceAnalyzer(PythonAnalyzer):
274 def __init__(self, enablePositionIncrements):
275 super(StopWhitespaceAnalyzer, self).__init__()
277 self.enablePositionIncrements = enablePositionIncrements
278 self.a = WhitespaceAnalyzer()
280 def tokenStream(self, fieldName, reader):
282 ts = self.a.tokenStream(fieldName, reader)
286 return StopFilter(self.enablePositionIncrements, ts, set)
289 class TestPayloadAnalyzer(PythonAnalyzer):
291 def tokenStream(self, fieldName, reader):
293 result = LowerCaseTokenizer(reader)
294 return PayloadFilter(result, fieldName)
297 class PayloadFilter(PythonTokenFilter):
299 def __init__(self, input, fieldName):
300 super(PayloadFilter, self).__init__(input)
303 self.fieldName = fieldName
306 self.posIncrAttr = input.addAttribute(PositionIncrementAttribute.class_)
307 self.payloadAttr = input.addAttribute(PayloadAttribute.class_)
308 self.termAttr = input.addAttribute(TermAttribute.class_)
310 def incrementToken(self):
312 if self.input.incrementToken():
313 bytes = JArray('byte')("pos: %d" %(self.pos))
314 self.payloadAttr.setPayload(Payload(bytes))
321 self.posIncrAttr.setPositionIncrement(posIncr)
329 if __name__ == "__main__":
332 if '-loop' in sys.argv:
333 sys.argv.remove('-loop')