diff --git a/haystack/core/pipeline/breakpoint.py b/haystack/core/pipeline/breakpoint.py index c1208a83f0..ca15ad7150 100644 --- a/haystack/core/pipeline/breakpoint.py +++ b/haystack/core/pipeline/breakpoint.py @@ -375,29 +375,12 @@ def _create_agent_snapshot( :param agent_breakpoint: AgentBreakpoint object containing breakpoints :return: An AgentSnapshot containing the agent's state and component visits. """ - try: - serialized_chat_generator = _serialize_value_with_schema( - _deepcopy_with_exceptions(component_inputs["chat_generator"]) - ) - except Exception as error: - logger.warning( - "Failed to serialize the agent's chat_generator inputs. " - "The inputs in the snapshot will be replaced with an empty dictionary. Error: {e}", - e=error, - ) - serialized_chat_generator = {} - - try: - serialized_tool_invoker = _serialize_value_with_schema( - _deepcopy_with_exceptions(component_inputs["tool_invoker"]) - ) - except Exception as error: - logger.warning( - "Failed to serialize the agent's tool_invoker inputs. " - "The inputs in the snapshot will be replaced with an empty dictionary. Error: {e}", - e=error, - ) - serialized_tool_invoker = {} + serialized_chat_generator = _serialize_agent_component_inputs( + component_name="chat_generator", component_inputs=component_inputs["chat_generator"] + ) + serialized_tool_invoker = _serialize_agent_component_inputs( + component_name="tool_invoker", component_inputs=component_inputs["tool_invoker"] + ) return AgentSnapshot( component_inputs={"chat_generator": serialized_chat_generator, "tool_invoker": serialized_tool_invoker}, @@ -407,6 +390,56 @@ def _create_agent_snapshot( ) +def _serialize_agent_component_inputs(component_name: str, component_inputs: dict[str, Any]) -> dict[str, Any]: + """ + Serialize agent component inputs while preserving resumable fields whenever possible. + + If serializing the whole input mapping fails (for example due to a non-serializable callback), + we retry field-by-field and omit only the failing fields. This keeps snapshots resumable when + required fields like ``messages`` or ``state`` are still serializable. + + :param component_name: Name of the agent sub-component (e.g. ``chat_generator`` or ``tool_invoker``). + :param component_inputs: Runtime inputs for that sub-component. + :returns: A serialized payload that is always a structurally valid ``{"serialization_schema", "serialized_data"}`` + pair. When every field fails to serialize, an empty-but-valid object payload is returned so that + ``_deserialize_value_with_schema`` can still load it (e.g. when resuming from a ``ToolBreakpoint`` where the + sub-component's inputs are not strictly required). + """ + try: + return _serialize_value_with_schema(_deepcopy_with_exceptions(component_inputs)) + except Exception as error: + logger.warning( + "Failed to serialize the agent's {component_name} inputs. " + "Haystack will omit only the non-serializable fields when possible. Error: {e}", + component_name=component_name, + e=error, + ) + + serialized_properties: dict[str, Any] = {} + serialized_data: dict[str, Any] = {} + + for field_name, value in component_inputs.items(): + try: + serialized_value = _serialize_value_with_schema(_deepcopy_with_exceptions(value)) + except Exception as field_error: + logger.warning( + "Failed to serialize the agent's {component_name}.{field_name} input. " + "The field will be omitted from the snapshot. Error: {e}", + component_name=component_name, + field_name=field_name, + e=field_error, + ) + continue + + serialized_properties[field_name] = serialized_value["serialization_schema"] + serialized_data[field_name] = serialized_value["serialized_data"] + + return { + "serialization_schema": {"type": "object", "properties": serialized_properties}, + "serialized_data": serialized_data, + } + + def _validate_tool_breakpoint_is_valid(agent_breakpoint: AgentBreakpoint, tools: "ToolsType") -> None: """ Validates the AgentBreakpoint passed to the agent. diff --git a/releasenotes/notes/fix-agent-snapshot-resume-after-fallback-7fd7ff9a0f8f8b87.yaml b/releasenotes/notes/fix-agent-snapshot-resume-after-fallback-7fd7ff9a0f8f8b87.yaml new file mode 100644 index 0000000000..c698c866ef --- /dev/null +++ b/releasenotes/notes/fix-agent-snapshot-resume-after-fallback-7fd7ff9a0f8f8b87.yaml @@ -0,0 +1,11 @@ +--- +fixes: + - | + Preserve resumable agent snapshots when some ``chat_generator`` or ``tool_invoker`` inputs are + non-serializable. Haystack now omits only the failing runtime-only fields (for example + non-serializable callbacks) instead of replacing the whole payload with an empty dictionary. + When every field of a sub-component input fails to serialize, the snapshot still stores a + structurally valid empty payload (``{"serialization_schema": {"type": "object", "properties": {}}, + "serialized_data": {}}``) so that resuming the snapshot does not raise ``DeserializationError`` – + for example when resuming from a ``ToolBreakpoint`` where the sub-component's inputs are not + strictly required. diff --git a/test/components/agents/test_agent_breakpoints.py b/test/components/agents/test_agent_breakpoints.py index b092a04cf1..a1730b0da2 100644 --- a/test/components/agents/test_agent_breakpoints.py +++ b/test/components/agents/test_agent_breakpoints.py @@ -472,6 +472,40 @@ def test_resume_from_tool_invoker(self, agent, tmp_path, monkeypatch): assert "last_message" in result assert len(result["messages"]) > 0 + def test_resume_from_tool_invoker_omits_non_serializable_runtime_callback(self, agent, tmp_path, monkeypatch): + monkeypatch.setenv(HAYSTACK_PIPELINE_SNAPSHOT_SAVE_ENABLED, "true") + debug_path = str(tmp_path / "debug_snapshots") + tool_bp = ToolBreakpoint(component_name="tool_invoker", tool_name="weather_tool", snapshot_file_path=debug_path) + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test_agent") + + try: + agent.run( + messages=[ChatMessage.from_user("What's the weather in Berlin?")], + break_point=agent_breakpoint, + streaming_callback=lambda chunk: None, + ) + except BreakpointException: + pass + + snapshot_files = list(Path(debug_path).glob("test_agent_tool_invoker_*.json")) + assert len(snapshot_files) > 0 + latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime)) + agent_snapshot = load_pipeline_snapshot(latest_snapshot_file).agent_snapshot + + assert agent_snapshot is not None + assert "streaming_callback" not in agent_snapshot.component_inputs["chat_generator"]["serialized_data"] + assert "streaming_callback" not in agent_snapshot.component_inputs["tool_invoker"]["serialized_data"] + assert "state" in agent_snapshot.component_inputs["tool_invoker"]["serialized_data"] + + result = agent.run( + messages=[ChatMessage.from_user("This is actually ignored when resuming from snapshot.")], + snapshot=agent_snapshot, + ) + + assert "messages" in result + assert "last_message" in result + assert len(result["messages"]) == 4 + def test_resume_from_tool_invoker_and_new_breakpoint(self, weather_tool, tmp_path, monkeypatch): monkeypatch.setenv(HAYSTACK_PIPELINE_SNAPSHOT_SAVE_ENABLED, "true") agent = Agent( diff --git a/test/core/pipeline/test_breakpoint.py b/test/core/pipeline/test_breakpoint.py index 953fd26667..0c5d5d82fe 100644 --- a/test/core/pipeline/test_breakpoint.py +++ b/test/core/pipeline/test_breakpoint.py @@ -21,6 +21,9 @@ ) from haystack.dataclasses import ChatMessage from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, PipelineSnapshot, PipelineState +from haystack.utils import _deserialize_value_with_schema + +_EMPTY_OBJECT_PAYLOAD = {"serialization_schema": {"type": "object", "properties": {}}, "serialized_data": {}} def test_transform_json_structure_unwraps_sender_value(): @@ -257,8 +260,8 @@ def to_dict(self): component_inputs={"chat_generator": {"messages": NonSerializable()}, "tool_invoker": {"messages": []}}, ) - assert snapshot.component_inputs["chat_generator"] == {} - assert snapshot.component_inputs["tool_invoker"] != {} + assert snapshot.component_inputs["chat_generator"] == _EMPTY_OBJECT_PAYLOAD + assert snapshot.component_inputs["tool_invoker"] != _EMPTY_OBJECT_PAYLOAD assert "Failed to serialize the agent's chat_generator inputs" in caplog.text def test_create_agent_snapshot_non_serializable_tool_invoker(self, caplog): @@ -277,8 +280,8 @@ def to_dict(self): component_inputs={"chat_generator": {"messages": []}, "tool_invoker": {"messages": NonSerializable()}}, ) - assert snapshot.component_inputs["tool_invoker"] == {} - assert snapshot.component_inputs["chat_generator"] != {} + assert snapshot.component_inputs["tool_invoker"] == _EMPTY_OBJECT_PAYLOAD + assert snapshot.component_inputs["chat_generator"] != _EMPTY_OBJECT_PAYLOAD assert "Failed to serialize the agent's tool_invoker inputs" in caplog.text def test_create_agent_snapshot_both_non_serializable(self, caplog): @@ -300,13 +303,48 @@ def to_dict(self): }, ) - assert snapshot.component_inputs["chat_generator"] == {} - assert snapshot.component_inputs["tool_invoker"] == {} + assert snapshot.component_inputs["chat_generator"] == _EMPTY_OBJECT_PAYLOAD + assert snapshot.component_inputs["tool_invoker"] == _EMPTY_OBJECT_PAYLOAD assert "Failed to serialize the agent's chat_generator inputs" in caplog.text assert "Failed to serialize the agent's tool_invoker inputs" in caplog.text assert snapshot.component_visits == {"chat_generator": 1, "tool_invoker": 0} assert snapshot.break_point == agent_breakpoint + def test_create_agent_snapshot_all_fields_non_serializable_payload_is_deserializable(self, caplog): + """ + When every field of a sub-component input fails to serialize, the resulting payload must still be a + structurally valid ``{"serialization_schema", "serialized_data"}`` pair so that + ``_deserialize_value_with_schema`` can load it back (rather than raising ``DeserializationError`` as it would + for a bare ``{}``). This guards against the snapshot being silently non-resumable in the all-fields-fail path. + """ + + class NonSerializable: + def to_dict(self): + raise TypeError("Cannot serialize") + + agent_breakpoint = AgentBreakpoint( + agent_name="agent", break_point=Breakpoint(component_name="chat_generator", visit_count=1) + ) + + with caplog.at_level(logging.WARNING): + snapshot = _create_agent_snapshot( + component_visits={"chat_generator": 1, "tool_invoker": 0}, + agent_breakpoint=agent_breakpoint, + component_inputs={ + "chat_generator": {"streaming_callback": NonSerializable()}, + "tool_invoker": {"streaming_callback": NonSerializable()}, + }, + ) + + for component_name in ("chat_generator", "tool_invoker"): + payload = snapshot.component_inputs[component_name] + assert "serialization_schema" in payload + assert "serialized_data" in payload + assert payload["serialization_schema"] == {"type": "object", "properties": {}} + assert payload["serialized_data"] == {} + # Round-trip: deserializer must accept the empty-but-valid payload without raising. + assert _deserialize_value_with_schema(payload) == {} + def test_save_pipeline_snapshot_raises_on_failure(tmp_path, caplog, monkeypatch): monkeypatch.setenv(HAYSTACK_PIPELINE_SNAPSHOT_SAVE_ENABLED, "true")