diff --git a/packages/smithy-core/src/smithy_core/traits.py b/packages/smithy-core/src/smithy_core/traits.py index 4be81aba1..b2e6f795c 100644 --- a/packages/smithy-core/src/smithy_core/traits.py +++ b/packages/smithy-core/src/smithy_core/traits.py @@ -154,6 +154,12 @@ def __post_init__(self): assert self.document_value is None +@dataclass(init=False, frozen=True) +class RequiresLengthTrait(Trait, id=ShapeID("smithy.api#requiresLength")): + def __post_init__(self): + assert self.document_value is None + + @dataclass(init=False, frozen=True) class UnitTypeTrait(Trait, id=ShapeID("smithy.api#UnitTypeTrait")): def __post_init__(self): diff --git a/packages/smithy-http/src/smithy_http/serializers.py b/packages/smithy-http/src/smithy_http/serializers.py index 577ce24f0..b000f5549 100644 --- a/packages/smithy-http/src/smithy_http/serializers.py +++ b/packages/smithy-http/src/smithy_http/serializers.py @@ -9,6 +9,7 @@ from smithy_core import URI from smithy_core.codecs import Codec +from smithy_core.exceptions import SerializationError from smithy_core.schemas import Schema from smithy_core.serializers import ( InterceptingSerializer, @@ -24,12 +25,13 @@ HTTPQueryTrait, HTTPTrait, MediaTypeTrait, + RequiresLengthTrait, TimestampFormatTrait, ) from smithy_core.types import PathPattern, TimestampFormat from smithy_core.utils import serialize_float -from . import tuples_to_fields +from . import Field, tuples_to_fields from .aio import HTTPRequest as _HTTPRequest from .aio import HTTPResponse as _HTTPResponse from .aio.interfaces import HTTPRequest, HTTPResponse @@ -43,6 +45,7 @@ __all__ = ["HTTPRequestSerializer", "HTTPResponseSerializer"] +# TODO: refactor this to share code with response serializer class HTTPRequestSerializer(SpecificShapeSerializer): """Binds a serializable shape to an HTTP request. @@ -82,15 +85,24 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: host_prefix = self._endpoint_trait.host_prefix content_type = self._payload_codec.media_type + content_length: int | None = None + content_length_required = False binding_matcher = RequestBindingMatcher(schema) if (payload_member := binding_matcher.payload_member) is not None: - if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING): - content_type = ( - "application/octet-stream" - if payload_member.shape_type is ShapeType.BLOB - else "text/plain" - ) + content_length_required = RequiresLengthTrait in payload_member + if payload_member.shape_type in ( + ShapeType.BLOB, + ShapeType.STRING, + ShapeType.ENUM, + ): + if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None: + content_type = media_type.value + elif payload_member.shape_type is ShapeType.BLOB: + content_type = "application/octet-stream" + else: + content_type = "text/plain" + payload_serializer = RawPayloadSerializer() binding_serializer = HTTPRequestBindingSerializer( payload_serializer, @@ -100,6 +112,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: ) yield binding_serializer payload = payload_serializer.payload + try: + content_length = len(payload) + except TypeError: + pass else: if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None: content_type = media_type.value @@ -112,12 +128,14 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: binding_matcher, ) yield binding_serializer + content_length = payload.tell() + payload.seek(0) else: - if binding_matcher.event_stream_member is not None: - content_type = "application/vnd.amazon.eventstream" payload = BytesIO() payload_serializer = self._payload_codec.create_serializer(payload) if binding_matcher.should_write_body(self._omit_empty_payload): + if binding_matcher.event_stream_member is not None: + content_type = "application/vnd.amazon.eventstream" with payload_serializer.begin_struct(schema) as body_serializer: binding_serializer = HTTPRequestBindingSerializer( body_serializer, @@ -126,7 +144,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: binding_matcher, ) yield binding_serializer + content_length = payload.tell() else: + content_type = None + content_length = 0 binding_serializer = HTTPRequestBindingSerializer( payload_serializer, self._http_trait.path, @@ -134,15 +155,24 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: binding_matcher, ) yield binding_serializer + payload.seek(0) - if ( - seek := getattr(payload, "seek", None) - ) is not None and not iscoroutinefunction(seek): - seek(0) - - # TODO: conditional on empty-ness and based on the protocol headers = binding_serializer.header_serializer.headers - headers.append(("content-type", content_type)) + if content_type is not None: + headers.append(("content-type", content_type)) + + if content_length is not None: + headers.append(("content-length", str(content_length))) + + fields = tuples_to_fields(headers) + if content_length_required and "content-length" not in fields: + content_length = _compute_content_length(payload) + if content_length is None: + raise SerializationError( + "This operation requires the the content length of the input " + "stream, but it was not provided and was unable to be computed." + ) + fields.set_field(Field(name="content-length", values=[str(content_length)])) self.result = _HTTPRequest( method=self._http_trait.method, @@ -154,11 +184,30 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: prefix=self._http_trait.query or "", ), ), - fields=tuples_to_fields(headers), + fields=fields, body=payload, ) +def _compute_content_length(payload: Any) -> int | None: + if (tell := getattr(payload, "tell", None)) is not None and not iscoroutinefunction( + tell + ): + start: int = tell() + if (end := _seek(payload, 0, 2)) is not None: + content_length: int = end - start + _seek(payload, start, 0) + return content_length + return None + + +def _seek(payload: Any, pos: int, whence: int = 0) -> None: + if (seek := getattr(payload, "seek", None)) is not None and not iscoroutinefunction( + seek + ): + seek(pos, whence) + + class HTTPRequestBindingSerializer(InterceptingSerializer): """Delegates HTTP request bindings to binding-location-specific serializers.""" @@ -228,42 +277,79 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: payload: Any binding_serializer: HTTPResponseBindingSerializer + content_type: str | None = self._payload_codec.media_type + content_length: int | None = None + content_length_required = False + binding_matcher = ResponseBindingMatcher(schema) if (payload_member := binding_matcher.payload_member) is not None: + content_length_required = RequiresLengthTrait in payload_member if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING): + if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None: + content_type = media_type.value + elif payload_member.shape_type is ShapeType.BLOB: + content_type = "application/octet-stream" + else: + content_type = "text/plain" payload_serializer = RawPayloadSerializer() binding_serializer = HTTPResponseBindingSerializer( payload_serializer, binding_matcher ) yield binding_serializer payload = payload_serializer.payload + try: + content_length = len(payload) + except TypeError: + pass else: + if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None: + content_type = media_type.value payload = BytesIO() payload_serializer = self._payload_codec.create_serializer(payload) binding_serializer = HTTPResponseBindingSerializer( payload_serializer, binding_matcher ) yield binding_serializer + content_length = payload.tell() + payload.seek(0) else: payload = BytesIO() payload_serializer = self._payload_codec.create_serializer(payload) if binding_matcher.should_write_body(self._omit_empty_payload): + if binding_matcher.event_stream_member is not None: + content_type = "application/vnd.amazon.eventstream" with payload_serializer.begin_struct(schema) as body_serializer: binding_serializer = HTTPResponseBindingSerializer( body_serializer, binding_matcher ) yield binding_serializer + content_length = payload.tell() else: + content_type = None + content_length = 0 binding_serializer = HTTPResponseBindingSerializer( payload_serializer, binding_matcher, ) yield binding_serializer + payload.seek(0) - if ( - seek := getattr(payload, "seek", None) - ) is not None and not iscoroutinefunction(seek): - seek(0) + headers = binding_serializer.header_serializer.headers + if content_type is not None: + headers.append(("content-type", content_type)) + + if content_length is not None: + headers.append(("content-length", str(content_length))) + + fields = tuples_to_fields(headers) + if content_length_required and "content-length" not in fields: + content_length = _compute_content_length(payload) + if content_length is None: + raise SerializationError( + "This operation requires the the content length of the input " + "stream, but it was not provided and was unable to be computed." + ) + fields.set_field(Field(name="content-length", values=[str(content_length)])) status = binding_serializer.response_code_serializer.response_code if status is None: diff --git a/packages/smithy-http/tests/unit/test_serializers.py b/packages/smithy-http/tests/unit/test_serializers.py index cc396336f..d17df2348 100644 --- a/packages/smithy-http/tests/unit/test_serializers.py +++ b/packages/smithy-http/tests/unit/test_serializers.py @@ -40,7 +40,7 @@ TimestampFormatTrait, Trait, ) -from smithy_http import Field, Fields, tuples_to_fields +from smithy_http import Fields, tuples_to_fields from smithy_http.aio import HTTPResponse as _HTTPResponse from smithy_http.deserializers import HTTPResponseDeserializer from smithy_http.serializers import HTTPRequestSerializer, HTTPResponseSerializer @@ -1092,66 +1092,92 @@ def header_cases() -> list[HTTPMessageTestCase]: HTTPMessageTestCase( HTTPHeaders(boolean_member=True), HTTPMessage( - fields=tuples_to_fields([("boolean", "true")]), + fields=tuples_to_fields([("boolean", "true"), ("content-length", "0")]), ), ), HTTPMessageTestCase( HTTPHeaders(boolean_list_member=[True, False]), HTTPMessage( fields=tuples_to_fields( - [("booleanList", "true"), ("booleanList", "false")] + [ + ("booleanList", "true"), + ("booleanList", "false"), + ("content-length", "0"), + ] ), ), ), HTTPMessageTestCase( HTTPHeaders(integer_member=1), HTTPMessage( - fields=tuples_to_fields([("integer", "1")]), + fields=tuples_to_fields([("integer", "1"), ("content-length", "0")]), ), ), HTTPMessageTestCase( HTTPHeaders(integer_list_member=[1, 2]), HTTPMessage( - fields=tuples_to_fields([("integerList", "1"), ("integerList", "2")]), + fields=tuples_to_fields( + [ + ("integerList", "1"), + ("integerList", "2"), + ("content-length", "0"), + ] + ), ), ), HTTPMessageTestCase( HTTPHeaders(float_member=1.1), HTTPMessage( - fields=tuples_to_fields([("float", "1.1")]), + fields=tuples_to_fields([("float", "1.1"), ("content-length", "0")]), ), ), HTTPMessageTestCase( HTTPHeaders(float_list_member=[1.1, 2.2]), HTTPMessage( - fields=tuples_to_fields([("floatList", "1.1"), ("floatList", "2.2")]), + fields=tuples_to_fields( + [ + ("floatList", "1.1"), + ("floatList", "2.2"), + ("content-length", "0"), + ] + ), ), ), HTTPMessageTestCase( HTTPHeaders(big_decimal_member=Decimal("1.1")), HTTPMessage( - fields=tuples_to_fields([("bigDecimal", "1.1")]), + fields=tuples_to_fields( + [("bigDecimal", "1.1"), ("content-length", "0")] + ), ), ), HTTPMessageTestCase( HTTPHeaders(big_decimal_list_member=[Decimal("1.1"), Decimal("2.2")]), HTTPMessage( fields=tuples_to_fields( - [("bigDecimalList", "1.1"), ("bigDecimalList", "2.2")] + [ + ("bigDecimalList", "1.1"), + ("bigDecimalList", "2.2"), + ("content-length", "0"), + ] ), ), ), HTTPMessageTestCase( HTTPHeaders(string_member="foo"), HTTPMessage( - fields=tuples_to_fields([("string", "foo")]), + fields=tuples_to_fields([("string", "foo"), ("content-length", "0")]), ), ), HTTPMessageTestCase( HTTPHeaders(string_list_member=["spam", "eggs"]), HTTPMessage( fields=tuples_to_fields( - [("stringList", "spam"), ("stringList", "eggs")] + [ + ("stringList", "spam"), + ("stringList", "eggs"), + ("content-length", "0"), + ] ), ), ), @@ -1161,7 +1187,10 @@ def header_cases() -> list[HTTPMessageTestCase]: ), HTTPMessage( fields=tuples_to_fields( - [("defaultTimestamp", "Wed, 01 Jan 2025 00:00:00 GMT")] + [ + ("defaultTimestamp", "Wed, 01 Jan 2025 00:00:00 GMT"), + ("content-length", "0"), + ] ), ), ), @@ -1171,7 +1200,10 @@ def header_cases() -> list[HTTPMessageTestCase]: ), HTTPMessage( fields=tuples_to_fields( - [("httpDateTimestamp", "Wed, 01 Jan 2025 00:00:00 GMT")] + [ + ("httpDateTimestamp", "Wed, 01 Jan 2025 00:00:00 GMT"), + ("content-length", "0"), + ] ), ), ), @@ -1187,6 +1219,7 @@ def header_cases() -> list[HTTPMessageTestCase]: [ ("httpDateListTimestamp", "Wed, 01 Jan 2025 00:00:00 GMT"), ("httpDateListTimestamp", "Mon, 01 Jan 2024 00:00:00 GMT"), + ("content-length", "0"), ] ), ), @@ -1197,7 +1230,10 @@ def header_cases() -> list[HTTPMessageTestCase]: ), HTTPMessage( fields=tuples_to_fields( - [("dateTimeTimestamp", "2025-01-01T00:00:00Z")] + [ + ("dateTimeTimestamp", "2025-01-01T00:00:00Z"), + ("content-length", "0"), + ] ), ), ), @@ -1213,6 +1249,7 @@ def header_cases() -> list[HTTPMessageTestCase]: [ ("dateTimeListTimestamp", "2025-01-01T00:00:00Z"), ("dateTimeListTimestamp", "2024-01-01T00:00:00Z"), + ("content-length", "0"), ] ), ), @@ -1222,7 +1259,9 @@ def header_cases() -> list[HTTPMessageTestCase]: epoch_timestamp_member=datetime.datetime(2025, 1, 1, tzinfo=UTC) ), HTTPMessage( - fields=tuples_to_fields([("epochTimestamp", "1735689600")]), + fields=tuples_to_fields( + [("epochTimestamp", "1735689600"), ("content-length", "0")] + ), ), ), HTTPMessageTestCase( @@ -1237,6 +1276,7 @@ def header_cases() -> list[HTTPMessageTestCase]: [ ("epochListTimestamp", "1735689600"), ("epochListTimestamp", "1704067200"), + ("content-length", "0"), ] ), ), @@ -1244,7 +1284,9 @@ def header_cases() -> list[HTTPMessageTestCase]: HTTPMessageTestCase( HTTPHeaders(string_map_member={"foo": "bar", "baz": "bam"}), HTTPMessage( - fields=tuples_to_fields([("x-foo", "bar"), ("x-baz", "bam")]), + fields=tuples_to_fields( + [("x-foo", "bar"), ("x-baz", "bam"), ("content-length", "0")] + ), ), ), ] @@ -1342,7 +1384,12 @@ def empty_prefix_header_ser_cases() -> list[HTTPMessageTestCase]: ), HTTPMessage( fields=tuples_to_fields( - [("foo", "bar"), ("baz", "bam"), ("string", "string")] + [ + ("foo", "bar"), + ("baz", "bam"), + ("string", "string"), + ("content-length", "0"), + ] ), ), ), @@ -1354,11 +1401,21 @@ def empty_prefix_header_deser_cases() -> list[HTTPMessageTestCase]: HTTPMessageTestCase( HTTPEmptyPrefixHeaders( string_member="string", - string_map_member={"foo": "bar", "baz": "bam", "string": "string"}, + string_map_member={ + "foo": "bar", + "baz": "bam", + "string": "string", + "content-length": "0", + }, ), HTTPMessage( fields=tuples_to_fields( - [("foo", "bar"), ("baz", "bam"), ("string", "string")] + [ + ("foo", "bar"), + ("baz", "bam"), + ("string", "string"), + ("content-length", "0"), + ] ), ), ), @@ -1371,6 +1428,7 @@ def query_cases() -> list[HTTPMessageTestCase]: HTTPQuery(boolean_member=True), HTTPMessage( destination=URI(host="", path="/", query="boolean=true"), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( @@ -1379,55 +1437,75 @@ def query_cases() -> list[HTTPMessageTestCase]: destination=URI( host="", path="/", query="booleanList=true&booleanList=false" ), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( HTTPQuery(integer_member=1), - HTTPMessage(destination=URI(host="", path="/", query="integer=1")), + HTTPMessage( + destination=URI(host="", path="/", query="integer=1"), + fields=tuples_to_fields([("content-length", "0")]), + ), ), HTTPMessageTestCase( HTTPQuery(integer_list_member=[1, 2]), HTTPMessage( - destination=URI(host="", path="/", query="integerList=1&integerList=2") + destination=URI(host="", path="/", query="integerList=1&integerList=2"), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( HTTPQuery(float_member=1.1), - HTTPMessage(destination=URI(host="", path="/", query="float=1.1")), + HTTPMessage( + destination=URI(host="", path="/", query="float=1.1"), + fields=tuples_to_fields([("content-length", "0")]), + ), ), HTTPMessageTestCase( HTTPQuery(float_list_member=[1.1, 2.2]), HTTPMessage( - destination=URI(host="", path="/", query="floatList=1.1&floatList=2.2") + destination=URI(host="", path="/", query="floatList=1.1&floatList=2.2"), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( HTTPQuery(big_decimal_member=Decimal("1.1")), - HTTPMessage(destination=URI(host="", path="/", query="bigDecimal=1.1")), + HTTPMessage( + destination=URI(host="", path="/", query="bigDecimal=1.1"), + fields=tuples_to_fields([("content-length", "0")]), + ), ), HTTPMessageTestCase( HTTPQuery(big_decimal_list_member=[Decimal("1.1"), Decimal("2.2")]), HTTPMessage( destination=URI( host="", path="/", query="bigDecimalList=1.1&bigDecimalList=2.2" - ) + ), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( HTTPQuery(string_member="foo"), - HTTPMessage(destination=URI(host="", path="/", query="string=foo")), + HTTPMessage( + destination=URI(host="", path="/", query="string=foo"), + fields=tuples_to_fields([("content-length", "0")]), + ), ), HTTPMessageTestCase( HTTPQuery(string_list_member=["spam", "eggs"]), HTTPMessage( destination=URI( host="", path="/", query="stringList=spam&stringList=eggs" - ) + ), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( HTTPQuery(string_member="foo bar"), - HTTPMessage(destination=URI(host="", path="/", query="string=foo%20bar")), + HTTPMessage( + destination=URI(host="", path="/", query="string=foo%20bar"), + fields=tuples_to_fields([("content-length", "0")]), + ), ), HTTPMessageTestCase( HTTPQuery(string_list_member=["spam eggs", "eggs spam"]), @@ -1436,7 +1514,8 @@ def query_cases() -> list[HTTPMessageTestCase]: host="", path="/", query="stringList=spam%20eggs&stringList=eggs%20spam", - ) + ), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( @@ -1448,7 +1527,8 @@ def query_cases() -> list[HTTPMessageTestCase]: host="", path="/", query="defaultTimestamp=2025-01-01T00%3A00%3A00Z", - ) + ), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( @@ -1460,7 +1540,8 @@ def query_cases() -> list[HTTPMessageTestCase]: host="", path="/", query="httpDateTimestamp=Wed%2C%2001%20Jan%202025%2000%3A00%3A00%20GMT", - ) + ), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( @@ -1478,7 +1559,8 @@ def query_cases() -> list[HTTPMessageTestCase]: "httpDateListTimestamp=Wed%2C%2001%20Jan%202025%2000%3A00%3A00%20GMT" "&httpDateListTimestamp=Mon%2C%2001%20Jan%202024%2000%3A00%3A00%20GMT" ), - ) + ), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( @@ -1490,7 +1572,8 @@ def query_cases() -> list[HTTPMessageTestCase]: host="", path="/", query="dateTimeTimestamp=2025-01-01T00%3A00%3A00Z", - ) + ), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( @@ -1508,13 +1591,15 @@ def query_cases() -> list[HTTPMessageTestCase]: "dateTimeListTimestamp=2025-01-01T00%3A00%3A00Z" "&dateTimeListTimestamp=2024-01-01T00%3A00%3A00Z" ), - ) + ), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( HTTPQuery(epoch_timestamp_member=datetime.datetime(2025, 1, 1, tzinfo=UTC)), HTTPMessage( - destination=URI(host="", path="/", query="epochTimestamp=1735689600") + destination=URI(host="", path="/", query="epochTimestamp=1735689600"), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( @@ -1529,23 +1614,31 @@ def query_cases() -> list[HTTPMessageTestCase]: host="", path="/", query="epochListTimestamp=1735689600&epochListTimestamp=1704067200", - ) + ), + fields=tuples_to_fields([("content-length", "0")]), ), ), HTTPMessageTestCase( HTTPQuery(string_map_member={"foo": "bar", "baz": "bam"}), - HTTPMessage(destination=URI(host="", path="/", query="foo=bar&baz=bam")), + HTTPMessage( + destination=URI(host="", path="/", query="foo=bar&baz=bam"), + fields=tuples_to_fields([("content-length", "0")]), + ), ), HTTPMessageTestCase( HTTPQuery(string_member="foo"), HTTPMessage( - destination=URI(host="", path="/", query="spam=eggs&string=foo") + destination=URI(host="", path="/", query="spam=eggs&string=foo"), + fields=tuples_to_fields([("content-length", "0")]), ), http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/?spam=eggs"}), ), HTTPMessageTestCase( HTTPQuery(string_member="foo"), - HTTPMessage(destination=URI(host="", path="/", query="spam&string=foo")), + HTTPMessage( + destination=URI(host="", path="/", query="spam&string=foo"), + fields=tuples_to_fields([("content-length", "0")]), + ), http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/?spam"}), ), ] @@ -1555,42 +1648,66 @@ def label_cases() -> list[HTTPMessageTestCase]: return [ HTTPMessageTestCase( HTTPStringLabel(label="foo/bar"), - HTTPMessage(destination=URI(host="", path="/foo%2Fbar")), + HTTPMessage( + destination=URI(host="", path="/foo%2Fbar"), + fields=tuples_to_fields([("content-length", "0")]), + ), http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), ), HTTPMessageTestCase( HTTPStringLabel(label="foo/bar"), - HTTPMessage(destination=URI(host="", path="/foo/bar")), + HTTPMessage( + destination=URI(host="", path="/foo/bar"), + fields=tuples_to_fields([("content-length", "0")]), + ), http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label+}"}), ), HTTPMessageTestCase( HTTPFloatLabel(label=1.1), - HTTPMessage(destination=URI(host="", path="/1.1")), + HTTPMessage( + destination=URI(host="", path="/1.1"), + fields=tuples_to_fields([("content-length", "0")]), + ), http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), ), HTTPMessageTestCase( HTTPBigDecimalLabel(label=Decimal("1.1")), - HTTPMessage(destination=URI(host="", path="/1.1")), + HTTPMessage( + destination=URI(host="", path="/1.1"), + fields=tuples_to_fields([("content-length", "0")]), + ), http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), ), HTTPMessageTestCase( HTTPBooleanLabel(label=True), - HTTPMessage(destination=URI(host="", path="/true")), + HTTPMessage( + destination=URI(host="", path="/true"), + fields=tuples_to_fields([("content-length", "0")]), + ), http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), ), HTTPMessageTestCase( HTTPDefaultTimestampLabel(label=datetime.datetime(2025, 1, 1, tzinfo=UTC)), - HTTPMessage(destination=URI(host="", path="/2025-01-01T00%3A00%3A00Z")), + HTTPMessage( + destination=URI(host="", path="/2025-01-01T00%3A00%3A00Z"), + fields=tuples_to_fields([("content-length", "0")]), + ), http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), ), HTTPMessageTestCase( HTTPEpochTimestampLabel(label=datetime.datetime(2025, 1, 1, tzinfo=UTC)), - HTTPMessage(destination=URI(host="", path="/1735689600")), + HTTPMessage( + destination=URI(host="", path="/1735689600"), + fields=tuples_to_fields([("content-length", "0")]), + ), http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), ), HTTPMessageTestCase( HTTPDateTimeTimestampLabel(label=datetime.datetime(2025, 1, 1, tzinfo=UTC)), - HTTPMessage(destination=URI(host="", path="/2025-01-01T00%3A00%3A00Z")), + HTTPMessage( + destination=URI(host="", path="/2025-01-01T00%3A00%3A00Z"), + fields=tuples_to_fields([("content-length", "0")]), + ), http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), ), HTTPMessageTestCase( @@ -1598,7 +1715,8 @@ def label_cases() -> list[HTTPMessageTestCase]: HTTPMessage( destination=URI( host="", path="/Wed%2C%2001%20Jan%202025%2000%3A00%3A00%20GMT" - ) + ), + fields=tuples_to_fields([("content-length", "0")]), ), http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), ), @@ -1610,13 +1728,20 @@ def host_cases() -> list[HTTPMessageTestCase]: HTTPMessageTestCase( HostLabel("foo"), HTTPMessage( - destination=URI(host="foo.", path="/"), body=BytesIO(b'{"label":"foo"}') + destination=URI(host="foo.", path="/"), + body=BytesIO(b'{"label":"foo"}'), + fields=tuples_to_fields( + [("content-type", "application/json"), ("content-length", "15")] + ), ), endpoint_trait=EndpointTrait({"hostPrefix": "{label}."}), ), HTTPMessageTestCase( HTTPHeaders(), - HTTPMessage(destination=URI(host="foo.", path="/")), + HTTPMessage( + destination=URI(host="foo.", path="/"), + fields=tuples_to_fields([("content-length", "0")]), + ), endpoint_trait=EndpointTrait({"hostPrefix": "foo."}), ), ] @@ -1627,30 +1752,54 @@ def payload_cases() -> list[HTTPMessageTestCase]: HTTPMessageTestCase( HTTPImplicitPayload(header="foo", payload_member="bar"), HTTPMessage( - fields=tuples_to_fields([("header", "foo")]), + fields=tuples_to_fields( + [ + ("header", "foo"), + ("content-type", "application/json"), + ("content-length", "24"), + ] + ), body=BytesIO(b'{"payload_member":"bar"}'), ), ), HTTPMessageTestCase( HTTPStringPayload(payload="foo"), HTTPMessage( - fields=tuples_to_fields([("content-type", "text/plain")]), body=b"foo" + fields=tuples_to_fields( + [("content-type", "text/plain"), ("content-length", "3")] + ), + body=b"foo", ), ), HTTPMessageTestCase( HTTPBlobPayload(payload=b"\xde\xad\xbe\xef"), HTTPMessage( - fields=tuples_to_fields([("content-type", "application/octet-stream")]), + fields=tuples_to_fields( + [ + ("content-type", "application/octet-stream"), + ("content-length", "4"), + ] + ), body=b"\xde\xad\xbe\xef", ), ), HTTPMessageTestCase( HTTPStructuredPayload(payload=HTTPStringPayload(payload="foo")), - HTTPMessage(body=BytesIO(b'{"payload":"foo"}')), + HTTPMessage( + body=BytesIO(b'{"payload":"foo"}'), + fields=tuples_to_fields( + [("content-type", "application/json"), ("content-length", "17")] + ), + ), ), HTTPMessageTestCase( HTTPStructuredPayload(HTTPStringPayload()), - HTTPMessage(body=BytesIO(b"{}")), + HTTPMessage( + body=BytesIO(b"{}"), + fields=tuples_to_fields( + [("content-type", "application/json"), ("content-length", "2")] + ), + ), ), ] @@ -1720,8 +1869,6 @@ def async_streaming_payload_cases() -> list[HTTPMessageTestCase]: + async_streaming_payload_cases() ) -CONTENT_TYPE_FIELD = Field(name="content-type", values=["application/json"]) - @pytest.mark.parametrize("case", REQUEST_SER_CASES) async def test_serialize_http_request(case: HTTPMessageTestCase) -> None: @@ -1741,10 +1888,6 @@ async def test_serialize_http_request(case: HTTPMessageTestCase) -> None: actual_query = actual.destination.query or "" expected_query = case.request.destination.query or "" assert actual_query == expected_query - # set the content-type field here, otherwise cases would have to duplicate it everywhere, - # but if the field is already set in the case, don't override it - if expected.fields.get(CONTENT_TYPE_FIELD.name) is None: - expected.fields.set_field(CONTENT_TYPE_FIELD) assert actual.fields == expected.fields if case.request.body: @@ -1783,9 +1926,6 @@ async def test_serialize_http_response(case: HTTPMessageTestCase) -> None: expected = case.request assert actual is not None - # Remove content-type from expected, we're re-using the request cases for brevity - if expected.fields.get(CONTENT_TYPE_FIELD.name) is not None: - del expected.fields[CONTENT_TYPE_FIELD.name] assert actual.fields == expected.fields assert actual.status == expected.status