Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 56 additions & 23 deletions haystack/core/pipeline/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,29 +375,12 @@ def _create_agent_snapshot(
:param agent_breakpoint: AgentBreakpoint object containing breakpoints
:return: An AgentSnapshot containing the agent's state and component visits.
"""
try:
serialized_chat_generator = _serialize_value_with_schema(
_deepcopy_with_exceptions(component_inputs["chat_generator"])
)
except Exception as error:
logger.warning(
"Failed to serialize the agent's chat_generator inputs. "
"The inputs in the snapshot will be replaced with an empty dictionary. Error: {e}",
e=error,
)
serialized_chat_generator = {}

try:
serialized_tool_invoker = _serialize_value_with_schema(
_deepcopy_with_exceptions(component_inputs["tool_invoker"])
)
except Exception as error:
logger.warning(
"Failed to serialize the agent's tool_invoker inputs. "
"The inputs in the snapshot will be replaced with an empty dictionary. Error: {e}",
e=error,
)
serialized_tool_invoker = {}
serialized_chat_generator = _serialize_agent_component_inputs(
component_name="chat_generator", component_inputs=component_inputs["chat_generator"]
)
serialized_tool_invoker = _serialize_agent_component_inputs(
component_name="tool_invoker", component_inputs=component_inputs["tool_invoker"]
)

return AgentSnapshot(
component_inputs={"chat_generator": serialized_chat_generator, "tool_invoker": serialized_tool_invoker},
Expand All @@ -407,6 +390,56 @@ def _create_agent_snapshot(
)


def _serialize_agent_component_inputs(component_name: str, component_inputs: dict[str, Any]) -> dict[str, Any]:
"""
Serialize agent component inputs while preserving resumable fields whenever possible.

If serializing the whole input mapping fails (for example due to a non-serializable callback),
we retry field-by-field and omit only the failing fields. This keeps snapshots resumable when
required fields like ``messages`` or ``state`` are still serializable.

:param component_name: Name of the agent sub-component (e.g. ``chat_generator`` or ``tool_invoker``).
:param component_inputs: Runtime inputs for that sub-component.
:returns: A serialized payload that is always a structurally valid ``{"serialization_schema", "serialized_data"}``
pair. When every field fails to serialize, an empty-but-valid object payload is returned so that
``_deserialize_value_with_schema`` can still load it (e.g. when resuming from a ``ToolBreakpoint`` where the
sub-component's inputs are not strictly required).
"""
try:
return _serialize_value_with_schema(_deepcopy_with_exceptions(component_inputs))
except Exception as error:
logger.warning(
"Failed to serialize the agent's {component_name} inputs. "
"Haystack will omit only the non-serializable fields when possible. Error: {e}",
component_name=component_name,
e=error,
)

serialized_properties: dict[str, Any] = {}
serialized_data: dict[str, Any] = {}

for field_name, value in component_inputs.items():
try:
serialized_value = _serialize_value_with_schema(_deepcopy_with_exceptions(value))
except Exception as field_error:
logger.warning(
"Failed to serialize the agent's {component_name}.{field_name} input. "
"The field will be omitted from the snapshot. Error: {e}",
component_name=component_name,
field_name=field_name,
e=field_error,
)
continue

serialized_properties[field_name] = serialized_value["serialization_schema"]
serialized_data[field_name] = serialized_value["serialized_data"]

return {
"serialization_schema": {"type": "object", "properties": serialized_properties},
"serialized_data": serialized_data,
}


def _validate_tool_breakpoint_is_valid(agent_breakpoint: AgentBreakpoint, tools: "ToolsType") -> None:
"""
Validates the AgentBreakpoint passed to the agent.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
fixes:
- |
Preserve resumable agent snapshots when some ``chat_generator`` or ``tool_invoker`` inputs are
non-serializable. Haystack now omits only the failing runtime-only fields (for example
non-serializable callbacks) instead of replacing the whole payload with an empty dictionary.
When every field of a sub-component input fails to serialize, the snapshot still stores a
structurally valid empty payload (``{"serialization_schema": {"type": "object", "properties": {}},
"serialized_data": {}}``) so that resuming the snapshot does not raise ``DeserializationError`` –
for example when resuming from a ``ToolBreakpoint`` where the sub-component's inputs are not
strictly required.
34 changes: 34 additions & 0 deletions test/components/agents/test_agent_breakpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,40 @@ def test_resume_from_tool_invoker(self, agent, tmp_path, monkeypatch):
assert "last_message" in result
assert len(result["messages"]) > 0

def test_resume_from_tool_invoker_omits_non_serializable_runtime_callback(self, agent, tmp_path, monkeypatch):
monkeypatch.setenv(HAYSTACK_PIPELINE_SNAPSHOT_SAVE_ENABLED, "true")
debug_path = str(tmp_path / "debug_snapshots")
tool_bp = ToolBreakpoint(component_name="tool_invoker", tool_name="weather_tool", snapshot_file_path=debug_path)
agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test_agent")

try:
agent.run(
messages=[ChatMessage.from_user("What's the weather in Berlin?")],
break_point=agent_breakpoint,
streaming_callback=lambda chunk: None,
)
except BreakpointException:
pass

snapshot_files = list(Path(debug_path).glob("test_agent_tool_invoker_*.json"))
assert len(snapshot_files) > 0
latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime))
agent_snapshot = load_pipeline_snapshot(latest_snapshot_file).agent_snapshot

