Skip to content

Commit 41794db

Browse files
committed
Revert "Noah feedback"
This reverts commit 4e02b3f.
1 parent e8df6d8 commit 41794db

2 files changed

Lines changed: 93 additions & 26 deletions

File tree

test/asynchronous/test_async_network_layer.py

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@
3030
from pymongo.network_layer import PyMongoProtocol, _async_socket_receive
3131

3232

33-
def _make_protocol(timeout=None):
34-
# PyMongoProtocol.__init__ calls asyncio.get_running_loop(), so this helper
35-
# must be called from inside an async test method.
33+
async def _make_protocol(timeout=None):
3634
protocol = PyMongoProtocol(timeout=timeout)
3735
mock_transport = MagicMock()
3836
mock_transport.is_closing.return_value = False
@@ -45,88 +43,109 @@ def _make_header(length, request_id, response_to, op_code):
4543

4644

4745
class TestPyMongoProtocol(AsyncUnitTest):
48-
def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
49-
protocol = _make_protocol()
46+
async def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
47+
protocol = await _make_protocol()
5048
protocol._max_message_size = max_size
5149
protocol._header = memoryview(bytearray(header_bytes))
5250
return protocol
5351

54-
def test_normal_op_msg(self):
52+
async def test_normal_op_msg(self):
5553
header = _make_header(length=32, request_id=1, response_to=99, op_code=2013)
56-
protocol = self._make_proto_with_header(header)
54+
protocol = await self._make_proto_with_header(header)
5755
body_len, op_code, response_to, expecting_compression = protocol.process_header()
5856
self.assertEqual(body_len, 16)
5957
self.assertEqual(op_code, 2013)
6058
self.assertEqual(response_to, 99)
6159
self.assertFalse(expecting_compression)
6260

63-
def test_op_compressed(self):
61+
async def test_op_compressed(self):
6462
# OP_COMPRESSED=2012; process_header strips the 9-byte compression sub-header
6563
# (op code + uncompressed size + compressor id), then the 16-byte standard header.
6664
# length=35 → after compression sub-header: 26 → body: 10
6765
header = _make_header(length=35, request_id=1, response_to=0, op_code=2012)
68-
protocol = self._make_proto_with_header(header)
66+
protocol = await self._make_proto_with_header(header)
6967
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
7068
self.assertEqual(body_len, 10)
7169
self.assertEqual(op_code, 2012)
7270
self.assertTrue(expecting_compression)
7371

74-
def test_op_compressed_length_too_small_raises(self):
72+
async def test_op_compressed_length_too_small_raises(self):
7573
header = _make_header(length=25, request_id=1, response_to=0, op_code=2012)
76-
protocol = self._make_proto_with_header(header)
74+
protocol = await self._make_proto_with_header(header)
7775
with self.assertRaises(ProtocolError):
7876
protocol.process_header()
7977

80-
def test_non_compressed_length_too_small_raises(self):
78+
async def test_non_compressed_length_too_small_raises(self):
8179
header = _make_header(length=16, request_id=1, response_to=0, op_code=2013)
82-
protocol = self._make_proto_with_header(header)
80+
protocol = await self._make_proto_with_header(header)
8381
with self.assertRaises(ProtocolError):
8482
protocol.process_header()
8583

86-
def test_length_exceeds_max_raises(self):
84+
async def test_length_exceeds_max_raises(self):
8785
header = _make_header(
8886
length=MAX_MESSAGE_SIZE + 1, request_id=1, response_to=0, op_code=2013
8987
)
90-
protocol = self._make_proto_with_header(header)
88+
protocol = await self._make_proto_with_header(header)
9189
with self.assertRaises(ProtocolError):
9290
protocol.process_header()
9391

94-
def test_op_reply_op_code(self):
92+
async def test_op_reply_op_code(self):
9593
header = _make_header(length=20, request_id=0, response_to=0, op_code=1)
96-
protocol = self._make_proto_with_header(header)
94+
protocol = await self._make_proto_with_header(header)
9795
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
9896
self.assertEqual(body_len, 4)
9997
self.assertEqual(op_code, 1)
10098
self.assertFalse(expecting_compression)
10199

102-
def test_compression_header_snappy_compressor_id(self):
103-
protocol = _make_protocol()
100+
async def test_compression_header_snappy_compressor_id(self):
101+
protocol = await _make_protocol()
104102
# <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
105103
data = struct.pack("<iiB", 2013, 0, 1)
106104
protocol._compression_header = memoryview(bytearray(data))
107105
op_code, compressor_id = protocol.process_compression_header()
108106
self.assertEqual(op_code, 2013)
109107
self.assertEqual(compressor_id, 1)
110108

