Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion haystack_experimental/components/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

from lazy_imports import LazyImporter

_import_structure = {"agent": ["Agent"]}
_import_structure = {"agent": ["Agent"], "unified_agent": ["UnifiedAgent"]}

if TYPE_CHECKING:
from .agent import Agent as Agent
from .unified_agent import UnifiedAgent as UnifiedAgent

else:
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
155 changes: 155 additions & 0 deletions haystack_experimental/components/agents/unified_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

import inspect
from typing import Any

from haystack import component
from haystack.components.builders import ChatPromptBuilder
from haystack.components.generators.chat.types import ChatGenerator
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.breakpoints import AgentBreakpoint, AgentSnapshot
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
from haystack.human_in_the_loop.strategies import ConfirmationStrategy
from haystack.tools import ToolsType

from haystack_experimental.chat_message_stores.types import ChatMessageStore
from haystack_experimental.components.agents import Agent
from haystack_experimental.memory_stores.types import MemoryStore


@component
class UnifiedAgent(Agent):
def __init__(
self,
*,
chat_generator: ChatGenerator,
user_prompt: str,
system_prompt: str | None = None,
tools: ToolsType | None = None,
exit_conditions: list[str] | None = None,
state_schema: dict[str, Any] | None = None,
max_agent_steps: int = 100,
streaming_callback: StreamingCallbackT | None = None,
raise_on_tool_invocation_failure: bool = False,
confirmation_strategies: dict[str, ConfirmationStrategy] | None = None,
tool_invoker_kwargs: dict[str, Any] | None = None,
chat_message_store: ChatMessageStore | None = None,
memory_store: MemoryStore | None = None,
):
super(UnifiedAgent, self).__init__(
chat_generator=chat_generator,
tools=tools,
system_prompt=system_prompt,
exit_conditions=exit_conditions,
state_schema=state_schema,
max_agent_steps=max_agent_steps,
streaming_callback=streaming_callback,
raise_on_tool_invocation_failure=raise_on_tool_invocation_failure,
confirmation_strategies=confirmation_strategies,
tool_invoker_kwargs=tool_invoker_kwargs,
chat_message_store=chat_message_store,
memory_store=memory_store,
)
self.user_prompt: str = user_prompt
self._user_prompt_builder: ChatPromptBuilder = ChatPromptBuilder(template=user_prompt)

# Register template variables as dynamic input sockets so that pipelines can connect other components' outputs.
# We skip variables whose names collide with existing explicit input sockets defined by the run() signature.
run_params = set(inspect.signature(self.run).parameters.keys())
for var in self._user_prompt_builder.variables:
if var not in run_params:
component.set_input_type(self, var, Any, "")

