diff --git a/AUTHORS.rst b/AUTHORS.rst index 3d019c0fb..ebdd8425f 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -163,3 +163,4 @@ Contributors (chronological) - Javier Fernández `@jfernandz `_ - Michael Dimchuk `@michaeldimchuk `_ - Jochen Kupperschmidt `@homeworkprod `_ +- Midokura `@midokura `_ diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2d459d7cf..d33b03711 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,14 @@ Changelog --------- +3.13.0 (unreleased) +******************* + +Features: + +- Add ``validate.Unique`` (:pr:`1793`). + Thanks :user:`bonastreyair` for the PR. + 3.12.2 (2021-07-06) ******************* diff --git a/src/marshmallow/validate.py b/src/marshmallow/validate.py index 9c637fc92..143bc639a 100644 --- a/src/marshmallow/validate.py +++ b/src/marshmallow/validate.py @@ -6,6 +6,7 @@ from operator import attrgetter from marshmallow import types +from marshmallow import utils from marshmallow.exceptions import ValidationError _T = typing.TypeVar("_T") @@ -644,3 +645,56 @@ def __call__(self, value: typing.Sequence[_T]) -> typing.Sequence[_T]: if val in self.iterable: raise ValidationError(self._format_error(value)) return value + + +class Unique(Validator): + """Validator which succeeds if the ``value`` is an ``iterable`` and has unique + elements. In case of a list of objects, it can easy check an internal + attribute by passing the ``attribute`` parameter. + Validator which fails if ``value`` is not a member of ``iterable``. + + :param str attribute: The name of the attribute of the object you want to check. + """ + + default_message = "Invalid input. Supported lists or str." + error = "Found a duplicate value: {value}." + attribute_error = "Found a duplicate object attribute ({attribute}): {value}." + + def __init__(self, attribute: typing.Optional[str] = None): + self.attribute = attribute + + def _repr_args(self) -> str: + return "attribute={!r}".format(self.attribute) + + def _format_error(self, value) -> str: + if self.attribute: + return self.attribute_error.format(attribute=self.attribute, value=value) + return self.error.format(value=value) + + def __call__(self, value: typing.Iterable) -> typing.Iterable: + if not isinstance(value, typing.Iterable): + raise ValidationError(self.default_message) + ids = [ + utils.get_value(item, self.attribute) if self.attribute else item + for item in value + ] + try: + self._duplicate_hash(ids) + except TypeError: + self._duplicate_equal(ids) + + return value + + def _duplicate_hash(self, ids: typing.List) -> None: + used = set() + for _id in ids: + if _id in used: + raise ValidationError(self._format_error(_id)) + used.add(_id) + + def _duplicate_equal(self, ids: typing.List) -> None: + used = [] + for _id in ids: + if _id in used: + raise ValidationError(self._format_error(_id)) + used.append(_id) diff --git a/tests/test_validate.py b/tests/test_validate.py index 0dc70c1e3..e6dd99755 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -912,3 +912,75 @@ def test_and(): errors = excinfo.value.messages assert errors == ["Not an even value.", "Must be less than or equal to 6."] + + +def test_unique(): + class Bar: + def __init__(self, num): + self.num = num + + class Mock: + def __init__(self, name, bar): + self.name = name + self.bar = bar + + mock_object_a_1 = Mock("a", Bar(1)) + mock_object_a_2 = Mock("a", Bar(2)) + mock_object_b_1 = Mock("b", Bar(1)) + mock_dict_a_1 = {"name": "a", "bar": {"num": 1}} + mock_dict_a_2 = {"name": "a", "bar": {"num": 2}} + mock_dict_b_1 = {"name": "b", "bar": {"num": 1}} + + assert validate.Unique()("d") == "d" + assert validate.Unique()([]) == [] + assert validate.Unique()({}) == {} + assert validate.Unique()(["a", "b"]) == ["a", "b"] + assert validate.Unique()([1, 2]) == [1, 2] + assert validate.Unique(attribute="name")([mock_object_a_1, mock_object_b_1]) == [ + mock_object_a_1, + mock_object_b_1, + ] + assert validate.Unique(attribute="bar.num")([mock_object_a_1, mock_object_a_2]) == [ + mock_object_a_1, + mock_object_a_2, + ] + assert validate.Unique(attribute="name")([mock_dict_a_1, mock_dict_b_1]) == [ + mock_dict_a_1, + mock_dict_b_1, + ] + assert validate.Unique(attribute="bar.num")([mock_dict_a_1, mock_dict_a_2]) == [ + mock_dict_a_1, + mock_dict_a_2, + ] + assert validate.Unique()([[1, 2], [3, 4]]) == [[1, 2], [3, 4]] + assert validate.Unique()([{1, 2}, {3, 4}]) == [{1, 2}, {3, 4}] + assert validate.Unique()([{"a": 1}, {"b": 2}]) == [{"a": 1}, {"b": 2}] + + with pytest.raises(ValidationError, match="Invalid input."): + validate.Unique()(3) + with pytest.raises(ValidationError, match="Invalid input."): + validate.Unique()(1.1) + with pytest.raises(ValidationError, match="Invalid input."): + validate.Unique()(True) + with pytest.raises(ValidationError, match="Invalid input."): + validate.Unique()(None) + with pytest.raises(ValidationError, match="Found a duplicate value"): + validate.Unique()([1, 1, 2]) + with pytest.raises(ValidationError, match="Found a duplicate value"): + validate.Unique()("aab") + with pytest.raises(ValidationError, match="Found a duplicate value"): + validate.Unique()(["a", "a", "b"]) + with pytest.raises(ValidationError, match="Found a duplicate object attribute"): + validate.Unique(attribute="name")([mock_object_a_1, mock_object_a_2]) + with pytest.raises(ValidationError, match="Found a duplicate object attribute"): + validate.Unique(attribute="bar.num")([mock_object_a_1, mock_object_b_1]) + with pytest.raises(ValidationError, match="Found a duplicate object attribute"): + validate.Unique(attribute="name")([mock_dict_a_1, mock_dict_a_2]) + with pytest.raises(ValidationError, match="Found a duplicate object attribute"): + validate.Unique(attribute="bar.num")([mock_dict_a_1, mock_dict_b_1]) + with pytest.raises(ValidationError, match="Found a duplicate value"): + validate.Unique()([[1, 2], [1, 2]]) + with pytest.raises(ValidationError, match="Found a duplicate value"): + validate.Unique()([{1, 2}, {1, 2}]) + with pytest.raises(ValidationError, match="Found a duplicate value"): + validate.Unique()([{"a": 1, "b": 2}, {"a": 1, "b": 2}])