Skip to content

Commit 771b7c4

Browse files
authored
Handle unknown event types gracefully instead of crashing (#680)
1 parent 58d5188 commit 771b7c4

6 files changed

Lines changed: 61 additions & 10 deletions

File tree

codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,10 @@ private void generateDeserializer() {
144144
writer.addImport("smithy_core.deserializers", "ShapeDeserializer");
145145
writer.addImport("smithy_core.exceptions", "SerializationError");
146146

147-
// TODO: add in unknown handling
148-
149147
var symbol = symbolProvider.toSymbol(shape);
150148
var deserializerSymbol = symbol.expectProperty(SymbolProperties.DESERIALIZER);
151149
var schemaSymbol = symbol.expectProperty(SymbolProperties.SCHEMA);
150+
var unknownSymbol = symbol.expectProperty(SymbolProperties.UNION_UNKNOWN);
152151
writer.putContext("schema", schemaSymbol);
153152
writer.write("""
154153
class $1L:
@@ -167,7 +166,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None:
167166
match schema.expect_member_index():
168167
${4C|}
169168
case _:
170-
logger.debug("Unexpected member schema: %s", schema)
169+
self._set_result($5L(tag=schema.expect_member_name()))
171170
172171
def _set_result(self, value: $2T) -> None:
173172
if self._result is not None:
@@ -177,7 +176,8 @@ raise SerializationError("Unions must have exactly one value, but found more tha
177176
deserializerSymbol.getName(),
178177
symbol,
179178
schemaSymbol,
180-
writer.consumer(w -> deserializeMembers()));
179+
writer.consumer(w -> deserializeMembers()),
180+
unknownSymbol.getName());
181181
}
182182

183183
private void deserializeMembers() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "enhancement",
3+
"description": "Handle unknown event types gracefully instead of crashing."
4+
}

packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def read_struct(
5050
message_deserializer = self._create_deserializer(schema, headers)
5151
message_deserializer.read_struct(schema, consumer)
5252
else:
53-
member_schema = schema.members[member_name]
53+
member_schema = self._resolve_member_schema(schema, member_name)
5454
message_deserializer = self._create_deserializer(
5555
member_schema, headers
5656
)
@@ -71,6 +71,21 @@ def read_struct(
7171
case _:
7272
raise EventError(f"Unknown event structure: {self._event}")
7373

74+
def _resolve_member_schema(self, schema: Schema, member_name: str) -> Schema:
75+
if member_schema := schema.members.get(member_name):
76+
return member_schema
77+
78+
logger.debug(
79+
"Received unmodeled event stream member %s for union %s",
80+
member_name,
81+
schema.id,
82+
)
83+
return Schema.member(
84+
id=schema.id.with_member(member_name),
85+
target=schema,
86+
index=-1,
87+
)
88+
7489
def _create_deserializer(
7590
self, schema: Schema, headers: HEADERS_DICT
7691
) -> ShapeDeserializer:

packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,9 @@ async def receive(self) -> E | None:
132132
)
133133
result = self._deserializer(deserializer)
134134
logger.debug("Successfully deserialized event: %s", result)
135-
if isinstance(getattr(result, "value"), Exception):
136-
raise result.value # type: ignore
135+
value = getattr(result, "value", None)
136+
if isinstance(value, Exception):
137+
raise value
137138
return result
138139

139140
async def close(self) -> None:

packages/smithy-aws-event-stream/tests/unit/_private/__init__.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def serialize_members(self, serializer: ShapeSerializer):
381381

382382

383383
@dataclass
384-
class EventStreamUnknownEvent:
384+
class EventStreamUnknown:
385385
tag: str
386386

387387
def serialize(self, serializer: ShapeSerializer):
@@ -396,7 +396,7 @@ def serialize_members(self, serializer: ShapeSerializer):
396396
| EventStreamPayloadEvent
397397
| EventStreamBlobPayloadEvent
398398
| EventStreamErrorEvent
399-
| EventStreamUnknownEvent
399+
| EventStreamUnknown
400400
)
401401

402402

@@ -429,7 +429,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None:
429429
self._set_result(EventStreamErrorEvent(ErrorEvent.deserialize(de)))
430430

431431
case _:
432-
raise SmithyError(f"Unexpected member schema: {schema}")
432+
self._set_result(EventStreamUnknown(tag=schema.expect_member_name()))
433433

434434
def _set_result(self, value: EventStream) -> None:
435435
if self._result is not None:
@@ -635,6 +635,19 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
635635
]
636636

637637

638+
UNKNOWN_EVENT_CASE = (
639+
EventStreamUnknown(tag="unmodeledEvent"),
640+
EventMessage(
641+
headers={
642+
":message-type": "event",
643+
":event-type": "unmodeledEvent",
644+
":content-type": "application/json",
645+
},
646+
payload=b"{}",
647+
),
648+
)
649+
650+
638651
INITIAL_REQUEST_CASE = (
639652
EventStreamOperationInputOutput(message="The initial request!"),
640653
EventMessage(

packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
EventStreamDeserializer,
2121
EventStreamErrorEvent,
2222
EventStreamOperationInputOutput,
23+
EventStreamUnknown,
2324
)
2425

2526

@@ -126,3 +127,20 @@ async def test_read_closed_receiver_source() -> None:
126127
with pytest.raises(IOError):
127128
await receiver.receive()
128129
assert receiver.closed
130+
131+
132+
def test_deserialize_unknown_event_type():
133+
message = EventMessage(
134+
headers={
135+
":message-type": "event",
136+
":event-type": "unmodeledEvent",
137+
":content-type": "application/json",
138+
},
139+
payload=b"{}",
140+
)
141+
source = Event.decode(BytesIO(message.encode()))
142+
assert source is not None
143+
deserializer = EventDeserializer(event=source, payload_codec=JSONCodec())
144+
result = EventStreamDeserializer().deserialize(deserializer)
145+
assert isinstance(result, EventStreamUnknown)
146+
assert result.tag == "unmodeledEvent"

0 commit comments

Comments
 (0)