11from .. import COMPOSE_FILE , Smoketest , requires_docker , spacetime
22from ..docker import DockerManager
33
4- import re
54import time
65from typing import Callable
76import unittest
87
9- def get_int (text ):
10- digits = re .search (r'\d+' , text )
11- if digits is None :
12- raise Exception ("no numbers found in string" )
13- return int (digits .group ())
14-
158def retry (func : Callable , max_retries : int = 3 , retry_delay : int = 2 ):
169 """Retry a function on failure with delay."""
1710 for attempt in range (1 , max_retries + 1 ):
@@ -25,6 +18,21 @@ def retry(func: Callable, max_retries: int = 3, retry_delay: int = 2):
2518 print ("Max retries reached. Skipping the exception." )
2619 return False
2720
21+ def parse_sql_result (res : str ) -> list [dict ]:
22+ """Parse tabular output from an SQL query into a list of dicts."""
23+ lines = res .splitlines ()
24+ headers = lines [0 ].split ('|' ) if '|' in lines [0 ] else [lines [0 ]]
25+ headers = [header .strip () for header in headers ]
26+ rows = []
27+ for row in lines [2 :]:
28+ cols = [col .strip () for col in row .split ('|' )]
29+ rows .append (dict (zip (headers , cols )))
30+ return rows
31+
32+ def int_vals (rows : list [dict ]) -> list [dict ]:
33+ """For all dicts in list, cast all values in dict to int."""
34+ return [{k : int (v ) for k , v in row .items ()} for row in rows ]
35+
2836class Cluster :
2937 """Manages leader-related operations and state for SpaceTime database cluster."""
3038
@@ -35,56 +43,47 @@ def __init__(self, docker_manager, smoketest: Smoketest):
3543 # Ensure all containers are up.
3644 self .docker .compose ("up" , "-d" )
3745
38- def read_controldb (self , sql ):
39- """Helper method to read from control database."""
40- return self .test .spacetime ("sql" , "spacetime-control" , sql )
46+ def sql (self , sql : str ) -> list [dict ]:
47+ """Query the test database."""
48+ res = self .test .sql (sql )
49+ return parse_sql_result (str (res ))
50+
51+ def read_controldb (self , sql : str ) -> list [dict ]:
52+ """Query the control database."""
53+ res = self .test .spacetime ("sql" , "spacetime-control" , sql )
54+ return parse_sql_result (str (res ))
4155
4256 def get_db_id (self ):
4357 """Query database ID."""
4458 sql = f"select id from database where database_identity=0x{ self .test .database_identity } "
45- db_id_tb = self .read_controldb (sql )
46- return get_int (db_id_tb )
47-
59+ res = self .read_controldb (sql )
60+ return int (res [0 ]['id' ])
4861
4962 def get_all_replicas (self ):
5063 """Get all replica nodes in the cluster."""
5164 database_id = self .get_db_id ()
5265 sql = f"select id, node_id from replica where database_id={ database_id } "
53- replica_tb = self .read_controldb (sql )
54- replicas = []
55- for line in str (replica_tb ).splitlines ()[2 :]:
56- replica_id , node_id = line .split ('|' )
57- replicas .append ({
58- 'replica_id' : int (replica_id ),
59- 'node_id' : int (node_id )
60- })
61- return replicas
66+ return int_vals (self .read_controldb (sql ))
6267
6368 def get_leader_info (self ):
6469 """Get current leader's node information including ID, hostname, and container ID."""
6570
6671 database_id = self .get_db_id ()
67- # Query leader replica ID
68- sql = f"select leader from replication_state where database_id={ database_id } "
69- leader_tb = self .read_controldb (sql )
70- leader_id = get_int (leader_tb )
71-
72- # Query leader node ID
73- sql = f"select node_id from replica where id={ leader_id } "
74- leader_node_tb = self .read_controldb (sql )
75- leader_node_id = get_int (leader_node_tb )
76-
77- # Query leader hostname
78- sql = f"select network_addr from node_v2 where id={ leader_node_id } "
79- leader_host_tb = str (self .read_controldb (sql ))
80- lines = leader_host_tb .splitlines ()
72+ sql = f""" \
73+ select node_v2.id, node_v2.network_addr from node_v2 \
74+ join replica on replica.node_id=node_v2.id \
75+ join replication_state on replication_state.leader=replica.id \
76+ where replication_state.database_id={ database_id } \
77+ """
78+ rows = self .read_controldb (sql )
79+ if not rows :
80+ raise Exception ("Could not find current leader's node" )
8181
82+ leader_node_id = int (rows [0 ]['id' ])
8283 hostname = ""
83- if len (lines ) == 3 : # actual row starts from 3rd line
84- leader_row = lines [2 ]
85- if "(some =" in leader_row :
86- address = leader_row .split ('"' )[1 ]
87- hostname = address .split (':' )[0 ]
84+ if "(some =" in rows [0 ]['network_addr' ]:
85+ address = rows [0 ]['network_addr' ].split ('"' )[1 ]
86+ hostname = address .split (':' )[0 ]
8887
8988 # Find container ID
9089 container_id = ""
@@ -114,15 +113,16 @@ def wait_for_leader_change(self, previous_leader_node, max_attempts=10, delay=2)
114113 time .sleep (delay )
115114 return None
116115
117- def ensure_leader_health (self , id , wait_time = 2 ):
116+ def ensure_leader_health (self , id ):
118117 """Verify leader is healthy by inserting a row."""
119- if wait_time :
120- time .sleep (wait_time )
121118
122119 retry (lambda : self .test .call ("start" , id , 1 ))
123- add_table = str ( self .test . sql (f"SELECT id FROM counter where id={ id } " ) )
124- if str ( id ) not in add_table :
120+ rows = self .sql (f"select id from counter where id={ id } " )
121+ if len ( rows ) < 1 or int ( rows [ 0 ][ 'id' ]) != id :
125122 raise ValueError (f"Could not find { id } in counter table" )
123+ # Wait for at least one tick to ensure buffers are flushed.
124+ # TODO: Replace with confirmed read.
125+ time .sleep (0.6 )
126126
127127
128128 def fail_leader (self , action = 'kill' ):
@@ -247,31 +247,42 @@ def start(self, id: int, count: int):
247247 """Send a message to the database."""
248248 retry (lambda : self .call ("start" , id , count ))
249249
250+ def collect_counter_rows (self ):
251+ return int_vals (self .cluster .sql ("select * from counter" ))
252+
253+
250254class LeaderElection (ReplicationTest ):
251255 def test_leader_election_in_loop (self ):
252256 """This test fails a leader, wait for new leader to be elected and verify if commits replicated to new leader"""
253257 iterations = 5
254258 row_ids = [101 + i for i in range (iterations * 2 )]
255259 for (first_id , second_id ) in zip (row_ids [::2 ], row_ids [1 ::2 ]):
256260 cur_leader = self .cluster .wait_for_leader_change (None )
261+ print (f"ensure leader health { first_id } " )
257262 self .cluster .ensure_leader_health (first_id )
258263
259- print ("killing current leader: {}" , cur_leader )
264+ print (f "killing current leader: { cur_leader } " )
260265 container_id = self .cluster .fail_leader ()
261266
262267 self .assertIsNotNone (container_id )
263268
264269 next_leader = self .cluster .wait_for_leader_change (cur_leader )
265270 self .assertNotEqual (cur_leader , next_leader )
266271 # this check if leader election happened
272+ print (f"ensure_leader_health { second_id } " )
267273 self .cluster .ensure_leader_health (second_id )
268274 # restart the old leader, so that we can maintain quorum for next iteration
275+ print (f"reconnect leader { container_id } " )
269276 self .cluster .restore_leader (container_id , 'start' )
270277
271- # verify if all past rows are present in new leader
272- for row_id in row_ids :
273- table = self .sql (f"SELECT * FROM counter WHERE id = { row_id } " )
274- self .assertIn (f"{ row_id } " , str (table ))
278+ # Ensure we have a current leader
279+ last_row_id = row_ids [- 1 ] + 1
280+ self .cluster .ensure_leader_health (row_ids [- 1 ] + 1 )
281+ row_ids .append (last_row_id )
282+
283+ # Verify that all inserted rows are present
284+ stored_row_ids = [row ['id' ] for row in self .collect_counter_rows ()]
285+ self .assertEqual (set (stored_row_ids ), set (row_ids ))
275286
276287class LeaderDisconnect (ReplicationTest ):
277288 def test_leader_c_disconnect_in_loop (self ):
@@ -300,12 +311,15 @@ def test_leader_c_disconnect_in_loop(self):
300311 # restart the old leader, so that we can maintain quorum for next iteration
301312 print (f"reconnect leader { container_id } " )
302313 self .cluster .restore_leader (container_id , 'connect' )
303- time .sleep (1 )
304314
305- # verify if all past rows are present in new leader
306- for row_id in row_ids :
307- table = self .sql (f"SELECT * FROM counter WHERE id = { row_id } " )
308- self .assertIn (f"{ row_id } " , str (table ))
315+ # Ensure we have a current leader
316+ last_row_id = row_ids [- 1 ] + 1
317+ self .cluster .ensure_leader_health (last_row_id )
318+ row_ids .append (last_row_id )
319+
320+ # Verify that all inserted rows are present
321+ stored_row_ids = [row ['id' ] for row in self .collect_counter_rows ()]
322+ self .assertEqual (set (stored_row_ids ), set (row_ids ))
309323
310324
311325@unittest .skip ("drain_node not yet supported" )
@@ -342,18 +356,16 @@ def test_prefer_leader(self):
342356 if replica ['node_id' ] != cur_leader_node_id :
343357 prefer_replica = replica
344358 break
345- prefer_replica_id = prefer_replica ['replica_id ' ]
359+ prefer_replica_id = prefer_replica ['id ' ]
346360 self .spacetime ("call" , "spacetime-control" , "prefer_leader" , f"{ prefer_replica_id } " )
347361
348362 next_leader_node_id = self .cluster .wait_for_leader_change (cur_leader_node_id )
349363 self .cluster .ensure_leader_health (402 )
350364 self .assertEqual (prefer_replica ['node_id' ], next_leader_node_id )
351365
352-
353366 # verify if all past rows are present in new leader
354- for row_id in [401 , 402 ]:
355- table = self .sql (f"SELECT * FROM counter WHERE id = { row_id } " )
356- self .assertIn (f"{ row_id } " , str (table ))
367+ stored_row_ids = [row ['id' ] for row in self .collect_counter_rows ()]
368+ self .assertEqual (set (stored_row_ids ), set ([401 , 402 ]))
357369
358370
359371class ManyTransactions (ReplicationTest ):
0 commit comments