@@ -39,14 +39,19 @@ class RedisKeys:
3939 LOOP_NONCE = f"{ KEY_PREFIX } :{{app_name}}:nonce:{{loop_id}}"
4040 LOOP_EVENT_CHANNEL = f"{ KEY_PREFIX } :{{app_name}}:events:{{loop_id}}:notify"
4141 LOOP_WAKE_KEY = f"{ KEY_PREFIX } :{{app_name}}:wake:{{loop_id}}"
42- LOOP_WAKE_INDEX = f"{ KEY_PREFIX } :{{app_name}}:wake_index"
42+ # Sorted set: member=loop_id, score=wake_timestamp (source of truth for scheduling)
43+ LOOP_WAKE_SCHEDULE = f"{ KEY_PREFIX } :{{app_name}}:wake_schedule"
4344 LOOP_MAPPING = f"{ KEY_PREFIX } :{{app_name}}:mapping:{{external_ref_id}}"
4445 LOOP_CONNECTION_INDEX = f"{ KEY_PREFIX } :{{app_name}}:connection_index:{{loop_id}}"
4546 LOOP_CONNECTION_KEY = (
4647 f"{ KEY_PREFIX } :{{app_name}}:connection:{{loop_id}}:{{connection_id}}"
4748 )
4849
4950
51+ # How often to check for missed wake events (seconds)
52+ WAKE_RECONCILIATION_INTERVAL_S = 1.0
53+
54+
5055class RedisStateManager (StateManager ):
5156 def __init__ (
5257 self ,
@@ -73,6 +78,8 @@ def __init__(
7378 )
7479
7580 self .wake_queue : Queue [str ] = wake_queue
81+ self ._stop_wake_monitor = threading .Event ()
82+
7683 if self .wake_queue :
7784 self .wake_thread = threading .Thread (
7885 target = self ._run_wake_monitoring , daemon = True
@@ -81,80 +88,109 @@ def __init__(
8188
8289 def _run_wake_monitoring (self ):
8390 """
84- Background thread that monitors Redis key expiration events for wake-up scheduling.
91+ Background thread for reliable wake scheduling.
92+
93+ Uses a hybrid approach for reliability:
94+ 1. ZSET (sorted set) is the source of truth for all scheduled wakes
95+ 2. Periodic reconciliation checks for due wakes every WAKE_RECONCILIATION_INTERVAL_S
96+ 3. TTL keys + keyspace notifications provide low-latency wake for normal operation
8597
86- This runs in a separate thread because:
87- 1. Redis pub/sub requires a blocking connection
88- 2. We need to react to key expiration events in real-time
98+ This ensures wakes are never missed even if:
99+ - The service was down when a wake was due
100+ - Keyspace notifications were missed
101+ - Redis pub/sub disconnected temporarily
89102 """
90- import redis
103+ import redis as sync_redis
91104 from ..logging import setup_logger
92105
93106 logger = setup_logger (__name__ )
94107
95108 try :
96- rdb = redis .Redis (
109+ rdb = sync_redis .Redis (
97110 host = self .config .host ,
98111 port = self .config .port ,
99112 db = self .config .database ,
100113 password = self .config .password ,
101114 ssl = self .config .ssl ,
102115 )
103116
104- # Enable keyspace notifications for expired events
105- with suppress (redis .exceptions .ResponseError ):
106- rdb .config_set ("notify-keyspace-events" , "Ex" ) # type: ignore
117+ # Enable keyspace notifications (best-effort, not required for reliability)
118+ with suppress (sync_redis .exceptions .ResponseError ):
119+ rdb .config_set ("notify-keyspace-events" , "Ex" )
107120
108- # Check for any wake events that may have been missed during downtime
109- self ._check_missed_wake_events_sync (rdb ) # type: ignore
121+ # Process any wakes that were missed while we were down
122+ self ._process_due_wakes (rdb )
110123
111- pubsub : PubSub = rdb .pubsub () # type: ignore
112- pubsub .psubscribe ("__keyevent@*__:expired" ) # type: ignore
124+ # Set up pub/sub for low-latency wake notifications
125+ pubsub = rdb .pubsub ()
126+ pubsub .psubscribe ("__keyevent@*__:expired" )
113127
114- for message in pubsub .listen (): # type: ignore
115- try :
116- if message ["type" ] == "pmessage" :
117- key : str = message ["data" ].decode ("utf-8" ) # type: ignore
118- if f":{ self .app_name } :wake:" in key :
119- loop_id : str = key .split (":" )[- 1 ] # type: ignore
120-
121- if self .wake_queue :
122- self .wake_queue .put (loop_id ) # type: ignore
123-
124- # Remove the full wake key from the index (matches what we add in set_wake_time)
125- rdb .srem (
126- RedisKeys .LOOP_WAKE_INDEX .format (app_name = self .app_name ),
127- key , # Use full key, not just loop_id
128- )
129- except Exception as e :
130- logger .error (
131- f"Error processing wake event: { e } " ,
132- extra = {"error" : str (e )},
133- )
134-
135- except Exception as e :
136- logger .error (
137- f"Wake monitoring thread error: { e } " ,
138- extra = {"error" : str (e )},
139- )
128+ last_reconciliation = time .time ()
140129
141- def _check_missed_wake_events_sync (self , rdb : redis .Redis ):
142- wake_index : list [bytes ] = rdb .smembers ( # type: ignore
143- RedisKeys .LOOP_WAKE_INDEX .format (app_name = self .app_name )
144- )
130+ while not self ._stop_wake_monitor .is_set ():
131+ # Non-blocking check for keyspace notifications
132+ message = pubsub .get_message (timeout = 0.1 )
133+
134+ if message and message ["type" ] == "pmessage" :
135+ try :
136+ key = message ["data" ].decode ("utf-8" )
137+ if f":{ self .app_name } :wake:" in key :
138+ loop_id = key .split (":" )[- 1 ]
139+ self ._queue_wake (rdb , loop_id )
140+ except Exception as e :
141+ logger .error (f"Error processing wake notification: { e } " )
145142
146- for wake_key_bytes in wake_index :
147- wake_key = wake_key_bytes .decode ("utf-8" )
143+ # Periodic reconciliation - the reliability guarantee
144+ now = time .time ()
145+ if now - last_reconciliation >= WAKE_RECONCILIATION_INTERVAL_S :
146+ self ._process_due_wakes (rdb )
147+ last_reconciliation = now
148148
149- if not rdb . exists ( wake_key ) :
150- loop_id = wake_key . split ( ":" )[ - 1 ]
149+ except Exception as e :
150+ logger . error ( f"Wake monitoring thread error: { e } " )
151151
152- if self .wake_queue :
153- self .wake_queue .put (loop_id )
152+ def _process_due_wakes (self , rdb ) -> int :
153+ """
154+ Process all wakes that are due (score <= now).
155+
156+ Uses ZRANGEBYSCORE to atomically get and remove due entries.
157+ Returns the number of wakes processed.
158+ """
159+ schedule_key = RedisKeys .LOOP_WAKE_SCHEDULE .format (app_name = self .app_name )
160+ now = time .time ()
161+ processed = 0
154162
155- rdb .srem (
156- RedisKeys .LOOP_WAKE_INDEX .format (app_name = self .app_name ), wake_key
157- )
163+ # Get all due wakes (score <= now)
164+ due_wakes : list [bytes ] = rdb .zrangebyscore (schedule_key , "-inf" , now )
165+
166+ for loop_id_bytes in due_wakes :
167+ loop_id = loop_id_bytes .decode ("utf-8" )
168+
169+ # Atomically remove from schedule (only if still there with same score)
170+ # This prevents double-processing in multi-replica scenarios
171+ removed = rdb .zrem (schedule_key , loop_id )
172+
173+ if removed :
174+ self .wake_queue .put (loop_id )
175+ processed += 1
176+
177+ return processed
178+
179+ def _queue_wake (self , rdb , loop_id : str ) -> bool :
180+ """
181+ Queue a wake for a loop, removing it from the schedule.
182+
183+ Returns True if the wake was queued, False if already processed.
184+ """
185+ schedule_key = RedisKeys .LOOP_WAKE_SCHEDULE .format (app_name = self .app_name )
186+
187+ # Remove from schedule - if it was there, queue the wake
188+ removed = rdb .zrem (schedule_key , loop_id )
189+
190+ if removed :
191+ self .wake_queue .put (loop_id )
192+ return True
193+ return False
158194
159195 async def set_loop_mapping (self , external_ref_id : str , loop_id : str ):
160196 await self .rdb .set (
@@ -439,26 +475,30 @@ async def pop_event(
439475
440476 async def set_wake_time (self , loop_id : str , timestamp : float ) -> None :
441477 """
442- Set a wake time for a loop. Uses Redis key expiration with millisecond precision .
478+ Schedule a wake time for a loop.
443479
444- Note: Redis requires TTL >= 1ms. For very short sleeps (< 1ms), we use 1ms minimum.
445- """
446- ttl_seconds = timestamp - time . time ()
480+ Uses two mechanisms for reliability:
481+ 1. ZSET (sorted set) - Source of truth, survives restarts
482+ 2. TTL key - Triggers keyspace notification for low-latency wake
447483
448- if ttl_seconds <= 0 :
484+ The periodic reconciliation in _process_due_wakes ensures wakes
485+ are never missed even if keyspace notifications fail.
486+ """
487+ if timestamp <= time .time ():
449488 raise ValueError ("Timestamp is in the past" )
450489
451- # Convert to milliseconds for pexpire (px parameter), minimum 1ms
452- ttl_ms = max (1 , int (ttl_seconds * 1000 ))
453-
490+ schedule_key = RedisKeys .LOOP_WAKE_SCHEDULE .format (app_name = self .app_name )
454491 wake_key = RedisKeys .LOOP_WAKE_KEY .format (
455492 app_name = self .app_name , loop_id = loop_id
456493 )
457- wake_index = RedisKeys .LOOP_WAKE_INDEX .format (app_name = self .app_name )
458494
459- # Use px (milliseconds) instead of ex (seconds) for better precision
460- await self .rdb .set (wake_key , "wake" , px = ttl_ms )
461- await self .rdb .sadd (wake_index , wake_key ) # pyright: ignore
495+ ttl_ms = max (1 , int ((timestamp - time .time ()) * 1000 ))
496+
497+ # Atomic: add to schedule and set TTL key
498+ async with self .rdb .pipeline (transaction = True ) as pipe :
499+ pipe .zadd (schedule_key , {loop_id : timestamp })
500+ pipe .set (wake_key , "1" , px = ttl_ms )
501+ await pipe .execute ()
462502
463503 async def get_initial_event (self , loop_id : str ) -> "LoopEvent | None" :
464504 """Get the initial event for a loop."""
0 commit comments