|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -"""Unit tests for code in network_layer.py shared between sync and async APIs. |
| 15 | +"""Sync-only unit tests for network_layer.py. |
16 | 16 |
|
17 | | -Async-only tests live in ``test_async_network_layer.py``. |
| 17 | +These cover ``receive_message`` and ``receive_data``, which only exist on the |
| 18 | +synchronous receive path (the async path uses ``PyMongoProtocol`` instead). |
| 19 | +The async-only tests live in ``test/asynchronous/test_async_network_layer.py``. |
18 | 20 | """ |
19 | 21 |
|
20 | 22 | from __future__ import annotations |
21 | 23 |
|
| 24 | +import struct |
22 | 25 | import sys |
23 | | -from unittest.mock import MagicMock |
| 26 | +from unittest.mock import MagicMock, patch |
24 | 27 |
|
25 | 28 | sys.path[0:0] = [""] |
26 | 29 |
|
27 | 30 | from test import UnitTest, unittest |
28 | 31 |
|
29 | | -from pymongo.network_layer import NetworkingInterfaceBase |
30 | | - |
31 | | -_IS_SYNC = True |
32 | | - |
33 | | - |
34 | | -class TestNetworkingInterfaceBase(UnitTest): |
35 | | - def setUp(self): |
36 | | - self.base = NetworkingInterfaceBase(MagicMock()) |
37 | | - |
38 | | - def test_gettimeout_raises(self): |
39 | | - with self.assertRaises(NotImplementedError): |
40 | | - _ = self.base.gettimeout |
41 | | - |
42 | | - def test_settimeout_raises(self): |
43 | | - with self.assertRaises(NotImplementedError): |
44 | | - self.base.settimeout(1.0) |
45 | | - |
46 | | - def test_close_raises(self): |
47 | | - with self.assertRaises(NotImplementedError): |
48 | | - self.base.close() |
49 | | - |
50 | | - def test_is_closing_raises(self): |
51 | | - with self.assertRaises(NotImplementedError): |
52 | | - self.base.is_closing() |
53 | | - |
54 | | - def test_get_conn_raises(self): |
55 | | - with self.assertRaises(NotImplementedError): |
56 | | - _ = self.base.get_conn |
57 | | - |
58 | | - def test_sock_raises(self): |
59 | | - with self.assertRaises(NotImplementedError): |
60 | | - _ = self.base.sock |
| 32 | +from pymongo import network_layer |
| 33 | +from pymongo.common import MAX_MESSAGE_SIZE |
| 34 | +from pymongo.errors import ProtocolError |
| 35 | + |
| 36 | + |
| 37 | +def _make_header(length, request_id, response_to, op_code): |
| 38 | + return struct.pack("<iiii", length, request_id, response_to, op_code) |
| 39 | + |
| 40 | + |
| 41 | +def _make_compression_header(op_code, uncompressed_size, compressor_id): |
| 42 | + return struct.pack("<iiB", op_code, uncompressed_size, compressor_id) |
| 43 | + |
| 44 | + |
| 45 | +def _make_conn(): |
| 46 | + conn = MagicMock() |
| 47 | + conn.conn.gettimeout.return_value = None |
| 48 | + return conn |
| 49 | + |
| 50 | + |
| 51 | +class TestReceiveMessage(UnitTest): |
| 52 | + def _patch_receive_data(self, *chunks): |
| 53 | + """Make receive_data return the given byte strings on successive calls.""" |
| 54 | + mock = patch.object(network_layer, "receive_data", side_effect=list(chunks)) |
| 55 | + self.addCleanup(mock.stop) |
| 56 | + return mock.start() |
| 57 | + |
| 58 | + def test_request_id_mismatch_raises(self): |
| 59 | + self._patch_receive_data( |
| 60 | + _make_header(length=32, request_id=0, response_to=99, op_code=2013) |
| 61 | + ) |
| 62 | + with self.assertRaises(ProtocolError): |
| 63 | + network_layer.receive_message(_make_conn(), request_id=1) |
| 64 | + |
| 65 | + def test_length_too_small_raises(self): |
| 66 | + self._patch_receive_data(_make_header(length=16, request_id=0, response_to=0, op_code=2013)) |
| 67 | + with self.assertRaisesRegex(ProtocolError, "not longer than standard message header"): |
| 68 | + network_layer.receive_message(_make_conn(), request_id=None) |
| 69 | + |
| 70 | + def test_length_exceeds_max_raises(self): |
| 71 | + self._patch_receive_data( |
| 72 | + _make_header(length=MAX_MESSAGE_SIZE + 1, request_id=0, response_to=0, op_code=2013) |
| 73 | + ) |
| 74 | + with self.assertRaisesRegex(ProtocolError, "larger than server max"): |
| 75 | + network_layer.receive_message(_make_conn(), request_id=None) |
| 76 | + |
| 77 | + def test_normal_op_msg_unpacks(self): |
| 78 | + body = b"x" * 16 |
| 79 | + self._patch_receive_data( |
| 80 | + _make_header(length=32, request_id=0, response_to=0, op_code=2013), body |
| 81 | + ) |
| 82 | + unpack = MagicMock(return_value="REPLY") |
| 83 | + with patch.object(network_layer, "_UNPACK_REPLY", {2013: unpack}): |
| 84 | + result = network_layer.receive_message(_make_conn(), request_id=None) |
| 85 | + unpack.assert_called_once_with(body) |
| 86 | + self.assertEqual(result, "REPLY") |
| 87 | + |
| 88 | + def test_op_compressed_decompresses(self): |
| 89 | + # length=35 -> body length = 35 - 25 = 10 (header 16 + compression sub-header 9). |
| 90 | + compressed_body = b"y" * 10 |
| 91 | + self._patch_receive_data( |
| 92 | + _make_header(length=35, request_id=0, response_to=0, op_code=2012), |
| 93 | + _make_compression_header(op_code=2013, uncompressed_size=0, compressor_id=1), |
| 94 | + compressed_body, |
| 95 | + ) |
| 96 | + unpack = MagicMock(return_value="REPLY") |
| 97 | + with ( |
| 98 | + patch.object(network_layer, "decompress", return_value=b"decompressed") as decompress, |
| 99 | + patch.object(network_layer, "_UNPACK_REPLY", {2013: unpack}), |
| 100 | + ): |
| 101 | + result = network_layer.receive_message(_make_conn(), request_id=None) |
| 102 | + decompress.assert_called_once_with(compressed_body, 1) |
| 103 | + unpack.assert_called_once_with(b"decompressed") |
| 104 | + self.assertEqual(result, "REPLY") |
| 105 | + |
| 106 | + def test_unknown_opcode_raises(self): |
| 107 | + self._patch_receive_data( |
| 108 | + _make_header(length=20, request_id=0, response_to=0, op_code=9999), b"data" |
| 109 | + ) |
| 110 | + with patch.object(network_layer, "_UNPACK_REPLY", {2013: MagicMock()}): |
| 111 | + with self.assertRaises(ProtocolError): |
| 112 | + network_layer.receive_message(_make_conn(), request_id=None) |
| 113 | + |
| 114 | + |
| 115 | +class TestReceiveData(UnitTest): |
| 116 | + def test_reads_data_in_multiple_chunks(self): |
| 117 | + # Covers the loop in receive_data that accumulates short reads until the |
| 118 | + # requested length has been received. |
| 119 | + data = b"abcdefgh" |
| 120 | + chunk1, chunk2 = data[:4], data[4:] |
| 121 | + conn = _make_conn() |
| 122 | + calls = 0 |
| 123 | + |
| 124 | + def fake_recv_into(buf): |
| 125 | + nonlocal calls |
| 126 | + if calls == 0: |
| 127 | + buf[: len(chunk1)] = chunk1 |
| 128 | + calls += 1 |
| 129 | + return len(chunk1) |
| 130 | + buf[: len(chunk2)] = chunk2 |
| 131 | + calls += 1 |
| 132 | + return len(chunk2) |
| 133 | + |
| 134 | + conn.conn.recv_into.side_effect = fake_recv_into |
| 135 | + result = network_layer.receive_data(conn, len(data), deadline=None) |
| 136 | + self.assertEqual(bytes(result), data) |
| 137 | + self.assertEqual(calls, 2) |
| 138 | + |
| 139 | + def test_raises_on_connection_closed(self): |
| 140 | + # Covers the explicit `raise OSError("connection closed")` branch when |
| 141 | + # recv_into returns 0. |
| 142 | + conn = _make_conn() |
| 143 | + conn.conn.recv_into.return_value = 0 |
| 144 | + with self.assertRaisesRegex(OSError, "connection closed"): |
| 145 | + network_layer.receive_data(conn, 10, deadline=None) |
61 | 146 |
|
62 | 147 |
|
63 | 148 | if __name__ == "__main__": |
|
0 commit comments