4a5b512d78b467a32abe144671f6f1299ab5f8b8
[wolnelektury.git] / apps / south / db / generic.py
1
2 import datetime
3 from django.core.management.color import no_style
4 from django.db import connection, transaction, models
5 from django.db.backends.util import truncate_name
6 from django.db.models.fields import NOT_PROVIDED
7 from django.dispatch import dispatcher
8 from django.conf import settings
9
10
11 def alias(attrname):
12     """
13     Returns a function which calls 'attrname' - for function aliasing.
14     We can't just use foo = bar, as this breaks subclassing.
15     """
16     def func(self, *args, **kwds):
17         return getattr(self, attrname)(*args, **kwds)
18     return func
19
20
21 class DatabaseOperations(object):
22
23     """
24     Generic SQL implementation of the DatabaseOperations.
25     Some of this code comes from Django Evolution.
26     """
27
28     # We assume the generic DB can handle DDL transactions. MySQL wil change this.
29     has_ddl_transactions = True
30
31     def __init__(self):
32         self.debug = False
33         self.deferred_sql = []
34         self.dry_run = False
35         self.pending_create_signals = []
36
37     def execute(self, sql, params=[]):
38         """
39         Executes the given SQL statement, with optional parameters.
40         If the instance's debug attribute is True, prints out what it executes.
41         """
42         cursor = connection.cursor()
43         if self.debug:
44             print "   = %s" % sql, params
45
46         if self.dry_run:
47             return []
48
49         cursor.execute(sql, params)
50         try:
51             return cursor.fetchall()
52         except:
53             return []
54
55
56     def add_deferred_sql(self, sql):
57         """
58         Add a SQL statement to the deferred list, that won't be executed until
59         this instance's execute_deferred_sql method is run.
60         """
61         self.deferred_sql.append(sql)
62
63
64     def execute_deferred_sql(self):
65         """
66         Executes all deferred SQL, resetting the deferred_sql list
67         """
68         for sql in self.deferred_sql:
69             self.execute(sql)
70
71         self.deferred_sql = []
72
73
74     def clear_deferred_sql(self):
75         """
76         Resets the deferred_sql list to empty.
77         """
78         self.deferred_sql = []
79     
80     
81     def clear_run_data(self):
82         """
83         Resets variables to how they should be before a run. Used for dry runs.
84         """
85         self.clear_deferred_sql()
86         self.pending_create_signals = []
87
88
89     def create_table(self, table_name, fields):
90         """
91         Creates the table 'table_name'. 'fields' is a tuple of fields,
92         each repsented by a 2-part tuple of field name and a
93         django.db.models.fields.Field object
94         """
95         qn = connection.ops.quote_name
96
97         # allow fields to be a dictionary
98         # removed for now - philosophical reasons (this is almost certainly not what you want)
99         #try:
100         #    fields = fields.items()
101         #except AttributeError:
102         #    pass
103
104         columns = [
105             self.column_sql(table_name, field_name, field)
106             for field_name, field in fields
107         ]
108
109         self.execute('CREATE TABLE %s (%s);' % (qn(table_name), ', '.join([col for col in columns if col])))
110
111     add_table = alias('create_table') # Alias for consistency's sake
112
113
114     def rename_table(self, old_table_name, table_name):
115         """
116         Renames the table 'old_table_name' to 'table_name'.
117         """
118         if old_table_name == table_name:
119             # No Operation
120             return
121         qn = connection.ops.quote_name
122         params = (qn(old_table_name), qn(table_name))
123         self.execute('ALTER TABLE %s RENAME TO %s;' % params)
124
125
126     def delete_table(self, table_name):
127         """
128         Deletes the table 'table_name'.
129         """
130         qn = connection.ops.quote_name
131         params = (qn(table_name), )
132         self.execute('DROP TABLE %s;' % params)
133
134     drop_table = alias('delete_table')
135
136
137     def clear_table(self, table_name):
138         """
139         Deletes all rows from 'table_name'.
140         """
141         qn = connection.ops.quote_name
142         params = (qn(table_name), )
143         self.execute('DELETE FROM %s;' % params)
144
145     add_column_string = 'ALTER TABLE %s ADD COLUMN %s;'
146
147     def add_column(self, table_name, name, field, keep_default=True):
148         """
149         Adds the column 'name' to the table 'table_name'.
150         Uses the 'field' paramater, a django.db.models.fields.Field instance,
151         to generate the necessary sql
152
153         @param table_name: The name of the table to add the column to
154         @param name: The name of the column to add
155         @param field: The field to use
156         """
157         qn = connection.ops.quote_name
158         sql = self.column_sql(table_name, name, field)
159         if sql:
160             params = (
161                 qn(table_name),
162                 sql,
163             )
164             sql = self.add_column_string % params
165             self.execute(sql)
166
167             # Now, drop the default if we need to
168             if not keep_default and field.default:
169                 field.default = NOT_PROVIDED
170                 self.alter_column(table_name, name, field, explicit_name=False)
171
172     alter_string_set_type = 'ALTER COLUMN %(column)s TYPE %(type)s'
173     alter_string_set_null = 'ALTER COLUMN %(column)s DROP NOT NULL'
174     alter_string_drop_null = 'ALTER COLUMN %(column)s SET NOT NULL'
175     allows_combined_alters = True
176
177     def alter_column(self, table_name, name, field, explicit_name=True):
178         """
179         Alters the given column name so it will match the given field.
180         Note that conversion between the two by the database must be possible.
181         Will not automatically add _id by default; to have this behavour, pass
182         explicit_name=False.
183
184         @param table_name: The name of the table to add the column to
185         @param name: The name of the column to alter
186         @param field: The new field definition to use
187         """
188
189         # hook for the field to do any resolution prior to it's attributes being queried
190         if hasattr(field, 'south_init'):
191             field.south_init()
192
193         qn = connection.ops.quote_name
194         
195         # Add _id or whatever if we need to
196         if not explicit_name:
197             field.set_attributes_from_name(name)
198             name = field.column
199
200         # First, change the type
201         params = {
202             "column": qn(name),
203             "type": field.db_type(),
204         }
205
206         # SQLs is a list of (SQL, values) pairs.
207         sqls = [(self.alter_string_set_type % params, [])]
208
209         # Next, set any default
210         if not field.null and field.has_default():
211             default = field.get_default()
212             sqls.append(('ALTER COLUMN %s SET DEFAULT %%s ' % (qn(name),), [default]))
213         else:
214             sqls.append(('ALTER COLUMN %s DROP DEFAULT' % (qn(name),), []))
215
216
217         # Next, nullity
218         params = {
219             "column": qn(name),
220             "type": field.db_type(),
221         }
222         if field.null:
223             sqls.append((self.alter_string_set_null % params, []))
224         else:
225             sqls.append((self.alter_string_drop_null % params, []))
226
227
228         # TODO: Unique
229
230         if self.allows_combined_alters:
231             sqls, values = zip(*sqls)
232             self.execute(
233                 "ALTER TABLE %s %s;" % (qn(table_name), ", ".join(sqls)),
234                 flatten(values),
235             )
236         else:
237             # Databases like e.g. MySQL don't like more than one alter at once.
238             for sql, values in sqls:
239                 self.execute("ALTER TABLE %s %s;" % (qn(table_name), sql), values)
240
241
242     def column_sql(self, table_name, field_name, field, tablespace=''):
243         """
244         Creates the SQL snippet for a column. Used by add_column and add_table.
245         """
246         qn = connection.ops.quote_name
247
248         field.set_attributes_from_name(field_name)
249
250         # hook for the field to do any resolution prior to it's attributes being queried
251         if hasattr(field, 'south_init'):
252             field.south_init()
253
254         sql = field.db_type()
255         if sql:        
256             field_output = [qn(field.column), sql]
257             field_output.append('%sNULL' % (not field.null and 'NOT ' or ''))
258             if field.primary_key:
259                 field_output.append('PRIMARY KEY')
260             elif field.unique:
261                 # Instead of using UNIQUE, add a unique index with a predictable name
262                 self.add_deferred_sql(
263                     self.create_index_sql(
264                         table_name,
265                         [field.column],
266                         unique = True,
267                         db_tablespace = tablespace,
268                     )
269                 )
270
271             tablespace = field.db_tablespace or tablespace
272             if tablespace and connection.features.supports_tablespaces and field.unique:
273                 # We must specify the index tablespace inline, because we
274                 # won't be generating a CREATE INDEX statement for this field.
275                 field_output.append(connection.ops.tablespace_sql(tablespace, inline=True))
276
277             sql = ' '.join(field_output)
278             sqlparams = ()
279             # if the field is "NOT NULL" and a default value is provided, create the column with it
280             # this allows the addition of a NOT NULL field to a table with existing rows
281             if not field.null and field.has_default():
282                 default = field.get_default()
283                 # If the default is a callable, then call it!
284                 if callable(default):
285                     default = default()
286                 # Now do some very cheap quoting. TODO: Redesign return values to avoid this.
287                 if isinstance(default, basestring):
288                     default = "'%s'" % default.replace("'", "''")
289                 elif isinstance(default, datetime.date):
290                     default = "'%s'" % default
291                 sql += " DEFAULT %s"
292                 sqlparams = (default)
293
294             if field.rel and self.supports_foreign_keys:
295                 self.add_deferred_sql(
296                     self.foreign_key_sql(
297                         table_name,
298                         field.column,
299                         field.rel.to._meta.db_table,
300                         field.rel.to._meta.get_field(field.rel.field_name).column
301                     )
302                 )
303
304             if field.db_index and not field.unique:
305                 self.add_deferred_sql(self.create_index_sql(table_name, [field.column]))
306
307         if hasattr(field, 'post_create_sql'):
308             style = no_style()
309             for stmt in field.post_create_sql(style, table_name):
310                 self.add_deferred_sql(stmt)
311
312         if sql:
313             return sql % sqlparams
314         else:
315             return None
316
317
318     supports_foreign_keys = True
319
320     def foreign_key_sql(self, from_table_name, from_column_name, to_table_name, to_column_name):
321         """
322         Generates a full SQL statement to add a foreign key constraint
323         """
324         qn = connection.ops.quote_name
325         constraint_name = '%s_refs_%s_%x' % (from_column_name, to_column_name, abs(hash((from_table_name, to_table_name))))
326         return 'ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % (
327             qn(from_table_name),
328             qn(truncate_name(constraint_name, connection.ops.max_name_length())),
329             qn(from_column_name),
330             qn(to_table_name),
331             qn(to_column_name),
332             connection.ops.deferrable_sql() # Django knows this
333         )
334
335
336     def create_index_name(self, table_name, column_names):
337         """
338         Generate a unique name for the index
339         """
340         index_unique_name = ''
341         if len(column_names) > 1:
342             index_unique_name = '_%x' % abs(hash((table_name, ','.join(column_names))))
343
344         return '%s_%s%s' % (table_name, column_names[0], index_unique_name)
345
346
347     def create_index_sql(self, table_name, column_names, unique=False, db_tablespace=''):
348         """
349         Generates a create index statement on 'table_name' for a list of 'column_names'
350         """
351         qn = connection.ops.quote_name
352         if not column_names:
353             print "No column names supplied on which to create an index"
354             return ''
355
356         if db_tablespace and connection.features.supports_tablespaces:
357             tablespace_sql = ' ' + connection.ops.tablespace_sql(db_tablespace)
358         else:
359             tablespace_sql = ''
360
361         index_name = self.create_index_name(table_name, column_names)
362         qn = connection.ops.quote_name
363         return 'CREATE %sINDEX %s ON %s (%s)%s;' % (
364             unique and 'UNIQUE ' or '',
365             qn(index_name),
366             qn(table_name),
367             ','.join([qn(field) for field in column_names]),
368             tablespace_sql
369         )
370
371     def create_index(self, table_name, column_names, unique=False, db_tablespace=''):
372         """ Executes a create index statement """
373         sql = self.create_index_sql(table_name, column_names, unique, db_tablespace)
374         self.execute(sql)
375
376
377     drop_index_string = 'DROP INDEX %(index_name)s'
378
379     def delete_index(self, table_name, column_names, db_tablespace=''):
380         """
381         Deletes an index created with create_index.
382         This is possible using only columns due to the deterministic
383         index naming function which relies on column names.
384         """
385         if isinstance(column_names, (str, unicode)):
386             column_names = [column_names]
387         name = self.create_index_name(table_name, column_names)
388         qn = connection.ops.quote_name
389         sql = self.drop_index_string % {"index_name": qn(name), "table_name": qn(table_name)}
390         self.execute(sql)
391
392     drop_index = alias('delete_index')
393
394     delete_column_string = 'ALTER TABLE %s DROP COLUMN %s CASCADE;'
395
396     def delete_column(self, table_name, name):
397         """
398         Deletes the column 'column_name' from the table 'table_name'.
399         """
400         qn = connection.ops.quote_name
401         params = (qn(table_name), qn(name))
402         self.execute(self.delete_column_string % params, [])
403
404     drop_column = alias('delete_column')
405
406
407     def rename_column(self, table_name, old, new):
408         """
409         Renames the column 'old' from the table 'table_name' to 'new'.
410         """
411         raise NotImplementedError("rename_column has no generic SQL syntax")
412
413
414     def start_transaction(self):
415         """
416         Makes sure the following commands are inside a transaction.
417         Must be followed by a (commit|rollback)_transaction call.
418         """
419         if self.dry_run:
420             return
421         transaction.commit_unless_managed()
422         transaction.enter_transaction_management()
423         transaction.managed(True)
424
425
426     def commit_transaction(self):
427         """
428         Commits the current transaction.
429         Must be preceded by a start_transaction call.
430         """
431         if self.dry_run:
432             return
433         transaction.commit()
434         transaction.leave_transaction_management()
435
436
437     def rollback_transaction(self):
438         """
439         Rolls back the current transaction.
440         Must be preceded by a start_transaction call.
441         """
442         if self.dry_run:
443             return
444         transaction.rollback()
445         transaction.leave_transaction_management()
446
447
448     def send_create_signal(self, app_label, model_names):
449         self.pending_create_signals.append((app_label, model_names))
450
451
452     def send_pending_create_signals(self):
453         for (app_label, model_names) in self.pending_create_signals:
454             self.really_send_create_signal(app_label, model_names)
455         self.pending_create_signals = []
456
457
458     def really_send_create_signal(self, app_label, model_names):
459         """
460         Sends a post_syncdb signal for the model specified.
461
462         If the model is not found (perhaps it's been deleted?),
463         no signal is sent.
464
465         TODO: The behavior of django.contrib.* apps seems flawed in that
466         they don't respect created_models.  Rather, they blindly execute
467         over all models within the app sending the signal.  This is a
468         patch we should push Django to make  For now, this should work.
469         """
470         if self.debug:
471             print " - Sending post_syncdb signal for %s: %s" % (app_label, model_names)
472         app = models.get_app(app_label)
473         if not app:
474             return
475
476         created_models = []
477         for model_name in model_names:
478             model = models.get_model(app_label, model_name)
479             if model:
480                 created_models.append(model)
481
482         if created_models:
483             # syncdb defaults -- perhaps take these as options?
484             verbosity = 1
485             interactive = True
486
487             if hasattr(dispatcher, "send"):
488                 dispatcher.send(signal=models.signals.post_syncdb, sender=app,
489                                 app=app, created_models=created_models,
490                                 verbosity=verbosity, interactive=interactive)
491             else:
492                 models.signals.post_syncdb.send(sender=app,
493                                                 app=app, created_models=created_models,
494                                                 verbosity=verbosity, interactive=interactive)
495
496     def mock_model(self, model_name, db_table, db_tablespace='', 
497                    pk_field_name='id', pk_field_type=models.AutoField,
498                    pk_field_args=[], pk_field_kwargs={}):
499         """
500         Generates a MockModel class that provides enough information
501         to be used by a foreign key/many-to-many relationship.
502
503         Migrations should prefer to use these rather than actual models
504         as models could get deleted over time, but these can remain in
505         migration files forever.
506         """
507         class MockOptions(object):
508             def __init__(self):
509                 self.db_table = db_table
510                 self.db_tablespace = db_tablespace or settings.DEFAULT_TABLESPACE
511                 self.object_name = model_name
512                 self.module_name = model_name.lower()
513
514                 if pk_field_type == models.AutoField:
515                     pk_field_kwargs['primary_key'] = True
516
517                 self.pk = pk_field_type(*pk_field_args, **pk_field_kwargs)
518                 self.pk.set_attributes_from_name(pk_field_name)
519                 self.abstract = False
520
521             def get_field_by_name(self, field_name):
522                 # we only care about the pk field
523                 return (self.pk, self.model, True, False)
524
525             def get_field(self, name):
526                 # we only care about the pk field
527                 return self.pk
528
529         class MockModel(object):
530             _meta = None
531
532         # We need to return an actual class object here, not an instance
533         MockModel._meta = MockOptions()
534         MockModel._meta.model = MockModel
535         return MockModel
536
537 # Single-level flattening of lists
538 def flatten(ls):
539     nl = []
540     for l in ls:
541         nl += l
542     return nl
543