Skip to content

Commit 134998a

Browse files
sjrldavidsbatista
authored andcommitted
fix: Fix initializing an Agent from an AgentSnapshot (#9826)
* Fix initializing agent from snapshot. Refactoring tests * Fixing tests * Add integration test * Add reno * linting * Update releasenotes/notes/fix-openai-agent-snapshot-init-1ca26789564a53fe.yaml Co-authored-by: David S. Batista <dsbatista@gmail.com> --------- Co-authored-by: David S. Batista <dsbatista@gmail.com>
1 parent bdd5ab1 commit 134998a

4 files changed

Lines changed: 457 additions & 114 deletions

File tree

haystack/components/agents/agent.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -359,15 +359,7 @@ def _initialize_from_snapshot(
359359
state_data = current_inputs["tool_invoker"]["state"].data
360360
state = State(schema=self.state_schema, data=state_data)
361361

362-
if isinstance(snapshot.break_point.break_point, ToolBreakpoint):
363-
messages = current_inputs["tool_invoker"]["messages"]
364-
skip_chat_generator = True
365-
else:
366-
messages = current_inputs["chat_generator"]["messages"]
367-
skip_chat_generator = False
368-
369-
state.set("messages", messages)
370-
362+
skip_chat_generator = isinstance(snapshot.break_point.break_point, ToolBreakpoint)
371363
streaming_callback = current_inputs["chat_generator"].get("streaming_callback", streaming_callback)
372364
streaming_callback = select_streaming_callback( # type: ignore[call-overload]
373365
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=requires_async
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Prevent duplication of the last assistant message in the chat history when initializing from an `AgentSnapshot`.

test/components/agents/test_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def to_dict(self) -> dict[str, Any]:
141141
return {"type": "MockChatGeneratorWithoutRunAsync", "data": {}}
142142

143143
@classmethod
144-
def from_dict(cls, data: dict[str, Any]) -> "MockChatGeneratorWithoutRunAsync":
144+
def from_dict(cls, data: dict[str, Any]) -> "MockChatGenerator":
145145
return cls()
146146

147147
@component.output_types(replies=list[ChatMessage])

0 commit comments

Comments
 (0)