2929from test .asynchronous import AsyncUnitTest , unittest
3030
3131
32- async def _make_protocol (timeout = None ):
32+ def _make_protocol (timeout = None ):
3333 protocol = PyMongoProtocol (timeout = timeout )
3434 mock_transport = MagicMock ()
3535 mock_transport .is_closing .return_value = False
@@ -42,15 +42,15 @@ def _make_header(length, request_id, response_to, op_code):
4242
4343
4444class TestPyMongoProtocol (AsyncUnitTest ):
45- async def _make_proto_with_header (self , header_bytes , max_size = MAX_MESSAGE_SIZE ):
46- protocol = await _make_protocol ()
45+ def _make_proto_with_header (self , header_bytes , max_size = MAX_MESSAGE_SIZE ):
46+ protocol = _make_protocol ()
4747 protocol ._max_message_size = max_size
4848 protocol ._header = memoryview (bytearray (header_bytes ))
4949 return protocol
5050
5151 async def test_normal_op_msg (self ):
5252 header = _make_header (length = 32 , request_id = 1 , response_to = 99 , op_code = 2013 )
53- protocol = await self ._make_proto_with_header (header )
53+ protocol = self ._make_proto_with_header (header )
5454 body_len , op_code , response_to , expecting_compression = protocol .process_header ()
5555 self .assertEqual (body_len , 16 )
5656 self .assertEqual (op_code , 2013 )
@@ -62,68 +62,48 @@ async def test_op_compressed(self):
6262 # (op code + uncompressed size + compressor id), then the 16-byte standard header.
6363 # length=35 → after compression sub-header: 26 → body: 10
6464 header = _make_header (length = 35 , request_id = 1 , response_to = 0 , op_code = 2012 )
65- protocol = await self ._make_proto_with_header (header )
65+ protocol = self ._make_proto_with_header (header )
6666 body_len , op_code , _response_to , expecting_compression = protocol .process_header ()
6767 self .assertEqual (body_len , 10 )
6868 self .assertEqual (op_code , 2012 )
6969 self .assertTrue (expecting_compression )
7070
7171 async def test_op_compressed_length_too_small_raises (self ):
7272 header = _make_header (length = 25 , request_id = 1 , response_to = 0 , op_code = 2012 )
73- protocol = await self ._make_proto_with_header (header )
74- with self .assertRaises (ProtocolError ):
73+ protocol = self ._make_proto_with_header (header )
74+ with self .assertRaisesRegex (ProtocolError , "not longer than standard OP_COMPRESSED" ):
7575 protocol .process_header ()
7676
7777 async def test_non_compressed_length_too_small_raises (self ):
7878 header = _make_header (length = 16 , request_id = 1 , response_to = 0 , op_code = 2013 )
79- protocol = await self ._make_proto_with_header (header )
80- with self .assertRaises (ProtocolError ):
79+ protocol = self ._make_proto_with_header (header )
80+ with self .assertRaisesRegex (ProtocolError , "not longer than standard message header size" ):
8181 protocol .process_header ()
8282
8383 async def test_length_exceeds_max_raises (self ):
8484 header = _make_header (
8585 length = MAX_MESSAGE_SIZE + 1 , request_id = 1 , response_to = 0 , op_code = 2013
8686 )
87- protocol = await self ._make_proto_with_header (header )
88- with self .assertRaises (ProtocolError ):
87+ protocol = self ._make_proto_with_header (header )
88+ with self .assertRaisesRegex (ProtocolError , "larger than server max" ):
8989 protocol .process_header ()
9090
91- async def test_op_reply_op_code (self ):
92- header = _make_header (length = 20 , request_id = 0 , response_to = 0 , op_code = 1 )
93- protocol = await self ._make_proto_with_header (header )
94- body_len , op_code , _response_to , expecting_compression = protocol .process_header ()
95- self .assertEqual (body_len , 4 )
96- self .assertEqual (op_code , 1 )
97- self .assertFalse (expecting_compression )
98-
9991 async def test_compression_header_snappy_compressor_id (self ):
100- protocol = await _make_protocol ()
92+ protocol = _make_protocol ()
10193 # <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
10294 data = struct .pack ("<iiB" , 2013 , 0 , 1 )
10395 protocol ._compression_header = memoryview (bytearray (data ))
10496 op_code , compressor_id = protocol .process_compression_header ()
10597 self .assertEqual (op_code , 2013 )
10698 self .assertEqual (compressor_id , 1 )
10799
108- async def test_compression_header_zlib_compressor_id (self ):
109- protocol = await _make_protocol ()
110- data = struct .pack ("<iiB" , 2013 , 0 , 2 )
111- protocol ._compression_header = memoryview (bytearray (data ))
112- _ , compressor_id = protocol .process_compression_header ()
113- self .assertEqual (compressor_id , 2 )
114-
115100 async def test_close_aborts_transport (self ):
116- protocol = await _make_protocol ()
101+ protocol = _make_protocol ()
117102 protocol .close ()
118103 self .assertTrue (protocol .transport .abort .called )
119104
120- async def test_connection_lost_twice_does_not_raise (self ):
121- protocol = await _make_protocol ()
122- protocol .connection_lost (None )
123- protocol .connection_lost (None )
124-
125105 async def test_close_with_exception_propagates_to_pending (self ):
126- protocol = await _make_protocol ()
106+ protocol = _make_protocol ()
127107 future = asyncio .get_running_loop ().create_future ()
128108 protocol ._pending_messages .append (future )
129109 exc = OSError ("connection reset" )
0 commit comments