Skip to content

Commit d1fc057

Browse files
committed
Address PR feedback
1 parent 697654f commit d1fc057

2 files changed

Lines changed: 38 additions & 20 deletions

File tree

packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -85,23 +85,26 @@ async def close(self) -> None:
8585
return
8686
self._closed = True
8787

88-
# Send a signed empty frame to signal stream completion.
89-
if self._signing_config is not None:
90-
end_frame = EventMessage()
91-
identity = await self._signing_config.identity_resolver.get_identity(
92-
properties=self._signing_config.identity_properties
93-
)
94-
end_frame = await self._signing_config.signer.sign_empty(
95-
event=end_frame,
96-
identity=identity,
97-
properties=self._signing_config.signing_properties,
98-
)
99-
logger.debug("Sending signed empty message to terminate the event stream.")
100-
await self._writer.write(end_frame.encode())
101-
102-
if (close := getattr(self._writer, "close", None)) is not None:
103-
if asyncio.iscoroutine(result := close()):
104-
await result
88+
try:
89+
# Send a signed empty frame to signal stream completion.
90+
if self._signing_config is not None:
91+
end_frame = EventMessage()
92+
identity = await self._signing_config.identity_resolver.get_identity(
93+
properties=self._signing_config.identity_properties
94+
)
95+
end_frame = await self._signing_config.signer.sign_empty(
96+
event=end_frame,
97+
identity=identity,
98+
properties=self._signing_config.signing_properties,
99+
)
100+
logger.debug(
101+
"Sending signed empty message to terminate the event stream."
102+
)
103+
await self._writer.write(end_frame.encode())
104+
finally:
105+
if (close := getattr(self._writer, "close", None)) is not None:
106+
if asyncio.iscoroutine(result := close()):
107+
await result
105108

106109
@property
107110
def closed(self) -> bool:

packages/smithy-aws-event-stream/tests/unit/_private/test_serializers.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3+
import asyncio
4+
from io import BytesIO
35
from typing import Any
46
from unittest.mock import AsyncMock
57

68
import pytest
79
from smithy_aws_event_stream._private.serializers import EventSerializer
810
from smithy_aws_event_stream.aio import AWSEventPublisher
9-
from smithy_aws_event_stream.events import EventMessage
11+
from smithy_aws_event_stream.events import Event, EventMessage
1012
from smithy_core.aio.types import AsyncBytesProvider
1113
from smithy_core.serializers import SerializeableShape
1214
from smithy_json import JSONCodec
@@ -81,7 +83,7 @@ async def test_send_to_closed_writer():
8183

8284

8385
async def test_close_sends_empty_end_frame_when_signing():
84-
writer = AsyncMock()
86+
writer = AsyncBytesProvider()
8587
end_frame = EventMessage()
8688
signing_config = AsyncMock()
8789
signing_config.signer.sign_empty.return_value = end_frame
@@ -90,8 +92,21 @@ async def test_close_sends_empty_end_frame_when_signing():
9092
payload_codec=JSONCodec(), async_writer=writer, signing_config=signing_config
9193
)
9294

95+
# Read from the writer concurrently with close(), since close() flushes
96+
# and blocks until all chunks are consumed.
97+
reader = asyncio.create_task(_read(writer))
9398
await publisher.close()
99+
written = await reader
94100

95101
signing_config.signer.sign_empty.assert_awaited_once()
96-
writer.write.assert_awaited_once_with(end_frame.encode())
97102
assert publisher.closed
103+
assert writer.closed
104+
105+
decoded = Event.decode(BytesIO(written))
106+
assert decoded is not None
107+
assert decoded.message.payload == b""
108+
assert decoded.message.headers == {}
109+
110+
111+
async def _read(source: AsyncBytesProvider) -> bytes:
112+
return b"".join([chunk async for chunk in source])

0 commit comments

Comments
 (0)