diff --git a/smoketests/tests/replication.py b/smoketests/tests/replication.py index d3e3a2d9d69..168ebe49efe 100644 --- a/smoketests/tests/replication.py +++ b/smoketests/tests/replication.py @@ -1,17 +1,10 @@ from .. import COMPOSE_FILE, Smoketest, requires_docker, spacetime from ..docker import DockerManager -import re import time from typing import Callable import unittest -def get_int(text): - digits = re.search(r'\d+', text) - if digits is None: - raise Exception("no numbers found in string") - return int(digits.group()) - def retry(func: Callable, max_retries: int = 3, retry_delay: int = 2): """Retry a function on failure with delay.""" for attempt in range(1, max_retries + 1): @@ -25,6 +18,21 @@ def retry(func: Callable, max_retries: int = 3, retry_delay: int = 2): print("Max retries reached. Skipping the exception.") return False +def parse_sql_result(res: str) -> list[dict]: + """Parse tabular output from an SQL query into a list of dicts.""" + lines = res.splitlines() + headers = lines[0].split('|') if '|' in lines[0] else [lines[0]] + headers = [header.strip() for header in headers] + rows = [] + for row in lines[2:]: + cols = [col.strip() for col in row.split('|')] + rows.append(dict(zip(headers, cols))) + return rows + +def int_vals(rows: list[dict]) -> list[dict]: + """For all dicts in list, cast all values in dict to int.""" + return [{k: int(v) for k, v in row.items()} for row in rows] + class Cluster: """Manages leader-related operations and state for SpaceTime database cluster.""" @@ -35,56 +43,47 @@ def __init__(self, docker_manager, smoketest: Smoketest): # Ensure all containers are up. self.docker.compose("up", "-d") - def read_controldb(self, sql): - """Helper method to read from control database.""" - return self.test.spacetime("sql", "spacetime-control", sql) + def sql(self, sql: str) -> list[dict]: + """Query the test database.""" + res = self.test.sql(sql) + return parse_sql_result(str(res)) + + def read_controldb(self, sql: str) -> list[dict]: + """Query the control database.""" + res = self.test.spacetime("sql", "spacetime-control", sql) + return parse_sql_result(str(res)) def get_db_id(self): """Query database ID.""" sql = f"select id from database where database_identity=0x{self.test.database_identity}" - db_id_tb = self.read_controldb(sql) - return get_int(db_id_tb) - + res = self.read_controldb(sql) + return int(res[0]['id']) def get_all_replicas(self): """Get all replica nodes in the cluster.""" database_id = self.get_db_id() sql = f"select id, node_id from replica where database_id={database_id}" - replica_tb = self.read_controldb(sql) - replicas = [] - for line in str(replica_tb).splitlines()[2:]: - replica_id, node_id = line.split('|') - replicas.append({ - 'replica_id': int(replica_id), - 'node_id': int(node_id) - }) - return replicas + return int_vals(self.read_controldb(sql)) def get_leader_info(self): """Get current leader's node information including ID, hostname, and container ID.""" database_id = self.get_db_id() - # Query leader replica ID - sql = f"select leader from replication_state where database_id={database_id}" - leader_tb = self.read_controldb(sql) - leader_id = get_int(leader_tb) - - # Query leader node ID - sql = f"select node_id from replica where id={leader_id}" - leader_node_tb = self.read_controldb(sql) - leader_node_id = get_int(leader_node_tb) - - # Query leader hostname - sql = f"select network_addr from node_v2 where id={leader_node_id}" - leader_host_tb = str(self.read_controldb(sql)) - lines = leader_host_tb.splitlines() + sql = f""" \ +select node_v2.id, node_v2.network_addr from node_v2 \ +join replica on replica.node_id=node_v2.id \ +join replication_state on replication_state.leader=replica.id \ +where replication_state.database_id={database_id} \ +""" + rows = self.read_controldb(sql) + if not rows: + raise Exception("Could not find current leader's node") + leader_node_id = int(rows[0]['id']) hostname = "" - if len(lines) == 3: # actual row starts from 3rd line - leader_row = lines[2] - if "(some =" in leader_row: - address = leader_row.split('"')[1] - hostname = address.split(':')[0] + if "(some =" in rows[0]['network_addr']: + address = rows[0]['network_addr'].split('"')[1] + hostname = address.split(':')[0] # Find container ID container_id = "" @@ -114,15 +113,16 @@ def wait_for_leader_change(self, previous_leader_node, max_attempts=10, delay=2) time.sleep(delay) return None - def ensure_leader_health(self, id, wait_time=2): + def ensure_leader_health(self, id): """Verify leader is healthy by inserting a row.""" - if wait_time: - time.sleep(wait_time) retry(lambda: self.test.call("start", id, 1)) - add_table = str(self.test.sql(f"SELECT id FROM counter where id={id}")) - if str(id) not in add_table: + rows = self.sql(f"select id from counter where id={id}") + if len(rows) < 1 or int(rows[0]['id']) != id: raise ValueError(f"Could not find {id} in counter table") + # Wait for at least one tick to ensure buffers are flushed. + # TODO: Replace with confirmed read. + time.sleep(0.6) def fail_leader(self, action='kill'): @@ -247,6 +247,10 @@ def start(self, id: int, count: int): """Send a message to the database.""" retry(lambda: self.call("start", id, count)) + def collect_counter_rows(self): + return int_vals(self.cluster.sql("select * from counter")) + + class LeaderElection(ReplicationTest): def test_leader_election_in_loop(self): """This test fails a leader, wait for new leader to be elected and verify if commits replicated to new leader""" @@ -254,9 +258,10 @@ def test_leader_election_in_loop(self): row_ids = [101 + i for i in range(iterations * 2)] for (first_id, second_id) in zip(row_ids[::2], row_ids[1::2]): cur_leader = self.cluster.wait_for_leader_change(None) + print(f"ensure leader health {first_id}") self.cluster.ensure_leader_health(first_id) - print("killing current leader: {}", cur_leader) + print(f"killing current leader: {cur_leader}") container_id = self.cluster.fail_leader() self.assertIsNotNone(container_id) @@ -264,14 +269,20 @@ def test_leader_election_in_loop(self): next_leader = self.cluster.wait_for_leader_change(cur_leader) self.assertNotEqual(cur_leader, next_leader) # this check if leader election happened + print(f"ensure_leader_health {second_id}") self.cluster.ensure_leader_health(second_id) # restart the old leader, so that we can maintain quorum for next iteration + print(f"reconnect leader {container_id}") self.cluster.restore_leader(container_id, 'start') - # verify if all past rows are present in new leader - for row_id in row_ids: - table = self.sql(f"SELECT * FROM counter WHERE id = {row_id}") - self.assertIn(f"{row_id}", str(table)) + # Ensure we have a current leader + last_row_id = row_ids[-1] + 1 + self.cluster.ensure_leader_health(row_ids[-1] + 1) + row_ids.append(last_row_id) + + # Verify that all inserted rows are present + stored_row_ids = [row['id'] for row in self.collect_counter_rows()] + self.assertEqual(set(stored_row_ids), set(row_ids)) class LeaderDisconnect(ReplicationTest): def test_leader_c_disconnect_in_loop(self): @@ -300,12 +311,15 @@ def test_leader_c_disconnect_in_loop(self): # restart the old leader, so that we can maintain quorum for next iteration print(f"reconnect leader {container_id}") self.cluster.restore_leader(container_id, 'connect') - time.sleep(1) - # verify if all past rows are present in new leader - for row_id in row_ids: - table = self.sql(f"SELECT * FROM counter WHERE id = {row_id}") - self.assertIn(f"{row_id}", str(table)) + # Ensure we have a current leader + last_row_id = row_ids[-1] + 1 + self.cluster.ensure_leader_health(last_row_id) + row_ids.append(last_row_id) + + # Verify that all inserted rows are present + stored_row_ids = [row['id'] for row in self.collect_counter_rows()] + self.assertEqual(set(stored_row_ids), set(row_ids)) @unittest.skip("drain_node not yet supported") @@ -342,18 +356,16 @@ def test_prefer_leader(self): if replica['node_id'] != cur_leader_node_id: prefer_replica = replica break - prefer_replica_id = prefer_replica['replica_id'] + prefer_replica_id = prefer_replica['id'] self.spacetime("call", "spacetime-control", "prefer_leader", f"{prefer_replica_id}") next_leader_node_id = self.cluster.wait_for_leader_change(cur_leader_node_id) self.cluster.ensure_leader_health(402) self.assertEqual(prefer_replica['node_id'], next_leader_node_id) - # verify if all past rows are present in new leader - for row_id in [401, 402]: - table = self.sql(f"SELECT * FROM counter WHERE id = {row_id}") - self.assertIn(f"{row_id}", str(table)) + stored_row_ids = [row['id'] for row in self.collect_counter_rows()] + self.assertEqual(set(stored_row_ids), set([401, 402])) class ManyTransactions(ReplicationTest):