99
1010from smithy_core import URI
1111from smithy_core .codecs import Codec
12+ from smithy_core .exceptions import SerializationError
1213from smithy_core .schemas import Schema
1314from smithy_core .serializers import (
1415 InterceptingSerializer ,
2425 HTTPQueryTrait ,
2526 HTTPTrait ,
2627 MediaTypeTrait ,
28+ RequiresLengthTrait ,
2729 TimestampFormatTrait ,
2830)
2931from smithy_core .types import PathPattern , TimestampFormat
3032from smithy_core .utils import serialize_float
3133
32- from . import tuples_to_fields
34+ from . import Field , tuples_to_fields
3335from .aio import HTTPRequest as _HTTPRequest
3436from .aio import HTTPResponse as _HTTPResponse
3537from .aio .interfaces import HTTPRequest , HTTPResponse
4345__all__ = ["HTTPRequestSerializer" , "HTTPResponseSerializer" ]
4446
4547
48+ # TODO: refactor this to share code with response serializer
4649class HTTPRequestSerializer (SpecificShapeSerializer ):
4750 """Binds a serializable shape to an HTTP request.
4851
@@ -82,8 +85,12 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
8285 host_prefix = self ._endpoint_trait .host_prefix
8386
8487 content_type = self ._payload_codec .media_type
88+ content_length : int | None = None
89+ content_length_required = False
90+
8591 binding_matcher = RequestBindingMatcher (schema )
8692 if (payload_member := binding_matcher .payload_member ) is not None :
93+ content_length_required = RequiresLengthTrait in payload_member
8794 if payload_member .shape_type in (
8895 ShapeType .BLOB ,
8996 ShapeType .STRING ,
@@ -105,6 +112,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
105112 )
106113 yield binding_serializer
107114 payload = payload_serializer .payload
115+ try :
116+ content_length = len (payload )
117+ except TypeError :
118+ pass
108119 else :
109120 if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
110121 content_type = media_type .value
@@ -117,6 +128,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
117128 binding_matcher ,
118129 )
119130 yield binding_serializer
131+ content_length = payload .tell ()
120132 else :
121133 payload = BytesIO ()
122134 payload_serializer = self ._payload_codec .create_serializer (payload )
@@ -131,8 +143,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
131143 binding_matcher ,
132144 )
133145 yield binding_serializer
146+ content_length = payload .tell ()
134147 else :
135148 content_type = None
149+ content_length = 0
136150 binding_serializer = HTTPRequestBindingSerializer (
137151 payload_serializer ,
138152 self ._http_trait .path ,
@@ -141,15 +155,24 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
141155 )
142156 yield binding_serializer
143157
144- if (
145- seek := getattr (payload , "seek" , None )
146- ) is not None and not iscoroutinefunction (seek ):
147- seek (0 )
148-
158+ self ._seek (payload , 0 )
149159 headers = binding_serializer .header_serializer .headers
150160 if content_type is not None :
151161 headers .append (("content-type" , content_type ))
152162
163+ if content_length is not None :
164+ headers .append (("content-length" , str (content_length )))
165+
166+ fields = tuples_to_fields (headers )
167+ if content_length_required and "content-length" not in fields :
168+ content_length = self ._compute_content_length (payload )
169+ if content_length is None :
170+ raise SerializationError (
171+ "This operation requires the the content length of the input "
172+ "stream, but it was not provided and was unable to be computed."
173+ )
174+ fields .set_field (Field (name = "content-length" , values = [str (content_length )]))
175+
153176 self .result = _HTTPRequest (
154177 method = self ._http_trait .method ,
155178 destination = URI (
@@ -160,10 +183,22 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
160183 prefix = self ._http_trait .query or "" ,
161184 ),
162185 ),
163- fields = tuples_to_fields ( headers ) ,
186+ fields = fields ,
164187 body = payload ,
165188 )
166189
190+ def _seek (self , payload : Any , pos : int , whence : int = 0 ) -> None :
191+ if (
192+ seek := getattr (payload , "seek" , None )
193+ ) is not None and not iscoroutinefunction (seek ):
194+ seek (pos , whence )
195+
196+ def _compute_content_length (self , payload : Any ) -> int | None :
197+ content_length = self ._seek (payload , 0 , 2 )
198+ if content_length is not None :
199+ self ._seek (payload , 0 , 0 )
200+ return content_length
201+
167202
168203class HTTPRequestBindingSerializer (InterceptingSerializer ):
169204 """Delegates HTTP request bindings to binding-location-specific serializers."""
@@ -235,8 +270,12 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
235270 binding_serializer : HTTPResponseBindingSerializer
236271
237272 content_type : str | None = self ._payload_codec .media_type
273+ content_length : int | None = None
274+ content_length_required = False
275+
238276 binding_matcher = ResponseBindingMatcher (schema )
239277 if (payload_member := binding_matcher .payload_member ) is not None :
278+ content_length_required = RequiresLengthTrait in payload_member
240279 if payload_member .shape_type in (ShapeType .BLOB , ShapeType .STRING ):
241280 if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
242281 content_type = media_type .value
@@ -250,6 +289,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
250289 )
251290 yield binding_serializer
252291 payload = payload_serializer .payload
292+ try :
293+ content_length = len (payload )
294+ except TypeError :
295+ pass
253296 else :
254297 if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
255298 content_type = media_type .value
@@ -259,6 +302,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
259302 payload_serializer , binding_matcher
260303 )
261304 yield binding_serializer
305+ content_length = payload .tell ()
262306 else :
263307 payload = BytesIO ()
264308 payload_serializer = self ._payload_codec .create_serializer (payload )
@@ -270,23 +314,34 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
270314 body_serializer , binding_matcher
271315 )
272316 yield binding_serializer
317+ content_length = payload .tell ()
273318 else :
274319 content_type = None
320+ content_length = 0
275321 binding_serializer = HTTPResponseBindingSerializer (
276322 payload_serializer ,
277323 binding_matcher ,
278324 )
279325 yield binding_serializer
280326
281- if (
282- seek := getattr (payload , "seek" , None )
283- ) is not None and not iscoroutinefunction (seek ):
284- seek (0 )
285-
327+ self ._seek (payload , 0 )
286328 headers = binding_serializer .header_serializer .headers
287329 if content_type is not None :
288330 headers .append (("content-type" , content_type ))
289331
332+ if content_length is not None :
333+ headers .append (("content-length" , str (content_length )))
334+
335+ fields = tuples_to_fields (headers )
336+ if content_length_required and "content-length" not in fields :
337+ content_length = self ._compute_content_length (payload )
338+ if content_length is None :
339+ raise SerializationError (
340+ "This operation requires the the content length of the input "
341+ "stream, but it was not provided and was unable to be computed."
342+ )
343+ fields .set_field (Field (name = "content-length" , values = [str (content_length )]))
344+
290345 status = binding_serializer .response_code_serializer .response_code
291346 if status is None :
292347 if binding_matcher .response_status > 0 :
@@ -300,6 +355,18 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
300355 status = status ,
301356 )
302357
358+ def _seek (self , payload : Any , pos : int , whence : int = 0 ) -> int | None :
359+ if (
360+ seek := getattr (payload , "seek" , None )
361+ ) is not None and not iscoroutinefunction (seek ):
362+ return seek (pos , whence )
363+
364+ def _compute_content_length (self , payload : Any ) -> int | None :
365+ content_length = self ._seek (payload , 0 , 2 )
366+ if content_length is not None :
367+ self ._seek (payload , 0 , 0 )
368+ return content_length
369+
303370
304371class HTTPResponseBindingSerializer (InterceptingSerializer ):
305372 """Delegates HTTP response bindings to binding-location-specific serializers."""
0 commit comments