assert agent_snapshot is not None
assert "streaming_callback" not in agent_snapshot.component_inputs["chat_generator"]["serialized_data"]
assert "streaming_callback" not in agent_snapshot.component_inputs["tool_invoker"]["serialized_data"]
assert "state" in agent_snapshot.component_inputs["tool_invoker"]["serialized_data"]

result = agent.run(
messages=[ChatMessage.from_user("This is actually ignored when resuming from snapshot.")],
snapshot=agent_snapshot,
)

assert "messages" in result
assert "last_message" in result
assert len(result["messages"]) == 4

def test_resume_from_tool_invoker_and_new_breakpoint(self, weather_tool, tmp_path, monkeypatch):
monkeypatch.setenv(HAYSTACK_PIPELINE_SNAPSHOT_SAVE_ENABLED, "true")
agent = Agent(
Expand Down
50 changes: 44 additions & 6 deletions test/core/pipeline/test_breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
)
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, PipelineSnapshot, PipelineState
from haystack.utils import _deserialize_value_with_schema

_EMPTY_OBJECT_PAYLOAD = {"serialization_schema": {"type": "object", "properties": {}}, "serialized_data": {}}


def test_transform_json_structure_unwraps_sender_value():
Expand Down Expand Up @@ -257,8 +260,8 @@ def to_dict(self):
component_inputs={"chat_generator": {"messages": NonSerializable()}, "tool_invoker": {"messages": []}},
)

assert snapshot.component_inputs["chat_generator"] == {}
assert snapshot.component_inputs["tool_invoker"] != {}
assert snapshot.component_inputs["chat_generator"] == _EMPTY_OBJECT_PAYLOAD
assert snapshot.component_inputs["tool_invoker"] != _EMPTY_OBJECT_PAYLOAD
assert "Failed to serialize the agent's chat_generator inputs" in caplog.text

def test_create_agent_snapshot_non_serializable_tool_invoker(self, caplog):
Expand All @@ -277,8 +280,8 @@ def to_dict(self):
component_inputs={"chat_generator": {"messages": []}, "tool_invoker": {"messages": NonSerializable()}},
)

assert snapshot.component_inputs["tool_invoker"] == {}
assert snapshot.component_inputs["chat_generator"] != {}
assert snapshot.component_inputs["tool_invoker"] == _EMPTY_OBJECT_PAYLOAD
assert snapshot.component_inputs["chat_generator"] != _EMPTY_OBJECT_PAYLOAD
assert "Failed to serialize the agent's tool_invoker inputs" in caplog.text

def test_create_agent_snapshot_both_non_serializable(self, caplog):
Expand All @@ -300,13 +303,48 @@ def to_dict(self):
},
)

assert snapshot.component_inputs["chat_generator"] == {}
assert snapshot.component_inputs["tool_invoker"] == {}
assert snapshot.component_inputs["chat_generator"] == _EMPTY_OBJECT_PAYLOAD
assert snapshot.component_inputs["tool_invoker"] == _EMPTY_OBJECT_PAYLOAD
assert "Failed to serialize the agent's chat_generator inputs" in caplog.text
assert "Failed to serialize the agent's tool_invoker inputs" in caplog.text
assert snapshot.component_visits == {"chat_generator": 1, "tool_invoker": 0}
assert snapshot.break_point == agent_breakpoint

def test_create_agent_snapshot_all_fields_non_serializable_payload_is_deserializable(self, caplog):
"""
When every field of a sub-component input fails to serialize, the resulting payload must still be a
structurally valid ``{"serialization_schema", "serialized_data"}`` pair so that
``_deserialize_value_with_schema`` can load it back (rather than raising ``DeserializationError`` as it would
for a bare ``{}``). This guards against the snapshot being silently non-resumable in the all-fields-fail path.
"""

class NonSerializable:
def to_dict(self):
raise TypeError("Cannot serialize")

agent_breakpoint = AgentBreakpoint(
agent_name="agent", break_point=Breakpoint(component_name="chat_generator", visit_count=1)
)

with caplog.at_level(logging.WARNING):
snapshot = _create_agent_snapshot(
component_visits={"chat_generator": 1, "tool_invoker": 0},
agent_breakpoint=agent_breakpoint,
component_inputs={
"chat_generator": {"streaming_callback": NonSerializable()},
"tool_invoker": {"streaming_callback": NonSerializable()},
},
)

for component_name in ("chat_generator", "tool_invoker"):
payload = snapshot.component_inputs[component_name]
assert "serialization_schema" in payload
assert "serialized_data" in payload
assert payload["serialization_schema"] == {"type": "object", "properties": {}}
assert payload["serialized_data"] == {}
# Round-trip: deserializer must accept the empty-but-valid payload without raising.
assert _deserialize_value_with_schema(payload) == {}


def test_save_pipeline_snapshot_raises_on_failure(tmp_path, caplog, monkeypatch):
monkeypatch.setenv(HAYSTACK_PIPELINE_SNAPSHOT_SAVE_ENABLED, "true")
Expand Down