Skip to content

Commit 4d6c0c9

Browse files
Igor Terekhovbsbodden
authored andcommitted
feat: add adelete_thread method implementation for async shallow saver
1 parent 6456581 commit 4d6c0c9

2 files changed

Lines changed: 212 additions & 0 deletions

File tree

langgraph/checkpoint/redis/ashallow.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,3 +876,66 @@ def _extract_fallback_timestamp(self, checkpoint: Checkpoint) -> float:
876876
else:
877877
return ts_value
878878
return time.time() * MILLISECONDS_PER_SECOND
879+
880+
async def adelete_thread(self, thread_id: str) -> None:
881+
"""Delete checkpoint and writes associated with a specific thread ID.
882+
883+
Args:
884+
thread_id: The thread ID which checkpoint should be deleted.
885+
"""
886+
storage_safe_thread_id = to_storage_safe_id(thread_id)
887+
888+
# Despite the fact that shallow saver stores only the current version
889+
# of checkpoint, there may be several while using subgraphs.
890+
checkpoint_query = FilterQuery(
891+
filter_expression=Tag("thread_id") == thread_id,
892+
return_fields=["checkpoint_ns", "checkpoint_id"],
893+
num_results=10000,
894+
)
895+
896+
checkpoint_results = await self.checkpoints_index.search(checkpoint_query)
897+
898+
# Collect all keys to delete
899+
keys_to_delete = []
900+
checkpoint_namespaces = set()
901+
902+
for doc in checkpoint_results.docs:
903+
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
904+
# Collect namespaces to clean write_keys_zset later
905+
checkpoint_namespaces.add(checkpoint_ns)
906+
907+
# Delete checkpoint key
908+
checkpoint_key = self._make_shallow_redis_checkpoint_key(
909+
thread_id, checkpoint_ns
910+
)
911+
keys_to_delete.append(checkpoint_key)
912+
913+
checkpoint_writes_query = FilterQuery(
914+
filter_expression=Tag("thread_id") == thread_id,
915+
return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"],
916+
num_results=10000,
917+
)
918+
checkpoint_writes_results = await self.checkpoint_writes_index.search(
919+
checkpoint_writes_query
920+
)
921+
for doc in checkpoint_writes_results.docs:
922+
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
923+
checkpoint_id = getattr(doc, "checkpoint_id", "")
924+
task_id = getattr(doc, "task_id", "")
925+
idx = getattr(doc, "idx", 0)
926+
write_key = self._make_redis_checkpoint_writes_key(
927+
thread_id, checkpoint_ns, checkpoint_id, task_id, idx
928+
)
929+
keys_to_delete.append(write_key)
930+
checkpoint_namespaces.add(checkpoint_ns)
931+
932+
for checkpoint_ns in checkpoint_namespaces:
933+
keys_to_delete.append(
934+
f"write_keys_zset:{storage_safe_thread_id}:{to_storage_safe_str(checkpoint_ns)}:shallow"
935+
)
936+
937+
# use pipeline for efficiency
938+
pipeline = self._redis.pipeline()
939+
for key in keys_to_delete:
940+
pipeline.delete(key)
941+
await pipeline.execute()
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""Integration tests for AsyncShallowRedisSaver.adelete_thread."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any, Dict, List, Tuple
6+
7+
import pytest
8+
from langchain_core.runnables import RunnableConfig
9+
from langgraph.checkpoint.base import (
10+
WRITES_IDX_MAP,
11+
CheckpointMetadata,
12+
create_checkpoint,
13+
empty_checkpoint,
14+
)
15+
16+
from langgraph.checkpoint.redis.ashallow import AsyncShallowRedisSaver
17+
from langgraph.checkpoint.redis.util import to_storage_safe_id, to_storage_safe_str
18+
19+
20+
def _expected_write_keys(
21+
*,
22+
saver: AsyncShallowRedisSaver,
23+
thread_id: str,
24+
checkpoint_ns: str,
25+
checkpoint_id: str,
26+
task_id: str,
27+
writes: List[Tuple[str, Any]],
28+
) -> List[str]:
29+
"""Compute the concrete Redis keys created by aput_writes."""
30+
keys: List[str] = []
31+
for enum_idx, (channel, _value) in enumerate(writes):
32+
idx = WRITES_IDX_MAP.get(channel, enum_idx)
33+
keys.append(
34+
saver._make_redis_checkpoint_writes_key_cached( # noqa: SLF001
35+
thread_id, checkpoint_ns, checkpoint_id, task_id, idx
36+
)
37+
)
38+
return keys
39+
40+
41+
@pytest.mark.asyncio
42+
async def test_adelete_thread_cleans_shallow_checkpoints_writes_and_registry(
43+
redis_url: str, async_client
44+
) -> None:
45+
thread_id = "test-ashallow-adelete-thread"
46+
other_thread_id = "test-ashallow-adelete-thread-other"
47+
48+
# Two namespaces to simulate subgraph usage in shallow mode.
49+
namespaces = ["", "inner"]
50+
51+
async with AsyncShallowRedisSaver.from_conn_string(redis_url) as saver:
52+
created: Dict[str, Dict[str, Any]] = {}
53+
54+
for checkpoint_ns in namespaces:
55+
config: RunnableConfig = {
56+
"configurable": {
57+
"thread_id": thread_id,
58+
"checkpoint_ns": checkpoint_ns,
59+
}
60+
}
61+
checkpoint = create_checkpoint(empty_checkpoint(), {}, 1)
62+
metadata: CheckpointMetadata = {"source": "input", "step": 1, "writes": {}}
63+
64+
saved_config = await saver.aput(config, checkpoint, metadata, {})
65+
checkpoint_id = saved_config["configurable"]["checkpoint_id"]
66+
67+
# Create a couple writes and record expected keys.
68+
writes = [("channel1", "value1"), ("channel2", "value2")]
69+
task_id = f"task-{checkpoint_ns or 'root'}"
70+
await saver.aput_writes(saved_config, writes, task_id)
71+
72+
checkpoint_key = (
73+
saver._make_shallow_redis_checkpoint_key_cached( # noqa: SLF001
74+
thread_id, checkpoint_ns
75+
)
76+
)
77+
zset_key = (
78+
f"write_keys_zset:{to_storage_safe_id(thread_id)}:"
79+
f"{to_storage_safe_str(checkpoint_ns)}:shallow"
80+
)
81+
write_keys = _expected_write_keys(
82+
saver=saver,
83+
thread_id=thread_id,
84+
checkpoint_ns=checkpoint_ns,
85+
checkpoint_id=checkpoint_id,
86+
task_id=task_id,
87+
writes=writes,
88+
)
89+
90+
created[checkpoint_ns] = {
91+
"saved_config": saved_config,
92+
"checkpoint_key": checkpoint_key,
93+
"zset_key": zset_key,
94+
"write_keys": write_keys,
95+
}
96+
97+
# Also create a checkpoint for a different thread that must not be deleted.
98+
other_config: RunnableConfig = {
99+
"configurable": {"thread_id": other_thread_id, "checkpoint_ns": ""}
100+
}
101+
other_checkpoint = create_checkpoint(empty_checkpoint(), {}, 1)
102+
other_saved_config = await saver.aput(
103+
other_config,
104+
other_checkpoint,
105+
{"source": "input", "step": 1, "writes": {}},
106+
{},
107+
)
108+
other_checkpoint_key = (
109+
saver._make_shallow_redis_checkpoint_key_cached( # noqa: SLF001
110+
other_thread_id, ""
111+
)
112+
)
113+
114+
# Assert keys exist before deletion (direct key checks; avoids index lag).
115+
assert await async_client.exists(other_checkpoint_key) == 1
116+
assert await saver.aget_tuple(other_saved_config) is not None
117+
118+
for checkpoint_ns in namespaces:
119+
checkpoint_key = created[checkpoint_ns]["checkpoint_key"]
120+
zset_key = created[checkpoint_ns]["zset_key"]
121+
write_keys = created[checkpoint_ns]["write_keys"]
122+
123+
assert await async_client.exists(checkpoint_key) == 1
124+
assert await async_client.exists(zset_key) == 1
125+
assert await async_client.zcard(zset_key) == len(write_keys)
126+
127+
for key in write_keys:
128+
assert await async_client.exists(key) == 1
129+
130+
# Delete everything for thread_id.
131+
await saver.adelete_thread(thread_id)
132+
133+
# The other thread should still exist.
134+
assert await async_client.exists(other_checkpoint_key) == 1
135+
assert await saver.aget_tuple(other_saved_config) is not None
136+
137+
# Keys for thread_id should be gone.
138+
for checkpoint_ns in namespaces:
139+
saved_config = created[checkpoint_ns]["saved_config"]
140+
checkpoint_key = created[checkpoint_ns]["checkpoint_key"]
141+
zset_key = created[checkpoint_ns]["zset_key"]
142+
write_keys = created[checkpoint_ns]["write_keys"]
143+
144+
assert await saver.aget_tuple(saved_config) is None
145+
assert await async_client.exists(checkpoint_key) == 0
146+
assert await async_client.exists(zset_key) == 0
147+
148+
for key in write_keys:
149+
assert await async_client.exists(key) == 0

0 commit comments

Comments
 (0)