Skip to content

Commit 39a360c

Browse files
committed
Noah + Copilot review
1 parent d64fc74 commit 39a360c

4 files changed

Lines changed: 224 additions & 559 deletions

File tree

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright 2026-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Async-only unit tests for network_layer.py."""
16+
17+
from __future__ import annotations
18+
19+
import asyncio
20+
import struct
21+
import sys
22+
from unittest.mock import AsyncMock, MagicMock, patch
23+
24+
sys.path[0:0] = [""]
25+
26+
from test.asynchronous import AsyncUnitTest, unittest
27+
28+
from pymongo.common import MAX_MESSAGE_SIZE
29+
from pymongo.errors import ProtocolError
30+
from pymongo.network_layer import PyMongoProtocol, _async_socket_receive
31+
32+
33+
async def _make_protocol(timeout=None):
34+
protocol = PyMongoProtocol(timeout=timeout)
35+
mock_transport = MagicMock()
36+
mock_transport.is_closing.return_value = False
37+
protocol.transport = mock_transport
38+
return protocol
39+
40+
41+
def _make_header(length, request_id, response_to, op_code):
42+
return struct.pack("<iiii", length, request_id, response_to, op_code)
43+
44+
45+
class TestPyMongoProtocol(AsyncUnitTest):
46+
async def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
47+
protocol = await _make_protocol()
48+
protocol._max_message_size = max_size
49+
protocol._header = memoryview(bytearray(header_bytes))
50+
return protocol
51+
52+
async def test_initial_timeout_from_constructor(self):
53+
protocol = await _make_protocol(timeout=3.0)
54+
self.assertEqual(protocol.gettimeout, 3.0)
55+
56+
async def test_settimeout_updates_value(self):
57+
protocol = await _make_protocol()
58+
protocol.settimeout(7.5)
59+
self.assertEqual(protocol.gettimeout, 7.5)
60+
61+
async def test_default_timeout_is_none(self):
62+
protocol = await _make_protocol()
63+
self.assertIsNone(protocol.gettimeout)
64+
65+
async def test_normal_op_msg(self):
66+
header = _make_header(length=32, request_id=1, response_to=99, op_code=2013)
67+
protocol = await self._make_proto_with_header(header)
68+
body_len, op_code, response_to, expecting_compression = protocol.process_header()
69+
self.assertEqual(body_len, 16)
70+
self.assertEqual(op_code, 2013)
71+
self.assertEqual(response_to, 99)
72+
self.assertFalse(expecting_compression)
73+
74+
async def test_op_compressed(self):
75+
# OP_COMPRESSED=2012; process_header strips the 9-byte compression sub-header
76+
# (op code + uncompressed size + compressor id), then the 16-byte standard header.
77+
# length=35 → after compression sub-header: 26 → body: 10
78+
header = _make_header(length=35, request_id=1, response_to=0, op_code=2012)
79+
protocol = await self._make_proto_with_header(header)
80+
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
81+
self.assertEqual(body_len, 10)
82+
self.assertEqual(op_code, 2012)
83+
self.assertTrue(expecting_compression)
84+
85+
async def test_op_compressed_length_too_small_raises(self):
86+
header = _make_header(length=25, request_id=1, response_to=0, op_code=2012)
87+
protocol = await self._make_proto_with_header(header)
88+
with self.assertRaises(ProtocolError):
89+
protocol.process_header()
90+
91+
async def test_non_compressed_length_too_small_raises(self):
92+
header = _make_header(length=16, request_id=1, response_to=0, op_code=2013)
93+
protocol = await self._make_proto_with_header(header)
94+
with self.assertRaises(ProtocolError):
95+
protocol.process_header()
96+
97+
async def test_length_exceeds_max_raises(self):
98+
header = _make_header(
99+
length=MAX_MESSAGE_SIZE + 1, request_id=1, response_to=0, op_code=2013
100+
)
101+
protocol = await self._make_proto_with_header(header)
102+
with self.assertRaises(ProtocolError):
103+
protocol.process_header()
104+
105+
async def test_op_reply_op_code(self):
106+
header = _make_header(length=20, request_id=0, response_to=0, op_code=1)
107+
protocol = await self._make_proto_with_header(header)
108+
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
109+
self.assertEqual(body_len, 4)
110+
self.assertEqual(op_code, 1)
111+
self.assertFalse(expecting_compression)
112+
113+
async def test_compression_header_snappy_compressor_id(self):
114+
protocol = await _make_protocol()
115+
# <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
116+
data = struct.pack("<iiB", 2013, 0, 1)
117+
protocol._compression_header = memoryview(bytearray(data))
118+
op_code, compressor_id = protocol.process_compression_header()
119+
self.assertEqual(op_code, 2013)
120+
self.assertEqual(compressor_id, 1)
121+
122+
async def test_compression_header_zlib_compressor_id(self):
123+
protocol = await _make_protocol()
124+
data = struct.pack("<iiB", 2013, 0, 2)
125+
protocol._compression_header = memoryview(bytearray(data))
126+
_, compressor_id = protocol.process_compression_header()
127+
self.assertEqual(compressor_id, 2)
128+
129+
async def test_message_complete_resolves_pending_future(self):
130+
protocol = await _make_protocol()
131+
protocol._expecting_header = False
132+
protocol._expecting_compression = False
133+
protocol._message_size = 10
134+
protocol._message = memoryview(bytearray(10))
135+
protocol._message_index = 0
136+
protocol._op_code = 2013
137+
protocol._compressor_id = None
138+
protocol._response_to = 42
139+
140+
future = asyncio.get_running_loop().create_future()
141+
protocol._pending_messages.append(future)
142+
143+
protocol.buffer_updated(10)
144+
self.assertTrue(future.done())
145+
op_code, compressor_id, response_to, _ = future.result()
146+
self.assertEqual(op_code, 2013)
147+
self.assertIsNone(compressor_id)
148+
self.assertEqual(response_to, 42)
149+
150+
async def test_close_aborts_transport(self):
151+
protocol = await _make_protocol()
152+
protocol.close()
153+
self.assertTrue(protocol.transport.abort.called)
154+
155+
async def test_connection_lost_twice_does_not_raise(self):
156+
protocol = await _make_protocol()
157+
protocol.connection_lost(None)
158+
protocol.connection_lost(None)
159+
160+
async def test_close_with_exception_propagates_to_pending(self):
161+
protocol = await _make_protocol()
162+
future = asyncio.get_running_loop().create_future()
163+
protocol._pending_messages.append(future)
164+
exc = OSError("connection reset")
165+
protocol.close(exc)
166+
with self.assertRaisesRegex(OSError, "connection reset"):
167+
await future
168+
169+
170+
class TestAsyncSocketReceive(AsyncUnitTest):
171+
async def test_reads_data_in_multiple_chunks(self):
172+
# Covers the loop in _async_socket_receive that accumulates short reads
173+
# until the requested length has been received.
174+
data = b"abcdefgh"
175+
length = len(data)
176+
chunk1, chunk2 = data[:4], data[4:]
177+
mock_socket = MagicMock()
178+
loop = asyncio.get_running_loop()
179+
calls = 0
180+
181+
async def fake_recv_into(sock, buf):
182+
nonlocal calls
183+
if calls == 0:
184+
buf[: len(chunk1)] = chunk1
185+
calls += 1
186+
return len(chunk1)
187+
buf[: len(chunk2)] = chunk2
188+
calls += 1
189+
return len(chunk2)
190+
191+
with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
192+
result = await _async_socket_receive(mock_socket, length, loop)
193+
self.assertEqual(bytes(result), data)
194+
self.assertEqual(calls, 2)
195+
196+
async def test_raises_on_connection_closed(self):
197+
# Covers the explicit `raise OSError("connection closed")` branch when
198+
# sock_recv_into returns 0.
199+
mock_socket = MagicMock()
200+
loop = asyncio.get_running_loop()
201+
202+
async def fake_recv_into(sock, buf):
203+
return 0
204+
205+
with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
206+
with self.assertRaisesRegex(OSError, "connection closed"):
207+
await _async_socket_receive(mock_socket, 10, loop)
208+
209+
210+
if __name__ == "__main__":
211+
unittest.main()

0 commit comments

Comments
 (0)