Skip to content
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Custom Temporal Model Provider with streaming support for OpenAI agents."""
from __future__ import annotations

import time
import uuid
from typing import Any, List, Union, Optional, override

Expand All @@ -26,6 +27,7 @@
CodeInterpreterTool,
ImageGenerationTool,
)
from opentelemetry import metrics
from agents.computer import Computer, AsyncComputer

# Re-export the canonical StreamingMode literal from the streaming service so
Expand Down Expand Up @@ -78,6 +80,63 @@
logger = make_logger("agentex.temporal.streaming")


# OTel metrics for LLM streaming behavior. Instruments are created lazily on
# first use so the meter resolves to whatever MeterProvider the application
# eventually configures, even if that happens after this module is imported.
# All metrics carry only a ``model`` attribute to keep cardinality bounded;
# resource attributes (service.name, k8s.*, etc.) come from the application's
# OTel resource configuration.
class _StreamingMetrics:
"""Lazily-created OTel instruments for streaming LLM telemetry."""

def __init__(self) -> None:
meter = metrics.get_meter("agentex.openai_agents.streaming")
self.ttft_ms = meter.create_histogram(
name="agentex.llm.ttft",
unit="ms",
description="Time from request submission to first content token (ms)",
)
# Note: TPS denominator is the model-generation window
# (last_token_time - first_token_time), not total stream wall time.
# This isolates raw model throughput from event-loop / tool-call latency.
self.tps = meter.create_histogram(
name="agentex.llm.tps",
unit="tokens/s",
description="Output tokens per second over the generation window",
)
self.input_tokens = meter.create_counter(
name="agentex.llm.input_tokens",
unit="tokens",
description="Total input tokens sent to the LLM",
)
self.output_tokens = meter.create_counter(
name="agentex.llm.output_tokens",
unit="tokens",
description="Total output tokens returned by the LLM",
)
self.cached_input_tokens = meter.create_counter(
name="agentex.llm.cached_input_tokens",
unit="tokens",
description="Subset of input tokens served from prompt cache",
)
self.reasoning_tokens = meter.create_counter(
name="agentex.llm.reasoning_tokens",
unit="tokens",
description="Output tokens spent on reasoning (subset of output_tokens)",
)


_streaming_metrics: Optional[_StreamingMetrics] = None


def _get_streaming_metrics() -> _StreamingMetrics:
"""Return the streaming metrics singleton, creating it on first use."""
global _streaming_metrics
if _streaming_metrics is None:
_streaming_metrics = _StreamingMetrics()
return _streaming_metrics


def _serialize_item(item: Any) -> dict[str, Any]:
"""
Universal serializer for any item type from OpenAI Agents SDK.
Expand Down Expand Up @@ -592,7 +651,11 @@ async def get_response(
# endpoints recognize this parameter, so we don't auto-inject a default.
prompt_cache_key = extra_args.pop("prompt_cache_key", NOT_GIVEN)

# Create the response stream using Responses API
# Create the response stream using Responses API.
# Bookmark request start *before* the await so ttft captures the full
# user-perceived latency (HTTP round-trip + model TTFB), not just the
# post-connect event-loop delay.
stream_start_perf = time.perf_counter()
logger.debug(f"[TemporalStreamingModel] Creating response stream with Responses API")
stream = await self.client.responses.create( # type: ignore[call-overload]

Expand Down Expand Up @@ -642,6 +705,12 @@ async def get_response(
reasoning_summaries = []
reasoning_contents = []
event_count = 0
# ttft / tps instrumentation. ``stream_start_perf`` is set above,
# before the responses.create() await, so it captures the full
# request-to-first-token latency. ``first_token_at`` and
# ``last_token_at`` bracket the model-generation window for tps.
first_token_at: Optional[float] = None
Comment thread
greptile-apps[bot] marked this conversation as resolved.
last_token_at: Optional[float] = None

# We expect task_id to always be provided for streaming
if not task_id:
Expand All @@ -656,6 +725,20 @@ async def get_response(
# Log event type
logger.debug(f"[TemporalStreamingModel] Event {event_count}: {type(event).__name__}")

# Bookmark first/last token-producing events for ttft and tps.
# Includes function-call argument deltas so the generation window
# covers every event type whose tokens land in usage.output_tokens.
if isinstance(event, (
ResponseTextDeltaEvent,
ResponseReasoningTextDeltaEvent,
ResponseReasoningSummaryTextDeltaEvent,
ResponseFunctionCallArgumentsDeltaEvent,
)):
now_perf = time.perf_counter()
if first_token_at is None:
first_token_at = now_perf
last_token_at = now_perf

# Handle different event types using isinstance for type safety
if isinstance(event, ResponseOutputItemAddedEvent):
# New output item (reasoning, function call, or message)
Expand Down Expand Up @@ -983,6 +1066,32 @@ async def get_response(

span.output = output_data

# Emit LLM metrics derived from the captured stream. The meter is a
# no-op if the application hasn't configured a MeterProvider, so this
# is safe to do unconditionally. We only emit ttft / tps when their
# input data is actually meaningful (got a content delta, got tokens).
m = _get_streaming_metrics()
metric_attrs = {"model": self.model_name}
m.input_tokens.add(usage.input_tokens or 0, metric_attrs)
m.output_tokens.add(usage.output_tokens or 0, metric_attrs)
m.cached_input_tokens.add(usage.input_tokens_details.cached_tokens or 0, metric_attrs)
m.reasoning_tokens.add(usage.output_tokens_details.reasoning_tokens or 0, metric_attrs)
if first_token_at is not None:
m.ttft_ms.record((first_token_at - stream_start_perf) * 1000, metric_attrs)
# tps denominator is the generation window (first→last delta), not
# total stream wall time — see _StreamingMetrics for rationale.
# Note: single-token responses (where first_token_at == last_token_at,
# e.g. a one-token tool-result acknowledgement) collapse the window
# to 0 and are intentionally skipped — TPS is undefined in that case.
if (
first_token_at is not None
and last_token_at is not None
and last_token_at > first_token_at
and (usage.output_tokens or 0) > 0
):
generation_window_s = last_token_at - first_token_at
m.tps.record(usage.output_tokens / generation_window_s, metric_attrs)
Comment thread
greptile-apps[bot] marked this conversation as resolved.
Outdated

# Return the response. response_id is the server-issued id from
# ResponseCompletedEvent.response.id, or None when the stream ended
# without a completed event (error path) — matching the documented
Expand Down
Loading