Skip to content

Commit fd569dd

Browse files
jacobtylerwallscharettes
authored andcommitted
Fixed #36210, Refs #36181 -- Allowed Subquery usage in further lookups against composite pks.
Follow-up to 8561100. Co-authored-by: Simon Charette <charette.s@gmail.com>
1 parent de7bb7e commit fd569dd

7 files changed

Lines changed: 101 additions & 6 deletions

File tree

django/db/backends/base/features.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,10 @@ class BaseDatabaseFeatures:
385385
# Does the backend support native tuple lookups (=, >, <, IN)?
386386
supports_tuple_lookups = True
387387

388+
# Does the backend support native tuple gt(e), lt(e) comparisons against
389+
# subqueries?
390+
supports_tuple_comparison_against_subquery = True
391+
388392
# Collation names for use by the Django test suite.
389393
test_collations = {
390394
"ci": None, # Case-insensitive.

django/db/backends/oracle/features.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
2121
can_return_columns_from_insert = True
2222
supports_subqueries_in_group_by = False
2323
ignores_unnecessary_order_by_in_subqueries = False
24+
supports_tuple_comparison_against_subquery = False
2425
supports_transactions = True
2526
supports_timezones = False
2627
has_native_duration_field = True

django/db/models/expressions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,6 +1781,7 @@ def __init__(self, queryset, output_field=None, **extra):
17811781
# Allow the usage of both QuerySet and sql.Query objects.
17821782
self.query = getattr(queryset, "query", queryset).clone()
17831783
self.query.subquery = True
1784+
self.template = extra.pop("template", self.template)
17841785
self.extra = extra
17851786
super().__init__(output_field)
17861787

@@ -1793,6 +1794,21 @@ def set_source_expressions(self, exprs):
17931794
def _resolve_output_field(self):
17941795
return self.query.output_field
17951796

1797+
def resolve_expression(self, *args, **kwargs):
1798+
resolved = super().resolve_expression(*args, **kwargs)
1799+
if type(self) is Subquery and self.template == Subquery.template:
1800+
resolved.query.contains_subquery = True
1801+
# Subquery is an unnecessary shim for a resolved query as it
1802+
# complexifies the lookup's right-hand-side introspection.
1803+
try:
1804+
self.output_field
1805+
except AttributeError:
1806+
return resolved.query
1807+
if self.output_field and self.output_field != resolved.query.output_field:
1808+
return ExpressionWrapper(resolved.query, output_field=self.output_field)
1809+
return resolved.query
1810+
return resolved
1811+
17961812
def copy(self):
17971813
clone = super().copy()
17981814
clone.query = clone.query.clone()

django/db/models/fields/tuple_lookups.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import itertools
22

33
from django.core.exceptions import EmptyResultSet
4-
from django.db import models
4+
from django.db import NotSupportedError, models
55
from django.db.models.expressions import (
66
ColPairs,
77
Exists,
@@ -129,6 +129,20 @@ def get_fallback_sql(self, compiler, connection):
129129
)
130130

131131
def as_sql(self, compiler, connection):
132+
if (
133+
not connection.features.supports_tuple_comparison_against_subquery
134+
and isinstance(self.rhs, Query)
135+
and self.rhs.subquery
136+
and isinstance(
137+
self, (GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual)
138+
)
139+
):
140+
lookup = self.lookup_name
141+
msg = (
142+
f'"{lookup}" cannot be used to target composite fields '
143+
"through subqueries on this backend"
144+
)
145+
raise NotSupportedError(msg)
132146
if not connection.features.supports_tuple_lookups:
133147
return self.get_fallback_sql(compiler, connection)
134148
return super().as_sql(compiler, connection)

django/db/models/sql/query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ class Query(BaseExpression):
242242

243243
filter_is_sticky = False
244244
subquery = False
245+
contains_subquery = False
245246

246247
# SQL-related attributes.
247248
# Select and related select clauses are expressions to use in the SELECT

tests/composite_pk/test_filter.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from unittest.mock import patch
22

3-
from django.db import connection
3+
from django.db import NotSupportedError, connection
44
from django.db.models import (
55
Case,
66
F,
@@ -14,7 +14,7 @@
1414
)
1515
from django.db.models.functions import Cast
1616
from django.db.models.lookups import Exact
17-
from django.test import TestCase, skipUnlessDBFeature
17+
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
1818

1919
from .models import Comment, Tenant, User
2020

@@ -492,6 +492,39 @@ def test_outer_ref_pk(self):
492492
queryset = Comment.objects.filter(**{f"id{lookup}": subquery})
493493
self.assertEqual(queryset.count(), expected_count)
494494

495+
def test_outer_ref_pk_filter_on_pk_exact(self):
496+
subquery = Subquery(User.objects.filter(pk=OuterRef("pk")).values("pk")[:1])
497+
qs = Comment.objects.filter(pk=subquery)
498+
self.assertEqual(qs.count(), 2)
499+
500+
@skipUnlessDBFeature("supports_tuple_comparison_against_subquery")
501+
def test_outer_ref_pk_filter_on_pk_comparison(self):
502+
subquery = Subquery(User.objects.filter(pk=OuterRef("pk")).values("pk")[:1])
503+
tests = [
504+
("gt", 0),
505+
("gte", 2),
506+
("lt", 0),
507+
("lte", 2),
508+
]
509+
for lookup, expected_count in tests:
510+
with self.subTest(f"pk__{lookup}"):
511+
qs = Comment.objects.filter(**{f"pk__{lookup}": subquery})
512+
self.assertEqual(qs.count(), expected_count)
513+
514+
@skipIfDBFeature("supports_tuple_comparison_against_subquery")
515+
def test_outer_ref_pk_filter_on_pk_comparison_unsupported(self):
516+
subquery = Subquery(User.objects.filter(pk=OuterRef("pk")).values("pk")[:1])
517+
tests = ["gt", "gte", "lt", "lte"]
518+
for lookup in tests:
519+
with self.subTest(f"pk__{lookup}"):
520+
qs = Comment.objects.filter(**{f"pk__{lookup}": subquery})
521+
with self.assertRaisesMessage(
522+
NotSupportedError,
523+
f'"{lookup}" cannot be used to target composite fields '
524+
"through subqueries on this backend",
525+
):
526+
qs.count()
527+
495528
def test_unsupported_rhs(self):
496529
pk = Exact(F("tenant_id"), 1)
497530
msg = (
@@ -561,7 +594,11 @@ def test_filter_by_tuple_containing_expression(self):
561594
@skipUnlessDBFeature("supports_tuple_lookups")
562595
class CompositePKFilterTupleLookupFallbackTests(CompositePKFilterTests):
563596
def setUp(self):
564-
feature_patch = patch.object(
597+
feature_patch_1 = patch.object(
565598
connection.features, "supports_tuple_lookups", False
566599
)
567-
self.enterContext(feature_patch)
600+
feature_patch_2 = patch.object(
601+
connection.features, "supports_tuple_comparison_against_subquery", False
602+
)
603+
self.enterContext(feature_patch_1)
604+
self.enterContext(feature_patch_2)

tests/expressions/tests.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,11 +988,24 @@ def test_annotation_with_outerref(self):
988988
)
989989
.order_by("-salary_raise")
990990
.values("salary_raise")[:1],
991-
output_field=IntegerField(),
992991
),
993992
).get(pk=self.gmbh.pk)
994993
self.assertEqual(gmbh_salary.max_ceo_salary_raise, 2332)
995994

