Skip to content

Commit 4e02b3f

Browse files
committed
Noah feedback
1 parent fccf737 commit 4e02b3f

2 files changed

Lines changed: 26 additions & 93 deletions

File tree

test/asynchronous/test_async_network_layer.py

Lines changed: 26 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
from pymongo.network_layer import PyMongoProtocol, _async_socket_receive
3131

3232

33-
async def _make_protocol(timeout=None):
33+
def _make_protocol(timeout=None):
34+
# PyMongoProtocol.__init__ calls asyncio.get_running_loop(), so this helper
35+
# must be called from inside an async test method.
3436
protocol = PyMongoProtocol(timeout=timeout)
3537
mock_transport = MagicMock()
3638
mock_transport.is_closing.return_value = False
@@ -43,109 +45,88 @@ def _make_header(length, request_id, response_to, op_code):
4345

4446

4547
class TestPyMongoProtocol(AsyncUnitTest):
46-
async def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
47-
protocol = await _make_protocol()
48+
def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
49+
protocol = _make_protocol()
4850
protocol._max_message_size = max_size
4951
protocol._header = memoryview(bytearray(header_bytes))
5052
return protocol
5153

52-
async def test_normal_op_msg(self):
54+
def test_normal_op_msg(self):
5355
header = _make_header(length=32, request_id=1, response_to=99, op_code=2013)
54-
protocol = await self._make_proto_with_header(header)
56+
protocol = self._make_proto_with_header(header)
5557
body_len, op_code, response_to, expecting_compression = protocol.process_header()
5658
self.assertEqual(body_len, 16)
5759
self.assertEqual(op_code, 2013)
5860
self.assertEqual(response_to, 99)
5961
self.assertFalse(expecting_compression)
6062

61-
async def test_op_compressed(self):
63+
def test_op_compressed(self):
6264
# OP_COMPRESSED=2012; process_header strips the 9-byte compression sub-header
6365
# (op code + uncompressed size + compressor id), then the 16-byte standard header.
6466
# length=35 → after compression sub-header: 26 → body: 10
6567
header = _make_header(length=35, request_id=1, response_to=0, op_code=2012)
66-
protocol = await self._make_proto_with_header(header)
68+
protocol = self._make_proto_with_header(header)
6769
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
6870
self.assertEqual(body_len, 10)
6971
self.assertEqual(op_code, 2012)
7072
self.assertTrue(expecting_compression)
7173

72-
async def test_op_compressed_length_too_small_raises(self):
74+
def test_op_compressed_length_too_small_raises(self):
7375
header = _make_header(length=25, request_id=1, response_to=0, op_code=2012)
74-
protocol = await self._make_proto_with_header(header)
76+
protocol = self._make_proto_with_header(header)
7577
with self.assertRaises(ProtocolError):
7678
protocol.process_header()
7779

78-
async def test_non_compressed_length_too_small_raises(self):
80+
def test_non_compressed_length_too_small_raises(self):
7981
header = _make_header(length=16, request_id=1, response_to=0, op_code=2013)
80-
protocol = await self._make_proto_with_header(header)
82+
protocol = self._make_proto_with_header(header)
8183
with self.assertRaises(ProtocolError):
8284
protocol.process_header()
8385

84-
async def test_length_exceeds_max_raises(self):
86+
def test_length_exceeds_max_raises(self):
8587
header = _make_header(
8688
length=MAX_MESSAGE_SIZE + 1, request_id=1, response_to=0, op_code=2013
8789
)
88-
protocol = await self._make_proto_with_header(header)
90+
protocol = self._make_proto_with_header(header)
8991
with self.assertRaises(ProtocolError):
9092
protocol.process_header()
9193

92-
async def test_op_reply_op_code(self):
94+
def test_op_reply_op_code(self):
9395
header = _make_header(length=20, request_id=0, response_to=0, op_code=1)
94-
protocol = await self._make_proto_with_header(header)
96+
protocol = self._make_proto_with_header(header)
9597
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
9698
self.assertEqual(body_len, 4)
9799
self.assertEqual(op_code, 1)
98100
self.assertFalse(expecting_compression)
99101

100-
async def test_compression_header_snappy_compressor_id(self):
101-
protocol = await _make_protocol()
102+
def test_compression_header_snappy_compressor_id(self):
103+
protocol = _make_protocol()
102104
# <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
103105
data = struct.pack("<iiB", 2013, 0, 1)
104106
protocol._compression_header = memoryview(bytearray(data))
105107
op_code, compressor_id = protocol.process_compression_header()
106108
self.assertEqual(op_code, 2013)
107109
self.assertEqual(compressor_id, 1)
108110

