Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions django/contrib/gis/db/models/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def process_band_indices(self, only_lhs=False):

def get_db_prep_lookup(self, value, connection):
# get_db_prep_lookup is called by process_rhs from super class
return ("%s", [connection.ops.Adapter(value)])
return ("%s", (connection.ops.Adapter(value),))

def process_rhs(self, compiler, connection):
if isinstance(self.rhs, Query):
Expand Down Expand Up @@ -284,7 +284,7 @@ def process_rhs(self, compiler, connection):
elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern):
raise ValueError('Invalid intersection matrix pattern "%s".' % pattern)
sql, params = super().process_rhs(compiler, connection)
return sql, [*params, pattern]
return sql, (*params, pattern)


@BaseSpatialField.register_lookup
Expand Down Expand Up @@ -352,7 +352,7 @@ def process_rhs(self, compiler, connection):
dist_sql, dist_params = self.process_distance(compiler, connection)
self.template_params["value"] = dist_sql
rhs_sql, params = super().process_rhs(compiler, connection)
return rhs_sql, params + dist_params
return rhs_sql, (*params, *dist_params)


class DistanceLookupFromFunction(DistanceLookupBase):
Expand All @@ -367,7 +367,7 @@ def as_sql(self, compiler, connection):
dist_sql, dist_params = self.process_distance(compiler, connection)
return (
"%(func)s %(op)s %(dist)s" % {"func": sql, "op": self.op, "dist": dist_sql},
params + dist_params,
(*params, *dist_params),
)


Expand Down
10 changes: 5 additions & 5 deletions django/contrib/postgres/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def process_rhs(self, qn, connection):
def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params
params = (*lhs_params, *rhs_params)
return "%s @@ %s" % (lhs, rhs), params


Expand Down Expand Up @@ -148,7 +148,7 @@ def as_sql(self, compiler, connection, function=None, template=None):
weight_sql, extra_params = compiler.compile(clone.weight)
sql = "setweight({}, {})".format(sql, weight_sql)

return sql, config_params + params + extra_params
return sql, (*config_params, *params, *extra_params)


