Skip to content

Commit da84238

Browse files
Refactor Redis wake scheduling for reliability
Co-authored-by: luke <luke@smartshare.io>
1 parent 29c5e59 commit da84238

2 files changed

Lines changed: 198 additions & 132 deletions

File tree

fastloop/state/state_redis.py

Lines changed: 105 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,19 @@ class RedisKeys:
3939
LOOP_NONCE = f"{KEY_PREFIX}:{{app_name}}:nonce:{{loop_id}}"
4040
LOOP_EVENT_CHANNEL = f"{KEY_PREFIX}:{{app_name}}:events:{{loop_id}}:notify"
4141
LOOP_WAKE_KEY = f"{KEY_PREFIX}:{{app_name}}:wake:{{loop_id}}"
42-
LOOP_WAKE_INDEX = f"{KEY_PREFIX}:{{app_name}}:wake_index"
42+
# Sorted set: member=loop_id, score=wake_timestamp (source of truth for scheduling)
43+
LOOP_WAKE_SCHEDULE = f"{KEY_PREFIX}:{{app_name}}:wake_schedule"
4344
LOOP_MAPPING = f"{KEY_PREFIX}:{{app_name}}:mapping:{{external_ref_id}}"
4445
LOOP_CONNECTION_INDEX = f"{KEY_PREFIX}:{{app_name}}:connection_index:{{loop_id}}"
4546
LOOP_CONNECTION_KEY = (
4647
f"{KEY_PREFIX}:{{app_name}}:connection:{{loop_id}}:{{connection_id}}"
4748
)
4849

4950

