2020
2121import atexit
2222import datetime
23+ from enum import Enum
2324from binascii import hexlify
2425from collections import defaultdict
2526from collections .abc import Mapping
2930from itertools import groupby , count , chain
3031import json
3132import logging
32- from typing import Any , Dict , Optional , Union , Literal
33+ from typing import Any , Dict , Optional , Union , Tuple
3334from warnings import warn
3435from random import random
3536import re
@@ -214,6 +215,27 @@ def __init__(self, message, errors):
214215 self .errors = errors
215216
216217
218+ class SchemaAgreementScope (str , Enum ):
219+ """Scope selectors for :meth:`.Session.wait_for_schema_agreement`."""
220+
221+ RACK = 'rack'
222+ DC = 'dc'
223+ CLUSTER = 'cluster'
224+
225+
226+ _SCHEMA_AGREEMENT_SCOPE_LABELS = {
227+ SchemaAgreementScope .RACK : 'local rack' ,
228+ SchemaAgreementScope .DC : 'local datacenter' ,
229+ SchemaAgreementScope .CLUSTER : 'cluster' ,
230+ }
231+
232+ _SCHEMA_AGREEMENT_ALLOWED_DISTANCES = {
233+ SchemaAgreementScope .RACK : (HostDistance .LOCAL_RACK ,),
234+ SchemaAgreementScope .DC : (HostDistance .LOCAL_RACK , HostDistance .LOCAL ),
235+ SchemaAgreementScope .CLUSTER : (HostDistance .LOCAL_RACK , HostDistance .LOCAL , HostDistance .REMOTE ),
236+ }
237+
238+
217239def _future_completed (future ):
218240 """ Helper for run_in_executor() """
219241 exc = future .exception ()
@@ -3374,7 +3396,8 @@ def pool_finished_setting_keyspace(pool, host_errors):
33743396 for pool in tuple (self ._pools .values ()):
33753397 pool ._set_keyspace_for_all_conns (keyspace , pool_finished_setting_keyspace )
33763398
3377- def wait_for_schema_agreement (self , wait_time = None , scope : Literal ['rack' , 'dc' , 'cluster' ]= 'dc' ):
3399+ def wait_for_schema_agreement (self , wait_time : Optional [float ] = None ,
3400+ scope : SchemaAgreementScope = SchemaAgreementScope .CLUSTER ) -> bool :
33783401 """
33793402 Wait for connected hosts in the selected scope to report the same
33803403 schema version from ``system.local``.
@@ -3388,11 +3411,13 @@ def wait_for_schema_agreement(self, wait_time=None, scope: Literal['rack', 'dc',
33883411 must be greater than 0.
33893412
33903413 ``scope`` determines which connected hosts participate in the check.
3391- Accepted values are ``'rack'``, ``'dc'``, and ``'cluster'``. The
3392- default ``'dc'`` scope queries connected hosts in the local rack and
3393- local datacenter. ``'rack'`` narrows the check to connected hosts in
3394- the local rack only. ``'cluster'`` queries every host this session has
3395- a live connection pool for, across all datacenters.
3414+ Pass :attr:`SchemaAgreementScope.RACK`, :attr:`SchemaAgreementScope.DC`,
3415+ or :attr:`SchemaAgreementScope.CLUSTER`. String values ``'rack'``,
3416+ ``'dc'``, and ``'cluster'`` are accepted for backward compatibility.
3417+ The default is :attr:`SchemaAgreementScope.CLUSTER`. ``RACK`` narrows
3418+ the check to connected hosts in the local rack only. ``DC`` checks
3419+ connected hosts in the local datacenter. ``CLUSTER`` queries every
3420+ connected host across all datacenters.
33963421
33973422 :param wait_time: Override for
33983423 :attr:`~.Cluster.max_schema_agreement_wait`.
@@ -3402,10 +3427,12 @@ def wait_for_schema_agreement(self, wait_time=None, scope: Literal['rack', 'dc',
34023427 otherwise ``False``.
34033428 :raises ValueError: If ``wait_time`` is provided and is not greater
34043429 than 0.
3405- :raises ValueError: If ``scope`` is not one of ``'rack'``, ``'dc'``,
3406- or ``'cluster'`` .
3430+ :raises ValueError: If ``scope`` is not one of the schema agreement
3431+ scope values .
34073432 """
3408- if scope not in ('rack' , 'dc' , 'cluster' ):
3433+ try :
3434+ scope = SchemaAgreementScope (scope )
3435+ except ValueError :
34093436 raise ValueError ("Invalid schema agreement scope: %s" % (scope ,))
34103437
34113438 if wait_time is not None and wait_time <= 0 :
@@ -3417,43 +3444,50 @@ def wait_for_schema_agreement(self, wait_time=None, scope: Literal['rack', 'dc',
34173444
34183445 deadline = time .time () + total_timeout
34193446 schema_mismatches = None
3447+ scope_label = _SCHEMA_AGREEMENT_SCOPE_LABELS [scope ]
34203448
34213449 while time .time () < deadline :
34223450 schema_mismatches = self ._get_schema_mismatches_for_scope (deadline , scope )
34233451 if schema_mismatches is None :
34243452 return True
34253453
3426- log .debug ("[session] Local schemas mismatched , trying again" )
3454+ log .debug ("[session] Connected hosts in the %s still disagree on schema , trying again" , scope_label )
34273455 remaining = deadline - time .time ()
34283456 if remaining > 0 :
34293457 time .sleep (min (0.2 , remaining ))
34303458
3431- log .warning ("Local nodes are reporting a schema disagreement: %s" , schema_mismatches )
3459+ log .warning ("[session] Connected hosts in the %s are reporting a schema disagreement: %s" ,
3460+ scope_label , schema_mismatches )
34323461 return False
34333462
3434- def _get_schema_mismatches_for_scope (self , deadline , scope : Literal ['rack' , 'dc' , 'cluster' ]):
3463+ def _get_schema_mismatches_for_scope (self , deadline : float ,
3464+ scope : SchemaAgreementScope ) -> Optional [Dict [Any , Any ]]:
34353465 hosts = self ._get_schema_agreement_hosts (scope )
3436- versions = defaultdict (set )
3466+ mismatches = defaultdict (list )
34373467 errors = {}
34383468
34393469 if not hosts :
3440- return {'unavailable' : 'No local hosts available' }
3470+ errors [scope .value ] = ConnectionException (
3471+ "No connected hosts available in the %s" % (_SCHEMA_AGREEMENT_SCOPE_LABELS [scope ],)
3472+ )
3473+ return {'unavailable' : errors }
34413474
34423475 metadata_request_timeout = self .cluster .control_connection ._metadata_request_timeout
34433476 query = maybe_add_timeout_to_query (ControlConnection ._SELECT_SCHEMA_LOCAL , metadata_request_timeout )
34443477
34453478 schema_version_futures = []
34463479 for host in hosts :
3447- schema_version , error = self ._query_local_schema_version (host , query , deadline )
3448- if error is not None :
3449- errors [host .endpoint ] = error
3480+ try :
3481+ schema_version_future = self ._query_local_schema_version (host , query , deadline )
3482+ except Exception as exc :
3483+ errors [host .endpoint ] = exc
34503484 continue
34513485
3452- schema_version_futures .append ((host , schema_version ))
3486+ schema_version_futures .append ((host , schema_version_future ))
34533487
34543488 if schema_version_futures :
34553489 # Start all host queries first, then wait for the whole batch.
3456- remaining = deadline - time .time ()
3490+ remaining = max ( 0.0 , deadline - time .time () )
34573491 if remaining > 0 :
34583492 wait_futures ([future for _ , future in schema_version_futures ], timeout = remaining )
34593493
@@ -3467,35 +3501,27 @@ def _get_schema_mismatches_for_scope(self, deadline, scope: Literal['rack', 'dc'
34673501
34683502 row = rows .one ()
34693503 schema_version = getattr (row , "schema_version" , None ) if row is not None else None
3470- versions [schema_version ].add (host .endpoint )
3504+ mismatches [schema_version ].append (host .endpoint )
34713505 else :
3472- errors [host .endpoint ] = "Timed out before querying host"
3506+ errors [host .endpoint ] = OperationTimedOut ( last_host = host , timeout = max ( 0.0 , deadline - time . time ()))
34733507
3474- if len (versions ) == 1 and None not in versions and not errors :
3475- log .debug ("[session] Local schemas match" )
3508+ if len (mismatches ) == 1 and None not in mismatches and not errors :
3509+ log .debug ("[session] Connected hosts in the %s agree on schema" , _SCHEMA_AGREEMENT_SCOPE_LABELS [ scope ] )
34763510 return None
34773511
3478- mismatches = dict ((version , list (nodes )) for version , nodes in versions .items ())
34793512 if errors :
3480- mismatches ['unavailable' ] = dict (( endpoint , str ( error )) for endpoint , error in errors . items ())
3481- return mismatches
3513+ mismatches ['unavailable' ] = errors
3514+ return dict ( mismatches )
34823515
3483- def _get_schema_agreement_hosts (self , scope ):
3484- allowed_distances = {
3485- 'rack' : (HostDistance .LOCAL_RACK ,),
3486- 'dc' : (HostDistance .LOCAL_RACK , HostDistance .LOCAL ),
3487- }
3516+ def _get_schema_agreement_hosts (self , scope : SchemaAgreementScope ) -> Tuple [Host , ...]:
34883517 return tuple (
34893518 host for host , pool in tuple (self ._pools .items ())
34903519 if host .is_up is not False
34913520 and not pool .is_shutdown
3492- and (scope == 'cluster' or self ._profile_manager .distance (host ) in allowed_distances [scope ]))
3493-
3494- def _query_local_schema_version (self , host , query , deadline ):
3495- remaining = deadline - time .time ()
3496- if remaining <= 0 :
3497- return None , "Timed out before querying host"
3521+ and self ._profile_manager .distance (host ) in _SCHEMA_AGREEMENT_ALLOWED_DISTANCES [scope ])
34983522
3523+ def _query_local_schema_version (self , host : Host , query : str , deadline : float ) -> Future :
3524+ remaining = max (0.0 , deadline - time .time ())
34993525 try :
35003526 response_future = self .execute_async (
35013527 query ,
@@ -3504,10 +3530,10 @@ def _query_local_schema_version(self, host, query, deadline):
35043530 )
35053531 except OperationTimedOut as timeout :
35063532 log .debug ("[session] Timed out waiting for schema version from %s: %s" , host , timeout )
3507- return None , timeout
3533+ raise
35083534 except Exception as exc :
35093535 log .debug ("[session] Error querying schema version from %s: %s" , host , exc )
3510- return None , exc
3536+ raise
35113537
35123538 schema_version_future = Future ()
35133539
@@ -3528,13 +3554,15 @@ def _set_exception(exc, result_future=schema_version_future):
35283554 response_future .add_callbacks (_set_result , _set_exception )
35293555 except Exception as exc :
35303556 log .debug ("[session] Error registering schema version callback from %s: %s" , host , exc )
3531- return None , exc
3557+ raise
35323558
3533- return schema_version_future , None
3559+ return schema_version_future
35343560
3535- def _schema_agreement_query_timeout (self , remaining ) :
3561+ def _schema_agreement_query_timeout (self , remaining : float ) -> float :
35363562 control_timeout = self .cluster .control_connection ._timeout
3537- return min (control_timeout , remaining ) if control_timeout is not None else remaining
3563+ if control_timeout is None :
3564+ return max (0.0 , remaining )
3565+ return max (0.0 , min (control_timeout , remaining ))
35383566
35393567 def user_type_registered (self , keyspace , user_type , klass ):
35403568 """
0 commit comments