1a8da997783ae434fb62eeb2a2139942be16d95a
[wolnelektury.git] / apps / south / management / commands / startmigration.py
1 from django.core.management.base import BaseCommand
2 from django.core.management.color import no_style
3 from django.db import models
4 from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
5 from django.contrib.contenttypes.generic import GenericRelation
6 from django.db.models.fields import FieldDoesNotExist
7 from optparse import make_option
8 from south import migration
9 import sys
10 import os
11 import re
12 import string
13 import random
14 import inspect
15 import parser
16
17 class Command(BaseCommand):
18     option_list = BaseCommand.option_list + (
19         make_option('--model', action='append', dest='model_list', type='string',
20             help='Generate a Create Table migration for the specified model.  Add multiple models to this migration with subsequent --model parameters.'),
21         make_option('--add-field', action='append', dest='field_list', type='string',
22             help='Generate an Add Column migration for the specified modelname.fieldname - you can use this multiple times to add more than one column.'),
23         make_option('--initial', action='store_true', dest='initial', default=False,
24             help='Generate the initial schema for the app.'),
25     )
26     help = "Creates a new template migration for the given app"
27     
28     def handle(self, app=None, name="", model_list=None, field_list=None, initial=False, **options):
29         
30         # If model_list is None, then it's an empty list
31         model_list = model_list or []
32         
33         # If field_list is None, then it's an empty list
34         field_list = field_list or []
35         
36         # make sure --model and --all aren't both specified
37         if initial and (model_list or field_list):
38             print "You cannot use --initial and other options together"
39             return
40             
41         # specify the default name 'initial' if a name wasn't specified and we're
42         # doing a migration for an entire app
43         if not name and initial:
44             name = 'initial'
45             
46         # if not name, there's an error
47         if not name:
48             print "You must name this migration"
49             return
50         
51         if not app:
52             print "Please provide an app in which to create the migration."
53             return
54             
55         # See if the app exists
56         app_models_module = models.get_app(app)
57         if not app_models_module:
58             print "App '%s' doesn't seem to exist, isn't in INSTALLED_APPS, or has no models." % app
59             return
60             
61         # Determine what models should be included in this migration.
62         models_to_migrate = []
63         if initial:
64             models_to_migrate = models.get_models(app_models_module)
65             if not models_to_migrate:
66                 print "No models found in app '%s'" % (app)
67                 return
68         else:
69             for model_name in model_list:
70                 model = models.get_model(app, model_name)
71                 if not model:
72                     print "Couldn't find model '%s' in app '%s'" % (model_name, app)
73                     return
74                     
75                 models_to_migrate.append(model)
76         
77         # See what fields need to be included
78         fields_to_add = []
79         for field_spec in field_list:
80             model_name, field_name = field_spec.split(".", 1)
81             model = models.get_model(app, model_name)
82             if not model:
83                 print "Couldn't find model '%s' in app '%s'" % (model_name, app)
84                 return
85             try:
86                 field = model._meta.get_field(field_name)
87             except FieldDoesNotExist:
88                 print "Model '%s' doesn't have a field '%s'" % (model_name, field_name)
89                 return
90             fields_to_add.append((model, field_name, field))
91         
92         # Make the migrations directory if it's not there
93         app_module_path = app_models_module.__name__.split('.')[0:-1]
94         try:
95             app_module = __import__('.'.join(app_module_path), {}, {}, [''])
96         except ImportError:
97             print "Couldn't find path to App '%s'." % app
98             return
99             
100         migrations_dir = os.path.join(
101             os.path.dirname(app_module.__file__),
102             "migrations",
103         )
104         # Make sure there's a migrations directory and __init__.py
105         if not os.path.isdir(migrations_dir):
106             print "Creating migrations directory at '%s'..." % migrations_dir
107             os.mkdir(migrations_dir)
108         init_path = os.path.join(migrations_dir, "__init__.py")
109         if not os.path.isfile(init_path):
110             # Touch the init py file
111             print "Creating __init__.py in '%s'..." % migrations_dir
112             open(init_path, "w").close()
113         # See what filename is next in line. We assume they use numbers.
114         migrations = migration.get_migration_names(migration.get_app(app))
115         highest_number = 0
116         for migration_name in migrations:
117             try:
118                 number = int(migration_name.split("_")[0])
119                 highest_number = max(highest_number, number)
120             except ValueError:
121                 pass
122         # Make the new filename
123         new_filename = "%04i%s_%s.py" % (
124             highest_number + 1,
125             "".join([random.choice(string.letters.lower()) for i in range(0)]), # Possible random stuff insertion
126             name,
127         )
128         # If there's a model, make the migration skeleton, else leave it bare
129         forwards, backwards = '', ''
130         if fields_to_add:
131             # First, do the added fields
132             for model, field_name, field in fields_to_add:
133                 field_definition = generate_field_definition(model, field)
134                 
135                 if isinstance(field, models.ManyToManyField):
136                     # Make a mock model for each side
137                     mock_model = "\n".join([
138                         create_mock_model(model, "        "), 
139                         create_mock_model(field.rel.to, "        ")
140                     ])
141                     # And a field defn, that's actually a table creation
142                     forwards += '''
143         # Mock Model
144 %s
145         # Adding ManyToManyField '%s.%s'
146         db.create_table('%s', (
147             ('id', models.AutoField(verbose_name='ID', primary_key=True, auto_created=True)),
148             ('%s', models.ForeignKey(%s, null=False)),
149             ('%s', models.ForeignKey(%s, null=False))
150         )) ''' % (
151                 mock_model,
152                 model._meta.object_name,
153                 field.name,
154                 field.m2m_db_table(),
155                 field.m2m_column_name()[:-3], # strip off the '_id' at the end
156                 model._meta.object_name,
157                 field.m2m_reverse_name()[:-3], # strip off the '_id' at the ned
158                 field.rel.to._meta.object_name
159                 )
160                     backwards += '''
161         # Dropping ManyToManyField '%s.%s'
162         db.drop_table('%s')''' % (
163                         model._meta.object_name,
164                         field.name,
165                         field.m2m_db_table()
166                     )
167                     continue
168                 elif field.rel: # ForeignKey, etc.
169                     mock_model = create_mock_model(field.rel.to, "        ")
170                     field_definition = related_field_definition(field, field_definition)
171                 else:
172                     mock_model = None
173                 
174                 # If we can't get it (inspect madness?) then insert placeholder
175                 if not field_definition:
176                     print "Warning: Could not generate field definition for %s.%s, manual editing of migration required." % \
177                                 (model._meta.object_name, field.name)
178                     field_definition = '<<< REPLACE THIS WITH FIELD DEFINITION FOR %s.%s >>>' % (model._meta.object_name, f.name)
179                 
180                 if mock_model:
181                     forwards += '''
182         # Mock model
183 %s
184         ''' % (mock_model)
185                 
186                 forwards += '''
187         # Adding field '%s.%s'
188         db.add_column(%r, %r, %s)
189         ''' % (
190             model._meta.object_name,
191             field.name,
192             model._meta.db_table,
193             field.name,
194             field_definition,
195         )
196                 backwards += '''
197         # Deleting field '%s.%s'
198         db.delete_column(%r, %r)
199         ''' % (
200             model._meta.object_name,
201             field.name,
202             model._meta.db_table,
203             field.column,
204         )
205         
206         if models_to_migrate:
207             # Now, do the added models
208             for model in models_to_migrate:
209                 table_name = model._meta.db_table
210                 mock_models = []
211                 fields = []
212                 for f in model._meta.local_fields:
213                     
214                     # Look up the field definition to see how this was created
215                     field_definition = generate_field_definition(model, f)
216                     
217                     # If it's a OneToOneField, and ends in _ptr, just use it
218                     if isinstance(f, models.OneToOneField) and f.name.endswith("_ptr"):
219                         mock_models.append(create_mock_model(f.rel.to, "        "))
220                         field_definition = "models.OneToOneField(%s)" % f.rel.to.__name__
221                     
222                     # It's probably normal then
223                     elif field_definition:
224                         
225                         if isinstance(f, models.ForeignKey):
226                             mock_models.append(create_mock_model(f.rel.to, "        "))
227                             field_definition = related_field_definition(f, field_definition)
228                     
229                     # Oh noes, no defn found
230                     else:
231                         print "Warning: Could not generate field definition for %s.%s, manual editing of migration required." % \
232                                 (model._meta.object_name, f.name)
233                         print f, type(f)
234                                 
235                         field_definition = '<<< REPLACE THIS WITH FIELD DEFINITION FOR %s.%s >>>' % (model._meta.object_name, f.name)
236                                                 
237                     fields.append((f.name, field_definition))
238                     
239                 if mock_models:
240                     forwards += '''
241         
242         # Mock Models
243 %s
244         ''' % "\n".join(mock_models)
245         
246                 forwards += '''
247         # Model '%s'
248         db.create_table(%r, (
249             %s
250         ))''' % (
251                     model._meta.object_name,
252                     table_name,
253                     "\n            ".join(["('%s', %s)," % (f[0], f[1]) for f in fields]),
254                 )
255
256                 backwards = ('''db.delete_table('%s')
257         ''' % table_name) + backwards
258         
259                 # Now go through local M2Ms and add extra stuff for them
260                 for m in model._meta.local_many_to_many:
261                     # ignore generic relations
262                     if isinstance(m, GenericRelation):
263                         continue
264
265                     # if the 'through' option is specified, the table will
266                     # be created through the normal model creation above.
267                     if m.rel.through:
268                         continue
269                         
270                     mock_models = [create_mock_model(model, "        "), create_mock_model(m.rel.to, "        ")]
271                     
272                     forwards += '''
273         # Mock Models
274 %s
275         
276         # M2M field '%s.%s'
277         db.create_table('%s', (
278             ('id', models.AutoField(verbose_name='ID', primary_key=True, auto_created=True)),
279             ('%s', models.ForeignKey(%s, null=False)),
280             ('%s', models.ForeignKey(%s, null=False))
281         )) ''' % (
282                         "\n".join(mock_models),
283                         model._meta.object_name,
284                         m.name,
285                         m.m2m_db_table(),
286                         m.m2m_column_name()[:-3], # strip off the '_id' at the end
287                         model._meta.object_name,
288                         m.m2m_reverse_name()[:-3], # strip off the '_id' at the ned
289                         m.rel.to._meta.object_name
290                 )
291                 
292                     backwards = '''db.delete_table('%s')
293         ''' % m.m2m_db_table() + backwards
294                 
295                 if model._meta.unique_together:
296                     ut = model._meta.unique_together
297                     if not isinstance(ut[0], (list, tuple)):
298                         ut = (ut,)
299                         
300                     for unique in ut:
301                         columns = ["'%s'" % model._meta.get_field(f).column for f in unique]
302                         
303                         forwards += '''
304         db.create_index('%s', [%s], unique=True, db_tablespace='%s')
305         ''' %   (
306                         table_name,
307                         ','.join(columns),
308                         model._meta.db_tablespace
309                 )
310                 
311                 
312             forwards += '''
313         
314         db.send_create_signal('%s', ['%s'])''' % (
315                 app, 
316                 "','".join(model._meta.object_name for model in models_to_migrate)
317                 )
318         
319         # Try sniffing the encoding using PEP 0263's method
320         encoding = None
321         first_two_lines = inspect.getsourcelines(app_models_module)[0][:2]
322         for line in first_two_lines:
323             if re.search("coding[:=]\s*([-\w.]+)", line):
324                 encoding = line
325         
326         if (not forwards) and (not backwards):
327             forwards = '"Write your forwards migration here"'
328             backwards = '"Write your backwards migration here"'
329         fp = open(os.path.join(migrations_dir, new_filename), "w")
330         fp.write("""%s
331 from south.db import db
332 from django.db import models
333 from %s.models import *
334
335 class Migration:
336     
337     def forwards(self):
338         %s
339     
340     def backwards(self):
341         %s
342 """ % (encoding or "", '.'.join(app_module_path), forwards, backwards))
343         fp.close()
344         print "Created %s." % new_filename
345
346
347 def generate_field_definition(model, field):
348     """
349     Inspects the source code of 'model' to find the code used to generate 'field'
350     """
351     def test_field(field_definition):
352         try:
353             parser.suite(field_definition)
354             return True
355         except SyntaxError:
356             return False
357             
358     def strip_comments(field_definition):
359         # remove any comments at the end of the field definition string.
360         field_definition = field_definition.strip()
361         if '#' not in field_definition:
362             return field_definition
363             
364         index = field_definition.index('#')
365         while index:
366             stripped_definition = field_definition[:index].strip()
367             # if the stripped definition is parsable, then we've removed
368             # the correct comment.
369             if test_field(stripped_definition):
370                 return stripped_definition
371             
372             try:    
373                 index = field_definition.index('#', index+1)
374             except ValueError:
375                 break
376             
377         return field_definition
378         
379     # give field subclasses a chance to do anything tricky
380     # with the field definition
381     if hasattr(field, 'south_field_definition'):
382         return field.south_field_definition()
383     
384     field_pieces = []
385     found_field = False
386     source = inspect.getsourcelines(model)
387     if not source:
388         raise Exception("Could not find source to model: '%s'" % (model.__name__))
389     
390     # look for a line starting with the field name
391     start_field_re = re.compile(r'\s*%s\s*=\s*(.*)' % field.name)
392     for line in source[0]:
393         # if the field was found during a previous iteration, 
394         # we're here because the field spans across multiple lines
395         # append the current line and try again
396         if found_field:
397             field_pieces.append(line.strip())
398             if test_field(' '.join(field_pieces)):
399                 return strip_comments(' '.join(field_pieces))
400             continue
401         
402         match = start_field_re.match(line)
403         if match:
404             found_field = True
405             field_pieces.append(match.groups()[0].strip())
406             if test_field(' '.join(field_pieces)):
407                 return strip_comments(' '.join(field_pieces))
408     
409     # the 'id' field never gets defined, so return what django does by default
410     # django.db.models.options::_prepare
411     if field.name == 'id' and field.__class__ == models.AutoField:
412         return "models.AutoField(verbose_name='ID', primary_key=True, auto_created=True)"
413     
414     # search this classes parents
415     for base in model.__bases__:
416         # we don't want to scan the django base model
417         if base == models.Model:
418             continue
419             
420         field_definition = generate_field_definition(base, field)
421         if field_definition:
422             return field_definition
423             
424     return None
425     
426 def replace_model_string(field_definition, search_string, model_name):
427     # wrap 'search_string' in both ' and " chars when searching
428     quotes = ["'", '"']
429     for quote in quotes:
430         test = "%s%s%s" % (quote, search_string, quote)
431         if test in field_definition:
432             return field_definition.replace(test, model_name)
433             
434     return None
435         
436 def related_field_definition(field, field_definition):
437     # if the field definition contains any of the following strings,
438     # replace them with the model definition:
439     #   applabel.modelname
440     #   modelname
441     #   django.db.models.fields.related.RECURSIVE_RELATIONSHIP_CONSTANT
442     strings = [
443         '%s.%s' % (field.rel.to._meta.app_label, field.rel.to._meta.object_name),
444         '%s' % field.rel.to._meta.object_name,
445         RECURSIVE_RELATIONSHIP_CONSTANT
446     ]
447     
448     for test in strings:
449         fd = replace_model_string(field_definition, test, field.rel.to._meta.object_name)
450         if fd:
451             return fd
452     
453     return field_definition
454
455 def create_mock_model(model, indent="        "):
456     # produce a string representing the python syntax necessary for creating
457     # a mock model using the supplied real model
458     if not model._meta.pk.__class__.__module__.startswith('django.db.models.fields'):
459         # we can fix this with some clever imports, but it doesn't seem necessary to
460         # spend time on just yet
461         print "Can't generate a mock model for %s because it's primary key isn't a default django field; it's type %s." % (model, model._meta.pk.__class__)
462         sys.exit()
463     
464     pk_field_args = []
465     pk_field_kwargs = {}
466     other_mocks = []
467     # If it's a OneToOneField or ForeignKey, take it's first arg
468     if model._meta.pk.__class__.__name__ in ["OneToOneField", "ForeignKey"]:
469         if model._meta.pk.rel.to == model:
470             pk_field_args += ["'self'"]
471         else:
472             pk_field_args += [model._meta.pk.rel.to._meta.object_name]
473             other_mocks += [model._meta.pk.rel.to]
474     
475     # Perhaps it has a max_length set?
476     if model._meta.pk.max_length:
477         pk_field_kwargs["max_length"] = model._meta.pk.max_length
478     
479     return "%s%s%s = db.mock_model(model_name='%s', db_table='%s', db_tablespace='%s', pk_field_name='%s', pk_field_type=models.%s, pk_field_args=[%s], pk_field_kwargs=%r)" % \
480         (
481         "\n".join([create_mock_model(m, indent) for m in other_mocks]+[""]),
482         indent,
483         model._meta.object_name,
484         model._meta.object_name,
485         model._meta.db_table,
486         model._meta.db_tablespace,
487         model._meta.pk.name,
488         model._meta.pk.__class__.__name__,
489         ", ".join(pk_field_args),
490         pk_field_kwargs,
491         )