diff --git a/django_mongodb_backend/constraints.py b/django_mongodb_backend/constraints.py index 60eef58a4..8e9566178 100644 --- a/django_mongodb_backend/constraints.py +++ b/django_mongodb_backend/constraints.py @@ -1,3 +1,4 @@ +import datetime from collections import defaultdict from django.core import checks @@ -10,6 +11,31 @@ from .indexes import EmbeddedFieldIndexMixin, _get_condition_mql, get_field +def _get_partial_unique_filter(field, connection): + field = getattr(field, "field", field) + db_type = field.db_type(connection) + + match db_type: + case "string": + return {"$gte": ""} + case "int": + return {"$gte": -2147483648, "$lte": 2147483647} + case "long": + return { + "$gte": -9223372036854775808, + "$lte": 9223372036854775807, + } + case "bool": + return {"$in": [True, False]} + case "date": + return { + "$gte": datetime.datetime.min, + "$lte": datetime.datetime.max, + } + case _: + return {"$type": db_type} + + def get_pymongo_index_model(self, model, schema_editor, field=None, column_prefix=""): """Return a pymongo IndexModel for this UniqueConstraint.""" if self.contains_expressions: @@ -25,12 +51,14 @@ def get_pymongo_index_model(self, model, schema_editor, field=None, column_prefi # Field(unique=True) or Meta.unique_together. if field: column = column_prefix + field.column - filter_expression[column].update({"$type": field.db_type(schema_editor.connection)}) + filter_expression[column].update( + _get_partial_unique_filter(field, schema_editor.connection) + ) else: for field_name in self.fields: field_ = get_field(model, field_name) filter_expression[field_.column].update( - {"$type": field_.field.db_type(schema_editor.connection)} + _get_partial_unique_filter(field_, schema_editor.connection) ) if filter_expression: kwargs["partialFilterExpression"] = filter_expression diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index 18d0e7551..a07d3cd25 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -2,6 +2,7 @@ from django.db.models.fields.related_lookups import In, RelatedIn from django.db.models.lookups import ( BuiltinLookup, + Exact, FieldGetDbPrepValueIterableMixin, IsNull, LessThan, @@ -14,6 +15,30 @@ from .query_utils import is_constant_value, process_lhs, process_rhs +def _exact_partial_filter(self, compiler, connection): + if self.rhs is None: + return None + + output_field = getattr(self.lhs, "output_field", None) + if output_field is None: + return None + + db_type = output_field.db_type(connection) + if db_type != "string": + return None + + lhs_mql = process_lhs(self, compiler, connection) + return {lhs_mql: {"$type": "string"}} + + +def exact_path(self, compiler, connection): + query = builtin_lookup_path(self, compiler, connection) + partial_filter = _exact_partial_filter(self, compiler, connection) + if partial_filter is None: + return query + return {"$and": [query, partial_filter]} + + def builtin_lookup_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_expr=True) value = process_rhs(self, compiler, connection, as_expr=True) @@ -172,6 +197,7 @@ def uuid_text_mixin(self, compiler, connection, as_expr=False): # noqa: ARG001 def register_lookups(): BuiltinLookup.as_mql_expr = builtin_lookup_expr BuiltinLookup.as_mql_path = builtin_lookup_path + Exact.as_mql_path = exact_path FieldGetDbPrepValueIterableMixin.resolve_expression_parameter = ( field_resolve_expression_parameter ) diff --git a/tests/constraints_/test_unique_indexes.py b/tests/constraints_/test_unique_indexes.py new file mode 100644 index 000000000..bef2de3ea --- /dev/null +++ b/tests/constraints_/test_unique_indexes.py @@ -0,0 +1,58 @@ +from django.db import connection, models +from django.test import SimpleTestCase +from django.test.utils import isolate_apps + + +@isolate_apps("constraints_") +class UniqueIndexTests(SimpleTestCase): + def test_single_field_unique_index_filter(self): + class Author(models.Model): + name = models.TextField(unique=True) + + class Meta: + app_label = "constraints_" + + field = Author._meta.get_field("name") + constraint = models.UniqueConstraint(fields=["name"], name="author_name_uniq") + + with connection.schema_editor() as editor: + index = constraint.get_pymongo_index_model( + Author, + schema_editor=editor, + field=field, + ) + + self.assertEqual( + dict(index.document["partialFilterExpression"]), + {"name": {"$gte": ""}}, + ) + + def test_multi_field_unique_index_filter(self): + class Book(models.Model): + version = models.IntegerField() + name = models.TextField() + + class Meta: + app_label = "constraints_" + constraints = [ + models.UniqueConstraint( + fields=["version", "name"], + name="unique_book_version", + ) + ] + + constraint = Book._meta.constraints[0] + + with connection.schema_editor() as editor: + index = constraint.get_pymongo_index_model(Book, schema_editor=editor) + + self.assertEqual( + dict(index.document["partialFilterExpression"]), + { + "version": { + "$gte": -9223372036854775808, + "$lte": 9223372036854775807, + }, + "name": {"$gte": ""}, + }, + ) diff --git a/tests/lookup_/models.py b/tests/lookup_/models.py index e91582aa5..9cc07cdc4 100644 --- a/tests/lookup_/models.py +++ b/tests/lookup_/models.py @@ -17,3 +17,21 @@ class Meta: def __str__(self): return str(self.num) + + +class UniqueAuthor(models.Model): + name = models.TextField(unique=True) + + +class UniqueBook(models.Model): + author = models.ForeignKey(UniqueAuthor, on_delete=models.CASCADE) + version = models.IntegerField() + name = models.TextField() + + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["version", "name"], + name="unique_book_version", + ) + ] diff --git a/tests/lookup_/tests.py b/tests/lookup_/tests.py index 94b63f833..018004ca5 100644 --- a/tests/lookup_/tests.py +++ b/tests/lookup_/tests.py @@ -1,10 +1,10 @@ -from bson import SON +from bson import SON, json_util from django.db.models import Sum from django.test import TestCase from django_mongodb_backend.test import MongoTestCaseMixin -from .models import Book, Number +from .models import Book, Number, UniqueAuthor, UniqueBook class NumericLookupTests(MongoTestCaseMixin, TestCase): @@ -170,3 +170,42 @@ def test_subquery_filter_constant(self): {"$sort": SON([("num", 1)])}, ], ) + + +class PartialUniqueIndexLookupTests(TestCase): + def _find_ixscan(self, winning_plan): + plan = winning_plan + while plan: + if plan.get("stage") == "IXSCAN": + return plan + plan = plan.get("inputStage") + return None + + def test_exact_lookup_uses_partial_unique_index(self): + UniqueAuthor.objects.create(name="JK Rowling") + + plan = json_util.loads(UniqueAuthor.objects.filter(name="JK Rowling").explain())[ + "queryPlanner" + ]["winningPlan"] + ixscan = self._find_ixscan(plan) + + self.assertIsNotNone(ixscan) + self.assertEqual(ixscan["keyPattern"], {"name": 1}) + self.assertTrue(ixscan["isUnique"]) + self.assertTrue(ixscan["isPartial"]) + + def test_compound_exact_lookup_uses_partial_unique_index(self): + author = UniqueAuthor.objects.create(name="JK Rowling") + UniqueBook.objects.create(author=author, version=3, name="Harry Potter") + UniqueBook.objects.create(author=author, version=4, name="Harry Potter") + + plan = json_util.loads(UniqueBook.objects.filter(version=3, name="Harry Potter").explain())[ + "queryPlanner" + ]["winningPlan"] + ixscan = self._find_ixscan(plan) + + self.assertIsNotNone(ixscan) + self.assertEqual(ixscan["indexName"], "unique_book_version") + self.assertEqual(ixscan["keyPattern"], {"version": 1, "name": 1}) + self.assertTrue(ixscan["isUnique"]) + self.assertTrue(ixscan["isPartial"])