diff --git a/haystack_experimental/components/agents/__init__.py b/haystack_experimental/components/agents/__init__.py index 4170413e..0f3528e5 100644 --- a/haystack_experimental/components/agents/__init__.py +++ b/haystack_experimental/components/agents/__init__.py @@ -7,10 +7,11 @@ from lazy_imports import LazyImporter -_import_structure = {"agent": ["Agent"]} +_import_structure = {"agent": ["Agent"], "unified_agent": ["UnifiedAgent"]} if TYPE_CHECKING: from .agent import Agent as Agent + from .unified_agent import UnifiedAgent as UnifiedAgent else: sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure) diff --git a/haystack_experimental/components/agents/unified_agent.py b/haystack_experimental/components/agents/unified_agent.py new file mode 100644 index 00000000..2e30fec3 --- /dev/null +++ b/haystack_experimental/components/agents/unified_agent.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import inspect +from typing import Any + +from haystack import component +from haystack.components.builders import ChatPromptBuilder +from haystack.components.generators.chat.types import ChatGenerator +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import AgentBreakpoint, AgentSnapshot +from haystack.dataclasses.streaming_chunk import StreamingCallbackT +from haystack.human_in_the_loop.strategies import ConfirmationStrategy +from haystack.tools import ToolsType + +from haystack_experimental.chat_message_stores.types import ChatMessageStore +from haystack_experimental.components.agents import Agent +from haystack_experimental.memory_stores.types import MemoryStore + + +@component +class UnifiedAgent(Agent): + def __init__( + self, + *, + chat_generator: ChatGenerator, + user_prompt: str, + system_prompt: str | None = None, + tools: ToolsType | None = None, + exit_conditions: list[str] | None = None, + state_schema: dict[str, Any] | None = None, + max_agent_steps: int = 100, + streaming_callback: StreamingCallbackT | None = None, + raise_on_tool_invocation_failure: bool = False, + confirmation_strategies: dict[str, ConfirmationStrategy] | None = None, + tool_invoker_kwargs: dict[str, Any] | None = None, + chat_message_store: ChatMessageStore | None = None, + memory_store: MemoryStore | None = None, + ): + super(UnifiedAgent, self).__init__( + chat_generator=chat_generator, + tools=tools, + system_prompt=system_prompt, + exit_conditions=exit_conditions, + state_schema=state_schema, + max_agent_steps=max_agent_steps, + streaming_callback=streaming_callback, + raise_on_tool_invocation_failure=raise_on_tool_invocation_failure, + confirmation_strategies=confirmation_strategies, + tool_invoker_kwargs=tool_invoker_kwargs, + chat_message_store=chat_message_store, + memory_store=memory_store, + ) + self.user_prompt: str = user_prompt + self._user_prompt_builder: ChatPromptBuilder = ChatPromptBuilder(template=user_prompt) + + # Register template variables as dynamic input sockets so that pipelines can connect other components' outputs. + # We skip variables whose names collide with existing explicit input sockets defined by the run() signature. + run_params = set(inspect.signature(self.run).parameters.keys()) + for var in self._user_prompt_builder.variables: + if var not in run_params: + component.set_input_type(self, var, Any, "") + + def run( # type: ignore[override] # noqa: PLR0915 PLR0912 D102 + self, + *, + system_prompt: str | None = None, + user_prompt: str | None = None, + streaming_callback: StreamingCallbackT | None = None, + generation_kwargs: dict[str, Any] | None = None, + break_point: AgentBreakpoint | None = None, + snapshot: AgentSnapshot | None = None, + tools: ToolsType | list[str] | None = None, + confirmation_strategy_context: dict[str, Any] | None = None, + chat_message_store_kwargs: dict[str, Any] | None = None, + memory_store_kwargs: dict[str, Any] | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + # Distinguish between prompt builder kwargs and state kwargs. + # Template variables go to the prompt builder; state schema keys go to the parent Agent.run(). + # A kwarg present in both sets is passed to both. Unknown kwargs (in neither set) are passed + # only to the prompt builder, where Jinja2 silently ignores them, rather than to State which + # would reject unknown keys. + template_variable_names = set(self._user_prompt_builder.variables) + state_schema_keys = set(self.state_schema.keys()) + + prompt_kwargs = {k: v for k, v in kwargs.items() if k in template_variable_names or k not in state_schema_keys} + agent_state_kwargs = {k: v for k, v in kwargs.items() if k in state_schema_keys} + + prompt_builder_result = self._user_prompt_builder.run( + template=user_prompt or self.user_prompt, + **prompt_kwargs, + ) + messages: list[ChatMessage] = prompt_builder_result["prompt"] + + return super(UnifiedAgent, self).run( + messages=messages, + streaming_callback=streaming_callback, + generation_kwargs=generation_kwargs, + break_point=break_point, + snapshot=snapshot, + system_prompt=system_prompt or self.system_prompt, + tools=tools, + confirmation_strategy_context=confirmation_strategy_context, + chat_message_store_kwargs=chat_message_store_kwargs, + memory_store_kwargs=memory_store_kwargs, + **agent_state_kwargs, + ) + + async def run_async( # type: ignore[override] # noqa: D102 + self, + *, + system_prompt: str | None = None, + user_prompt: str | None = None, + streaming_callback: StreamingCallbackT | None = None, + generation_kwargs: dict[str, Any] | None = None, + break_point: AgentBreakpoint | None = None, + snapshot: AgentSnapshot | None = None, + tools: ToolsType | list[str] | None = None, + confirmation_strategy_context: dict[str, Any] | None = None, + chat_message_store_kwargs: dict[str, Any] | None = None, + memory_store_kwargs: dict[str, Any] | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + # Distinguish between prompt builder kwargs and state kwargs. + # Template variables go to the prompt builder; state schema keys go to the parent Agent.run(). + # A kwarg present in both sets is passed to both. Unknown kwargs (in neither set) are passed + # only to the prompt builder, where Jinja2 silently ignores them, rather than to State which + # would reject unknown keys. + template_variable_names = set(self._user_prompt_builder.variables) + state_schema_keys = set(self.state_schema.keys()) + + prompt_kwargs = {k: v for k, v in kwargs.items() if k in template_variable_names or k not in state_schema_keys} + agent_state_kwargs = {k: v for k, v in kwargs.items() if k in state_schema_keys} + + prompt_builder_result = self._user_prompt_builder.run( + template=user_prompt or self.user_prompt, + **prompt_kwargs, + ) + messages: list[ChatMessage] = prompt_builder_result["prompt"] + + return await super(UnifiedAgent, self).run_async( + messages=messages, + streaming_callback=streaming_callback, + generation_kwargs=generation_kwargs, + break_point=break_point, + snapshot=snapshot, + system_prompt=system_prompt or self.system_prompt, + tools=tools, + confirmation_strategy_context=confirmation_strategy_context, + chat_message_store_kwargs=chat_message_store_kwargs, + memory_store_kwargs=memory_store_kwargs, + **agent_state_kwargs, + ) diff --git a/test/components/agents/test_unified_agent.py b/test/components/agents/test_unified_agent.py new file mode 100644 index 00000000..1741859e --- /dev/null +++ b/test/components/agents/test_unified_agent.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import pytest +from haystack import Document, Pipeline, component +from haystack.dataclasses import ChatMessage, ToolCall +from haystack.tools import Tool, create_tool_from_function + +from haystack_experimental.components.agents.unified_agent import UnifiedAgent + + +# Helper to wrap a plain text string in the Jinja2 message block syntax +# required by ChatPromptBuilder when using string templates. +def _user(text: str) -> str: + return f'{{% message role="user" %}}{text}{{% endmessage %}}' + + +REPLY_PREFIX = "Reply: " + +@component +class MockChatGenerator: + """Echoes back the last user message.""" + + @component.output_types(replies=list[ChatMessage]) + def run(self, messages: list[ChatMessage], tools: Any = None) -> dict[str, list[ChatMessage]]: + last_user_text = "" + for m in reversed(messages): + if m.text is not None: + last_user_text = m.text + break + return {"replies": [ChatMessage.from_assistant(f"{REPLY_PREFIX}{last_user_text}")]} + +@component +class MockRetriever: + """Returns a fixed set of documents.""" + + @component.output_types(documents=list[Document]) + def run(self, query: str) -> dict[str, list[Document]]: + return { + "documents": [ + Document(content="Paris is the capital of France."), + Document(content="Berlin is the capital of Germany."), + Document(content="Madrid is the capital of Spain."), + ] + } + + +def _run_agent(agent: UnifiedAgent, **kwargs) -> dict[str, Any]: + """Warm-up and run an agent in one call.""" + agent.warm_up() + return agent.run(**kwargs) + + +def _make_agent(chat_generator=None, user_prompt=None, **kwargs) -> UnifiedAgent: + """Shorthand to create a UnifiedAgent with sensible defaults.""" + return UnifiedAgent( + chat_generator=chat_generator or MockChatGenerator(), + user_prompt=user_prompt or _user("Hello"), + **kwargs, + ) + + +class TestChatPromptBuilderUnification: + + def test_run_static_template(self): + """Static template should produce a user+assistant pair with correct content.""" + result = _run_agent(_make_agent(user_prompt=_user("Hello, world!"))) + + + messages = result["messages"] + assert len(messages) == 2 + assert messages[0].role.value == "user" + assert messages[0].text == "Hello, world!" + assert messages[1].role.value == "assistant" + assert messages[1].text == f"{REPLY_PREFIX}Hello, world!" + assert result["last_message"].text == f"{REPLY_PREFIX}Hello, world!" + + def test_run_static_string_template(self): + user_prompt = '{% message role="user" %}What is the meaning of life?{% endmessage %}' + result = _run_agent(_make_agent(user_prompt=user_prompt)) + + messages = result["messages"] + assert len(messages) == 2 + assert messages[0].role.value == "user" + assert messages[0].text == "What is the meaning of life?" + assert messages[1].role.value == "assistant" + assert messages[1].text == f"{REPLY_PREFIX}What is the meaning of life?" + + def test_run_static_template_multi_message_with_system_prompt(self): + user_prompt = ( + '{% message role="user" %}First message{% endmessage %}' + '{% message role="user" %}Second message{% endmessage %}' + ) + result = _run_agent(_make_agent(user_prompt=user_prompt, + system_prompt="You are a helpful assistant.", + )) + + messages = result["messages"] + assert len(messages) == 4 + assert messages[0].role.value == "system" + assert messages[0].text == "You are a helpful assistant." + assert messages[1].role.value == "user" + assert messages[1].text == "First message" + assert messages[2].role.value == "user" + assert messages[2].text == "Second message" + assert messages[3].role.value == "assistant" + assert messages[3].text == f"{REPLY_PREFIX}Second message" + + + def test_run_string_template_with_multiple_variables(self): + user_prompt = ( + '{% message role="system" %}You speak {{language}}.{% endmessage %}' + '{% message role="user" %}{{query}}{% endmessage %}' + ) + result = _run_agent( + _make_agent(user_prompt=user_prompt), + language="English", query="What is AI?", + ) + + messages = result["messages"] + assert len(messages) == 3 + assert messages[0].role.value == "system" + assert messages[0].text == "You speak English." + assert messages[1].role.value == "user" + assert messages[1].text == "What is AI?" + assert messages[2].role.value == "assistant" + assert messages[2].text == f"{REPLY_PREFIX}What is AI?" + + def test_run_template_override_at_runtime(self): + """Passing user_prompt at runtime overrides the init-time template.""" + result = _run_agent( + _make_agent(user_prompt=_user("Original {{topic}}")), + user_prompt=_user("Overridden: {{concept}}"), + concept="deep learning", + ) + assert result["last_message"].text == f"{REPLY_PREFIX}Overridden: deep learning" + + def test_run_optional_variable_defaults_to_empty(self): + result = _run_agent( + _make_agent(user_prompt=_user("Hello {{name}}")) + ) + assert result["last_message"].text == f"{REPLY_PREFIX}Hello" + + def test_run_rag_pipeline_with_documents_and_query(self): + """RAG pipeline: a retriever feeds documents into the agent via a Pipeline.""" + + user_prompt = ( + '{% message role="user" %}' + "Answer the question based on the following documents:\n" + "{% for doc in documents %}" + "- {{ doc.content }}\n" + "{% endfor %}\n" + "Question: {{ query }}" + "{% endmessage %}" + ) + agent = _make_agent(user_prompt=user_prompt) + + pipeline = Pipeline() + pipeline.add_component("retriever", MockRetriever()) + pipeline.add_component("agent", agent) + pipeline.connect("retriever.documents", "agent.documents") + + query = "What is the capital of France?" + result = pipeline.run( + { + "retriever": { "query": query }, + "agent": { "query": query }, + } + ) + + messages = result["agent"]["messages"] + assert len(messages) == 2 + assert messages[0].role.value == "user" + + # Verify all document contents from the retriever are rendered in the prompt + assert "Paris is the capital of France." in messages[0].text + assert "Berlin is the capital of Germany." in messages[0].text + assert "Madrid is the capital of Spain." in messages[0].text + assert query in messages[0].text + + # The mock generator echoes back the rendered user message + assert messages[1].role.value == "assistant" + assert messages[1].text == f"{REPLY_PREFIX}{messages[0].text}"