Skip to content

Commit 15c9571

Browse files
authored
fix: ensure LLM stats tracking is accurate by including completed subagents (#441)
1 parent 62e9af3 commit 15c9571

5 files changed

Lines changed: 151 additions & 32 deletions

File tree

strix/agents/base_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def _add_to_agents_graph(self) -> None:
134134
}
135135
agents_graph_actions._agent_graph["nodes"][self.state.agent_id] = node
136136

137-
agents_graph_actions._agent_instances[self.state.agent_id] = self
137+
with agents_graph_actions._agent_llm_stats_lock:
138+
agents_graph_actions._agent_instances[self.state.agent_id] = self
138139
agents_graph_actions._agent_states[self.state.agent_id] = self.state
139140

140141
if self.state.parent_id:

strix/telemetry/tracer.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -799,17 +799,25 @@ def get_real_tool_count(self) -> int:
799799
)
800800

801801
def get_total_llm_stats(self) -> dict[str, Any]:
802-
from strix.tools.agents_graph.agents_graph_actions import _agent_instances
802+
from strix.tools.agents_graph.agents_graph_actions import (
803+
_agent_instances,
804+
_completed_agent_llm_totals,
805+
_agent_llm_stats_lock,
806+
)
807+
808+
with _agent_llm_stats_lock:
809+
completed_totals = dict(_completed_agent_llm_totals)
810+
active_agents = list(_agent_instances.values())
803811

804812
total_stats = {
805-
"input_tokens": 0,
806-
"output_tokens": 0,
807-
"cached_tokens": 0,
808-
"cost": 0.0,
809-
"requests": 0,
813+
"input_tokens": int(completed_totals.get("input_tokens", 0) or 0),
814+
"output_tokens": int(completed_totals.get("output_tokens", 0) or 0),
815+
"cached_tokens": int(completed_totals.get("cached_tokens", 0) or 0),
816+
"cost": float(completed_totals.get("cost", 0.0) or 0.0),
817+
"requests": int(completed_totals.get("requests", 0) or 0),
810818
}
811819

812-
for agent_instance in _agent_instances.values():
820+
for agent_instance in active_agents:
813821
if hasattr(agent_instance, "llm") and hasattr(agent_instance.llm, "_total_stats"):
814822
agent_stats = agent_instance.llm._total_stats
815823
total_stats["input_tokens"] += agent_stats.input_tokens

strix/tools/agents_graph/agents_graph_actions.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,55 @@
1919

2020
_agent_instances: dict[str, Any] = {}
2121

22+
_agent_llm_stats_lock = threading.Lock()
23+
24+
25+
def _empty_llm_stats_totals() -> dict[str, int | float]:
26+
return {
27+
"input_tokens": 0,
28+
"output_tokens": 0,
29+
"cached_tokens": 0,
30+
"cost": 0.0,
31+
"requests": 0,
32+
}
33+
34+
35+
_completed_agent_llm_totals: dict[str, int | float] = _empty_llm_stats_totals()
36+
2237
_agent_states: dict[str, Any] = {}
2338

2439

