Skip to content

Commit 278c599

Browse files
sjrljulian-risch
andauthored
feat!: Track step_count, token_usage and tool_call_counts in Agent's State (#11427)
Co-authored-by: Julian Risch <julian.risch@deepset.ai>
1 parent c487176 commit 278c599

7 files changed

Lines changed: 392 additions & 83 deletions

File tree

MIGRATION.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,32 @@ pipeline.run(data={"retriever": {"query": query}, "agent": {"messages": [], "que
337337
If the prompt itself must still be assembled per run, build `ChatMessage` objects before the `Agent` (e.g. with a `ChatPromptBuilder`) and pass them through the `messages` input.
338338
For a runtime system prompt, construct an `Agent` without `system_prompt` or `user_prompt` and include a system message at the start of `messages`.
339339

340+
#### Reserved `state_schema` keys for built-in run metadata
341+
342+
**What changed:** `Agent` now auto-populates three new outputs — `step_count`, `token_usage`, and `tool_call_counts` — and reserves those names in its `state_schema`. Passing any of them as a `state_schema` key now raises `ValueError`.
343+
344+
**Why:** These keys are managed by `Agent` itself and exposed as outputs only; allowing users to redefine them would let an input shadow the value the Agent is trying to write.
345+
346+
**How to migrate:** Rename any clashing `state_schema` entries.
347+
348+
Before (v2.x):
349+
```python
350+
agent = Agent(
351+
chat_generator=...,
352+
tools=[...],
353+
state_schema={"token_usage": {"type": dict}},
354+
)
355+
```
356+
357+
After (v3.0):
358+
```python
359+
agent = Agent(
360+
chat_generator=...,
361+
tools=[...],
362+
state_schema={"my_token_usage": {"type": dict}},
363+
)
364+
```
365+
340366
### LLM
341367

342368
#### Runtime `user_prompt` and `system_prompt` removed from `LLM.run` / `LLM.run_async`

haystack/components/agents/agent.py

Lines changed: 118 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import contextvars
77
import inspect
88
import re
9+
from copy import deepcopy
910
from dataclasses import dataclass
1011
from typing import Any, Literal, cast
1112

@@ -48,6 +49,77 @@
4849
# Regex to extract the role from a Jinja2 message block, e.g. {% message role="user" %}
4950
_JINJA2_MESSAGE_ROLE_RE = re.compile(r'\{%\s*message\s+role\s*=\s*["\'](\w+)["\']')
5051

52+
# State keys that the Agent populates automatically during a run.
53+
# Users may not define them in their own `state_schema`, and they are exposed only as Agent outputs.
54+
_INTERNAL_STATE_KEYS: dict[str, dict[str, Any]] = {
55+
"step_count": {"type": int, "handler": replace_values},
56+
"token_usage": {"type": dict[str, Any], "handler": replace_values},
57+
"tool_call_counts": {"type": dict[str, int], "handler": replace_values},
58+
}
59+
60+
61+
def _accumulate_usage(current: Any, new: Any) -> Any:
62+
"""
63+
Recursively sum numeric leaf values across two usage-like dicts.
64+
65+
Used to aggregate `ChatMessage.meta["usage"]` payloads across LLM calls in a run. Nested dicts (e.g. OpenAI's
66+
`completion_tokens_details`) are merged recursively; numeric leaves are summed; other types fall back to the new
67+
value.
68+
69+
:param current: The current accumulated usage data.
70+
:param new: The new usage data to merge in.
71+
"""
72+
if isinstance(current, dict) and isinstance(new, dict):
73+
result = dict(current)
74+
for k, v in new.items():
75+
result[k] = _accumulate_usage(result[k], v) if k in result else deepcopy(v)
76+
return result
77+
if isinstance(current, (int, float)) and isinstance(new, (int, float)):
78+
return current + new
79+
return new
80+
81+
82+
def _record_llm_usage(state: State, llm_messages: list[ChatMessage]) -> None:
83+
"""
84+
Aggregate token usage from the latest LLM messages into the State.
85+
86+
Only writes when at least one message reports `meta["usage"]`, so generators that don't surface usage data
87+
leave `token_usage` at its default empty dict rather than overwriting it.
88+
89+
:param state: The Agent's State, used to read the running `token_usage` total and write back the new total.
90+
:param llm_messages: The ChatMessage objects returned from the latest LLM call. Token usage is read from each
91+
message's `meta["usage"]` field, if present.
92+
"""
93+
current = state.get("token_usage")
94+
updated = False
95+
for msg in llm_messages:
96+
usage = msg.meta.get("usage")
97+
if isinstance(usage, dict):
98+
current = _accumulate_usage(current or {}, usage)
99+
updated = True
100+
if updated:
101+
state.set("token_usage", current)
102+
103+
104+
def _record_tool_calls(state: State, tool_messages: list[ChatMessage]) -> None:
105+
"""
106+
Increment per-tool call counts in the State for every successfully dispatched tool.
107+
108+
:param state: The Agent's State, used to read the running `tool_call_counts` map and write back the new totals.
109+
:param tool_messages: The ChatMessage objects returned from the latest tool execution. Per-tool counts are
110+
incremented based on each message's `tool_call_result.origin.tool_name`.
111+
"""
112+
counts = state.get("tool_call_counts") or {}
113+
updated = False
114+
for tm in tool_messages:
115+
if tm.tool_call_result is None:
116+
continue
117+
name = tm.tool_call_result.origin.tool_name
118+
counts[name] = counts.get(name, 0) + 1
119+
updated = True
120+
if updated:
121+
state.set("tool_call_counts", counts)
122+
51123

52124
def _get_run_method_params(instance: "Agent") -> set[str]:
53125
"""Derive the parameter names of the Agent.run method via introspection."""
@@ -292,7 +364,8 @@ def __init__(
292364
with `"type"` (required) and an optional `"handler"` for merging values across tool calls.
293365
Tools can read from and write to state keys using `inputs_from_state` and `outputs_to_state`.
294366
:param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
295-
If the agent exceeds this number of steps, it will stop and return the current state.
367+
A step is one chat-generator call plus the execution of every tool call the model requested in
368+
that call (if any). If the agent reaches this number of steps it stops and returns the current state.
296369
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
297370
The same callback can be configured to emit tool results when a tool is called.
298371
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
@@ -324,6 +397,12 @@ def __init__(
324397
)
325398

326399
if state_schema is not None:
400+
reserved_used = sorted(set(state_schema) & _INTERNAL_STATE_KEYS.keys())
401+
if reserved_used:
402+
raise ValueError(
403+
f"state_schema keys {reserved_used} are reserved for Agent internal state and "
404+
f"cannot be redefined. Reserved keys: {sorted(_INTERNAL_STATE_KEYS)}."
405+
)
327406
_validate_schema(state_schema)
328407
_validate_prompt_message_blocks(user_prompt, system_prompt)
329408
if tool_concurrency_limit < 1:
@@ -350,13 +429,16 @@ def __init__(
350429
self.state_schema = dict(self._state_schema)
351430
if self.state_schema.get("messages") is None:
352431
self.state_schema["messages"] = {"type": list[ChatMessage], "handler": merge_lists}
432+
for key, config in _INTERNAL_STATE_KEYS.items():
433+
self.state_schema[key] = dict(config)
353434

354435
# --- Component I/O ---
355436
self._run_method_params = _get_run_method_params(self)
356-
output_types = {"last_message": ChatMessage}
437+
output_types: dict[str, Any] = {"last_message": ChatMessage}
357438
for param, config in self.state_schema.items():
358439
output_types[param] = config["type"]
359-
if param not in self._run_method_params:
440+
# Internal state keys are populated internally by the Agent itself and are not exposed as inputs
441+
if param not in self._run_method_params and param not in _INTERNAL_STATE_KEYS:
360442
component.set_input_type(self, name=param, type=config["type"], default=None)
361443
component.set_output_types(self, **output_types)
362444

@@ -569,15 +651,18 @@ def _initialize_fresh_execution(
569651
if all(m.is_from(ChatRole.SYSTEM) for m in messages):
570652
logger.warning("All messages provided to the Agent component are system messages. This is not recommended.")
571653

654+
selected_tools = self._select_tools(tools)
655+
572656
state_kwargs: dict[str, Any] = {key: kwargs[key] for key in self.state_schema.keys() if key in kwargs}
573657
state = State(schema=self.state_schema, data=state_kwargs)
574658
state.set("messages", messages)
659+
state.set("step_count", 0)
660+
state.set("token_usage", {})
661+
state.set("tool_call_counts", {tool.name: 0 for tool in flatten_tools_or_toolsets(selected_tools)})
575662

576663
streaming_callback = select_streaming_callback( # type: ignore[call-overload]
577664
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=requires_async
578665
)
579-
580-
selected_tools = self._select_tools(tools)
581666
generator_inputs: dict[str, Any] = {}
582667
if self._chat_generator_supports_tools:
583668
generator_inputs["tools"] = selected_tools
@@ -669,6 +754,12 @@ def run(
669754
A dictionary with the following keys:
670755
- "messages": List of all messages exchanged during the agent's run.
671756
- "last_message": The last message exchanged during the agent's run.
757+
- "step_count": The number of steps the agent ran. A step is one chat-generator call plus the
758+
execution of every tool call the model requested in that call (if any). The counter is incremented
759+
after each step completes, including the final step that hits an exit condition or `max_agent_steps`.
760+
- "token_usage": Aggregated token usage from every LLM call in the run, summed from each LLM message's
761+
`meta["usage"]`.
762+
- "tool_call_counts": Mapping of tool name to the number of times that tool was invoked.
672763
- Any additional keys defined in the `state_schema`.
673764
"""
674765
agent_inputs = {"messages": messages, "streaming_callback": streaming_callback, **kwargs}
@@ -738,6 +829,12 @@ async def run_async(
738829
A dictionary with the following keys:
739830
- "messages": List of all messages exchanged during the agent's run.
740831
- "last_message": The last message exchanged during the agent's run.
832+
- "step_count": The number of steps the agent ran. A step is one chat-generator call plus the
833+
execution of every tool call the model requested in that call (if any). The counter is incremented
834+
after each step completes, including the final step that hits an exit condition or `max_agent_steps`.
835+
- "token_usage": Aggregated token usage from every LLM call in the run, summed from each LLM message's
836+
`meta["usage"]`.
837+
- "tool_call_counts": Mapping of tool name to the number of times that tool was invoked.
741838
- Any additional keys defined in the `state_schema`.
742839
"""
743840
agent_inputs = {"messages": messages, "streaming_callback": streaming_callback, **kwargs}
@@ -787,9 +884,11 @@ def _run_step(self, exe_context: _ExecutionContext, agent_span: tracing.Span) ->
787884
llm_span.set_content_tag("haystack.agent.step.llm.output", result)
788885
llm_messages = result["replies"]
789886
exe_context.state.set("messages", llm_messages)
887+
_record_llm_usage(exe_context.state, llm_messages)
790888

791889
if not any(msg.tool_call for msg in llm_messages) or not self.tools:
792890
exe_context.counter += 1
891+
exe_context.state.set("step_count", exe_context.counter)
793892
return False
794893

795894
modified_tool_call_messages, new_chat_history = _process_confirmation_strategies(
@@ -815,13 +914,14 @@ def _run_step(self, exe_context: _ExecutionContext, agent_span: tracing.Span) ->
815914
"haystack.agent.step.tool.output", {"tool_messages": tool_messages, "state": exe_context.state}
816915
)
817916
exe_context.state.set("messages", tool_messages)
818-
819-
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
820-
exe_context.counter += 1
821-
return False
917+
_record_tool_calls(exe_context.state, tool_messages)
822918

823919
exe_context.counter += 1
824-
return True
920+
exe_context.state.set("step_count", exe_context.counter)
921+
exit_triggered = self.exit_conditions != ["text"] and self._check_exit_conditions(
922+
llm_messages, tool_messages
923+
)
924+
return not exit_triggered
825925

826926
async def _run_step_async(self, exe_context: _ExecutionContext, agent_span: tracing.Span) -> bool:
827927
"""Execute one agent step asynchronously. Returns True to continue the loop, False to stop."""
@@ -848,9 +948,11 @@ async def _run_step_async(self, exe_context: _ExecutionContext, agent_span: trac
848948
llm_span.set_content_tag("haystack.agent.step.llm.output", result)
849949
llm_messages = result["replies"]
850950
exe_context.state.set("messages", llm_messages)
951+
_record_llm_usage(exe_context.state, llm_messages)
851952

852953
if not any(msg.tool_call for msg in llm_messages) or not self.tools:
853954
exe_context.counter += 1
955+
exe_context.state.set("step_count", exe_context.counter)
854956
return False
855957

856958
modified_tool_call_messages, new_chat_history = await _process_confirmation_strategies_async(
@@ -876,13 +978,14 @@ async def _run_step_async(self, exe_context: _ExecutionContext, agent_span: trac
876978
"haystack.agent.step.tool.output", {"tool_messages": tool_messages, "state": exe_context.state}
877979
)
878980
exe_context.state.set("messages", tool_messages)
879-
880-
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
881-
exe_context.counter += 1
882-
return False
981+
_record_tool_calls(exe_context.state, tool_messages)
883982

884983
exe_context.counter += 1
885-
return True
984+
exe_context.state.set("step_count", exe_context.counter)
985+
exit_triggered = self.exit_conditions != ["text"] and self._check_exit_conditions(
986+
llm_messages, tool_messages
987+
)
988+
return not exit_triggered
886989

887990
def _check_exit_conditions(self, llm_messages: list[ChatMessage], tool_messages: list[ChatMessage]) -> bool:
888991
"""

haystack/components/generators/chat/llm.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ def __init__(
8686
)
8787
component.set_input_type(self, "messages", list[ChatMessage], None)
8888

89+
# The Agent base class declares `step_count` and `tool_call_counts` as outputs, but an LLM never has tools
90+
# and always runs exactly one step — those values are uninformative, so drop them from the public surface.
91+
# `token_usage` is still meaningful and stays exposed.
92+
component.set_output_types(
93+
self, messages=list[ChatMessage], last_message=ChatMessage, token_usage=dict[str, Any]
94+
)
95+
8996
def to_dict(self) -> dict[str, Any]:
9097
"""
9198
Serialize the LLM component to a dictionary.
@@ -140,16 +147,22 @@ def run( # type: ignore[override] # `messages` is in **kwargs to allow dynamic
140147
A dictionary with the following keys:
141148
- "messages": List of all messages exchanged during the LLM's run.
142149
- "last_message": The last message exchanged during the LLM's run.
150+
- "token_usage": Token usage from the LLM call (e.g. prompt_tokens, completion_tokens). Empty if the
151+
chat generator did not return usage data.
143152
"""
144153
# `messages` is intentionally omitted from the signature so the framework can treat it as required
145154
# or optional depending on init configuration. See __init__ for details.
146155
messages = kwargs.pop("messages", None)
147-
return super(LLM, self).run( # noqa: UP008
156+
result = super(LLM, self).run( # noqa: UP008
148157
messages=messages or [],
149158
streaming_callback=streaming_callback,
150159
generation_kwargs=generation_kwargs,
151160
**kwargs,
152161
)
162+
# Inherited Agent-internal bookkeeping that isn't useful at the LLM surface.
163+
result.pop("step_count", None)
164+
result.pop("tool_call_counts", None)
165+
return result
153166

154167
async def run_async( # type: ignore[override] # `messages` is in **kwargs to allow dynamic required/optional status
155168
self,
@@ -174,13 +187,19 @@ async def run_async( # type: ignore[override] # `messages` is in **kwargs to a
174187
A dictionary with the following keys:
175188
- "messages": List of all messages exchanged during the LLM's run.
176189
- "last_message": The last message exchanged during the LLM's run.
190+
- "token_usage": Token usage from the LLM call (e.g. prompt_tokens, completion_tokens). Empty if the
191+
chat generator did not return usage data.
177192
"""
178193
# `messages` is intentionally omitted from the signature so the framework can treat it as required
179194
# or optional depending on init configuration. See __init__ for details.
180195
messages = kwargs.pop("messages", None)
181-
return await super(LLM, self).run_async( # noqa: UP008
196+
result = await super(LLM, self).run_async( # noqa: UP008
182197
messages=messages or [],
183198
streaming_callback=streaming_callback,
184199
generation_kwargs=generation_kwargs,
185200
**kwargs,
186201
)
202+
# Inherited Agent-internal bookkeeping that isn't useful at the LLM surface.
203+
result.pop("step_count", None)
204+
result.pop("tool_call_counts", None)
205+
return result
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
---
2+
enhancements:
3+
- |
4+
``Agent`` now exposes three new outputs that are populated automatically during a
5+
run and made available alongside ``messages`` and ``last_message`` in the result dict:
6+
7+
- ``step_count`` (``int``): the number of steps the agent ran.
8+
- ``token_usage`` (``dict[str, Any]``): aggregated token usage summed across every LLM call in the run
9+
- ``tool_call_counts`` (``dict[str, int]``): number of times each tool was invoked, keyed by tool name.
10+
11+
These fields are added to ``Agent.state_schema`` automatically so that tools registered via ``inputs_from_state`` can read them mid-run.
12+
They are exposed only as Agent outputs so cannot be passed in as inputs to ``Agent.run`` / ``Agent.run_async``.
13+
- |
14+
``LLM`` now exposes a ``token_usage`` output alongside ``messages`` and ``last_message``. Because ``LLM`` never
15+
invokes tools and always runs exactly one step, ``step_count`` and ``tool_call_counts`` inherited from ``Agent``
16+
are not exposed on ``LLM``.
17+
upgrade:
18+
- |
19+
``step_count``, ``token_usage``, and ``tool_call_counts`` are now reserved keys in ``Agent.state_schema``.
20+
Passing any of them via the ``state_schema`` argument now raises ``ValueError``.
21+
Rename the conflicting state key (e.g. ``my_token_usage``) to migrate.

0 commit comments

Comments
 (0)