Skip to content

Commit 62e6a6b

Browse files
committed
PYTHON-5781 Drop trivial tests, add sync-only receive_message/receive_data coverage
Remove the NetworkingInterfaceBase NotImplementedError tests and the PyMongoProtocol timeout getter/setter tests (implementation, not behavior), along with the now-empty synchro'd test_network_layer.py pair and its synchro.py registration (including the unused AsyncMock->MagicMock rule). Add a hand-maintained sync-only test/test_network_layer.py covering receive_message and receive_data, mirroring the async-only protocol tests. Neither file goes through synchro: the async (PyMongoProtocol) and sync (receive_message) paths share no implementation to mirror.
1 parent aeef0bf commit 62e6a6b

4 files changed

Lines changed: 120 additions & 114 deletions

File tree

test/asynchronous/test_async_network_layer.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,6 @@ async def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE)
4949
protocol._header = memoryview(bytearray(header_bytes))
5050
return protocol
5151

52-
async def test_initial_timeout_from_constructor(self):
53-
protocol = await _make_protocol(timeout=3.0)
54-
self.assertEqual(protocol.gettimeout, 3.0)
55-
56-
async def test_settimeout_updates_value(self):
57-
protocol = await _make_protocol()
58-
protocol.settimeout(7.5)
59-
self.assertEqual(protocol.gettimeout, 7.5)
60-
61-
async def test_default_timeout_is_none(self):
62-
protocol = await _make_protocol()
63-
self.assertIsNone(protocol.gettimeout)
64-
6552
async def test_normal_op_msg(self):
6653
header = _make_header(length=32, request_id=1, response_to=99, op_code=2013)
6754
protocol = await self._make_proto_with_header(header)

test/asynchronous/test_network_layer.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

test/test_network_layer.py

Lines changed: 120 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,52 +12,137 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Unit tests for code in network_layer.py shared between sync and async APIs.
15+
"""Sync-only unit tests for network_layer.py.
1616
17-
Async-only tests live in ``test_async_network_layer.py``.
17+
These cover ``receive_message`` and ``receive_data``, which only exist on the
18+
synchronous receive path (the async path uses ``PyMongoProtocol`` instead).
19+
The async-only tests live in ``test/asynchronous/test_async_network_layer.py``.
1820
"""
1921

2022
from __future__ import annotations
2123

24+
import struct
2225
import sys
23-
from unittest.mock import MagicMock
26+
from unittest.mock import MagicMock, patch
2427

2528
sys.path[0:0] = [""]
2629

2730
from test import UnitTest, unittest
2831

