diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index d4a853b24..8b4172859 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -690,7 +690,21 @@ def getter( } for key in set(data) - fields: value = data[key] - if unknown == INCLUDE: + if isinstance(unknown, ma_fields.Field): + + def getter(val, unknown_field=unknown, field_name=key): + return unknown_field.deserialize(val, field_name, data) + + deserialized = self._call_and_store( + getter_func=getter, + data=value, + field_name=key, + error_store=error_store, + index=index, + ) + if deserialized is not missing: + ret_d[key] = deserialized + elif unknown == INCLUDE: ret_d[key] = value elif unknown == RAISE: error_store.store_error( diff --git a/src/marshmallow/types.py b/src/marshmallow/types.py index 4c5d98da1..9eb32ae28 100644 --- a/src/marshmallow/types.py +++ b/src/marshmallow/types.py @@ -9,14 +9,25 @@ import typing +if typing.TYPE_CHECKING: + from marshmallow.fields import Field + #: A type that can be either a sequence of strings or a set of strings StrSequenceOrSet: typing.TypeAlias = typing.Sequence[str] | typing.AbstractSet[str] #: Type for validator functions Validator: typing.TypeAlias = typing.Callable[[typing.Any], typing.Any] -#: A valid option for the ``unknown`` schema option and argument -UnknownOption: typing.TypeAlias = typing.Literal["exclude", "include", "raise"] +#: A valid option for the ``unknown`` schema option and argument. +#: Can be a string constant (``"exclude"``, ``"include"``, ``"raise"``) +#: or a :class:`Field ` instance to deserialize unknown +#: field values through. +if typing.TYPE_CHECKING: + UnknownOption: typing.TypeAlias = ( + typing.Literal["exclude", "include", "raise"] | Field + ) +else: + UnknownOption: typing.TypeAlias = typing.Literal["exclude", "include", "raise"] class SchemaValidator(typing.Protocol): diff --git a/tests/test_schema.py b/tests/test_schema.py index ff2811737..a15f4a78e 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -638,6 +638,75 @@ class ErrorSchema(Schema): assert "Invalid email" in errors["email"] +class TestUnknownFieldOption: + """Tests for passing a Field instance to ``unknown``.""" + + def test_unknown_field_deserializes_values(self): + class MySchema(Schema): + name = fields.String() + + schema = MySchema(unknown=fields.Int()) + result = schema.load({"name": "Joe", "age": "42"}) + assert result == {"name": "Joe", "age": 42} + + def test_unknown_field_validation_error(self): + class MySchema(Schema): + name = fields.String() + + schema = MySchema(unknown=fields.Int()) + with pytest.raises(ValidationError) as excinfo: + schema.load({"name": "Joe", "age": "not_a_number"}) + assert "age" in excinfo.value.messages + + def test_unknown_field_in_meta(self): + class MySchema(Schema): + class Meta: + unknown = fields.String() + + name = fields.String() + + result = MySchema().load({"name": "Joe", "extra": "hello"}) + assert result == {"name": "Joe", "extra": "hello"} + + def test_unknown_field_with_many(self): + class MySchema(Schema): + name = fields.String() + + schema = MySchema(unknown=fields.Int()) + result = schema.load( + [{"name": "Joe", "age": "42"}, {"name": "Jane", "score": "99"}], + many=True, + ) + assert result == [{"name": "Joe", "age": 42}, {"name": "Jane", "score": 99}] + + def test_unknown_field_in_load_kwarg(self): + class MySchema(Schema): + name = fields.String() + + schema = MySchema() + result = schema.load({"name": "Joe", "extra": "42"}, unknown=fields.Int()) + assert result == {"name": "Joe", "extra": 42} + + def test_unknown_field_nested(self): + class ChildSchema(Schema): + num = fields.Int() + + class ParentSchema(Schema): + child = fields.Nested(ChildSchema, unknown=fields.String()) + + data = {"child": {"num": 1, "extra": "hello"}} + result = ParentSchema().load(data) + assert result == {"child": {"num": 1, "extra": "hello"}} + + def test_unknown_field_excludes_nothing(self): + class MySchema(Schema): + name = fields.String() + + schema = MySchema(unknown=fields.Field()) + result = schema.load({"name": "Joe", "extra": "value", "more": 123}) + assert result == {"name": "Joe", "extra": "value", "more": 123} + + def test_custom_unknown_error_message(): custom_message = "custom error message."