Skip to content

Commit 952e4df

Browse files
committed
test: verify callbacks observe the call_llm span by name and ID
Add span name assertions so tests prove each callback sees specifically the 'call_llm' span (not just any span with a matching ID). This directly addresses the reviewer's request for proof that before_model_callback, after_model_callback, and on_model_error_callback all observe the same valid call_llm span ID.
1 parent b219c9b commit 952e4df

1 file changed

Lines changed: 57 additions & 21 deletions

File tree

tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,34 @@ def _setup_real_tracer_provider():
5858

5959

6060
class SpanCapturingPlugin(BasePlugin):
61-
"""Plugin that captures the current span ID in each model callback."""
61+
"""Plugin that captures span ID and name in each model callback."""
6262

6363
def __init__(self):
6464
super().__init__(name='span_capturing_plugin')
6565
self.before_model_span_id: Optional[int] = None
66+
self.before_model_span_name: Optional[str] = None
6667
self.after_model_span_id: Optional[int] = None
68+
self.after_model_span_name: Optional[str] = None
6769
self.on_model_error_span_id: Optional[int] = None
70+
self.on_model_error_span_name: Optional[str] = None
71+
72+
def _capture_span(self):
73+
"""Capture the current span's ID and name."""
74+
span = trace.get_current_span()
75+
ctx = span.get_span_context()
76+
span_id = ctx.span_id if ctx and ctx.span_id else None
77+
span_name = getattr(span, 'name', None)
78+
return span_id, span_name
6879

6980
async def before_model_callback(
7081
self,
7182
*,
7283
callback_context: CallbackContext,
7384
llm_request: LlmRequest,
7485
) -> Optional[LlmResponse]:
75-
span = trace.get_current_span()
76-
ctx = span.get_span_context()
77-
if ctx and ctx.span_id:
78-
self.before_model_span_id = ctx.span_id
86+
self.before_model_span_id, self.before_model_span_name = (
87+
self._capture_span()
88+
)
7989
return None
8090

