3fbcc965700f6ff401a050d6bb2a8f8017b33db4
[wolnelektury.git] / apps / djangosphinx / manager.py
1 import select
2 import socket
3 import time
4 import struct
5 import warnings
6 import operator
7 import apis.current as sphinxapi
8
9 try:
10     import decimal
11 except ImportError:
12     from django.utils import _decimal as decimal # for Python 2.3
13
14 from django.db.models.query import QuerySet, Q
15 from django.conf import settings
16
17 __all__ = ('SearchError', 'ConnectionError', 'SphinxSearch', 'SphinxRelation')
18
19 from django.contrib.contenttypes.models import ContentType
20 from datetime import datetime, date
21
22 # server settings
23 SPHINX_SERVER           = getattr(settings, 'SPHINX_SERVER', 'localhost')
24 SPHINX_PORT             = int(getattr(settings, 'SPHINX_PORT', 3312))
25
26 # These require search API 275 (Sphinx 0.9.8)
27 SPHINX_RETRIES          = int(getattr(settings, 'SPHINX_RETRIES', 0))
28 SPHINX_RETRIES_DELAY    = int(getattr(settings, 'SPHINX_RETRIES_DELAY', 5))
29
30 MAX_INT = int(2**31-1)
31
32 class SearchError(Exception): pass
33 class ConnectionError(Exception): pass
34
35 class SphinxProxy(object):
36     """
37     Acts exactly like a normal instance of an object except that
38     it will handle any special sphinx attributes in a _sphinx class.
39     """
40     __slots__ = ('__dict__', '__instance__', '_sphinx')
41
42     def __init__(self, instance, attributes):
43         object.__setattr__(self, '__instance__', instance)
44         object.__setattr__(self, '_sphinx', attributes)
45
46     def _get_current_object(self):
47         """
48         Return the current object.  This is useful if you want the real object
49         behind the proxy at a time for performance reasons or because you want
50         to pass the object into a different context.
51         """
52         return self.__instance__
53     __current_object = property(_get_current_object)
54
55     def __dict__(self):
56         try:
57             return self.__current_object.__dict__
58         except RuntimeError:
59             return AttributeError('__dict__')
60     __dict__ = property(__dict__)
61
62     def __repr__(self):
63         try:
64             obj = self.__current_object
65         except RuntimeError:
66             return '<%s unbound>' % self.__class__.__name__
67         return repr(obj)
68
69     def __nonzero__(self):
70         try:
71             return bool(self.__current_object)
72         except RuntimeError:
73             return False
74
75     def __unicode__(self):
76         try:
77             return unicode(self.__current_oject)
78         except RuntimeError:
79             return repr(self)
80
81     def __dir__(self):
82         try:
83             return dir(self.__current_object)
84         except RuntimeError:
85             return []
86
87     def __getattr__(self, name, value=None):
88         if name == '__members__':
89             return dir(self.__current_object)
90         elif name == '_sphinx':
91             return object.__getattr__(self, '_sphinx', value)
92         return getattr(self.__current_object, name)
93
94     def __setattr__(self, name, value):
95         if name == '_sphinx':
96             return object.__setattr__(self, '_sphinx', value)
97         return setattr(self.__current_object, name, value)
98
99     def __setitem__(self, key, value):
100         self.__current_object[key] = value
101
102     def __delitem__(self, key):
103         del self.__current_object[key]
104
105     def __setslice__(self, i, j, seq):
106         self.__current_object[i:j] = seq
107
108     def __delslice__(self, i, j):
109         del self.__current_object[i:j]
110
111     __delattr__ = lambda x, n: delattr(x.__current_object, n)
112     __str__ = lambda x: str(x.__current_object)
113     __unicode__ = lambda x: unicode(x.__current_object)
114     __lt__ = lambda x, o: x.__current_object < o
115     __le__ = lambda x, o: x.__current_object <= o
116     __eq__ = lambda x, o: x.__current_object == o
117     __ne__ = lambda x, o: x.__current_object != o
118     __gt__ = lambda x, o: x.__current_object > o
119     __ge__ = lambda x, o: x.__current_object >= o
120     __cmp__ = lambda x, o: cmp(x.__current_object, o)
121     __hash__ = lambda x: hash(x.__current_object)
122     # attributes are currently not callable
123     # __call__ = lambda x, *a, **kw: x.__current_object(*a, **kw)
124     __len__ = lambda x: len(x.__current_object)
125     __getitem__ = lambda x, i: x.__current_object[i]
126     __iter__ = lambda x: iter(x.__current_object)
127     __contains__ = lambda x, i: i in x.__current_object
128     __getslice__ = lambda x, i, j: x.__current_object[i:j]
129     __add__ = lambda x, o: x.__current_object + o
130     __sub__ = lambda x, o: x.__current_object - o
131     __mul__ = lambda x, o: x.__current_object * o
132     __floordiv__ = lambda x, o: x.__current_object // o
133     __mod__ = lambda x, o: x.__current_object % o
134     __divmod__ = lambda x, o: x.__current_object.__divmod__(o)
135     __pow__ = lambda x, o: x.__current_object ** o
136     __lshift__ = lambda x, o: x.__current_object << o
137     __rshift__ = lambda x, o: x.__current_object >> o
138     __and__ = lambda x, o: x.__current_object & o
139     __xor__ = lambda x, o: x.__current_object ^ o
140     __or__ = lambda x, o: x.__current_object | o
141     __div__ = lambda x, o: x.__current_object.__div__(o)
142     __truediv__ = lambda x, o: x.__current_object.__truediv__(o)
143     __neg__ = lambda x: -(x.__current_object)
144     __pos__ = lambda x: +(x.__current_object)
145     __abs__ = lambda x: abs(x.__current_object)
146     __invert__ = lambda x: ~(x.__current_object)
147     __complex__ = lambda x: complex(x.__current_object)
148     __int__ = lambda x: int(x.__current_object)
149     __long__ = lambda x: long(x.__current_object)
150     __float__ = lambda x: float(x.__current_object)
151     __oct__ = lambda x: oct(x.__current_object)
152     __hex__ = lambda x: hex(x.__current_object)
153     __index__ = lambda x: x.__current_object.__index__()
154     __coerce__ = lambda x, o: x.__coerce__(x, o)
155     __enter__ = lambda x: x.__enter__()
156     __exit__ = lambda x, *a, **kw: x.__exit__(*a, **kw)
157
158 def to_sphinx(value):
159     "Convert a value into a sphinx query value"
160     if isinstance(value, date) or isinstance(value, datetime):
161         return int(time.mktime(value.timetuple()))
162     elif isinstance(value, decimal.Decimal) or isinstance(value, float):
163         return float(value)
164     return int(value)
165
166 class SphinxQuerySet(object):
167     available_kwargs = ('rankmode', 'mode', 'weights', 'maxmatches')
168     
169     def __init__(self, model=None, **kwargs):
170         self._select_related        = False
171         self._select_related_args   = {}
172         self._select_related_fields = []
173         self._filters               = {}
174         self._excludes              = {}
175         self._extra                 = {}
176         self._query                 = ''
177         self.__metadata             = None
178         self._offset                = 0
179         self._limit                 = 20
180
181         self._groupby               = None
182         self._sort                  = None
183         self._weights               = [1, 100]
184
185         self._maxmatches            = 1000
186         self._result_cache          = None
187         self._mode                  = sphinxapi.SPH_MATCH_ALL
188         self._rankmode              = getattr(sphinxapi, 'SPH_RANK_PROXIMITY_BM25', None)
189         self._model                 = model
190         self._anchor                = {}
191         self.__metadata             = {}
192         
193         self.set_options(**kwargs)
194
195         if model:
196             self._index             = kwargs.get('index', model._meta.db_table)
197         else:
198             self._index             = kwargs.get('index')
199
200     def __repr__(self):
201         if self._result_cache is not None:
202             return repr(self._get_data())
203         else:
204             return '<%s instance>' % (self.__class__.__name__,)
205
206     def __len__(self):
207         return len(self._get_data())
208         
209     def __iter__(self):
210         return iter(self._get_data())
211     
212     def __getitem__(self, k):
213         if not isinstance(k, (slice, int, long)):
214             raise TypeError
215         assert (not isinstance(k, slice) and (k >= 0)) \
216             or (isinstance(k, slice) and (k.start is None or k.start >= 0) and (k.stop is None or k.stop >= 0)), \
217             "Negative indexing is not supported."
218         if type(k) == slice:
219             if self._offset < k.start or k.stop-k.start > self._limit:
220                 self._result_cache = None
221         else:
222             if k not in range(self._offset, self._limit+self._offset):
223                 self._result_cache = None
224         if self._result_cache is None:
225             if type(k) == slice:
226                 self._offset = k.start
227                 self._limit = k.stop-k.start
228                 return self._get_results()
229             else:
230                 self._offset = k
231                 self._limit = 1
232                 return self._get_results()[0]
233         else:
234             return self._result_cache[k]
235
236     def set_options(self, **kwargs):
237         if 'rankmode' in kwargs:
238             if kwargs.get('rankmode') is None:
239                 kwargs['rankmode'] = sphinxapi.SPH_RANK_NONE
240         for key in self.available_kwargs:
241             if key in kwargs:
242                 setattr(self, '_%s' % (key,), kwargs[key])
243
244     def query(self, string):
245         return self._clone(_query=unicode(string).encode('utf-8'))
246
247     def group_by(self, attribute, func, groupsort='@group desc'):
248         return self._clone(_groupby=attribute, _groupfunc=func, _groupsort=groupsort)
249
250     def rank_none(self):
251         warnings.warn('`rank_none()` is deprecated. Use `set_options(rankmode=None)` instead.', DeprecationWarning)
252         return self._clone(_rankmode=sphinxapi.SPH_RANK_NONE)
253
254     def mode(self, mode):
255         warnings.warn('`mode()` is deprecated. Use `set_options(mode='')` instead.', DeprecationWarning)
256         return self._clone(_mode=mode)
257
258     def weights(self, weights):
259         warnings.warn('`mode()` is deprecated. Use `set_options(weights=[])` instead.', DeprecationWarning)
260         return self._clone(_weights=weights)
261
262     def on_index(self, index):
263         warnings.warn('`mode()` is deprecated. Use `set_options(on_index=foo)` instead.', DeprecationWarning)
264         return self._clone(_index=index)
265
266     # only works on attributes
267     def filter(self, **kwargs):
268         filters = self._filters.copy()
269         for k,v in kwargs.iteritems():
270             if hasattr(v, 'next'):
271                 v = list(v)
272             elif not (isinstance(v, list) or isinstance(v, tuple)):
273                  v = [v,]
274             filters.setdefault(k, []).extend(map(to_sphinx, v))
275         return self._clone(_filters=filters)
276
277     def geoanchor(self, lat_attr, lng_attr, lat, lng):
278         assert(sphinxapi.VER_COMMAND_SEARCH >= 0x113, "You must upgrade sphinxapi to version 0.98 to use Geo Anchoring.")
279         return self._clone(_anchor=(lat_attr, lng_attr, float(lat), float(lng)))
280
281     # this actually does nothing, its just a passthru to
282     # keep things looking/working generally the same
283     def all(self):
284         return self
285
286     # only works on attributes
287     def exclude(self, **kwargs):
288         filters = self._excludes.copy()
289         for k,v in kwargs.iteritems():
290             if hasattr(v, 'next'):
291                 v = list(v)
292             elif not (isinstance(v, list) or isinstance(v, tuple)):
293                  v = [v,]
294             filters.setdefault(k, []).extend(map(to_sphinx, v))
295         return self._clone(_excludes=filters)
296
297     # you cannot order by @weight (it always orders in descending)
298     # keywords are @id, @weight, @rank, and @relevance
299     def order_by(self, *args):
300         sort_by = []
301         for arg in args:
302             sort = 'ASC'
303             if arg[0] == '-':
304                 arg = arg[1:]
305                 sort = 'DESC'
306             if arg == 'id':
307                 arg = '@id'
308             sort_by.append('%s %s' % (arg, sort))
309         if sort_by:
310             return self._clone(_sort=(sphinxapi.SPH_SORT_EXTENDED, ', '.join(sort_by)))
311         return self
312                     
313     # pass these thru on the queryset and let django handle it
314     def select_related(self, *args, **kwargs):
315         _args = self._select_related_fields[:]
316         _args.extend(args)
317         _kwargs = self._select_related_args.copy()
318         _kwargs.update(kwargs)
319         
320         return self._clone(
321             _select_related=True,
322             _select_related_fields=_args,
323             _select_related_args=_kwargs,
324         )
325     
326     def extra(self, **kwargs):
327         extra = self._extra.copy()
328         extra.update(kwargs)
329         return self._clone(_extra=extra)
330
331     def count(self):
332         return min(self._sphinx.get('total_found', 0), self._maxmatches)
333
334     def reset(self):
335         return self.__class__(self._model, self._index)
336
337     # Internal methods
338     def _clone(self, **kwargs):
339         # Clones the queryset passing any changed args
340         c = self.__class__()
341         c.__dict__.update(self.__dict__)
342         c.__dict__.update(kwargs)
343         return c
344     
345     def _sphinx(self):
346         if not self.__metadata:
347             # We have to force execution if this is accessed beforehand
348             self._get_data()
349         return self.__metadata
350     _sphinx = property(_sphinx)
351
352     def _get_data(self):
353         assert(self._index)
354         # need to find a way to make this work yet
355         if self._result_cache is None:
356             self._result_cache = list(self._get_results())
357         return self._result_cache
358
359     def _get_sphinx_results(self):
360         assert(self._offset + self._limit <= self._maxmatches)
361
362         client = sphinxapi.SphinxClient()
363         client.SetServer(SPHINX_SERVER, SPHINX_PORT)
364
365         if self._sort:
366             client.SetSortMode(*self._sort)
367         
368         if isinstance(self._weights, dict):
369             client.SetFieldWeights(self._weights)
370         else:
371             # assume its a list
372             client.SetWeights(map(int, self._weights))
373         
374         client.SetMatchMode(self._mode)
375
376         # 0.97 requires you to reset it
377         if hasattr(client, 'ResetFilters'):
378              client.ResetFilters()
379         if hasattr(client, 'ResetGroupBy'):
380              client.ResetGroupBy()
381         
382         def _handle_filters(filter_list, exclude=False):
383             for name, values in filter_list.iteritems():
384                 parts = len(name.split('__'))
385                 if parts > 2:
386                     raise NotImplementedError, 'Related object and/or multiple field lookups not supported'
387                 elif parts == 2:
388                     # The float handling for __gt and __lt is kind of ugly..
389                     name, lookup = name.split('__', 1)
390                     is_float = isinstance(values[0], float)
391                     if lookup == 'gt':
392                         value = is_float and values[0] + (1.0/MAX_INT) or values[0] - 1
393                         args = (name, value, MAX_INT, exclude)
394                     elif lookup == 'gte':
395                         args = (name, values[0], MAX_INT, exclude)
396                     elif lookup == 'lt':
397                         value = is_float and values[0] - (1.0/MAX_INT) or values[0] - 1
398                         args = (name, -MAX_INT, value, exclude)
399                     elif lookup == 'lte':
400                         args = (name, -MAX_INT, values[0], exclude)
401                     elif lookup == 'range':
402                         args = (name, values[0], values[1], exclude)
403                     else:
404                         raise NotImplementedError, 'Related object and/or field lookup "%s" not supported' % lookup
405                     if is_float:
406                         client.SetFilterFloatRange(*args)
407                     elif not exclude and self._model and name == self._model._meta.pk.column:
408                         client.SetIDRange(*args[1:3])
409                     else:
410                         client.SetFilterRange(*args)
411
412                 else:
413                     client.SetFilter(name, values, exclude)
414
415         # Include filters
416         if self._filters:
417             _handle_filters(self._filters)
418
419         # Exclude filters
420         if self._excludes:
421             _handle_filters(self._excludes, True)
422         
423         if self._groupby:
424             client.SetGroupBy(self._groupby, self._groupfunc, self._groupsort)
425
426         if self._anchor:
427             client.SetGeoAnchor(*self._anchor)
428
429         if self._rankmode:
430             client.SetRankingMode(self._rankmode)
431
432         if not self._limit > 0:
433             # Fix for Sphinx throwing an assertion error when you pass it an empty limiter
434             return []
435         
436
437         if sphinxapi.VER_COMMAND_SEARCH >= 0x113:
438             client.SetRetries(SPHINX_RETRIES, SPHINX_RETRIES_DELAY)
439         
440         client.SetLimits(int(self._offset), int(self._limit), int(self._maxmatches))
441         
442         results = client.Query(self._query, self._index)
443         
444         # The Sphinx API doesn't raise exceptions
445         if not results:
446             if client.GetLastError():
447                 raise SearchError, client.GetLastError()
448             elif client.GetLastWarning():
449                 raise SearchError, client.GetLastWarning()
450         return results
451
452     def _get_results(self):
453         results = self._get_sphinx_results()
454         if not results or not results['matches']:
455             results = []
456         elif self._model:
457             queryset = self._model.objects.all()
458             if self._select_related:
459                 queryset = queryset.select_related(*self._select_related_fields, **self._select_related_args)
460             if self._extra:
461                 queryset = queryset.extra(**self._extra)
462             pks = getattr(self._model._meta, 'pks', None)
463             if pks is None or len(pks) == 1:
464                 queryset = queryset.filter(pk__in=[r['id'] for r in results['matches']])
465                 queryset = dict([(o.pk, o) for o in queryset])
466             else:
467                 for r in results['matches']:
468                     r['id'] = ', '.join([unicode(r['attrs'][p.column]) for p in pks])
469                 q = reduce(operator.or_, [reduce(operator.and_, [Q(**{p.name: r['attrs'][p.column]}) for p in pks]) for r in results['matches']])
470                 if q:
471                     queryset = queryset.filter(q)
472                     queryset = dict([(', '.join([unicode(p) for p in o.pks]), o) for o in queryset])
473                 else:
474                     queryset = None
475         
476             if queryset:
477                 self.__metadata = {
478                     'total': results['total'],
479                     'total_found': results['total_found'],
480                     'words': results['words'],
481                 }
482                 results = [SphinxProxy(queryset[r['id']], r) for r in results['matches'] if r['id'] in queryset]
483             else:
484                 results = []
485         else:
486             "We did a query without a model, lets see if there's a content_type"
487             results['attrs'] = dict(results['attrs'])
488             if 'content_type' in results['attrs']:
489                 "Now we have to do one query per content_type"
490                 objcache = {}
491                 for r in results['matches']:
492                     ct = r['attrs']['content_type']
493                     if ct not in objcache:
494                         objcache[ct] = {}
495                     objcache[ct][r['id']] = None
496                 for ct in objcache:
497                     queryset = ContentType.objects.get(pk=ct).model_class().objects.filter(pk__in=objcache[ct])
498                     for o in queryset:
499                         objcache[ct][o.id] = o
500                 results = [objcache[r['attrs']['content_type']][r['id']] for r in results['matches']]
501             else:
502                 results = results['matches']
503         self._result_cache = results
504         return results
505
506 class SphinxModelManager(object):
507     def __init__(self, model, **kwargs):
508         self._model = model
509         self._index = kwargs.pop('index', model._meta.db_table)
510         self._kwargs = kwargs
511     
512     def _get_query_set(self):
513         return SphinxQuerySet(self._model, index=self._index, **self._kwargs)
514     
515     def get_index(self):
516         return self._index
517     
518     def all(self):
519         return self._get_query_set()
520     
521     def filter(self, **kwargs):
522         return self._get_query_set().filter(**kwargs)
523     
524     def query(self, *args, **kwargs):
525         return self._get_query_set().query(*args, **kwargs)
526
527     def on_index(self, *args, **kwargs):
528         return self._get_query_set().on_index(*args, **kwargs)
529
530     def geoanchor(self, *args, **kwargs):
531         return self._get_query_set().geoanchor(*args, **kwargs)
532
533 class SphinxInstanceManager(object):
534     """Collection of tools useful for objects which are in a Sphinx index."""
535     def __init__(self, instance, index):
536         self._instance = instance
537         self._index = index
538         
539     def update(self, **kwargs):
540         assert(sphinxapi.VER_COMMAND_SEARCH >= 0x113, "You must upgrade sphinxapi to version 0.98 to use Geo Anchoring.")
541         sphinxapi.UpdateAttributes(index, kwargs.keys(), dict(self.instance.pk, map(to_sphinx, kwargs.values())))
542
543
544 class SphinxSearch(object):
545     def __init__(self, index=None, **kwargs):
546         self._kwargs = kwargs
547         self._sphinx = None
548         self._index = index
549         self.model = None
550         
551     def __call__(self, index, **kwargs):
552         warnings.warn('For non-model searches use a SphinxQuerySet instance.', DeprecationWarning)
553         return SphinxQuerySet(index=index, **kwargs)
554         
555     def __get__(self, instance, model, **kwargs):
556         if instance:
557             return SphinxInstanceManager(instance, index)
558         return self._sphinx
559
560     def contribute_to_class(self, model, name, **kwargs):
561         if self._index is None:
562             self._index = model._meta.db_table
563         self._sphinx = SphinxModelManager(model, index=self._index, **self._kwargs)
564         self.model = model
565         if getattr(model, '__sphinx_indexes__', None) is None:
566             setattr(model, '__sphinx_indexes__', [self._index])
567         else:
568             model.__sphinx_indexes__.append(self._index)
569         setattr(model, name, self._sphinx)
570
571 class SphinxRelationProxy(SphinxProxy):
572     def count(self):
573         return min(self._sphinx['attrs']['@count'], self._maxmatches)
574     
575 class SphinxRelation(SphinxSearch):
576     """
577     Adds "related model" support to django-sphinx --
578     http://code.google.com/p/django-sphinx/
579     http://www.sphinxsearch.com/
580     
581     Example --
582     
583     class MySearch(SphinxSearch):
584         myrelatedobject = SphinxRelation(RelatedModel)
585         anotherone = SphinxRelation(AnotherModel)
586         ...
587     
588     class MyModel(models.Model):
589         search = MySearch('index')
590     
591     """
592     def __init__(self, model=None, attr=None, sort='@count desc', **kwargs):
593         if model:
594             self._related_model = model
595             self._related_attr = attr or model.__name__.lower()
596             self._related_sort = sort
597         super(SphinxRelation, self).__init__(**kwargs)
598         
599     def __get__(self, instance, instance_model, **kwargs):
600         self._mode = instance._mode
601         self._rankmode = instance._rankmode
602         self._index = instance._index
603         self._query = instance._query
604         self._filters = instance._filters
605         self._excludes = instance._excludes
606         self._model = self._related_model
607         self._groupby = self._related_attr
608         self._groupsort = self._related_sort
609         self._groupfunc = sphinxapi.SPH_GROUPBY_ATTR
610         return self
611
612     def _get_results(self):
613         results = self._get_sphinx_results()
614         if not results: return []
615         if results['matches'] and self._model:
616             ids = []
617             for r in results['matches']:
618                 value = r['attrs']['@groupby']
619                 if isinstance(value, (int, long)):
620                     ids.append(value)
621                 else:
622                     ids.extend()
623             qs = self._model.objects.filter(pk__in=set(ids))
624             if self._select_related:
625                 qs = qs.select_related(*self._select_related_fields,
626                                        **self._select_related_args)
627             if self._extra:
628                 qs = qs.extra(**self._extra)
629             queryset = dict([(o.id, o) for o in qs])
630             self.__metadata = {
631                 'total': results['total'],
632                 'total_found': results['total_found'],
633                 'words': results['words'],
634             }
635             results = [ SphinxRelationProxy(queryset[k['attrs']['@groupby']], k) \
636                         for k in results['matches'] \
637                         if k['attrs']['@groupby'] in queryset ]
638         else:
639             results = []
640         self._result_cache = results
641         return results
642
643     def _sphinx(self):
644         if not self.__metadata:
645             # We have to force execution if this is accessed beforehand
646             self._get_data()
647         return self.__metadata
648     _sphinx = property(_sphinx)