Function get_tag_list is now a static method of class TagBase.
[wolnelektury.git] / newtagging / models.py
1 """
2 Models and managers for generic tagging.
3 """
4 # Python 2.3 compatibility
5 if not hasattr(__builtins__, 'set'):
6     from sets import Set as set
7
8 from django.contrib.contenttypes import generic
9 from django.contrib.contenttypes.models import ContentType
10 from django.db import connection, models
11 from django.db.models.query import QuerySet
12 from django.utils.translation import ugettext_lazy as _
13 from django.db.models.base import ModelBase
14
15 qn = connection.ops.quote_name
16
17 try:
18     from django.db.models.query import parse_lookup
19 except ImportError:
20     parse_lookup = None
21
22
23 def get_queryset_and_model(queryset_or_model):
24     """
25     Given a ``QuerySet`` or a ``Model``, returns a two-tuple of
26     (queryset, model).
27
28     If a ``Model`` is given, the ``QuerySet`` returned will be created
29     using its default manager.
30     """
31     try:
32         return queryset_or_model, queryset_or_model.model
33     except AttributeError:
34         return queryset_or_model._default_manager.all(), queryset_or_model
35
36
37 ############
38 # Managers #
39 ############
40 class TagManager(models.Manager):
41     def __init__(self, intermediary_table_model):
42         super(TagManager, self).__init__()
43         self.intermediary_table_model = intermediary_table_model
44     
45     def update_tags(self, obj, tags):
46         """
47         Update tags associated with an object.
48         """
49         content_type = ContentType.objects.get_for_model(obj)
50         current_tags = list(self.filter(items__content_type__pk=content_type.pk,
51                                         items__object_id=obj.pk))
52         updated_tags = self.model.get_tag_list(tags)
53     
54         # Remove tags which no longer apply
55         tags_for_removal = [tag for tag in current_tags \
56                             if tag not in updated_tags]
57         if len(tags_for_removal):
58             self.intermediary_table_model._default_manager.filter(content_type__pk=content_type.pk,
59                                                object_id=obj.pk,
60                                                tag__in=tags_for_removal).delete()
61         # Add new tags
62         for tag in updated_tags:
63             if tag not in current_tags:
64                 self.intermediary_table_model._default_manager.create(tag=tag, content_object=obj)
65     
66     def get_for_object(self, obj):
67         """
68         Create a queryset matching all tags associated with the given
69         object.
70         """
71         ctype = ContentType.objects.get_for_model(obj)
72         return self.filter(items__content_type__pk=ctype.pk,
73                            items__object_id=obj.pk)
74     
75     def _get_usage(self, model, counts=False, min_count=None, extra_joins=None, extra_criteria=None, params=None, extra=None):
76         """
77         Perform the custom SQL query for ``usage_for_model`` and
78         ``usage_for_queryset``.
79         """
80         if min_count is not None: counts = True
81
82         model_table = qn(model._meta.db_table)
83         model_pk = '%s.%s' % (model_table, qn(model._meta.pk.column))
84         tag_columns = self._get_tag_columns()
85         
86         if extra is None: extra = {}
87         extra_where = ''
88         if 'where' in extra:
89             extra_where = 'AND ' + ' AND '.join(extra['where'])
90         
91         query = """
92         SELECT DISTINCT %(tag_columns)s%(count_sql)s
93         FROM
94             %(tag)s
95             INNER JOIN %(tagged_item)s
96                 ON %(tag)s.id = %(tagged_item)s.tag_id
97             INNER JOIN %(model)s
98                 ON %(tagged_item)s.object_id = %(model_pk)s
99             %%s
100         WHERE %(tagged_item)s.content_type_id = %(content_type_id)s
101             %%s
102             %(extra_where)s
103         GROUP BY %(tag)s.id, %(tag)s.name
104         %%s
105         ORDER BY %(tag)s.%(ordering)s ASC""" % {
106             'tag': qn(self.model._meta.db_table),
107             'ordering': ', '.join(qn(field) for field in self.model._meta.ordering),
108             'tag_columns': tag_columns,
109             'count_sql': counts and (', COUNT(%s)' % model_pk) or '',
110             'tagged_item': qn(self.intermediary_table_model._meta.db_table),
111             'model': model_table,
112             'model_pk': model_pk,
113             'extra_where': extra_where,
114             'content_type_id': ContentType.objects.get_for_model(model).pk,
115         }
116
117         min_count_sql = ''
118         if min_count is not None:
119             min_count_sql = 'HAVING COUNT(%s) >= %%s' % model_pk
120             params.append(min_count)
121
122         cursor = connection.cursor()
123         cursor.execute(query % (extra_joins, extra_criteria, min_count_sql), params)
124         tags = []
125         for row in cursor.fetchall():
126             t = self.model(*row[:len(self.model._meta.fields)])
127             if counts:
128                 t.count = row[len(self.model._meta.fields)]
129             tags.append(t)
130         return tags
131
132     def usage_for_model(self, model, counts=False, min_count=None, filters=None, extra=None):
133         """
134         Obtain a list of tags associated with instances of the given
135         Model class.
136
137         If ``counts`` is True, a ``count`` attribute will be added to
138         each tag, indicating how many times it has been used against
139         the Model class in question.
140
141         If ``min_count`` is given, only tags which have a ``count``
142         greater than or equal to ``min_count`` will be returned.
143         Passing a value for ``min_count`` implies ``counts=True``.
144
145         To limit the tags (and counts, if specified) returned to those
146         used by a subset of the Model's instances, pass a dictionary
147         of field lookups to be applied to the given Model as the
148         ``filters`` argument.
149         """
150         if extra is None: extra = {}
151         if filters is None: filters = {}
152
153         if not parse_lookup:
154             # post-queryset-refactor (hand off to usage_for_queryset)
155             queryset = model._default_manager.filter()
156             for f in filters.items():
157                 queryset.query.add_filter(f)
158             usage = self.usage_for_queryset(queryset, counts, min_count, extra)
159         else:
160             # pre-queryset-refactor
161             extra_joins = ''
162             extra_criteria = ''
163             params = []
164             if len(filters) > 0:
165                 joins, where, params = parse_lookup(filters.items(), model._meta)
166                 extra_joins = ' '.join(['%s %s AS %s ON %s' % (join_type, table, alias, condition)
167                                         for (alias, (table, join_type, condition)) in joins.items()])
168                 extra_criteria = 'AND %s' % (' AND '.join(where))
169             usage = self._get_usage(model, counts, min_count, extra_joins, extra_criteria, params, extra)
170
171         return usage
172
173     def usage_for_queryset(self, queryset, counts=False, min_count=None, extra=None):
174         """
175         Obtain a list of tags associated with instances of a model
176         contained in the given queryset.
177
178         If ``counts`` is True, a ``count`` attribute will be added to
179         each tag, indicating how many times it has been used against
180         the Model class in question.
181
182         If ``min_count`` is given, only tags which have a ``count``
183         greater than or equal to ``min_count`` will be returned.
184         Passing a value for ``min_count`` implies ``counts=True``.
185         """
186         if parse_lookup:
187             raise AttributeError("'TagManager.usage_for_queryset' is not compatible with pre-queryset-refactor versions of Django.")
188
189         extra_joins = ' '.join(queryset.query.get_from_clause()[0][1:])
190         where, params = queryset.query.where.as_sql()
191         if where:
192             extra_criteria = 'AND %s' % where
193         else:
194             extra_criteria = ''
195         return self._get_usage(queryset.model, counts, min_count, extra_joins, extra_criteria, params, extra)
196
197     def related_for_model(self, tags, model, counts=False, min_count=None, extra=None):
198         """
199         Obtain a list of tags related to a given list of tags - that
200         is, other tags used by items which have all the given tags.
201
202         If ``counts`` is True, a ``count`` attribute will be added to
203         each tag, indicating the number of items which have it in
204         addition to the given list of tags.
205
206         If ``min_count`` is given, only tags which have a ``count``
207         greater than or equal to ``min_count`` will be returned.
208         Passing a value for ``min_count`` implies ``counts=True``.
209         """
210         if min_count is not None: counts = True
211         tags = self.model.get_tag_list(tags)
212         tag_count = len(tags)
213         tagged_item_table = qn(self.intermediary_table_model._meta.db_table)
214         tag_columns = self._get_tag_columns()
215         
216         if extra is None: extra = {}
217         extra_where = ''
218         if 'where' in extra:
219             extra_where = 'AND ' + ' AND '.join(extra['where'])
220         
221         query = """
222         SELECT %(tag_columns)s%(count_sql)s
223         FROM %(tagged_item)s INNER JOIN %(tag)s ON %(tagged_item)s.tag_id = %(tag)s.id
224         WHERE %(tagged_item)s.content_type_id = %(content_type_id)s
225           AND %(tagged_item)s.object_id IN
226           (
227               SELECT %(tagged_item)s.object_id
228               FROM %(tagged_item)s, %(tag)s
229               WHERE %(tagged_item)s.content_type_id = %(content_type_id)s
230                 AND %(tag)s.id = %(tagged_item)s.tag_id
231                 AND %(tag)s.id IN (%(tag_id_placeholders)s)
232               GROUP BY %(tagged_item)s.object_id
233               HAVING COUNT(%(tagged_item)s.object_id) = %(tag_count)s
234           )
235           AND %(tag)s.id NOT IN (%(tag_id_placeholders)s)
236           %(extra_where)s
237         GROUP BY %(tag)s.id, %(tag)s.name
238         %(min_count_sql)s
239         ORDER BY %(tag)s.%(ordering)s ASC""" % {
240             'tag': qn(self.model._meta.db_table),
241             'ordering': ', '.join(qn(field) for field in self.model._meta.ordering),
242             'tag_columns': tag_columns,
243             'count_sql': counts and ', COUNT(%s.object_id)' % tagged_item_table or '',
244             'tagged_item': tagged_item_table,
245             'content_type_id': ContentType.objects.get_for_model(model).pk,
246             'tag_id_placeholders': ','.join(['%s'] * tag_count),
247             'extra_where': extra_where,
248             'tag_count': tag_count,
249             'min_count_sql': min_count is not None and ('HAVING COUNT(%s.object_id) >= %%s' % tagged_item_table) or '',
250         }
251
252         params = [tag.pk for tag in tags] * 2
253         if min_count is not None:
254             params.append(min_count)
255
256         cursor = connection.cursor()
257         cursor.execute(query, params)
258         related = []
259         for row in cursor.fetchall():
260             tag = self.model(*row[:len(self.model._meta.fields)])
261             if counts is True:
262                 tag.count = row[len(self.model._meta.fields)]
263             related.append(tag)
264         return related
265
266     def _get_tag_columns(self):
267         tag_table = qn(self.model._meta.db_table)
268         return ', '.join('%s.%s' % (tag_table, qn(field.column)) for field in self.model._meta.fields)
269
270
271 class TaggedItemManager(models.Manager):
272     """
273     FIXME There's currently no way to get the ``GROUP BY`` and ``HAVING``
274           SQL clauses required by many of this manager's methods into
275           Django's ORM.
276
277           For now, we manually execute a query to retrieve the PKs of
278           objects we're interested in, then use the ORM's ``__in``
279           lookup to return a ``QuerySet``.
280
281           Once the queryset-refactor branch lands in trunk, this can be
282           tidied up significantly.
283     """
284     def __init__(self, tag_model):
285         super(TaggedItemManager, self).__init__()
286         self.tag_model = tag_model
287     
288     def get_by_model(self, queryset_or_model, tags):
289         """
290         Create a ``QuerySet`` containing instances of the specified
291         model associated with a given tag or list of tags.
292         """
293         tags = self.tag_model.get_tag_list(tags)
294         tag_count = len(tags)
295         if tag_count == 0:
296             # No existing tags were given
297             queryset, model = get_queryset_and_model(queryset_or_model)
298             return model._default_manager.none()
299         elif tag_count == 1:
300             # Optimisation for single tag - fall through to the simpler
301             # query below.
302             tag = tags[0]
303         else:
304             return self.get_intersection_by_model(queryset_or_model, tags)
305
306         queryset, model = get_queryset_and_model(queryset_or_model)
307         content_type = ContentType.objects.get_for_model(model)
308         opts = self.model._meta
309         tagged_item_table = qn(opts.db_table)
310         return queryset.extra(
311             tables=[opts.db_table],
312             where=[
313                 '%s.content_type_id = %%s' % tagged_item_table,
314                 '%s.tag_id = %%s' % tagged_item_table,
315                 '%s.%s = %s.object_id' % (qn(model._meta.db_table),
316                                           qn(model._meta.pk.column),
317                                           tagged_item_table)
318             ],
319             params=[content_type.pk, tag.pk],
320         )
321
322     def get_intersection_by_model(self, queryset_or_model, tags):
323         """
324         Create a ``QuerySet`` containing instances of the specified
325         model associated with *all* of the given list of tags.
326         """
327         tags = self.tag_model.get_tag_list(tags)
328         tag_count = len(tags)
329         queryset, model = get_queryset_and_model(queryset_or_model)
330
331         if not tag_count:
332             return model._default_manager.none()
333
334         model_table = qn(model._meta.db_table)
335         # This query selects the ids of all objects which have all the
336         # given tags.
337         query = """
338         SELECT %(model_pk)s
339         FROM %(model)s, %(tagged_item)s
340         WHERE %(tagged_item)s.content_type_id = %(content_type_id)s
341           AND %(tagged_item)s.tag_id IN (%(tag_id_placeholders)s)
342           AND %(model_pk)s = %(tagged_item)s.object_id
343         GROUP BY %(model_pk)s
344         HAVING COUNT(%(model_pk)s) = %(tag_count)s""" % {
345             'model_pk': '%s.%s' % (model_table, qn(model._meta.pk.column)),
346             'model': model_table,
347             'tagged_item': qn(self.model._meta.db_table),
348             'content_type_id': ContentType.objects.get_for_model(model).pk,
349             'tag_id_placeholders': ','.join(['%s'] * tag_count),
350             'tag_count': tag_count,
351         }
352
353         cursor = connection.cursor()
354         cursor.execute(query, [tag.pk for tag in tags])
355         object_ids = [row[0] for row in cursor.fetchall()]
356         if len(object_ids) > 0:
357             return queryset.filter(pk__in=object_ids)
358         else:
359             return model._default_manager.none()
360
361     def get_union_by_model(self, queryset_or_model, tags):
362         """
363         Create a ``QuerySet`` containing instances of the specified
364         model associated with *any* of the given list of tags.
365         """
366         tags = self.tag_model.get_tag_list(tags)
367         tag_count = len(tags)
368         queryset, model = get_queryset_and_model(queryset_or_model)
369
370         if not tag_count:
371             return model._default_manager.none()
372
373         model_table = qn(model._meta.db_table)
374         # This query selects the ids of all objects which have any of
375         # the given tags.
376         query = """
377         SELECT %(model_pk)s
378         FROM %(model)s, %(tagged_item)s
379         WHERE %(tagged_item)s.content_type_id = %(content_type_id)s
380           AND %(tagged_item)s.tag_id IN (%(tag_id_placeholders)s)
381           AND %(model_pk)s = %(tagged_item)s.object_id
382         GROUP BY %(model_pk)s""" % {
383             'model_pk': '%s.%s' % (model_table, qn(model._meta.pk.column)),
384             'model': model_table,
385             'tagged_item': qn(self.model._meta.db_table),
386             'content_type_id': ContentType.objects.get_for_model(model).pk,
387             'tag_id_placeholders': ','.join(['%s'] * tag_count),
388         }
389
390         cursor = connection.cursor()
391         cursor.execute(query, [tag.pk for tag in tags])
392         object_ids = [row[0] for row in cursor.fetchall()]
393         if len(object_ids) > 0:
394             return queryset.filter(pk__in=object_ids)
395         else:
396             return model._default_manager.none()
397
398     def get_related(self, obj, queryset_or_model, num=None):
399         """
400         Retrieve a list of instances of the specified model which share
401         tags with the model instance ``obj``, ordered by the number of
402         shared tags in descending order.
403
404         If ``num`` is given, a maximum of ``num`` instances will be
405         returned.
406         """
407         queryset, model = get_queryset_and_model(queryset_or_model)
408         model_table = qn(model._meta.db_table)
409         content_type = ContentType.objects.get_for_model(obj)
410         related_content_type = ContentType.objects.get_for_model(model)
411         query = """
412         SELECT %(model_pk)s, COUNT(related_tagged_item.object_id) AS %(count)s
413         FROM %(model)s, %(tagged_item)s, %(tag)s, %(tagged_item)s related_tagged_item
414         WHERE %(tagged_item)s.object_id = %%s
415           AND %(tagged_item)s.content_type_id = %(content_type_id)s
416           AND %(tag)s.id = %(tagged_item)s.tag_id
417           AND related_tagged_item.content_type_id = %(related_content_type_id)s
418           AND related_tagged_item.tag_id = %(tagged_item)s.tag_id
419           AND %(model_pk)s = related_tagged_item.object_id"""
420         if content_type.pk == related_content_type.pk:
421             # Exclude the given instance itself if determining related
422             # instances for the same model.
423             query += """
424           AND related_tagged_item.object_id != %(tagged_item)s.object_id"""
425         query += """
426         GROUP BY %(model_pk)s
427         ORDER BY %(count)s DESC
428         %(limit_offset)s"""
429         query = query % {
430             'model_pk': '%s.%s' % (model_table, qn(model._meta.pk.column)),
431             'count': qn('count'),
432             'model': model_table,
433             'tagged_item': qn(self.model._meta.db_table),
434             'tag': qn(self.model._meta.get_field('tag').rel.to._meta.db_table),
435             'content_type_id': content_type.pk,
436             'related_content_type_id': related_content_type.pk,
437             'limit_offset': num is not None and connection.ops.limit_offset_sql(num) or '',
438         }
439
440         cursor = connection.cursor()
441         cursor.execute(query, [obj.pk])
442         object_ids = [row[0] for row in cursor.fetchall()]
443         if len(object_ids) > 0:
444             # Use in_bulk here instead of an id__in lookup, because id__in would
445             # clobber the ordering.
446             object_dict = queryset.in_bulk(object_ids)
447             return [object_dict[object_id] for object_id in object_ids \
448                     if object_id in object_dict]
449         else:
450             return []
451
452
453 ##########
454 # Models #
455 ##########
456 def create_intermediary_table_model(model):
457     """Create an intermediary table model for the specific tag model"""
458     name = model.__name__ + 'Relation'
459      
460     class Meta:
461         db_table = '%s_relation' % model._meta.db_table
462         unique_together = (('tag', 'content_type', 'object_id'),)
463
464     def obj_unicode(self):
465         return u'%s [%s]' % (self.content_type.get_object_for_this_type(pk=self.object_id), self.tag)
466         
467     # Set up a dictionary to simulate declarations within a class    
468     attrs = {
469         '__module__': model.__module__,
470         'Meta': Meta,
471         'tag': models.ForeignKey(model, verbose_name=_('tag'), related_name='items'),
472         'content_type': models.ForeignKey(ContentType, verbose_name=_('content type')),
473         'object_id': models.PositiveIntegerField(_('object id'), db_index=True),
474         'content_object': generic.GenericForeignKey('content_type', 'object_id'),
475         '__unicode__': obj_unicode,
476     }
477
478     return type(name, (models.Model,), attrs)
479
480
481 class TagMeta(ModelBase):
482     "Metaclass for tag models (models inheriting from TagBase)."
483     def __new__(cls, name, bases, attrs):
484         model = super(TagMeta, cls).__new__(cls, name, bases, attrs)
485         if not model._meta.abstract:
486             # Create an intermediary table and register custom managers for concrete models
487             intermediary_table_model = create_intermediary_table_model(model)
488             TagManager(intermediary_table_model).contribute_to_class(model, 'objects')
489             TaggedItemManager(model).contribute_to_class(intermediary_table_model, 'objects')
490         return model
491
492
493 class TagBase(models.Model):
494     """Abstract class to be inherited by model classes."""
495     __metaclass__ = TagMeta
496     
497     class Meta:
498         abstract = True
499     
500     @staticmethod
501     def get_tag_list(tag_list):
502         """
503         Utility function for accepting tag input in a flexible manner.
504         
505         You should probably override this method in your subclass.
506         """
507         if isinstance(tag_list, TagBase):
508             return [tag_list]
509         else:
510             return tag_list
511