diff --git a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py index 519046979..ba0e8be8c 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py @@ -1,13 +1,15 @@ from typing import Any, Final from smithy_core.codecs import Codec +from smithy_core.exceptions import DiscriminatorError from smithy_core.schemas import APIOperation -from smithy_core.shapes import ShapeID +from smithy_core.shapes import ShapeID, ShapeType from smithy_http.aio.interfaces import HTTPErrorIdentifier, HTTPResponse from smithy_http.aio.protocols import HttpBindingClientProtocol -from smithy_json import JSONCodec +from smithy_json import JSONCodec, JSONDocument from ..traits import RestJson1Trait +from ..utils import parse_document_discriminator, parse_error_code class AWSErrorIdentifier(HTTPErrorIdentifier): @@ -24,20 +26,29 @@ def identify( error_field = response.fields[self._HEADER_KEY] code = error_field.values[0] if len(error_field.values) > 0 else None - if not code: - return None + if code is not None: + return parse_error_code(code, operation.schema.id.namespace) + return None + - code = code.split(":")[0] - if "#" in code: - return ShapeID(code) - return ShapeID.from_parts(name=code, namespace=operation.schema.id.namespace) +class AWSJSONDocument(JSONDocument): + @property + def discriminator(self) -> ShapeID: + if self.shape_type is ShapeType.STRUCTURE: + return self._schema.id + parsed = parse_document_discriminator(self, self._settings.default_namespace) + if parsed is None: + raise DiscriminatorError( + f"Unable to parse discriminator for {self.shape_type} document." + ) + return parsed class RestJsonClientProtocol(HttpBindingClientProtocol): """An implementation of the aws.protocols#restJson1 protocol.""" _id: Final = RestJson1Trait.id - _codec: Final = JSONCodec() + _codec: Final = JSONCodec(document_class=AWSJSONDocument) _contentType: Final = "application/json" _error_identifier: Final = AWSErrorIdentifier() diff --git a/packages/smithy-aws-core/src/smithy_aws_core/utils.py b/packages/smithy-aws-core/src/smithy_aws_core/utils.py new file mode 100644 index 000000000..940160e05 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/utils.py @@ -0,0 +1,32 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from smithy_core.documents import Document +from smithy_core.shapes import ShapeID, ShapeType + + +def parse_document_discriminator( + document: Document, default_namespace: str | None +) -> ShapeID | None: + if document.shape_type is ShapeType.MAP: + map_document = document.as_map() + code = map_document.get("__type") + if code is None: + code = map_document.get("code") + if code is not None and code.shape_type is ShapeType.STRING: + return parse_error_code(code.as_string(), default_namespace) + + return None + + +def parse_error_code(code: str, default_namespace: str | None) -> ShapeID | None: + if not code: + return None + + code = code.split(":")[0] + if "#" in code: + return ShapeID(code) + + if not code or not default_namespace: + return None + + return ShapeID.from_parts(name=code, namespace=default_namespace) diff --git a/packages/smithy-aws-core/tests/unit/aio/test_protocols.py b/packages/smithy-aws-core/tests/unit/aio/test_protocols.py index 82cf7d1e8..7b767a080 100644 --- a/packages/smithy-aws-core/tests/unit/aio/test_protocols.py +++ b/packages/smithy-aws-core/tests/unit/aio/test_protocols.py @@ -4,11 +4,13 @@ from unittest.mock import Mock import pytest -from smithy_aws_core.aio.protocols import AWSErrorIdentifier +from smithy_aws_core.aio.protocols import AWSErrorIdentifier, AWSJSONDocument +from smithy_core.exceptions import DiscriminatorError from smithy_core.schemas import APIOperation, Schema from smithy_core.shapes import ShapeID, ShapeType from smithy_http import Fields, tuples_to_fields from smithy_http.aio import HTTPResponse +from smithy_json import JSONSettings @pytest.mark.parametrize( @@ -24,6 +26,7 @@ "com.test#FooError", ), ("", None), + (":", None), (None, None), ], ) @@ -42,3 +45,55 @@ def test_aws_error_identifier(header: str | None, expected: ShapeID | None) -> N actual = error_identifier.identify(operation=operation, response=http_response) assert actual == expected + + +@pytest.mark.parametrize( + "document, expected", + [ + ({"__type": "FooError"}, "com.test#FooError"), + ({"__type": "com.test#FooError"}, "com.test#FooError"), + ( + { + "__type": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/" + }, + "com.test#FooError", + ), + ( + { + "__type": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate" + }, + "com.test#FooError", + ), + ({"code": "FooError"}, "com.test#FooError"), + ({"code": "com.test#FooError"}, "com.test#FooError"), + ( + { + "code": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/" + }, + "com.test#FooError", + ), + ( + { + "code": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate" + }, + "com.test#FooError", + ), + ({"__type": "FooError", "code": "BarError"}, "com.test#FooError"), + ("FooError", None), + ({"__type": None}, None), + ({"__type": ""}, None), + ({"__type": ":"}, None), + ], +) +def test_aws_json_document_discriminator( + document: dict[str, str], expected: ShapeID | None +) -> None: + settings = JSONSettings( + document_class=AWSJSONDocument, default_namespace="com.test" + ) + if expected is None: + with pytest.raises(DiscriminatorError): + AWSJSONDocument(document, settings=settings).discriminator + else: + discriminator = AWSJSONDocument(document, settings=settings).discriminator + assert discriminator == expected diff --git a/packages/smithy-aws-core/tests/unit/test_utils.py b/packages/smithy-aws-core/tests/unit/test_utils.py new file mode 100644 index 000000000..6927a2fce --- /dev/null +++ b/packages/smithy-aws-core/tests/unit/test_utils.py @@ -0,0 +1,78 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from smithy_aws_core.utils import parse_document_discriminator, parse_error_code +from smithy_core.documents import Document +from smithy_core.shapes import ShapeID + + +@pytest.mark.parametrize( + "document, expected", + [ + ({"__type": "FooError"}, "com.test#FooError"), + ({"__type": "com.test#FooError"}, "com.test#FooError"), + ( + { + "__type": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/" + }, + "com.test#FooError", + ), + ( + { + "__type": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate" + }, + "com.test#FooError", + ), + ({"code": "FooError"}, "com.test#FooError"), + ({"code": "com.test#FooError"}, "com.test#FooError"), + ( + { + "code": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/" + }, + "com.test#FooError", + ), + ( + { + "code": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate" + }, + "com.test#FooError", + ), + ({"__type": "FooError", "code": "BarError"}, "com.test#FooError"), + ("FooError", None), + ({"__type": None}, None), + ({"__type": ""}, None), + ({"__type": ":"}, None), + ], +) +def test_aws_json_document_discriminator( + document: dict[str, str], expected: ShapeID | None +) -> None: + actual = parse_document_discriminator(Document(document), "com.test") + assert actual == expected + + +@pytest.mark.parametrize( + "code, expected", + [ + ("FooError", "com.test#FooError"), + ( + "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/", + "com.test#FooError", + ), + ( + "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate", + "com.test#FooError", + ), + ("", None), + (":", None), + ], +) +def test_parse_error_code(code: str, expected: ShapeID | None) -> None: + actual = parse_error_code(code, "com.test") + assert actual == expected + + +def test_parse_error_code_without_default_namespace() -> None: + actual = parse_error_code("FooError", None) + assert actual is None diff --git a/packages/smithy-core/src/smithy_core/documents.py b/packages/smithy-core/src/smithy_core/documents.py index 2166bae1e..1c565fa9e 100644 --- a/packages/smithy-core/src/smithy_core/documents.py +++ b/packages/smithy-core/src/smithy_core/documents.py @@ -5,7 +5,7 @@ from typing import TypeGuard, override from .deserializers import DeserializeableShape, ShapeDeserializer -from .exceptions import ExpectationNotMetError, SmithyError +from .exceptions import DiscriminatorError, ExpectationNotMetError, SmithyError from .schemas import Schema from .serializers import ( InterceptingSerializer, @@ -146,7 +146,9 @@ def shape_type(self) -> ShapeType: @property def discriminator(self) -> ShapeID: """The shape ID that corresponds to the contents of the document.""" - return self._schema.id + if self._type is ShapeType.STRUCTURE: + return self._schema.id + raise DiscriminatorError(f"{self._type} document has no discriminator.") def is_none(self) -> bool: """Indicates whether the document contains a null value.""" diff --git a/packages/smithy-core/src/smithy_core/exceptions.py b/packages/smithy-core/src/smithy_core/exceptions.py index 7d320b76c..0e28bd530 100644 --- a/packages/smithy-core/src/smithy_core/exceptions.py +++ b/packages/smithy-core/src/smithy_core/exceptions.py @@ -65,6 +65,11 @@ class SerializationError(SmithyError): """Base exception type for exceptions raised during serialization.""" +class DiscriminatorError(SmithyError): + """Exception indicating something went wrong when attempting to find the + discriminator in a document.""" + + class RetryError(SmithyError): """Base exception type for all exceptions raised in retry strategies.""" diff --git a/packages/smithy-core/tests/unit/test_documents.py b/packages/smithy-core/tests/unit/test_documents.py index 1ae24eb94..8a1e3ccf1 100644 --- a/packages/smithy-core/tests/unit/test_documents.py +++ b/packages/smithy-core/tests/unit/test_documents.py @@ -12,7 +12,7 @@ _DocumentDeserializer, _DocumentSerializer, ) -from smithy_core.exceptions import ExpectationNotMetError +from smithy_core.exceptions import DiscriminatorError, ExpectationNotMetError from smithy_core.prelude import ( BIG_DECIMAL, BLOB, @@ -938,3 +938,13 @@ def _read_optional_map(k: str, d: ShapeDeserializer): actual = given.as_shape(DocumentSerdeShape) case _: raise Exception(f"Unexpected type: {type(given)}") + + +def test_document_has_no_discriminator_by_default() -> None: + with pytest.raises(DiscriminatorError): + Document().discriminator + + +def test_struct_document_has_discriminator() -> None: + document = Document({"integerMember": 1}, schema=SCHEMA) + assert document.discriminator == SCHEMA.id diff --git a/packages/smithy-core/tests/unit/test_type_registry.py b/packages/smithy-core/tests/unit/test_type_registry.py index 3c7ca98c3..6a2a4ffc2 100644 --- a/packages/smithy-core/tests/unit/test_type_registry.py +++ b/packages/smithy-core/tests/unit/test_type_registry.py @@ -1,8 +1,10 @@ import pytest from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer from smithy_core.documents import Document, TypeRegistry +from smithy_core.prelude import STRING from smithy_core.schemas import Schema -from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.shapes import ShapeID +from smithy_core.traits import RequiredTrait def test_get(): @@ -59,11 +61,16 @@ def test_deserialize(): class TestShape(DeserializeableShape): __test__ = False - schema = Schema(id=ShapeID("com.example#Test"), shape_type=ShapeType.STRING) + schema = Schema.collection( + id=ShapeID("com.example#Test"), + members={"value": {"index": 0, "target": STRING, "traits": [RequiredTrait()]}}, + ) def __init__(self, value: str): self.value = value @classmethod def deserialize(cls, deserializer: ShapeDeserializer) -> "TestShape": - return TestShape(deserializer.read_string(schema=TestShape.schema)) + return TestShape( + value=deserializer.read_string(schema=cls.schema.members["value"]) + ) diff --git a/packages/smithy-json/src/smithy_json/__init__.py b/packages/smithy-json/src/smithy_json/__init__.py index c90653d03..2993deed0 100644 --- a/packages/smithy-json/src/smithy_json/__init__.py +++ b/packages/smithy-json/src/smithy_json/__init__.py @@ -10,23 +10,24 @@ from smithy_core.types import TimestampFormat from ._private.deserializers import JSONShapeDeserializer as _JSONShapeDeserializer +from ._private.documents import JSONDocument from ._private.serializers import JSONShapeSerializer as _JSONShapeSerializer +from .settings import JSONSettings __version__: str = importlib.metadata.version("smithy-json") +__all__ = ("JSONCodec", "JSONDocument", "JSONSettings") class JSONCodec(Codec): """A codec for converting shapes to/from JSON.""" - _use_json_name: bool - _use_timestamp_format: bool - _default_timestamp_format: TimestampFormat - def __init__( self, use_json_name: bool = True, use_timestamp_format: bool = True, default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, + default_namespace: str | None = None, + document_class: type[JSONDocument] = JSONDocument, ) -> None: """Initializes a JSONCodec. @@ -36,29 +37,26 @@ def __init__( `smithy.api#timestampFormat` trait, if present. :param default_timestamp_format: The default timestamp format to use if the `smithy.api#timestampFormat` trait is not enabled or not present. + :param default_namespace: The default namespace to use when determining a + document's discriminator. + :param document_class: The document class to deserialize to. """ - self._use_json_name = use_json_name - self._use_timestamp_format = use_timestamp_format - self._default_timestamp_format = default_timestamp_format + self._settings = JSONSettings( + use_json_name=use_json_name, + use_timestamp_format=use_timestamp_format, + default_timestamp_format=default_timestamp_format, + default_namespace=default_namespace, + document_class=document_class, + ) @property def media_type(self) -> str: return "application/json" def create_serializer(self, sink: BytesWriter) -> "ShapeSerializer": - return _JSONShapeSerializer( - sink, - use_json_name=self._use_json_name, - use_timestamp_format=self._use_timestamp_format, - default_timestamp_format=self._default_timestamp_format, - ) + return _JSONShapeSerializer(sink, settings=self._settings) def create_deserializer(self, source: bytes | BytesReader) -> "ShapeDeserializer": if isinstance(source, bytes): source = BytesIO(source) - return _JSONShapeDeserializer( - source, - use_json_name=self._use_json_name, - use_timestamp_format=self._use_timestamp_format, - default_timestamp_format=self._default_timestamp_format, - ) + return _JSONShapeDeserializer(source, settings=self._settings) diff --git a/packages/smithy-json/src/smithy_json/_private/__init__.py b/packages/smithy-json/src/smithy_json/_private/__init__.py index d5e52e68c..04559d52a 100644 --- a/packages/smithy-json/src/smithy_json/_private/__init__.py +++ b/packages/smithy-json/src/smithy_json/_private/__init__.py @@ -1,6 +1,5 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 - from typing import Protocol, runtime_checkable diff --git a/packages/smithy-json/src/smithy_json/_private/deserializers.py b/packages/smithy-json/src/smithy_json/_private/deserializers.py index ac79f646e..bbbb16927 100644 --- a/packages/smithy-json/src/smithy_json/_private/deserializers.py +++ b/packages/smithy-json/src/smithy_json/_private/deserializers.py @@ -18,7 +18,7 @@ from smithy_core.traits import JSONNameTrait, TimestampFormatTrait from smithy_core.types import TimestampFormat -from .documents import JSONDocument +from ..settings import JSONSettings # TODO: put these type hints in a pyi somewhere. There here because ijson isn't # typed. @@ -88,18 +88,9 @@ def peek(self) -> JSONParseEvent: class JSONShapeDeserializer(ShapeDeserializer): - def __init__( - self, - source: BytesReader, - *, - use_json_name: bool = True, - use_timestamp_format: bool = True, - default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, - ) -> None: + def __init__(self, source: BytesReader, settings: JSONSettings) -> None: self._stream = BufferedParser(ijson.parse(source)) - self._use_json_name = use_json_name - self._use_timestamp_format = use_timestamp_format - self._default_timestamp_format = default_timestamp_format + self._settings = settings # A mapping of json name to member name for each shape. Since the deserializer # is shared and we don't know which shapes will be deserialized, this is @@ -164,12 +155,8 @@ def read_string(self, schema: Schema) -> str: def read_document(self, schema: Schema) -> Document: start = next(self._stream) if start.type not in ("start_map", "start_array"): - return JSONDocument( - start.value, - schema=schema, - use_json_name=self._use_json_name, - default_timestamp_format=self._default_timestamp_format, - use_timestamp_format=self._use_timestamp_format, + return self._settings.document_class( + value=start.value, schema=schema, settings=self._settings ) end_type = "end_map" if start.type == "start_map" else "end_array" @@ -180,17 +167,13 @@ def read_document(self, schema: Schema) -> Document: ).path != start.path or event.type != end_type: builder.event(event.type, event.value) - return JSONDocument( - builder.value, - schema=schema, - use_json_name=self._use_json_name, - default_timestamp_format=self._default_timestamp_format, - use_timestamp_format=self._use_timestamp_format, + return self._settings.document_class( + value=builder.value, schema=schema, settings=self._settings ) def read_timestamp(self, schema: Schema) -> datetime.datetime: - format = self._default_timestamp_format - if self._use_timestamp_format: + format = self._settings.default_timestamp_format + if self._settings.use_timestamp_format: if format_trait := schema.get_trait(TimestampFormatTrait): format = format_trait.format @@ -221,7 +204,7 @@ def read_struct( next(self._stream) def _resolve_member(self, schema: Schema, key: str) -> Schema | None: - if self._use_json_name: + if self._settings.use_json_name: if schema.id not in self._json_names: self._cache_json_names(schema=schema) if key in self._json_names[schema.id]: diff --git a/packages/smithy-json/src/smithy_json/_private/documents.py b/packages/smithy-json/src/smithy_json/_private/documents.py index 0ad962a43..7c0138515 100644 --- a/packages/smithy-json/src/smithy_json/_private/documents.py +++ b/packages/smithy-json/src/smithy_json/_private/documents.py @@ -9,32 +9,28 @@ from smithy_core.documents import Document, DocumentValue from smithy_core.prelude import DOCUMENT from smithy_core.schemas import Schema -from smithy_core.shapes import ShapeType +from smithy_core.shapes import ShapeID, ShapeType from smithy_core.traits import JSONNameTrait, TimestampFormatTrait -from smithy_core.types import TimestampFormat from smithy_core.utils import expect_type +from ..settings import JSONSettings + class JSONDocument(Document): - _schema: Schema - _json_names: dict[str, str] + _discriminator: ShapeID | None = None def __init__( self, value: DocumentValue | dict[str, "Document"] | list["Document"], + settings: JSONSettings | None = None, *, schema: Schema = DOCUMENT, - use_json_name: bool = True, - use_timestamp_format: bool = True, - default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, ) -> None: super().__init__(value, schema=schema) - self._use_json_name = use_json_name - self._use_timestamp_format = use_timestamp_format - self._default_timestamp_format = default_timestamp_format - self._json_names = {} + self._settings = settings or JSONSettings(document_class=type(self)) + self._json_names: dict[str, str] = {} - if use_json_name and schema.shape_type in ( + if self._settings.use_json_name and schema.shape_type in ( ShapeType.STRUCTURE, ShapeType.UNION, ): @@ -51,8 +47,8 @@ def as_float(self) -> float: return float(expect_type(Decimal, self._value)) def as_timestamp(self) -> datetime: - format = self._default_timestamp_format - if self._use_timestamp_format: + format = self._settings.default_timestamp_format + if self._settings.use_timestamp_format: if format_trait := self._schema.get_trait(TimestampFormatTrait): format = format_trait.format @@ -106,13 +102,7 @@ def _new_document( value: DocumentValue | dict[str, "Document"] | list["Document"], schema: Schema, ) -> "Document": - return JSONDocument( - value, - schema=schema, - use_json_name=self._use_json_name, - use_timestamp_format=self._use_timestamp_format, - default_timestamp_format=self._default_timestamp_format, - ) + return JSONDocument(value, schema=schema, settings=self._settings) def _wrap_map(self, value: Mapping[str, DocumentValue]) -> dict[str, "Document"]: if self._schema.shape_type not in (ShapeType.STRUCTURE, ShapeType.UNION): diff --git a/packages/smithy-json/src/smithy_json/_private/serializers.py b/packages/smithy-json/src/smithy_json/_private/serializers.py index 9c56a9dad..c1cd3df70 100644 --- a/packages/smithy-json/src/smithy_json/_private/serializers.py +++ b/packages/smithy-json/src/smithy_json/_private/serializers.py @@ -21,6 +21,7 @@ from smithy_core.traits import JSONNameTrait, TimestampFormatTrait from smithy_core.types import TimestampFormat +from ..settings import JSONSettings from . import Flushable _INF: float = float("inf") @@ -28,27 +29,14 @@ class JSONShapeSerializer(ShapeSerializer): - _stream: "StreamingJSONEncoder" - _use_json_name: bool - _use_timestamp_format: bool - _default_timestamp_format: TimestampFormat - - def __init__( - self, - sink: BytesWriter, - use_json_name: bool = True, - use_timestamp_format: bool = True, - default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, - ) -> None: + def __init__(self, sink: BytesWriter, settings: JSONSettings) -> None: self._stream = StreamingJSONEncoder(sink) - self._use_json_name = use_json_name - self._use_timestamp_format = use_timestamp_format - self._default_timestamp_format = default_timestamp_format + self._settings = settings def begin_struct( self, schema: "Schema" ) -> AbstractContextManager["ShapeSerializer"]: - return JSONStructSerializer(self._stream, self, self._use_json_name) + return JSONStructSerializer(self._stream, self, self._settings.use_json_name) def begin_list( self, schema: "Schema", size: int @@ -82,8 +70,8 @@ def write_blob(self, schema: "Schema", value: bytes) -> None: self._stream.write_string(b64encode(value).decode("utf-8")) def write_timestamp(self, schema: "Schema", value: datetime) -> None: - format = self._default_timestamp_format - if self._use_timestamp_format: + format = self._settings.default_timestamp_format + if self._settings.use_timestamp_format: if format_trait := schema.get_trait(TimestampFormatTrait): format = format_trait.format diff --git a/packages/smithy-json/src/smithy_json/settings.py b/packages/smithy-json/src/smithy_json/settings.py new file mode 100644 index 000000000..6ebde62cf --- /dev/null +++ b/packages/smithy-json/src/smithy_json/settings.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from smithy_core.types import TimestampFormat + +if TYPE_CHECKING: + from ._private.documents import JSONDocument + + +@dataclass(slots=True) +class JSONSettings: + """Settings for the JSON codec.""" + + document_class: type["JSONDocument"] + """The document class to deserialize to.""" + + use_json_name: bool = True + """Whether the codec should use `smithy.api#jsonName` trait, if present.""" + + use_timestamp_format: bool = True + """Whether the codec should use the `smithy.api#timestampFormat` trait, if + present.""" + + default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME + """The default timestamp format to use if the `smithy.api#timestampFormat` trait is + not enabled or not present.""" + + default_namespace: str | None = None + """The default namespace to use when determining a document's discriminator.""" diff --git a/packages/smithy-json/tests/unit/test_deserializers.py b/packages/smithy-json/tests/unit/test_deserializers.py index cee2fcd27..f67e9c380 100644 --- a/packages/smithy-json/tests/unit/test_deserializers.py +++ b/packages/smithy-json/tests/unit/test_deserializers.py @@ -9,12 +9,13 @@ BIG_DECIMAL, BLOB, BOOLEAN, + DOCUMENT, FLOAT, INTEGER, STRING, TIMESTAMP, ) -from smithy_json import JSONCodec +from smithy_json import JSONCodec, JSONDocument from . import ( JSON_SERDE_CASES, @@ -88,3 +89,13 @@ def _read_optional_map(k: str, d: ShapeDeserializer): assert actual_value == expected_value else: assert actual == expected + + +class CustomDocument(JSONDocument): + pass + + +def test_uses_custom_document() -> None: + codec = JSONCodec(document_class=CustomDocument) + actual = codec.create_deserializer(b'{"foo": "bar"}').read_document(DOCUMENT) + assert isinstance(actual, CustomDocument)