995+
def test_annotation_with_outerref_and_output_field(self):
996+
gmbh_salary = Company.objects.annotate(
997+
max_ceo_salary_raise=Subquery(
998+
Company.objects.annotate(
999+
salary_raise=OuterRef("num_employees") + F("num_employees"),
1000+
)
1001+
.order_by("-salary_raise")
1002+
.values("salary_raise")[:1],
1003+
output_field=DecimalField(),
1004+
),
1005+
).get(pk=self.gmbh.pk)
1006+
self.assertEqual(gmbh_salary.max_ceo_salary_raise, 2332.0)
1007+
self.assertIsInstance(gmbh_salary.max_ceo_salary_raise, Decimal)
1008+
9961009
def test_annotation_with_nested_outerref(self):
9971010
self.gmbh.point_of_contact = Employee.objects.get(lastname="Meyer")
9981011
self.gmbh.save()
@@ -2542,6 +2555,15 @@ def test_filter_by_empty_exists(self):
25422555
self.assertSequenceEqual(qs, [manager])
25432556
self.assertIs(qs.get().exists, False)
25442557

2558+
def test_annotate_by_empty_custom_exists(self):
2559+
class CustomExists(Exists):
2560+
template = Subquery.template
2561+
2562+
manager = Manager.objects.create()
2563+
qs = Manager.objects.annotate(exists=CustomExists(Manager.objects.none()))
2564+
self.assertSequenceEqual(qs, [manager])
2565+
self.assertIs(qs.get().exists, False)
2566+
25452567

25462568
class FieldTransformTests(TestCase):
25472569
@classmethod

0 commit comments

Comments
 (0)