Skip to content

Commit c4862e7

Browse files
committed
connection: release stream ids after send failures
1 parent 0842348 commit c4862e7

2 files changed

Lines changed: 49 additions & 12 deletions

File tree

cassandra/connection.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,15 +1219,19 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message,
12191219
# queue the decoder function with the request
12201220
# this allows us to inject custom functions per request to encode, decode messages
12211221
self._requests[request_id] = (cb, decoder, result_metadata)
1222-
msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor,
1223-
allow_beta_protocol_version=self.allow_beta_protocol_version)
1222+
try:
1223+
msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor,
1224+
allow_beta_protocol_version=self.allow_beta_protocol_version)
12241225

1225-
if self._is_checksumming_enabled:
1226-
buffer = io.BytesIO()
1227-
self._segment_codec.encode(buffer, msg)
1228-
msg = buffer.getvalue()
1226+
if self._is_checksumming_enabled:
1227+
buffer = io.BytesIO()
1228+
self._segment_codec.encode(buffer, msg)
1229+
msg = buffer.getvalue()
12291230

1230-
self.push(msg)
1231+
self.push(msg)
1232+
except Exception:
1233+
self._requests.pop(request_id, None)
1234+
raise
12311235
return len(msg)
12321236

12331237
def wait_for_response(self, msg, timeout=None, **kwargs):
@@ -1262,9 +1266,16 @@ def wait_for_responses(self, *msgs, **kwargs):
12621266
self.in_flight += available
12631267

12641268
for i, request_id in enumerate(request_ids):
1265-
self.send_msg(msgs[messages_sent + i],
1266-
request_id,
1267-
partial(waiter.got_response, index=messages_sent + i))
1269+
try:
1270+
self.send_msg(msgs[messages_sent + i],
1271+
request_id,
1272+
partial(waiter.got_response, index=messages_sent + i))
1273+
except Exception:
1274+
unsent_request_ids = request_ids[i:]
1275+
with self.lock:
1276+
self.in_flight -= len(unsent_request_ids)
1277+
self.request_ids.extend(unsent_request_ids)
1278+
raise
12681279
messages_sent += available
12691280

12701281
if messages_sent == len(msgs):

tests/unit/test_connection.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
from threading import Lock
1919
from unittest.mock import Mock, ANY, call, patch
2020

21-
from cassandra import OperationTimedOut
21+
from cassandra import ConsistencyLevel, OperationTimedOut
2222
from cassandra.cluster import Cluster
2323
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError,
2424
locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager,
2525
ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator)
2626
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
2727
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
28-
SupportedMessage, ProtocolHandler, ResultMessage,
28+
SupportedMessage, ProtocolHandler, ResultMessage, QueryMessage,
2929
RESULT_KIND_SET_KEYSPACE)
3030

3131
from tests.util import wait_until, assertRegex
@@ -363,6 +363,32 @@ def test_wait_for_responses_shutdown_includes_last_error(self):
363363
assert "already closed" in error_message
364364
assert "Bad file descriptor" in error_message
365365

366+
def test_wait_for_responses_releases_request_id_when_send_fails(self):
367+
c = self.make_connection()
368+
c._socket_writable = False
369+
initial_in_flight = c.in_flight
370+
initial_request_ids = len(c.request_ids)
371+
372+
with pytest.raises(ConnectionBusy):
373+
c.wait_for_responses(Mock())
374+
375+
assert c.in_flight == initial_in_flight
376+
assert len(c.request_ids) == initial_request_ids
377+
assert not c._requests
378+
379+
def test_wait_for_responses_releases_request_id_when_send_raises_after_registration(self):
380+
c = self.make_connection()
381+
c.push = Mock(side_effect=ConnectionException("write failed"))
382+
initial_in_flight = c.in_flight
383+
initial_request_ids = len(c.request_ids)
384+
385+
with pytest.raises(ConnectionException):
386+
c.wait_for_responses(QueryMessage("SELECT * FROM system.local", ConsistencyLevel.ONE))
387+
388+
assert c.in_flight == initial_in_flight
389+
assert len(c.request_ids) == initial_request_ids
390+
assert not c._requests
391+
366392

367393
@patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped')
368394
class ConnectionHeartbeatTest(unittest.TestCase):

0 commit comments

Comments
 (0)