diff --git a/django/contrib/postgres/search.py b/django/contrib/postgres/search.py index 4ab27605cb6b..52e925d27a4d 100644 --- a/django/contrib/postgres/search.py +++ b/django/contrib/postgres/search.py @@ -1,3 +1,4 @@ +from django.db.backends.postgresql.psycopg_any import is_psycopg3 from django.db.models import ( CharField, Expression, @@ -10,9 +11,45 @@ ) from django.db.models.expressions import CombinedExpression, register_combinable_fields from django.db.models.functions import Cast, Coalesce +from django.utils.regex_helper import _lazy_re_compile from .utils import CheckPostgresInstalledMixin +if is_psycopg3: + from psycopg.adapt import Dumper + + class UTF8Dumper(Dumper): + def dump(self, obj): + return bytes(obj, "utf-8") + + def quote_lexeme(value): + return UTF8Dumper(str).quote(psql_escape(value)).decode() + +else: + from psycopg2.extensions import adapt + + def quote_lexeme(value): + adapter = adapt(psql_escape(value)) + adapter.encoding = "utf-8" + return adapter.getquoted().decode() + + +spec_chars_re = _lazy_re_compile(r"['\0\[\]()|&:*!@<>\\]") +multiple_spaces_re = _lazy_re_compile(r"\s{2,}") + + +def normalize_spaces(val): + """Convert multiple spaces to single and strip from both sides.""" + if not (val := val.strip()): + return None + return multiple_spaces_re.sub(" ", val) + + +def psql_escape(query): + """Replace chars not fit for use in search queries with a single space.""" + query = spec_chars_re.sub(" ", query) + return normalize_spaces(query) + class SearchVectorExact(Lookup): lookup_name = "exact" @@ -205,6 +242,9 @@ def __init__( invert=False, search_type="plain", ): + if isinstance(value, LexemeCombinable): + search_type = "raw" + self.function = self.SEARCH_TYPES.get(search_type) if self.function is None: raise ValueError("Unknown search_type argument '%s'." % search_type) @@ -383,3 +423,104 @@ class TrigramWordSimilarity(TrigramWordBase): class TrigramStrictWordSimilarity(TrigramWordBase): function = "STRICT_WORD_SIMILARITY" + + +class LexemeCombinable: + BITAND = "&" + BITOR = "|" + + def _combine(self, other, connector, reversed): + if not isinstance(other, LexemeCombinable): + raise TypeError( + "A Lexeme can only be combined with another Lexeme, " + f"got {other.__class__.__name__}." + ) + if reversed: + return CombinedLexeme(other, connector, self) + return CombinedLexeme(self, connector, other) + + # On Combinable, these are not implemented to reduce confusion with Q. In + # this case we are actually (ab)using them to do logical combination so + # it's consistent with other usage in Django. + def __or__(self, other): + return self._combine(other, self.BITOR, False) + + def __ror__(self, other): + return self._combine(other, self.BITOR, True) + + def __and__(self, other): + return self._combine(other, self.BITAND, False) + + def __rand__(self, other): + return self._combine(other, self.BITAND, True) + + +class Lexeme(LexemeCombinable, Value): + _output_field = SearchQueryField() + + def __init__( + self, value, output_field=None, *, invert=False, prefix=False, weight=None + ): + if value == "": + raise ValueError("Lexeme value cannot be empty.") + + if not isinstance(value, str): + raise TypeError( + f"Lexeme value must be a string, got {value.__class__.__name__}." + ) + + if weight is not None and ( + not isinstance(weight, str) or weight.lower() not in {"a", "b", "c", "d"} + ): + raise ValueError( + f"Weight must be one of 'A', 'B', 'C', and 'D', got {weight!r}." + ) + + self.prefix = prefix + self.invert = invert + self.weight = weight + super().__init__(value, output_field=output_field) + + def as_sql(self, compiler, connection): + param = quote_lexeme(self.value) + label = "" + if self.prefix: + label += "*" + if self.weight: + label += self.weight + + if label: + param = f"{param}:{label}" + if self.invert: + param = f"!{param}" + + return "%s", (param,) + + def __invert__(self): + cloned = self.copy() + cloned.invert = not self.invert + return cloned + + +class CombinedLexeme(LexemeCombinable, CombinedExpression): + _output_field = SearchQueryField() + + def as_sql(self, compiler, connection): + value_params = [] + lsql, params = compiler.compile(self.lhs) + value_params.extend(params) + + rsql, params = compiler.compile(self.rhs) + value_params.extend(params) + + combined_sql = f"({lsql} {self.connector} {rsql})" + combined_value = combined_sql % tuple(value_params) + return "%s", (combined_value,) + + def __invert__(self): + # Apply De Morgan's theorem. + cloned = self.copy() + cloned.connector = self.BITAND if self.connector == self.BITOR else self.BITOR + cloned.lhs = ~self.lhs + cloned.rhs = ~self.rhs + return cloned diff --git a/django/db/models/base.py b/django/db/models/base.py index 518cfc44a252..82ea520065e6 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -1153,7 +1153,8 @@ def _save_table( getattr(self, field.attname) if raw else field.pre_save(self, False) ) if hasattr(value, "resolve_expression"): - returning_fields.append(field) + if field not in returning_fields: + returning_fields.append(field) elif field.db_returning: returning_fields.remove(field) results = self._do_insert( @@ -1357,7 +1358,7 @@ def _get_field_expression_map(self, meta, exclude=None): meta = meta or self._meta field_map = {} generated_fields = [] - for field in meta.local_concrete_fields: + for field in meta.local_fields: if field.name in exclude: continue if field.generated: @@ -1368,7 +1369,19 @@ def _get_field_expression_map(self, meta, exclude=None): continue generated_fields.append(field) continue - value = getattr(self, field.attname) + if ( + isinstance(field.remote_field, ForeignObjectRel) + and field not in meta.local_concrete_fields + ): + value = tuple( + getattr(self, from_field) for from_field in field.from_fields + ) + if len(value) == 1: + value = value[0] + elif field.concrete: + value = getattr(self, field.attname) + else: + continue if not value or not hasattr(value, "resolve_expression"): value = Value(value, field) field_map[field.name] = value diff --git a/django/db/models/query.py b/django/db/models/query.py index 2359ee3bb4a3..0de5787f426a 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -2333,8 +2333,8 @@ def normalize_prefetch_lookups(lookups, prefix=None): def prefetch_related_objects(model_instances, *related_lookups): """ - Populate prefetched object caches for a list of model instances based on - the lookups/Prefetch instances given. + Populate prefetched object caches for an iterable of model instances based + on the lookups/Prefetch instances given. """ if not model_instances: return # nothing to do @@ -2402,7 +2402,7 @@ def prefetch_related_objects(model_instances, *related_lookups): # We assume that objects retrieved are homogeneous (which is the # premise of prefetch_related), so what applies to first object # applies to all. - first_obj = obj_list[0] + first_obj = next(iter(obj_list)) to_attr = lookup.get_current_to_attr(level)[0] prefetcher, descriptor, attr_found, is_fetched = get_prefetcher( first_obj, through_attr, to_attr diff --git a/docs/ref/contrib/postgres/search.txt b/docs/ref/contrib/postgres/search.txt index 4647fcbfa26a..88e3cfaeb045 100644 --- a/docs/ref/contrib/postgres/search.txt +++ b/docs/ref/contrib/postgres/search.txt @@ -96,7 +96,7 @@ Examples: .. code-block:: pycon - >>> from django.contrib.postgres.search import SearchQuery + >>> from django.contrib.postgres.search import SearchQuery, Lexeme >>> SearchQuery("red tomato") # two keywords >>> SearchQuery("tomato red") # same results as above >>> SearchQuery("red tomato", search_type="phrase") # a phrase @@ -105,6 +105,7 @@ Examples: >>> SearchQuery( ... "'tomato' ('red' OR 'green')", search_type="websearch" ... ) # websearch operators + >>> SearchQuery(Lexeme("tomato") & (Lexeme("red") | Lexeme("green"))) # Lexeme objects ``SearchQuery`` terms can be combined logically to provide more flexibility: @@ -118,6 +119,10 @@ Examples: See :ref:`postgresql-fts-search-configuration` for an explanation of the ``config`` parameter. +.. versionchanged:: 6.0 + + :class:`Lexeme` objects were added. + ``SearchRank`` ============== @@ -276,6 +281,53 @@ floats to :class:`SearchRank` as ``weights`` in the same order above: >>> rank = SearchRank(vector, query, weights=[0.2, 0.4, 0.6, 0.8]) >>> Entry.objects.annotate(rank=rank).filter(rank__gte=0.3).order_by("-rank") +``Lexeme`` +========== + +.. versionadded:: 6.0 + +.. class:: Lexeme(value, output_field=None, *, invert=False, prefix=False, weight=None) + +``Lexeme`` objects allow search operators to be safely used with strings from +an untrusted source. The content of each lexeme is escaped so that any +operators that may exist in the string itself will not be interpreted. + +You can combine lexemes with other lexemes using the ``&`` and ``|`` operators +and also negate them with the ``~`` operator. For example: + +.. code-block:: pycon + + >>> from django.contrib.postgres.search import SearchQuery, SearchVector, Lexeme + >>> vector = SearchVector("body_text", "blog__tagline") + >>> Entry.objects.annotate(search=vector).filter( + ... search=SearchQuery(Lexeme("fruit") & Lexeme("dessert")) + ... ) + , ]> + +.. code-block:: pycon + + >>> Entry.objects.annotate(search=vector).filter( + ... search=SearchQuery(Lexeme("fruit") & Lexeme("dessert") & ~Lexeme("banana")) + ... ) + ]> + +Lexeme objects also support term weighting and prefixes: + +.. code-block:: pycon + + >>> Entry.objects.annotate(search=vector).filter( + ... search=SearchQuery(Lexeme("Pizza") | Lexeme("Cheese")) + ... ) + , ]> + >>> Entry.objects.annotate(search=vector).filter( + ... search=SearchQuery(Lexeme("Pizza") | Lexeme("Cheese", weight="A")) + ... ) + ]> + >>> Entry.objects.annotate(search=vector).filter( + ... search=SearchQuery(Lexeme("za", prefix=True)) + ... ) + + Performance =========== diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index baeb3e8746e0..59550e669097 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -4223,8 +4223,9 @@ Prefetches the given lookups on an iterable of model instances. This is useful in code that receives a list of model instances as opposed to a ``QuerySet``; for example, when fetching models from a cache or instantiating them manually. -Pass an iterable of model instances (must all be of the same class) and the -lookups or :class:`Prefetch` objects you want to prefetch for. For example: +Pass an iterable of model instances (must all be of the same class and able to +be iterated multiple times) and the lookups or :class:`Prefetch` objects you +want to prefetch for. For example: .. code-block:: pycon diff --git a/docs/releases/6.0.txt b/docs/releases/6.0.txt index adfac83b8da2..fba0935a2bbe 100644 --- a/docs/releases/6.0.txt +++ b/docs/releases/6.0.txt @@ -171,6 +171,12 @@ Minor features :mod:`django.contrib.postgres` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* The new :class:`Lexeme ` expression + for full text search provides fine-grained control over search terms. + ``Lexeme`` objects automatically escape their input and support logical + combination operators (``&``, ``|``, ``~``), prefix matching, and term + weighting. + * Model fields, indexes, and constraints from :mod:`django.contrib.postgres` now include system checks to verify that ``django.contrib.postgres`` is an installed app. diff --git a/tests/foreign_object/models/__init__.py b/tests/foreign_object/models/__init__.py index 69778d3dddd0..d2ddc6864672 100644 --- a/tests/foreign_object/models/__init__.py +++ b/tests/foreign_object/models/__init__.py @@ -1,5 +1,5 @@ from .article import Article, ArticleIdea, ArticleTag, ArticleTranslation, NewsArticle -from .customers import Address, Contact, Customer +from .customers import Address, Contact, Customer, CustomerTab from .empty_join import SlugPage from .person import Country, Friendship, Group, Membership, Person @@ -12,6 +12,7 @@ "Contact", "Country", "Customer", + "CustomerTab", "Friendship", "Group", "Membership", diff --git a/tests/foreign_object/models/customers.py b/tests/foreign_object/models/customers.py index 91ac0915242b..085b7272e98d 100644 --- a/tests/foreign_object/models/customers.py +++ b/tests/foreign_object/models/customers.py @@ -39,3 +39,22 @@ class Contact(models.Model): to_fields=["customer_id", "company"], from_fields=["customer_code", "company_code"], ) + + +class CustomerTab(models.Model): + customer_id = models.IntegerField() + customer = models.ForeignObject( + Customer, + from_fields=["customer_id"], + to_fields=["id"], + on_delete=models.CASCADE, + ) + + class Meta: + required_db_features = {"supports_table_check_constraints"} + constraints = [ + models.CheckConstraint( + condition=models.Q(customer__lt=1000), + name="customer_id_limit", + ), + ] diff --git a/tests/foreign_object/tests.py b/tests/foreign_object/tests.py index b4072d500d13..09fb47e771dc 100644 --- a/tests/foreign_object/tests.py +++ b/tests/foreign_object/tests.py @@ -3,10 +3,10 @@ import pickle from operator import attrgetter -from django.core.exceptions import FieldError +from django.core.exceptions import FieldError, ValidationError from django.db import connection, models from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature -from django.test.utils import isolate_apps +from django.test.utils import CaptureQueriesContext, isolate_apps from django.utils import translation from .models import ( @@ -15,6 +15,7 @@ ArticleTag, ArticleTranslation, Country, + CustomerTab, Friendship, Group, Membership, @@ -767,3 +768,33 @@ def test_pickling_foreignobject(self): foreign_object_restored = pickle.loads(pickle.dumps(foreign_object)) self.assertIn("path_infos", foreign_object_restored.__dict__) self.assertIn("reverse_path_infos", foreign_object_restored.__dict__) + + +class ForeignObjectModelValidationTests(TestCase): + @skipUnlessDBFeature("supports_table_check_constraints") + def test_validate_constraints_with_foreign_object(self): + customer_tab = CustomerTab(customer_id=1500) + with self.assertRaisesMessage(ValidationError, "customer_id_limit"): + customer_tab.validate_constraints() + + @skipUnlessDBFeature("supports_table_check_constraints") + def test_validate_constraints_success_case_single_query(self): + customer_tab = CustomerTab(customer_id=500) + with CaptureQueriesContext(connection) as ctx: + customer_tab.validate_constraints() + select_queries = [ + query["sql"] + for query in ctx.captured_queries + if "select" in query["sql"].lower() + ] + self.assertEqual(len(select_queries), 1) + + @skipUnlessDBFeature("supports_table_check_constraints") + def test_validate_constraints_excluding_foreign_object(self): + customer_tab = CustomerTab(customer_id=150) + customer_tab.validate_constraints(exclude={"customer"}) + + @skipUnlessDBFeature("supports_table_check_constraints") + def test_validate_constraints_excluding_foreign_object_member(self): + customer_tab = CustomerTab(customer_id=150) + customer_tab.validate_constraints(exclude={"customer_id"}) diff --git a/tests/postgres_tests/test_search.py b/tests/postgres_tests/test_search.py index a7118e7c79d7..c206c69747d2 100644 --- a/tests/postgres_tests/test_search.py +++ b/tests/postgres_tests/test_search.py @@ -6,6 +6,7 @@ transcript. """ +from django.db import connection from django.db.models import F, Value from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase @@ -13,11 +14,13 @@ try: from django.contrib.postgres.search import ( + Lexeme, SearchConfig, SearchHeadline, SearchQuery, SearchRank, SearchVector, + quote_lexeme, ) except ImportError: pass @@ -769,3 +772,223 @@ def test_headline_fragments_words_options(self): "Brave, brave, brave...
" "brave Sir Robin", ) + + +class TestLexemes(GrailTestData, PostgreSQLTestCase): + def test_and(self): + searched = Line.objects.annotate( + search=SearchVector("scene__setting", "dialogue"), + ).filter(search=SearchQuery(Lexeme("bedemir") & Lexeme("scales"))) + self.assertSequenceEqual(searched, [self.bedemir0]) + + def test_multiple_and(self): + searched = Line.objects.annotate( + search=SearchVector("scene__setting", "dialogue"), + ).filter( + search=SearchQuery( + Lexeme("bedemir") & Lexeme("scales") & Lexeme("nostrils") + ) + ) + self.assertSequenceEqual(searched, []) + + searched = Line.objects.annotate( + search=SearchVector("scene__setting", "dialogue"), + ).filter(search=SearchQuery(Lexeme("shall") & Lexeme("use") & Lexeme("larger"))) + self.assertSequenceEqual(searched, [self.bedemir0]) + + def test_or(self): + searched = Line.objects.annotate(search=SearchVector("dialogue")).filter( + search=SearchQuery(Lexeme("kneecaps") | Lexeme("nostrils")) + ) + self.assertCountEqual(searched, [self.verse1, self.verse2]) + + def test_multiple_or(self): + searched = Line.objects.annotate(search=SearchVector("dialogue")).filter( + search=SearchQuery( + Lexeme("kneecaps") | Lexeme("nostrils") | Lexeme("Sir Robin") + ) + ) + self.assertCountEqual(searched, [self.verse1, self.verse2, self.verse0]) + + def test_advanced(self): + """ + Combination of & and | + This is mainly helpful for checking the test_advanced_invert below + """ + searched = Line.objects.annotate(search=SearchVector("dialogue")).filter( + search=SearchQuery( + Lexeme("shall") & Lexeme("use") & Lexeme("larger") | Lexeme("nostrils") + ) + ) + self.assertCountEqual(searched, [self.bedemir0, self.verse2]) + + def test_invert(self): + searched = Line.objects.annotate(search=SearchVector("dialogue")).filter( + character=self.minstrel, search=SearchQuery(~Lexeme("kneecaps")) + ) + self.assertCountEqual(searched, [self.verse0, self.verse2]) + + def test_advanced_invert(self): + """ + Inverting a query that uses a combination of & and | + should return the opposite of test_advanced. + """ + searched = Line.objects.annotate(search=SearchVector("dialogue")).filter( + search=SearchQuery( + ~( + Lexeme("shall") & Lexeme("use") & Lexeme("larger") + | Lexeme("nostrils") + ) + ) + ) + expected_result = Line.objects.exclude( + id__in=[self.bedemir0.id, self.verse2.id] + ) + self.assertCountEqual(searched, expected_result) + + def test_as_sql(self): + query = Line.objects.all().query + compiler = query.get_compiler(connection.alias) + + tests = ( + (Lexeme("a"), ("'a'",)), + (Lexeme("a", invert=True), ("!'a'",)), + (~Lexeme("a"), ("!'a'",)), + (Lexeme("a", prefix=True), ("'a':*",)), + (Lexeme("a", weight="D"), ("'a':D",)), + (Lexeme("a", invert=True, prefix=True, weight="D"), ("!'a':*D",)), + (Lexeme("a") | Lexeme("b") & ~Lexeme("c"), ("('a' | ('b' & !'c'))",)), + ( + ~(Lexeme("a") | Lexeme("b") & ~Lexeme("c")), + ("(!'a' & (!'b' | 'c'))",), + ), + ) + + for expression, expected_params in tests: + with self.subTest(expression=expression, expected_params=expected_params): + _, params = expression.as_sql(compiler, connection) + self.assertEqual(params, expected_params) + + def test_quote_lexeme(self): + tests = ( + ("L'amour piqué par une abeille", "'L amour piqué par une abeille'"), + ("'starting quote", "'starting quote'"), + ("ending quote'", "'ending quote'"), + ("double quo''te", "'double quo te'"), + ("triple quo'''te", "'triple quo te'"), + ("backslash\\", "'backslash'"), + ("exclamation!", "'exclamation'"), + ("ampers&nd", "'ampers nd'"), + ) + for lexeme, quoted in tests: + with self.subTest(lexeme=lexeme): + self.assertEqual(quote_lexeme(lexeme), quoted) + + def test_prefix_searching(self): + searched = Line.objects.annotate( + search=SearchVector("scene__setting", "dialogue"), + ).filter(search=SearchQuery(Lexeme("hear", prefix=True))) + + self.assertSequenceEqual(searched, [self.verse2]) + + def test_inverse_prefix_searching(self): + searched = Line.objects.annotate( + search=SearchVector("scene__setting", "dialogue"), + ).filter(search=SearchQuery(Lexeme("Robi", prefix=True, invert=True))) + self.assertEqual( + set(searched), + { + self.verse2, + self.bedemir0, + self.bedemir1, + self.french, + self.crowd, + self.witch, + self.duck, + }, + ) + + def test_lexemes_multiple_and(self): + searched = Line.objects.annotate( + search=SearchVector("scene__setting", "dialogue"), + ).filter( + search=SearchQuery( + Lexeme("Robi", prefix=True) & Lexeme("Camel", prefix=True) + ) + ) + + self.assertSequenceEqual(searched, [self.verse0]) + + def test_lexemes_multiple_or(self): + searched = Line.objects.annotate( + search=SearchVector("scene__setting", "dialogue"), + ).filter( + search=SearchQuery( + Lexeme("kneecap", prefix=True) | Lexeme("afrai", prefix=True) + ) + ) + + self.assertSequenceEqual(searched, [self.verse0, self.verse1]) + + def test_config_query_explicit(self): + searched = Line.objects.annotate( + search=SearchVector("scene__setting", "dialogue", config="french"), + ).filter(search=SearchQuery(Lexeme("cadeaux"), config="french")) + + self.assertSequenceEqual(searched, [self.french]) + + def test_config_query_implicit(self): + searched = Line.objects.annotate( + search=SearchVector("scene__setting", "dialogue", config="french"), + ).filter(search=Lexeme("cadeaux")) + + self.assertSequenceEqual(searched, [self.french]) + + def test_config_from_field_explicit(self): + searched = Line.objects.annotate( + search=SearchVector( + "scene__setting", "dialogue", config=F("dialogue_config") + ), + ).filter(search=SearchQuery(Lexeme("cadeaux"), config=F("dialogue_config"))) + self.assertSequenceEqual(searched, [self.french]) + + def test_config_from_field_implicit(self): + searched = Line.objects.annotate( + search=SearchVector( + "scene__setting", "dialogue", config=F("dialogue_config") + ), + ).filter(search=Lexeme("cadeaux")) + self.assertSequenceEqual(searched, [self.french]) + + def test_invalid_combinations(self): + msg = "A Lexeme can only be combined with another Lexeme, got NoneType." + with self.assertRaisesMessage(TypeError, msg): + Line.objects.filter(dialogue__search=None | Lexeme("kneecaps")) + + with self.assertRaisesMessage(TypeError, msg): + Line.objects.filter(dialogue__search=None & Lexeme("kneecaps")) + + def test_invalid_weights(self): + invalid_weights = ["E", "Drandom", "AB", "C ", 0, "", " ", [1, 2, 3]] + for weight in invalid_weights: + with self.subTest(weight=weight): + with self.assertRaisesMessage( + ValueError, + f"Weight must be one of 'A', 'B', 'C', and 'D', got {weight!r}.", + ): + Line.objects.filter( + dialogue__search=Lexeme("kneecaps", weight=weight) + ) + + def test_empty(self): + with self.assertRaisesMessage(ValueError, "Lexeme value cannot be empty."): + Line.objects.annotate( + search=SearchVector("scene__setting", "dialogue") + ).filter(search=SearchQuery(Lexeme(""))) + + def test_non_string_values(self): + msg = "Lexeme value must be a string, got NoneType." + with self.assertRaisesMessage(TypeError, msg): + Line.objects.annotate( + search=SearchVector("scene__setting", "dialogue") + ).filter(search=SearchQuery(Lexeme(None))) diff --git a/tests/prefetch_related/test_prefetch_related_objects.py b/tests/prefetch_related/test_prefetch_related_objects.py index eea9a7fff78c..20f620417aea 100644 --- a/tests/prefetch_related/test_prefetch_related_objects.py +++ b/tests/prefetch_related/test_prefetch_related_objects.py @@ -1,3 +1,5 @@ +from collections import deque + from django.db.models import Prefetch, prefetch_related_objects from django.test import TestCase @@ -221,3 +223,32 @@ def test_prefetch_queryset(self): with self.assertNumQueries(0): self.assertCountEqual(book1.authors.all(), [self.author1, self.author2]) + + def test_prefetch_related_objects_with_various_iterables(self): + book = self.book1 + + class MyIterable: + def __iter__(self): + yield book + + cases = { + "set": {book}, + "tuple": (book,), + "dict_values": {"a": book}.values(), + "frozenset": frozenset([book]), + "deque": deque([book]), + "custom iterator": MyIterable(), + } + for case_type, case in cases.items(): + with self.subTest(case=case_type): + # Clear the prefetch cache. + book._prefetched_objects_cache = {} + with self.assertNumQueries(1): + prefetch_related_objects(case, "authors") + with self.assertNumQueries(0): + self.assertCountEqual( + book.authors.all(), [self.author1, self.author2, self.author3] + ) + + def test_prefetch_related_objects_empty(self): + prefetch_related_objects([], "authors") diff --git a/tests/queries/test_db_returning.py b/tests/queries/test_db_returning.py index 50c164a57f94..06efe023e101 100644 --- a/tests/queries/test_db_returning.py +++ b/tests/queries/test_db_returning.py @@ -42,6 +42,16 @@ def test_insert_returning_multiple(self): ), captured_queries[-1]["sql"], ) + self.assertEqual( + captured_queries[-1]["sql"] + .split("RETURNING ")[1] + .count( + connection.ops.quote_name( + ReturningModel._meta.get_field("created").column + ), + ), + 1, + ) self.assertTrue(obj.pk) self.assertIsInstance(obj.created, datetime.datetime)