|
22 | 22 | from cassandra.cluster import Cluster |
23 | 23 | from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, |
24 | 24 | locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, |
25 | | - ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator) |
| 25 | + ConnectionException, ConnectionShutdown, ConnectionBusy, DefaultEndPoint, ShardAwarePortGenerator) |
26 | 26 | from cassandra.marshal import uint8_pack, uint32_pack, int32_pack |
27 | 27 | from cassandra.protocol import (write_stringmultimap, write_int, write_string, |
28 | 28 | SupportedMessage, ProtocolHandler, ResultMessage, QueryMessage, |
@@ -389,6 +389,26 @@ def test_wait_for_responses_releases_request_id_when_send_raises_after_registrat |
389 | 389 | assert len(c.request_ids) == initial_request_ids |
390 | 390 | assert not c._requests |
391 | 391 |
|
| 392 | + def test_set_keyspace_async_reports_send_failure_and_releases_request_id(self): |
| 393 | + c = self.make_connection() |
| 394 | + c.push = Mock(side_effect=ConnectionException("write failed")) |
| 395 | + initial_in_flight = c.in_flight |
| 396 | + initial_request_ids = len(c.request_ids) |
| 397 | + callback_errors = [] |
| 398 | + |
| 399 | + def callback(conn, error): |
| 400 | + callback_errors.append(error) |
| 401 | + with conn.lock: |
| 402 | + conn.in_flight -= 1 |
| 403 | + |
| 404 | + c.set_keyspace_async("ks", callback) |
| 405 | + |
| 406 | + assert len(callback_errors) == 1 |
| 407 | + assert isinstance(callback_errors[0], ConnectionException) |
| 408 | + assert c.in_flight == initial_in_flight |
| 409 | + assert len(c.request_ids) == initial_request_ids |
| 410 | + assert not c._requests |
| 411 | + |
392 | 412 |
|
393 | 413 | @patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped') |
394 | 414 | class ConnectionHeartbeatTest(unittest.TestCase): |
|
0 commit comments