add --shared
[pylucene.git] / samples / LuceneInAction / lia / extsearch / sorting / DistanceComparatorSource.py
1 # ====================================================================
2 #   Licensed under the Apache License, Version 2.0 (the "License");
3 #   you may not use this file except in compliance with the License.
4 #   You may obtain a copy of the License at
5 #
6 #       http://www.apache.org/licenses/LICENSE-2.0
7 #
8 #   Unless required by applicable law or agreed to in writing, software
9 #   distributed under the License is distributed on an "AS IS" BASIS,
10 #   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 #   See the License for the specific language governing permissions and
12 #   limitations under the License.
13 # ====================================================================
14
15 from math import sqrt
16 from lucene import SortField, Term, IndexReader, FieldCache, \
17     PythonFieldComparatorSource, PythonFieldComparator, Double
18
19 #
20 # A FieldComparatorSource implementation
21 #
22
23 class DistanceComparatorSource(PythonFieldComparatorSource):
24
25     def __init__(self, x, y):
26         super(DistanceComparatorSource, self).__init__()
27
28         self.x = x
29         self.y = y
30
31     def newComparator(self, fieldName, numHits, sortPos, reversed):
32
33         class DistanceScoreDocLookupComparator(PythonFieldComparator):
34
35             def __init__(_self, fieldName, numHits):
36                 super(DistanceScoreDocLookupComparator, _self).__init__()
37                 _self.values = [0.0] * numHits
38                 _self.fieldName = fieldName
39
40             def setNextReader(_self, reader, docBase):
41       
42                 _self.xDoc = FieldCache.DEFAULT.getInts(reader, "x")
43                 _self.yDoc = FieldCache.DEFAULT.getInts(reader, "y")
44
45             def _getDistance(_self, doc):
46
47                 deltax = _self.xDoc[doc] - self.x
48                 deltay = _self.yDoc[doc] - self.y
49
50                 return sqrt(deltax * deltax + deltay * deltay)
51
52             def compare(_self, slot1, slot2):
53
54                 if _self.values[slot1] < _self.values[slot2]:
55                     return -1
56                 if _self.values[slot1] > _self.values[slot2]:
57                     return 1
58
59                 return 0
60
61             def setBottom(_self, slot):
62
63                 _self._bottom = _self.values[slot]
64
65             def compareBottom(_self, doc):
66
67                 docDistance = _self._getDistance(doc)
68                 if _self._bottom < docDistance:
69                     return -1
70                 if _self._bottom > docDistance:
71                      return 1
72
73                 return 0
74
75             def copy(_self, slot, doc):
76
77                 _self.values[slot] = _self._getDistance(doc)
78
79             def value(_self, slot):
80
81                 return Double(_self.values[slot])
82
83             def sortType(_self):
84                 return SortField.CUSTOM
85
86         return DistanceScoreDocLookupComparator(fieldName, numHits)
87
88     def __str__(self):
89
90         return "Distance from (" + self.x + "," + self.y + ")"