Skip to content

Commit 6b2d86b

Browse files
committed
PYTHON-5781 Address Noah's code review feedback
- Make _make_protocol and _make_proto_with_header sync (no await needed) - Remove await from test method calls to those helpers - Use assertRaisesRegex with specific messages for all ProtocolError tests - Remove test_op_reply_op_code (op_code is directly from struct unpack, no extra logic) - Remove test_compression_header_zlib_compressor_id (no code-path difference between compressors) - Remove test_connection_lost_twice_does_not_raise (tests internal implementation details)
1 parent 9ee2ab3 commit 6b2d86b

2 files changed

Lines changed: 15 additions & 35 deletions

File tree

test/asynchronous/test_async_network_layer.py

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from test.asynchronous import AsyncUnitTest, unittest
3030

3131

32-
async def _make_protocol(timeout=None):
32+
def _make_protocol(timeout=None):
3333
protocol = PyMongoProtocol(timeout=timeout)
3434
mock_transport = MagicMock()
3535
mock_transport.is_closing.return_value = False
@@ -42,15 +42,15 @@ def _make_header(length, request_id, response_to, op_code):
4242

4343

4444
class TestPyMongoProtocol(AsyncUnitTest):
45-
async def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
46-
protocol = await _make_protocol()
45+
def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
46+
protocol = _make_protocol()
4747
protocol._max_message_size = max_size
4848
protocol._header = memoryview(bytearray(header_bytes))
4949
return protocol
5050

5151
async def test_normal_op_msg(self):
5252
header = _make_header(length=32, request_id=1, response_to=99, op_code=2013)
53-
protocol = await self._make_proto_with_header(header)
53+
protocol = self._make_proto_with_header(header)
5454
body_len, op_code, response_to, expecting_compression = protocol.process_header()
5555
self.assertEqual(body_len, 16)
5656
self.assertEqual(op_code, 2013)
@@ -62,68 +62,48 @@ async def test_op_compressed(self):
6262
# (op code + uncompressed size + compressor id), then the 16-byte standard header.
6363
# length=35 → after compression sub-header: 26 → body: 10
6464
header = _make_header(length=35, request_id=1, response_to=0, op_code=2012)
65-
protocol = await self._make_proto_with_header(header)
65+
protocol = self._make_proto_with_header(header)
6666
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
6767
self.assertEqual(body_len, 10)
6868
self.assertEqual(op_code, 2012)
6969
self.assertTrue(expecting_compression)
7070

7171
async def test_op_compressed_length_too_small_raises(self):
7272
header = _make_header(length=25, request_id=1, response_to=0, op_code=2012)
73-
protocol = await self._make_proto_with_header(header)
74-
with self.assertRaises(ProtocolError):
73+
protocol = self._make_proto_with_header(header)
74+
with self.assertRaisesRegex(ProtocolError, "not longer than standard OP_COMPRESSED"):
7575
protocol.process_header()
7676

7777
async def test_non_compressed_length_too_small_raises(self):
7878
header = _make_header(length=16, request_id=1, response_to=0, op_code=2013)
79-
protocol = await self._make_proto_with_header(header)
80-
with self.assertRaises(ProtocolError):
79+
protocol = self._make_proto_with_header(header)
80+
with self.assertRaisesRegex(ProtocolError, "not longer than standard message header size"):
8181
protocol.process_header()
8282

8383
async def test_length_exceeds_max_raises(self):
8484
header = _make_header(
8585
length=MAX_MESSAGE_SIZE + 1, request_id=1, response_to=0, op_code=2013
8686
)
87-
protocol = await self._make_proto_with_header(header)
88-
with self.assertRaises(ProtocolError):
87+
protocol = self._make_proto_with_header(header)
88+
with self.assertRaisesRegex(ProtocolError, "larger than server max"):
8989
protocol.process_header()
9090

91-
async def test_op_reply_op_code(self):
92-
header = _make_header(length=20, request_id=0, response_to=0, op_code=1)
93-
protocol = await self._make_proto_with_header(header)
94-
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
95-
self.assertEqual(body_len, 4)
96-
self.assertEqual(op_code, 1)
97-
self.assertFalse(expecting_compression)
98-
9991
async def test_compression_header_snappy_compressor_id(self):
100-
protocol = await _make_protocol()
92+
protocol = _make_protocol()
10193
# <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
10294
data = struct.pack("<iiB", 2013, 0, 1)
10395
protocol._compression_header = memoryview(bytearray(data))
10496
op_code, compressor_id = protocol.process_compression_header()
10597
self.assertEqual(op_code, 2013)
10698
self.assertEqual(compressor_id, 1)
10799

108-
async def test_compression_header_zlib_compressor_id(self):
109-
protocol = await _make_protocol()
110-
data = struct.pack("<iiB", 2013, 0, 2)
111-
protocol._compression_header = memoryview(bytearray(data))
112-
_, compressor_id = protocol.process_compression_header()
113-
self.assertEqual(compressor_id, 2)
114-
115100
async def test_close_aborts_transport(self):
116-
protocol = await _make_protocol()
101+
protocol = _make_protocol()
117102
protocol.close()
118103
self.assertTrue(protocol.transport.abort.called)
119104

120-
async def test_connection_lost_twice_does_not_raise(self):
121-
protocol = await _make_protocol()
122-
protocol.connection_lost(None)
123-
protocol.connection_lost(None)
124-
125105
async def test_close_with_exception_propagates_to_pending(self):
126-
protocol = await _make_protocol()
106+
protocol = _make_protocol()
127107
future = asyncio.get_running_loop().create_future()
128108
protocol._pending_messages.append(future)
129109
exc = OSError("connection reset")

test/test_network_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_request_id_mismatch_raises(self):
6262
self._patch_receive_data(
6363
_make_header(length=32, request_id=0, response_to=99, op_code=2013)
6464
)
65-
with self.assertRaises(ProtocolError):
65+
with self.assertRaisesRegex(ProtocolError, "Got response id"):
6666
network_layer.receive_message(_make_conn(), request_id=1)
6767

6868
def test_length_too_small_raises(self):

0 commit comments

Comments
 (0)