diff --git a/MIGRATION.md b/MIGRATION.md index 17ee14b2d5..6faea706a5 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -323,6 +323,32 @@ pipeline.run(data={"retriever": {"query": query}, "agent": {"messages": [], "que If the prompt itself must still be assembled per run, build `ChatMessage` objects before the `Agent` (e.g. with a `ChatPromptBuilder`) and pass them through the `messages` input. For a runtime system prompt, construct an `Agent` without `system_prompt` or `user_prompt` and include a system message at the start of `messages`. +#### Reserved `state_schema` keys for built-in run metadata + +**What changed:** `Agent` now auto-populates three new outputs — `step_count`, `token_usage`, and `tool_call_counts` — and reserves those names in its `state_schema`. Passing any of them as a `state_schema` key now raises `ValueError`. + +**Why:** These keys are managed by `Agent` itself and exposed as outputs only; allowing users to redefine them would let an input shadow the value the Agent is trying to write. + +**How to migrate:** Rename any clashing `state_schema` entries. + +Before (v2.x): +```python +agent = Agent( + chat_generator=..., + tools=[...], + state_schema={"token_usage": {"type": dict}}, +) +``` + +After (v3.0): +```python +agent = Agent( + chat_generator=..., + tools=[...], + state_schema={"my_token_usage": {"type": dict}}, +) +``` + ### LLM #### Runtime `user_prompt` and `system_prompt` removed from `LLM.run` / `LLM.run_async` diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index babf3be5ac..4cdbc5f232 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -6,6 +6,7 @@ import contextvars import inspect import re +from copy import deepcopy from dataclasses import dataclass from typing import Any, Literal, cast @@ -48,6 +49,77 @@ # Regex to extract the role from a Jinja2 message block, e.g. {% message role="user" %} _JINJA2_MESSAGE_ROLE_RE = re.compile(r'\{%\s*message\s+role\s*=\s*["\'](\w+)["\']') +# State keys that the Agent populates automatically during a run. +# Users may not define them in their own `state_schema`, and they are exposed only as Agent outputs. +_INTERNAL_STATE_KEYS: dict[str, dict[str, Any]] = { + "step_count": {"type": int, "handler": replace_values}, + "token_usage": {"type": dict[str, Any], "handler": replace_values}, + "tool_call_counts": {"type": dict[str, int], "handler": replace_values}, +} + + +def _accumulate_usage(current: Any, new: Any) -> Any: + """ + Recursively sum numeric leaf values across two usage-like dicts. + + Used to aggregate `ChatMessage.meta["usage"]` payloads across LLM calls in a run. Nested dicts (e.g. OpenAI's + `completion_tokens_details`) are merged recursively; numeric leaves are summed; other types fall back to the new + value. + + :param current: The current accumulated usage data. + :param new: The new usage data to merge in. + """ + if isinstance(current, dict) and isinstance(new, dict): + result = dict(current) + for k, v in new.items(): + result[k] = _accumulate_usage(result[k], v) if k in result else deepcopy(v) + return result + if isinstance(current, (int, float)) and isinstance(new, (int, float)): + return current + new + return new + + +def _record_llm_usage(state: State, llm_messages: list[ChatMessage]) -> None: + """ + Aggregate token usage from the latest LLM messages into the State. + + Only writes when at least one message reports `meta["usage"]`, so generators that don't surface usage data + leave `token_usage` at its default empty dict rather than overwriting it. + + :param state: The Agent's State, used to read the running `token_usage` total and write back the new total. + :param llm_messages: The ChatMessage objects returned from the latest LLM call. Token usage is read from each + message's `meta["usage"]` field, if present. + """ + current = state.get("token_usage") + updated = False + for msg in llm_messages: + usage = msg.meta.get("usage") + if isinstance(usage, dict): + current = _accumulate_usage(current or {}, usage) + updated = True + if updated: + state.set("token_usage", current) + + +def _record_tool_calls(state: State, tool_messages: list[ChatMessage]) -> None: + """ + Increment per-tool call counts in the State for every successfully dispatched tool. + + :param state: The Agent's State, used to read the running `tool_call_counts` map and write back the new totals. + :param tool_messages: The ChatMessage objects returned from the latest tool execution. Per-tool counts are + incremented based on each message's `tool_call_result.origin.tool_name`. + """ + counts = state.get("tool_call_counts") or {} + updated = False + for tm in tool_messages: + if tm.tool_call_result is None: + continue + name = tm.tool_call_result.origin.tool_name + counts[name] = counts.get(name, 0) + 1 + updated = True + if updated: + state.set("tool_call_counts", counts) + def _get_run_method_params(instance: "Agent") -> set[str]: """Derive the parameter names of the Agent.run method via introspection.""" @@ -292,7 +364,8 @@ def __init__( with `"type"` (required) and an optional `"handler"` for merging values across tool calls. Tools can read from and write to state keys using `inputs_from_state` and `outputs_to_state`. :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100. - If the agent exceeds this number of steps, it will stop and return the current state. + A step is one chat-generator call plus the execution of every tool call the model requested in + that call (if any). If the agent reaches this number of steps it stops and returns the current state. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? @@ -324,6 +397,12 @@ def __init__( ) if state_schema is not None: + reserved_used = sorted(set(state_schema) & _INTERNAL_STATE_KEYS.keys()) + if reserved_used: + raise ValueError( + f"state_schema keys {reserved_used} are reserved for Agent internal state and " + f"cannot be redefined. Reserved keys: {sorted(_INTERNAL_STATE_KEYS)}." + ) _validate_schema(state_schema) _validate_prompt_message_blocks(user_prompt, system_prompt) if tool_concurrency_limit < 1: @@ -350,13 +429,16 @@ def __init__( self.state_schema = dict(self._state_schema) if self.state_schema.get("messages") is None: self.state_schema["messages"] = {"type": list[ChatMessage], "handler": merge_lists} + for key, config in _INTERNAL_STATE_KEYS.items(): + self.state_schema[key] = dict(config) # --- Component I/O --- self._run_method_params = _get_run_method_params(self) - output_types = {"last_message": ChatMessage} + output_types: dict[str, Any] = {"last_message": ChatMessage} for param, config in self.state_schema.items(): output_types[param] = config["type"] - if param not in self._run_method_params: + # Internal state keys are populated internally by the Agent itself and are not exposed as inputs + if param not in self._run_method_params and param not in _INTERNAL_STATE_KEYS: component.set_input_type(self, name=param, type=config["type"], default=None) component.set_output_types(self, **output_types) @@ -569,15 +651,18 @@ def _initialize_fresh_execution( if all(m.is_from(ChatRole.SYSTEM) for m in messages): logger.warning("All messages provided to the Agent component are system messages. This is not recommended.") + selected_tools = self._select_tools(tools) + state_kwargs: dict[str, Any] = {key: kwargs[key] for key in self.state_schema.keys() if key in kwargs} state = State(schema=self.state_schema, data=state_kwargs) state.set("messages", messages) + state.set("step_count", 0) + state.set("token_usage", {}) + state.set("tool_call_counts", {tool.name: 0 for tool in flatten_tools_or_toolsets(selected_tools)}) streaming_callback = select_streaming_callback( # type: ignore[call-overload] init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=requires_async ) - - selected_tools = self._select_tools(tools) generator_inputs: dict[str, Any] = {} if self._chat_generator_supports_tools: generator_inputs["tools"] = selected_tools @@ -669,6 +754,12 @@ def run( A dictionary with the following keys: - "messages": List of all messages exchanged during the agent's run. - "last_message": The last message exchanged during the agent's run. + - "step_count": The number of steps the agent ran. A step is one chat-generator call plus the + execution of every tool call the model requested in that call (if any). The counter is incremented + after each step completes, including the final step that hits an exit condition or `max_agent_steps`. + - "token_usage": Aggregated token usage from every LLM call in the run, summed from each LLM message's + `meta["usage"]`. + - "tool_call_counts": Mapping of tool name to the number of times that tool was invoked. - Any additional keys defined in the `state_schema`. """ agent_inputs = {"messages": messages, "streaming_callback": streaming_callback, **kwargs} @@ -738,6 +829,12 @@ async def run_async( A dictionary with the following keys: - "messages": List of all messages exchanged during the agent's run. - "last_message": The last message exchanged during the agent's run. + - "step_count": The number of steps the agent ran. A step is one chat-generator call plus the + execution of every tool call the model requested in that call (if any). The counter is incremented + after each step completes, including the final step that hits an exit condition or `max_agent_steps`. + - "token_usage": Aggregated token usage from every LLM call in the run, summed from each LLM message's + `meta["usage"]`. + - "tool_call_counts": Mapping of tool name to the number of times that tool was invoked. - Any additional keys defined in the `state_schema`. """ agent_inputs = {"messages": messages, "streaming_callback": streaming_callback, **kwargs} @@ -787,9 +884,11 @@ def _run_step(self, exe_context: _ExecutionContext, agent_span: tracing.Span) -> llm_span.set_content_tag("haystack.agent.step.llm.output", result) llm_messages = result["replies"] exe_context.state.set("messages", llm_messages) + _record_llm_usage(exe_context.state, llm_messages) if not any(msg.tool_call for msg in llm_messages) or not self.tools: exe_context.counter += 1 + exe_context.state.set("step_count", exe_context.counter) return False modified_tool_call_messages, new_chat_history = _process_confirmation_strategies( @@ -815,13 +914,14 @@ def _run_step(self, exe_context: _ExecutionContext, agent_span: tracing.Span) -> "haystack.agent.step.tool.output", {"tool_messages": tool_messages, "state": exe_context.state} ) exe_context.state.set("messages", tool_messages) - - if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages): - exe_context.counter += 1 - return False + _record_tool_calls(exe_context.state, tool_messages) exe_context.counter += 1 - return True + exe_context.state.set("step_count", exe_context.counter) + exit_triggered = self.exit_conditions != ["text"] and self._check_exit_conditions( + llm_messages, tool_messages + ) + return not exit_triggered async def _run_step_async(self, exe_context: _ExecutionContext, agent_span: tracing.Span) -> bool: """Execute one agent step asynchronously. Returns True to continue the loop, False to stop.""" @@ -848,9 +948,11 @@ async def _run_step_async(self, exe_context: _ExecutionContext, agent_span: trac llm_span.set_content_tag("haystack.agent.step.llm.output", result) llm_messages = result["replies"] exe_context.state.set("messages", llm_messages) + _record_llm_usage(exe_context.state, llm_messages) if not any(msg.tool_call for msg in llm_messages) or not self.tools: exe_context.counter += 1 + exe_context.state.set("step_count", exe_context.counter) return False modified_tool_call_messages, new_chat_history = await _process_confirmation_strategies_async( @@ -876,13 +978,14 @@ async def _run_step_async(self, exe_context: _ExecutionContext, agent_span: trac "haystack.agent.step.tool.output", {"tool_messages": tool_messages, "state": exe_context.state} ) exe_context.state.set("messages", tool_messages) - - if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages): - exe_context.counter += 1 - return False + _record_tool_calls(exe_context.state, tool_messages) exe_context.counter += 1 - return True + exe_context.state.set("step_count", exe_context.counter) + exit_triggered = self.exit_conditions != ["text"] and self._check_exit_conditions( + llm_messages, tool_messages + ) + return not exit_triggered def _check_exit_conditions(self, llm_messages: list[ChatMessage], tool_messages: list[ChatMessage]) -> bool: """ diff --git a/haystack/components/generators/chat/llm.py b/haystack/components/generators/chat/llm.py index be57fd6283..2e2f1adad0 100644 --- a/haystack/components/generators/chat/llm.py +++ b/haystack/components/generators/chat/llm.py @@ -86,6 +86,13 @@ def __init__( ) component.set_input_type(self, "messages", list[ChatMessage], None) + # The Agent base class declares `step_count` and `tool_call_counts` as outputs, but an LLM never has tools + # and always runs exactly one step — those values are uninformative, so drop them from the public surface. + # `token_usage` is still meaningful and stays exposed. + component.set_output_types( + self, messages=list[ChatMessage], last_message=ChatMessage, token_usage=dict[str, Any] + ) + def to_dict(self) -> dict[str, Any]: """ Serialize the LLM component to a dictionary. @@ -140,16 +147,22 @@ def run( # type: ignore[override] # `messages` is in **kwargs to allow dynamic A dictionary with the following keys: - "messages": List of all messages exchanged during the LLM's run. - "last_message": The last message exchanged during the LLM's run. + - "token_usage": Token usage from the LLM call (e.g. prompt_tokens, completion_tokens). Empty if the + chat generator did not return usage data. """ # `messages` is intentionally omitted from the signature so the framework can treat it as required # or optional depending on init configuration. See __init__ for details. messages = kwargs.pop("messages", None) - return super(LLM, self).run( # noqa: UP008 + result = super(LLM, self).run( # noqa: UP008 messages=messages or [], streaming_callback=streaming_callback, generation_kwargs=generation_kwargs, **kwargs, ) + # Inherited Agent-internal bookkeeping that isn't useful at the LLM surface. + result.pop("step_count", None) + result.pop("tool_call_counts", None) + return result async def run_async( # type: ignore[override] # `messages` is in **kwargs to allow dynamic required/optional status self, @@ -174,13 +187,19 @@ async def run_async( # type: ignore[override] # `messages` is in **kwargs to a A dictionary with the following keys: - "messages": List of all messages exchanged during the LLM's run. - "last_message": The last message exchanged during the LLM's run. + - "token_usage": Token usage from the LLM call (e.g. prompt_tokens, completion_tokens). Empty if the + chat generator did not return usage data. """ # `messages` is intentionally omitted from the signature so the framework can treat it as required # or optional depending on init configuration. See __init__ for details. messages = kwargs.pop("messages", None) - return await super(LLM, self).run_async( # noqa: UP008 + result = await super(LLM, self).run_async( # noqa: UP008 messages=messages or [], streaming_callback=streaming_callback, generation_kwargs=generation_kwargs, **kwargs, ) + # Inherited Agent-internal bookkeeping that isn't useful at the LLM surface. + result.pop("step_count", None) + result.pop("tool_call_counts", None) + return result diff --git a/releasenotes/notes/expose-agent-run-metadata-8f2942aba14b4f0c.yaml b/releasenotes/notes/expose-agent-run-metadata-8f2942aba14b4f0c.yaml new file mode 100644 index 0000000000..9cf04e9877 --- /dev/null +++ b/releasenotes/notes/expose-agent-run-metadata-8f2942aba14b4f0c.yaml @@ -0,0 +1,21 @@ +--- +enhancements: + - | + ``Agent`` now exposes three new outputs that are populated automatically during a + run and made available alongside ``messages`` and ``last_message`` in the result dict: + + - ``step_count`` (``int``): the number of steps the agent ran. + - ``token_usage`` (``dict[str, Any]``): aggregated token usage summed across every LLM call in the run + - ``tool_call_counts`` (``dict[str, int]``): number of times each tool was invoked, keyed by tool name. + + These fields are added to ``Agent.state_schema`` automatically so that tools registered via ``inputs_from_state`` can read them mid-run. + They are exposed only as Agent outputs so cannot be passed in as inputs to ``Agent.run`` / ``Agent.run_async``. + - | + ``LLM`` now exposes a ``token_usage`` output alongside ``messages`` and ``last_message``. Because ``LLM`` never + invokes tools and always runs exactly one step, ``step_count`` and ``tool_call_counts`` inherited from ``Agent`` + are not exposed on ``LLM``. +upgrade: + - | + ``step_count``, ``token_usage``, and ``tool_call_counts`` are now reserved keys in ``Agent.state_schema``. + Passing any of them via the ``state_schema`` argument now raises ``ValueError``. + Rename the conflicting state key (e.g. ``my_token_usage``) to migrate. diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index 3eb60d629a..8751b4d086 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -16,8 +16,8 @@ from openai.types.chat import ChatCompletionChunk, chat_completion_chunk from haystack import Document, Pipeline, component, tracing -from haystack.components.agents.agent import Agent -from haystack.components.agents.state import merge_lists +from haystack.components.agents.agent import Agent, _accumulate_usage +from haystack.components.agents.state import merge_lists, replace_values from haystack.components.agents.tool_calling import _run_tool from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder from haystack.components.builders.prompt_builder import PromptBuilder @@ -45,6 +45,16 @@ def _sys_msg(text: str) -> str: return f'{{% message role="system" %}}{text}{{% endmessage %}}' +def _assistant_with_usage(text: str | None = None, *, tool_calls=None, usage: dict[str, Any] | None = None): + """Build an assistant ChatMessage with optional tool_calls and `meta['usage']` populated.""" + meta: dict[str, Any] = {} + if usage is not None: + meta["usage"] = usage + if tool_calls is not None: + return ChatMessage.from_assistant(tool_calls=tool_calls, meta=meta or None) + return ChatMessage.from_assistant(text or "", meta=meta or None) + + def sync_streaming_callback(chunk: StreamingChunk) -> None: """A synchronous streaming callback.""" pass @@ -194,7 +204,14 @@ def test_output_types(self, weather_tool, component_tool, monkeypatch): assert agent.__haystack_output__._sockets_dict == { "messages": OutputSocket(name="messages", type=list[ChatMessage], receivers=[]), "last_message": OutputSocket(name="last_message", type=ChatMessage, receivers=[]), + "step_count": OutputSocket(name="step_count", type=int, receivers=[]), + "token_usage": OutputSocket(name="token_usage", type=dict[str, Any], receivers=[]), + "tool_call_counts": OutputSocket(name="tool_call_counts", type=dict[str, int], receivers=[]), } + # Check that the internal-state keys are not set up as input sockets + assert {"step_count", "token_usage", "tool_call_counts"}.isdisjoint( + agent.__haystack_input__._sockets_dict.keys() + ) def test_to_dict(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") @@ -443,6 +460,9 @@ def test_from_dict(self, monkeypatch): assert agent.state_schema == { "foo": {"type": str}, "messages": {"handler": merge_lists, "type": list[ChatMessage]}, + "step_count": {"type": int, "handler": replace_values}, + "token_usage": {"type": dict[str, Any], "handler": replace_values}, + "tool_call_counts": {"type": dict[str, int], "handler": replace_values}, } assert agent.tool_concurrency_limit == 5 assert agent.tool_streaming_callback_passthrough is True @@ -532,55 +552,24 @@ def test_from_dict_state_schema_none(self, monkeypatch): "http_client_kwargs": None, }, }, - "tools": [ - { - "type": "haystack.tools.tool.Tool", - "data": { - "name": "weather_tool", - "description": "Provides weather information for a given location.", - "parameters": { - "type": "object", - "properties": {"location": {"type": "string"}}, - "required": ["location"], - }, - "function": "test_agent.weather_function", - "outputs_to_string": None, - "inputs_from_state": None, - "outputs_to_state": None, - }, - }, - { - "type": "haystack.tools.component_tool.ComponentTool", - "data": { - "component": { - "type": "haystack.components.builders.prompt_builder.PromptBuilder", - "init_parameters": { - "template": "{{parrot}}", - "variables": None, - "required_variables": "*", - }, - }, - "name": "parrot", - "description": "This is a parrot.", - "parameters": None, - "outputs_to_string": None, - "inputs_from_state": None, - "outputs_to_state": None, - }, - }, - ], + "tools": None, "system_prompt": None, - "exit_conditions": ["text", "weather_tool"], + "exit_conditions": ["text"], "state_schema": None, "max_agent_steps": 100, "raise_on_tool_invocation_failure": False, "streaming_callback": None, - "tool_concurrency_limit": 5, - "tool_streaming_callback_passthrough": True, + "tool_concurrency_limit": 4, + "tool_streaming_callback_passthrough": False, }, } agent = Agent.from_dict(data) - assert agent.state_schema == {"messages": {"type": list[ChatMessage], "handler": merge_lists}} + assert agent.state_schema == { + "messages": {"type": list[ChatMessage], "handler": merge_lists}, + "step_count": {"type": int, "handler": replace_values}, + "token_usage": {"type": dict[str, Any], "handler": replace_values}, + "tool_call_counts": {"type": dict[str, int], "handler": replace_values}, + } def test_serde(self, weather_tool, component_tool, monkeypatch): monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") @@ -590,18 +579,18 @@ def test_serde(self, weather_tool, component_tool, monkeypatch): tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"], state_schema={"foo": {"type": str}}, + streaming_callback=sync_streaming_callback, ) serialized_agent = agent.to_dict() init_parameters = serialized_agent["init_parameters"] - assert serialized_agent["type"] == "haystack.components.agents.agent.Agent" assert ( init_parameters["chat_generator"]["type"] == "haystack.components.generators.chat.openai.OpenAIChatGenerator" ) - assert init_parameters["streaming_callback"] is None + assert init_parameters["streaming_callback"] == "test_agent.sync_streaming_callback" assert init_parameters["tools"][0]["data"]["function"] == serialize_callable(weather_function) assert ( init_parameters["tools"][1]["data"]["component"]["type"] @@ -610,7 +599,6 @@ def test_serde(self, weather_tool, component_tool, monkeypatch): assert init_parameters["exit_conditions"] == ["text", "weather_tool"] deserialized_agent = Agent.from_dict(serialized_agent) - assert isinstance(deserialized_agent, Agent) assert isinstance(deserialized_agent.chat_generator, OpenAIChatGenerator) assert deserialized_agent.tools[0].function is weather_function @@ -619,21 +607,10 @@ def test_serde(self, weather_tool, component_tool, monkeypatch): assert deserialized_agent.state_schema == { "foo": {"type": str}, "messages": {"handler": merge_lists, "type": list[ChatMessage]}, + "step_count": {"type": int, "handler": replace_values}, + "token_usage": {"type": dict[str, Any], "handler": replace_values}, + "tool_call_counts": {"type": dict[str, int], "handler": replace_values}, } - - def test_serde_with_streaming_callback(self, weather_tool, component_tool, monkeypatch): - monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") - generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY")) - agent = Agent( - chat_generator=generator, tools=[weather_tool, component_tool], streaming_callback=sync_streaming_callback - ) - - serialized_agent = agent.to_dict() - - init_parameters = serialized_agent["init_parameters"] - assert init_parameters["streaming_callback"] == "test_agent.sync_streaming_callback" - - deserialized_agent = Agent.from_dict(serialized_agent) assert deserialized_agent.streaming_callback is sync_streaming_callback def test_exit_conditions_validation(self, weather_tool, component_tool, monkeypatch): @@ -773,8 +750,9 @@ def test_exceed_max_steps(self, monkeypatch, weather_tool, caplog): agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages}) with caplog.at_level(logging.WARNING): - agent.run([ChatMessage.from_user("Hello")]) + result = agent.run([ChatMessage.from_user("Hello")]) assert "Agent reached maximum agent steps" in caplog.text + assert result["step_count"] == 0 def test_exit_condition_exits(self, monkeypatch, weather_tool): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") @@ -950,6 +928,13 @@ def test_run(self, weather_tool): assert "last_message" in response assert isinstance(response["last_message"], ChatMessage) assert response["messages"][-1] == response["last_message"] + # Auto-populated run outputs: + # 4 messages → tool call + final answer = 2 LLM calls = 2 steps; one weather_tool invocation. + assert response["step_count"] == 2 + assert response["tool_call_counts"] == {"weather_tool": 1} + assert response["token_usage"]["prompt_tokens"] > 0 + assert response["token_usage"]["completion_tokens"] > 0 + assert response["token_usage"]["total_tokens"] > 0 @pytest.mark.asyncio async def test_generation_kwargs(self): @@ -1025,13 +1010,21 @@ def streaming_callback(chunk: StreamingChunk) -> None: streaming_callback_called = True result = agent.run( - [ChatMessage.from_user("What's the weather in Paris?")], streaming_callback=streaming_callback + [ChatMessage.from_user("What's the weather in Paris?")], + streaming_callback=streaming_callback, + generation_kwargs={"stream_options": {"include_usage": True}}, ) assert result is not None assert result["messages"] is not None assert result["last_message"] is not None assert streaming_callback_called + # Auto-populated run outputs. + assert result["step_count"] == 2 + assert result["tool_call_counts"] == {"weather_tool": 1} + assert result["token_usage"]["prompt_tokens"] > 0 + assert result["token_usage"]["completion_tokens"] > 0 + assert result["token_usage"]["total_tokens"] > 0 @pytest.mark.asyncio async def test_run_async_with_async_streaming_callback(self, weather_tool): @@ -1060,6 +1053,140 @@ async def test_run_async_with_sync_streaming_callback_fails(self, weather_tool): with pytest.raises(ValueError, match="The init callback must be async compatible"): await agent.run_async([ChatMessage.from_user("Hello")]) + def test_reserved_state_schema_keys_raise(self, monkeypatch, weather_tool): + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + for reserved in ("step_count", "token_usage", "tool_call_counts"): + with pytest.raises(ValueError, match="reserved for Agent internal state"): + Agent( + chat_generator=OpenAIChatGenerator(), tools=[weather_tool], state_schema={reserved: {"type": int}} + ) + + def test_run_populates_token_usage_and_tool_call_counts(self, monkeypatch, weather_tool, component_tool): + """A multi-step run aggregates step_count, token_usage (incl. nested details), and tool_call_counts.""" + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + agent = Agent(chat_generator=OpenAIChatGenerator(), tools=[weather_tool, component_tool]) + # Step 1: two parallel tool calls + usage with nested detail dicts. + # Step 2: one more weather_tool call + flat usage. + # Step 3: final text answer + usage. + first_step = [ + _assistant_with_usage( + tool_calls=[ + ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}), + ToolCall(tool_name="parrot", arguments={"parrot": "hi"}), + ], + usage={ + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + "completion_tokens_details": {"reasoning_tokens": 2}, + }, + ) + ] + second_step = [ + _assistant_with_usage( + tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Paris"})], + usage={"prompt_tokens": 6, "completion_tokens": 3, "total_tokens": 9}, + ) + ] + third_step = [ + _assistant_with_usage( + "Done.", + usage={ + "prompt_tokens": 4, + "completion_tokens": 2, + "total_tokens": 6, + "completion_tokens_details": {"reasoning_tokens": 1}, + }, + ) + ] + agent.chat_generator.run = MagicMock( + side_effect=[{"replies": first_step}, {"replies": second_step}, {"replies": third_step}] + ) + + result = agent.run([ChatMessage.from_user("Hi")]) + assert result["step_count"] == 3 + assert result["tool_call_counts"] == {"weather_tool": 2, "parrot": 1} + assert result["token_usage"] == { + "prompt_tokens": 20, + "completion_tokens": 10, + "total_tokens": 30, + "completion_tokens_details": {"reasoning_tokens": 3}, + } + + @pytest.mark.asyncio + async def test_run_async_populates_token_usage_and_tool_call_counts(self, monkeypatch, weather_tool): + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + agent = Agent(chat_generator=OpenAIChatGenerator(), tools=[weather_tool]) + first_step = [ + _assistant_with_usage( + tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})], + usage={"prompt_tokens": 4, "completion_tokens": 2, "total_tokens": 6}, + ) + ] + second_step = [ + _assistant_with_usage("Done.", usage={"prompt_tokens": 3, "completion_tokens": 1, "total_tokens": 4}) + ] + agent.chat_generator.run_async = AsyncMock(side_effect=[{"replies": first_step}, {"replies": second_step}]) + + result = await agent.run_async([ChatMessage.from_user("Hi")]) + assert result["step_count"] == 2 + assert result["tool_call_counts"] == {"weather_tool": 1} + assert result["token_usage"] == {"prompt_tokens": 7, "completion_tokens": 3, "total_tokens": 10} + + def test_metadata_outputs_show_defaults_when_no_data(self, weather_tool): + """`token_usage` stays empty and `tool_call_counts` reports zero for every tool when nothing happens.""" + agent = Agent(chat_generator=MockChatGenerator(), tools=[weather_tool]) + result = agent.run([ChatMessage.from_user("Hi")]) + # MockChatGenerator returns a text-only reply with no `usage` meta and no tool calls. + assert result["step_count"] == 1 + assert result["token_usage"] == {} + assert result["tool_call_counts"] == {"weather_tool": 0} + + +class TestAccumulateUsage: + """Unit tests for the `_accumulate_usage` helper used to merge ChatGenerator usage dicts.""" + + def test_sums_flat_numeric_keys(self): + current = {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + new = {"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5} + assert _accumulate_usage(current, new) == {"prompt_tokens": 13, "completion_tokens": 7, "total_tokens": 20} + + def test_merges_nested_detail_dicts_recursively(self): + current = {"prompt_tokens": 10, "completion_tokens_details": {"reasoning_tokens": 2, "audio_tokens": 0}} + new = { + "prompt_tokens": 4, + "completion_tokens_details": {"reasoning_tokens": 3, "audio_tokens": 1}, + "prompt_tokens_details": {"cached_tokens": 6}, + } + assert _accumulate_usage(current, new) == { + "prompt_tokens": 14, + "completion_tokens_details": {"reasoning_tokens": 5, "audio_tokens": 1}, + "prompt_tokens_details": {"cached_tokens": 6}, + } + + def test_adds_keys_missing_in_current(self): + assert _accumulate_usage({"prompt_tokens": 5}, {"completion_tokens": 7}) == { + "prompt_tokens": 5, + "completion_tokens": 7, + } + + def test_empty_current_dict_returns_copy_of_new(self): + new = {"prompt_tokens": 5, "details": {"cached_tokens": 1}} + result = _accumulate_usage({}, new) + assert result == new + # Nested dicts must be deep-copied so future merges don't mutate the source. + new["details"]["cached_tokens"] = 999 + assert result["details"]["cached_tokens"] == 1 + + def test_non_dict_non_numeric_falls_back_to_new(self): + # Strings, lists, or any other type that isn't a dict-or-number pair returns `new` unchanged. + assert _accumulate_usage("old-model", "new-model") == "new-model" + assert _accumulate_usage(5, "stringified") == "stringified" + assert _accumulate_usage({"model": "gpt-x"}, {"model": "gpt-y"}) == {"model": "gpt-y"} + + def test_sums_floats(self): + assert _accumulate_usage(1.5, 2.25) == 3.75 + class TestAgentTracing: def test_agent_tracing_span_run(self, caplog, monkeypatch, weather_tool): @@ -1111,9 +1238,9 @@ def test_agent_tracing_span_run(self, caplog, monkeypatch, weather_tool): "haystack.agent.max_steps": 100, "haystack.agent.tools": '[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]', # noqa: E501 "haystack.agent.exit_conditions": '["text"]', - "haystack.agent.state_schema": '{"messages": {"type": "list[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}', # noqa: E501 + "haystack.agent.state_schema": '{"messages": {"type": "list[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}, "step_count": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"}, "token_usage": {"type": "dict[str, typing.Any]", "handler": "haystack.components.agents.state.state_utils.replace_values"}, "tool_call_counts": {"type": "dict[str, int]", "handler": "haystack.components.agents.state.state_utils.replace_values"}}', # noqa: E501 "haystack.agent.input": '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null}', # noqa: E501 - "haystack.agent.output": '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}]}', # noqa: E501 + "haystack.agent.output": '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}], "step_count": 1, "token_usage": {}, "tool_call_counts": {"weather_tool": 0}}', # noqa: E501 "haystack.agent.steps_taken": 1, } @@ -1217,9 +1344,9 @@ async def test_agent_tracing_span_async_run(self, caplog, monkeypatch, weather_t "haystack.agent.max_steps": 100, "haystack.agent.tools": '[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]', # noqa: E501 "haystack.agent.exit_conditions": '["text"]', - "haystack.agent.state_schema": '{"messages": {"type": "list[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}', # noqa: E501 + "haystack.agent.state_schema": '{"messages": {"type": "list[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}, "step_count": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"}, "token_usage": {"type": "dict[str, typing.Any]", "handler": "haystack.components.agents.state.state_utils.replace_values"}, "tool_call_counts": {"type": "dict[str, int]", "handler": "haystack.components.agents.state.state_utils.replace_values"}}', # noqa: E501 "haystack.agent.input": '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null}', # noqa: E501 - "haystack.agent.output": '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}]}', # noqa: E501 + "haystack.agent.output": '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}], "step_count": 1, "token_usage": {}, "tool_call_counts": {"weather_tool": 0}}', # noqa: E501 "haystack.agent.steps_taken": 1, } diff --git a/test/components/agents/test_agent_hitl.py b/test/components/agents/test_agent_hitl.py index e13e7b14d9..1944aa0b33 100644 --- a/test/components/agents/test_agent_hitl.py +++ b/test/components/agents/test_agent_hitl.py @@ -166,6 +166,12 @@ def test_run_blocking_confirmation_strategy_modify(self, tools): assert isinstance(result["last_message"], ChatMessage) assert result["last_message"].text is not None assert "5" in result["last_message"].text + # Auto-populated run-metadata outputs: at least one tool call plus a final answer. + assert result["step_count"] >= 2 + assert result["tool_call_counts"]["addition_tool"] >= 1 + assert result["token_usage"]["prompt_tokens"] > 0 + assert result["token_usage"]["completion_tokens"] > 0 + assert result["token_usage"]["total_tokens"] > 0 @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @pytest.mark.integration @@ -188,3 +194,9 @@ async def test_run_async_blocking_confirmation_strategy_modify(self, tools): assert isinstance(result["last_message"], ChatMessage) assert result["last_message"].text is not None assert "5" in result["last_message"].text + # Auto-populated run-metadata outputs: at least one tool call plus a final answer. + assert result["step_count"] >= 2 + assert result["tool_call_counts"]["addition_tool"] >= 1 + assert result["token_usage"]["prompt_tokens"] > 0 + assert result["token_usage"]["completion_tokens"] > 0 + assert result["token_usage"]["total_tokens"] > 0 diff --git a/test/components/generators/chat/test_llm.py b/test/components/generators/chat/test_llm.py index e05bba316f..451bd549eb 100644 --- a/test/components/generators/chat/test_llm.py +++ b/test/components/generators/chat/test_llm.py @@ -90,6 +90,7 @@ def test_output_sockets(self): assert llm.__haystack_output__._sockets_dict == { "messages": OutputSocket(name="messages", type=list[ChatMessage], receivers=[]), "last_message": OutputSocket(name="last_message", type=ChatMessage, receivers=[]), + "token_usage": OutputSocket(name="token_usage", type=dict[str, Any], receivers=[]), } def test_detects_no_tools_support(self):