def run( # type: ignore[override] # noqa: PLR0915 PLR0912 D102
self,
*,
system_prompt: str | None = None,
user_prompt: str | None = None,
streaming_callback: StreamingCallbackT | None = None,
generation_kwargs: dict[str, Any] | None = None,
break_point: AgentBreakpoint | None = None,
snapshot: AgentSnapshot | None = None,
tools: ToolsType | list[str] | None = None,
confirmation_strategy_context: dict[str, Any] | None = None,
chat_message_store_kwargs: dict[str, Any] | None = None,
memory_store_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
# Distinguish between prompt builder kwargs and state kwargs.
# Template variables go to the prompt builder; state schema keys go to the parent Agent.run().
# A kwarg present in both sets is passed to both. Unknown kwargs (in neither set) are passed
# only to the prompt builder, where Jinja2 silently ignores them, rather than to State which
# would reject unknown keys.
template_variable_names = set(self._user_prompt_builder.variables)
state_schema_keys = set(self.state_schema.keys())

prompt_kwargs = {k: v for k, v in kwargs.items() if k in template_variable_names or k not in state_schema_keys}
agent_state_kwargs = {k: v for k, v in kwargs.items() if k in state_schema_keys}

prompt_builder_result = self._user_prompt_builder.run(
template=user_prompt or self.user_prompt,
**prompt_kwargs,
)
messages: list[ChatMessage] = prompt_builder_result["prompt"]

return super(UnifiedAgent, self).run(
messages=messages,
streaming_callback=streaming_callback,
generation_kwargs=generation_kwargs,
break_point=break_point,
snapshot=snapshot,
system_prompt=system_prompt or self.system_prompt,
tools=tools,
confirmation_strategy_context=confirmation_strategy_context,
chat_message_store_kwargs=chat_message_store_kwargs,
memory_store_kwargs=memory_store_kwargs,
**agent_state_kwargs,
)

async def run_async( # type: ignore[override] # noqa: D102
self,
*,
system_prompt: str | None = None,
user_prompt: str | None = None,
streaming_callback: StreamingCallbackT | None = None,
generation_kwargs: dict[str, Any] | None = None,
break_point: AgentBreakpoint | None = None,
snapshot: AgentSnapshot | None = None,
tools: ToolsType | list[str] | None = None,
confirmation_strategy_context: dict[str, Any] | None = None,
chat_message_store_kwargs: dict[str, Any] | None = None,
memory_store_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
# Distinguish between prompt builder kwargs and state kwargs.
# Template variables go to the prompt builder; state schema keys go to the parent Agent.run().
# A kwarg present in both sets is passed to both. Unknown kwargs (in neither set) are passed
# only to the prompt builder, where Jinja2 silently ignores them, rather than to State which
# would reject unknown keys.
template_variable_names = set(self._user_prompt_builder.variables)
state_schema_keys = set(self.state_schema.keys())

prompt_kwargs = {k: v for k, v in kwargs.items() if k in template_variable_names or k not in state_schema_keys}
agent_state_kwargs = {k: v for k, v in kwargs.items() if k in state_schema_keys}

prompt_builder_result = self._user_prompt_builder.run(
template=user_prompt or self.user_prompt,
**prompt_kwargs,
)
messages: list[ChatMessage] = prompt_builder_result["prompt"]

return await super(UnifiedAgent, self).run_async(
messages=messages,
streaming_callback=streaming_callback,
generation_kwargs=generation_kwargs,
break_point=break_point,
snapshot=snapshot,
system_prompt=system_prompt or self.system_prompt,
tools=tools,
confirmation_strategy_context=confirmation_strategy_context,
chat_message_store_kwargs=chat_message_store_kwargs,
memory_store_kwargs=memory_store_kwargs,
**agent_state_kwargs,
)
186 changes: 186 additions & 0 deletions test/components/agents/test_unified_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any

import pytest
from haystack import Document, Pipeline, component
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import Tool, create_tool_from_function

from haystack_experimental.components.agents.unified_agent import UnifiedAgent


# Helper to wrap a plain text string in the Jinja2 message block syntax
# required by ChatPromptBuilder when using string templates.
def _user(text: str) -> str:
return f'{{% message role="user" %}}{text}{{% endmessage %}}'


REPLY_PREFIX = "Reply: "

@component
class MockChatGenerator:
"""Echoes back the last user message."""

@component.output_types(replies=list[ChatMessage])
def run(self, messages: list[ChatMessage], tools: Any = None) -> dict[str, list[ChatMessage]]:
last_user_text = ""
for m in reversed(messages):
if m.text is not None:
last_user_text = m.text
break
return {"replies": [ChatMessage.from_assistant(f"{REPLY_PREFIX}{last_user_text}")]}

@component
class MockRetriever:
"""Returns a fixed set of documents."""

@component.output_types(documents=list[Document])
def run(self, query: str) -> dict[str, list[Document]]:
return {
"documents": [
Document(content="Paris is the capital of France."),
Document(content="Berlin is the capital of Germany."),
Document(content="Madrid is the capital of Spain."),
]
}


def _run_agent(agent: UnifiedAgent, **kwargs) -> dict[str, Any]:
"""Warm-up and run an agent in one call."""
agent.warm_up()
return agent.run(**kwargs)


def _make_agent(chat_generator=None, user_prompt=None, **kwargs) -> UnifiedAgent:
"""Shorthand to create a UnifiedAgent with sensible defaults."""
return UnifiedAgent(
chat_generator=chat_generator or MockChatGenerator(),
user_prompt=user_prompt or _user("Hello"),
**kwargs,
)


class TestChatPromptBuilderUnification:

def test_run_static_template(self):
"""Static template should produce a user+assistant pair with correct content."""
result = _run_agent(_make_agent(user_prompt=_user("Hello, world!")))


messages = result["messages"]
assert len(messages) == 2
assert messages[0].role.value == "user"
assert messages[0].text == "Hello, world!"
assert messages[1].role.value == "assistant"
assert messages[1].text == f"{REPLY_PREFIX}Hello, world!"
assert result["last_message"].text == f"{REPLY_PREFIX}Hello, world!"

def test_run_static_string_template(self):
user_prompt = '{% message role="user" %}What is the meaning of life?{% endmessage %}'
result = _run_agent(_make_agent(user_prompt=user_prompt))

messages = result["messages"]
assert len(messages) == 2
assert messages[0].role.value == "user"
assert messages[0].text == "What is the meaning of life?"
assert messages[1].role.value == "assistant"
assert messages[1].text == f"{REPLY_PREFIX}What is the meaning of life?"

def test_run_static_template_multi_message_with_system_prompt(self):
user_prompt = (
'{% message role="user" %}First message{% endmessage %}'
'{% message role="user" %}Second message{% endmessage %}'
)
result = _run_agent(_make_agent(user_prompt=user_prompt,
system_prompt="You are a helpful assistant.",
))

messages = result["messages"]
assert len(messages) == 4
assert messages[0].role.value == "system"
assert messages[0].text == "You are a helpful assistant."
assert messages[1].role.value == "user"
assert messages[1].text == "First message"
assert messages[2].role.value == "user"
assert messages[2].text == "Second message"
assert messages[3].role.value == "assistant"
assert messages[3].text == f"{REPLY_PREFIX}Second message"


def test_run_string_template_with_multiple_variables(self):
user_prompt = (
'{% message role="system" %}You speak {{language}}.{% endmessage %}'
'{% message role="user" %}{{query}}{% endmessage %}'
)
result = _run_agent(
_make_agent(user_prompt=user_prompt),
language="English", query="What is AI?",
)

messages = result["messages"]
assert len(messages) == 3
assert messages[0].role.value == "system"
assert messages[0].text == "You speak English."
assert messages[1].role.value == "user"
assert messages[1].text == "What is AI?"
assert messages[2].role.value == "assistant"
assert messages[2].text == f"{REPLY_PREFIX}What is AI?"

def test_run_template_override_at_runtime(self):
"""Passing user_prompt at runtime overrides the init-time template."""
result = _run_agent(
_make_agent(user_prompt=_user("Original {{topic}}")),
user_prompt=_user("Overridden: {{concept}}"),
concept="deep learning",
)
assert result["last_message"].text == f"{REPLY_PREFIX}Overridden: deep learning"

def test_run_optional_variable_defaults_to_empty(self):
result = _run_agent(
_make_agent(user_prompt=_user("Hello {{name}}"))
)
assert result["last_message"].text == f"{REPLY_PREFIX}Hello"

def test_run_rag_pipeline_with_documents_and_query(self):
"""RAG pipeline: a retriever feeds documents into the agent via a Pipeline."""

user_prompt = (
'{% message role="user" %}'
"Answer the question based on the following documents:\n"
"{% for doc in documents %}"
"- {{ doc.content }}\n"
"{% endfor %}\n"
"Question: {{ query }}"
"{% endmessage %}"
)
agent = _make_agent(user_prompt=user_prompt)

pipeline = Pipeline()
pipeline.add_component("retriever", MockRetriever())
pipeline.add_component("agent", agent)
pipeline.connect("retriever.documents", "agent.documents")

query = "What is the capital of France?"
result = pipeline.run(
{
"retriever": { "query": query },
"agent": { "query": query },
}
)

messages = result["agent"]["messages"]
assert len(messages) == 2
assert messages[0].role.value == "user"

# Verify all document contents from the retriever are rendered in the prompt
assert "Paris is the capital of France." in messages[0].text
assert "Berlin is the capital of Germany." in messages[0].text
assert "Madrid is the capital of Spain." in messages[0].text
assert query in messages[0].text

# The mock generator echoes back the rendered user message
assert messages[1].role.value == "assistant"
assert messages[1].text == f"{REPLY_PREFIX}{messages[0].text}"
Loading