|
23 | 23 | InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion |
24 | 24 | from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ |
25 | 25 | ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT |
| 26 | +from cassandra.connection import ConnectionBusy |
26 | 27 | from cassandra.pool import Host |
27 | 28 | from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy |
28 | 29 | from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory |
@@ -247,11 +248,90 @@ def test_event_delay_timing(self, *_): |
247 | 248 |
|
248 | 249 |
|
249 | 250 | class SessionTest(unittest.TestCase): |
| 251 | + class FakeTime(object): |
| 252 | + |
| 253 | + def __init__(self): |
| 254 | + self.clock = 0 |
| 255 | + |
| 256 | + def time(self): |
| 257 | + return self.clock |
| 258 | + |
| 259 | + def sleep(self, amount): |
| 260 | + self.clock += amount |
| 261 | + |
| 262 | + class MockPool(object): |
| 263 | + |
| 264 | + def __init__(self, host, connection): |
| 265 | + self.host = host |
| 266 | + self.host_distance = HostDistance.LOCAL |
| 267 | + self.is_shutdown = False |
| 268 | + self.connection = connection |
| 269 | + |
| 270 | + def _get_connection_for_routing_key(self): |
| 271 | + return self.connection |
| 272 | + |
| 273 | + class MockResultSet(object): |
| 274 | + |
| 275 | + def __init__(self, schema_version): |
| 276 | + self._schema_version = schema_version |
| 277 | + |
| 278 | + def one(self): |
| 279 | + return Mock(schema_version=self._schema_version) |
| 280 | + |
250 | 281 | def setUp(self): |
251 | 282 | if connection_class is None: |
252 | 283 | raise unittest.SkipTest('libev does not appear to be installed correctly') |
253 | 284 | connection_class.initialize_reactor() |
254 | 285 |
|
| 286 | + def _mock_schema_future(self, outcome): |
| 287 | + future = Mock() |
| 288 | + if isinstance(outcome, Exception): |
| 289 | + future.result.side_effect = outcome |
| 290 | + else: |
| 291 | + future.result.return_value = self.MockResultSet(outcome) |
| 292 | + return future |
| 293 | + |
| 294 | + def _host_query_count(self, session, target_host): |
| 295 | + return sum(1 for call in session.execute.call_args_list if call.kwargs.get('host') is target_host) |
| 296 | + |
| 297 | + def _new_schema_agreement_session(self, schema_versions, distances=None): |
| 298 | + hosts = [] |
| 299 | + connections = {} |
| 300 | + distance_map = {} |
| 301 | + if distances is None: |
| 302 | + distances = [HostDistance.LOCAL] * len(schema_versions) |
| 303 | + |
| 304 | + for index, schema_version in enumerate(schema_versions): |
| 305 | + host = Host("127.0.0.%d" % (index + 1), SimpleConvictionPolicy, host_id=uuid.uuid4()) |
| 306 | + host.set_up() |
| 307 | + hosts.append(host) |
| 308 | + distance_map[host] = distances[index] |
| 309 | + |
| 310 | + cluster = Cluster(protocol_version=4) |
| 311 | + for host in hosts: |
| 312 | + cluster.metadata.add_or_return_host(host) |
| 313 | + |
| 314 | + session = Session(cluster, hosts) |
| 315 | + session._profile_manager.distance = Mock(side_effect=lambda host: distance_map.get(host, HostDistance.LOCAL)) |
| 316 | + session._pools = {} |
| 317 | + for host, schema_version in zip(hosts, schema_versions): |
| 318 | + connection = Mock(endpoint=host.endpoint) |
| 319 | + connection.future_outcomes = [schema_version] |
| 320 | + session._pools[host] = self.MockPool(host, connection) |
| 321 | + connections[host] = connection |
| 322 | + |
| 323 | + def execute(query, parameters=None, timeout=None, trace=False, |
| 324 | + custom_payload=None, execution_profile=None, |
| 325 | + paging_state=None, host=None, execute_as=None): |
| 326 | + connection = connections[host] |
| 327 | + outcome = connection.future_outcomes.pop(0) if len(connection.future_outcomes) > 1 else connection.future_outcomes[0] |
| 328 | + future = self._mock_schema_future(outcome) |
| 329 | + return future.result() |
| 330 | + |
| 331 | + session.execute = Mock(side_effect=execute) |
| 332 | + |
| 333 | + return session, hosts, connections |
| 334 | + |
255 | 335 | # TODO: this suite could be expanded; for now just adding a test covering a PR |
256 | 336 | @mock_session_pools |
257 | 337 | def test_default_serial_consistency_level_ep(self, *_): |
@@ -339,6 +419,89 @@ def test_set_keyspace_escapes_quotes(self, *_): |
339 | 419 | assert query == 'USE simple_ks', ( |
340 | 420 | "Simple keyspace names should not be quoted, got: %r" % query) |
341 | 421 |
|
| 422 | + @mock_session_pools |
| 423 | + def test_wait_for_schema_agreement_queries_all_local_hosts(self, *_): |
| 424 | + session, hosts, _ = self._new_schema_agreement_session(["a", "a"]) |
| 425 | + |
| 426 | + assert session.wait_for_schema_agreement(wait_time=1) |
| 427 | + |
| 428 | + for host in hosts: |
| 429 | + assert self._host_query_count(session, host) == 1 |
| 430 | + |
| 431 | + @mock_session_pools |
| 432 | + def test_wait_for_schema_agreement_retries_until_local_hosts_match(self, *_): |
| 433 | + session, hosts, connections = self._new_schema_agreement_session(["a", "b"]) |
| 434 | + clock = self.FakeTime() |
| 435 | + connections[hosts[1]].future_outcomes = ["b", "a"] |
| 436 | + |
| 437 | + with patch('cassandra.cluster.time', new=clock): |
| 438 | + assert session.wait_for_schema_agreement(wait_time=1) |
| 439 | + for host in hosts: |
| 440 | + assert self._host_query_count(session, host) == 2 |
| 441 | + assert clock.clock == 0.2 |
| 442 | + |
| 443 | + @mock_session_pools |
| 444 | + def test_wait_for_schema_agreement_retries_when_local_connection_is_busy(self, *_): |
| 445 | + session, hosts, connections = self._new_schema_agreement_session(["a", "a"]) |
| 446 | + clock = self.FakeTime() |
| 447 | + connections[hosts[1]].future_outcomes = [ |
| 448 | + ConnectionBusy("connection overloaded"), |
| 449 | + "a"] |
| 450 | + |
| 451 | + with patch('cassandra.cluster.time', new=clock): |
| 452 | + assert session.wait_for_schema_agreement(wait_time=1) |
| 453 | + for host in hosts: |
| 454 | + assert self._host_query_count(session, host) == 2 |
| 455 | + assert clock.clock == 0.2 |
| 456 | + |
| 457 | + @mock_session_pools |
| 458 | + def test_wait_for_schema_agreement_ignores_local_hosts_without_session_pool(self, *_): |
| 459 | + session, hosts, _ = self._new_schema_agreement_session(["a"]) |
| 460 | + |
| 461 | + unconnected_host = Host("127.0.0.2", SimpleConvictionPolicy, host_id=uuid.uuid4()) |
| 462 | + unconnected_host.set_up() |
| 463 | + session.cluster.metadata.add_or_return_host(unconnected_host) |
| 464 | + |
| 465 | + assert session.wait_for_schema_agreement(wait_time=1) |
| 466 | + assert self._host_query_count(session, hosts[0]) == 1 |
| 467 | + |
| 468 | + @mock_session_pools |
| 469 | + def test_wait_for_schema_agreement_queries_hosts_in_order(self, *_): |
| 470 | + session, hosts, _ = self._new_schema_agreement_session(["a"] * 11) |
| 471 | + |
| 472 | + assert session.wait_for_schema_agreement(wait_time=1) |
| 473 | + assert [call.kwargs['host'] for call in session.execute.call_args_list] == list(hosts) |
| 474 | + |
| 475 | + @mock_session_pools |
| 476 | + def test_wait_for_schema_agreement_rack_scope_only_queries_local_rack_connections(self, *_): |
| 477 | + session, hosts, _ = self._new_schema_agreement_session( |
| 478 | + ["a", "a", "a"], |
| 479 | + distances=[HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]) |
| 480 | + |
| 481 | + assert session.wait_for_schema_agreement(wait_time=1, scope='rack') |
| 482 | + |
| 483 | + assert self._host_query_count(session, hosts[0]) == 1 |
| 484 | + assert self._host_query_count(session, hosts[1]) == 0 |
| 485 | + assert self._host_query_count(session, hosts[2]) == 0 |
| 486 | + |
| 487 | + @mock_session_pools |
| 488 | + def test_wait_for_schema_agreement_cluster_scope_queries_all_connected_hosts(self, *_): |
| 489 | + session, hosts, _ = self._new_schema_agreement_session( |
| 490 | + ["a", "a", "a"], |
| 491 | + distances=[HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]) |
| 492 | + |
| 493 | + assert session.wait_for_schema_agreement(wait_time=1, scope='cluster') |
| 494 | + |
| 495 | + for host in hosts: |
| 496 | + assert self._host_query_count(session, host) == 1 |
| 497 | + |
| 498 | + @mock_session_pools |
| 499 | + def test_wait_for_schema_agreement_rejects_unknown_scope(self, *_): |
| 500 | + session, _, _ = self._new_schema_agreement_session(["a"]) |
| 501 | + |
| 502 | + with pytest.raises(ValueError): |
| 503 | + session.wait_for_schema_agreement(wait_time=1, scope='planet') |
| 504 | + |
342 | 505 | class ProtocolVersionTests(unittest.TestCase): |
343 | 506 |
|
344 | 507 | def test_protocol_downgrade_test(self): |
|
0 commit comments