old python needs __main__ to call a module
[pylucene.git] / samples / LuceneInAction / lia / advsearching / CategorizerTest.py
1 # ====================================================================
2 #   Licensed under the Apache License, Version 2.0 (the "License");
3 #   you may not use this file except in compliance with the License.
4 #   You may obtain a copy of the License at
5 #
6 #       http://www.apache.org/licenses/LICENSE-2.0
7 #
8 #   Unless required by applicable law or agreed to in writing, software
9 #   distributed under the License is distributed on an "AS IS" BASIS,
10 #   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 #   See the License for the specific language governing permissions and
12 #   limitations under the License.
13 # ====================================================================
14
15 from math import pi, sqrt, acos
16 from lia.common.LiaTestCase import LiaTestCase
17
18 from lucene import Document, IndexReader 
19
20
21 class CategorizerTest(LiaTestCase):
22
23     def setUp(self):
24
25         super(CategorizerTest, self).setUp()
26         self.categoryMap = {}
27
28         self.buildCategoryVectors()
29         self.dumpCategoryVectors()
30
31     def testCategorization(self):
32         
33         self.assertEqual("/technology/computers/programming/methodology",
34                          self.getCategory("extreme agile methodology"))
35         self.assertEqual("/education/pedagogy",
36                          self.getCategory("montessori education philosophy"))
37
38     def dumpCategoryVectors(self):
39
40         for category, vectorMap in self.categoryMap.iteritems():
41             print "Category", category
42             for term, freq in vectorMap.iteritems():
43                 print "   ", term, "=", freq
44
45     def buildCategoryVectors(self):
46
47         reader = IndexReader.open(self.directory, True)
48
49         for id in xrange(reader.maxDoc()):
50             doc = reader.document(id)
51             category = doc.get("category")
52             vectorMap = self.categoryMap.get(category, None)
53             if vectorMap is None:
54                 vectorMap = self.categoryMap[category] = {}
55
56             termFreqVector = reader.getTermFreqVector(id, "subject")
57             self.addTermFreqToMap(vectorMap, termFreqVector)
58
59     def addTermFreqToMap(self, vectorMap, termFreqVector):
60
61         terms = termFreqVector.getTerms()
62         freqs = termFreqVector.getTermFrequencies()
63
64         i = 0
65         for term in terms:
66             if term in vectorMap:
67                 vectorMap[term] += freqs[i]
68             else:
69                 vectorMap[term] = freqs[i]
70             i += 1
71
72     def getCategory(self, subject):
73
74         words = subject.split(' ')
75
76         bestAngle = 2 * pi
77         bestCategory = None
78
79         for category, vectorMap in self.categoryMap.iteritems():
80             angle = self.computeAngle(words, category, vectorMap)
81             if angle != 'nan' and angle < bestAngle:
82                 bestAngle = angle
83                 bestCategory = category
84
85         return bestCategory
86
87     def computeAngle(self, words, category, vectorMap):
88
89         # assume words are unique and only occur once
90
91         dotProduct = 0
92         sumOfSquares = 0
93
94         for word in words:
95             categoryWordFreq = 0
96
97             if word in vectorMap:
98                 categoryWordFreq = vectorMap[word]
99
100             # optimized because we assume frequency in words is 1
101             dotProduct += categoryWordFreq
102             sumOfSquares += categoryWordFreq ** 2
103
104         if sumOfSquares == 0:
105             return 'nan'
106
107         if sumOfSquares == len(words):
108             # avoid precision issues for special case
109             # sqrt x * sqrt x = x
110             denominator = sumOfSquares 
111         else:
112             denominator = sqrt(sumOfSquares) * sqrt(len(words))
113
114         return acos(dotProduct / denominator)