Skip to content

Commit 62cd6f5

Browse files
committed
Rebase
1 parent 4209372 commit 62cd6f5

6 files changed

Lines changed: 515 additions & 22 deletions

File tree

factory/builder.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,18 @@ def chain(self):
218218
parent_chain = ()
219219
return (self.stub,) + parent_chain
220220

221-
def recurse(self, factory, declarations, force_sequence=None):
221+
def recurse(self, factory, declarations, force_sequence=None, collect_instances=None):
222222
from . import base
223223
if not issubclass(factory, base.BaseFactory):
224224
raise errors.AssociatedClassError(
225225
"%r: Attempting to recursing into a non-factory object %r"
226226
% (self, factory))
227227
builder = self.builder.recurse(factory._meta, declarations)
228-
return builder.build(parent_step=self, force_sequence=force_sequence)
228+
return builder.build(
229+
parent_step=self,
230+
force_sequence=force_sequence,
231+
collect_instances=collect_instances,
232+
)
229233

230234
def __repr__(self):
231235
return f"<BuildStep for {self.builder!r}>"
@@ -246,7 +250,7 @@ def __init__(self, factory_meta, extras, strategy):
246250
self.extras = extras
247251
self.force_init_sequence = extras.pop('__sequence', None)
248252

249-
def build(self, parent_step=None, force_sequence=None):
253+
def build(self, parent_step=None, force_sequence=None, collect_instances=None):
250254
"""Build a factory instance."""
251255
# TODO: Handle "batch build" natively
252256
pre, post = parse_declarations(
@@ -277,19 +281,23 @@ def build(self, parent_step=None, force_sequence=None):
277281
kwargs=kwargs,
278282
)
279283

