1 # ====================================================================
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at
6 # http://www.apache.org/licenses/LICENSE-2.0
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 # ====================================================================
16 from unittest import TestCase
19 WhitespaceAnalyzer, IndexSearcher, Term, TermQuery, RAMDirectory, \
20 Document, Field, IndexWriter, Sort, SortField, FieldDoc, Double
22 from lia.extsearch.sorting.DistanceComparatorSource import \
23 DistanceComparatorSource
26 class DistanceSortingTest(TestCase):
30 self.directory = RAMDirectory()
31 writer = IndexWriter(self.directory, WhitespaceAnalyzer(), True,
32 IndexWriter.MaxFieldLength.UNLIMITED)
34 self.addPoint(writer, "El Charro", "restaurant", 1, 2)
35 self.addPoint(writer, "Cafe Poca Cosa", "restaurant", 5, 9)
36 self.addPoint(writer, "Los Betos", "restaurant", 9, 6)
37 self.addPoint(writer, "Nico's Taco Shop", "restaurant", 3, 8)
41 self.searcher = IndexSearcher(self.directory, True)
42 self.query = TermQuery(Term("type", "restaurant"))
44 def addPoint(self, writer, name, type, x, y):
47 doc.add(Field("name", name, Field.Store.YES, Field.Index.NOT_ANALYZED))
48 doc.add(Field("type", type, Field.Store.YES, Field.Index.NOT_ANALYZED))
49 doc.add(Field("x", str(x), Field.Store.YES,
50 Field.Index.NOT_ANALYZED_NO_NORMS))
51 doc.add(Field("y", str(y), Field.Store.YES,
52 Field.Index.NOT_ANALYZED_NO_NORMS));
54 writer.addDocument(doc)
56 def testNearestRestaurantToHome(self):
58 sort = Sort(SortField("location", DistanceComparatorSource(0, 0)))
60 scoreDocs = self.searcher.search(self.query, None, 50, sort).scoreDocs
61 self.assertEqual("El Charro", self.searcher.doc(scoreDocs[0].doc).get("name"), "closest")
62 self.assertEqual("Los Betos", self.searcher.doc(scoreDocs[3].doc).get("name"), "furthest")
64 def testNeareastRestaurantToWork(self):
66 sort = Sort(SortField("location", DistanceComparatorSource(10, 10)))
68 docs = self.searcher.search(self.query, None, 3, sort)
69 self.assertEqual(4, docs.totalHits)
70 self.assertEqual(3, len(docs.scoreDocs))
72 fieldDoc = FieldDoc.cast_(docs.scoreDocs[0])
73 distance = Double.cast_(fieldDoc.fields[0]).doubleValue()
75 self.assertEqual(sqrt(17), distance,
76 "(10,10) -> (9,6) = sqrt(17)")
78 document = self.searcher.doc(fieldDoc.doc)
79 self.assertEqual("Los Betos", document["name"])
81 self.dumpDocs(sort, docs)
83 def dumpDocs(self, sort, docs):
85 print "Sorted by:", sort
87 for scoreDoc in docs.scoreDocs:
88 fieldDoc = FieldDoc.cast_(scoreDoc)
89 distance = Double.cast_(fieldDoc.fields[0]).doubleValue()
90 doc = self.searcher.doc(fieldDoc.doc)
91 print " %(name)s @ (%(location)s) ->" %doc, distance