fix in librarian
[wolnelektury.git] / apps / search / custom.py
index fcc3bac..b3b704d 100644 (file)
@@ -1,15 +1,18 @@
-
+# -*- coding: utf-8 -*-
+# This file is part of Wolnelektury, licensed under GNU Affero GPLv3 or later.
+# Copyright © Fundacja Nowoczesna Polska. See NOTICE for more information.
+#
 from sunburnt import sunburnt
 from lxml import etree
 import urllib
 import warnings
 from sunburnt import search
 import copy
 from sunburnt import sunburnt
 from lxml import etree
 import urllib
 import warnings
 from sunburnt import search
 import copy
+from httplib2 import socket
+import re
 
 
 class TermVectorOptions(search.Options):
 
 
 class TermVectorOptions(search.Options):
-    option_name = "tv"
-
     def __init__(self, schema, original=None):
         self.schema = schema
         if original is None:
     def __init__(self, schema, original=None):
         self.schema = schema
         if original is None:
@@ -30,7 +33,8 @@ class TermVectorOptions(search.Options):
 
     def options(self):
         opts = {}
 
     def options(self):
         opts = {}
-        opts['tv'] = 'true'
+        if self.positions or self.fields:
+            opts['tv'] = 'true'
         if self.positions:
             opts['tv.positions'] = 'true'
         if self.fields:
         if self.positions:
             opts['tv.positions'] = 'true'
         if self.fields:
@@ -72,12 +76,12 @@ def __term_vector(self, positions=False, fields=None):
     newself.term_vectorer.update(positions, fields)
     return newself
 setattr(search.SolrSearch, 'term_vector', __term_vector)
     newself.term_vectorer.update(positions, fields)
     return newself
 setattr(search.SolrSearch, 'term_vector', __term_vector)
-__original__init_common_modules = search.SolrSearch._init_common_modules
 
 
 def __patched__init_common_modules(self):
     __original__init_common_modules(self)
     self.term_vectorer = TermVectorOptions(self.schema)
 
 
 def __patched__init_common_modules(self):
     __original__init_common_modules(self)
     self.term_vectorer = TermVectorOptions(self.schema)
+__original__init_common_modules = search.SolrSearch._init_common_modules
 setattr(search.SolrSearch, '_init_common_modules', __patched__init_common_modules)
 
 
 setattr(search.SolrSearch, '_init_common_modules', __patched__init_common_modules)
 
 
@@ -86,11 +90,14 @@ class CustomSolrInterface(sunburnt.SolrInterface):
     def __init__(self, url, schemadoc=None, http_connection=None, mode='', retry_timeout=-1, max_length_get_url=sunburnt.MAX_LENGTH_GET_URL):
         self.conn = CustomSolrConnection(url, http_connection, retry_timeout, max_length_get_url)
         self.schemadoc = schemadoc
     def __init__(self, url, schemadoc=None, http_connection=None, mode='', retry_timeout=-1, max_length_get_url=sunburnt.MAX_LENGTH_GET_URL):
         self.conn = CustomSolrConnection(url, http_connection, retry_timeout, max_length_get_url)
         self.schemadoc = schemadoc
-        if mode == 'r':
+        if 'w' not in mode:
             self.writeable = False
             self.writeable = False
-        elif mode == 'w':
+        elif 'r' not in mode:
             self.readable = False
             self.readable = False
-        self.init_schema()
+        try:
+            self.init_schema()
+        except socket.error, e:
+            raise socket.error, "Cannot connect to Solr server, and search indexing is enabled (%s)" % str(e)
 
     def _analyze(self, **kwargs):
         if not self.readable:
 
     def _analyze(self, **kwargs):
         if not self.readable:
@@ -118,39 +125,67 @@ class CustomSolrInterface(sunburnt.SolrInterface):
             end = int(wrd.xpath("int[@name='end']")[0].text)
             matches.add((start, end))
 
             end = int(wrd.xpath("int[@name='end']")[0].text)
             matches.add((start, end))
 
-        print matches
         if matches:
             return self.substring(kwargs['text'], matches,
         if matches:
             return self.substring(kwargs['text'], matches,
-                            margins=kwargs.get('margins', 30),
-            mark=kwargs.get('mark', ("<b>", "</b>")))
+                margins=kwargs.get('margins', 30),
+                mark=kwargs.get('mark', ("<b>", "</b>")))
         else:
             return None
 
     def analyze(self, **kwargs):
         else:
             return None
 
     def analyze(self, **kwargs):
-        doc = self._analyze(self, **kwargs)
-        terms = doc.xpath("/lst[@name='index']/arr[last()]/lst/str[1]")
+        doc = self._analyze(**kwargs)
+        terms = doc.xpath("//lst[@name='index']/arr[last()]/lst/str[1]")
         terms = map(lambda n: unicode(n.text), terms)
         return terms
 
         terms = map(lambda n: unicode(n.text), terms)
         return terms
 
+    def expand_margins(self, text, start, end):
+        totlen = len(text)
+
+        def is_boundary(x):
+            ws = re.compile(r"\W", re.UNICODE)
+            return bool(ws.match(x))
+
+        while start > 0:
+            if is_boundary(text[start - 1]):
+                break
+            start -= 1
+
+        while end < totlen - 1:
+            if is_boundary(text[end + 1]):
+                break
+            end += 1
+
+        return (start, end)
+
     def substring(self, text, matches, margins=30, mark=("<b>", "</b>")):
         start = None
         end = None
         totlen = len(text)
     def substring(self, text, matches, margins=30, mark=("<b>", "</b>")):
         start = None
         end = None
         totlen = len(text)
-        matches_margins = map(lambda (s, e): (max(0, s - margins), min(totlen, e + margins)), matches)
-        (start, end) = matches_margins[0]
-
-        for (s, e) in matches_margins[1:]:
+        matches_margins = map(lambda (s, e):
+                              ((s, e),
+                               (max(0, s - margins), min(totlen, e + margins))),
+                                  matches)
+        matches_margins = map(lambda (m, (s, e)):
+                              (m, self.expand_margins(text, s, e)),
+            matches_margins)
+
+            # lets start with first match
+        (start, end) = matches_margins[0][1]
+        matches = [matches_margins[0][0]]
+
+        for (m, (s, e)) in matches_margins[1:]:
             if end < s or start > e:
                 continue
             start = min(start, s)
             end = max(end, e)
             if end < s or start > e:
                 continue
             start = min(start, s)
             end = max(end, e)
+            matches.append(m)
 
         snip = text[start:end]
 
         snip = text[start:end]
-        matches = list(matches)
         matches.sort(lambda a, b: cmp(b[0], a[0]))
         matches.sort(lambda a, b: cmp(b[0], a[0]))
+
         for (s, e) in matches:
             off = - start
             snip = snip[:e + off] + mark[1] + snip[e + off:]
             snip = snip[:s + off] + mark[0] + snip[s + off:]
         for (s, e) in matches:
             off = - start
             snip = snip[:e + off] + mark[1] + snip[e + off:]
             snip = snip[:s + off] + mark[0] + snip[s + off:]
-            # maybe break on word boundaries
+
         return snip
         return snip