f52efe7aac64b884ddb0260462004ae8db7f00d0
[wolnelektury.git] / apps / south / management / commands / startmigration.py
1 from django.core.management.base import BaseCommand
2 from django.core.management.color import no_style
3 from django.db import models
4 from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
5 from django.contrib.contenttypes.generic import GenericRelation
6 from optparse import make_option
7 from south import migration
8 import sys
9 import os
10 import re
11 import string
12 import random
13 import inspect
14 import parser
15
16 class Command(BaseCommand):
17     option_list = BaseCommand.option_list + (
18         make_option('--model', action='append', dest='model_list', type='string',
19             help='Generate a Create Table migration for the specified model.  Add multiple models to this migration with subsequent --model parameters.'),
20         make_option('--initial', action='store_true', dest='initial', default=False,
21             help='Generate the initial schema for the app.'),
22     )
23     help = "Creates a new template migration for the given app"
24     
25     def handle(self, app=None, name="", model_list=None, initial=False, **options):
26         
27         # If model_list is None, then it's an empty list
28         model_list = model_list or []
29         
30         # make sure --model and --all aren't both specified
31         if initial and model_list:
32             print "You cannot use --initial and other options together"
33             return
34             
35         # specify the default name 'initial' if a name wasn't specified and we're
36         # doing a migration for an entire app
37         if not name and initial:
38             name = 'initial'
39             
40         # if not name, there's an error
41         if not name:
42             print "You must name this migration"
43             return
44         
45         if not app:
46             print "Please provide an app in which to create the migration."
47             return
48             
49         # See if the app exists
50         app_models_module = models.get_app(app)
51         if not app_models_module:
52             print "App '%s' doesn't seem to exist, isn't in INSTALLED_APPS, or has no models." % app
53             return
54             
55         # Determine what models should be included in this migration.
56         models_to_migrate = []
57         if initial:
58             models_to_migrate = models.get_models(app_models_module)
59             if not models_to_migrate:
60                 print "No models found in app '%s'" % (app)
61                 return
62         else:
63             for model_name in model_list:
64                 model = models.get_model(app, model_name)
65                 if not model:
66                     print "Couldn't find model '%s' in app '%s'" % (model_name, app)
67                     return
68                     
69                 models_to_migrate.append(model)
70                 
71         # Make the migrations directory if it's not there
72         app_module_path = app_models_module.__name__.split('.')[0:-1]
73         try:
74             app_module = __import__('.'.join(app_module_path), {}, {}, [''])
75         except ImportError:
76             print "Couldn't find path to App '%s'." % app
77             return
78             
79         migrations_dir = os.path.join(
80             os.path.dirname(app_module.__file__),
81             "migrations",
82         )
83         if not os.path.isdir(migrations_dir):
84             print "Creating migrations directory at '%s'..." % migrations_dir
85             os.mkdir(migrations_dir)
86             # Touch the init py file
87             open(os.path.join(migrations_dir, "__init__.py"), "w").close()
88         # See what filename is next in line. We assume they use numbers.
89         migrations = migration.get_migration_names(migration.get_app(app))
90         highest_number = 0
91         for migration_name in migrations:
92             try:
93                 number = int(migration_name.split("_")[0])
94                 highest_number = max(highest_number, number)
95             except ValueError:
96                 pass
97         # Make the new filename
98         new_filename = "%04i%s_%s.py" % (
99             highest_number + 1,
100             "".join([random.choice(string.letters.lower()) for i in range(0)]), # Possible random stuff insertion
101             name,
102         )
103         # If there's a model, make the migration skeleton, else leave it bare
104         forwards, backwards = '', ''
105         if models_to_migrate:
106             for model in models_to_migrate:
107                 table_name = model._meta.db_table
108                 mock_models = []
109                 fields = []
110                 for f in model._meta.local_fields:
111                     # look up the field definition to see how this was created
112                     field_definition = generate_field_definition(model, f)
113                     if field_definition:
114                         
115                         if isinstance(f, models.ForeignKey):
116                             mock_models.append(create_mock_model(f.rel.to))
117                             field_definition = related_field_definition(f, field_definition)
118                             
119                     else:
120                         print "Warning: Could not generate field definition for %s.%s, manual editing of migration required." % \
121                                 (model._meta.object_name, f.name)
122                                 
123                         field_definition = '<<< REPLACE THIS WITH FIELD DEFINITION FOR %s.%s >>>' % (model._meta.object_name, f.name)
124                                                 
125                     fields.append((f.name, field_definition))
126                     
127                 if mock_models:
128                     forwards += '''
129         
130         # Mock Models
131         %s
132         ''' % "\n        ".join(mock_models)
133         
134                 forwards += '''
135         # Model '%s'
136         db.create_table('%s', (
137             %s
138         ))''' % (
139                     model._meta.object_name,
140                     table_name,
141                     "\n            ".join(["('%s', %s)," % (f[0], f[1]) for f in fields]),
142                 )
143
144                 backwards = ('''db.delete_table('%s')
145         ''' % table_name) + backwards
146         
147                 # Now go through local M2Ms and add extra stuff for them
148                 for m in model._meta.local_many_to_many:
149                     # ignore generic relations
150                     if isinstance(m, GenericRelation):
151                         continue
152
153                     # if the 'through' option is specified, the table will
154                     # be created through the normal model creation above.
155                     if m.rel.through:
156                         continue
157                         
158                     mock_models = [create_mock_model(model), create_mock_model(m.rel.to)]
159                     
160                     forwards += '''
161         # Mock Models
162         %s
163         
164         # M2M field '%s.%s'
165         db.create_table('%s', (
166             ('id', models.AutoField(verbose_name='ID', primary_key=True, auto_created=True)),
167             ('%s', models.ForeignKey(%s, null=False)),
168             ('%s', models.ForeignKey(%s, null=False))
169         )) ''' % (
170                         "\n        ".join(mock_models),
171                         model._meta.object_name,
172                         m.name,
173                         m.m2m_db_table(),
174                         m.m2m_column_name()[:-3], # strip off the '_id' at the end
175                         model._meta.object_name,
176                         m.m2m_reverse_name()[:-3], # strip off the '_id' at the ned
177                         m.rel.to._meta.object_name
178                 )
179                 
180                     backwards = '''db.delete_table('%s')
181         ''' % m.m2m_db_table() + backwards
182                 
183                 if model._meta.unique_together:
184                     ut = model._meta.unique_together
185                     if not isinstance(ut[0], (list, tuple)):
186                         ut = (ut,)
187                         
188                     for unique in ut:
189                         columns = ["'%s'" % model._meta.get_field(f).column for f in unique]
190                         
191                         forwards += '''
192         db.create_index('%s', [%s], unique=True, db_tablespace='%s')
193         ''' %   (
194                         table_name,
195                         ','.join(columns),
196                         model._meta.db_tablespace
197                 )
198                 
199                 
200             forwards += '''
201         
202         db.send_create_signal('%s', ['%s'])''' % (
203                 app, 
204                 "','".join(model._meta.object_name for model in models_to_migrate)
205                 )
206         
207         else:
208             forwards = '"Write your forwards migration here"'
209             backwards = '"Write your backwards migration here"'
210         fp = open(os.path.join(migrations_dir, new_filename), "w")
211         fp.write("""
212 from south.db import db
213 from %s.models import *
214
215 class Migration:
216     
217     def forwards(self):
218         %s
219     
220     def backwards(self):
221         %s
222 """ % ('.'.join(app_module_path), forwards, backwards))
223         fp.close()
224         print "Created %s." % new_filename
225
226
227 def generate_field_definition(model, field):
228     """
229     Inspects the source code of 'model' to find the code used to generate 'field'
230     """
231     def test_field(field_definition):
232         try:
233             parser.suite(field_definition)
234             return True
235         except SyntaxError:
236             return False
237             
238     def strip_comments(field_definition):
239         # remove any comments at the end of the field definition string.
240         field_definition = field_definition.strip()
241         if '#' not in field_definition:
242             return field_definition
243             
244         index = field_definition.index('#')
245         while index:
246             stripped_definition = field_definition[:index].strip()
247             # if the stripped definition is parsable, then we've removed
248             # the correct comment.
249             if test_field(stripped_definition):
250                 return stripped_definition
251                 
252             index = field_definition.index('#', index+1)
253             
254         return field_definition
255         
256     # give field subclasses a chance to do anything tricky
257     # with the field definition
258     if hasattr(field, 'south_field_definition'):
259         return field.south_field_definition()
260     
261     field_pieces = []
262     found_field = False
263     source = inspect.getsourcelines(model)
264     if not source:
265         raise Exception("Could not find source to model: '%s'" % (model.__name__))
266         
267     # look for a line starting with the field name
268     start_field_re = re.compile(r'\s*%s\s*=\s*(.*)' % field.name)
269     for line in source[0]:
270         # if the field was found during a previous iteration, 
271         # we're here because the field spans across multiple lines
272         # append the current line and try again
273         if found_field:
274             field_pieces.append(line.strip())
275             if test_field(' '.join(field_pieces)):
276                 return strip_comments(' '.join(field_pieces))
277             continue
278         
279         match = start_field_re.match(line)
280         if match:
281             found_field = True
282             field_pieces.append(match.groups()[0].strip())
283             if test_field(' '.join(field_pieces)):
284                 return strip_comments(' '.join(field_pieces))
285     
286     # the 'id' field never gets defined, so return what django does by default
287     # django.db.models.options::_prepare
288     if field.name == 'id' and field.__class__ == models.AutoField:
289         return "models.AutoField(verbose_name='ID', primary_key=True, auto_created=True)"
290     
291     # search this classes parents
292     for base in model.__bases__:
293         # we don't want to scan the django base model
294         if base == models.Model:
295             continue
296             
297         field_definition = generate_field_definition(base, field)
298         if field_definition:
299             return field_definition
300             
301     return None
302     
303 def replace_model_string(field_definition, search_string, model_name):
304     # wrap 'search_string' in both ' and " chars when searching
305     quotes = ["'", '"']
306     for quote in quotes:
307         test = "%s%s%s" % (quote, search_string, quote)
308         if test in field_definition:
309             return field_definition.replace(test, model_name)
310             
311     return None
312         
313 def related_field_definition(field, field_definition):
314     # if the field definition contains any of the following strings,
315     # replace them with the model definition:
316     #   applabel.modelname
317     #   modelname
318     #   django.db.models.fields.related.RECURSIVE_RELATIONSHIP_CONSTANT
319     strings = [
320         '%s.%s' % (field.rel.to._meta.app_label, field.rel.to._meta.object_name),
321         '%s' % field.rel.to._meta.object_name,
322         RECURSIVE_RELATIONSHIP_CONSTANT
323     ]
324     
325     for test in strings:
326         fd = replace_model_string(field_definition, test, field.rel.to._meta.object_name)
327         if fd:
328             return fd
329     
330     return field_definition
331
332 def create_mock_model(model):
333     # produce a string representing the python syntax necessary for creating
334     # a mock model using the supplied real model
335     if model._meta.pk.__class__.__module__ != 'django.db.models.fields':
336         # we can fix this with some clever imports, but it doesn't seem necessary to
337         # spend time on just yet
338         print "Can't generate a mock model for %s because it's primary key isn't a default django field" % model
339         sys.exit()
340     
341     return "%s = db.mock_model(model_name='%s', db_table='%s', db_tablespace='%s', pk_field_name='%s', pk_field_type=models.%s)" % \
342         (
343         model._meta.object_name,
344         model._meta.object_name,
345         model._meta.db_table,
346         model._meta.db_tablespace,
347         model._meta.pk.name,
348         model._meta.pk.__class__.__name__
349         )