old python needs __main__ to call a module
[pylucene.git] / samples / LuceneInAction / lia / extsearch / sorting / DistanceSortingTest.py
1 # ====================================================================
2 #   Licensed under the Apache License, Version 2.0 (the "License");
3 #   you may not use this file except in compliance with the License.
4 #   You may obtain a copy of the License at
5 #
6 #       http://www.apache.org/licenses/LICENSE-2.0
7 #
8 #   Unless required by applicable law or agreed to in writing, software
9 #   distributed under the License is distributed on an "AS IS" BASIS,
10 #   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 #   See the License for the specific language governing permissions and
12 #   limitations under the License.
13 # ====================================================================
14
15 from math import sqrt
16 from unittest import TestCase
17
18 from lucene import \
19      WhitespaceAnalyzer, IndexSearcher, Term, TermQuery, RAMDirectory, \
20      Document, Field, IndexWriter, Sort, SortField, FieldDoc, Double
21
22 from lia.extsearch.sorting.DistanceComparatorSource import \
23      DistanceComparatorSource
24
25
26 class DistanceSortingTest(TestCase):
27
28     def setUp(self):
29
30         self.directory = RAMDirectory()
31         writer = IndexWriter(self.directory, WhitespaceAnalyzer(), True,
32                              IndexWriter.MaxFieldLength.UNLIMITED)
33
34         self.addPoint(writer, "El Charro", "restaurant", 1, 2)
35         self.addPoint(writer, "Cafe Poca Cosa", "restaurant", 5, 9)
36         self.addPoint(writer, "Los Betos", "restaurant", 9, 6)
37         self.addPoint(writer, "Nico's Taco Shop", "restaurant", 3, 8)
38
39         writer.close()
40
41         self.searcher = IndexSearcher(self.directory, True)
42         self.query = TermQuery(Term("type", "restaurant"))
43
44     def addPoint(self, writer, name, type, x, y):
45
46         doc = Document()
47         doc.add(Field("name", name, Field.Store.YES, Field.Index.NOT_ANALYZED))
48         doc.add(Field("type", type, Field.Store.YES, Field.Index.NOT_ANALYZED))
49         doc.add(Field("x", str(x), Field.Store.YES,
50                       Field.Index.NOT_ANALYZED_NO_NORMS))
51         doc.add(Field("y", str(y), Field.Store.YES,
52                       Field.Index.NOT_ANALYZED_NO_NORMS));
53
54         writer.addDocument(doc)
55
56     def testNearestRestaurantToHome(self):
57
58         sort = Sort(SortField("location", DistanceComparatorSource(0, 0)))
59
60         scoreDocs = self.searcher.search(self.query, None, 50, sort).scoreDocs
61         self.assertEqual("El Charro", self.searcher.doc(scoreDocs[0].doc).get("name"), "closest")
62         self.assertEqual("Los Betos", self.searcher.doc(scoreDocs[3].doc).get("name"), "furthest")
63
64     def testNeareastRestaurantToWork(self):
65
66         sort = Sort(SortField("location", DistanceComparatorSource(10, 10)))
67
68         docs = self.searcher.search(self.query, None, 3, sort)
69         self.assertEqual(4, docs.totalHits)
70         self.assertEqual(3, len(docs.scoreDocs))
71
72         fieldDoc = FieldDoc.cast_(docs.scoreDocs[0])
73         distance = Double.cast_(fieldDoc.fields[0]).doubleValue()
74
75         self.assertEqual(sqrt(17), distance,
76                          "(10,10) -> (9,6) = sqrt(17)")
77
78         document = self.searcher.doc(fieldDoc.doc)
79         self.assertEqual("Los Betos", document["name"])
80
81         self.dumpDocs(sort, docs)
82
83     def dumpDocs(self, sort, docs):
84
85         print "Sorted by:", sort
86
87         for scoreDoc in docs.scoreDocs:
88             fieldDoc = FieldDoc.cast_(scoreDoc)
89             distance = Double.cast_(fieldDoc.fields[0]).doubleValue()
90             doc = self.searcher.doc(fieldDoc.doc)
91             print "  %(name)s @ (%(location)s) ->" %doc, distance