Merge branch 'minimal-double-reader' into rwd
[wolnelektury.git] / apps / search / custom.py
1
2 from sunburnt import sunburnt
3 from lxml import etree
4 import urllib
5 import warnings
6 from sunburnt import search
7 import copy
8 from httplib2 import socket
9 import re
10
11
12 class TermVectorOptions(search.Options):
13     def __init__(self, schema, original=None):
14         self.schema = schema
15         if original is None:
16             self.fields = set()
17             self.positions = False
18         else:
19             self.fields = copy.copy(original.fields)
20             self.positions = copy.copy(original.positions)
21
22     def update(self, positions=False, fields=None):
23         if fields is None:
24             fields = []
25         if isinstance(fields, basestring):
26             fields = [fields]
27         self.schema.check_fields(fields, {"stored": True})
28         self.fields.update(fields)
29         self.positions = positions
30
31     def options(self):
32         opts = {}
33         if self.positions or self.fields:
34             opts['tv'] = 'true'
35         if self.positions:
36             opts['tv.positions'] = 'true'
37         if self.fields:
38             opts['tv.fl'] = ','.join(sorted(self.fields))
39         return opts
40
41
42 class CustomSolrConnection(sunburnt.SolrConnection):
43     def __init__(self, *args, **kw):
44         super(CustomSolrConnection, self).__init__(*args, **kw)
45         self.analysis_url = self.url + "analysis/field/"
46
47     def analyze(self, params):
48         qs = urllib.urlencode(params)
49         url = "%s?%s" % (self.analysis_url, qs)
50         if len(url) > self.max_length_get_url:
51             warnings.warn("Long query URL encountered - POSTing instead of "
52                 "GETting. This query will not be cached at the HTTP layer")
53             url = self.analysis_url
54             kwargs = dict(
55                 method="POST",
56                 body=qs,
57                 headers={"Content-Type": "application/x-www-form-urlencoded"},
58             )
59         else:
60             kwargs = dict(method="GET")
61         r, c = self.request(url, **kwargs)
62         if r.status != 200:
63             raise sunburnt.SolrError(r, c)
64         return c
65
66
67 # monkey patching sunburnt SolrSearch
68 search.SolrSearch.option_modules += ('term_vectorer',)
69
70
71 def __term_vector(self, positions=False, fields=None):
72     newself = self.clone()
73     newself.term_vectorer.update(positions, fields)
74     return newself
75 setattr(search.SolrSearch, 'term_vector', __term_vector)
76
77
78 def __patched__init_common_modules(self):
79     __original__init_common_modules(self)
80     self.term_vectorer = TermVectorOptions(self.schema)
81 __original__init_common_modules = search.SolrSearch._init_common_modules
82 setattr(search.SolrSearch, '_init_common_modules', __patched__init_common_modules)
83
84
85 class CustomSolrInterface(sunburnt.SolrInterface):
86     # just copied from parent and SolrConnection -> CustomSolrConnection
87     def __init__(self, url, schemadoc=None, http_connection=None, mode='', retry_timeout=-1, max_length_get_url=sunburnt.MAX_LENGTH_GET_URL):
88         self.conn = CustomSolrConnection(url, http_connection, retry_timeout, max_length_get_url)
89         self.schemadoc = schemadoc
90         if 'w' not in mode:
91             self.writeable = False
92         elif 'r' not in mode:
93             self.readable = False
94         try:
95             self.init_schema()
96         except socket.error, e:
97             raise socket.error, "Cannot connect to Solr server, and search indexing is enabled (%s)" % str(e)
98
99     def _analyze(self, **kwargs):
100         if not self.readable:
101             raise TypeError("This Solr instance is only for writing")
102         args = {
103             'analysis_showmatch': True
104             }
105         if 'field' in kwargs: args['analysis_fieldname'] = kwargs['field']
106         if 'text' in kwargs: args['analysis_fieldvalue'] = kwargs['text']
107         if 'q' in kwargs: args['q'] = kwargs['q']
108         if 'query' in kwargs: args['q'] = kwargs['q']
109
110         params = map(lambda (k, v): (k.replace('_', '.'), v), sunburnt.params_from_dict(**args))
111
112         content = self.conn.analyze(params)
113         doc = etree.fromstring(content)
114         return doc
115
116     def highlight(self, **kwargs):
117         doc = self._analyze(**kwargs)
118         analyzed = doc.xpath("//lst[@name='index']/arr[last()]/lst[bool/@name='match']")
119         matches = set()
120         for wrd in analyzed:
121             start = int(wrd.xpath("int[@name='start']")[0].text)
122             end = int(wrd.xpath("int[@name='end']")[0].text)
123             matches.add((start, end))
124
125         if matches:
126             return self.substring(kwargs['text'], matches,
127                 margins=kwargs.get('margins', 30),
128                 mark=kwargs.get('mark', ("<b>", "</b>")))
129         else:
130             return None
131
132     def analyze(self, **kwargs):
133         doc = self._analyze(**kwargs)
134         terms = doc.xpath("//lst[@name='index']/arr[last()]/lst/str[1]")
135         terms = map(lambda n: unicode(n.text), terms)
136         return terms
137
138     def expand_margins(self, text, start, end):
139         totlen = len(text)
140
141         def is_boundary(x):
142             ws = re.compile(r"\W", re.UNICODE)
143             return bool(ws.match(x))
144
145         while start > 0:
146             if is_boundary(text[start - 1]):
147                 break
148             start -= 1
149
150         while end < totlen - 1:
151             if is_boundary(text[end + 1]):
152                 break
153             end += 1
154
155         return (start, end)
156
157     def substring(self, text, matches, margins=30, mark=("<b>", "</b>")):
158         start = None
159         end = None
160         totlen = len(text)
161         matches_margins = map(lambda (s, e):
162                               ((s, e),
163                                (max(0, s - margins), min(totlen, e + margins))),
164                                   matches)
165         matches_margins = map(lambda (m, (s, e)):
166                               (m, self.expand_margins(text, s, e)),
167             matches_margins)
168
169             # lets start with first match
170         (start, end) = matches_margins[0][1]
171         matches = [matches_margins[0][0]]
172
173         for (m, (s, e)) in matches_margins[1:]:
174             if end < s or start > e:
175                 continue
176             start = min(start, s)
177             end = max(end, e)
178             matches.append(m)
179
180         snip = text[start:end]
181         matches.sort(lambda a, b: cmp(b[0], a[0]))
182
183         for (s, e) in matches:
184             off = - start
185             snip = snip[:e + off] + mark[1] + snip[e + off:]
186             snip = snip[:s + off] + mark[0] + snip[s + off:]
187
188         return snip