Skip to content
32 changes: 30 additions & 2 deletions django_mongodb_backend/constraints.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from collections import defaultdict

from django.core import checks
Expand All @@ -10,6 +11,31 @@
from .indexes import EmbeddedFieldIndexMixin, _get_condition_mql, get_field


def _get_partial_unique_filter(field, connection):
field = getattr(field, "field", field)
db_type = field.db_type(connection)

match db_type:
case "string":
return {"$gte": ""}
case "int":
return {"$gte": -2147483648, "$lte": 2147483647}
case "long":
return {
"$gte": -9223372036854775808,
"$lte": 9223372036854775807,
}
case "bool":
return {"$in": [True, False]}
case "date":
return {
"$gte": datetime.datetime.min,
"$lte": datetime.datetime.max,
}
case _:
return {"$type": db_type}


def get_pymongo_index_model(self, model, schema_editor, field=None, column_prefix=""):
"""Return a pymongo IndexModel for this UniqueConstraint."""
if self.contains_expressions:
Expand All @@ -25,12 +51,14 @@ def get_pymongo_index_model(self, model, schema_editor, field=None, column_prefi
# Field(unique=True) or Meta.unique_together.
if field:
column = column_prefix + field.column
filter_expression[column].update({"$type": field.db_type(schema_editor.connection)})
filter_expression[column].update(
_get_partial_unique_filter(field, schema_editor.connection)
)
else:
for field_name in self.fields:
field_ = get_field(model, field_name)
filter_expression[field_.column].update(
{"$type": field_.field.db_type(schema_editor.connection)}
_get_partial_unique_filter(field_, schema_editor.connection)
)
if filter_expression:
kwargs["partialFilterExpression"] = filter_expression
Expand Down
26 changes: 26 additions & 0 deletions django_mongodb_backend/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from django.db.models.fields.related_lookups import In, RelatedIn
from django.db.models.lookups import (
BuiltinLookup,
Exact,
FieldGetDbPrepValueIterableMixin,
IsNull,
LessThan,
Expand All @@ -14,6 +15,30 @@
from .query_utils import is_constant_value, process_lhs, process_rhs


def _exact_partial_filter(self, compiler, connection):
if self.rhs is None:
return None

output_field = getattr(self.lhs, "output_field", None)
if output_field is None:
return None

db_type = output_field.db_type(connection)
if db_type != "string":
return None

lhs_mql = process_lhs(self, compiler, connection)
return {lhs_mql: {"$type": "string"}}


def exact_path(self, compiler, connection):
query = builtin_lookup_path(self, compiler, connection)
partial_filter = _exact_partial_filter(self, compiler, connection)
if partial_filter is None:
return query
return {"$and": [query, partial_filter]}


def builtin_lookup_expr(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
value = process_rhs(self, compiler, connection, as_expr=True)
Expand Down Expand Up @@ -172,6 +197,7 @@ def uuid_text_mixin(self, compiler, connection, as_expr=False): # noqa: ARG001
def register_lookups():
BuiltinLookup.as_mql_expr = builtin_lookup_expr
BuiltinLookup.as_mql_path = builtin_lookup_path
Exact.as_mql_path = exact_path
FieldGetDbPrepValueIterableMixin.resolve_expression_parameter = (
field_resolve_expression_parameter
)
Expand Down
58 changes: 58 additions & 0 deletions tests/constraints_/test_unique_indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from django.db import connection, models
from django.test import SimpleTestCase
from django.test.utils import isolate_apps


@isolate_apps("constraints_")
class UniqueIndexTests(SimpleTestCase):
def test_single_field_unique_index_filter(self):
class Author(models.Model):
name = models.TextField(unique=True)

class Meta:
app_label = "constraints_"

field = Author._meta.get_field("name")
constraint = models.UniqueConstraint(fields=["name"], name="author_name_uniq")

with connection.schema_editor() as editor:
index = constraint.get_pymongo_index_model(
Author,
schema_editor=editor,
field=field,
)

self.assertEqual(
dict(index.document["partialFilterExpression"]),
{"name": {"$gte": ""}},
)

def test_multi_field_unique_index_filter(self):
class Book(models.Model):
version = models.IntegerField()
name = models.TextField()

class Meta:
app_label = "constraints_"
constraints = [
models.UniqueConstraint(
fields=["version", "name"],
name="unique_book_version",
)
]

constraint = Book._meta.constraints[0]

with connection.schema_editor() as editor:
index = constraint.get_pymongo_index_model(Book, schema_editor=editor)

self.assertEqual(
dict(index.document["partialFilterExpression"]),
{
"version": {
"$gte": -9223372036854775808,
"$lte": 9223372036854775807,
},
"name": {"$gte": ""},
},
)
18 changes: 18 additions & 0 deletions tests/lookup_/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,21 @@ class Meta:

def __str__(self):
return str(self.num)


class UniqueAuthor(models.Model):
name = models.TextField(unique=True)


class UniqueBook(models.Model):
author = models.ForeignKey(UniqueAuthor, on_delete=models.CASCADE)
version = models.IntegerField()
name = models.TextField()

class Meta:
constraints = [
models.UniqueConstraint(
fields=["version", "name"],
name="unique_book_version",
)
]
43 changes: 41 additions & 2 deletions tests/lookup_/tests.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from bson import SON
from bson import SON, json_util
from django.db.models import Sum
from django.test import TestCase

from django_mongodb_backend.test import MongoTestCaseMixin

from .models import Book, Number
from .models import Book, Number, UniqueAuthor, UniqueBook


class NumericLookupTests(MongoTestCaseMixin, TestCase):
Expand Down Expand Up @@ -170,3 +170,42 @@ def test_subquery_filter_constant(self):
{"$sort": SON([("num", 1)])},
],
)


class PartialUniqueIndexLookupTests(TestCase):
def _find_ixscan(self, winning_plan):
plan = winning_plan
while plan:
if plan.get("stage") == "IXSCAN":
return plan
plan = plan.get("inputStage")
return None

def test_exact_lookup_uses_partial_unique_index(self):
UniqueAuthor.objects.create(name="JK Rowling")

plan = json_util.loads(UniqueAuthor.objects.filter(name="JK Rowling").explain())[
"queryPlanner"
]["winningPlan"]
ixscan = self._find_ixscan(plan)

self.assertIsNotNone(ixscan)
self.assertEqual(ixscan["keyPattern"], {"name": 1})
self.assertTrue(ixscan["isUnique"])
self.assertTrue(ixscan["isPartial"])

def test_compound_exact_lookup_uses_partial_unique_index(self):
author = UniqueAuthor.objects.create(name="JK Rowling")
UniqueBook.objects.create(author=author, version=3, name="Harry Potter")
UniqueBook.objects.create(author=author, version=4, name="Harry Potter")

plan = json_util.loads(UniqueBook.objects.filter(version=3, name="Harry Potter").explain())[
"queryPlanner"
]["winningPlan"]
ixscan = self._find_ixscan(plan)

self.assertIsNotNone(ixscan)
self.assertEqual(ixscan["indexName"], "unique_book_version")
self.assertEqual(ixscan["keyPattern"], {"version": 1, "name": 1})
self.assertTrue(ixscan["isUnique"])
self.assertTrue(ixscan["isPartial"])
Loading