Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/strands/telemetry/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,16 +689,19 @@ def end_agent_span(
if hasattr(response, "metrics") and hasattr(response.metrics, "accumulated_usage"):
if self.is_langfuse:
attributes.update({"langfuse.observation.type": "span"})
accumulated_usage = response.metrics.accumulated_usage
# Use the latest invocation's usage so each agent span reports only its
# own tokens, not the session-lifetime accumulated total (issue #2010).
invocation = response.metrics.latest_agent_invocation
usage = invocation.usage if invocation is not None else response.metrics.accumulated_usage
attributes.update(
{
"gen_ai.usage.prompt_tokens": accumulated_usage["inputTokens"],
"gen_ai.usage.completion_tokens": accumulated_usage["outputTokens"],
"gen_ai.usage.input_tokens": accumulated_usage["inputTokens"],
"gen_ai.usage.output_tokens": accumulated_usage["outputTokens"],
"gen_ai.usage.total_tokens": accumulated_usage["totalTokens"],
"gen_ai.usage.cache_read_input_tokens": accumulated_usage.get("cacheReadInputTokens", 0),
"gen_ai.usage.cache_write_input_tokens": accumulated_usage.get("cacheWriteInputTokens", 0),
"gen_ai.usage.prompt_tokens": usage["inputTokens"],
"gen_ai.usage.completion_tokens": usage["outputTokens"],
"gen_ai.usage.input_tokens": usage["inputTokens"],
"gen_ai.usage.output_tokens": usage["outputTokens"],
"gen_ai.usage.total_tokens": usage["totalTokens"],
"gen_ai.usage.cache_read_input_tokens": usage.get("cacheReadInputTokens", 0),
"gen_ai.usage.cache_write_input_tokens": usage.get("cacheWriteInputTokens", 0),
}
)

Expand Down
41 changes: 41 additions & 0 deletions tests/strands/telemetry/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ def test_end_agent_span(mock_span):
# Mock AgentResult with metrics
mock_metrics = mock.MagicMock()
mock_metrics.accumulated_usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150}
mock_metrics.latest_agent_invocation.usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150}

mock_response = mock.MagicMock()
mock_response.metrics = mock_metrics
Expand Down Expand Up @@ -958,6 +959,7 @@ def test_end_agent_span_with_langfuse_observation_type(mock_span, monkeypatch):
# Mock AgentResult with metrics
mock_metrics = mock.MagicMock()
mock_metrics.accumulated_usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150}
mock_metrics.latest_agent_invocation.usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150}

mock_response = mock.MagicMock()
mock_response.metrics = mock_metrics
Expand Down Expand Up @@ -994,6 +996,7 @@ def test_end_agent_span_latest_conventions(mock_span, monkeypatch):
# Mock AgentResult with metrics
mock_metrics = mock.MagicMock()
mock_metrics.accumulated_usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150}
mock_metrics.latest_agent_invocation.usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150}

mock_response = mock.MagicMock()
mock_response.metrics = mock_metrics
Expand Down Expand Up @@ -1077,6 +1080,13 @@ def test_end_agent_span_with_cache_metrics(mock_span):
"cacheReadInputTokens": 25,
"cacheWriteInputTokens": 10,
}
mock_metrics.latest_agent_invocation.usage = {
"inputTokens": 50,
"outputTokens": 100,
"totalTokens": 150,
"cacheReadInputTokens": 25,
"cacheWriteInputTokens": 10,
}

mock_response = mock.MagicMock()
mock_response.metrics = mock_metrics
Expand All @@ -1100,6 +1110,37 @@ def test_end_agent_span_with_cache_metrics(mock_span):
mock_span.end.assert_called_once()


def test_end_agent_span_uses_invocation_not_accumulated_usage(mock_span):
"""Test that the agent span reports per-invocation usage, not session-lifetime accumulated usage."""
tracer = Tracer()

mock_metrics = mock.MagicMock()
# Accumulated usage is larger (simulating a multi-request session)
mock_metrics.accumulated_usage = {"inputTokens": 300, "outputTokens": 600, "totalTokens": 900}
# Latest invocation used only 100/200/300 tokens
mock_metrics.latest_agent_invocation.usage = {"inputTokens": 100, "outputTokens": 200, "totalTokens": 300}

mock_response = mock.MagicMock()
mock_response.metrics = mock_metrics
mock_response.stop_reason = "end_turn"
mock_response.__str__ = mock.MagicMock(return_value="Agent response")

tracer.end_agent_span(mock_span, mock_response)

# Span should report the per-invocation values, not the inflated accumulated values
mock_span.set_attributes.assert_called_once_with(
{
"gen_ai.usage.prompt_tokens": 100,
"gen_ai.usage.input_tokens": 100,
"gen_ai.usage.completion_tokens": 200,
"gen_ai.usage.output_tokens": 200,
"gen_ai.usage.total_tokens": 300,
"gen_ai.usage.cache_read_input_tokens": 0,
"gen_ai.usage.cache_write_input_tokens": 0,
}
)


def test_get_tracer_singleton():
"""Test that get_tracer returns a singleton instance."""
# Reset the singleton first
Expand Down