Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions django/db/models/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import operator
import warnings
from contextlib import nullcontext
from functools import reduce
from itertools import chain, islice

Expand Down Expand Up @@ -802,7 +803,11 @@ def bulk_create(
fields = [f for f in opts.concrete_fields if not f.generated]
objs = list(objs)
objs_with_pk, objs_without_pk = self._prepare_for_bulk_create(objs)
with transaction.atomic(using=self.db, savepoint=False):
if objs_with_pk and objs_without_pk:
context = transaction.atomic(using=self.db, savepoint=False)
else:
context = nullcontext()
with context:
self._handle_order_with_respect_to(objs)
if objs_with_pk:
returned_columns = self._batched_insert(
Expand Down Expand Up @@ -1918,11 +1923,21 @@ def _batched_insert(
max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
inserted_rows = []
bulk_return = connection.features.can_return_rows_from_bulk_insert
for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
if bulk_return and (
on_conflict is None or on_conflict == OnConflict.UPDATE
):
returning_fields = (
self.model._meta.db_returning_fields
if (
connection.features.can_return_rows_from_bulk_insert
and (on_conflict is None or on_conflict == OnConflict.UPDATE)
)
else None
)
batches = [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]
if len(batches) > 1:
context = transaction.atomic(using=self.db, savepoint=False)
else:
context = nullcontext()
with context:
for item in batches:
inserted_rows.extend(
self._insert(
item,
Expand All @@ -1931,18 +1946,9 @@ def _batched_insert(
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
returning_fields=self.model._meta.db_returning_fields,
returning_fields=returning_fields,
)
)
else:
self._insert(
item,
fields=fields,
using=self.db,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
)
return inserted_rows

def _chain(self):
Expand Down
45 changes: 38 additions & 7 deletions tests/bulk_create/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from django.db.models.functions import Lower, Now
from django.test import (
TestCase,
override_settings,
TransactionTestCase,
skipIfDBFeature,
skipUnlessDBFeature,
)
from django.test.utils import CaptureQueriesContext
from django.utils import timezone

from .models import (
Expand Down Expand Up @@ -216,12 +217,11 @@ def test_large_single_field_batch(self):

@skipUnlessDBFeature("has_bulk_insert")
def test_large_batch_efficiency(self):
with override_settings(DEBUG=True):
connection.queries_log.clear()
with CaptureQueriesContext(connection) as ctx:
TwoFields.objects.bulk_create(
[TwoFields(f1=i, f2=i + 1) for i in range(0, 1001)]
)
self.assertLess(len(connection.queries), 10)
self.assertLess(len(ctx), 10)

def test_large_batch_mixed(self):
"""
Expand All @@ -247,15 +247,14 @@ def test_large_batch_mixed_efficiency(self):
Test inserting a large batch with objects having primary key set
mixed together with objects without PK set.
"""
with override_settings(DEBUG=True):
connection.queries_log.clear()
with CaptureQueriesContext(connection) as ctx:
TwoFields.objects.bulk_create(
[
TwoFields(id=i if i % 2 == 0 else None, f1=i, f2=i + 1)
for i in range(100000, 101000)
]
)
self.assertLess(len(connection.queries), 10)
self.assertLess(len(ctx), 10)

def test_explicit_batch_size(self):
objs = [TwoFields(f1=i, f2=i) for i in range(0, 4)]
Expand Down Expand Up @@ -884,3 +883,35 @@ def test_db_default_field_excluded(self):
def test_db_default_primary_key(self):
(obj,) = DbDefaultPrimaryKey.objects.bulk_create([DbDefaultPrimaryKey()])
self.assertIsInstance(obj.id, datetime)


@skipUnlessDBFeature("supports_transactions", "has_bulk_insert")
class BulkCreateTransactionTests(TransactionTestCase):
available_apps = ["bulk_create"]

def test_no_unnecessary_transaction(self):
with self.assertNumQueries(1):
Country.objects.bulk_create(
[Country(id=1, name="France", iso_two_letter="FR")]
)
with self.assertNumQueries(1):
Country.objects.bulk_create([Country(name="Canada", iso_two_letter="CA")])

def test_objs_with_and_without_pk(self):
with self.assertNumQueries(4):
Country.objects.bulk_create(
[
Country(id=1, name="France", iso_two_letter="FR"),
Country(name="Canada", iso_two_letter="CA"),
]
)

def test_multiple_batches(self):
with self.assertNumQueries(4):
Country.objects.bulk_create(
[
Country(name="France", iso_two_letter="FR"),
Country(name="Canada", iso_two_letter="CA"),
],
batch_size=1,
)
Loading