Skip to content

Commit 1c417c3

Browse files
committed
cluster: add Session.wait_for_schema_agreement
Add Session.wait_for_schema_agreement() as a session-scoped schema agreement check. The new API queries schema_version from system.local on the connected hosts selected by the requested rack, dc, or cluster scope, respects Cluster.max_schema_agreement_wait and the control-connection metadata timeouts, and bounds the fan-out with configurable parallelism. Update the public Session docs and switch the integration callers that were explicitly waiting on schema agreement to use the session API. Add unit coverage for agreement, retries, busy connections, missing pools, batching, scope filtering, and invalid scope handling.
1 parent 0842348 commit 1c417c3

6 files changed

Lines changed: 459 additions & 4 deletions

File tree

cassandra/cluster.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from itertools import groupby, count, chain
3030
import json
3131
import logging
32-
from typing import Any, Dict, Optional, Union
32+
from typing import Any, Dict, Optional, Union, Literal
3333
from warnings import warn
3434
from random import random
3535
import re
@@ -3374,6 +3374,125 @@ def pool_finished_setting_keyspace(pool, host_errors):
33743374
for pool in tuple(self._pools.values()):
33753375
pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace)
33763376

3377+
def wait_for_schema_agreement(self, wait_time=None, scope: Literal['rack', 'dc', 'cluster']='dc'):
3378+
"""
3379+
Wait for connected hosts in the selected scope to report the same
3380+
schema version from ``system.local``.
3381+
3382+
By default, the timeout for this operation is governed by
3383+
:attr:`~.Cluster.max_schema_agreement_wait` and
3384+
:attr:`~.Cluster.control_connection_timeout`.
3385+
3386+
Passing ``wait_time`` here overrides
3387+
:attr:`~.Cluster.max_schema_agreement_wait`. Setting ``wait_time <= 0``
3388+
will bypass schema agreement waits.
3389+
3390+
``scope`` determines which connected hosts participate in the check.
3391+
Accepted values are ``'rack'``, ``'dc'``, and ``'cluster'``. The
3392+
default ``'dc'`` scope queries connected hosts in the local rack and
3393+
local datacenter. ``'rack'`` narrows the check to connected hosts in
3394+
the local rack only. ``'cluster'`` queries every host this session has
3395+
a live connection pool for, across all datacenters.
3396+
3397+
:param wait_time: Override for
3398+
:attr:`~.Cluster.max_schema_agreement_wait`.
3399+
:param scope: Restricts the check to connected hosts in the local rack,
3400+
local datacenter, or whole connected cluster.
3401+
:returns: ``True`` when the selected connected hosts agree on schema,
3402+
otherwise ``False``.
3403+
:raises ValueError: If ``scope`` is not one of ``'rack'``, ``'dc'``,
3404+
or ``'cluster'``.
3405+
"""
3406+
if scope not in ('rack', 'dc', 'cluster'):
3407+
raise ValueError("Invalid schema agreement scope: %s" % (scope,))
3408+
3409+
total_timeout = wait_time if wait_time is not None else self.cluster.max_schema_agreement_wait
3410+
if total_timeout <= 0:
3411+
return True
3412+
3413+
deadline = time.time() + total_timeout
3414+
schema_mismatches = None
3415+
3416+
while time.time() < deadline:
3417+
schema_mismatches = self._get_schema_mismatches_for_scope(deadline, scope)
3418+
if schema_mismatches is None:
3419+
return True
3420+
3421+
log.debug("[session] Local schemas mismatched, trying again")
3422+
remaining = deadline - time.time()
3423+
if remaining > 0:
3424+
time.sleep(min(0.2, remaining))
3425+
3426+
log.warning("Local nodes are reporting a schema disagreement: %s", schema_mismatches)
3427+
return False
3428+
3429+
def _get_schema_mismatches_for_scope(self, deadline, scope: Literal['rack', 'dc', 'cluster']):
3430+
hosts = self._get_schema_agreement_hosts(scope)
3431+
versions = defaultdict(set)
3432+
errors = {}
3433+
3434+
if not hosts:
3435+
return {'unavailable': 'No local hosts available'}
3436+
3437+
metadata_request_timeout = self.cluster.control_connection._metadata_request_timeout
3438+
query = maybe_add_timeout_to_query(ControlConnection._SELECT_SCHEMA_LOCAL, metadata_request_timeout)
3439+
3440+
for index, host in enumerate(hosts):
3441+
schema_version, error = self._query_local_schema_version(host, query, deadline)
3442+
if error is not None:
3443+
errors[host.endpoint] = error
3444+
if error == "Timed out before querying host":
3445+
for timed_out_host in hosts[index + 1:]:
3446+
errors[timed_out_host.endpoint] = error
3447+
break
3448+
continue
3449+
versions[schema_version].add(host.endpoint)
3450+
3451+
if len(versions) == 1 and None not in versions and not errors:
3452+
log.debug("[session] Local schemas match")
3453+
return None
3454+
3455+
mismatches = dict((version, list(nodes)) for version, nodes in versions.items())
3456+
if errors:
3457+
mismatches['unavailable'] = dict((endpoint, str(error)) for endpoint, error in errors.items())
3458+
return mismatches
3459+
3460+
def _get_schema_agreement_hosts(self, scope):
3461+
allowed_distances = {
3462+
'rack': (HostDistance.LOCAL_RACK,),
3463+
'dc': (HostDistance.LOCAL_RACK, HostDistance.LOCAL),
3464+
}
3465+
return tuple(
3466+
host for host, pool in tuple(self._pools.items())
3467+
if host.is_up is not False
3468+
and not pool.is_shutdown
3469+
and (scope == 'cluster' or self._profile_manager.distance(host) in allowed_distances[scope]))
3470+
3471+
def _query_local_schema_version(self, host, query, deadline):
3472+
remaining = deadline - time.time()
3473+
if remaining <= 0:
3474+
return None, "Timed out before querying host"
3475+
3476+
try:
3477+
rows = self.execute(
3478+
query,
3479+
timeout=self._schema_agreement_query_timeout(remaining),
3480+
host=host,
3481+
)
3482+
except OperationTimedOut as timeout:
3483+
log.debug("[session] Timed out waiting for schema version from %s: %s", host, timeout)
3484+
return None, timeout
3485+
except Exception as exc:
3486+
log.debug("[session] Error querying schema version from %s: %s", host, exc)
3487+
return None, exc
3488+
3489+
row = rows.one()
3490+
return (getattr(row, "schema_version", None) if row is not None else None), None
3491+
3492+
def _schema_agreement_query_timeout(self, remaining):
3493+
control_timeout = self.cluster.control_connection._timeout
3494+
return min(control_timeout, remaining) if control_timeout is not None else remaining
3495+
33773496
def user_type_registered(self, keyspace, user_type, klass):
33783497
"""
33793498
Called by the parent Cluster instance when the user registers a new
@@ -4079,7 +4198,6 @@ def _handle_schema_change(self, event):
40794198
self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, **event)
40804199

40814200
def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None):
4082-
40834201
total_timeout = wait_time if wait_time is not None else self._cluster.max_schema_agreement_wait
40844202
if total_timeout <= 0:
40854203
return True

docs/api/cassandra/cluster.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ Clusters and Sessions
169169

170170
.. automethod:: set_keyspace(keyspace)
171171

172+
.. automethod:: wait_for_schema_agreement
173+
172174
.. automethod:: get_execution_profile
173175

174176
.. automethod:: execution_profile_clone_update

tests/integration/long/test_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,4 @@ def check_and_wait_for_agreement(self, session, rs, exepected):
158158
time.sleep(1)
159159
assert rs.response_future.is_schema_agreed == exepected
160160
if not rs.response_future.is_schema_agreed:
161-
session.cluster.control_connection.wait_for_schema_agreement(wait_time=1000)
161+
session.wait_for_schema_agreement(wait_time=1000)

tests/integration/standard/test_udts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_can_register_udt_before_connecting(self):
147147
c.register_user_type("udt_test_register_before_connecting2", "user", User2)
148148

149149
s = c.connect(wait_for_all_pools=True)
150-
c.control_connection.wait_for_schema_agreement()
150+
s.wait_for_schema_agreement()
151151

152152
s.execute("INSERT INTO udt_test_register_before_connecting.mytable (a, b) VALUES (%s, %s)", (0, User1(42, 'bob')))
153153
result = s.execute("SELECT b FROM udt_test_register_before_connecting.mytable WHERE a=0")

tests/unit/test_cluster.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion
2424
from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \
2525
ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT
26+
from cassandra.connection import ConnectionBusy
2627
from cassandra.pool import Host
2728
from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy
2829
from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory
@@ -247,11 +248,90 @@ def test_event_delay_timing(self, *_):
247248

248249

249250
class SessionTest(unittest.TestCase):
251+
class FakeTime(object):
252+
253+
def __init__(self):
254+
self.clock = 0
255+
256+
def time(self):
257+
return self.clock
258+
259+
def sleep(self, amount):
260+
self.clock += amount
261+
262+
class MockPool(object):
263+
264+
def __init__(self, host, connection):
265+
self.host = host
266+
self.host_distance = HostDistance.LOCAL
267+
self.is_shutdown = False
268+
self.connection = connection
269+
270+
def _get_connection_for_routing_key(self):
271+
return self.connection
272+
273+
class MockResultSet(object):
274+
275+
def __init__(self, schema_version):
276+
self._schema_version = schema_version
277+
278+
def one(self):
279+
return Mock(schema_version=self._schema_version)
280+
250281
def setUp(self):
251282
if connection_class is None:
252283
raise unittest.SkipTest('libev does not appear to be installed correctly')
253284
connection_class.initialize_reactor()
254285

286+
def _mock_schema_future(self, outcome):
287+
future = Mock()
288+
if isinstance(outcome, Exception):
289+
future.result.side_effect = outcome
290+
else:
291+
future.result.return_value = self.MockResultSet(outcome)
292+
return future
293+
294+
def _host_query_count(self, session, target_host):
295+
return sum(1 for call in session.execute.call_args_list if call.kwargs.get('host') is target_host)
296+
297+
def _new_schema_agreement_session(self, schema_versions, distances=None):
298+
hosts = []
299+
connections = {}
300+
distance_map = {}
301+
if distances is None:
302+
distances = [HostDistance.LOCAL] * len(schema_versions)
303+
304+
for index, schema_version in enumerate(schema_versions):
305+
host = Host("127.0.0.%d" % (index + 1), SimpleConvictionPolicy, host_id=uuid.uuid4())
306+
host.set_up()
307+
hosts.append(host)
308+
distance_map[host] = distances[index]
309+
310+
cluster = Cluster(protocol_version=4)
311+
for host in hosts:
312+
cluster.metadata.add_or_return_host(host)
313+
314+
session = Session(cluster, hosts)
315+
session._profile_manager.distance = Mock(side_effect=lambda host: distance_map.get(host, HostDistance.LOCAL))
316+
session._pools = {}
317+
for host, schema_version in zip(hosts, schema_versions):
318+
connection = Mock(endpoint=host.endpoint)
319+
connection.future_outcomes = [schema_version]
320+
session._pools[host] = self.MockPool(host, connection)
321+
connections[host] = connection
322+
323+
def execute(query, parameters=None, timeout=None, trace=False,
324+
custom_payload=None, execution_profile=None,
325+
paging_state=None, host=None, execute_as=None):
326+
connection = connections[host]
327+
outcome = connection.future_outcomes.pop(0) if len(connection.future_outcomes) > 1 else connection.future_outcomes[0]
328+
future = self._mock_schema_future(outcome)
329+
return future.result()
330+
331+
session.execute = Mock(side_effect=execute)
332+
333+
return session, hosts, connections
334+
255335
# TODO: this suite could be expanded; for now just adding a test covering a PR
256336
@mock_session_pools
257337
def test_default_serial_consistency_level_ep(self, *_):
@@ -339,6 +419,89 @@ def test_set_keyspace_escapes_quotes(self, *_):
339419
assert query == 'USE simple_ks', (
340420
"Simple keyspace names should not be quoted, got: %r" % query)
341421

422+
@mock_session_pools
423+
def test_wait_for_schema_agreement_queries_all_local_hosts(self, *_):
424+
session, hosts, _ = self._new_schema_agreement_session(["a", "a"])
425+
426+
assert session.wait_for_schema_agreement(wait_time=1)
427+
428+
for host in hosts:
429+
assert self._host_query_count(session, host) == 1
430+
431+
@mock_session_pools
432+
def test_wait_for_schema_agreement_retries_until_local_hosts_match(self, *_):
433+
session, hosts, connections = self._new_schema_agreement_session(["a", "b"])
434+
clock = self.FakeTime()
435+
connections[hosts[1]].future_outcomes = ["b", "a"]
436+
437+
with patch('cassandra.cluster.time', new=clock):
438+
assert session.wait_for_schema_agreement(wait_time=1)
439+
for host in hosts:
440+
assert self._host_query_count(session, host) == 2
441+
assert clock.clock == 0.2
442+
443+
@mock_session_pools
444+
def test_wait_for_schema_agreement_retries_when_local_connection_is_busy(self, *_):
445+
session, hosts, connections = self._new_schema_agreement_session(["a", "a"])
446+
clock = self.FakeTime()
447+
connections[hosts[1]].future_outcomes = [
448+
ConnectionBusy("connection overloaded"),
449+
"a"]
450+
451+
with patch('cassandra.cluster.time', new=clock):
452+
assert session.wait_for_schema_agreement(wait_time=1)
453+
for host in hosts:
454+
assert self._host_query_count(session, host) == 2
455+
assert clock.clock == 0.2
456+
457+
@mock_session_pools
458+
def test_wait_for_schema_agreement_ignores_local_hosts_without_session_pool(self, *_):
459+
session, hosts, _ = self._new_schema_agreement_session(["a"])
460+
461+
unconnected_host = Host("127.0.0.2", SimpleConvictionPolicy, host_id=uuid.uuid4())
462+
unconnected_host.set_up()
463+
session.cluster.metadata.add_or_return_host(unconnected_host)
464+
465+
assert session.wait_for_schema_agreement(wait_time=1)
466+
assert self._host_query_count(session, hosts[0]) == 1
467+
468+
@mock_session_pools
469+
def test_wait_for_schema_agreement_queries_hosts_in_order(self, *_):
470+
session, hosts, _ = self._new_schema_agreement_session(["a"] * 11)
471+
472+
assert session.wait_for_schema_agreement(wait_time=1)
473+
assert [call.kwargs['host'] for call in session.execute.call_args_list] == list(hosts)
474+
475+
@mock_session_pools
476+
def test_wait_for_schema_agreement_rack_scope_only_queries_local_rack_connections(self, *_):
477+
session, hosts, _ = self._new_schema_agreement_session(
478+
["a", "a", "a"],
479+
distances=[HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE])
480+
481+
assert session.wait_for_schema_agreement(wait_time=1, scope='rack')
482+
483+
assert self._host_query_count(session, hosts[0]) == 1
484+
assert self._host_query_count(session, hosts[1]) == 0
485+
assert self._host_query_count(session, hosts[2]) == 0
486+
487+
@mock_session_pools
488+
def test_wait_for_schema_agreement_cluster_scope_queries_all_connected_hosts(self, *_):
489+
session, hosts, _ = self._new_schema_agreement_session(
490+
["a", "a", "a"],
491+
distances=[HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE])
492+
493+
assert session.wait_for_schema_agreement(wait_time=1, scope='cluster')
494+
495+
for host in hosts:
496+
assert self._host_query_count(session, host) == 1
497+
498+
@mock_session_pools
499+
def test_wait_for_schema_agreement_rejects_unknown_scope(self, *_):
500+
session, _, _ = self._new_schema_agreement_session(["a"])
501+
502+
with pytest.raises(ValueError):
503+
session.wait_for_schema_agreement(wait_time=1, scope='planet')
504+
342505
class ProtocolVersionTests(unittest.TestCase):
343506

344507
def test_protocol_downgrade_test(self):

0 commit comments

Comments
 (0)