40+
def _snapshot_agent_llm_stats(agent: Any) -> dict[str, int | float] | None:
41+
if not hasattr(agent, "llm") or not hasattr(agent.llm, "_total_stats"):
42+
return None
43+
44+
stats = agent.llm._total_stats
45+
return {
46+
"input_tokens": stats.input_tokens,
47+
"output_tokens": stats.output_tokens,
48+
"cached_tokens": stats.cached_tokens,
49+
"cost": stats.cost,
50+
"requests": stats.requests,
51+
}
52+
53+
54+
def _finalize_agent_llm_stats(agent_id: str, agent: Any) -> None:
55+
stats = _snapshot_agent_llm_stats(agent)
56+
with _agent_llm_stats_lock:
57+
if stats is not None:
58+
_completed_agent_llm_totals["input_tokens"] += int(stats["input_tokens"])
59+
_completed_agent_llm_totals["output_tokens"] += int(stats["output_tokens"])
60+
_completed_agent_llm_totals["cached_tokens"] += int(stats["cached_tokens"])
61+
_completed_agent_llm_totals["cost"] += float(stats["cost"])
62+
_completed_agent_llm_totals["requests"] += int(stats["requests"])
63+
64+
node = _agent_graph["nodes"].get(agent_id)
65+
if node is not None:
66+
node["llm_stats"] = stats
67+
68+
_agent_instances.pop(agent_id, None)
69+
70+
2571
def _is_whitebox_agent(agent_id: str) -> bool:
2672
agent = _agent_instances.get(agent_id)
2773
return bool(getattr(getattr(agent, "llm_config", None), "is_whitebox", False))
@@ -237,7 +283,7 @@ def _run_agent_in_thread(
237283
_agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat()
238284
_agent_graph["nodes"][state.agent_id]["result"] = {"error": str(e)}
239285
_running_agents.pop(state.agent_id, None)
240-
_agent_instances.pop(state.agent_id, None)
286+
_finalize_agent_llm_stats(state.agent_id, agent)
241287
raise
242288
else:
243289
if state.stop_requested:
@@ -247,7 +293,7 @@ def _run_agent_in_thread(
247293
_agent_graph["nodes"][state.agent_id]["finished_at"] = datetime.now(UTC).isoformat()
248294
_agent_graph["nodes"][state.agent_id]["result"] = result
249295
_running_agents.pop(state.agent_id, None)
250-
_agent_instances.pop(state.agent_id, None)
296+
_finalize_agent_llm_stats(state.agent_id, agent)
251297

252298
return {"result": result}
253299

@@ -418,7 +464,8 @@ def create_agent(
418464
if inherit_context:
419465
inherited_messages = agent_state.get_conversation_history()
420466

421-
_agent_instances[state.agent_id] = agent
467+
with _agent_llm_stats_lock:
468+
_agent_instances[state.agent_id] = agent
422469

423470
thread = threading.Thread(
424471
target=_run_agent_in_thread,

tests/telemetry/test_tracer.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from strix.telemetry import tracer as tracer_module
1111
from strix.telemetry import utils as telemetry_utils
1212
from strix.telemetry.tracer import Tracer, set_global_tracer
13+
from strix.tools.agents_graph import agents_graph_actions
1314

1415

1516
def _load_events(events_path: Path) -> list[dict[str, Any]]:
@@ -255,6 +256,75 @@ def test_events_with_agent_id_include_agent_name(monkeypatch, tmp_path) -> None:
255256
assert chat_event["actor"]["agent_name"] == "Root Agent"
256257

257258

259+
def test_get_total_llm_stats_includes_completed_subagents(monkeypatch, tmp_path) -> None:
260+
monkeypatch.chdir(tmp_path)
261+
262+
class DummyStats:
263+
def __init__(
264+
self,
265+
*,
266+
input_tokens: int,
267+
output_tokens: int,
268+
cached_tokens: int,
269+
cost: float,
270+
requests: int,
271+
) -> None:
272+
self.input_tokens = input_tokens
273+
self.output_tokens = output_tokens
274+
self.cached_tokens = cached_tokens
275+
self.cost = cost
276+
self.requests = requests
277+
278+
class DummyLLM:
279+
def __init__(self, stats: DummyStats) -> None:
280+
self._total_stats = stats
281+
282+
class DummyAgent:
283+
def __init__(self, stats: DummyStats) -> None:
284+
self.llm = DummyLLM(stats)
285+
286+
tracer = Tracer("cost-rollup")
287+
set_global_tracer(tracer)
288+
289+
monkeypatch.setattr(
290+
agents_graph_actions,
291+
"_agent_instances",
292+
{
293+
"root-agent": DummyAgent(
294+
DummyStats(
295+
input_tokens=1_000,
296+
output_tokens=250,
297+
cached_tokens=100,
298+
cost=0.12345,
299+
requests=2,
300+
)
301+
)
302+
},
303+
)
304+
monkeypatch.setattr(
305+
agents_graph_actions,
306+
"_completed_agent_llm_totals",
307+
{
308+
"input_tokens": 2_000,
309+
"output_tokens": 500,
310+
"cached_tokens": 400,
311+
"cost": 0.54321,
312+
"requests": 3,
313+
},
314+
)
315+
316+
stats = tracer.get_total_llm_stats()
317+
318+
assert stats["total"] == {
319+
"input_tokens": 3_000,
320+
"output_tokens": 750,
321+
"cached_tokens": 500,
322+
"cost": 0.6667,
323+
"requests": 5,
324+
}
325+
assert stats["total_tokens"] == 3_750
326+
327+
258328
def test_run_metadata_is_only_on_run_lifecycle_events(monkeypatch, tmp_path) -> None:
259329
monkeypatch.chdir(tmp_path)
260330

tests/tools/test_agents_graph_whitebox.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,24 @@
55
from strix.tools.agents_graph import agents_graph_actions
66

77

8-
def test_create_agent_inherits_parent_whitebox_flag(monkeypatch) -> None:
9-
monkeypatch.setenv("STRIX_LLM", "openai/gpt-5")
10-
8+
def _reset_agent_graph_state() -> None:
119
agents_graph_actions._agent_graph["nodes"].clear()
1210
agents_graph_actions._agent_graph["edges"].clear()
1311
agents_graph_actions._agent_messages.clear()
1412
agents_graph_actions._running_agents.clear()
1513
agents_graph_actions._agent_instances.clear()
14+
agents_graph_actions._completed_agent_llm_totals.clear()
15+
agents_graph_actions._completed_agent_llm_totals.update(
16+
agents_graph_actions._empty_llm_stats_totals()
17+
)
1618
agents_graph_actions._agent_states.clear()
1719

20+
21+
def test_create_agent_inherits_parent_whitebox_flag(monkeypatch) -> None:
22+
monkeypatch.setenv("STRIX_LLM", "openai/gpt-5")
23+
24+
_reset_agent_graph_state()
25+
1826
parent_id = "parent-agent"
1927
parent_llm = LLMConfig(timeout=123, scan_mode="standard", is_whitebox=True)
2028
agents_graph_actions._agent_instances[parent_id] = SimpleNamespace(
@@ -66,12 +74,7 @@ def start(self) -> None:
6674
def test_delegation_prompt_includes_wiki_memory_instruction_in_whitebox(monkeypatch) -> None:
6775
monkeypatch.setenv("STRIX_LLM", "openai/gpt-5")
6876

69-
agents_graph_actions._agent_graph["nodes"].clear()
70-
agents_graph_actions._agent_graph["edges"].clear()
71-
agents_graph_actions._agent_messages.clear()
72-
agents_graph_actions._running_agents.clear()
73-
agents_graph_actions._agent_instances.clear()
74-
agents_graph_actions._agent_states.clear()
77+
_reset_agent_graph_state()
7578

7679
parent_id = "parent-1"
7780
child_id = "child-1"
@@ -116,12 +119,7 @@ async def agent_loop(self, _task: str) -> dict[str, bool]:
116119
def test_agent_finish_appends_wiki_update_for_whitebox(monkeypatch) -> None:
117120
monkeypatch.setenv("STRIX_LLM", "openai/gpt-5")
118121

119-
agents_graph_actions._agent_graph["nodes"].clear()
120-
agents_graph_actions._agent_graph["edges"].clear()
121-
agents_graph_actions._agent_messages.clear()
122-
agents_graph_actions._running_agents.clear()
123-
agents_graph_actions._agent_instances.clear()
124-
agents_graph_actions._agent_states.clear()
122+
_reset_agent_graph_state()
125123

126124
parent_id = "parent-2"
127125
child_id = "child-2"
@@ -192,12 +190,7 @@ def fake_append_note_content(note_id: str, delta: str):
192190
def test_run_agent_in_thread_injects_shared_wiki_context_in_whitebox(monkeypatch) -> None:
193191
monkeypatch.setenv("STRIX_LLM", "openai/gpt-5")
194192

195-
agents_graph_actions._agent_graph["nodes"].clear()
196-
agents_graph_actions._agent_graph["edges"].clear()
197-
agents_graph_actions._agent_messages.clear()
198-
agents_graph_actions._running_agents.clear()
199-
agents_graph_actions._agent_instances.clear()
200-
agents_graph_actions._agent_states.clear()
193+
_reset_agent_graph_state()
201194

202195
parent_id = "parent-3"
203196
child_id = "child-3"

0 commit comments

Comments
 (0)