111-
def test_compression_header_zlib_compressor_id(self):
112-
protocol = _make_protocol()
109+
async def test_compression_header_zlib_compressor_id(self):
110+
protocol = await _make_protocol()
113111
data = struct.pack("<iiB", 2013, 0, 2)
114112
protocol._compression_header = memoryview(bytearray(data))
115113
_, compressor_id = protocol.process_compression_header()
116114
self.assertEqual(compressor_id, 2)
117115

118-
def test_close_aborts_transport(self):
119-
protocol = _make_protocol()
116+
async def test_message_complete_resolves_pending_future(self):
117+
protocol = await _make_protocol()
118+
protocol._expecting_header = False
119+
protocol._expecting_compression = False
120+
protocol._message_size = 10
121+
protocol._message = memoryview(bytearray(10))
122+
protocol._message_index = 0
123+
protocol._op_code = 2013
124+
protocol._compressor_id = None
125+
protocol._response_to = 42
126+
127+
future = asyncio.get_running_loop().create_future()
128+
protocol._pending_messages.append(future)
129+
130+
protocol.buffer_updated(10)
131+
self.assertTrue(future.done())
132+
op_code, compressor_id, response_to, _ = future.result()
133+
self.assertEqual(op_code, 2013)
134+
self.assertIsNone(compressor_id)
135+
self.assertEqual(response_to, 42)
136+
137+
async def test_close_aborts_transport(self):
138+
protocol = await _make_protocol()
120139
protocol.close()
121140
self.assertTrue(protocol.transport.abort.called)
122141

123-
def test_connection_lost_twice_does_not_raise(self):
124-
protocol = _make_protocol()
142+
async def test_connection_lost_twice_does_not_raise(self):
143+
protocol = await _make_protocol()
125144
protocol.connection_lost(None)
126145
protocol.connection_lost(None)
127146

128147
async def test_close_with_exception_propagates_to_pending(self):
129-
protocol = _make_protocol()
148+
protocol = await _make_protocol()
130149
future = asyncio.get_running_loop().create_future()
131150
protocol._pending_messages.append(future)
132151
exc = OSError("connection reset")
@@ -136,6 +155,31 @@ async def test_close_with_exception_propagates_to_pending(self):
136155

137156

138157
class TestAsyncSocketReceive(AsyncUnitTest):
158+
async def test_reads_data_in_multiple_chunks(self):
159+
# Covers the loop in _async_socket_receive that accumulates short reads
160+
# until the requested length has been received.
161+
data = b"abcdefgh"
162+
length = len(data)
163+
chunk1, chunk2 = data[:4], data[4:]
164+
mock_socket = MagicMock()
165+
loop = asyncio.get_running_loop()
166+
calls = 0
167+
168+
async def fake_recv_into(sock, buf):
169+
nonlocal calls
170+
if calls == 0:
171+
buf[: len(chunk1)] = chunk1
172+
calls += 1
173+
return len(chunk1)
174+
buf[: len(chunk2)] = chunk2
175+
calls += 1
176+
return len(chunk2)
177+
178+
with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
179+
result = await _async_socket_receive(mock_socket, length, loop)
180+
self.assertEqual(bytes(result), data)
181+
self.assertEqual(calls, 2)
182+
139183
async def test_raises_on_connection_closed(self):
140184
# Covers the explicit `raise OSError("connection closed")` branch when
141185
# sock_recv_into returns 0.

test/test_network_layer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,29 @@ def test_unknown_opcode_raises(self):
117117

118118

119119
class TestReceiveData(UnitTest):
120+
def test_reads_data_in_multiple_chunks(self):
121+
# Covers the loop in receive_data that accumulates short reads until the
122+
# requested length has been received.
123+
data = b"abcdefgh"
124+
chunk1, chunk2 = data[:4], data[4:]
125+
conn = _make_conn()
126+
calls = 0
127+
128+
def fake_recv_into(buf):
129+
nonlocal calls
130+
if calls == 0:
131+
buf[: len(chunk1)] = chunk1
132+
calls += 1
133+
return len(chunk1)
134+
buf[: len(chunk2)] = chunk2
135+
calls += 1
136+
return len(chunk2)
137+
138+
conn.conn.recv_into.side_effect = fake_recv_into
139+
result = network_layer.receive_data(conn, len(data), deadline=None)
140+
self.assertEqual(bytes(result), data)
141+
self.assertEqual(calls, 2)
142+
120143
def test_raises_on_connection_closed(self):
121144
# Covers the explicit `raise OSError("connection closed")` branch when
122145
# recv_into returns 0.

0 commit comments

Comments
 (0)