Skip to content

Commit 245fed5

Browse files
committed
Add nulls_distinct support to UniqueTogetherValidator
1 parent 3f190b7 commit 245fed5

3 files changed

Lines changed: 157 additions & 12 deletions

File tree

rest_framework/serializers.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,20 +1435,26 @@ def get_extra_kwargs(self):
14351435

14361436
def get_unique_together_constraints(self, model):
14371437
"""
1438-
Returns iterator of (fields, queryset, condition_fields, condition),
1438+
Returns iterator of (fields, queryset, condition_fields, condition, nulls_distinct),
14391439
each entry describes an unique together constraint on `fields` in `queryset`
1440-
with respect of constraint's `condition`.
1440+
with respect of constraint's `condition` and `nulls_distinct` option.
14411441
"""
14421442
for parent_class in [model] + list(model._meta.parents):
14431443
for unique_together in parent_class._meta.unique_together:
1444-
yield unique_together, model._default_manager, [], None
1444+
yield unique_together, model._default_manager, [], None, None
14451445
for constraint in parent_class._meta.constraints:
14461446
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
14471447
if constraint.condition is None:
14481448
condition_fields = []
14491449
else:
14501450
condition_fields = list(get_referenced_base_fields_from_q(constraint.condition))
1451-
yield (constraint.fields, model._default_manager, condition_fields, constraint.condition)
1451+
yield (
1452+
constraint.fields,
1453+
model._default_manager,
1454+
condition_fields,
1455+
constraint.condition,
1456+
getattr(constraint, 'nulls_distinct', None),
1457+
)
14521458

14531459
def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
14541460
"""
@@ -1481,7 +1487,7 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs
14811487

14821488
# Include each of the `unique_together` and `UniqueConstraint` field names,
14831489
# so long as all the field names are included on the serializer.
1484-
for unique_together_list, queryset, condition_fields, condition in self.get_unique_together_constraints(model):
1490+
for unique_together_list, queryset, condition_fields, condition, nulls_distinct in self.get_unique_together_constraints(model):
14851491
unique_together_list_and_condition_fields = set(unique_together_list) | set(condition_fields)
14861492
if model_fields_names.issuperset(unique_together_list_and_condition_fields):
14871493
unique_constraint_names |= unique_together_list_and_condition_fields
@@ -1624,7 +1630,7 @@ def get_unique_together_validators(self):
16241630
# Note that we make sure to check `unique_together` both on the
16251631
# base model class, but also on any parent classes.
16261632
validators = []
1627-
for unique_together, queryset, condition_fields, condition in self.get_unique_together_constraints(self.Meta.model):
1633+
for unique_together, queryset, condition_fields, condition, nulls_distinct in self.get_unique_together_constraints(self.Meta.model):
16281634
# Skip if serializer does not map to all unique together sources
16291635
unique_together_and_condition_fields = set(unique_together) | set(condition_fields)
16301636
if not set(source_map).issuperset(unique_together_and_condition_fields):
@@ -1658,6 +1664,7 @@ def get_unique_together_validators(self):
16581664
condition=condition,
16591665
message=violation_error_message,
16601666
code=getattr(constraint, 'violation_error_code', None),
1667+
nulls_distinct=nulls_distinct,
16611668
)
16621669
validators.append(validator)
16631670
return validators

rest_framework/validators.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,14 @@ class UniqueTogetherValidator:
113113
requires_context = True
114114
code = 'unique'
115115

116-
def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None, code=None):
116+
def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None, code=None, nulls_distinct=None):
117117
self.queryset = queryset
118118
self.fields = fields
119119
self.message = message or self.message
120120
self.condition_fields = [] if condition_fields is None else condition_fields
121121
self.condition = condition
122122
self.code = code or self.code
123+
self.nulls_distinct = nulls_distinct
123124

124125
def enforce_required_fields(self, attrs, serializer):
125126
"""
@@ -197,17 +198,21 @@ def __call__(self, attrs, serializer):
197198
else getattr(serializer.instance, source)
198199
for source in condition_sources
199200
}
200-
if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs):
201-
field_names = ', '.join(self.fields)
202-
message = self.message.format(field_names=field_names)
203-
raise ValidationError(message, code=self.code)
201+
if checked_values:
202+
# Skip validation for None values unless nulls_distinct is False
203+
if self.nulls_distinct is not False and None in checked_values:
204+
return
205+
if qs_exists_with_condition(queryset, self.condition, condition_kwargs):
206+
field_names = ', '.join(self.fields)
207+
message = self.message.format(field_names=field_names)
208+
raise ValidationError(message, code=self.code)
204209

205210
def __repr__(self):
206211
return '<{}({})>'.format(
207212
self.__class__.__name__,
208213
', '.join(
209214
f'{attr}={smart_repr(getattr(self, attr))}'
210-
for attr in ('queryset', 'fields', 'condition')
215+
for attr in ('queryset', 'fields', 'condition', 'nulls_distinct')
211216
if getattr(self, attr) is not None)
212217
)
213218

@@ -220,6 +225,7 @@ def __eq__(self, other):
220225
and self.queryset == other.queryset
221226
and self.fields == other.fields
222227
and self.code == other.code
228+
and self.nulls_distinct == other.nulls_distinct
223229
)
224230