8191
async def after_model_callback(
@@ -84,10 +94,9 @@ async def after_model_callback(
8494
callback_context: CallbackContext,
8595
llm_response: LlmResponse,
8696
) -> Optional[LlmResponse]:
87-
span = trace.get_current_span()
88-
ctx = span.get_span_context()
89-
if ctx and ctx.span_id:
90-
self.after_model_span_id = ctx.span_id
97+
self.after_model_span_id, self.after_model_span_name = (
98+
self._capture_span()
99+
)
91100
return None
92101

93102
async def on_model_error_callback(
@@ -97,10 +106,9 @@ async def on_model_error_callback(
97106
llm_request: LlmRequest,
98107
error: Exception,
99108
) -> Optional[LlmResponse]:
100-
span = trace.get_current_span()
101-
ctx = span.get_span_context()
102-
if ctx and ctx.span_id:
103-
self.on_model_error_span_id = ctx.span_id
109+
self.on_model_error_span_id, self.on_model_error_span_name = (
110+
self._capture_span()
111+
)
104112
return LlmResponse(
105113
content=testing_utils.ModelContent(
106114
[types.Part.from_text(text='error handled')]
@@ -110,7 +118,7 @@ async def on_model_error_callback(
110118

111119
@pytest.mark.asyncio
112120
async def test_before_and_after_model_callbacks_share_span_id():
113-
"""Verify before_model_callback and after_model_callback share the same span.
121+
"""Verify before_model_callback and after_model_callback share call_llm span.
114122
115123
This is the core regression test for issue #4851. Before the fix,
116124
before_model_callback ran outside the call_llm span, causing a span_id
@@ -125,16 +133,26 @@ async def test_before_and_after_model_callbacks_share_span_id():
125133
)
126134

127135
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
128-
events = await runner.run_async_with_new_session('test')
136+
await runner.run_async_with_new_session('test')
129137

130-
# Both callbacks should have captured a span ID
138+
# Both callbacks must have captured a valid span ID
131139
assert (
132140
plugin.before_model_span_id is not None
133141
), 'before_model_callback did not capture a span ID'
134142
assert (
135143
plugin.after_model_span_id is not None
136144
), 'after_model_callback did not capture a span ID'
137145

146+
# Both must observe the call_llm span specifically
147+
assert plugin.before_model_span_name == 'call_llm', (
148+
f'before_model_callback saw span "{plugin.before_model_span_name}", '
149+
'expected "call_llm"'
150+
)
151+
assert plugin.after_model_span_name == 'call_llm', (
152+
f'after_model_callback saw span "{plugin.after_model_span_name}", '
153+
'expected "call_llm"'
154+
)
155+
138156
# The span IDs must match — this is the core assertion for issue #4851
139157
assert plugin.before_model_span_id == plugin.after_model_span_id, (
140158
'Span ID mismatch: before_model_callback span_id='
@@ -146,10 +164,10 @@ async def test_before_and_after_model_callbacks_share_span_id():
146164

147165
@pytest.mark.asyncio
148166
async def test_before_and_on_error_model_callbacks_share_span_id():
149-
"""Verify before_model_callback and on_model_error_callback share span.
167+
"""Verify before_model_callback and on_model_error_callback share call_llm span.
150168
151169
When the model raises an error, on_model_error_callback should see the
152-
same span as before_model_callback.
170+
same call_llm span as before_model_callback.
153171
"""
154172
plugin = SpanCapturingPlugin()
155173

@@ -162,16 +180,26 @@ async def test_before_and_on_error_model_callbacks_share_span_id():
162180
)
163181

164182
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
165-
events = await runner.run_async_with_new_session('test')
183+
await runner.run_async_with_new_session('test')
166184

167-
# Both callbacks should have captured a span ID
185+
# Both callbacks must have captured a valid span ID
168186
assert (
169187
plugin.before_model_span_id is not None
170188
), 'before_model_callback did not capture a span ID'
171189
assert (
172190
plugin.on_model_error_span_id is not None
173191
), 'on_model_error_callback did not capture a span ID'
174192

193+
# Both must observe the call_llm span specifically
194+
assert plugin.before_model_span_name == 'call_llm', (
195+
f'before_model_callback saw span "{plugin.before_model_span_name}", '
196+
'expected "call_llm"'
197+
)
198+
assert plugin.on_model_error_span_name == 'call_llm', (
199+
f'on_model_error_callback saw span "{plugin.on_model_error_span_name}", '
200+
'expected "call_llm"'
201+
)
202+
175203
# The span IDs must match
176204
assert plugin.before_model_span_id == plugin.on_model_error_span_id, (
177205
'Span ID mismatch: before_model_callback span_id='
@@ -184,13 +212,14 @@ async def test_before_and_on_error_model_callbacks_share_span_id():
184212

185213
@pytest.mark.asyncio
186214
async def test_before_model_callback_short_circuit_has_span():
187-
"""Verify before_model_callback has a valid span when short-circuiting."""
215+
"""Verify before_model_callback has a valid call_llm span when short-circuiting."""
188216

189217
class ShortCircuitPlugin(BasePlugin):
190218

191219
def __init__(self):
192220
super().__init__(name='short_circuit_plugin')
193221
self.span_id: Optional[int] = None
222+
self.span_name: Optional[str] = None
194223

195224
async def before_model_callback(
196225
self,
@@ -202,6 +231,7 @@ async def before_model_callback(
202231
ctx = span.get_span_context()
203232
if ctx and ctx.span_id:
204233
self.span_id = ctx.span_id
234+
self.span_name = getattr(span, 'name', None)
205235
return LlmResponse(
206236
content=testing_utils.ModelContent(
207237
[types.Part.from_text(text='short-circuited')]
@@ -225,6 +255,12 @@ async def before_model_callback(
225255
'short-circuiting the LLM call'
226256
)
227257

258+
# Must be specifically the call_llm span
259+
assert plugin.span_name == 'call_llm', (
260+
f'before_model_callback saw span "{plugin.span_name}", '
261+
'expected "call_llm"'
262+
)
263+
228264
# Verify the short-circuit response was received
229265
simplified = testing_utils.simplify_events(events)
230266
assert any('short-circuited' in str(e) for e in simplified)

0 commit comments

Comments
 (0)