29-
from pymongo.network_layer import NetworkingInterfaceBase
30-
31-
_IS_SYNC = True
32-
33-
34-
class TestNetworkingInterfaceBase(UnitTest):
35-
def setUp(self):
36-
self.base = NetworkingInterfaceBase(MagicMock())
37-
38-
def test_gettimeout_raises(self):
39-
with self.assertRaises(NotImplementedError):
40-
_ = self.base.gettimeout
41-
42-
def test_settimeout_raises(self):
43-
with self.assertRaises(NotImplementedError):
44-
self.base.settimeout(1.0)
45-
46-
def test_close_raises(self):
47-
with self.assertRaises(NotImplementedError):
48-
self.base.close()
49-
50-
def test_is_closing_raises(self):
51-
with self.assertRaises(NotImplementedError):
52-
self.base.is_closing()
53-
54-
def test_get_conn_raises(self):
55-
with self.assertRaises(NotImplementedError):
56-
_ = self.base.get_conn
57-
58-
def test_sock_raises(self):
59-
with self.assertRaises(NotImplementedError):
60-
_ = self.base.sock
32+
from pymongo import network_layer
33+
from pymongo.common import MAX_MESSAGE_SIZE
34+
from pymongo.errors import ProtocolError
35+
36+
37+
def _make_header(length, request_id, response_to, op_code):
38+
return struct.pack("<iiii", length, request_id, response_to, op_code)
39+
40+
41+
def _make_compression_header(op_code, uncompressed_size, compressor_id):
42+
return struct.pack("<iiB", op_code, uncompressed_size, compressor_id)
43+
44+
45+
def _make_conn():
46+
conn = MagicMock()
47+
conn.conn.gettimeout.return_value = None
48+
return conn
49+
50+
51+
class TestReceiveMessage(UnitTest):
52+
def _patch_receive_data(self, *chunks):
53+
"""Make receive_data return the given byte strings on successive calls."""
54+
mock = patch.object(network_layer, "receive_data", side_effect=list(chunks))
55+
self.addCleanup(mock.stop)
56+
return mock.start()
57+
58+
def test_request_id_mismatch_raises(self):
59+
self._patch_receive_data(
60+
_make_header(length=32, request_id=0, response_to=99, op_code=2013)
61+
)
62+
with self.assertRaises(ProtocolError):
63+
network_layer.receive_message(_make_conn(), request_id=1)
64+
65+
def test_length_too_small_raises(self):
66+
self._patch_receive_data(_make_header(length=16, request_id=0, response_to=0, op_code=2013))
67+
with self.assertRaisesRegex(ProtocolError, "not longer than standard message header"):
68+
network_layer.receive_message(_make_conn(), request_id=None)
69+
70+
def test_length_exceeds_max_raises(self):
71+
self._patch_receive_data(
72+
_make_header(length=MAX_MESSAGE_SIZE + 1, request_id=0, response_to=0, op_code=2013)
73+
)
74+
with self.assertRaisesRegex(ProtocolError, "larger than server max"):
75+
network_layer.receive_message(_make_conn(), request_id=None)
76+
77+
def test_normal_op_msg_unpacks(self):
78+
body = b"x" * 16
79+
self._patch_receive_data(
80+
_make_header(length=32, request_id=0, response_to=0, op_code=2013), body
81+
)
82+
unpack = MagicMock(return_value="REPLY")
83+
with patch.object(network_layer, "_UNPACK_REPLY", {2013: unpack}):
84+
result = network_layer.receive_message(_make_conn(), request_id=None)
85+
unpack.assert_called_once_with(body)
86+
self.assertEqual(result, "REPLY")
87+
88+
def test_op_compressed_decompresses(self):
89+
# length=35 -> body length = 35 - 25 = 10 (header 16 + compression sub-header 9).
90+
compressed_body = b"y" * 10
91+
self._patch_receive_data(
92+
_make_header(length=35, request_id=0, response_to=0, op_code=2012),
93+
_make_compression_header(op_code=2013, uncompressed_size=0, compressor_id=1),
94+
compressed_body,
95+
)
96+
unpack = MagicMock(return_value="REPLY")
97+
with (
98+
patch.object(network_layer, "decompress", return_value=b"decompressed") as decompress,
99+
patch.object(network_layer, "_UNPACK_REPLY", {2013: unpack}),
100+
):
101+
result = network_layer.receive_message(_make_conn(), request_id=None)
102+
decompress.assert_called_once_with(compressed_body, 1)
103+
unpack.assert_called_once_with(b"decompressed")
104+
self.assertEqual(result, "REPLY")
105+
106+
def test_unknown_opcode_raises(self):
107+
self._patch_receive_data(
108+
_make_header(length=20, request_id=0, response_to=0, op_code=9999), b"data"
109+
)
110+
with patch.object(network_layer, "_UNPACK_REPLY", {2013: MagicMock()}):
111+
with self.assertRaises(ProtocolError):
112+
network_layer.receive_message(_make_conn(), request_id=None)
113+
114+
115+
class TestReceiveData(UnitTest):
116+
def test_reads_data_in_multiple_chunks(self):
117+
# Covers the loop in receive_data that accumulates short reads until the
118+
# requested length has been received.
119+
data = b"abcdefgh"
120+
chunk1, chunk2 = data[:4], data[4:]
121+
conn = _make_conn()
122+
calls = 0
123+
124+
def fake_recv_into(buf):
125+
nonlocal calls
126+
if calls == 0:
127+
buf[: len(chunk1)] = chunk1
128+
calls += 1
129+
return len(chunk1)
130+
buf[: len(chunk2)] = chunk2
131+
calls += 1
132+
return len(chunk2)
133+
134+
conn.conn.recv_into.side_effect = fake_recv_into
135+
result = network_layer.receive_data(conn, len(data), deadline=None)
136+
self.assertEqual(bytes(result), data)
137+
self.assertEqual(calls, 2)
138+
139+
def test_raises_on_connection_closed(self):
140+
# Covers the explicit `raise OSError("connection closed")` branch when
141+
# recv_into returns 0.
142+
conn = _make_conn()
143+
conn.conn.recv_into.return_value = 0
144+
with self.assertRaisesRegex(OSError, "connection closed"):
145+
network_layer.receive_data(conn, 10, deadline=None)
61146

62147

63148
if __name__ == "__main__":

tools/synchro.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@
128128
"SpecRunnerTask": "SpecRunnerThread",
129129
"AsyncMockConnection": "MockConnection",
130130
"AsyncMockPool": "MockPool",
131-
"AsyncMock": "MagicMock",
132131
"StopAsyncIteration": "StopIteration",
133132
"create_async_event": "create_event",
134133
"async_create_barrier": "create_barrier",
@@ -255,7 +254,6 @@ def async_only_test(f: str) -> bool:
255254
"test_monitor.py",
256255
"test_monitoring.py",
257256
"test_mongos_load_balancing.py",
258-
"test_network_layer.py",
259257
"test_on_demand_csfle.py",
260258
"test_periodic_executor.py",
261259
"test_pooling.py",

0 commit comments

Comments
 (0)