Skip to content

Commit 9b180b0

Browse files
Fix rebase issues: restore missing _conversation_state.py and checkpoint decode logic
- Add back _conversation_state.py (encode/decode_chat_messages) lost in rebase - Fix on_checkpoint_restore to decode cache/conversation with decode_chat_messages - Fix on_checkpoint_restore to use decode_checkpoint_value for pending requests - Add tests/workflow/__init__.py for relative import support - Fix test_agent_executor checkpoint selection (checkpoints[1] not superstep)
1 parent 4aa3f96 commit 9b180b0

4 files changed

Lines changed: 101 additions & 5 deletions

File tree

python/packages/core/agent_framework/_workflows/_agent_executor.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .._sessions import AgentSession
1515
from .._types import AgentResponse, AgentResponseUpdate, Message
1616
from ._agent_utils import resolve_agent_id
17-
from ._checkpoint_encoding import encode_checkpoint_value
17+
from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
1818
from ._const import WORKFLOW_RUN_KWARGS_KEY
1919
from ._conversation_state import encode_chat_messages
2020
from ._executor import Executor, handler
@@ -245,11 +245,27 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
245245
Args:
246246
state: Checkpoint data dict
247247
"""
248+
from ._conversation_state import decode_chat_messages
249+
248250
cache_payload = state.get("cache")
249-
self._cache = cache_payload or []
251+
if cache_payload:
252+
try:
253+
self._cache = decode_chat_messages(cache_payload)
254+
except Exception as exc:
255+
logger.warning("Failed to restore cache: %s", exc)
256+
self._cache = []
257+
else:
258+
self._cache = []
250259

251260
full_conversation_payload = state.get("full_conversation")
252-
self._full_conversation = full_conversation_payload or []
261+
if full_conversation_payload:
262+
try:
263+
self._full_conversation = decode_chat_messages(full_conversation_payload)
264+
except Exception as exc:
265+
logger.warning("Failed to restore full conversation: %s", exc)
266+
self._full_conversation = []
267+
else:
268+
self._full_conversation = []
253269

254270
session_payload = state.get("agent_session")
255271
if session_payload:
@@ -263,11 +279,11 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
263279

264280
pending_requests_payload = state.get("pending_agent_requests")
265281
if pending_requests_payload:
266-
self._pending_agent_requests = pending_requests_payload
282+
self._pending_agent_requests = decode_checkpoint_value(pending_requests_payload)
267283

268284
pending_responses_payload = state.get("pending_responses_to_agent")
269285
if pending_responses_payload:
270-
self._pending_responses_to_agent = pending_responses_payload
286+
self._pending_responses_to_agent = decode_checkpoint_value(pending_responses_payload)
271287

272288
def reset(self) -> None:
273289
"""Reset the internal cache of the executor."""
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from collections.abc import Iterable
4+
from typing import Any, cast
5+
6+
from agent_framework import Message
7+
8+
from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
9+
10+
"""Utilities for serializing and deserializing chat conversations for persistence.
11+
12+
These helpers convert rich `Message` instances to checkpoint-friendly payloads
13+
using the same encoding primitives as the workflow runner. This preserves
14+
`additional_properties` and other metadata without relying on unsafe mechanisms
15+
such as pickling.
16+
"""
17+
18+
19+
def encode_chat_messages(messages: Iterable[Message]) -> list[dict[str, Any]]:
20+
"""Serialize chat messages into checkpoint-safe payloads."""
21+
encoded: list[dict[str, Any]] = []
22+
for message in messages:
23+
encoded.append({
24+
"role": encode_checkpoint_value(message.role),
25+
"contents": [encode_checkpoint_value(content) for content in message.contents],
26+
"author_name": message.author_name,
27+
"message_id": message.message_id,
28+
"additional_properties": {
29+
key: encode_checkpoint_value(value) for key, value in message.additional_properties.items()
30+
},
31+
})
32+
return encoded
33+
34+
35+
def decode_chat_messages(payload: Iterable[dict[str, Any]]) -> list[Message]:
36+
"""Restore chat messages from checkpoint-safe payloads."""
37+
restored: list[Message] = []
38+
for item in payload:
39+
if not isinstance(item, dict):
40+
continue
41+
42+
role_value = decode_checkpoint_value(item.get("role"))
43+
if isinstance(role_value, str):
44+
role = role_value
45+
elif isinstance(role_value, dict) and "value" in role_value:
46+
# Handle legacy serialization format
47+
role = role_value["value"]
48+
else:
49+
role = "assistant"
50+
51+
contents_field = item.get("contents", [])
52+
contents: list[Any] = []
53+
if isinstance(contents_field, list):
54+
contents_iter: list[Any] = contents_field # type: ignore[assignment]
55+
for entry in contents_iter:
56+
decoded_entry: Any = decode_checkpoint_value(entry)
57+
contents.append(decoded_entry)
58+
59+
additional_field = item.get("additional_properties", {})
60+
additional: dict[str, Any] = {}
61+
if isinstance(additional_field, dict):
62+
additional_dict = cast(dict[str, Any], additional_field)
63+
for key, value in additional_dict.items():
64+
additional[key] = decode_checkpoint_value(value)
65+
66+
restored.append(
67+
Message( # type: ignore[call-overload]
68+
role=role,
69+
contents=contents,
70+
author_name=item.get("author_name"),
71+
message_id=item.get("message_id"),
72+
additional_properties=additional,
73+
)
74+
)
75+
return restored

python/packages/core/tests/workflow/__init__.py

Whitespace-only changes.

python/packages/core/tests/workflow/test_agent_executor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
8989
"and the second one is after the agent execution."
9090
)
9191

92+
# Get the second checkpoint which should contain the state after processing
93+
# the first message by the start executor in the sequential workflow
94+
checkpoints.sort(key=lambda cp: cp.timestamp)
95+
restore_checkpoint = checkpoints[1]
96+
9297
# Verify checkpoint contains executor state with both cache and session
9398
assert "_executor_state" in restore_checkpoint.state
9499
executor_states = restore_checkpoint.state["_executor_state"]

0 commit comments

Comments
 (0)