Uaktualnienie django-south.
[wolnelektury.git] / apps / south / db / generic.py
index 09dde03..4a5b512 100644 (file)
@@ -1,10 +1,23 @@
 
+import datetime
 from django.core.management.color import no_style
 from django.db import connection, transaction, models
 from django.db.backends.util import truncate_name
+from django.db.models.fields import NOT_PROVIDED
 from django.dispatch import dispatcher
 from django.conf import settings
 
+
+def alias(attrname):
+    """
+    Returns a function which calls 'attrname' - for function aliasing.
+    We can't just use foo = bar, as this breaks subclassing.
+    """
+    def func(self, *args, **kwds):
+        return getattr(self, attrname)(*args, **kwds)
+    return func
+
+
 class DatabaseOperations(object):
 
     """
@@ -12,10 +25,14 @@ class DatabaseOperations(object):
     Some of this code comes from Django Evolution.
     """
 
+    # We assume the generic DB can handle DDL transactions. MySQL wil change this.
+    has_ddl_transactions = True
+
     def __init__(self):
         self.debug = False
         self.deferred_sql = []
-
+        self.dry_run = False
+        self.pending_create_signals = []
 
     def execute(self, sql, params=[]):
         """
@@ -25,31 +42,50 @@ class DatabaseOperations(object):
         cursor = connection.cursor()
         if self.debug:
             print "   = %s" % sql, params
+
+        if self.dry_run:
+            return []
+
         cursor.execute(sql, params)
         try:
             return cursor.fetchall()
         except:
             return []
-            
-            
+
+
     def add_deferred_sql(self, sql):
         """
         Add a SQL statement to the deferred list, that won't be executed until
         this instance's execute_deferred_sql method is run.
         """
         self.deferred_sql.append(sql)
-        
-        
+
+
     def execute_deferred_sql(self):
         """
         Executes all deferred SQL, resetting the deferred_sql list
         """
         for sql in self.deferred_sql:
             self.execute(sql)
-            
+
         self.deferred_sql = []
 
 
+    def clear_deferred_sql(self):
+        """
+        Resets the deferred_sql list to empty.
+        """
+        self.deferred_sql = []
+    
+    
+    def clear_run_data(self):
+        """
+        Resets variables to how they should be before a run. Used for dry runs.
+        """
+        self.clear_deferred_sql()
+        self.pending_create_signals = []
+
+
     def create_table(self, table_name, fields):
         """
         Creates the table 'table_name'. 'fields' is a tuple of fields,
@@ -57,14 +93,22 @@ class DatabaseOperations(object):
         django.db.models.fields.Field object
         """
         qn = connection.ops.quote_name
+
+        # allow fields to be a dictionary
+        # removed for now - philosophical reasons (this is almost certainly not what you want)
+        #try:
+        #    fields = fields.items()
+        #except AttributeError:
+        #    pass
+
         columns = [
             self.column_sql(table_name, field_name, field)
             for field_name, field in fields
         ]
-        
+
         self.execute('CREATE TABLE %s (%s);' % (qn(table_name), ', '.join([col for col in columns if col])))
-    
-    add_table = create_table # Alias for consistency's sake
+
+    add_table = alias('create_table') # Alias for consistency's sake
 
 
     def rename_table(self, old_table_name, table_name):
@@ -86,16 +130,26 @@ class DatabaseOperations(object):
         qn = connection.ops.quote_name
         params = (qn(table_name), )
         self.execute('DROP TABLE %s;' % params)
-    
-    drop_table = delete_table
+
+    drop_table = alias('delete_table')
 
 
-    def add_column(self, table_name, name, field):
+    def clear_table(self, table_name):
+        """
+        Deletes all rows from 'table_name'.
+        """
+        qn = connection.ops.quote_name
+        params = (qn(table_name), )
+        self.execute('DELETE FROM %s;' % params)
+
+    add_column_string = 'ALTER TABLE %s ADD COLUMN %s;'
+
+    def add_column(self, table_name, name, field, keep_default=True):
         """
         Adds the column 'name' to the table 'table_name'.
         Uses the 'field' paramater, a django.db.models.fields.Field instance,
         to generate the necessary sql
-        
+
         @param table_name: The name of the table to add the column to
         @param name: The name of the column to add
         @param field: The field to use
@@ -107,68 +161,82 @@ class DatabaseOperations(object):
                 qn(table_name),
                 sql,
             )
-            sql = 'ALTER TABLE %s ADD COLUMN %s;' % params
+            sql = self.add_column_string % params
             self.execute(sql)
