diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index dc961e631..bcc6b0fe9 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -2,6 +2,7 @@ from __future__ import annotations import abc +import base64 import collections import copy import datetime as dt @@ -48,6 +49,7 @@ "AwareDateTime", "Bool", "Boolean", + "Bytes", "Constant", "Date", "DateTime", @@ -878,6 +880,95 @@ def _deserialize(self, value, attr, data, **kwargs) -> str: raise self.make_error("invalid_utf8") from error +class Bytes(Field[bytes]): + """ + A field for deserializing strings into byte arrays. + + :param encoding: Specifies the string encoding used when encoding/decoding to/from strings. + :param errors: Error behaviour when converting to/from a :class:`str`, inherited from it's constructor. + :param serialize: Specifies the return type when serializing. + `base64` and `str` use the value of `encoding` for the string. + :param kwargs: The same keyword arguments that :class:`Field` receives. + + .. versionadded:: 4.3.0 + """ + + #: Default error messages. + default_error_messages = { + "not_bytes": "Not a bytes-like object.", + "unicode": "Invalid unicode string.", + } + + def __init__( + self, + encoding: str = "utf-8", + errors: str = "strict", + serialize: typing.Literal["int", "str", "bytes", "base64"] = "base64", + **kwargs: Unpack[_BaseFieldKwargs], + ): + super().__init__(**kwargs) + self.encoding = encoding + self.errors = errors + self.serialize = serialize + + def _deserialize( + self, + value: typing.Any, + attr: str | None, + data: typing.Mapping[str, typing.Any] | None, + **kwargs: typing.Any, + ) -> bytes: + try: + match value: + case str() as s: + return bytes( + s, + encoding=self.encoding, + errors=self.errors, + ) + case int() as i: + return i.to_bytes( + length=max(1, (7 + i.bit_length()) // 8), + byteorder="big", + signed=i < 0, + ) + case _: + return bytes(value) + except TypeError as e: + raise self.make_error("not_bytes") from e + except UnicodeError as e: + raise self.make_error("unicode") from e + + def _serialize( + self, + value: bytes, + attr: str | None, + obj: typing.Any, + **kwargs: typing.Any, + ) -> str | int | bytes: + try: + match self.serialize: + case "str": + return str( + value, + encoding=self.encoding, + errors=self.errors, + ) + case "base64": + return base64.standard_b64encode(value) + case "int": + return int.from_bytes( + value, + byteorder="big", + ) + case "bytes": + return value + case _: + typing.assert_never(self.serialize) + except UnicodeError as e: + raise self.make_error("unicode") from e + + class UUID(Field[uuid.UUID]): """A UUID field.""" diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index 7cbe28e48..7712e62bd 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -322,6 +322,23 @@ def test_string_field_deserialization(self): with pytest.raises(ValidationError): field.deserialize({}) + def test_bytes_field_deserialization(self): + field = fields.Bytes() + assert field.deserialize(b"foo") == b"foo" + assert field.deserialize(bytearray(b"foo")) == b"foo" + assert field.deserialize("foo") == b"foo" + assert field.deserialize(0xDEAD) == b"\xde\xad" + assert field.deserialize([0xBE, 0xEF]) == b"\xbe\xef" + assert field.deserialize((0xB, 0xA, 0xB, 0xE)) == b"\x0b\x0a\x0b\x0e" + + with pytest.raises(ValidationError) as excinfo: + field.deserialize({"hi": 222}) + assert excinfo.value.args[0] == "not a bytes-like object" + + with pytest.raises(ValidationError) as excinfo: + field.deserialize(["12345"]) + assert excinfo.value.args[0] == "not a bytes-like object" + def test_boolean_field_deserialization(self): field = fields.Boolean() assert field.deserialize(True) is True