Skip to content

Commit fad3046

Browse files
committed
feat: add time index functionality to ListRedisScheduleSource
- Introduced `populate_time_index` parameter to backfill the time index from existing keys. - Updated `startup` method to populate the time index if `populate_time_index` is set to True. - Modified schedule addition and deletion to manage the time index sorted set. - Added tests to verify time index population and cleanup behavior.
1 parent 374c789 commit fad3046

2 files changed

Lines changed: 287 additions & 13 deletions

File tree

taskiq_redis/list_schedule_source.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
serializer: TaskiqSerializer | None = None,
2424
buffer_size: int = 50,
2525
skip_past_schedules: bool = False,
26+
populate_time_index: bool = False,
2627
**connection_kwargs: Any,
2728
) -> None:
2829
"""
@@ -34,6 +35,11 @@ def __init__(
3435
:param serializer: Serializer to use for the schedules
3536
:param buffer_size: Buffer size for getting schedules
3637
:param skip_past_schedules: Skip schedules that are in the past.
38+
:param populate_time_index: If True, on startup run a one-time SCAN
39+
to populate the time index sorted set from existing time keys.
40+
This is needed for migrating from an older version that did not
41+
maintain the time index. Set this to True once to backfill the
42+
index, then set it back to False for subsequent runs.
3743
:param connection_kwargs: Additional connection kwargs
3844
"""
3945
super().__init__()
@@ -51,6 +57,7 @@ def __init__(
5157
self._previous_schedule_source: ScheduleSource | None = None
5258
self._delete_schedules_after_migration: bool = True
5359
self._skip_past_schedules = skip_past_schedules
60+
self._populate_time_index = populate_time_index
5461

5562
async def startup(self) -> None:
5663
"""
@@ -59,6 +66,9 @@ async def startup(self) -> None:
5966
By default this function does nothing.
6067
But if the previous schedule source is set,
6168
it will try to migrate schedules from it.
69+
70+
If populate_time_index is True, it will scan for existing
71+
time keys and populate the time index sorted set.
6272
"""
6373
if self._previous_schedule_source is not None:
6474
logger.info("Migrating schedules from previous source")
@@ -74,13 +84,36 @@ async def startup(self) -> None:
7484
await self._previous_schedule_source.shutdown()
7585
logger.info("Migration complete")
7686

87+
if self._populate_time_index:
88+
logger.info("Populating time index from existing keys via scan")
89+
async with Redis(connection_pool=self._connection_pool) as redis:
90+
batch: dict[str, float] = {}
91+
async for key in redis.scan_iter(f"{self._prefix}:time:*"):
92+
key_str = key.decode()
93+
key_time = self._parse_time_key(key_str)
94+
if key_time:
95+
batch[key_str] = key_time.timestamp()
96+
if len(batch) >= self._buffer_size:
97+
await redis.zadd(
98+
self._get_time_index_key(),
99+
batch,
100+
)
101+
batch = {}
102+
if batch:
103+
await redis.zadd(self._get_time_index_key(), batch)
104+
logger.info("Time index population complete")
105+
77106
def _get_time_key(self, time: datetime.datetime) -> str:
78107
"""Get the key for a time-based schedule."""
79108
if time.tzinfo is None:
80109
time = time.replace(tzinfo=datetime.timezone.utc)
81110
iso_time = time.astimezone(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M")
82111
return f"{self._prefix}:time:{iso_time}"
83112

113+
def _get_time_index_key(self) -> str:
114+
"""Get the key for the time index sorted set."""
115+
return f"{self._prefix}:time_index"
116+
84117
def _get_cron_key(self) -> str:
85118
"""Get the key for a cron-based schedule."""
86119
return f"{self._prefix}:cron"
@@ -111,8 +144,8 @@ async def _get_previous_time_schedules(self) -> list[bytes]:
111144
we need to get all the schedules that are in the past and haven't
112145
been sent yet.
113146
114-
We do this by getting all the time keys and checking if the time
115-
is less than the current time.
147+
Uses the time index sorted set to look up past time keys
148+
instead of scanning all Redis keys.
116149
117150
This function is called only during the first run to minimize
118151
the number of requests to the Redis server.
@@ -125,13 +158,12 @@ async def _get_previous_time_schedules(self) -> list[bytes]:
125158
)
126159
schedules = []
127160
async with Redis(connection_pool=self._connection_pool) as redis:
128-
time_keys: list[str] = []
129-
# We need to get all the time keys and check if the time is less than
130-
# the current time.
131-
async for key in redis.scan_iter(f"{self._prefix}:time:*"):
132-
key_time = self._parse_time_key(key.decode())
133-
if key_time and key_time <= minute_before:
134-
time_keys.append(key.decode())
161+
max_score = minute_before.timestamp()
162+
time_keys: list[bytes] = await redis.zrangebyscore(
163+
self._get_time_index_key(),
164+
"-inf",
165+
max_score,
166+
)
135167
for key in time_keys:
136168
schedules.extend(await redis.lrange(key, 0, -1)) # type: ignore[misc]
137169

@@ -153,6 +185,14 @@ async def delete_schedule(self, schedule_id: str) -> None:
153185
elif schedule.time is not None:
154186
time_key = self._get_time_key(schedule.time)
155187
await redis.lrem(time_key, 0, schedule_id) # type: ignore[misc]
188+
# If the time key list is now empty, clean up both
189+
# the list key and its entry in the time index.
190+
if await redis.llen(time_key) == 0:
191+
await redis.delete(time_key)
192+
await redis.zrem(
193+
self._get_time_index_key(),
194+
time_key,
195+
)
156196
elif schedule.interval:
157197
await redis.lrem(self._get_interval_key(), 0, schedule_id) # type: ignore[misc]
158198

@@ -170,9 +210,21 @@ async def add_schedule(self, schedule: "ScheduledTask") -> None:
170210
if schedule.cron is not None:
171211
await redis.rpush(self._get_cron_key(), schedule.schedule_id) # type: ignore[misc]
172212
elif schedule.time is not None:
173-
await redis.rpush( # type: ignore[misc]
174-
self._get_time_key(schedule.time),
175-
schedule.schedule_id,
213+
time_key = self._get_time_key(schedule.time)
214+
await redis.rpush(time_key, schedule.schedule_id) # type: ignore[misc]
215+
# Add to the time index sorted set so we can look up
216+
# past time keys without scanning all Redis keys.
217+
time_val = schedule.time
218+
if time_val.tzinfo is None:
219+
time_val = time_val.replace(tzinfo=datetime.timezone.utc)
220+
score = (
221+
time_val.astimezone(datetime.timezone.utc)
222+
.replace(second=0, microsecond=0)
223+
.timestamp()
224+
)
225+
await redis.zadd( # type: ignore[misc]
226+
self._get_time_index_key(),
227+
{time_key: score},
176228
)
177229
elif schedule.interval:
178230
await redis.rpush( # type: ignore[misc]
@@ -200,7 +252,7 @@ async def get_schedules(self) -> list["ScheduledTask"]:
200252
current_time = datetime.datetime.now(datetime.timezone.utc)
201253
timed: list[bytes] = []
202254
# Only during first run, we need to get previous time schedules
203-
if not self._skip_past_schedules:
255+
if not self._skip_past_schedules and self._is_first_run:
204256
timed = await self._get_previous_time_schedules()
205257
self._is_first_run = False
206258
async with Redis(connection_pool=self._connection_pool) as redis:

tests/test_list_schedule_source.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
from freezegun import freeze_time
6+
from redis.asyncio import BlockingConnectionPool, Redis
67
from taskiq import ScheduledTask
78

89
from taskiq_redis.list_schedule_source import ListRedisScheduleSource
@@ -179,3 +180,224 @@ async def test_migration(redis_url: str) -> None:
179180
for old_schedule in old_schedules:
180181
with freeze_time(old_schedule.time):
181182
assert await source.get_schedules() == [old_schedule]
183+
184+
185+
@pytest.mark.anyio
186+
@freeze_time("2025-01-01 00:00:00")
187+
async def test_time_index_populated_on_add(redis_url: str) -> None:
188+
"""Test that adding a time schedule populates the time index sorted set."""
189+
prefix = uuid.uuid4().hex
190+
source = ListRedisScheduleSource(redis_url, prefix=prefix)
191+
schedule = ScheduledTask(
192+
task_name="test_task",
193+
labels={},
194+
args=[],
195+
kwargs={},
196+
time=datetime.datetime.now(datetime.timezone.utc)
197+
+ datetime.timedelta(minutes=5),
198+
)
199+
await source.add_schedule(schedule)
200+
201+
# Verify the time index sorted set has an entry.
202+
async with Redis(connection_pool=source._connection_pool) as redis:
203+
members = await redis.zrange(source._get_time_index_key(), 0, -1)
204+
assert len(members) == 1
205+
assert members[0].decode() == source._get_time_key(schedule.time)
206+
207+
208+
@pytest.mark.anyio
209+
@freeze_time("2025-01-01 00:00:00")
210+
async def test_time_index_cleaned_on_delete(redis_url: str) -> None:
211+
"""Test that deleting last schedule from a time key cleans the index."""
212+
prefix = uuid.uuid4().hex
213+
source = ListRedisScheduleSource(redis_url, prefix=prefix)
214+
schedule = ScheduledTask(
215+
task_name="test_task",
216+
labels={},
217+
args=[],
218+
kwargs={},
219+
time=datetime.datetime.now(datetime.timezone.utc)
220+
+ datetime.timedelta(minutes=5),
221+
)
222+
await source.add_schedule(schedule)
223+
224+
# Index has 1 entry.
225+
async with Redis(connection_pool=source._connection_pool) as redis:
226+
assert await redis.zcard(source._get_time_index_key()) == 1
227+
228+
await source.delete_schedule(schedule.schedule_id)
229+
230+
# After deletion, the index should be empty.
231+
async with Redis(connection_pool=source._connection_pool) as redis:
232+
assert await redis.zcard(source._get_time_index_key()) == 0
233+
# The time key list itself should also be deleted.
234+
assert not await redis.exists(source._get_time_key(schedule.time))
235+
236+
237+
@pytest.mark.anyio
238+
@freeze_time("2025-01-01 00:00:00")
239+
async def test_time_index_not_cleaned_when_other_schedules_remain(
240+
redis_url: str,
241+
) -> None:
242+
"""Test that deleting one schedule doesn't remove the index entry
243+
when other schedules still exist at the same time."""
244+
prefix = uuid.uuid4().hex
245+
source = ListRedisScheduleSource(redis_url, prefix=prefix)
246+
schedule_time = datetime.datetime.now(
247+
datetime.timezone.utc,
248+
) + datetime.timedelta(minutes=5)
249+
schedule1 = ScheduledTask(
250+
task_name="test_task_1",
251+
labels={},
252+
args=[],
253+
kwargs={},
254+
time=schedule_time,
255+
)
256+
schedule2 = ScheduledTask(
257+
task_name="test_task_2",
258+
labels={},
259+
args=[],
260+
kwargs={},
261+
time=schedule_time,
262+
)
263+
await source.add_schedule(schedule1)
264+
await source.add_schedule(schedule2)
265+
266+
await source.delete_schedule(schedule1.schedule_id)
267+
268+
# Index should still have the entry because schedule2 remains.
269+
async with Redis(connection_pool=source._connection_pool) as redis:
270+
assert await redis.zcard(source._get_time_index_key()) == 1
271+
272+
await source.delete_schedule(schedule2.schedule_id)
273+
274+
# Now the index should be empty.
275+
async with Redis(connection_pool=source._connection_pool) as redis:
276+
assert await redis.zcard(source._get_time_index_key()) == 0
277+
278+
279+
@pytest.mark.anyio
280+
@freeze_time("2025-01-01 00:00:00")
281+
async def test_past_schedules_found_via_time_index(redis_url: str) -> None:
282+
"""Test that past schedules are discovered via the time index
283+
instead of a full SCAN."""
284+
prefix = uuid.uuid4().hex
285+
source = ListRedisScheduleSource(redis_url, prefix=prefix)
286+
past_time = datetime.datetime.now(
287+
datetime.timezone.utc,
288+
) - datetime.timedelta(minutes=5)
289+
schedule = ScheduledTask(
290+
task_name="test_task",
291+
labels={},
292+
args=[],
293+
kwargs={},
294+
time=past_time,
295+
)
296+
await source.add_schedule(schedule)
297+
298+
# First call to get_schedules should find the past schedule via time index.
299+
schedules = await source.get_schedules()
300+
assert schedules == [schedule]
301+
302+
303+
@pytest.mark.anyio
304+
@freeze_time("2025-01-01 00:00:00")
305+
async def test_populate_time_index_from_existing_keys(redis_url: str) -> None:
306+
"""Test that populate_time_index=True backfills the sorted set
307+
from existing time keys created without the index."""
308+
prefix = uuid.uuid4().hex
309+
310+
# Simulate old-style data: create time key lists directly in Redis
311+
# without populating the time index sorted set.
312+
pool = BlockingConnectionPool.from_url(url=redis_url)
313+
past_times = [
314+
datetime.datetime(2024, 12, 31, 23, 55, tzinfo=datetime.timezone.utc),
315+
datetime.datetime(2024, 12, 31, 23, 56, tzinfo=datetime.timezone.utc),
316+
datetime.datetime(2024, 12, 31, 23, 57, tzinfo=datetime.timezone.utc),
317+
]
318+
319+
source_for_keys = ListRedisScheduleSource(redis_url, prefix=prefix)
320+
async with Redis(connection_pool=pool) as redis:
321+
for t in past_times:
322+
time_key = source_for_keys._get_time_key(t)
323+
# Push a dummy schedule ID directly (bypassing add_schedule
324+
# to simulate old behavior without time index).
325+
await redis.rpush(time_key, f"sched_{t.minute}") # type: ignore[misc]
326+
327+
# Verify no time index exists yet.
328+
assert await redis.zcard(source_for_keys._get_time_index_key()) == 0
329+
await pool.disconnect()
330+
331+
# Now create a source with populate_time_index=True.
332+
source = ListRedisScheduleSource(
333+
redis_url,
334+
prefix=prefix,
335+
populate_time_index=True,
336+
)
337+
await source.startup()
338+
339+
# The time index should now be populated.
340+
async with Redis(connection_pool=source._connection_pool) as redis:
341+
count = await redis.zcard(source._get_time_index_key())
342+
assert count == len(past_times)
343+
344+
345+
@pytest.mark.anyio
346+
@freeze_time("2025-01-01 00:00:00")
347+
async def test_post_send_cleans_time_index(redis_url: str) -> None:
348+
"""Test that post_send (which calls delete_schedule for time tasks)
349+
properly cleans up the time index."""
350+
prefix = uuid.uuid4().hex
351+
source = ListRedisScheduleSource(redis_url, prefix=prefix)
352+
schedule = ScheduledTask(
353+
task_name="test_task",
354+
labels={},
355+
args=[],
356+
kwargs={},
357+
time=datetime.datetime.now(datetime.timezone.utc)
358+
- datetime.timedelta(minutes=3),
359+
)
360+
await source.add_schedule(schedule)
361+
362+
# First run picks up past schedules.
363+
schedules = await source.get_schedules()
364+
assert schedules == [schedule]
365+
366+
# Simulate sending the task.
367+
for s in schedules:
368+
await source.post_send(s)
369+
370+
# Time index should be empty now.
371+
async with Redis(connection_pool=source._connection_pool) as redis:
372+
assert await redis.zcard(source._get_time_index_key()) == 0
373+
374+
# Second run should return nothing.
375+
schedules = await source.get_schedules()
376+
assert schedules == []
377+
378+
379+
@pytest.mark.anyio
380+
@freeze_time("2025-01-01 00:00:00")
381+
async def test_cron_and_interval_not_in_time_index(redis_url: str) -> None:
382+
"""Test that cron and interval schedules do not affect the time index."""
383+
prefix = uuid.uuid4().hex
384+
source = ListRedisScheduleSource(redis_url, prefix=prefix)
385+
cron_schedule = ScheduledTask(
386+
task_name="cron_task",
387+
labels={},
388+
args=[],
389+
kwargs={},
390+
cron="* * * * *",
391+
)
392+
interval_schedule = ScheduledTask(
393+
task_name="interval_task",
394+
labels={},
395+
args=[],
396+
kwargs={},
397+
interval=datetime.timedelta(seconds=30),
398+
)
399+
await source.add_schedule(cron_schedule)
400+
await source.add_schedule(interval_schedule)
401+
402+
async with Redis(connection_pool=source._connection_pool) as redis:
403+
assert await redis.zcard(source._get_time_index_key()) == 0

0 commit comments

Comments
 (0)