2424
2525from google .adk .platform import time as platform_time
2626from google .genai import types
27+ from opentelemetry import trace
2728from websockets .exceptions import ConnectionClosed
2829from websockets .exceptions import ConnectionClosedOK
2930
@@ -259,9 +260,7 @@ async def _maybe_add_grounding_metadata(
259260 tools = await agent .canonical_tools (readonly_context )
260261 invocation_context .canonical_tools_cache = tools
261262
262- if not any (
263- getattr (tool , 'propagate_grounding_metadata' , False ) for tool in tools
264- ):
263+ if not any (tool .name == 'google_search_agent' for tool in tools ):
265264 return response
266265 ground_metadata = invocation_context .session .state .get (
267266 'temp:_adk_grounding_metadata' , None
@@ -308,6 +307,7 @@ async def _run_and_handle_error(
308307 invocation_context : InvocationContext ,
309308 llm_request : LlmRequest ,
310309 model_response_event : Event ,
310+ call_llm_span : Optional [trace .Span ] = None ,
311311) -> AsyncGenerator [LlmResponse , None ]:
312312 """Wraps an LLM response generator with error callback handling.
313313
@@ -321,6 +321,9 @@ async def _run_and_handle_error(
321321 invocation_context: The invocation context.
322322 llm_request: The LLM request.
323323 model_response_event: The model response event.
324+ call_llm_span: The call_llm span to rebind error callbacks to. When
325+ provided, on_model_error callbacks run under this span so plugins observe
326+ the same span as before/after model callbacks.
324327
325328 Yields:
326329 LlmResponse objects from the generator.
@@ -382,11 +385,19 @@ async def _run_on_model_error_callbacks(
382385 callback_context = CallbackContext (
383386 invocation_context , event_actions = model_response_event .actions
384387 )
385- error_response = await _run_on_model_error_callbacks (
386- callback_context = callback_context ,
387- llm_request = llm_request ,
388- error = model_error ,
389- )
388+ if call_llm_span is not None :
389+ with trace .use_span (call_llm_span , end_on_exit = False ):
390+ error_response = await _run_on_model_error_callbacks (
391+ callback_context = callback_context ,
392+ llm_request = llm_request ,
393+ error = model_error ,
394+ )
395+ else :
396+ error_response = await _run_on_model_error_callbacks (
397+ callback_context = callback_context ,
398+ llm_request = llm_request ,
399+ error = model_error ,
400+ )
390401 if error_response is not None :
391402 yield error_response
392403 else :
@@ -1104,28 +1115,30 @@ async def _call_llm_async(
11041115 llm_request : LlmRequest ,
11051116 model_response_event : Event ,
11061117 ) -> AsyncGenerator [LlmResponse , None ]:
1107- # Runs before_model_callback if it exists.
1108- if response := await self ._handle_before_model_callback (
1109- invocation_context , llm_request , model_response_event
1110- ):
1111- yield response
1112- return
1113-
1114- llm_request .config = llm_request .config or types .GenerateContentConfig ()
1115- llm_request .config .labels = llm_request .config .labels or {}
1116-
1117- # Add agent name as a label to the llm_request. This will help with slicing
1118- # the billing reports on a per-agent basis.
1119- if _ADK_AGENT_NAME_LABEL_KEY not in llm_request .config .labels :
1120- llm_request .config .labels [_ADK_AGENT_NAME_LABEL_KEY ] = (
1121- invocation_context .agent .name
1122- )
1123-
1124- # Calls the LLM.
1125- llm = self .__get_llm (invocation_context )
11261118
11271119 async def _call_llm_with_tracing () -> AsyncGenerator [LlmResponse , None ]:
11281120 with tracer .start_as_current_span ('call_llm' ) as span :
1121+ # Runs before_model_callback inside the call_llm span so
1122+ # plugins observe the same span as after/error callbacks.
1123+ if response := await self ._handle_before_model_callback (
1124+ invocation_context , llm_request , model_response_event
1125+ ):
1126+ yield response
1127+ return
1128+
1129+ llm_request .config = llm_request .config or types .GenerateContentConfig ()
1130+ llm_request .config .labels = llm_request .config .labels or {}
1131+
1132+ # Add agent name as a label to the llm_request. This will help
1133+ # with slicing billing reports on a per-agent basis.
1134+ if _ADK_AGENT_NAME_LABEL_KEY not in llm_request .config .labels :
1135+ llm_request .config .labels [_ADK_AGENT_NAME_LABEL_KEY ] = (
1136+ invocation_context .agent .name
1137+ )
1138+
1139+ # Calls the LLM.
1140+ llm = self .__get_llm (invocation_context )
1141+
11291142 if invocation_context .run_config .support_cfc :
11301143 invocation_context .live_request_queue = LiveRequestQueue ()
11311144 responses_generator = self .run_live (invocation_context )
@@ -1135,14 +1148,20 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11351148 invocation_context ,
11361149 llm_request ,
11371150 model_response_event ,
1151+ call_llm_span = span ,
11381152 )
11391153 ) as agen :
11401154 async for llm_response in agen :
1141- # Runs after_model_callback if it exists.
1142- if altered_llm_response := await self ._handle_after_model_callback (
1143- invocation_context , llm_response , model_response_event
1144- ):
1145- llm_response = altered_llm_response
1155+ # Rebind to call_llm span for after_model_callback.
1156+ with trace .use_span (span , end_on_exit = False ):
1157+ if altered := (
1158+ await self ._handle_after_model_callback (
1159+ invocation_context ,
1160+ llm_response ,
1161+ model_response_event ,
1162+ )
1163+ ):
1164+ llm_response = altered
11461165 # only yield partial response in SSE streaming mode
11471166 if (
11481167 invocation_context .run_config .streaming_mode
@@ -1153,9 +1172,9 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11531172 if llm_response .turn_complete :
11541173 invocation_context .live_request_queue .close ()
11551174 else :
1156- # Check if we can make this llm call or not. If the current call
1157- # pushes the counter beyond the max set value, then the execution is
1158- # stopped right here, and exception is thrown.
1175+ # Check if we can make this llm call or not. If the current
1176+ # call pushes the counter beyond the max set value, then the
1177+ # execution is stopped right here, and exception is thrown.
11591178 invocation_context .increment_llm_call_count ()
11601179 responses_generator = llm .generate_content_async (
11611180 llm_request ,
@@ -1168,6 +1187,7 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11681187 invocation_context ,
11691188 llm_request ,
11701189 model_response_event ,
1190+ call_llm_span = span ,
11711191 )
11721192 ) as agen :
11731193 async for llm_response in agen :
@@ -1178,11 +1198,16 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11781198 llm_response ,
11791199 span ,
11801200 )
1181- # Runs after_model_callback if it exists.
1182- if altered_llm_response := await self ._handle_after_model_callback (
1183- invocation_context , llm_response , model_response_event
1184- ):
1185- llm_response = altered_llm_response
1201+ # Rebind to call_llm span for after_model_callback.
1202+ with trace .use_span (span , end_on_exit = False ):
1203+ if altered := (
1204+ await self ._handle_after_model_callback (
1205+ invocation_context ,
1206+ llm_response ,
1207+ model_response_event ,
1208+ )
1209+ ):
1210+ llm_response = altered
11861211
11871212 yield llm_response
11881213
@@ -1237,13 +1262,15 @@ async def _run_and_handle_error(
12371262 invocation_context : InvocationContext ,
12381263 llm_request : LlmRequest ,
12391264 model_response_event : Event ,
1265+ call_llm_span : Optional [trace .Span ] = None ,
12401266 ) -> AsyncGenerator [LlmResponse , None ]:
12411267 async with Aclosing (
12421268 _run_and_handle_error (
12431269 response_generator ,
12441270 invocation_context ,
12451271 llm_request ,
12461272 model_response_event ,
1273+ call_llm_span = call_llm_span ,
12471274 )
12481275 ) as agen :
12491276 async for response in agen :
0 commit comments