99import logging
1010import os
1111import warnings
12+ from collections import defaultdict
1213from typing import Dict , TypeVar
1314
1415from django .contrib .auth .hashers import make_password
1516from django .core import files as django_files
16- from django .db import IntegrityError
17+ from django .db import IntegrityError , connections , models
18+ from django .db .models .sql import InsertQuery
1719
18- from . import base , declarations , errors
20+ from . import base , builder , declarations , enums , errors
1921
2022logger = logging .getLogger ('factory.generate' )
2123
22-
2324DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS
2425T = TypeVar ("T" )
25-
2626_LAZY_LOADS : Dict [str , object ] = {}
2727
2828
@@ -45,11 +45,29 @@ def _lazy_load_get_model():
4545 _LAZY_LOADS ['get_model' ] = django_apps .apps .get_model
4646
4747
48+ def connection_supports_bulk_insert (using ):
49+ """
50+ Does the database support bulk_insert
51+
52+ There are 2 pieces to this puzzle:
53+ * The database needs to support `bulk_insert`
54+ * AND it also needs to be capable of returning all the newly minted objects' id
55+
56+ If any of these is `False`, the database does NOT support bulk_insert
57+ """
58+ db_features = connections [using ].features
59+ return (
60+ db_features .has_bulk_insert
61+ and db_features .can_return_rows_from_bulk_insert
62+ )
63+
64+
4865class DjangoOptions (base .FactoryOptions ):
4966 def _build_default_options (self ):
5067 return super ()._build_default_options () + [
5168 base .OptionDefault ('django_get_or_create' , (), inherit = True ),
5269 base .OptionDefault ('database' , DEFAULT_DB_ALIAS , inherit = True ),
70+ base .OptionDefault ('use_bulk_create' , False , inherit = True ),
5371 base .OptionDefault ('skip_postgeneration_save' , False , inherit = True ),
5472 ]
5573
@@ -165,6 +183,89 @@ def _get_or_create(cls, model_class, *args, **kwargs):
165183
166184 return instance
167185
186+ @classmethod
187+ def supports_bulk_insert (cls ):
188+ return (cls ._meta .use_bulk_create
189+ and connection_supports_bulk_insert (cls ._meta .database ))
190+
191+ @classmethod
192+ def create (cls , ** kwargs ):
193+ """Create an instance of the associated class, with overridden attrs."""
194+ if not cls .supports_bulk_insert ():
195+ return super ().create (** kwargs )
196+
197+ return cls ._bulk_create (1 , ** kwargs )[0 ]
198+
199+ @classmethod
200+ def create_batch (cls , size , ** kwargs ):
201+ if not cls .supports_bulk_insert ():
202+ return super ().create_batch (size , ** kwargs )
203+
204+ return cls ._bulk_create (size , ** kwargs )
205+
206+ @classmethod
207+ def _refresh_database_pks (cls , model_cls , objs ):
208+ # Avoid causing a django.core.exceptions.AppRegistryNotReady throughout all the tests.
209+ # TODO: remove the `from . import django` from the `__init__.py`
210+ from django .contrib .contenttypes .fields import GenericForeignKey
211+
212+ def get_field_value (instance , field ):
213+ if isinstance (field , GenericForeignKey ) and field .is_cached (instance ):
214+ return field .get_cached_value (instance )
215+ return getattr (instance , field .name )
216+
217+ # Current Django version's GenericForeignKey is not made to work with bulk_insert.
218+ #
219+ # The issue is that it caches the object referenced, once the object is
220+ # saved and receives a pk, the cache no longer matches. It doesn't
221+ # matter that it's the same obj reference. This is to bypass that pk
222+ # check and reset it.
223+ fields_to_reset = (GenericForeignKey , models .OneToOneField )
224+
225+ fields = [f for f in model_cls ._meta .get_fields () if isinstance (f , fields_to_reset )]
226+ if not fields :
227+ return
228+
229+ for obj in objs :
230+ for field in fields :
231+ setattr (obj , field .name , get_field_value (obj , field ))
232+
233+ @classmethod
234+ def _bulk_create (cls , size , ** kwargs ):
235+ if cls ._meta .abstract :
236+ raise errors .FactoryError (
237+ "Cannot generate instances of abstract factory %(f)s; "
238+ "Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract "
239+ "is either not set or False." % dict (f = cls .__name__ ))
240+
241+ models_to_return = []
242+ instances = []
243+ for _ in range (size ):
244+ step = builder .StepBuilder (cls ._meta , kwargs , enums .BUILD_STRATEGY )
245+ models_to_return .append (step .build (collect_instances = instances ))
246+
247+ for model_cls , objs in dependency_insert_order (instances ):
248+ manager = cls ._get_manager (model_cls )
249+ cls ._refresh_database_pks (model_cls , objs )
250+
251+ concrete_model = True
252+ for parent in model_cls ._meta .get_parent_list ():
253+ if parent ._meta .concrete_model is not model_cls ._meta .concrete_model :
254+ concrete_model = False
255+
256+ if concrete_model :
257+ manager .bulk_create (objs )
258+ else :
259+ concrete_fields = model_cls ._meta .local_fields
260+ connection = connections [cls ._meta .database ]
261+
262+ # Avoids writing the INSERT INTO sql script manually
263+ query = InsertQuery (model_cls )
264+ query .insert_values (concrete_fields , objs )
265+ query .get_compiler (connection = connection ).execute_sql ()
266+
267+ return models_to_return
268+
168269 @classmethod
169270 def _create (cls , model_class , * args , ** kwargs ):
170271 """Create an instance of the model, and save it to the database."""
@@ -272,6 +373,82 @@ def _make_data(self, params):
272373 return thumb_io .getvalue ()
273374
274375
376+ def dependency_insert_order (data ):
377+ """This is almost the same function from django/core/serializers/__init__.py:sort_dependencies with a slight
378+ modification on `if hasattr(rel_model, 'natural_key') and rel_model != model:` that was removed, so we have the
379+ REAL dependency order. The original implementation was setup to only write to fields in order if they had a known
380+ dependency, we always want it in order regardless of the natural_key.
381+ """
382+
383+ lookup = []
384+ model_cls_by_data = defaultdict (list )
385+ for instance in data :
386+ # Instance has been persisted in the database
387+ if not instance ._state .adding :
388+ continue
389+ # Instance already in the list
390+ if id (instance ) in lookup :
391+ continue
392+ model_cls_by_data [type (instance )].append (instance )
393+
394+ # Avoid data leaks
395+ del lookup
396+ del data
397+
398+ # Process the list of models, and get the list of dependencies
399+ model_dependencies = []
400+ models = list (model_cls_by_data .keys ())
401+
402+ for model in models :
403+ deps = set ()
404+
405+ # Now add a dependency for any FK relation with a model that
406+ # defines a natural key
407+ for field in model ._meta .fields :
408+ rel_model = field .related_model
409+ if rel_model and rel_model != model :
410+ deps .add (rel_model )
411+
412+ model_dependencies .append ((model , deps ))
413+
414+ model_dependencies .reverse ()
415+ # Now sort the models to ensure that dependencies are met. This
416+ # is done by repeatedly iterating over the input list of models.
417+ # If all the dependencies of a given model are in the final list,
418+ # that model is promoted to the end of the final list. This process
419+ # continues until the input list is empty, or we do a full iteration
420+ # over the input models without promoting a model to the final list.
421+ # If we do a full iteration without a promotion, that means there are
422+ # circular dependencies in the list.
423+ model_list = []
424+ while model_dependencies :
425+ skipped = []
426+ changed = False
427+ while model_dependencies :
428+ model , deps = model_dependencies .pop ()
429+
430+ # If all of the models in the dependency list are either already
431+ # on the final model list, or not on the original serialization list,
432+ # then we've found another model with all it's dependencies satisfied.
433+ found = True
434+ for candidate in ((d not in models or d in model_list ) for d in deps ):
435+ if not candidate :
436+ found = False
437+ if found :
438+ model_list .append (model )
439+ changed = True
440+ else :
441+ skipped .append ((model , deps ))
442+ if not changed :
443+ unresolved_models = (f'{ model ._meta .app_label } .{ model ._meta .object_name } '
444+ for model , _ in sorted (skipped , key = lambda obj : obj [0 ].__name__ ))
445+ message = f"Can't resolve dependencies for { ', ' .join (unresolved_models )} ."
446+ raise RuntimeError (message )
447+ model_dependencies = skipped
448+
449+ return [(model_cls , model_cls_by_data [model_cls ]) for model_cls in model_list ]
450+
451+
275452class mute_signals :
276453 """Temporarily disables and then restores any django signals.
277454
@@ -327,6 +504,7 @@ def __call__(self, callable_obj):
327504 if isinstance (callable_obj , base .FactoryMetaClass ):
328505 # Retrieve __func__, the *actual* callable object.
329506 callable_obj ._create = self .wrap_method (callable_obj ._create .__func__ )
507+ callable_obj ._bulk_create = self .wrap_method (callable_obj ._bulk_create .__func__ )
330508 callable_obj ._generate = self .wrap_method (callable_obj ._generate .__func__ )
331509 callable_obj ._after_postgeneration = self .wrap_method (
332510 callable_obj ._after_postgeneration .__func__
0 commit comments