-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathprotocols.py
More file actions
205 lines (177 loc) · 6.93 KB
/
protocols.py
File metadata and controls
205 lines (177 loc) · 6.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import os
from inspect import iscoroutinefunction
from io import BytesIO
from typing import Any
from smithy_core.aio.interfaces import ClientProtocol
from smithy_core.codecs import Codec
from smithy_core.deserializers import DeserializeableShape
from smithy_core.documents import TypeRegistry
from smithy_core.exceptions import CallError, ExpectationNotMetError, ModeledError
from smithy_core.interfaces import (
Endpoint,
SeekableBytesReader,
TypedProperties,
URI,
is_streaming_blob,
)
from smithy_core.interfaces import StreamingBlob as SyncStreamingBlob
from smithy_core.prelude import DOCUMENT
from smithy_core.schemas import APIOperation
from smithy_core.serializers import SerializeableShape
from smithy_core.traits import EndpointTrait, HTTPTrait
from ..deserializers import HTTPResponseDeserializer
from ..serializers import HTTPRequestSerializer
from .interfaces import HTTPErrorIdentifier, HTTPRequest, HTTPResponse
class HttpClientProtocol(ClientProtocol[HTTPRequest, HTTPResponse]):
"""An HTTP-based protocol."""
def set_service_endpoint(
self,
*,
request: HTTPRequest,
endpoint: Endpoint,
) -> HTTPRequest:
uri = endpoint.uri
uri_builder = request.destination
if uri.scheme:
uri_builder.scheme = uri.scheme
if uri.host:
uri_builder.host = uri.host
if uri.port and uri.port > -1:
uri_builder.port = uri.port
if uri.path:
uri_builder.path = os.path.join(uri.path, uri_builder.path or "")
# TODO: merge headers from the endpoint properties bag
return request
class HttpBindingClientProtocol(HttpClientProtocol):
"""An HTTP-based protocol that uses HTTP binding traits."""
@property
def payload_codec(self) -> Codec:
"""The codec used for the serde of input and output payloads."""
raise NotImplementedError()
@property
def content_type(self) -> str:
"""The media type of the http payload."""
raise NotImplementedError()
@property
def error_identifier(self) -> HTTPErrorIdentifier:
"""The class used to identify the shape IDs of errors based on fields or other
response information."""
raise NotImplementedError()
def serialize_request[
OperationInput: "SerializeableShape",
OperationOutput: "DeserializeableShape",
](
self,
*,
operation: APIOperation[OperationInput, OperationOutput],
input: OperationInput,
endpoint: URI,
context: TypedProperties,
) -> HTTPRequest:
# TODO(optimization): request binding cache like done in SJ
serializer = HTTPRequestSerializer(
payload_codec=self.payload_codec,
http_trait=operation.schema.expect_trait(HTTPTrait),
endpoint_trait=operation.schema.get_trait(EndpointTrait),
)
input.serialize(serializer=serializer)
request = serializer.result
if request is None:
raise ExpectationNotMetError(
"Expected request to be serialized, but was None"
)
return request
async def deserialize_response[
OperationInput: "SerializeableShape",
OperationOutput: "DeserializeableShape",
](
self,
*,
operation: APIOperation[OperationInput, OperationOutput],
request: HTTPRequest,
response: HTTPResponse,
error_registry: TypeRegistry,
context: TypedProperties,
) -> OperationOutput:
body = response.body
# if body is not streaming and is async, we have to buffer it
if not operation.output_stream_member and not is_streaming_blob(body):
if (
read := getattr(body, "read", None)
) is not None and iscoroutinefunction(read):
body = BytesIO(await read())
if not self._is_success(operation, context, response):
raise await self._create_error(
operation=operation,
request=request,
response=response,
response_body=body, # type: ignore
error_registry=error_registry,
context=context,
)
# TODO(optimization): response binding cache like done in SJ
deserializer = HTTPResponseDeserializer(
payload_codec=self.payload_codec,
http_trait=operation.schema.expect_trait(HTTPTrait),
response=response,
body=body, # type: ignore
)
return operation.output.deserialize(deserializer)
def _is_success(
self,
operation: APIOperation[Any, Any],
context: TypedProperties,
response: HTTPResponse,
) -> bool:
return 200 <= response.status < 300
async def _create_error(
self,
operation: APIOperation[Any, Any],
request: HTTPRequest,
response: HTTPResponse,
response_body: SyncStreamingBlob,
error_registry: TypeRegistry,
context: TypedProperties,
) -> CallError:
error_id = self.error_identifier.identify(
operation=operation, response=response
)
if error_id is None:
if isinstance(response_body, bytearray):
response_body = bytes(response_body)
deserializer = self.payload_codec.create_deserializer(source=response_body)
document = deserializer.read_document(schema=DOCUMENT)
if document.discriminator in error_registry:
error_id = document.discriminator
if isinstance(response_body, SeekableBytesReader):
response_body.seek(0)
if error_id is not None and error_id in error_registry:
error_shape = error_registry.get(error_id)
# make sure the error shape is derived from modeled exception
if not issubclass(error_shape, ModeledError):
raise ExpectationNotMetError(
f"Modeled errors must be derived from 'ModeledError', "
f"but got {error_shape}"
)
deserializer = HTTPResponseDeserializer(
payload_codec=self.payload_codec,
http_trait=operation.schema.expect_trait(HTTPTrait),
response=response,
body=response_body,
)
return error_shape.deserialize(deserializer)
is_throttle = response.status == 429
message = (
f"Unknown error for operation {operation.schema.id} "
f"- status: {response.status}"
)
if error_id is not None:
message += f" - id: {error_id}"
if response.reason is not None:
message += f" - reason: {response.status}"
return CallError(
message=message,
fault="client" if response.status < 500 else "server",
is_throttling_error=is_throttle,
is_retry_safe=is_throttle or None,
)