-    
-    
+
+            # Now, drop the default if we need to
+            if not keep_default and field.default:
+                field.default = NOT_PROVIDED
+                self.alter_column(table_name, name, field, explicit_name=False)
+
     alter_string_set_type = 'ALTER COLUMN %(column)s TYPE %(type)s'
-    alter_string_set_null = 'ALTER COLUMN %(column)s SET NOT NULL'
-    alter_string_drop_null = 'ALTER COLUMN %(column)s DROP NOT NULL'
-    
-    def alter_column(self, table_name, name, field):
+    alter_string_set_null = 'ALTER COLUMN %(column)s DROP NOT NULL'
+    alter_string_drop_null = 'ALTER COLUMN %(column)s SET NOT NULL'
+    allows_combined_alters = True
+
+    def alter_column(self, table_name, name, field, explicit_name=True):
         """
         Alters the given column name so it will match the given field.
         Note that conversion between the two by the database must be possible.
-        
+        Will not automatically add _id by default; to have this behavour, pass
+        explicit_name=False.
+
         @param table_name: The name of the table to add the column to
         @param name: The name of the column to alter
         @param field: The new field definition to use
         """
-        
+
         # hook for the field to do any resolution prior to it's attributes being queried
         if hasattr(field, 'south_init'):
             field.south_init()
-        
+
         qn = connection.ops.quote_name
         
+        # Add _id or whatever if we need to
+        if not explicit_name:
+            field.set_attributes_from_name(name)
+            name = field.column
+
         # First, change the type
         params = {
             "column": qn(name),
             "type": field.db_type(),
         }
-        sqls = [self.alter_string_set_type % params]
-        
-        
+
+        # SQLs is a list of (SQL, values) pairs.
+        sqls = [(self.alter_string_set_type % params, [])]
+
         # Next, set any default
-        params = (
-            qn(name),
-        )
-        
         if not field.null and field.has_default():
             default = field.get_default()
-            if isinstance(default, basestring):
-                default = "'%s'" % default
-            params += ("SET DEFAULT %s",)
+            sqls.append(('ALTER COLUMN %s SET DEFAULT %%s ' % (qn(name),), [default]))
         else:
-            params += ("DROP DEFAULT",)
-        
-        sqls.append('ALTER COLUMN %s %s ' % params)
-        
-        
+            sqls.append(('ALTER COLUMN %s DROP DEFAULT' % (qn(name),), []))
+
+
         # Next, nullity
         params = {
             "column": qn(name),
             "type": field.db_type(),
         }
         if field.null:
-            sqls.append(self.alter_string_drop_null % params)
+            sqls.append((self.alter_string_set_null % params, []))
         else:
-            sqls.append(self.alter_string_set_null % params)
-        
-        
+            sqls.append((self.alter_string_drop_null % params, []))
+
+
         # TODO: Unique
-        
-        self.execute("ALTER TABLE %s %s;" % (qn(table_name), ", ".join(sqls)))
+
+        if self.allows_combined_alters:
+            sqls, values = zip(*sqls)
+            self.execute(
+                "ALTER TABLE %s %s;" % (qn(table_name), ", ".join(sqls)),
+                flatten(values),
+            )
+        else:
+            # Databases like e.g. MySQL don't like more than one alter at once.
+            for sql, values in sqls:
+                self.execute("ALTER TABLE %s %s;" % (qn(table_name), sql), values)
 
 
     def column_sql(self, table_name, field_name, field, tablespace=''):
