Skip to content

Commit e2a9448

Browse files
authored
fix: preserve JSON-safe checkpoint serialization for msgpack async saves (#182)
* fix: normalize msgpack checkpoints for RedisJSON * fix: preserve msgpack checkpoints with non-string keys
1 parent 1b533ee commit e2a9448

3 files changed

Lines changed: 148 additions & 10 deletions

File tree

langgraph/checkpoint/redis/base.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,17 +276,14 @@ def _dump_checkpoint(self, checkpoint: Checkpoint) -> dict[str, Any]:
276276
if type_ == "json":
277277
checkpoint_data = cast(dict, orjson.loads(data))
278278
else:
279-
# For msgpack or other types, deserialize with loads_typed
280279
checkpoint_data = cast(dict, self.serde.loads_typed((type_, data)))
281-
282-
# When using msgpack, bytes are preserved - but Redis JSON.SET can't handle them
283-
# Encode bytes in channel_values with type marker for JSON storage
284-
if "channel_values" in checkpoint_data:
285-
for key, value in checkpoint_data["channel_values"].items():
286-
if isinstance(value, bytes):
287-
checkpoint_data["channel_values"][key] = {
288-
"__bytes__": self._encode_blob(value)
289-
}
280+
if type_ == "msgpack":
281+
# Msgpack fallback can rehydrate LangChain messages as live Python
282+
# objects. Normalize the checkpoint back through the JSON serializer
283+
# so RedisJSON only sees JSON-safe constructor dictionaries.
284+
checkpoint_data = cast(
285+
dict, self._msgpack_to_redis_json(checkpoint_data)
286+
)
290287

291288
# Ensure channel_versions are always strings to fix issue #40
292289
if "channel_versions" in checkpoint_data:
@@ -296,6 +293,30 @@ def _dump_checkpoint(self, checkpoint: Checkpoint) -> dict[str, Any]:
296293

297294
return {"type": type_, **checkpoint_data, "pending_sends": []}
298295

296+
def _msgpack_to_redis_json(self, value: Any) -> dict[str, Any]:
297+
"""Convert a msgpack-deserialized checkpoint into Redis JSON-safe data."""
298+
binary_safe = self._replace_binary_markers(value)
299+
serializer = cast(JsonPlusRedisSerializer, self.serde)
300+
processed = serializer._preprocess_interrupts(binary_safe)
301+
json_bytes = orjson.dumps(
302+
processed,
303+
default=serializer._default_handler,
304+
option=orjson.OPT_NON_STR_KEYS,
305+
)
306+
return cast(dict, orjson.loads(json_bytes))
307+
308+
def _replace_binary_markers(self, value: Any) -> Any:
309+
"""Recursively replace binary values with JSON-safe markers."""
310+
if isinstance(value, bytes):
311+
return {"__bytes__": self._encode_blob(value)}
312+
if isinstance(value, dict):
313+
return {k: self._replace_binary_markers(v) for k, v in value.items()}
314+
if isinstance(value, list):
315+
return [self._replace_binary_markers(item) for item in value]
316+
if isinstance(value, tuple):
317+
return tuple(self._replace_binary_markers(item) for item in value)
318+
return value
319+
299320
def _deserialize_channel_values(
300321
self, channel_values: dict[str, Any]
301322
) -> dict[str, Any]:

tests/test_checkpoint_serialization.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,91 @@ def test_checkpoint_with_messages(redis_url: str) -> None:
561561
assert loaded_messages[1].content == "Let me check that for you."
562562

563563

564+
def test_issue_181_msgpack_checkpoint_keeps_messages_json_safe() -> None:
565+
"""Ensure msgpack fallback does not leak live message objects into Redis JSON."""
566+
from langchain_core.messages import AIMessage, HumanMessage
567+
568+
from langgraph.checkpoint.redis.base import BaseRedisSaver
569+
570+
class DummySaver(BaseRedisSaver):
571+
def create_indexes(self) -> None:
572+
pass
573+
574+
def configure_client(self, **kwargs) -> None:
575+
self._redis = None
576+
577+
saver = DummySaver(redis_client=object())
578+
messages = [
579+
HumanMessage(content="What is the weather in SF?"),
580+
AIMessage(content="Let me check that for you."),
581+
]
582+
583+
checkpoint = create_checkpoint(
584+
checkpoint=empty_checkpoint(),
585+
channels={"messages": messages, "binary": b"abc"},
586+
step=1,
587+
)
588+
checkpoint["channel_values"]["messages"] = messages
589+
checkpoint["channel_values"]["binary"] = b"abc"
590+
591+
dumped = saver._dump_checkpoint(checkpoint)
592+
593+
# The mixed bytes payload forces msgpack fallback, but the Redis document
594+
# must still be fully JSON-safe.
595+
assert dumped["type"] == "msgpack"
596+
assert isinstance(dumped["channel_values"]["messages"][0], dict)
597+
assert dumped["channel_values"]["messages"][0]["id"][-1] == "HumanMessage"
598+
assert dumped["channel_values"]["binary"] == {"__bytes__": "YWJj"}
599+
600+
restored = saver._load_checkpoint(
601+
dumped,
602+
saver._deserialize_channel_values(dumped["channel_values"]),
603+
[],
604+
)
605+
606+
restored_messages = restored["channel_values"]["messages"]
607+
assert isinstance(restored_messages[0], HumanMessage)
608+
assert isinstance(restored_messages[1], AIMessage)
609+
assert restored["channel_values"]["binary"] == b"abc"
610+
611+
612+
def test_msgpack_checkpoint_with_non_string_keys_remains_json_safe() -> None:
613+
"""Ensure msgpack checkpoints with non-string keys still normalize for Redis JSON."""
614+
from langchain_core.messages import HumanMessage
615+
616+
from langgraph.checkpoint.redis.base import BaseRedisSaver
617+
618+
class DummySaver(BaseRedisSaver):
619+
def create_indexes(self) -> None:
620+
pass
621+
622+
def configure_client(self, **kwargs) -> None:
623+
self._redis = None
624+
625+
saver = DummySaver(redis_client=object())
626+
messages = [HumanMessage(content="hello")]
627+
628+
checkpoint = create_checkpoint(
629+
checkpoint=empty_checkpoint(),
630+
channels={
631+
"messages": messages,
632+
"binary": b"abc",
633+
"mapping": {1: "one", 2: "two"},
634+
},
635+
step=1,
636+
)
637+
checkpoint["channel_values"]["messages"] = messages
638+
checkpoint["channel_values"]["binary"] = b"abc"
639+
checkpoint["channel_values"]["mapping"] = {1: "one", 2: "two"}
640+
641+
dumped = saver._dump_checkpoint(checkpoint)
642+
643+
assert dumped["type"] == "msgpack"
644+
assert dumped["channel_values"]["mapping"] == {"1": "one", "2": "two"}
645+
assert isinstance(dumped["channel_values"]["messages"][0], dict)
646+
assert dumped["channel_values"]["binary"] == {"__bytes__": "YWJj"}
647+
648+
564649
def test_subgraph_state_history_pending_sends(redis_url: str) -> None:
565650
"""Test that get_state_history with subgraphs properly handles pending_sends.
566651

tests/test_issue_87_async_deserialization.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,38 @@ async def test_async_deserializes_langchain_messages(redis_url: str):
115115
assert loaded_messages[3].tool_call_id == "call-1"
116116

117117

118+
@pytest.mark.asyncio
119+
async def test_issue_181_async_checkpoint_with_messages_and_bytes(redis_url: str):
120+
"""Ensure mixed binary payloads do not break async message checkpoint writes."""
121+
async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
122+
thread_id = str(uuid4())
123+
messages = [
124+
HumanMessage(content="What's the weather like?", id="human-1"),
125+
AIMessage(content="I'll help you check the weather.", id="ai-1"),
126+
]
127+
128+
checkpoint = create_checkpoint(
129+
checkpoint=empty_checkpoint(),
130+
channels={"messages": messages, "binary": b"abc"},
131+
step=1,
132+
)
133+
checkpoint["channel_values"]["messages"] = messages
134+
checkpoint["channel_values"]["binary"] = b"abc"
135+
136+
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
137+
138+
saved_config = await saver.aput(
139+
config, checkpoint, {"source": "test", "step": 1, "writes": {}}, {}
140+
)
141+
loaded_tuple = await saver.aget_tuple(saved_config)
142+
143+
assert loaded_tuple is not None
144+
loaded_messages = loaded_tuple.checkpoint["channel_values"]["messages"]
145+
assert isinstance(loaded_messages[0], HumanMessage)
146+
assert isinstance(loaded_messages[1], AIMessage)
147+
assert loaded_tuple.checkpoint["channel_values"]["binary"] == b"abc"
148+
149+
118150
@pytest.mark.asyncio
119151
async def test_async_handles_serialized_langchain_format(redis_url: str):
120152
"""Test that async handles the serialized LangChain format that causes MESSAGE_COERCION_FAILURE.

0 commit comments

Comments
 (0)