109-
async def test_compression_header_zlib_compressor_id(self):
110-
protocol = await _make_protocol()
111+
def test_compression_header_zlib_compressor_id(self):
112+
protocol = _make_protocol()
111113
data = struct.pack("<iiB", 2013, 0, 2)
112114
protocol._compression_header = memoryview(bytearray(data))
113115
_, compressor_id = protocol.process_compression_header()
114116
self.assertEqual(compressor_id, 2)
115117

116-
async def test_message_complete_resolves_pending_future(self):
117-
protocol = await _make_protocol()
118-
protocol._expecting_header = False
119-
protocol._expecting_compression = False
120-
protocol._message_size = 10
121-
protocol._message = memoryview(bytearray(10))
122-
protocol._message_index = 0
123-
protocol._op_code = 2013
124-
protocol._compressor_id = None
125-
protocol._response_to = 42
126-
127-
future = asyncio.get_running_loop().create_future()
128-
protocol._pending_messages.append(future)
129-
130-
protocol.buffer_updated(10)
131-
self.assertTrue(future.done())
132-
op_code, compressor_id, response_to, _ = future.result()
133-
self.assertEqual(op_code, 2013)
134-
self.assertIsNone(compressor_id)
135-
self.assertEqual(response_to, 42)
136-
137-
async def test_close_aborts_transport(self):
138-
protocol = await _make_protocol()
118+
def test_close_aborts_transport(self):
119+
protocol = _make_protocol()
139120
protocol.close()
140121
self.assertTrue(protocol.transport.abort.called)
141122

142-
async def test_connection_lost_twice_does_not_raise(self):
143-
protocol = await _make_protocol()
123+
def test_connection_lost_twice_does_not_raise(self):
124+
protocol = _make_protocol()
144125
protocol.connection_lost(None)
145126
protocol.connection_lost(None)
146127

147128
async def test_close_with_exception_propagates_to_pending(self):
148-
protocol = await _make_protocol()
129+
protocol = _make_protocol()
149130
future = asyncio.get_running_loop().create_future()
150131
protocol._pending_messages.append(future)
151132
exc = OSError("connection reset")
@@ -155,31 +136,6 @@ async def test_close_with_exception_propagates_to_pending(self):
155136

156137

157138
class TestAsyncSocketReceive(AsyncUnitTest):
158-
async def test_reads_data_in_multiple_chunks(self):
159-
# Covers the loop in _async_socket_receive that accumulates short reads
160-
# until the requested length has been received.
161-
data = b"abcdefgh"
162-
length = len(data)
163-
chunk1, chunk2 = data[:4], data[4:]
164-
mock_socket = MagicMock()
165-
loop = asyncio.get_running_loop()
166-
calls = 0
167-
168-
async def fake_recv_into(sock, buf):
169-
nonlocal calls
170-
if calls == 0:
171-
buf[: len(chunk1)] = chunk1
172-
calls += 1
173-
return len(chunk1)
174-
buf[: len(chunk2)] = chunk2
175-
calls += 1
176-
return len(chunk2)
177-
178-
with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
179-
result = await _async_socket_receive(mock_socket, length, loop)
180-
self.assertEqual(bytes(result), data)
181-
self.assertEqual(calls, 2)
182-
183139
async def test_raises_on_connection_closed(self):
184140
# Covers the explicit `raise OSError("connection closed")` branch when
185141
# sock_recv_into returns 0.

test/test_network_layer.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -117,29 +117,6 @@ def test_unknown_opcode_raises(self):
117117

118118

119119
class TestReceiveData(UnitTest):
120-
def test_reads_data_in_multiple_chunks(self):
121-
# Covers the loop in receive_data that accumulates short reads until the
122-
# requested length has been received.
123-
data = b"abcdefgh"
124-
chunk1, chunk2 = data[:4], data[4:]
125-
conn = _make_conn()
126-
calls = 0
127-
128-
def fake_recv_into(buf):
129-
nonlocal calls
130-
if calls == 0:
131-
buf[: len(chunk1)] = chunk1
132-
calls += 1
133-
return len(chunk1)
134-
buf[: len(chunk2)] = chunk2
135-
calls += 1
136-
return len(chunk2)
137-
138-
conn.conn.recv_into.side_effect = fake_recv_into
139-
result = network_layer.receive_data(conn, len(data), deadline=None)
140-
self.assertEqual(bytes(result), data)
141-
self.assertEqual(calls, 2)
142-
143120
def test_raises_on_connection_closed(self):
144121
# Covers the explicit `raise OSError("connection closed")` branch when
145122
# recv_into returns 0.

0 commit comments

Comments
 (0)