Skip to content

Commit b219c9b

Browse files
committed
fix: pass call_llm span to error callbacks and use real TracerProvider in tests
Address review feedback from caohy1988: 1. Pass the call_llm span explicitly to _run_and_handle_error and re-activate it with trace.use_span(parent_span) for error callbacks. This ensures on_model_error_callback reliably sees the same span context as before_model_callback, defending against async context propagation issues across generator yield boundaries. 2. Replace mock.patch-based tracer setup in tests with a real global TracerProvider. The original tests masked the production code path by mocking the tracer — now they validate the same proxy tracer behavior used at runtime. Fixes #4851
1 parent c1b3e41 commit b219c9b

2 files changed

Lines changed: 50 additions & 32 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ async def _run_and_handle_error(
307307
invocation_context: InvocationContext,
308308
llm_request: LlmRequest,
309309
model_response_event: Event,
310+
parent_span: Optional[trace.Span] = None,
310311
) -> AsyncGenerator[LlmResponse, None]:
311312
"""Wraps an LLM response generator with error callback handling.
312313
@@ -320,6 +321,9 @@ async def _run_and_handle_error(
320321
invocation_context: The invocation context.
321322
llm_request: The LLM request.
322323
model_response_event: The model response event.
324+
parent_span: Optional parent span (e.g. call_llm) to re-activate for
325+
error callbacks, ensuring on_model_error_callback sees the same
326+
span context as before_model_callback (issue #4851).
323327
324328
Yields:
325329
LlmResponse objects from the generator.
@@ -381,11 +385,23 @@ async def _run_on_model_error_callbacks(
381385
callback_context = CallbackContext(
382386
invocation_context, event_actions=model_response_event.actions
383387
)
384-
error_response = await _run_on_model_error_callbacks(
385-
callback_context=callback_context,
386-
llm_request=llm_request,
387-
error=model_error,
388-
)
388+
# Re-activate the parent span (call_llm) so on_model_error_callback
389+
# sees the same span_id as before_model_callback (issue #4851).
390+
# This is necessary because the inference span has exited and async
391+
# context propagation across generator yields can be unreliable.
392+
if parent_span is not None:
393+
with trace.use_span(parent_span):
394+
error_response = await _run_on_model_error_callbacks(
395+
callback_context=callback_context,
396+
llm_request=llm_request,
397+
error=model_error,
398+
)
399+
else:
400+
error_response = await _run_on_model_error_callbacks(
401+
callback_context=callback_context,
402+
llm_request=llm_request,
403+
error=model_error,
404+
)
389405
if error_response is not None:
390406
yield error_response
391407
else:
@@ -1153,6 +1169,7 @@ async def _apply_after_model_callback(
11531169
invocation_context,
11541170
llm_request,
11551171
model_response_event,
1172+
parent_span=span,
11561173
)
11571174
) as agen:
11581175
async for llm_response in agen:
@@ -1182,6 +1199,7 @@ async def _apply_after_model_callback(
11821199
invocation_context,
11831200
llm_request,
11841201
model_response_event,
1202+
parent_span=span,
11851203
)
11861204
) as agen:
11871205
async for llm_response in agen:
@@ -1247,13 +1265,15 @@ async def _run_and_handle_error(
12471265
invocation_context: InvocationContext,
12481266
llm_request: LlmRequest,
12491267
model_response_event: Event,
1268+
parent_span: Optional[trace.Span] = None,
12501269
) -> AsyncGenerator[LlmResponse, None]:
12511270
async with Aclosing(
12521271
_run_and_handle_error(
12531272
response_generator,
12541273
invocation_context,
12551274
llm_request,
12561275
model_response_event,
1276+
parent_span=parent_span,
12571277
)
12581278
) as agen:
12591279
async for response in agen:

tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,18 @@
1818
after_model_callback, and on_model_error_callback must all execute within
1919
the same call_llm span so that plugins (e.g. BigQueryAgentAnalyticsPlugin)
2020
see a consistent span_id for LLM_REQUEST and LLM_RESPONSE events.
21+
22+
These tests set up a real TracerProvider globally — rather than mocking
23+
the tracer — so that they validate the same code path used in production.
2124
"""
2225

2326
from typing import Optional
24-
from unittest import mock
2527

