Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 61 additions & 9 deletions django/db/models/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,8 +1166,8 @@ def in_bulk(self, id_list=None, *, field_name="pk"):
"""
if self.query.is_sliced:
raise TypeError("Cannot use 'limit' or 'offset' with in_bulk().")
if not issubclass(self._iterable_class, ModelIterable):
raise TypeError("in_bulk() cannot be used with values() or values_list().")
if id_list is not None and not id_list:
return {}
opts = self.model._meta
unique_fields = [
constraint.fields[0]
Expand All @@ -1184,24 +1184,76 @@ def in_bulk(self, id_list=None, *, field_name="pk"):
"in_bulk()'s field_name must be a unique field but %r isn't."
% field_name
)

qs = self

def get_obj(obj):
return obj

if issubclass(self._iterable_class, ModelIterable):
# Raise an AttributeError if field_name is deferred.
get_key = operator.attrgetter(field_name)

elif issubclass(self._iterable_class, ValuesIterable):
if field_name not in self.query.values_select:
qs = qs.values(field_name, *self.query.values_select)

def get_obj(obj): # noqa: F811
# We can safely mutate the dictionaries returned by
# ValuesIterable here, since they are limited to the scope
# of this function, and get_key runs before get_obj.
del obj[field_name]
return obj

get_key = operator.itemgetter(field_name)

elif issubclass(self._iterable_class, ValuesListIterable):
try:
field_index = self.query.values_select.index(field_name)
except ValueError:
# field_name is missing from values_select, so add it.
field_index = 0
if issubclass(self._iterable_class, NamedValuesListIterable):
kwargs = {"named": True}
else:
kwargs = {}
get_obj = operator.itemgetter(slice(1, None))
qs = qs.values_list(field_name, *self.query.values_select, **kwargs)

get_key = operator.itemgetter(field_index)

elif issubclass(self._iterable_class, FlatValuesListIterable):
if self.query.values_select == (field_name,):
# Mapping field_name to itself.
get_key = get_obj
else:
# Transform it back into a non-flat values_list().
qs = qs.values_list(field_name, *self.query.values_select)
get_key = operator.itemgetter(0)
get_obj = operator.itemgetter(1)

else:
raise TypeError(
f"in_bulk() cannot be used with {self._iterable_class.__name__}."
)

if id_list is not None:
if not id_list:
return {}
filter_key = "{}__in".format(field_name)
id_list = tuple(id_list)
batch_size = connections[self.db].ops.bulk_batch_size([opts.pk], id_list)
# If the database has a limit on the number of query parameters
# (e.g. SQLite), retrieve objects in batches if necessary.
if batch_size and batch_size < len(id_list):
qs = ()
results = ()
for offset in range(0, len(id_list), batch_size):
batch = id_list[offset : offset + batch_size]
qs += tuple(self.filter(**{filter_key: batch}))
results += tuple(qs.filter(**{filter_key: batch}))
qs = results
else:
qs = self.filter(**{filter_key: id_list})
qs = qs.filter(**{filter_key: id_list})
else:
qs = self._chain()
return {getattr(obj, field_name): obj for obj in qs}
qs = qs._chain()
return {get_key(obj): get_obj(obj) for obj in qs}

async def ain_bulk(self, id_list=None, *, field_name="pk"):
return await sync_to_async(self.in_bulk)(
Expand Down
5 changes: 5 additions & 0 deletions docs/ref/models/querysets.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2588,6 +2588,11 @@ Example:

If you pass ``in_bulk()`` an empty list, you'll get an empty dictionary.

.. versionchanged:: 6.1

Support for chaining ``in_bulk()`` after :meth:`values` or
:meth:`values_list` was added.

``iterator()``
~~~~~~~~~~~~~~

Expand Down
3 changes: 2 additions & 1 deletion docs/releases/6.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ Migrations
Models
~~~~~~

* ...
* :meth:`.QuerySet.in_bulk` now supports chaining after
:meth:`.QuerySet.values` and :meth:`.QuerySet.values_list`.

Pagination
~~~~~~~~~~
Expand Down
61 changes: 61 additions & 0 deletions tests/composite_pk/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,67 @@ def test_in_bulk_batching(self):
comment_dict = Comment.objects.in_bulk(id_list=id_list)
self.assertQuerySetEqual(comment_dict, id_list)

def test_in_bulk_values(self):
result = Comment.objects.values().in_bulk([self.comment.pk])
self.assertEqual(
result,
{
self.comment.pk: {
"tenant_id": self.comment.tenant_id,
"id": self.comment.id,
"user_id": self.comment.user_id,
"text": self.comment.text,
"integer": self.comment.integer,
}
},
)

def test_in_bulk_values_field(self):
result = Comment.objects.values("text").in_bulk([self.comment.pk])
self.assertEqual(
result,
{self.comment.pk: {"text": self.comment.text}},
)

def test_in_bulk_values_fields(self):
result = Comment.objects.values("pk", "text").in_bulk([self.comment.pk])
self.assertEqual(
result,
{self.comment.pk: {"pk": self.comment.pk, "text": self.comment.text}},
)

def test_in_bulk_values_list(self):
result = Comment.objects.values_list("text").in_bulk([self.comment.pk])
self.assertEqual(result, {self.comment.pk: (self.comment.text,)})

def test_in_bulk_values_list_multiple_fields(self):
result = Comment.objects.values_list("pk", "text").in_bulk([self.comment.pk])
self.assertEqual(
result, {self.comment.pk: (self.comment.pk, self.comment.text)}
)

def test_in_bulk_values_list_fields_are_pk(self):
result = Comment.objects.values_list("tenant", "id").in_bulk([self.comment.pk])
self.assertEqual(
result, {self.comment.pk: (self.comment.tenant_id, self.comment.id)}
)

def test_in_bulk_values_list_flat(self):
result = Comment.objects.values_list("text", flat=True).in_bulk(
[self.comment.pk]
)
self.assertEqual(result, {self.comment.pk: self.comment.text})

def test_in_bulk_values_list_flat_pk(self):
result = Comment.objects.values_list("pk", flat=True).in_bulk([self.comment.pk])
self.assertEqual(result, {self.comment.pk: self.comment.pk})

def test_in_bulk_values_list_flat_tenant(self):
result = Comment.objects.values_list("tenant", flat=True).in_bulk(
[self.comment.pk]
)
self.assertEqual(result, {self.comment.pk: self.tenant.id})

def test_iterator(self):
"""
Test the .iterator() method of composite_pk models.
Expand Down
Loading
Loading