class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
Expand Down Expand Up @@ -318,13 +318,13 @@ def __init__(

def as_sql(self, compiler, connection, function=None, template=None):
options_sql = ""
options_params = []
options_params = ()
if self.options:
options_params.append(
options_params = (
", ".join(
connection.ops.compose_sql(f"{option}=%s", [value])
for option, value in self.options.items()
)
),
)
options_sql = ", %s"
sql, params = super().as_sql(
Expand Down
3 changes: 2 additions & 1 deletion django/db/backends/base/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import operator
from datetime import datetime
from itertools import chain

from django.conf import settings
from django.core.exceptions import FieldError
Expand Down Expand Up @@ -1160,7 +1161,7 @@ def _alter_field(
# Combine actions together if we can (e.g. postgres)
if self.connection.features.supports_combined_alters and actions:
sql, params = tuple(zip(*actions))
actions = [(", ".join(sql), sum(params, []))]
actions = [(", ".join(sql), tuple(chain(*params)))]
# Apply those actions
for sql, params in actions:
self.execute(
Expand Down
4 changes: 2 additions & 2 deletions django/db/models/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ def as_sql(
template = template or data.get("template", self.template)
arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
return template % data, params
return template % data, tuple(params)

def copy(self):
copy = super().copy()
Expand Down Expand Up @@ -1323,7 +1323,7 @@ def as_sql(self, compiler, connection):
alias, column = self.alias, self.target.column
identifiers = (alias, column) if alias else (column,)
sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
return sql, []
return sql, ()

def relabeled_clone(self, relabels):
if self.alias is None:
Expand Down
2 changes: 1 addition & 1 deletion django/db/models/functions/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def as_sqlite(self, compiler, connection, **extra_context):
compiler, connection, template=template, **extra_context
)
format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
params.insert(0, format_string)
params = (format_string, *params)
return sql, params
elif db_type == "date":
template = "date(%(expressions)s)"
Expand Down
12 changes: 7 additions & 5 deletions django/db/models/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_prep_lhs(self):
return Value(self.lhs)

def get_db_prep_lookup(self, value, connection):
return ("%s", [value])
return ("%s", (value,))

def process_lhs(self, compiler, connection, lhs=None):
lhs = lhs or self.lhs
Expand Down Expand Up @@ -415,7 +415,7 @@ class IExact(BuiltinLookup):
def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection)
if params:
params[0] = connection.ops.prep_for_iexact_query(params[0])
params = (connection.ops.prep_for_iexact_query(params[0]), *params[1:])
return rhs, params


Expand Down Expand Up @@ -603,8 +603,9 @@ def get_rhs_op(self, connection, rhs):
def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection)
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
params[0] = self.param_pattern % connection.ops.prep_for_like_query(
params[0]
params = (
self.param_pattern % connection.ops.prep_for_like_query(params[0]),
*params[1:],
)
return rhs, params

Expand Down Expand Up @@ -686,8 +687,9 @@ def as_sql(self, compiler, connection):
else:
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = (*lhs_params, *rhs_params)
sql_template = connection.ops.regex_lookup(self.lookup_name)
return sql_template % (lhs, rhs), lhs_params + rhs_params
return sql_template % (lhs, rhs), params


@Field.register_lookup
Expand Down
41 changes: 40 additions & 1 deletion django/db/models/query_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from contextlib import nullcontext

from django.core.exceptions import FieldError
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, transaction
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models, transaction
from django.db.models.constants import LOOKUP_SEP
from django.utils import tree
from django.utils.functional import cached_property
Expand Down Expand Up @@ -99,6 +99,45 @@ def resolve_expression(
query.promote_joins(joins)
return clause

def replace_expressions(self, replacements):
if not replacements:
return self
clone = self.create(connector=self.connector, negated=self.negated)
for child in self.children:
child_replacement = child
if isinstance(child, tuple):
lhs, rhs = child
if LOOKUP_SEP in lhs:
path, lookup = lhs.rsplit(LOOKUP_SEP, 1)
else:
path = lhs
lookup = None
field = models.F(path)
if (
field_replacement := field.replace_expressions(replacements)
) is not field:
# Handle the implicit __exact case by falling back to an
# extra transform when get_lookup returns no match for the
# last component of the path.
if lookup is None:
lookup = "exact"
if (lookup_class := field_replacement.get_lookup(lookup)) is None:
if (
transform_class := field_replacement.get_transform(lookup)
) is not None:
field_replacement = transform_class(field_replacement)
lookup = "exact"
lookup_class = field_replacement.get_lookup(lookup)
if rhs is None and lookup == "exact":
lookup_class = field_replacement.get_lookup("isnull")
rhs = True
if lookup_class is not None:
child_replacement = lookup_class(field_replacement, rhs)
else:
child_replacement = child.replace_expressions(replacements)
clone.children.append(child_replacement)
return clone

def flatten(self):
"""
Recursively yield this Q object and all subexpressions, in depth-first
Expand Down
4 changes: 3 additions & 1 deletion docs/ref/models/expressions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,8 @@ calling the appropriate methods on the wrapped expression.
:meth:`~django.db.models.query.QuerySet.reverse()` is called on a
queryset.

.. _writing-your-own-query-expressions:

Writing your own Query Expressions
----------------------------------

Expand Down Expand Up @@ -1262,7 +1264,7 @@ Next, we write the method responsible for generating the SQL::
sql_params.extend(params)
template = template or self.template
data = {"expressions": ",".join(sql_expressions)}
return template % data, sql_params
return template % data, tuple(sql_params)


def as_oracle(self, compiler, connection):
Expand Down
4 changes: 4 additions & 0 deletions docs/releases/5.2.5.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ Bugfixes

* Fixed a crash in Django 5.2 when filtering against a composite primary key
using a tuple containing expressions (:ticket:`36522`).

* Fixed a crash in Django 5.2 when validating a model that uses
``GeneratedField`` or constraints composed of ``Q`` and ``Case`` lookups
(:ticket:`36518`).
18 changes: 18 additions & 0 deletions docs/releases/6.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,24 @@ Email
significantly, closely examine any custom subclasses that rely on overriding
undocumented, internal underscore methods.

Custom ORM expressions should return params as a tuple
------------------------------------------------------

Prior to Django 6.0, :doc:`custom lookups </howto/custom-lookups>` and
:ref:`custom expressions <writing-your-own-query-expressions>` implementing the
``as_sql()`` method (and its supporting methods ``process_lhs()`` and
``process_rhs()``) were allowed to return a sequence of params in either a list
or a tuple. To address the interoperability problems that resulted, the second
return element of the ``as_sql()`` method should now be a tuple::

def as_sql(self, compiler, connection) -> tuple[str, tuple]: ...

If your custom expressions support multiple versions of Django, you should
adjust any pre-processing of parameters to be resilient against either tuples
or lists. For instance, prefer unpacking like this::

params = (*lhs_params, *rhs_params)

Miscellaneous
-------------

Expand Down
19 changes: 18 additions & 1 deletion tests/constraints/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from django.core.exceptions import ValidationError
from django.db import IntegrityError, connection, models
from django.db.models import F
from django.db.models import Case, F, When
from django.db.models.constraints import BaseConstraint, UniqueConstraint
from django.db.models.functions import Abs, Lower, Sqrt, Upper
from django.db.transaction import atomic
Expand Down Expand Up @@ -1064,6 +1064,23 @@ def test_validate_field_transform(self):
UniqueConstraintProduct(updated=updated_date + timedelta(days=1)),
)

def test_validate_case_when(self):
UniqueConstraintProduct.objects.create(name="p1")
constraint = models.UniqueConstraint(
Case(When(color__isnull=True, then=F("name"))),
name="name_without_color_uniq",
)
msg = "Constraint “name_without_color_uniq” is violated."
with self.assertRaisesMessage(ValidationError, msg):
constraint.validate(
UniqueConstraintProduct,
UniqueConstraintProduct(name="p1"),
)
constraint.validate(
UniqueConstraintProduct,
UniqueConstraintProduct(name="p1", color="green"),
)

def test_validate_ordered_expression(self):
constraint = models.UniqueConstraint(
Lower("name").desc(), name="name_lower_uniq_desc"
Expand Down
42 changes: 38 additions & 4 deletions tests/custom_lookups/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.core.exceptions import FieldError
from django.db import connection, models
from django.db.models.fields.related_lookups import RelatedGreaterThan
from django.db.models.functions import Lower
from django.db.models.lookups import EndsWith, StartsWith
from django.test import SimpleTestCase, TestCase, override_settings
from django.test.utils import register_lookup
Expand All @@ -17,15 +18,15 @@ class Div3Lookup(models.Lookup):
lookup_name = "div3"

def as_sql(self, compiler, connection):
lhs, params = self.process_lhs(compiler, connection)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
params = (*lhs_params, *rhs_params)
return "(%s) %%%% 3 = %s" % (lhs, rhs), params

def as_oracle(self, compiler, connection):
lhs, params = self.process_lhs(compiler, connection)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
params = (*lhs_params, *rhs_params)
return "mod(%s, 3) = %s" % (lhs, rhs), params


Expand Down Expand Up @@ -249,6 +250,39 @@ def test_custom_name_lookup(self):
self.assertSequenceEqual(qs1, [a1])
self.assertSequenceEqual(qs2, [a1])

def test_custom_lookup_with_subquery(self):
class NotEqual(models.Lookup):
lookup_name = "ne"

def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
# Although combining via (*lhs_params, *rhs_params) would be
# more resilient, the "simple" way works too.
params = lhs_params + rhs_params
return "%s <> %s" % (lhs, rhs), params

author = Author.objects.create(name="Isabella")

with register_lookup(models.Field, NotEqual):
qs = Author.objects.annotate(
unknown_age=models.Subquery(
Author.objects.filter(age__isnull=True)
.order_by("name")
.values("name")[:1]
)
).filter(unknown_age__ne="Plato")
self.assertSequenceEqual(qs, [author])

qs = Author.objects.annotate(
unknown_age=Lower(
Author.objects.filter(age__isnull=True)
.order_by("name")
.values("name")[:1]
)
).filter(unknown_age__ne="plato")
self.assertSequenceEqual(qs, [author])

def test_custom_exact_lookup_none_rhs(self):
"""
__exact=None is transformed to __isnull=True if a custom lookup class
Expand Down
18 changes: 18 additions & 0 deletions tests/expressions/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,24 @@ def test_in_subquery(self):
)
self.assertCountEqual(subquery_test2, [self.foobar_ltd])

def test_lookups_subquery(self):
smallest_company = Company.objects.order_by("num_employees").values("name")[:1]
for lookup in CharField.get_lookups():
if lookup == "isnull":
continue # not allowed, rhs must be a literal boolean.
if (
lookup == "in"
and not connection.features.allow_sliced_subqueries_with_in
):
continue
if lookup == "range":
rhs = (Subquery(smallest_company), Subquery(smallest_company))
else:
rhs = Subquery(smallest_company)
with self.subTest(lookup=lookup):
qs = Company.objects.filter(**{f"name__{lookup}": rhs})
self.assertGreater(len(qs), 0)

def test_uuid_pk_subquery(self):
u = UUIDPK.objects.create()
UUID.objects.create(uuid_fk=u)
Expand Down
Loading
Loading