280-
postgen_results = {}
281-
for declaration_name in post.sorted():
282-
declaration = post[declaration_name]
283-
postgen_results[declaration_name] = declaration.declaration.evaluate_post(
284+
if collect_instances is None:
285+
postgen_results = {}
286+
for declaration_name in post.sorted():
287+
declaration = post[declaration_name]
288+
postgen_results[declaration_name] = declaration.declaration.evaluate_post(
289+
instance=instance,
290+
step=step,
291+
overrides=declaration.context,
292+
)
293+
self.factory_meta.use_postgeneration_results(
284294
instance=instance,
285295
step=step,
286-
overrides=declaration.context,
296+
results=postgen_results,
287297
)
288-
self.factory_meta.use_postgeneration_results(
289-
instance=instance,
290-
step=step,
291-
results=postgen_results,
292-
)
298+
else:
299+
collect_instances.append(instance)
300+
293301
return instance
294302

295303
def recurse(self, factory_meta, extras):

factory/django.py

Lines changed: 182 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@
99
import logging
1010
import os
1111
import warnings
12+
from collections import defaultdict
1213
from typing import Dict, TypeVar
1314

1415
from django.contrib.auth.hashers import make_password
1516
from django.core import files as django_files
16-
from django.db import IntegrityError
17+
from django.db import IntegrityError, connections, models
18+
from django.db.models.sql import InsertQuery
1719

18-
from . import base, declarations, errors
20+
from . import base, builder, declarations, enums, errors
1921

2022
logger = logging.getLogger('factory.generate')
2123

22-
2324
DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS
2425
T = TypeVar("T")
25-
2626
_LAZY_LOADS: Dict[str, object] = {}
2727

2828

@@ -45,11 +45,29 @@ def _lazy_load_get_model():
4545
_LAZY_LOADS['get_model'] = django_apps.apps.get_model
4646

4747

48+
def connection_supports_bulk_insert(using):
49+
"""
50+
Does the database support bulk_insert
51+
52+
There are 2 pieces to this puzzle:
53+
* The database needs to support `bulk_insert`
54+
* AND it also needs to be capable of returning all the newly minted objects' id
55+
56+
If any of these is `False`, the database does NOT support bulk_insert
57+
"""
58+
db_features = connections[using].features
59+
return (
60+
db_features.has_bulk_insert
61+
and db_features.can_return_rows_from_bulk_insert
62+
)
63+
64+
4865
class DjangoOptions(base.FactoryOptions):
4966
def _build_default_options(self):
5067
return super()._build_default_options() + [
5168
base.OptionDefault('django_get_or_create', (), inherit=True),
5269
base.OptionDefault('database', DEFAULT_DB_ALIAS, inherit=True),
70+
base.OptionDefault('use_bulk_create', False, inherit=True),
5371
base.OptionDefault('skip_postgeneration_save', False, inherit=True),
5472
]
5573

@@ -165,6 +183,89 @@ def _get_or_create(cls, model_class, *args, **kwargs):
165183

166184
return instance
167185

186+
@classmethod
187+
def supports_bulk_insert(cls):
188+
return (cls._meta.use_bulk_create
189+
and connection_supports_bulk_insert(cls._meta.database))
190+
191+
@classmethod
192+
def create(cls, **kwargs):
193+
"""Create an instance of the associated class, with overridden attrs."""
194+
if not cls.supports_bulk_insert():
195+
return super().create(**kwargs)
196+
197+
return cls._bulk_create(1, **kwargs)[0]
198+
199+
@classmethod
200+
def create_batch(cls, size, **kwargs):
201+
if not cls.supports_bulk_insert():
202+
return super().create_batch(size, **kwargs)
203+
204+
return cls._bulk_create(size, **kwargs)
205+
206+
@classmethod
207+
def _refresh_database_pks(cls, model_cls, objs):
208+
# Avoid causing a django.core.exceptions.AppRegistryNotReady throughout all the tests.
209+
# TODO: remove the `from . import django` from the `__init__.py`
210+
from django.contrib.contenttypes.fields import GenericForeignKey
211+
212+
def get_field_value(instance, field):
213+
if isinstance(field, GenericForeignKey) and field.is_cached(instance):
214+
return field.get_cached_value(instance)
215+
return getattr(instance, field.name)
216+
217+
# Current Django version's GenericForeignKey is not made to work with bulk_insert.
218+
#
219+
# The issue is that it caches the object referenced, once the object is
220+
# saved and receives a pk, the cache no longer matches. It doesn't
221+
# matter that it's the same obj reference. This is to bypass that pk
222+
# check and reset it.
223+
fields_to_reset = (GenericForeignKey, models.OneToOneField)
224+
225+
fields = [f for f in model_cls._meta.get_fields() if isinstance(f, fields_to_reset)]
226+
if not fields:
227+
return
228+
229+
for obj in objs:
230+
for field in fields:
231+
setattr(obj, field.name, get_field_value(obj, field))
232+
233+
@classmethod
234+
def _bulk_create(cls, size, **kwargs):
235+
if cls._meta.abstract:
236+
raise errors.FactoryError(
237+
"Cannot generate instances of abstract factory %(f)s; "
238+
"Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract "
239+
"is either not set or False." % dict(f=cls.__name__))
240+
241+
models_to_return = []
242+
instances = []
243+
for _ in range(size):
244+
step = builder.StepBuilder(cls._meta, kwargs, enums.BUILD_STRATEGY)
245+
models_to_return.append(step.build(collect_instances=instances))
246+
247+
for model_cls, objs in dependency_insert_order(instances):
248+
manager = cls._get_manager(model_cls)
249+
cls._refresh_database_pks(model_cls, objs)
250+
251+
concrete_model = True
252+
for parent in model_cls._meta.get_parent_list():
253+
if parent._meta.concrete_model is not model_cls._meta.concrete_model:
254+
concrete_model = False
255+
256+
if concrete_model:
257+
manager.bulk_create(objs)
258+
else:
259+
concrete_fields = model_cls._meta.local_fields
260+
connection = connections[cls._meta.database]
261+
262+
# Avoids writing the INSERT INTO sql script manually
263+
query = InsertQuery(model_cls)
264+
query.insert_values(concrete_fields, objs)
265+
query.get_compiler(connection=connection).execute_sql()
266+
267+
return models_to_return
268+
168269
@classmethod
169270
def _create(cls, model_class, *args, **kwargs):
170271
"""Create an instance of the model, and save it to the database."""
@@ -272,6 +373,82 @@ def _make_data(self, params):
272373
return thumb_io.getvalue()
273374

274375

376+
def dependency_insert_order(data):
377+
"""This is almost the same function from django/core/serializers/__init__.py:sort_dependencies with a slight
378+
modification on `if hasattr(rel_model, 'natural_key') and rel_model != model:` that was removed, so we have the
379+
REAL dependency order. The original implementation was setup to only write to fields in order if they had a known
380+
dependency, we always want it in order regardless of the natural_key.
381+
"""
382+
383+
lookup = []
384+
model_cls_by_data = defaultdict(list)
385+
for instance in data:
386+
# Instance has been persisted in the database
387+
if not instance._state.adding:
388+
continue
389+
# Instance already in the list
390+
if id(instance) in lookup:
391+
continue
392+
model_cls_by_data[type(instance)].append(instance)
393+
394+
# Avoid data leaks
395+
del lookup
396+
del data
397+
398+
# Process the list of models, and get the list of dependencies
399+
model_dependencies = []
400+
models = list(model_cls_by_data.keys())
401+
402+
for model in models:
403+
deps = set()
404+
405+
# Now add a dependency for any FK relation with a model that
406+
# defines a natural key
407+
for field in model._meta.fields:
408+
rel_model = field.related_model
409+
if rel_model and rel_model != model:
410+
deps.add(rel_model)
411+
412+
model_dependencies.append((model, deps))
413+
414+
model_dependencies.reverse()
415+
# Now sort the models to ensure that dependencies are met. This
416+
# is done by repeatedly iterating over the input list of models.
417+
# If all the dependencies of a given model are in the final list,
418+
# that model is promoted to the end of the final list. This process
419+
# continues until the input list is empty, or we do a full iteration
420+
# over the input models without promoting a model to the final list.
421+
# If we do a full iteration without a promotion, that means there are
422+
# circular dependencies in the list.
423+
model_list = []
424+
while model_dependencies:
425+
skipped = []
426+
changed = False
427+
while model_dependencies:
428+
model, deps = model_dependencies.pop()
429+
430+
# If all of the models in the dependency list are either already
431+
# on the final model list, or not on the original serialization list,
432+
# then we've found another model with all it's dependencies satisfied.
433+
found = True
434+
for candidate in ((d not in models or d in model_list) for d in deps):
435+
if not candidate:
436+
found = False
437+
if found:
438+
model_list.append(model)
439+
changed = True
440+
else:
441+
skipped.append((model, deps))
442+
if not changed:
443+
unresolved_models = (f'{model._meta.app_label}.{model._meta.object_name}'
444+
for model, _ in sorted(skipped, key=lambda obj: obj[0].__name__))
445+
message = f"Can't resolve dependencies for {', '.join(unresolved_models)}."
446+
raise RuntimeError(message)
447+
model_dependencies = skipped
448+
449+
return [(model_cls, model_cls_by_data[model_cls]) for model_cls in model_list]
450+
451+
275452
class mute_signals:
276453
"""Temporarily disables and then restores any django signals.
277454
@@ -327,6 +504,7 @@ def __call__(self, callable_obj):
327504
if isinstance(callable_obj, base.FactoryMetaClass):
328505
# Retrieve __func__, the *actual* callable object.
329506
callable_obj._create = self.wrap_method(callable_obj._create.__func__)
507+
callable_obj._bulk_create = self.wrap_method(callable_obj._bulk_create.__func__)
330508
callable_obj._generate = self.wrap_method(callable_obj._generate.__func__)
331509
callable_obj._after_postgeneration = self.wrap_method(
332510
callable_obj._after_postgeneration.__func__

tests/djapp/models.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import os.path
66

77
from django.conf import settings
8+
from django.contrib.contenttypes.fields import GenericForeignKey
9+
from django.contrib.contenttypes.models import ContentType
810
from django.db import models
911
from django.db.models import signals
1012

@@ -137,3 +139,46 @@ class FromAbstractWithCustomManager(AbstractWithCustomManager):
137139

138140
class HasMultifieldModel(models.Model):
139141
multifield = models.ForeignKey(to=MultifieldModel, on_delete=models.CASCADE)
142+
143+
144+
class P(models.Model):
145+
pass
146+
147+
148+
class R(models.Model):
149+
is_default = models.BooleanField(default=False)
150+
p = models.ForeignKey(P, models.CASCADE, null=True)
151+
152+
153+
class S(models.Model):
154+
r = models.ForeignKey(R, models.CASCADE)
155+
156+
157+
class T(models.Model):
158+
s = models.ForeignKey(S, models.CASCADE)
159+
160+
161+
class U(models.Model):
162+
t = models.ForeignKey(T, models.CASCADE)
163+
164+
165+
class RChild(R):
166+
text = models.CharField(max_length=10)
167+
168+
169+
class A(models.Model):
170+
p_o = models.OneToOneField('P', models.CASCADE, related_name="+")
171+
p_f = models.ForeignKey('P', models.CASCADE, related_name="+")
172+
p_m = models.ManyToManyField('P')
173+
174+
175+
class AA(models.Model):
176+
a = models.OneToOneField(A, models.CASCADE)
177+
u = models.OneToOneField(U, models.CASCADE)
178+
p = models.OneToOneField(P, models.CASCADE)
179+
180+
181+
class GenericModel(models.Model):
182+
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
183+
object_id = models.PositiveIntegerField()
184+
generic_obj = GenericForeignKey("content_type", "object_id")

tests/djapp/settings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323

2424

2525
INSTALLED_APPS = [
26-
'tests.djapp'
26+
'django.contrib.contenttypes',
27+
'tests.djapp',
2728
]
2829

2930
MIDDLEWARE_CLASSES = ()

0 commit comments

Comments
 (0)