Skip to content

Commit 9a47b6f

Browse files
caohy1988claude
andcommitted
fix(telemetry): ensure all model callbacks observe the same call_llm span
Move before_model_callback inside the call_llm span, wrap after_model_callback and on_model_error_callback with trace.use_span(span, end_on_exit=False) to rebind to the call_llm span. Thread call_llm_span into _run_and_handle_error so error callbacks also run under the correct span context. Fixes #4851 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f973673 commit 9a47b6f

File tree

2 files changed

+319
-37
lines changed

2 files changed

+319
-37
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from google.adk.platform import time as platform_time
2626
from google.genai import types
27+
from opentelemetry import trace
2728
from websockets.exceptions import ConnectionClosed
2829
from websockets.exceptions import ConnectionClosedOK
2930

@@ -306,6 +307,7 @@ async def _run_and_handle_error(
306307
invocation_context: InvocationContext,
307308
llm_request: LlmRequest,
308309
model_response_event: Event,
310+
call_llm_span: Optional[trace.Span] = None,
309311
) -> AsyncGenerator[LlmResponse, None]:
310312
"""Wraps an LLM response generator with error callback handling.
311313
@@ -319,6 +321,9 @@ async def _run_and_handle_error(
319321
invocation_context: The invocation context.
320322
llm_request: The LLM request.
321323
model_response_event: The model response event.
324+
call_llm_span: The call_llm span to rebind error callbacks to.
325+
When provided, on_model_error callbacks run under this span so
326+
plugins observe the same span as before/after model callbacks.
322327
323328
Yields:
324329
LlmResponse objects from the generator.
@@ -380,11 +385,19 @@ async def _run_on_model_error_callbacks(
380385
callback_context = CallbackContext(
381386
invocation_context, event_actions=model_response_event.actions
382387
)
383-
error_response = await _run_on_model_error_callbacks(
384-
callback_context=callback_context,
385-
llm_request=llm_request,
386-
error=model_error,
387-
)
388+
if call_llm_span is not None:
389+
with trace.use_span(call_llm_span, end_on_exit=False):
390+
error_response = await _run_on_model_error_callbacks(
391+
callback_context=callback_context,
392+
llm_request=llm_request,
393+
error=model_error,
394+
)
395+
else:
396+
error_response = await _run_on_model_error_callbacks(
397+
callback_context=callback_context,
398+
llm_request=llm_request,
399+
error=model_error,
400+
)
388401
if error_response is not None:
389402
yield error_response
390403
else:
@@ -1102,28 +1115,30 @@ async def _call_llm_async(
11021115
llm_request: LlmRequest,
11031116
model_response_event: Event,
11041117
) -> AsyncGenerator[LlmResponse, None]:
1105-
# Runs before_model_callback if it exists.
1106-
if response := await self._handle_before_model_callback(
1107-
invocation_context, llm_request, model_response_event
1108-
):
1109-
yield response
1110-
return
1111-
1112-
llm_request.config = llm_request.config or types.GenerateContentConfig()
1113-
llm_request.config.labels = llm_request.config.labels or {}
1114-
1115-
# Add agent name as a label to the llm_request. This will help with slicing
1116-
# the billing reports on a per-agent basis.
1117-
if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels:
1118-
llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = (
1119-
invocation_context.agent.name
1120-
)
1121-
1122-
# Calls the LLM.
1123-
llm = self.__get_llm(invocation_context)
11241118

11251119
async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11261120
with tracer.start_as_current_span('call_llm') as span:
1121+
# Runs before_model_callback inside the call_llm span so
1122+
# plugins observe the same span as after/error callbacks.
1123+
if response := await self._handle_before_model_callback(
1124+
invocation_context, llm_request, model_response_event
1125+
):
1126+
yield response
1127+
return
1128+
1129+
llm_request.config = llm_request.config or types.GenerateContentConfig()
1130+
llm_request.config.labels = llm_request.config.labels or {}
1131+
1132+
# Add agent name as a label to the llm_request. This will help
1133+
# with slicing billing reports on a per-agent basis.
1134+
if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels:
1135+
llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = (
1136+
invocation_context.agent.name
1137+
)
1138+
1139+
# Calls the LLM.
1140+
llm = self.__get_llm(invocation_context)
1141+
11271142
if invocation_context.run_config.support_cfc:
11281143
invocation_context.live_request_queue = LiveRequestQueue()
11291144
responses_generator = self.run_live(invocation_context)
@@ -1133,14 +1148,20 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11331148
invocation_context,
11341149
llm_request,
11351150
model_response_event,
1151+
call_llm_span=span,
11361152
)
11371153
) as agen:
11381154
async for llm_response in agen:
1139-
# Runs after_model_callback if it exists.
1140-
if altered_llm_response := await self._handle_after_model_callback(
1141-
invocation_context, llm_response, model_response_event
1142-
):
1143-
llm_response = altered_llm_response
1155+
# Rebind to call_llm span for after_model_callback.
1156+
with trace.use_span(span, end_on_exit=False):
1157+
if altered := (
1158+
await self._handle_after_model_callback(
1159+
invocation_context,
1160+
llm_response,
1161+
model_response_event,
1162+
)
1163+
):
1164+
llm_response = altered
11441165
# only yield partial response in SSE streaming mode
11451166
if (
11461167
invocation_context.run_config.streaming_mode
@@ -1151,9 +1172,9 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11511172
if llm_response.turn_complete:
11521173
invocation_context.live_request_queue.close()
11531174
else:
1154-
# Check if we can make this llm call or not. If the current call
1155-
# pushes the counter beyond the max set value, then the execution is
1156-
# stopped right here, and exception is thrown.
1175+
# Check if we can make this llm call or not. If the current
1176+
# call pushes the counter beyond the max set value, then the
1177+
# execution is stopped right here, and exception is thrown.
11571178
invocation_context.increment_llm_call_count()
11581179
responses_generator = llm.generate_content_async(
11591180
llm_request,
@@ -1166,6 +1187,7 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11661187
invocation_context,
11671188
llm_request,
11681189
model_response_event,
1190+
call_llm_span=span,
11691191
)
11701192
) as agen:
11711193
async for llm_response in agen:
@@ -1176,11 +1198,16 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11761198
llm_response,
11771199
span,
11781200
)
1179-
# Runs after_model_callback if it exists.
1180-
if altered_llm_response := await self._handle_after_model_callback(
1181-
invocation_context, llm_response, model_response_event
1182-
):
1183-
llm_response = altered_llm_response
1201+
# Rebind to call_llm span for after_model_callback.
1202+
with trace.use_span(span, end_on_exit=False):
1203+
if altered := (
1204+
await self._handle_after_model_callback(
1205+
invocation_context,
1206+
llm_response,
1207+
model_response_event,
1208+
)
1209+
):
1210+
llm_response = altered
11841211

11851212
yield llm_response
11861213

@@ -1235,13 +1262,15 @@ async def _run_and_handle_error(
12351262
invocation_context: InvocationContext,
12361263
llm_request: LlmRequest,
12371264
model_response_event: Event,
1265+
call_llm_span: Optional[trace.Span] = None,
12381266
) -> AsyncGenerator[LlmResponse, None]:
12391267
async with Aclosing(
12401268
_run_and_handle_error(
12411269
response_generator,
12421270
invocation_context,
12431271
llm_request,
12441272
model_response_event,
1273+
call_llm_span=call_llm_span,
12451274
)
12461275
) as agen:
12471276
async for response in agen:

0 commit comments

Comments
 (0)