@@ -176,13 +244,13 @@ class DatabaseOperations(object):
         Creates the SQL snippet for a column. Used by add_column and add_table.
         """
         qn = connection.ops.quote_name
-        
+
         field.set_attributes_from_name(field_name)
-        
+
         # hook for the field to do any resolution prior to it's attributes being queried
         if hasattr(field, 'south_init'):
             field.south_init()
-        
+
         sql = field.db_type()
         if sql:        
             field_output = [qn(field.column), sql]
@@ -190,26 +258,40 @@ class DatabaseOperations(object):
             if field.primary_key:
                 field_output.append('PRIMARY KEY')
             elif field.unique:
-                field_output.append('UNIQUE')
-        
+                # Instead of using UNIQUE, add a unique index with a predictable name
+                self.add_deferred_sql(
+                    self.create_index_sql(
+                        table_name,
+                        [field.column],
+                        unique = True,
+                        db_tablespace = tablespace,
+                    )
+                )
+
             tablespace = field.db_tablespace or tablespace
             if tablespace and connection.features.supports_tablespaces and field.unique:
                 # We must specify the index tablespace inline, because we
                 # won't be generating a CREATE INDEX statement for this field.
                 field_output.append(connection.ops.tablespace_sql(tablespace, inline=True))
-            
+
             sql = ' '.join(field_output)
             sqlparams = ()
             # if the field is "NOT NULL" and a default value is provided, create the column with it
             # this allows the addition of a NOT NULL field to a table with existing rows
             if not field.null and field.has_default():
                 default = field.get_default()
+                # If the default is a callable, then call it!
+                if callable(default):
+                    default = default()
+                # Now do some very cheap quoting. TODO: Redesign return values to avoid this.
                 if isinstance(default, basestring):
                     default = "'%s'" % default.replace("'", "''")
+                elif isinstance(default, datetime.date):
+                    default = "'%s'" % default
                 sql += " DEFAULT %s"
                 sqlparams = (default)
-            
-            if field.rel:
+
+            if field.rel and self.supports_foreign_keys:
                 self.add_deferred_sql(
                     self.foreign_key_sql(
                         table_name,
@@ -218,10 +300,10 @@ class DatabaseOperations(object):
                         field.rel.to._meta.get_field(field.rel.field_name).column
                     )
                 )
-            
+
             if field.db_index and not field.unique:
                 self.add_deferred_sql(self.create_index_sql(table_name, [field.column]))
-            
+
         if hasattr(field, 'post_create_sql'):
             style = no_style()
             for stmt in field.post_create_sql(style, table_name):
@@ -231,21 +313,26 @@ class DatabaseOperations(object):
             return sql % sqlparams
         else:
             return None
-        
+
+
+    supports_foreign_keys = True
+
     def foreign_key_sql(self, from_table_name, from_column_name, to_table_name, to_column_name):
         """
         Generates a full SQL statement to add a foreign key constraint
         """
+        qn = connection.ops.quote_name
         constraint_name = '%s_refs_%s_%x' % (from_column_name, to_column_name, abs(hash((from_table_name, to_table_name))))
         return 'ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % (
-            from_table_name,
-            truncate_name(constraint_name, connection.ops.max_name_length()),
-            from_column_name,
-            to_table_name,
-            to_column_name,
+            qn(from_table_name),
+            qn(truncate_name(constraint_name, connection.ops.max_name_length())),
+            qn(from_column_name),
+            qn(to_table_name),
+            qn(to_column_name),
             connection.ops.deferrable_sql() # Django knows this
         )
-        
+
+
     def create_index_name(self, table_name, column_names):
         """
         Generate a unique name for the index
@@ -256,45 +343,55 @@ class DatabaseOperations(object):
 
         return '%s_%s%s' % (table_name, column_names[0], index_unique_name)
 
+
     def create_index_sql(self, table_name, column_names, unique=False, db_tablespace=''):
         """
         Generates a create index statement on 'table_name' for a list of 'column_names'
         """
+        qn = connection.ops.quote_name
         if not column_names:
             print "No column names supplied on which to create an index"
             return ''
-            
+
         if db_tablespace and connection.features.supports_tablespaces:
             tablespace_sql = ' ' + connection.ops.tablespace_sql(db_tablespace)
         else:
             tablespace_sql = ''
-        
+
         index_name = self.create_index_name(table_name, column_names)
         qn = connection.ops.quote_name
         return 'CREATE %sINDEX %s ON %s (%s)%s;' % (
             unique and 'UNIQUE ' or '',
-            index_name,
-            table_name,
+            qn(index_name),
+            qn(table_name),
             ','.join([qn(field) for field in column_names]),
             tablespace_sql
-            )
-        
+        )
+
     def create_index(self, table_name, column_names, unique=False, db_tablespace=''):
         """ Executes a create index statement """
         sql = self.create_index_sql(table_name, column_names, unique, db_tablespace)
         self.execute(sql)
 
 
+    drop_index_string = 'DROP INDEX %(index_name)s'
+
     def delete_index(self, table_name, column_names, db_tablespace=''):
         """
         Deletes an index created with create_index.
         This is possible using only columns due to the deterministic
         index naming function which relies on column names.
         """
+        if isinstance(column_names, (str, unicode)):
+            column_names = [column_names]
         name = self.create_index_name(table_name, column_names)
-        sql = "DROP INDEX %s" % name
+        qn = connection.ops.quote_name
+        sql = self.drop_index_string % {"index_name": qn(name), "table_name": qn(table_name)}
         self.execute(sql)
 