51+
# How often to check for missed wake events (seconds)
52+
WAKE_RECONCILIATION_INTERVAL_S = 1.0
53+
54+
5055
class RedisStateManager(StateManager):
5156
def __init__(
5257
self,
@@ -73,6 +78,8 @@ def __init__(
7378
)
7479

7580
self.wake_queue: Queue[str] = wake_queue
81+
self._stop_wake_monitor = threading.Event()
82+
7683
if self.wake_queue:
7784
self.wake_thread = threading.Thread(
7885
target=self._run_wake_monitoring, daemon=True
@@ -81,80 +88,109 @@ def __init__(
8188

8289
def _run_wake_monitoring(self):
8390
"""
84-
Background thread that monitors Redis key expiration events for wake-up scheduling.
91+
Background thread for reliable wake scheduling.
92+
93+
Uses a hybrid approach for reliability:
94+
1. ZSET (sorted set) is the source of truth for all scheduled wakes
95+
2. Periodic reconciliation checks for due wakes every WAKE_RECONCILIATION_INTERVAL_S
96+
3. TTL keys + keyspace notifications provide low-latency wake for normal operation
8597
86-
This runs in a separate thread because:
87-
1. Redis pub/sub requires a blocking connection
88-
2. We need to react to key expiration events in real-time
98+
This ensures wakes are never missed even if:
99+
- The service was down when a wake was due
100+
- Keyspace notifications were missed
101+
- Redis pub/sub disconnected temporarily
89102
"""
90-
import redis
103+
import redis as sync_redis
91104
from ..logging import setup_logger
92105

93106
logger = setup_logger(__name__)
94107

95108
try:
96-
rdb = redis.Redis(
109+
rdb = sync_redis.Redis(
97110
host=self.config.host,
98111
port=self.config.port,
99112
db=self.config.database,
100113
password=self.config.password,
101114
ssl=self.config.ssl,
102115
)
103116

104-
# Enable keyspace notifications for expired events
105-
with suppress(redis.exceptions.ResponseError):
106-
rdb.config_set("notify-keyspace-events", "Ex") # type: ignore
117+
# Enable keyspace notifications (best-effort, not required for reliability)
118+
with suppress(sync_redis.exceptions.ResponseError):
119+
rdb.config_set("notify-keyspace-events", "Ex")
107120

108-
# Check for any wake events that may have been missed during downtime
109-
self._check_missed_wake_events_sync(rdb) # type: ignore
121+
# Process any wakes that were missed while we were down
122+
self._process_due_wakes(rdb)
110123

111-
pubsub: PubSub = rdb.pubsub() # type: ignore
112-
pubsub.psubscribe("__keyevent@*__:expired") # type: ignore
124+
# Set up pub/sub for low-latency wake notifications
125+
pubsub = rdb.pubsub()
126+
pubsub.psubscribe("__keyevent@*__:expired")
113127

114-
for message in pubsub.listen(): # type: ignore
115-
try:
116-
if message["type"] == "pmessage":
117-
key: str = message["data"].decode("utf-8") # type: ignore
118-
if f":{self.app_name}:wake:" in key:
119-
loop_id: str = key.split(":")[-1] # type: ignore
120-
121-
if self.wake_queue:
122-
self.wake_queue.put(loop_id) # type: ignore
123-
124-
# Remove the full wake key from the index (matches what we add in set_wake_time)
125-
rdb.srem(
126-
RedisKeys.LOOP_WAKE_INDEX.format(app_name=self.app_name),
127-
key, # Use full key, not just loop_id
128-
)
129-
except Exception as e:
130-
logger.error(
131-
f"Error processing wake event: {e}",
132-
extra={"error": str(e)},
133-
)
134-
135-
except Exception as e:
136-
logger.error(
137-
f"Wake monitoring thread error: {e}",
138-
extra={"error": str(e)},
139-
)
128+
last_reconciliation = time.time()
140129

141-
def _check_missed_wake_events_sync(self, rdb: redis.Redis):
142-
wake_index: list[bytes] = rdb.smembers( # type: ignore
143-
RedisKeys.LOOP_WAKE_INDEX.format(app_name=self.app_name)
144-
)
130+
while not self._stop_wake_monitor.is_set():
131+
# Non-blocking check for keyspace notifications
132+
message = pubsub.get_message(timeout=0.1)
133+
134+
if message and message["type"] == "pmessage":
135+
try:
136+
key = message["data"].decode("utf-8")
137+
if f":{self.app_name}:wake:" in key:
138+
loop_id = key.split(":")[-1]
139+
self._queue_wake(rdb, loop_id)
140+
except Exception as e:
141+
logger.error(f"Error processing wake notification: {e}")
145142

146-
for wake_key_bytes in wake_index:
147-
wake_key = wake_key_bytes.decode("utf-8")
143+
# Periodic reconciliation - the reliability guarantee
144+
now = time.time()
145+
if now - last_reconciliation >= WAKE_RECONCILIATION_INTERVAL_S:
146+
self._process_due_wakes(rdb)
147+
last_reconciliation = now
148148

149-
if not rdb.exists(wake_key):
150-
loop_id = wake_key.split(":")[-1]
149+
except Exception as e:
150+
logger.error(f"Wake monitoring thread error: {e}")
151151

152-
if self.wake_queue:
153-
self.wake_queue.put(loop_id)
152+
def _process_due_wakes(self, rdb) -> int:
153+
"""
154+
Process all wakes that are due (score <= now).
155+
156+
Uses ZRANGEBYSCORE to atomically get and remove due entries.
157+
Returns the number of wakes processed.
158+
"""
159+
schedule_key = RedisKeys.LOOP_WAKE_SCHEDULE.format(app_name=self.app_name)
160+
now = time.time()
161+
processed = 0
154162

155-
rdb.srem(
156-
RedisKeys.LOOP_WAKE_INDEX.format(app_name=self.app_name), wake_key
157-
)
163+
# Get all due wakes (score <= now)
164+
due_wakes: list[bytes] = rdb.zrangebyscore(schedule_key, "-inf", now)
165+
166+
for loop_id_bytes in due_wakes:
167+
loop_id = loop_id_bytes.decode("utf-8")
168+
169+
# Atomically remove from schedule (only if still there with same score)
170+
# This prevents double-processing in multi-replica scenarios
171+
removed = rdb.zrem(schedule_key, loop_id)
172+
173+
if removed:
174+
self.wake_queue.put(loop_id)
175+
processed += 1
176+
177+
return processed
178+
179+
def _queue_wake(self, rdb, loop_id: str) -> bool:
180+
"""
181+
Queue a wake for a loop, removing it from the schedule.
182+
183+
Returns True if the wake was queued, False if already processed.
184+
"""
185+
schedule_key = RedisKeys.LOOP_WAKE_SCHEDULE.format(app_name=self.app_name)
186+
187+
# Remove from schedule - if it was there, queue the wake
188+
removed = rdb.zrem(schedule_key, loop_id)
189+
190+
if removed:
191+
self.wake_queue.put(loop_id)
192+
return True
193+
return False
158194

159195
async def set_loop_mapping(self, external_ref_id: str, loop_id: str):
160196
await self.rdb.set(
@@ -439,26 +475,30 @@ async def pop_event(
439475

440476
async def set_wake_time(self, loop_id: str, timestamp: float) -> None:
441477
"""
442-
Set a wake time for a loop. Uses Redis key expiration with millisecond precision.
478+
Schedule a wake time for a loop.
443479
444-
Note: Redis requires TTL >= 1ms. For very short sleeps (< 1ms), we use 1ms minimum.
445-
"""
446-
ttl_seconds = timestamp - time.time()
480+
Uses two mechanisms for reliability:
481+
1. ZSET (sorted set) - Source of truth, survives restarts
482+
2. TTL key - Triggers keyspace notification for low-latency wake
447483
448-
if ttl_seconds <= 0:
484+
The periodic reconciliation in _process_due_wakes ensures wakes
485+
are never missed even if keyspace notifications fail.
486+
"""
487+
if timestamp <= time.time():
449488
raise ValueError("Timestamp is in the past")
450489

451-
# Convert to milliseconds for pexpire (px parameter), minimum 1ms
452-
ttl_ms = max(1, int(ttl_seconds * 1000))
453-
490+
schedule_key = RedisKeys.LOOP_WAKE_SCHEDULE.format(app_name=self.app_name)
454491
wake_key = RedisKeys.LOOP_WAKE_KEY.format(
455492
app_name=self.app_name, loop_id=loop_id
456493
)
457-
wake_index = RedisKeys.LOOP_WAKE_INDEX.format(app_name=self.app_name)
458494

459-
# Use px (milliseconds) instead of ex (seconds) for better precision
460-
await self.rdb.set(wake_key, "wake", px=ttl_ms)
461-
await self.rdb.sadd(wake_index, wake_key) # pyright: ignore
495+
ttl_ms = max(1, int((timestamp - time.time()) * 1000))
496+
497+
# Atomic: add to schedule and set TTL key
498+
async with self.rdb.pipeline(transaction=True) as pipe:
499+
pipe.zadd(schedule_key, {loop_id: timestamp})
500+
pipe.set(wake_key, "1", px=ttl_ms)
501+
await pipe.execute()
462502

463503
async def get_initial_event(self, loop_id: str) -> "LoopEvent | None":
464504
"""Get the initial event for a loop."""

0 commit comments

Comments
 (0)