2424
2525from google .adk .platform import time as platform_time
2626from google .genai import types
27+ from opentelemetry import trace
2728from websockets .exceptions import ConnectionClosed
2829from websockets .exceptions import ConnectionClosedOK
2930
@@ -306,6 +307,7 @@ async def _run_and_handle_error(
306307 invocation_context : InvocationContext ,
307308 llm_request : LlmRequest ,
308309 model_response_event : Event ,
310+ call_llm_span : Optional [trace .Span ] = None ,
309311) -> AsyncGenerator [LlmResponse , None ]:
310312 """Wraps an LLM response generator with error callback handling.
311313
@@ -319,6 +321,9 @@ async def _run_and_handle_error(
319321 invocation_context: The invocation context.
320322 llm_request: The LLM request.
321323 model_response_event: The model response event.
324+ call_llm_span: The call_llm span to rebind error callbacks to.
325+ When provided, on_model_error callbacks run under this span so
326+ plugins observe the same span as before/after model callbacks.
322327
323328 Yields:
324329 LlmResponse objects from the generator.
@@ -380,11 +385,19 @@ async def _run_on_model_error_callbacks(
380385 callback_context = CallbackContext (
381386 invocation_context , event_actions = model_response_event .actions
382387 )
383- error_response = await _run_on_model_error_callbacks (
384- callback_context = callback_context ,
385- llm_request = llm_request ,
386- error = model_error ,
387- )
388+ if call_llm_span is not None :
389+ with trace .use_span (call_llm_span , end_on_exit = False ):
390+ error_response = await _run_on_model_error_callbacks (
391+ callback_context = callback_context ,
392+ llm_request = llm_request ,
393+ error = model_error ,
394+ )
395+ else :
396+ error_response = await _run_on_model_error_callbacks (
397+ callback_context = callback_context ,
398+ llm_request = llm_request ,
399+ error = model_error ,
400+ )
388401 if error_response is not None :
389402 yield error_response
390403 else :
@@ -1102,28 +1115,30 @@ async def _call_llm_async(
11021115 llm_request : LlmRequest ,
11031116 model_response_event : Event ,
11041117 ) -> AsyncGenerator [LlmResponse , None ]:
1105- # Runs before_model_callback if it exists.
1106- if response := await self ._handle_before_model_callback (
1107- invocation_context , llm_request , model_response_event
1108- ):
1109- yield response
1110- return
1111-
1112- llm_request .config = llm_request .config or types .GenerateContentConfig ()
1113- llm_request .config .labels = llm_request .config .labels or {}
1114-
1115- # Add agent name as a label to the llm_request. This will help with slicing
1116- # the billing reports on a per-agent basis.
1117- if _ADK_AGENT_NAME_LABEL_KEY not in llm_request .config .labels :
1118- llm_request .config .labels [_ADK_AGENT_NAME_LABEL_KEY ] = (
1119- invocation_context .agent .name
1120- )
1121-
1122- # Calls the LLM.
1123- llm = self .__get_llm (invocation_context )
11241118
11251119 async def _call_llm_with_tracing () -> AsyncGenerator [LlmResponse , None ]:
11261120 with tracer .start_as_current_span ('call_llm' ) as span :
1121+ # Runs before_model_callback inside the call_llm span so
1122+ # plugins observe the same span as after/error callbacks.
1123+ if response := await self ._handle_before_model_callback (
1124+ invocation_context , llm_request , model_response_event
1125+ ):
1126+ yield response
1127+ return
1128+
1129+ llm_request .config = llm_request .config or types .GenerateContentConfig ()
1130+ llm_request .config .labels = llm_request .config .labels or {}
1131+
1132+ # Add agent name as a label to the llm_request. This will help
1133+ # with slicing billing reports on a per-agent basis.
1134+ if _ADK_AGENT_NAME_LABEL_KEY not in llm_request .config .labels :
1135+ llm_request .config .labels [_ADK_AGENT_NAME_LABEL_KEY ] = (
1136+ invocation_context .agent .name
1137+ )
1138+
1139+ # Calls the LLM.
1140+ llm = self .__get_llm (invocation_context )
1141+
11271142 if invocation_context .run_config .support_cfc :
11281143 invocation_context .live_request_queue = LiveRequestQueue ()
11291144 responses_generator = self .run_live (invocation_context )
@@ -1133,14 +1148,20 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11331148 invocation_context ,
11341149 llm_request ,
11351150 model_response_event ,
1151+ call_llm_span = span ,
11361152 )
11371153 ) as agen :
11381154 async for llm_response in agen :
1139- # Runs after_model_callback if it exists.
1140- if altered_llm_response := await self ._handle_after_model_callback (
1141- invocation_context , llm_response , model_response_event
1142- ):
1143- llm_response = altered_llm_response
1155+ # Rebind to call_llm span for after_model_callback.
1156+ with trace .use_span (span , end_on_exit = False ):
1157+ if altered := (
1158+ await self ._handle_after_model_callback (
1159+ invocation_context ,
1160+ llm_response ,
1161+ model_response_event ,
1162+ )
1163+ ):
1164+ llm_response = altered
11441165 # only yield partial response in SSE streaming mode
11451166 if (
11461167 invocation_context .run_config .streaming_mode
@@ -1151,9 +1172,9 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11511172 if llm_response .turn_complete :
11521173 invocation_context .live_request_queue .close ()
11531174 else :
1154- # Check if we can make this llm call or not. If the current call
1155- # pushes the counter beyond the max set value, then the execution is
1156- # stopped right here, and exception is thrown.
1175+ # Check if we can make this llm call or not. If the current
1176+ # call pushes the counter beyond the max set value, then the
1177+ # execution is stopped right here, and exception is thrown.
11571178 invocation_context .increment_llm_call_count ()
11581179 responses_generator = llm .generate_content_async (
11591180 llm_request ,
@@ -1166,6 +1187,7 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11661187 invocation_context ,
11671188 llm_request ,
11681189 model_response_event ,
1190+ call_llm_span = span ,
11691191 )
11701192 ) as agen :
11711193 async for llm_response in agen :
@@ -1176,11 +1198,16 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
11761198 llm_response ,
11771199 span ,
11781200 )
1179- # Runs after_model_callback if it exists.
1180- if altered_llm_response := await self ._handle_after_model_callback (
1181- invocation_context , llm_response , model_response_event
1182- ):
1183- llm_response = altered_llm_response
1201+ # Rebind to call_llm span for after_model_callback.
1202+ with trace .use_span (span , end_on_exit = False ):
1203+ if altered := (
1204+ await self ._handle_after_model_callback (
1205+ invocation_context ,
1206+ llm_response ,
1207+ model_response_event ,
1208+ )
1209+ ):
1210+ llm_response = altered
11841211
11851212 yield llm_response
11861213
@@ -1235,13 +1262,15 @@ async def _run_and_handle_error(
12351262 invocation_context : InvocationContext ,
12361263 llm_request : LlmRequest ,
12371264 model_response_event : Event ,
1265+ call_llm_span : Optional [trace .Span ] = None ,
12381266 ) -> AsyncGenerator [LlmResponse , None ]:
12391267 async with Aclosing (
12401268 _run_and_handle_error (
12411269 response_generator ,
12421270 invocation_context ,
12431271 llm_request ,
12441272 model_response_event ,
1273+ call_llm_span = call_llm_span ,
12451274 )
12461275 ) as agen :
12471276 async for response in agen :
0 commit comments