3030from pymongo .network_layer import PyMongoProtocol , _async_socket_receive
3131
3232
33- def _make_protocol (timeout = None ):
34- # PyMongoProtocol.__init__ calls asyncio.get_running_loop(), so this helper
35- # must be called from inside an async test method.
33+ async def _make_protocol (timeout = None ):
3634 protocol = PyMongoProtocol (timeout = timeout )
3735 mock_transport = MagicMock ()
3836 mock_transport .is_closing .return_value = False
@@ -45,88 +43,109 @@ def _make_header(length, request_id, response_to, op_code):
4543
4644
4745class TestPyMongoProtocol (AsyncUnitTest ):
48- def _make_proto_with_header (self , header_bytes , max_size = MAX_MESSAGE_SIZE ):
49- protocol = _make_protocol ()
46+ async def _make_proto_with_header (self , header_bytes , max_size = MAX_MESSAGE_SIZE ):
47+ protocol = await _make_protocol ()
5048 protocol ._max_message_size = max_size
5149 protocol ._header = memoryview (bytearray (header_bytes ))
5250 return protocol
5351
54- def test_normal_op_msg (self ):
52+ async def test_normal_op_msg (self ):
5553 header = _make_header (length = 32 , request_id = 1 , response_to = 99 , op_code = 2013 )
56- protocol = self ._make_proto_with_header (header )
54+ protocol = await self ._make_proto_with_header (header )
5755 body_len , op_code , response_to , expecting_compression = protocol .process_header ()
5856 self .assertEqual (body_len , 16 )
5957 self .assertEqual (op_code , 2013 )
6058 self .assertEqual (response_to , 99 )
6159 self .assertFalse (expecting_compression )
6260
63- def test_op_compressed (self ):
61+ async def test_op_compressed (self ):
6462 # OP_COMPRESSED=2012; process_header strips the 9-byte compression sub-header
6563 # (op code + uncompressed size + compressor id), then the 16-byte standard header.
6664 # length=35 → after compression sub-header: 26 → body: 10
6765 header = _make_header (length = 35 , request_id = 1 , response_to = 0 , op_code = 2012 )
68- protocol = self ._make_proto_with_header (header )
66+ protocol = await self ._make_proto_with_header (header )
6967 body_len , op_code , _response_to , expecting_compression = protocol .process_header ()
7068 self .assertEqual (body_len , 10 )
7169 self .assertEqual (op_code , 2012 )
7270 self .assertTrue (expecting_compression )
7371
74- def test_op_compressed_length_too_small_raises (self ):
72+ async def test_op_compressed_length_too_small_raises (self ):
7573 header = _make_header (length = 25 , request_id = 1 , response_to = 0 , op_code = 2012 )
76- protocol = self ._make_proto_with_header (header )
74+ protocol = await self ._make_proto_with_header (header )
7775 with self .assertRaises (ProtocolError ):
7876 protocol .process_header ()
7977
80- def test_non_compressed_length_too_small_raises (self ):
78+ async def test_non_compressed_length_too_small_raises (self ):
8179 header = _make_header (length = 16 , request_id = 1 , response_to = 0 , op_code = 2013 )
82- protocol = self ._make_proto_with_header (header )
80+ protocol = await self ._make_proto_with_header (header )
8381 with self .assertRaises (ProtocolError ):
8482 protocol .process_header ()
8583
86- def test_length_exceeds_max_raises (self ):
84+ async def test_length_exceeds_max_raises (self ):
8785 header = _make_header (
8886 length = MAX_MESSAGE_SIZE + 1 , request_id = 1 , response_to = 0 , op_code = 2013
8987 )
90- protocol = self ._make_proto_with_header (header )
88+ protocol = await self ._make_proto_with_header (header )
9189 with self .assertRaises (ProtocolError ):
9290 protocol .process_header ()
9391
94- def test_op_reply_op_code (self ):
92+ async def test_op_reply_op_code (self ):
9593 header = _make_header (length = 20 , request_id = 0 , response_to = 0 , op_code = 1 )
96- protocol = self ._make_proto_with_header (header )
94+ protocol = await self ._make_proto_with_header (header )
9795 body_len , op_code , _response_to , expecting_compression = protocol .process_header ()
9896 self .assertEqual (body_len , 4 )
9997 self .assertEqual (op_code , 1 )
10098 self .assertFalse (expecting_compression )
10199
102- def test_compression_header_snappy_compressor_id (self ):
103- protocol = _make_protocol ()
100+ async def test_compression_header_snappy_compressor_id (self ):
101+ protocol = await _make_protocol ()
104102 # <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
105103 data = struct .pack ("<iiB" , 2013 , 0 , 1 )
106104 protocol ._compression_header = memoryview (bytearray (data ))
107105 op_code , compressor_id = protocol .process_compression_header ()
108106 self .assertEqual (op_code , 2013 )
109107 self .assertEqual (compressor_id , 1 )
110108
111- def test_compression_header_zlib_compressor_id (self ):
112- protocol = _make_protocol ()
109+ async def test_compression_header_zlib_compressor_id (self ):
110+ protocol = await _make_protocol ()
113111 data = struct .pack ("<iiB" , 2013 , 0 , 2 )
114112 protocol ._compression_header = memoryview (bytearray (data ))
115113 _ , compressor_id = protocol .process_compression_header ()
116114 self .assertEqual (compressor_id , 2 )
117115
118- def test_close_aborts_transport (self ):
119- protocol = _make_protocol ()
116+ async def test_message_complete_resolves_pending_future (self ):
117+ protocol = await _make_protocol ()
118+ protocol ._expecting_header = False
119+ protocol ._expecting_compression = False
120+ protocol ._message_size = 10
121+ protocol ._message = memoryview (bytearray (10 ))
122+ protocol ._message_index = 0
123+ protocol ._op_code = 2013
124+ protocol ._compressor_id = None
125+ protocol ._response_to = 42
126+
127+ future = asyncio .get_running_loop ().create_future ()
128+ protocol ._pending_messages .append (future )
129+
130+ protocol .buffer_updated (10 )
131+ self .assertTrue (future .done ())
132+ op_code , compressor_id , response_to , _ = future .result ()
133+ self .assertEqual (op_code , 2013 )
134+ self .assertIsNone (compressor_id )
135+ self .assertEqual (response_to , 42 )
136+
137+ async def test_close_aborts_transport (self ):
138+ protocol = await _make_protocol ()
120139 protocol .close ()
121140 self .assertTrue (protocol .transport .abort .called )
122141
123- def test_connection_lost_twice_does_not_raise (self ):
124- protocol = _make_protocol ()
142+ async def test_connection_lost_twice_does_not_raise (self ):
143+ protocol = await _make_protocol ()
125144 protocol .connection_lost (None )
126145 protocol .connection_lost (None )
127146
128147 async def test_close_with_exception_propagates_to_pending (self ):
129- protocol = _make_protocol ()
148+ protocol = await _make_protocol ()
130149 future = asyncio .get_running_loop ().create_future ()
131150 protocol ._pending_messages .append (future )
132151 exc = OSError ("connection reset" )
@@ -136,6 +155,31 @@ async def test_close_with_exception_propagates_to_pending(self):
136155
137156
138157class TestAsyncSocketReceive (AsyncUnitTest ):
158+ async def test_reads_data_in_multiple_chunks (self ):
159+ # Covers the loop in _async_socket_receive that accumulates short reads
160+ # until the requested length has been received.
161+ data = b"abcdefgh"
162+ length = len (data )
163+ chunk1 , chunk2 = data [:4 ], data [4 :]
164+ mock_socket = MagicMock ()
165+ loop = asyncio .get_running_loop ()
166+ calls = 0
167+
168+ async def fake_recv_into (sock , buf ):
169+ nonlocal calls
170+ if calls == 0 :
171+ buf [: len (chunk1 )] = chunk1
172+ calls += 1
173+ return len (chunk1 )
174+ buf [: len (chunk2 )] = chunk2
175+ calls += 1
176+ return len (chunk2 )
177+
178+ with patch .object (loop , "sock_recv_into" , new = AsyncMock (side_effect = fake_recv_into )):
179+ result = await _async_socket_receive (mock_socket , length , loop )
180+ self .assertEqual (bytes (result ), data )
181+ self .assertEqual (calls , 2 )
182+
139183 async def test_raises_on_connection_closed (self ):
140184 # Covers the explicit `raise OSError("connection closed")` branch when
141185 # sock_recv_into returns 0.
0 commit comments