Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 175 additions & 67 deletions sentry_sdk/integrations/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@
from sentry_sdk.tracing import Span
from sentry_sdk._types import TextPart

from openai.types.responses import ResponseInputParam, SequenceNotStr
from openai.types.responses import ResponseStreamEvent
from openai.types.responses import (
ResponseInputParam,
SequenceNotStr,
ResponseStreamEvent,
)
from openai.types import CompletionUsage
from openai import Omit

try:
Expand Down Expand Up @@ -144,44 +148,48 @@ def _capture_exception(exc: "Any", manual_span_cleanup: bool = True) -> None:
sentry_sdk.capture_event(event, hint=hint)


def _get_usage(usage: "Any", names: "List[str]") -> int:
for name in names:
if hasattr(usage, name) and isinstance(getattr(usage, name), int):
return getattr(usage, name)
return 0


def _calculate_token_usage(
def _calculate_completions_token_usage(
messages: "Optional[Iterable[ChatCompletionMessageParam]]",
response: "Any",
span: "Span",
streaming_message_responses: "Optional[List[str]]",
streaming_message_token_usage: "Optional[CompletionUsage]",
count_tokens: "Callable[..., Any]",
) -> None:
"""Extract and record token usage from a Chat Completions API response."""
input_tokens: "Optional[int]" = 0
input_tokens_cached: "Optional[int]" = 0
output_tokens: "Optional[int]" = 0
output_tokens_reasoning: "Optional[int]" = 0
total_tokens: "Optional[int]" = 0

if hasattr(response, "usage"):
input_tokens = _get_usage(response.usage, ["input_tokens", "prompt_tokens"])
if hasattr(response.usage, "input_tokens_details"):
input_tokens_cached = _get_usage(
response.usage.input_tokens_details, ["cached_tokens"]
usage = None

if streaming_message_token_usage:
usage = streaming_message_token_usage
elif hasattr(response, "usage"):
usage = response.usage

if usage is not None:
if hasattr(usage, "prompt_tokens") and isinstance(usage.prompt_tokens, int):
input_tokens = usage.prompt_tokens
if hasattr(usage, "prompt_tokens_details"):
cached = getattr(usage.prompt_tokens_details, "cached_tokens", None)
if isinstance(cached, int):
input_tokens_cached = cached
if hasattr(usage, "completion_tokens") and isinstance(
usage.completion_tokens, int
):
output_tokens = usage.completion_tokens
if hasattr(usage, "completion_tokens_details"):
reasoning = getattr(
usage.completion_tokens_details, "reasoning_tokens", None
)
if isinstance(reasoning, int):
output_tokens_reasoning = reasoning
if hasattr(usage, "total_tokens") and isinstance(usage.total_tokens, int):
total_tokens = usage.total_tokens

output_tokens = _get_usage(
response.usage, ["output_tokens", "completion_tokens"]
)
if hasattr(response.usage, "output_tokens_details"):
output_tokens_reasoning = _get_usage(
response.usage.output_tokens_details, ["reasoning_tokens"]
)

total_tokens = _get_usage(response.usage, ["total_tokens"])

# Manually count tokens
# Manually count input tokens
if input_tokens == 0:
for message in messages or []:
if isinstance(message, str):
Expand All @@ -191,11 +199,11 @@ def _calculate_token_usage(
message_content = message.get("content")
if message_content is None:
continue
# Deliberate use of Completions function for both Completions and Responses input format.
text_items = _get_text_items(message_content)
input_tokens += sum(count_tokens(text) for text in text_items)
continue

# Manually count output tokens
if output_tokens == 0:
if streaming_message_responses is not None:
for message in streaming_message_responses:
Expand All @@ -222,6 +230,75 @@ def _calculate_token_usage(
)


def _calculate_responses_token_usage(
input: "Any",
response: "Any",
span: "Span",
streaming_message_responses: "Optional[List[str]]",
count_tokens: "Callable[..., Any]",
) -> None:
"""Extract and record token usage from a Responses API response."""
input_tokens: "Optional[int]" = 0
input_tokens_cached: "Optional[int]" = 0
output_tokens: "Optional[int]" = 0
output_tokens_reasoning: "Optional[int]" = 0
total_tokens: "Optional[int]" = 0

if hasattr(response, "usage"):
usage = response.usage
if hasattr(usage, "input_tokens") and isinstance(usage.input_tokens, int):
input_tokens = usage.input_tokens
if hasattr(usage, "input_tokens_details"):
cached = getattr(usage.input_tokens_details, "cached_tokens", None)
if isinstance(cached, int):
input_tokens_cached = cached
if hasattr(usage, "output_tokens") and isinstance(usage.output_tokens, int):
output_tokens = usage.output_tokens
if hasattr(usage, "output_tokens_details"):
reasoning = getattr(usage.output_tokens_details, "reasoning_tokens", None)
if isinstance(reasoning, int):
output_tokens_reasoning = reasoning
if hasattr(usage, "total_tokens") and isinstance(usage.total_tokens, int):
total_tokens = usage.total_tokens

# Manually count input tokens
if input_tokens == 0:
for message in input or []:
if isinstance(message, str):
input_tokens += count_tokens(message)
continue
elif isinstance(message, dict):
message_content = message.get("content")
if message_content is None:
continue
# Deliberate use of Completions function for both Completions and Responses input format.
text_items = _get_text_items(message_content)
input_tokens += sum(count_tokens(text) for text in text_items)
continue

# Manually count output tokens
if output_tokens == 0:
if streaming_message_responses is not None:
for message in streaming_message_responses:
output_tokens += count_tokens(message)

# Do not set token data if it is 0
input_tokens = input_tokens or None
input_tokens_cached = input_tokens_cached or None
output_tokens = output_tokens or None
output_tokens_reasoning = output_tokens_reasoning or None
total_tokens = total_tokens or None

record_token_usage(
span,
input_tokens=input_tokens,
input_tokens_cached=input_tokens_cached,
output_tokens=output_tokens,
output_tokens_reasoning=output_tokens_reasoning,
total_tokens=total_tokens,
)


def _set_responses_api_input_data(
span: "Span",
kwargs: "dict[str, Any]",
Expand Down Expand Up @@ -486,6 +563,7 @@ def _set_common_output_data(
if hasattr(response, "model"):
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, response.model)

# Chat Completions API
if hasattr(response, "choices"):
if should_send_default_pii() and integration.include_prompts:
response_text = [
Expand All @@ -496,11 +574,19 @@ def _set_common_output_data(
if len(response_text) > 0:
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_text)

_calculate_token_usage(input, response, span, None, integration.count_tokens)
_calculate_completions_token_usage(
messages=input,
response=response,
span=span,
streaming_message_responses=None,
streaming_message_token_usage=None,
count_tokens=integration.count_tokens,
)

if finish_span:
span.__exit__(None, None, None)

# Responses API
elif hasattr(response, "output"):
if should_send_default_pii() and integration.include_prompts:
output_messages: "dict[str, list[Any]]" = {
Expand Down Expand Up @@ -532,12 +618,26 @@ def _set_common_output_data(
span, SPANDATA.GEN_AI_RESPONSE_TEXT, output_messages["response"]
)

_calculate_token_usage(input, response, span, None, integration.count_tokens)
_calculate_responses_token_usage(
input=input,
response=response,
span=span,
streaming_message_responses=None,
count_tokens=integration.count_tokens,
)

if finish_span:
span.__exit__(None, None, None)
# Embeddings API (fallback for responses with neither choices nor output)
else:
_calculate_token_usage(input, response, span, None, integration.count_tokens)
_calculate_completions_token_usage(
messages=input,
response=response,
span=span,
streaming_message_responses=None,
streaming_message_token_usage=None,
count_tokens=integration.count_tokens,
)
if finish_span:
span.__exit__(None, None, None)

Expand Down Expand Up @@ -655,6 +755,7 @@ def _wrap_synchronous_completions_chunk_iterator(
"""
ttft = None
data_buf: "list[list[str]]" = [] # one for each choice
streaming_message_token_usage = None

for x in old_iterator:
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, x.model)
Expand All @@ -671,6 +772,8 @@ def _wrap_synchronous_completions_chunk_iterator(
data_buf.append([])
data_buf[choice_index].append(content or "")
choice_index += 1
if hasattr(x, "usage"):
streaming_message_token_usage = x.usage

yield x

Expand All @@ -683,12 +786,13 @@ def _wrap_synchronous_completions_chunk_iterator(
all_responses = ["".join(chunk) for chunk in data_buf]
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses)
_calculate_token_usage(
messages,
response,
span,
all_responses,
integration.count_tokens,
_calculate_completions_token_usage(
messages=messages,
response=response,
span=span,
streaming_message_responses=all_responses,
streaming_message_token_usage=streaming_message_token_usage,
count_tokens=integration.count_tokens,
)

if finish_span:
Expand All @@ -711,6 +815,7 @@ async def _wrap_asynchronous_completions_chunk_iterator(
"""
ttft = None
data_buf: "list[list[str]]" = [] # one for each choice
streaming_message_token_usage = None

async for x in old_iterator:
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, x.model)
Expand All @@ -727,6 +832,8 @@ async def _wrap_asynchronous_completions_chunk_iterator(
data_buf.append([])
data_buf[choice_index].append(content or "")
choice_index += 1
if hasattr(x, "usage"):
streaming_message_token_usage = x.usage

yield x

Expand All @@ -739,12 +846,13 @@ async def _wrap_asynchronous_completions_chunk_iterator(
all_responses = ["".join(chunk) for chunk in data_buf]
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses)
_calculate_token_usage(
messages,
response,
span,
all_responses,
integration.count_tokens,
_calculate_completions_token_usage(
messages=messages,
response=response,
span=span,
streaming_message_responses=all_responses,
streaming_message_token_usage=streaming_message_token_usage,
count_tokens=integration.count_tokens,
)

if finish_span:
Expand Down Expand Up @@ -781,12 +889,12 @@ def _wrap_synchronous_responses_event_iterator(
if isinstance(x, ResponseCompletedEvent):
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, x.response.model)

_calculate_token_usage(
input,
x.response,
span,
None,
integration.count_tokens,
_calculate_responses_token_usage(
input=input,
response=x.response,
span=span,
streaming_message_responses=None,
count_tokens=integration.count_tokens,
)
count_tokens_manually = False

Expand All @@ -802,12 +910,12 @@ def _wrap_synchronous_responses_event_iterator(
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses)
if count_tokens_manually:
_calculate_token_usage(
input,
response,
span,
all_responses,
integration.count_tokens,
_calculate_responses_token_usage(
input=input,
response=response,
span=span,
streaming_message_responses=all_responses,
count_tokens=integration.count_tokens,
)

if finish_span:
Expand Down Expand Up @@ -844,12 +952,12 @@ async def _wrap_asynchronous_responses_event_iterator(
if isinstance(x, ResponseCompletedEvent):
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, x.response.model)

_calculate_token_usage(
input,
x.response,
span,
None,
integration.count_tokens,
_calculate_responses_token_usage(
input=input,
response=x.response,
span=span,
streaming_message_responses=None,
count_tokens=integration.count_tokens,
)
count_tokens_manually = False

Expand All @@ -865,12 +973,12 @@ async def _wrap_asynchronous_responses_event_iterator(
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses)
if count_tokens_manually:
_calculate_token_usage(
input,
response,
span,
all_responses,
integration.count_tokens,
_calculate_responses_token_usage(
input=input,
response=response,
span=span,
streaming_message_responses=all_responses,
count_tokens=integration.count_tokens,
)
if finish_span:
span.__exit__(None, None, None)
Expand Down
Loading
Loading