My favorites | Sign in
Project Home Wiki Issues
Checkout   Browse   Changes  
Changes to /trunk/django_evolution/tests/utils.py
r196 vs. r207 Compare: vs.  Format:
Revision r207
Go to: 
/trunk/django_evolution/tests/utils.py   r196 /trunk/django_evolution/tests/utils.py   r207
1 from datetime import datetime 1 from datetime import datetime
2 from django.core.management import sql 2 from django.core.management import sql
3 from django.core.management.color import no_style 3 from django.core.management.color import no_style
4 from django.db import connection, transaction, settings, models 4 from django.db import connection, transaction, settings, models
5 from django.db.backends.util import truncate_name 5 from django.db.backends.util import truncate_name
6 from django.db.models.loading import cache 6 from django.db.models.loading import cache
7 from django.utils.datastructures import SortedDict 7 from django.utils.datastructures import SortedDict
8 from django.utils.functional import curry 8 from django.utils.functional import curry
9 from django_evolution import signature, is_multi_db 9 from django_evolution import signature, is_multi_db
10 from django_evolution.tests import models as evo_test 10 from django_evolution.tests import models as evo_test
11 from django_evolution.utils import write_sql, execute_sql 11 from django_evolution.utils import write_sql, execute_sql
12 import copy 12 import copy
13 13
14 if is_multi_db(): 14 if is_multi_db():
15 from django.db import connections 15 from django.db import connections
16 from django.db.utils import DEFAULT_DB_ALIAS 16 from django.db.utils import DEFAULT_DB_ALIAS
17 17
18 18
19 DEFAULT_TEST_ATTRIBUTE_VALUES = { 19 DEFAULT_TEST_ATTRIBUTE_VALUES = {
20 models.CharField: 'TestCharField', 20 models.CharField: 'TestCharField',
21 models.IntegerField: '123', 21 models.IntegerField: '123',
22 models.AutoField: None, 22 models.AutoField: None,
23 models.DateTimeField: datetime.now(), 23 models.DateTimeField: datetime.now(),
24 models.PositiveIntegerField: '42' 24 models.PositiveIntegerField: '42'
25 } 25 }
26 26
27 27
28 def wrap_sql_func(func, evo_test, style, db_name=None): 28 def wrap_sql_func(func, evo_test, style, db_name=None):
29 if is_multi_db(): 29 if is_multi_db():
30 return func(evo_test, style, connections[db_name or DEFAULT_DB_ALIAS]) 30 return func(evo_test, style, connections[db_name or DEFAULT_DB_ALIAS])
31 else: 31 else:
32 return func(evo_test, style) 32 return func(evo_test, style)
33 33
34 # Wrap the sql.* functions to work with the multi-db support 34 # Wrap the sql.* functions to work with the multi-db support
35 sql_create = curry(wrap_sql_func, sql.sql_create) 35 sql_create = curry(wrap_sql_func, sql.sql_create)
36 sql_indexes = curry(wrap_sql_func, sql.sql_indexes) 36 sql_indexes = curry(wrap_sql_func, sql.sql_indexes)
37 sql_delete = curry(wrap_sql_func, sql.sql_delete) 37 sql_delete = curry(wrap_sql_func, sql.sql_delete)
38 38
39 39
40 def _register_models(app_label='tests', db_name='default', *models): 40 def _register_models(app_label='tests', db_name='default', *models):
41 app_cache = SortedDict() 41 app_cache = SortedDict()
42 42
43 my_connection = connection 43 my_connection = connection
44 44
45 if is_multi_db(): 45 if is_multi_db():
46 my_connection = connections[db_name or DEFAULT_DB_ALIAS] 46 my_connection = connections[db_name or DEFAULT_DB_ALIAS]
47 47
48 max_name_length = my_connection.ops.max_name_length() 48 max_name_length = my_connection.ops.max_name_length()
49 49
50 for name, model in reversed(models): 50 for name, model in reversed(models):
51 if model._meta.module_name in cache.app_models['django_evolution']: 51 if model._meta.module_name in cache.app_models['django_evolution']:
52 del cache.app_models['django_evolution'][model._meta.module_name] 52 del cache.app_models['django_evolution'][model._meta.module_name]
53 53
54 orig_db_table = model._meta.db_table 54 orig_db_table = model._meta.db_table
55 orig_object_name = model._meta.object_name 55 orig_object_name = model._meta.object_name
56 orig_module_name = model._meta.module_name 56 orig_module_name = model._meta.module_name
57 57
58 generated_db_table = truncate_name( 58 generated_db_table = truncate_name(
59 '%s_%s' % (model._meta.app_label, model._meta.module_name), 59 '%s_%s' % (model._meta.app_label, model._meta.module_name),
60 max_name_length) 60 max_name_length)
61 61
62 if orig_db_table.startswith(generated_db_table): 62 if orig_db_table.startswith(generated_db_table):
63 model._meta.db_table = '%s_%s' % (app_label, name.lower()) 63 model._meta.db_table = '%s_%s' % (app_label, name.lower())
64 64
65 model._meta.db_table = truncate_name(model._meta.db_table, 65 model._meta.db_table = truncate_name(model._meta.db_table,
66 max_name_length) 66 max_name_length)
67 model._meta.app_label = app_label 67 model._meta.app_label = app_label
68 model._meta.object_name = name 68 model._meta.object_name = name
69 model._meta.module_name = name.lower() 69 model._meta.module_name = name.lower()
70 70
71 add_app_test_model(model, app_label=app_label) 71 add_app_test_model(model, app_label=app_label)
72 72
73 for field in model._meta.local_many_to_many: 73 for field in model._meta.local_many_to_many:
74 if not field.rel.through: 74 if not field.rel.through:
75 continue 75 continue
76 76
77 through = field.rel.through 77 through = field.rel.through
78 78
79 generated_db_table = truncate_name( 79 generated_db_table = truncate_name(
80 '%s_%s' % (orig_db_table, field.name), 80 '%s_%s' % (orig_db_table, field.name),
81 max_name_length) 81 max_name_length)
82 82
83 if through._meta.db_table == generated_db_table: 83 if through._meta.db_table == generated_db_table:
84 through._meta.app_label = app_label 84 through._meta.app_label = app_label
85 85
86 # Transform the 'through' table information only 86 # Transform the 'through' table information only
87 # if we've transformed the parent db_table. 87 # if we've transformed the parent db_table.
88 if model._meta.db_table != orig_db_table: 88 if model._meta.db_table != orig_db_table:
89 through._meta.db_table = \ 89 through._meta.db_table = \
90 '%s_%s' % (model._meta.db_table, field.name) 90 '%s_%s' % (model._meta.db_table, field.name)
91 91
92 through._meta.object_name = \ 92 through._meta.object_name = \
93 through._meta.object_name.replace( 93 through._meta.object_name.replace(
94 orig_object_name, 94 orig_object_name,
95 model._meta.object_name) 95 model._meta.object_name)
96 96
97 through._meta.module_name = \ 97 through._meta.module_name = \
98 through._meta.module_name.replace( 98 through._meta.module_name.replace(
99 orig_module_name, 99 orig_module_name,
100 model._meta.module_name) 100 model._meta.module_name)
101 101
102 through._meta.db_table = \ 102 through._meta.db_table = \
103 truncate_name(through._meta.db_table, max_name_length) 103 truncate_name(through._meta.db_table, max_name_length)
104 104
105 for field in through._meta.local_fields: 105 for field in through._meta.local_fields:
106 if field.rel and field.rel.to: 106 if field.rel and field.rel.to:
107 column = field.column 107 column = field.column
108 108
109 if (column.startswith(orig_module_name) or 109 if (column.startswith(orig_module_name) or
110 column.startswith('to_%s' % orig_module_name) or 110 column.startswith('to_%s' % orig_module_name) or
111 column.startswith('from_%s' % orig_module_name)): 111 column.startswith('from_%s' % orig_module_name)):
112 112
113 field.column = column.replace( 113 field.column = column.replace(
114 orig_module_name, 114 orig_module_name,
115 model._meta.module_name) 115 model._meta.module_name)
116 116
117 if (through._meta.module_name in 117 if (through._meta.module_name in
118 cache.app_models['django_evolution']): 118 cache.app_models['django_evolution']):
119 del cache.app_models['django_evolution'][ 119 del cache.app_models['django_evolution'][
120 through._meta.module_name] 120 through._meta.module_name]
121 121
122 app_cache[through._meta.module_name] = through 122 app_cache[through._meta.module_name] = through
123 add_app_test_model(through, app_label=app_label) 123 add_app_test_model(through, app_label=app_label)
124 124
125 app_cache[model._meta.module_name] = model 125 app_cache[model._meta.module_name] = model
126 126
127 if evo_test not in cache.app_store:
128 cache.app_store[evo_test] = len(cache.app_store)
129
130 if hasattr(cache, 'app_labels'):
131 cache.app_labels[app_label] = evo_test
132
127 return app_cache 133 return app_cache
128 134
129 135
130 def register_models(*models): 136 def register_models(*models):
131 return _register_models('tests', 'default', *models) 137 return _register_models('tests', 'default', *models)
132 138
133 139
134 def register_models_multi(app_label, db_name, *models): 140 def register_models_multi(app_label, db_name, *models):
135 return _register_models(app_label, db_name, *models) 141 return _register_models(app_label, db_name, *models)
136 142
137 143
138 def _test_proj_sig(app_label, *models, **kwargs): 144 def _test_proj_sig(app_label, *models, **kwargs):
139 "Generate a dummy project signature based around a single model" 145 "Generate a dummy project signature based around a single model"
140 version = kwargs.get('version', 1) 146 version = kwargs.get('version', 1)
141 proj_sig = { 147 proj_sig = {
142 app_label: SortedDict(), 148 app_label: SortedDict(),
143 '__version__': version, 149 '__version__': version,
144 } 150 }
145 151
146 # Compute the project siguature 152 # Compute the project siguature
147 for full_name, model in models: 153 for full_name, model in models:
148 parts = full_name.split('.') 154 parts = full_name.split('.')
149 155
150 if len(parts) == 1: 156 if len(parts) == 1:
151 name = parts[0] 157 name = parts[0]
152 app = app_label 158 app = app_label
153 else: 159 else:
154 app, name = parts 160 app, name = parts
155 161
156 proj_sig.setdefault(app, SortedDict())[name] = \ 162 proj_sig.setdefault(app, SortedDict())[name] = \
157 signature.create_model_sig(model) 163 signature.create_model_sig(model)
158 164
159 return proj_sig 165 return proj_sig
160 166
161 167
162 def test_proj_sig(*models, **kwargs): 168 def test_proj_sig(*models, **kwargs):
163 return _test_proj_sig('tests', *models, **kwargs) 169 return _test_proj_sig('tests', *models, **kwargs)
164 170
165 171
166 def test_proj_sig_multi(app_label, *models, **kwargs): 172 def test_proj_sig_multi(app_label, *models, **kwargs):
167 return _test_proj_sig(app_label, *models, **kwargs) 173 return _test_proj_sig(app_label, *models, **kwargs)
168 174
169 175
170 def execute_transaction(sql, output=False, database='default'): 176 def execute_transaction(sql, output=False, database='default'):
171 "A transaction wrapper for executing a list of SQL statements" 177 "A transaction wrapper for executing a list of SQL statements"
172 my_connection = connection 178 my_connection = connection
173 using_args = {} 179 using_args = {}
174 180
175 if is_multi_db(): 181 if is_multi_db():
176 if not database: 182 if not database:
177 database = DEFAULT_DB_ALIAS 183 database = DEFAULT_DB_ALIAS
178 184
179 my_connection = connections[database] 185 my_connection = connections[database]
180 using_args['using'] = database 186 using_args['using'] = database
181 187
182 try: 188 try:
183 # Begin Transaction 189 # Begin Transaction
184 transaction.enter_transaction_management(**using_args) 190 transaction.enter_transaction_management(**using_args)
185 transaction.managed(True, **using_args) 191 transaction.managed(True, **using_args)
186 192
187 cursor = my_connection.cursor() 193 cursor = my_connection.cursor()
188 194
189 # Perform the SQL 195 # Perform the SQL
190 if output: 196 if output:
191 write_sql(sql, database) 197 write_sql(sql, database)
192 198
193 execute_sql(cursor, sql) 199 execute_sql(cursor, sql)
194 200
195 transaction.commit(**using_args) 201 transaction.commit(**using_args)
196 transaction.leave_transaction_management(**using_args) 202 transaction.leave_transaction_management(**using_args)
197 except Exception: 203 except Exception:
198 transaction.rollback(**using_args) 204 transaction.rollback(**using_args)
199 raise 205 raise
200 206
201 207
202 def execute_test_sql(start, end, sql, debug=False, app_label='tests', 208 def execute_test_sql(start, end, sql, debug=False, app_label='tests',
203 database='default'): 209 database='default'):
204 """ 210 """
205 Execute a test SQL sequence. This method also creates and destroys the 211 Execute a test SQL sequence. This method also creates and destroys the
206 database tables required by the models registered against the test 212 database tables required by the models registered against the test
207 application. 213 application.
208 214
209 start and end are the start- and end-point states of the application cache. 215 start and end are the start- and end-point states of the application cache.
210 216
211 sql is the list of sql statements to execute. 217 sql is the list of sql statements to execute.
212 218
213 cleanup is a list of extra sql statements required to clean up. This is 219 cleanup is a list of extra sql statements required to clean up. This is
214 primarily for any extra m2m tables that were added during a test that won't 220 primarily for any extra m2m tables that were added during a test that won't
215 be cleaned up by Django's sql_delete() implementation. 221 be cleaned up by Django's sql_delete() implementation.
216 222
217 debug is a helper flag. It displays the ALL the SQL that would be executed, 223 debug is a helper flag. It displays the ALL the SQL that would be executed,
218 (including setup and teardown SQL), and executes the Django-derived 224 (including setup and teardown SQL), and executes the Django-derived
219 setup/teardown SQL. 225 setup/teardown SQL.
220 """ 226 """
221 # Set up the initial state of the app cache 227 # Set up the initial state of the app cache
222 set_app_test_models(copy.deepcopy(start), app_label=app_label) 228 set_app_test_models(copy.deepcopy(start), app_label=app_label)
223 229
224 # Install the initial tables and indicies 230 # Install the initial tables and indicies
225 style = no_style() 231 style = no_style()
226 execute_transaction(sql_create(evo_test, style, database), 232 execute_transaction(sql_create(evo_test, style, database),
227 output=debug, database=database) 233 output=debug, database=database)
228 execute_transaction(sql_indexes(evo_test, style, database), 234 execute_transaction(sql_indexes(evo_test, style, database),
229 output=debug, database=database) 235 output=debug, database=database)
230 create_test_data(models.get_models(evo_test), database) 236 create_test_data(models.get_models(evo_test), database)
231 237
232 # Set the app cache to the end state 238 # Set the app cache to the end state
233 set_app_test_models(copy.deepcopy(end), app_label=app_label) 239 set_app_test_models(copy.deepcopy(end), app_label=app_label)
234 240
235 try: 241 try:
236 # Execute the test sql 242 # Execute the test sql
237 if debug: 243 if debug:
238 write_sql(sql, database) 244 write_sql(sql, database)
239 else: 245 else:
240 execute_transaction(sql, output=True, database=database) 246 execute_transaction(sql, output=True, database=database)
241 finally: 247 finally:
242 # Cleanup the apps. 248 # Cleanup the apps.
243 if debug: 249 if debug:
244 print sql_delete(evo_test, style, database) 250 print sql_delete(evo_test, style, database)
245 else: 251 else:
246 execute_transaction(sql_delete(evo_test, style, database), 252 execute_transaction(sql_delete(evo_test, style, database),
247 output=debug, database=database) 253 output=debug, database=database)
248 254
249 def create_test_data(app_models, database): 255 def create_test_data(app_models, database):
250 deferred_models = [] 256 deferred_models = []
251 deferred_fields = {} 257 deferred_fields = {}
252 258
253 using_args = {} 259 using_args = {}
254 260
255 if is_multi_db(): 261 if is_multi_db():
256 using_args['using'] = database 262 using_args['using'] = database
257 263
258 for model in app_models: 264 for model in app_models:
259 params = {} 265 params = {}
260 deferred = False 266 deferred = False
261 for field in model._meta.fields: 267 for field in model._meta.fields:
262 if not deferred: 268 if not deferred:
263 if type(field) in (models.ForeignKey, models.ManyToManyField): 269 if type(field) in (models.ForeignKey, models.ManyToManyField):
264 related_model = field.rel.to 270 related_model = field.rel.to
265 271
266 related_q = related_model.objects.all() 272 related_q = related_model.objects.all()
267 273
268 if is_multi_db(): 274 if is_multi_db():
269 related_q = related_q.using(database) 275 related_q = related_q.using(database)
270 276
271 if related_q.count(): 277 if related_q.count():
272 related_instance = related_q[0] 278 related_instance = related_q[0]
273 else: 279 else:
274 if field.null == False: 280 if field.null == False:
275 # Field cannot be null yet the related object 281 # Field cannot be null yet the related object
276 # hasn't been created yet Defer the creation of 282 # hasn't been created yet Defer the creation of
277 # this model 283 # this model
278 deferred = True 284 deferred = True
279 deferred_models.append(model) 285 deferred_models.append(model)
280 else: 286 else:
281 # Field cannot be set yet but null is acceptable 287 # Field cannot be set yet but null is acceptable
282 # for the moment 288 # for the moment
283 deferred_fields[type(model)] = \ 289 deferred_fields[type(model)] = \
284 deferred_fields.get(type(model), 290 deferred_fields.get(type(model),
285 []).append(field) 291 []).append(field)
286 related_instance = None 292 related_instance = None
287 293
288 if not deferred: 294 if not deferred:
289 if type(field) == models.ForeignKey: 295 if type(field) == models.ForeignKey:
290 params[field.name] = related_instance 296 params[field.name] = related_instance
291 else: 297 else:
292 params[field.name] = [related_instance] 298 params[field.name] = [related_instance]
293 else: 299 else:
294 params[field.name] = \ 300 params[field.name] = \
295 DEFAULT_TEST_ATTRIBUTE_VALUES[type(field)] 301 DEFAULT_TEST_ATTRIBUTE_VALUES[type(field)]
296 302
297 if not deferred: 303 if not deferred:
298 model(**params).save(**using_args) 304 model(**params).save(**using_args)
299 305
300 # Create all deferred models. 306 # Create all deferred models.
301 if deferred_models: 307 if deferred_models:
302 create_test_data(deferred_models, database) 308 create_test_data(deferred_models, database)
303 309
304 # All models should be created (Not all deferred fields have been populated 310 # All models should be created (Not all deferred fields have been populated
305 # yet) Populate deferred fields that we know about. Here lies untested 311 # yet) Populate deferred fields that we know about. Here lies untested
306 # code! 312 # code!
307 if deferred_fields: 313 if deferred_fields:
308 for model, field_list in deferred_fields.items(): 314 for model, field_list in deferred_fields.items():
309 for field in field_list: 315 for field in field_list:
310 related_model = field.rel.to 316 related_model = field.rel.to
311 related_instance = related_model.objects.using(database)[0] 317 related_instance = related_model.objects.using(database)[0]
312 318
313 if type(field) == models.ForeignKey: 319 if type(field) == models.ForeignKey:
314 setattr(model, field.name, related_instance) 320 setattr(model, field.name, related_instance)
315 else: 321 else:
316 getattr(model, field.name).add(related_instance, 322 getattr(model, field.name).add(related_instance,
317 **using_args) 323 **using_args)
318 324
319 model.save(**using_args) 325 model.save(**using_args)
320 326
321 327
322 def test_sql_mapping(test_field_name, db_name='default'): 328 def test_sql_mapping(test_field_name, db_name='default'):
323 if is_multi_db(): 329 if is_multi_db():
324 engine = settings.DATABASES[db_name]['ENGINE'].split('.')[-1] 330 engine = settings.DATABASES[db_name]['ENGINE'].split('.')[-1]
325 else: 331 else:
326 engine = settings.DATABASE_ENGINE 332 engine = settings.DATABASE_ENGINE
327 333
328 sql_for_engine = __import__('django_evolution.tests.db.%s' % (engine), 334 sql_for_engine = __import__('django_evolution.tests.db.%s' % (engine),
329 {}, {}, ['']) 335 {}, {}, [''])
330 336
331 return getattr(sql_for_engine, test_field_name) 337 return getattr(sql_for_engine, test_field_name)
332 338
333 339
334 def deregister_models(app_label='tests'): 340 def deregister_models(app_label='tests'):
335 "Clear the test section of the app cache" 341 "Clear the test section of the app cache"
336 del cache.app_models[app_label] 342 del cache.app_models[app_label]
337 clear_models_cache() 343 clear_models_cache()
338 344
339 345
340 def clear_models_cache(): 346 def clear_models_cache():
341 """Clears the Django models cache. 347 """Clears the Django models cache.
342 348
343 This cache is used in Django >= 1.2 to quickly return results from 349 This cache is used in Django >= 1.2 to quickly return results from
344 cache.get_models(). It needs to be cleared when modifying the model 350 cache.get_models(). It needs to be cleared when modifying the model
345 registry. 351 registry.
346 """ 352 """
347 if hasattr(cache, '_get_models_cache'): 353 if hasattr(cache, '_get_models_cache'):
348 # On Django 1.2, we need to clear this cache when unregistering models. 354 # On Django 1.2, we need to clear this cache when unregistering models.
349 cache._get_models_cache.clear() 355 cache._get_models_cache.clear()
350 356
351 357
352 def set_app_test_models(models, app_label): 358 def set_app_test_models(models, app_label):
353 """Sets the list of models in the Django test models registry.""" 359 """Sets the list of models in the Django test models registry."""
354 cache.app_models[app_label] = models 360 cache.app_models[app_label] = models
355 clear_models_cache() 361 clear_models_cache()
356 362
357 363
358 def add_app_test_model(model, app_label): 364 def add_app_test_model(model, app_label):
359 """Adds a model to the Django test models registry.""" 365 """Adds a model to the Django test models registry."""
360 key = model._meta.object_name.lower() 366 key = model._meta.object_name.lower()
361 cache.app_models.setdefault(app_label, SortedDict())[key] = model 367 cache.app_models.setdefault(app_label, SortedDict())[key] = model
362 clear_models_cache() 368 clear_models_cache()
Powered by Google Project Hosting