Skip to content

Commit b2daf83

Browse files
haiyuan-eng-googlecopybara-github
authored andcommitted
fix: sync callbacks with call_llm span
Moved `before_model_callback` inside the `call_llm` span. Wrapped both `after_model_callback` and error callbacks with `trace.use_span()` to bind callbacks to the correct overarching span. Added regression tests to verify span ID consistency. Co-authored-by: Haiyuan Cao <haiyuan@google.com> PiperOrigin-RevId: 891952170
1 parent 92cad99 commit b2daf83

File tree

2 files changed

+454
-40
lines changed

2 files changed

+454
-40
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 67 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from google.adk.platform import time as platform_time
2626
from google.genai import types
27+
from opentelemetry import trace
2728
from websockets.exceptions import ConnectionClosed
2829
from websockets.exceptions import ConnectionClosedOK
2930

@@ -259,9 +260,7 @@ async def _maybe_add_grounding_metadata(
259260
tools = await agent.canonical_tools(readonly_context)
260261
invocation_context.canonical_tools_cache = tools
261262

262-
if not any(
263-
getattr(tool, 'propagate_grounding_metadata', False) for tool in tools
264-
):
263+
if not any(tool.name == 'google_search_agent' for tool in tools):
265264
return response
266265
ground_metadata = invocation_context.session.state.get(
267266
'temp:_adk_grounding_metadata', None
@@ -308,6 +307,7 @@ async def _run_and_handle_error(
308307
invocation_context: InvocationContext,
309308
llm_request: LlmRequest,
310309
model_response_event: Event,
310+
call_llm_span: Optional[trace.Span] = None,
311311
) -> AsyncGenerator[LlmResponse, None]:
312312
"""Wraps an LLM response generator with error callback handling.
313313
@@ -321,6 +321,9 @@ async def _run_and_handle_error(
321321
invocation_context: The invocation context.
322322
llm_request: The LLM request.
323323
model_response_event: The model response event.
324+
call_llm_span: The call_llm span to rebind error callbacks to. When
325+
provided, on_model_error callbacks run under this span so plugins observe
326+
the same span as before/after model callbacks.
324327
325328
Yields:
326329
LlmResponse objects from the generator.
@@ -382,11 +385,19 @@ async def _run_on_model_error_callbacks(
382385
callback_context = CallbackContext(
383386
invocation_context, event_actions=model_response_event.actions
384387
)
385-
error_response = await _run_on_model_error_callbacks(
386-
callback_context=callback_context,
387-
llm_request=llm_request,
388-
error=model_error,
389-
)
388+
if call_llm_span is not None:
389+
with trace.use_span(call_llm_span, end_on_exit=False):
390+
error_response = await _run_on_model_error_callbacks(
391+
callback_context=callback_context,
392+
llm_request=llm_request,
393+
error=model_error,
394+
)
395+
else:
396+
error_response = await _run_on_model_error_callbacks(
397+
callback_context=callback_context,
398+
llm_request=llm_request,
399+
error=model_error,
400+
)
390401
if error_response is not None:
391402
yield error_response
392403
else:
@@ -1104,28 +1115,30 @@ async def _call_llm_async(
11041115
llm_request: LlmRequest,
11051116
model_response_event: Event,
11061117
) -> AsyncGenerator[LlmResponse, None]:
1107-
# Runs before_model_callback if it exists.
1108-
if response := await self._handle_before_model_callback(
1109-
invocation_context, llm_request, model_response_event
1110-
):
1111-
yield response
1112-
return
1113-
1114-
llm_request.config = llm_request.config or types.GenerateContentConfig()
1115-
llm_request.config.labels = llm_request.config.labels or {}
1116-
1117-
# Add agent name as a label to the llm_request. This will help with slicing
1118-
# the billing reports on a per-agent basis.
1119-
if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels:
1120-
llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = (
1121-
invocation_context.agent.name
1122-
)
1123-
1124-
# Calls the LLM.
1125-
llm = self.__get_llm(invocation_context)
11261118

11271119
async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11281120
with tracer.start_as_current_span('call_llm') as span:
1121+
# Runs before_model_callback inside the call_llm span so
1122+
# plugins observe the same span as after/error callbacks.
1123+
if response := await self._handle_before_model_callback(
1124+
invocation_context, llm_request, model_response_event
1125+
):
1126+
yield response
1127+
return
1128+
1129+
llm_request.config = llm_request.config or types.GenerateContentConfig()
1130+
llm_request.config.labels = llm_request.config.labels or {}
1131+
1132+
# Add agent name as a label to the llm_request. This will help
1133+
# with slicing billing reports on a per-agent basis.
1134+
if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels:
1135+
llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = (
1136+
invocation_context.agent.name
1137+
)
1138+
1139+
# Calls the LLM.
1140+
llm = self.__get_llm(invocation_context)
1141+
11291142
if invocation_context.run_config.support_cfc:
11301143
invocation_context.live_request_queue = LiveRequestQueue()
11311144
responses_generator = self.run_live(invocation_context)
@@ -1135,14 +1148,20 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11351148
invocation_context,
11361149
llm_request,
11371150
model_response_event,
1151+
call_llm_span=span,
11381152
)
11391153
) as agen:
11401154
async for llm_response in agen:
1141-
# Runs after_model_callback if it exists.
1142-
if altered_llm_response := await self._handle_after_model_callback(
1143-
invocation_context, llm_response, model_response_event
1144-
):
1145-
llm_response = altered_llm_response
1155+
# Rebind to call_llm span for after_model_callback.
1156+
with trace.use_span(span, end_on_exit=False):
1157+
if altered := (
1158+
await self._handle_after_model_callback(
1159+
invocation_context,
1160+
llm_response,
1161+
model_response_event,
1162+
)
1163+
):
1164+
llm_response = altered
11461165
# only yield partial response in SSE streaming mode
11471166
if (
11481167
invocation_context.run_config.streaming_mode
@@ -1153,9 +1172,9 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11531172
if llm_response.turn_complete:
11541173
invocation_context.live_request_queue.close()
11551174
else:
1156-
# Check if we can make this llm call or not. If the current call
1157-
# pushes the counter beyond the max set value, then the execution is
1158-
# stopped right here, and exception is thrown.
1175+
# Check if we can make this llm call or not. If the current
1176+
# call pushes the counter beyond the max set value, then the
1177+
# execution is stopped right here, and exception is thrown.
11591178
invocation_context.increment_llm_call_count()
11601179
responses_generator = llm.generate_content_async(
11611180
llm_request,
@@ -1168,6 +1187,7 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11681187
invocation_context,
11691188
llm_request,
11701189
model_response_event,
1190+
call_llm_span=span,
11711191
)
11721192
) as agen:
11731193
async for llm_response in agen:
@@ -1178,11 +1198,16 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11781198
llm_response,
11791199
span,
11801200
)
1181-
# Runs after_model_callback if it exists.
1182-
if altered_llm_response := await self._handle_after_model_callback(
1183-
invocation_context, llm_response, model_response_event
1184-
):
1185-
llm_response = altered_llm_response
1201+
# Rebind to call_llm span for after_model_callback.
1202+
with trace.use_span(span, end_on_exit=False):
1203+
if altered := (
1204+
await self._handle_after_model_callback(
1205+
invocation_context,
1206+
llm_response,
1207+
model_response_event,
1208+
)
1209+
):
1210+
llm_response = altered
11861211

11871212
yield llm_response
11881213

@@ -1237,13 +1262,15 @@ async def _run_and_handle_error(
12371262
invocation_context: InvocationContext,
12381263
llm_request: LlmRequest,
12391264
model_response_event: Event,
1265+
call_llm_span: Optional[trace.Span] = None,
12401266
) -> AsyncGenerator[LlmResponse, None]:
12411267
async with Aclosing(
12421268
_run_and_handle_error(
12431269
response_generator,
12441270
invocation_context,
12451271
llm_request,
12461272
model_response_event,
1273+
call_llm_span=call_llm_span,
12471274
)
12481275
) as agen:
12491276
async for response in agen:

0 commit comments

Comments
 (0)