|
18 | 18 | from threading import Lock |
19 | 19 | from unittest.mock import Mock, ANY, call, patch |
20 | 20 |
|
21 | | -from cassandra import OperationTimedOut |
| 21 | +from cassandra import ConsistencyLevel, OperationTimedOut |
22 | 22 | from cassandra.cluster import Cluster |
23 | 23 | from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, |
24 | 24 | locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, |
25 | 25 | ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator) |
26 | 26 | from cassandra.marshal import uint8_pack, uint32_pack, int32_pack |
27 | 27 | from cassandra.protocol import (write_stringmultimap, write_int, write_string, |
28 | | - SupportedMessage, ProtocolHandler, ResultMessage, |
| 28 | + SupportedMessage, ProtocolHandler, ResultMessage, QueryMessage, |
29 | 29 | RESULT_KIND_SET_KEYSPACE) |
30 | 30 |
|
31 | 31 | from tests.util import wait_until, assertRegex |
@@ -363,6 +363,32 @@ def test_wait_for_responses_shutdown_includes_last_error(self): |
363 | 363 | assert "already closed" in error_message |
364 | 364 | assert "Bad file descriptor" in error_message |
365 | 365 |
|
| 366 | + def test_wait_for_responses_releases_request_id_when_send_fails(self): |
| 367 | + c = self.make_connection() |
| 368 | + c._socket_writable = False |
| 369 | + initial_in_flight = c.in_flight |
| 370 | + initial_request_ids = len(c.request_ids) |
| 371 | + |
| 372 | + with pytest.raises(ConnectionBusy): |
| 373 | + c.wait_for_responses(Mock()) |
| 374 | + |
| 375 | + assert c.in_flight == initial_in_flight |
| 376 | + assert len(c.request_ids) == initial_request_ids |
| 377 | + assert not c._requests |
| 378 | + |
| 379 | + def test_wait_for_responses_releases_request_id_when_send_raises_after_registration(self): |
| 380 | + c = self.make_connection() |
| 381 | + c.push = Mock(side_effect=ConnectionException("write failed")) |
| 382 | + initial_in_flight = c.in_flight |
| 383 | + initial_request_ids = len(c.request_ids) |
| 384 | + |
| 385 | + with pytest.raises(ConnectionException): |
| 386 | + c.wait_for_responses(QueryMessage("SELECT * FROM system.local", ConsistencyLevel.ONE)) |
| 387 | + |
| 388 | + assert c.in_flight == initial_in_flight |
| 389 | + assert len(c.request_ids) == initial_request_ids |
| 390 | + assert not c._requests |
| 391 | + |
366 | 392 |
|
367 | 393 | @patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped') |
368 | 394 | class ConnectionHeartbeatTest(unittest.TestCase): |
|
0 commit comments