225231

tests/test_validators.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,23 @@ class Meta:
616616
]
617617

618618

619+
# Only define nulls_distinct model for Django 5.0+
620+
if django_version >= (5, 0):
621+
class UniqueConstraintNullsDistinctModel(models.Model):
622+
name = models.CharField(max_length=100)
623+
code = models.CharField(max_length=100, null=True)
624+
category = models.CharField(max_length=100, null=True)
625+
626+
class Meta:
627+
constraints = [
628+
models.UniqueConstraint(
629+
name='unique_code_category_nulls_not_distinct',
630+
fields=('code', 'category'),
631+
nulls_distinct=False,
632+
),
633+
]
634+
635+
619636
class UniqueConstraintCustomMessageCodeModel(models.Model):
620637
username = models.CharField(max_length=32)
621638
company_id = models.IntegerField()
@@ -1063,3 +1080,118 @@ def test_equality_operator(self):
10631080
assert validator == validator2
10641081
validator2.date_field = "bar2"
10651082
assert validator != validator2
1083+
1084+
1085+
# Tests for `nulls_distinct` option (Django 5.0+)
1086+
# -----------------------------------------------
1087+
1088+
@pytest.mark.skipif(
1089+
django_version < (5, 0),
1090+
reason="nulls_distinct requires Django 5.0+"
1091+
)
1092+
class TestUniqueConstraintNullsDistinct(TestCase):
1093+
"""
1094+
Tests for UniqueConstraint with nulls_distinct=False option.
1095+
When nulls_distinct=False, NULL values should be treated as equal
1096+
for uniqueness validation.
1097+
"""
1098+
1099+
def setUp(self):
1100+
from tests.test_validators import UniqueConstraintNullsDistinctModel
1101+
1102+
class UniqueConstraintNullsDistinctSerializer(serializers.ModelSerializer):
1103+
class Meta:
1104+
model = UniqueConstraintNullsDistinctModel
1105+
fields = ('name', 'code', 'category')
1106+
1107+
self.serializer_class = UniqueConstraintNullsDistinctSerializer
1108+
1109+
def test_nulls_distinct_false_validates_null_as_duplicate(self):
1110+
"""
1111+
When nulls_distinct=False, creating a second record with NULL values
1112+
in the constrained fields should fail validation.
1113+
"""
1114+
from tests.test_validators import UniqueConstraintNullsDistinctModel
1115+
1116+
# Create first record with NULL values
1117+
UniqueConstraintNullsDistinctModel.objects.create(
1118+
name='First',
1119+
code=None,
1120+
category=None
1121+
)
1122+
1123+
# Attempt to create second record with same NULL values
1124+
serializer = self.serializer_class(data={
1125+
'name': 'Second',
1126+
'code': None,
1127+
'category': None
1128+
})
1129+
1130+
# Should fail validation because nulls_distinct=False
1131+
assert not serializer.is_valid()
1132+
1133+
def test_nulls_distinct_false_allows_different_non_null_values(self):
1134+
"""
1135+
Non-NULL values should still work normally with uniqueness validation.
1136+
"""
1137+
from tests.test_validators import UniqueConstraintNullsDistinctModel
1138+
1139+
# Create first record with non-NULL values
1140+
UniqueConstraintNullsDistinctModel.objects.create(
1141+
name='First',
1142+
code='A',
1143+
category='X'
1144+
)
1145+
1146+
# Create second record with different values - should pass
1147+
serializer = self.serializer_class(data={
1148+
'name': 'Second',
1149+
'code': 'B',
1150+
'category': 'Y'
1151+
})
1152+
assert serializer.is_valid(), serializer.errors
1153+
1154+
def test_nulls_distinct_false_rejects_duplicate_non_null_values(self):
1155+
"""
1156+
Duplicate non-NULL values should still fail validation.
1157+
"""
1158+
from tests.test_validators import UniqueConstraintNullsDistinctModel
1159+
1160+
# Create first record
1161+
UniqueConstraintNullsDistinctModel.objects.create(
1162+
name='First',
1163+
code='A',
1164+
category='X'
1165+
)
1166+
1167+
# Attempt to create duplicate - should fail
1168+
serializer = self.serializer_class(data={
1169+
'name': 'Second',
1170+
'code': 'A',
1171+
'category': 'X'
1172+
})
1173+
assert not serializer.is_valid()
1174+
1175+
def test_unique_together_validator_nulls_distinct_equality(self):
1176+
"""
1177+
Test that UniqueTogetherValidator equality considers nulls_distinct.
1178+
"""
1179+
mock_queryset = MagicMock()
1180+
validator1 = UniqueTogetherValidator(
1181+
queryset=mock_queryset,
1182+
fields=('a', 'b'),
1183+
nulls_distinct=False
1184+
)
1185+
validator2 = UniqueTogetherValidator(
1186+
queryset=mock_queryset,
1187+
fields=('a', 'b'),
1188+
nulls_distinct=False
1189+
)
1190+
validator3 = UniqueTogetherValidator(
1191+
queryset=mock_queryset,
1192+
fields=('a', 'b'),
1193+
nulls_distinct=True
1194+
)
1195+
1196+
assert validator1 == validator2
1197+
assert validator1 != validator3

0 commit comments

Comments
 (0)