+    drop_index = alias('delete_index')
+
+    delete_column_string = 'ALTER TABLE %s DROP COLUMN %s CASCADE;'
 
     def delete_column(self, table_name, name):
         """
@@ -302,7 +399,9 @@ class DatabaseOperations(object):
         """
         qn = connection.ops.quote_name
         params = (qn(table_name), qn(name))
-        self.execute('ALTER TABLE %s DROP COLUMN %s CASCADE;' % params, [])
+        self.execute(self.delete_column_string % params, [])
+
+    drop_column = alias('delete_column')
 
 
     def rename_column(self, table_name, old, new):
@@ -317,6 +416,8 @@ class DatabaseOperations(object):
         Makes sure the following commands are inside a transaction.
         Must be followed by a (commit|rollback)_transaction call.
         """
+        if self.dry_run:
+            return
         transaction.commit_unless_managed()
         transaction.enter_transaction_management()
         transaction.managed(True)
@@ -327,6 +428,8 @@ class DatabaseOperations(object):
         Commits the current transaction.
         Must be preceded by a start_transaction call.
         """
+        if self.dry_run:
+            return
         transaction.commit()
         transaction.leave_transaction_management()
 
@@ -336,53 +439,67 @@ class DatabaseOperations(object):
         Rolls back the current transaction.
         Must be preceded by a start_transaction call.
         """
+        if self.dry_run:
+            return
         transaction.rollback()
         transaction.leave_transaction_management()
-    
-    
+
+
     def send_create_signal(self, app_label, model_names):
+        self.pending_create_signals.append((app_label, model_names))
+
+
+    def send_pending_create_signals(self):
+        for (app_label, model_names) in self.pending_create_signals:
+            self.really_send_create_signal(app_label, model_names)
+        self.pending_create_signals = []
+
+
+    def really_send_create_signal(self, app_label, model_names):
         """
         Sends a post_syncdb signal for the model specified.
-        
+
         If the model is not found (perhaps it's been deleted?),
         no signal is sent.
-        
+
         TODO: The behavior of django.contrib.* apps seems flawed in that
         they don't respect created_models.  Rather, they blindly execute
         over all models within the app sending the signal.  This is a
         patch we should push Django to make  For now, this should work.
         """
+        if self.debug:
+            print " - Sending post_syncdb signal for %s: %s" % (app_label, model_names)
         app = models.get_app(app_label)
         if not app:
             return
-            
+
         created_models = []
         for model_name in model_names:
             model = models.get_model(app_label, model_name)
             if model:
                 created_models.append(model)
-                
+
         if created_models:
             # syncdb defaults -- perhaps take these as options?
             verbosity = 1
             interactive = True
-            
+
             if hasattr(dispatcher, "send"):
                 dispatcher.send(signal=models.signals.post_syncdb, sender=app,
-                app=app, created_models=created_models,
-                verbosity=verbosity, interactive=interactive)
+                                app=app, created_models=created_models,
+                                verbosity=verbosity, interactive=interactive)
             else:
                 models.signals.post_syncdb.send(sender=app,
-                app=app, created_models=created_models,
-                verbosity=verbosity, interactive=interactive)
-                
+                                                app=app, created_models=created_models,
+                                                verbosity=verbosity, interactive=interactive)
+
     def mock_model(self, model_name, db_table, db_tablespace='', 
-                    pk_field_name='id', pk_field_type=models.AutoField,
-                    pk_field_kwargs={}):
+                   pk_field_name='id', pk_field_type=models.AutoField,
+                   pk_field_args=[], pk_field_kwargs={}):
         """
         Generates a MockModel class that provides enough information
         to be used by a foreign key/many-to-many relationship.
-        
+
         Migrations should prefer to use these rather than actual models
         as models could get deleted over time, but these can remain in
         migration files forever.
@@ -397,7 +514,7 @@ class DatabaseOperations(object):
                 if pk_field_type == models.AutoField:
                     pk_field_kwargs['primary_key'] = True
 
-                self.pk = pk_field_type(**pk_field_kwargs)
+                self.pk = pk_field_type(*pk_field_args, **pk_field_kwargs)
                 self.pk.set_attributes_from_name(pk_field_name)
                 self.abstract = False
 
@@ -416,3 +533,11 @@ class DatabaseOperations(object):
         MockModel._meta = MockOptions()
         MockModel._meta.model = MockModel
         return MockModel
+
+# Single-level flattening of lists
+def flatten(ls):
+    nl = []
+    for l in ls:
+        nl += l
+    return nl
+