3030from pymongo .network_layer import PyMongoProtocol , _async_socket_receive
3131
3232
33- async def _make_protocol (timeout = None ):
33+ def _make_protocol (timeout = None ):
34+ # PyMongoProtocol.__init__ calls asyncio.get_running_loop(), so this helper
35+ # must be called from inside an async test method.
3436 protocol = PyMongoProtocol (timeout = timeout )
3537 mock_transport = MagicMock ()
3638 mock_transport .is_closing .return_value = False
@@ -43,109 +45,88 @@ def _make_header(length, request_id, response_to, op_code):
4345
4446
4547class TestPyMongoProtocol (AsyncUnitTest ):
46- async def _make_proto_with_header (self , header_bytes , max_size = MAX_MESSAGE_SIZE ):
47- protocol = await _make_protocol ()
48+ def _make_proto_with_header (self , header_bytes , max_size = MAX_MESSAGE_SIZE ):
49+ protocol = _make_protocol ()
4850 protocol ._max_message_size = max_size
4951 protocol ._header = memoryview (bytearray (header_bytes ))
5052 return protocol
5153
52- async def test_normal_op_msg (self ):
54+ def test_normal_op_msg (self ):
5355 header = _make_header (length = 32 , request_id = 1 , response_to = 99 , op_code = 2013 )
54- protocol = await self ._make_proto_with_header (header )
56+ protocol = self ._make_proto_with_header (header )
5557 body_len , op_code , response_to , expecting_compression = protocol .process_header ()
5658 self .assertEqual (body_len , 16 )
5759 self .assertEqual (op_code , 2013 )
5860 self .assertEqual (response_to , 99 )
5961 self .assertFalse (expecting_compression )
6062
61- async def test_op_compressed (self ):
63+ def test_op_compressed (self ):
6264 # OP_COMPRESSED=2012; process_header strips the 9-byte compression sub-header
6365 # (op code + uncompressed size + compressor id), then the 16-byte standard header.
6466 # length=35 → after compression sub-header: 26 → body: 10
6567 header = _make_header (length = 35 , request_id = 1 , response_to = 0 , op_code = 2012 )
66- protocol = await self ._make_proto_with_header (header )
68+ protocol = self ._make_proto_with_header (header )
6769 body_len , op_code , _response_to , expecting_compression = protocol .process_header ()
6870 self .assertEqual (body_len , 10 )
6971 self .assertEqual (op_code , 2012 )
7072 self .assertTrue (expecting_compression )
7173
72- async def test_op_compressed_length_too_small_raises (self ):
74+ def test_op_compressed_length_too_small_raises (self ):
7375 header = _make_header (length = 25 , request_id = 1 , response_to = 0 , op_code = 2012 )
74- protocol = await self ._make_proto_with_header (header )
76+ protocol = self ._make_proto_with_header (header )
7577 with self .assertRaises (ProtocolError ):
7678 protocol .process_header ()
7779
78- async def test_non_compressed_length_too_small_raises (self ):
80+ def test_non_compressed_length_too_small_raises (self ):
7981 header = _make_header (length = 16 , request_id = 1 , response_to = 0 , op_code = 2013 )
80- protocol = await self ._make_proto_with_header (header )
82+ protocol = self ._make_proto_with_header (header )
8183 with self .assertRaises (ProtocolError ):
8284 protocol .process_header ()
8385
84- async def test_length_exceeds_max_raises (self ):
86+ def test_length_exceeds_max_raises (self ):
8587 header = _make_header (
8688 length = MAX_MESSAGE_SIZE + 1 , request_id = 1 , response_to = 0 , op_code = 2013
8789 )
88- protocol = await self ._make_proto_with_header (header )
90+ protocol = self ._make_proto_with_header (header )
8991 with self .assertRaises (ProtocolError ):
9092 protocol .process_header ()
9193
92- async def test_op_reply_op_code (self ):
94+ def test_op_reply_op_code (self ):
9395 header = _make_header (length = 20 , request_id = 0 , response_to = 0 , op_code = 1 )
94- protocol = await self ._make_proto_with_header (header )
96+ protocol = self ._make_proto_with_header (header )
9597 body_len , op_code , _response_to , expecting_compression = protocol .process_header ()
9698 self .assertEqual (body_len , 4 )
9799 self .assertEqual (op_code , 1 )
98100 self .assertFalse (expecting_compression )
99101
100- async def test_compression_header_snappy_compressor_id (self ):
101- protocol = await _make_protocol ()
102+ def test_compression_header_snappy_compressor_id (self ):
103+ protocol = _make_protocol ()
102104 # <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
103105 data = struct .pack ("<iiB" , 2013 , 0 , 1 )
104106 protocol ._compression_header = memoryview (bytearray (data ))
105107 op_code , compressor_id = protocol .process_compression_header ()
106108 self .assertEqual (op_code , 2013 )
107109 self .assertEqual (compressor_id , 1 )
108110
109- async def test_compression_header_zlib_compressor_id (self ):
110- protocol = await _make_protocol ()
111+ def test_compression_header_zlib_compressor_id (self ):
112+ protocol = _make_protocol ()
111113 data = struct .pack ("<iiB" , 2013 , 0 , 2 )
112114 protocol ._compression_header = memoryview (bytearray (data ))
113115 _ , compressor_id = protocol .process_compression_header ()
114116 self .assertEqual (compressor_id , 2 )
115117
116- async def test_message_complete_resolves_pending_future (self ):
117- protocol = await _make_protocol ()
118- protocol ._expecting_header = False
119- protocol ._expecting_compression = False
120- protocol ._message_size = 10
121- protocol ._message = memoryview (bytearray (10 ))
122- protocol ._message_index = 0
123- protocol ._op_code = 2013
124- protocol ._compressor_id = None
125- protocol ._response_to = 42
126-
127- future = asyncio .get_running_loop ().create_future ()
128- protocol ._pending_messages .append (future )
129-
130- protocol .buffer_updated (10 )
131- self .assertTrue (future .done ())
132- op_code , compressor_id , response_to , _ = future .result ()
133- self .assertEqual (op_code , 2013 )
134- self .assertIsNone (compressor_id )
135- self .assertEqual (response_to , 42 )
136-
137- async def test_close_aborts_transport (self ):
138- protocol = await _make_protocol ()
118+ def test_close_aborts_transport (self ):
119+ protocol = _make_protocol ()
139120 protocol .close ()
140121 self .assertTrue (protocol .transport .abort .called )
141122
142- async def test_connection_lost_twice_does_not_raise (self ):
143- protocol = await _make_protocol ()
123+ def test_connection_lost_twice_does_not_raise (self ):
124+ protocol = _make_protocol ()
144125 protocol .connection_lost (None )
145126 protocol .connection_lost (None )
146127
147128 async def test_close_with_exception_propagates_to_pending (self ):
148- protocol = await _make_protocol ()
129+ protocol = _make_protocol ()
149130 future = asyncio .get_running_loop ().create_future ()
150131 protocol ._pending_messages .append (future )
151132 exc = OSError ("connection reset" )
@@ -155,31 +136,6 @@ async def test_close_with_exception_propagates_to_pending(self):
155136
156137
157138class TestAsyncSocketReceive (AsyncUnitTest ):
158- async def test_reads_data_in_multiple_chunks (self ):
159- # Covers the loop in _async_socket_receive that accumulates short reads
160- # until the requested length has been received.
161- data = b"abcdefgh"
162- length = len (data )
163- chunk1 , chunk2 = data [:4 ], data [4 :]
164- mock_socket = MagicMock ()
165- loop = asyncio .get_running_loop ()
166- calls = 0
167-
168- async def fake_recv_into (sock , buf ):
169- nonlocal calls
170- if calls == 0 :
171- buf [: len (chunk1 )] = chunk1
172- calls += 1
173- return len (chunk1 )
174- buf [: len (chunk2 )] = chunk2
175- calls += 1
176- return len (chunk2 )
177-
178- with patch .object (loop , "sock_recv_into" , new = AsyncMock (side_effect = fake_recv_into )):
179- result = await _async_socket_receive (mock_socket , length , loop )
180- self .assertEqual (bytes (result ), data )
181- self .assertEqual (calls , 2 )
182-
183139 async def test_raises_on_connection_closed (self ):
184140 # Covers the explicit `raise OSError("connection closed")` branch when
185141 # sock_recv_into returns 0.
0 commit comments