diff --git a/django/db/models/query.py b/django/db/models/query.py index cc4d9c1f22dd..9f245b02ca35 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -5,6 +5,7 @@ import copy import operator import warnings +from contextlib import nullcontext from functools import reduce from itertools import chain, islice @@ -802,7 +803,11 @@ def bulk_create( fields = [f for f in opts.concrete_fields if not f.generated] objs = list(objs) objs_with_pk, objs_without_pk = self._prepare_for_bulk_create(objs) - with transaction.atomic(using=self.db, savepoint=False): + if objs_with_pk and objs_without_pk: + context = transaction.atomic(using=self.db, savepoint=False) + else: + context = nullcontext() + with context: self._handle_order_with_respect_to(objs) if objs_with_pk: returned_columns = self._batched_insert( @@ -1918,11 +1923,21 @@ def _batched_insert( max_batch_size = max(ops.bulk_batch_size(fields, objs), 1) batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size inserted_rows = [] - bulk_return = connection.features.can_return_rows_from_bulk_insert - for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]: - if bulk_return and ( - on_conflict is None or on_conflict == OnConflict.UPDATE - ): + returning_fields = ( + self.model._meta.db_returning_fields + if ( + connection.features.can_return_rows_from_bulk_insert + and (on_conflict is None or on_conflict == OnConflict.UPDATE) + ) + else None + ) + batches = [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)] + if len(batches) > 1: + context = transaction.atomic(using=self.db, savepoint=False) + else: + context = nullcontext() + with context: + for item in batches: inserted_rows.extend( self._insert( item, @@ -1931,18 +1946,9 @@ def _batched_insert( on_conflict=on_conflict, update_fields=update_fields, unique_fields=unique_fields, - returning_fields=self.model._meta.db_returning_fields, + returning_fields=returning_fields, ) ) - else: - self._insert( - item, - fields=fields, - using=self.db, - on_conflict=on_conflict, - update_fields=update_fields, - unique_fields=unique_fields, - ) return inserted_rows def _chain(self): diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py index d590a292de26..4fd9e6c7bfa1 100644 --- a/tests/bulk_create/tests.py +++ b/tests/bulk_create/tests.py @@ -14,10 +14,11 @@ from django.db.models.functions import Lower, Now from django.test import ( TestCase, - override_settings, + TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature, ) +from django.test.utils import CaptureQueriesContext from django.utils import timezone from .models import ( @@ -216,12 +217,11 @@ def test_large_single_field_batch(self): @skipUnlessDBFeature("has_bulk_insert") def test_large_batch_efficiency(self): - with override_settings(DEBUG=True): - connection.queries_log.clear() + with CaptureQueriesContext(connection) as ctx: TwoFields.objects.bulk_create( [TwoFields(f1=i, f2=i + 1) for i in range(0, 1001)] ) - self.assertLess(len(connection.queries), 10) + self.assertLess(len(ctx), 10) def test_large_batch_mixed(self): """ @@ -247,15 +247,14 @@ def test_large_batch_mixed_efficiency(self): Test inserting a large batch with objects having primary key set mixed together with objects without PK set. """ - with override_settings(DEBUG=True): - connection.queries_log.clear() + with CaptureQueriesContext(connection) as ctx: TwoFields.objects.bulk_create( [ TwoFields(id=i if i % 2 == 0 else None, f1=i, f2=i + 1) for i in range(100000, 101000) ] ) - self.assertLess(len(connection.queries), 10) + self.assertLess(len(ctx), 10) def test_explicit_batch_size(self): objs = [TwoFields(f1=i, f2=i) for i in range(0, 4)] @@ -884,3 +883,35 @@ def test_db_default_field_excluded(self): def test_db_default_primary_key(self): (obj,) = DbDefaultPrimaryKey.objects.bulk_create([DbDefaultPrimaryKey()]) self.assertIsInstance(obj.id, datetime) + + +@skipUnlessDBFeature("supports_transactions", "has_bulk_insert") +class BulkCreateTransactionTests(TransactionTestCase): + available_apps = ["bulk_create"] + + def test_no_unnecessary_transaction(self): + with self.assertNumQueries(1): + Country.objects.bulk_create( + [Country(id=1, name="France", iso_two_letter="FR")] + ) + with self.assertNumQueries(1): + Country.objects.bulk_create([Country(name="Canada", iso_two_letter="CA")]) + + def test_objs_with_and_without_pk(self): + with self.assertNumQueries(4): + Country.objects.bulk_create( + [ + Country(id=1, name="France", iso_two_letter="FR"), + Country(name="Canada", iso_two_letter="CA"), + ] + ) + + def test_multiple_batches(self): + with self.assertNumQueries(4): + Country.objects.bulk_create( + [ + Country(name="France", iso_two_letter="FR"), + Country(name="Canada", iso_two_letter="CA"), + ], + batch_size=1, + )