2628
from google.adk.agents.callback_context import CallbackContext
2729
from google.adk.agents.llm_agent import Agent
28-
from google.adk.flows.llm_flows import base_llm_flow
2930
from google.adk.models.llm_request import LlmRequest
3031
from google.adk.models.llm_response import LlmResponse
3132
from google.adk.plugins.base_plugin import BasePlugin
32-
from google.adk.telemetry import tracing as adk_tracing
3333
from google.genai import types
3434
from opentelemetry import trace
3535
from opentelemetry.sdk.trace import TracerProvider
@@ -38,10 +38,23 @@
3838
from ... import testing_utils
3939

4040

41-
def _make_real_tracer():
42-
"""Create a real tracer that produces valid span IDs."""
41+
@pytest.fixture(autouse=True)
42+
def _setup_real_tracer_provider():
43+
"""Set up a real TracerProvider globally for realistic span validation.
44+
45+
This ensures that all code paths — including the module-level ``tracer``
46+
in ``tracing.py`` — produce real spans with valid span IDs, matching
47+
production behavior when a TracerProvider is configured.
48+
49+
Note: ``trace.get_tracer()`` returns a proxy that delegates to the
50+
currently set TracerProvider, so setting the provider after import
51+
correctly affects all existing tracer references.
52+
"""
4353
provider = TracerProvider()
44-
return provider.get_tracer('test_tracer')
54+
previous_provider = trace.get_tracer_provider()
55+
trace.set_tracer_provider(provider)
56+
yield
57+
trace.set_tracer_provider(previous_provider)
4558

4659

4760
class SpanCapturingPlugin(BasePlugin):
@@ -104,20 +117,15 @@ async def test_before_and_after_model_callbacks_share_span_id():
104117
mismatch between LLM_REQUEST and LLM_RESPONSE events.
105118
"""
106119
plugin = SpanCapturingPlugin()
107-
real_tracer = _make_real_tracer()
108120

109121
mock_model = testing_utils.MockModel.create(responses=['model_response'])
110122
agent = Agent(
111123
name='test_agent',
112124
model=mock_model,
113125
)
114126

115-
with (
116-
mock.patch.object(base_llm_flow, 'tracer', real_tracer),
117-
mock.patch.object(adk_tracing, 'tracer', real_tracer),
118-
):
119-
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
120-
events = await runner.run_async_with_new_session('test')
127+
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
128+
events = await runner.run_async_with_new_session('test')
121129

122130
# Both callbacks should have captured a span ID
123131
assert (
@@ -144,7 +152,6 @@ async def test_before_and_on_error_model_callbacks_share_span_id():
144152
same span as before_model_callback.
145153
"""
146154
plugin = SpanCapturingPlugin()
147-
real_tracer = _make_real_tracer()
148155

149156
mock_model = testing_utils.MockModel.create(
150157
responses=[], error=SystemError('model error')
@@ -154,12 +161,8 @@ async def test_before_and_on_error_model_callbacks_share_span_id():
154161
model=mock_model,
155162
)
156163

157-
with (
158-
mock.patch.object(base_llm_flow, 'tracer', real_tracer),
159-
mock.patch.object(adk_tracing, 'tracer', real_tracer),
160-
):
161-
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
162-
events = await runner.run_async_with_new_session('test')
164+
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
165+
events = await runner.run_async_with_new_session('test')
163166

164167
# Both callbacks should have captured a span ID
165168
assert (
@@ -206,20 +209,15 @@ async def before_model_callback(
206209
)
207210

208211
plugin = ShortCircuitPlugin()
209-
real_tracer = _make_real_tracer()
210212

211213
mock_model = testing_utils.MockModel.create(responses=['model_response'])
212214
agent = Agent(
213215
name='test_agent',
214216
model=mock_model,
215217
)
216218

217-
with (
218-
mock.patch.object(base_llm_flow, 'tracer', real_tracer),
219-
mock.patch.object(adk_tracing, 'tracer', real_tracer),
220-
):
221-
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
222-
events = await runner.run_async_with_new_session('test')
219+
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
220+
events = await runner.run_async_with_new_session('test')
223221

224222
# The callback should have a valid (non-zero) span ID from the call_llm span
225223
assert plugin.span_id is not None and plugin.span_id != 0, (

0 commit comments

Comments
 (0)