Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,42 @@ def _set_span_attribute(span: Span, key: str, value: Any) -> None:
span.set_attribute(key, "")


def _get_token_count(usage: dict[str, Any], *keys: str) -> int:
for key in keys:
value = usage.get(key)
if isinstance(value, (int, float)):
return int(value)
return 0


def _extract_message_usage(message: BaseMessage) -> dict[str, Any] | None:
usage_metadata = getattr(message, "usage_metadata", None)
if usage_metadata is not None:
return usage_metadata

response_metadata = getattr(message, "response_metadata", None)
if not isinstance(response_metadata, dict):
return None

response_usage = response_metadata.get("usage")
if isinstance(response_usage, dict):
return response_usage

if any(
key in response_metadata
for key in (
"prompt_tokens",
"completion_tokens",
"total_tokens",
"input_tokens",
"output_tokens",
)
):
return response_metadata

return None


def _content_to_parts(content) -> list[dict]:
"""Convert LangChain message content (str or list-of-blocks) into OTel parts."""
if isinstance(content, str):
Expand Down Expand Up @@ -405,25 +441,26 @@ def set_chat_response_usage(
for generation in generations:
if (
hasattr(generation, "message")
and hasattr(generation.message, "usage_metadata")
and generation.message.usage_metadata is not None
and (usage := _extract_message_usage(generation.message)) is not None
):
input_tokens += (
generation.message.usage_metadata.get("input_tokens")
or generation.message.usage_metadata.get("prompt_tokens")
or 0
generation_input_tokens = _get_token_count(
usage, "input_tokens", "prompt_tokens", "input_token_count"
)
generation_output_tokens = _get_token_count(
usage, "output_tokens", "completion_tokens", "generated_token_count"
)
output_tokens += (
generation.message.usage_metadata.get("output_tokens")
or generation.message.usage_metadata.get("completion_tokens")
or 0
generation_total_tokens = _get_token_count(
usage, "total_tokens", "total_token_count"
)

input_tokens += generation_input_tokens
output_tokens += generation_output_tokens
total_tokens += generation_total_tokens or (
generation_input_tokens + generation_output_tokens
)
total_tokens = input_tokens + output_tokens

if generation.message.usage_metadata.get("input_token_details"):
input_token_details = generation.message.usage_metadata.get(
"input_token_details", {}
)
if usage.get("input_token_details"):
input_token_details = usage.get("input_token_details", {})
raw_cache_read = input_token_details.get("cache_read")
if isinstance(raw_cache_read, (int, float)):
cache_read_tokens = (cache_read_tokens or 0) + raw_cache_read
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
from unittest.mock import Mock

from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from opentelemetry.instrumentation.langchain.span_utils import set_chat_response_usage
from opentelemetry.semconv._incubating.attributes import gen_ai_attributes as GenAIAttributes
from opentelemetry.semconv_ai import SpanAttributes


def _mock_span():
span = Mock()
span.is_recording.return_value = True
span.attributes = {}

def set_attribute(key, value):
span.attributes[key] = value

span.set_attribute = set_attribute
return span


@pytest.mark.parametrize(
"response_metadata",
[
{
"usage": {
"prompt_tokens": 10,
"completion_tokens": 16,
"total_tokens": 26,
}
},
{
"prompt_tokens": 10,
"completion_tokens": 16,
"total_tokens": 26,
},
{
"usage": {
"prompt_tokens": 10,
"completion_tokens": 16,
}
},
{
"usage": {
"prompt_tokens": 10,
"completion_tokens": 16,
"total_token_count": 26,
}
},
],
)
def test_chat_response_usage_reads_databricks_response_metadata(response_metadata):
span = _mock_span()
response = LLMResult(
generations=[
[
ChatGeneration(
message=AIMessage(
content="Hello!",
response_metadata=response_metadata,
)
)
]
]
)

set_chat_response_usage(
span,
response,
token_histogram=Mock(),
record_token_usage=False,
model_name="databricks-claude-sonnet",
)

assert span.attributes[GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS] == 10
assert span.attributes[GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS] == 16
assert span.attributes[SpanAttributes.GEN_AI_USAGE_TOTAL_TOKENS] == 26
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def extract_response_id(raw: Any) -> Optional[str]:


def extract_token_usage(raw: Any) -> TokenUsage:
"""Extract token usage from raw response. Handles OpenAI, Anthropic, Cohere, and dict formats."""
"""Extract token usage from raw response. Handles OpenAI, Anthropic, Cohere, VertexAI, and dict formats."""
usage = _get_nested(raw, "usage")
if usage:
result = _extract_openai_usage(usage)
Expand All @@ -75,6 +75,12 @@ def extract_token_usage(raw: Any) -> TokenUsage:
if result.input_tokens is not None:
return result

usage_metadata = _get_nested(raw, "usage_metadata") or _get_nested(raw, "usageMetadata")
if usage_metadata:
result = _extract_google_usage_metadata(usage_metadata)
if result.input_tokens is not None:
return result

meta = _get_nested(raw, "meta")
if meta:
return _extract_cohere_usage(meta)
Expand Down Expand Up @@ -146,6 +152,29 @@ def _extract_anthropic_usage(usage: Any) -> TokenUsage:
return TokenUsage()


def _extract_google_usage_metadata(usage_metadata: Any) -> TokenUsage:
"""Extract tokens from Google Gemini / VertexAI usage_metadata."""
input_tokens = _get_int(
usage_metadata, "prompt_token_count", "promptTokenCount"
)
output_tokens = _get_int(
usage_metadata, "candidates_token_count", "candidatesTokenCount"
)
total_tokens = _get_int(
usage_metadata, "total_token_count", "totalTokenCount"
)

return TokenUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=(
total_tokens
if total_tokens is not None
else _safe_sum(input_tokens, output_tokens)
),
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def _extract_cohere_usage(meta: Any) -> TokenUsage:
"""Extract tokens from Cohere-style meta.tokens or meta.billed_units."""
tokens = _get_nested(meta, "tokens")
Expand All @@ -165,12 +194,15 @@ def _extract_cohere_usage(meta: Any) -> TokenUsage:
return TokenUsage()


def _get_int(obj: Any, key: str) -> Optional[int]:
def _get_int(obj: Any, *keys: str) -> Optional[int]:
"""Get an integer attribute or dict key from obj."""
val = getattr(obj, key, None)
if val is None and isinstance(obj, dict):
val = obj.get(key)
return int(val) if val is not None else None
for key in keys:
val = getattr(obj, key, None)
if val is None and isinstance(obj, dict):
val = obj.get(key)
if val is not None:
return int(val)
return None


def _safe_sum(a: Optional[int], b: Optional[int]) -> Optional[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,106 @@ def test_openai_format_dict(self):
result = extract_token_usage(raw)
assert result == TokenUsage(input_tokens=10, output_tokens=20, total_tokens=30)

@pytest.mark.parametrize(
"raw, expected_total_tokens",
[
(
{
"usage_metadata": {
"prompt_token_count": 10,
"candidates_token_count": 20,
"total_token_count": 30,
}
},
30,
),
(
SimpleNamespace(
usage_metadata=SimpleNamespace(
prompt_token_count=10,
candidates_token_count=20,
total_token_count=30,
)
),
30,
),
(
{
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 20,
"totalTokenCount": 30,
}
},
30,
),
(
{
"usage_metadata": {
"prompt_token_count": 10,
"candidates_token_count": 20,
"total_token_count": 0,
}
},
0,
),
(
SimpleNamespace(
usage_metadata=SimpleNamespace(
prompt_token_count=10,
candidates_token_count=20,
total_token_count=0,
)
),
0,
),
(
{
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 20,
"totalTokenCount": 0,
}
},
0,
),
(
{
"usage_metadata": {
"prompt_token_count": 10,
"candidates_token_count": 20,
}
},
30,
),
(
SimpleNamespace(
usage_metadata=SimpleNamespace(
prompt_token_count=10,
candidates_token_count=20,
)
),
30,
),
(
{
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 20,
}
},
30,
),
],
)
def test_google_vertexai_usage_metadata(self, raw, expected_total_tokens):
result = extract_token_usage(raw)
assert result == TokenUsage(
input_tokens=10,
output_tokens=20,
total_tokens=expected_total_tokens,
)

def test_cohere_meta_tokens_format(self):
raw = SimpleNamespace(
meta=SimpleNamespace(tokens=SimpleNamespace(input_tokens=5, output_tokens=15))
Expand Down