@@ -58,24 +58,34 @@ def _setup_real_tracer_provider():
5858
5959
6060class SpanCapturingPlugin (BasePlugin ):
61- """Plugin that captures the current span ID in each model callback."""
61+ """Plugin that captures span ID and name in each model callback."""
6262
6363 def __init__ (self ):
6464 super ().__init__ (name = 'span_capturing_plugin' )
6565 self .before_model_span_id : Optional [int ] = None
66+ self .before_model_span_name : Optional [str ] = None
6667 self .after_model_span_id : Optional [int ] = None
68+ self .after_model_span_name : Optional [str ] = None
6769 self .on_model_error_span_id : Optional [int ] = None
70+ self .on_model_error_span_name : Optional [str ] = None
71+
72+ def _capture_span (self ):
73+ """Capture the current span's ID and name."""
74+ span = trace .get_current_span ()
75+ ctx = span .get_span_context ()
76+ span_id = ctx .span_id if ctx and ctx .span_id else None
77+ span_name = getattr (span , 'name' , None )
78+ return span_id , span_name
6879
6980 async def before_model_callback (
7081 self ,
7182 * ,
7283 callback_context : CallbackContext ,
7384 llm_request : LlmRequest ,
7485 ) -> Optional [LlmResponse ]:
75- span = trace .get_current_span ()
76- ctx = span .get_span_context ()
77- if ctx and ctx .span_id :
78- self .before_model_span_id = ctx .span_id
86+ self .before_model_span_id , self .before_model_span_name = (
87+ self ._capture_span ()
88+ )
7989 return None
8090
8191 async def after_model_callback (
@@ -84,10 +94,9 @@ async def after_model_callback(
8494 callback_context : CallbackContext ,
8595 llm_response : LlmResponse ,
8696 ) -> Optional [LlmResponse ]:
87- span = trace .get_current_span ()
88- ctx = span .get_span_context ()
89- if ctx and ctx .span_id :
90- self .after_model_span_id = ctx .span_id
97+ self .after_model_span_id , self .after_model_span_name = (
98+ self ._capture_span ()
99+ )
91100 return None
92101
93102 async def on_model_error_callback (
@@ -97,10 +106,9 @@ async def on_model_error_callback(
97106 llm_request : LlmRequest ,
98107 error : Exception ,
99108 ) -> Optional [LlmResponse ]:
100- span = trace .get_current_span ()
101- ctx = span .get_span_context ()
102- if ctx and ctx .span_id :
103- self .on_model_error_span_id = ctx .span_id
109+ self .on_model_error_span_id , self .on_model_error_span_name = (
110+ self ._capture_span ()
111+ )
104112 return LlmResponse (
105113 content = testing_utils .ModelContent (
106114 [types .Part .from_text (text = 'error handled' )]
@@ -110,7 +118,7 @@ async def on_model_error_callback(
110118
111119@pytest .mark .asyncio
112120async def test_before_and_after_model_callbacks_share_span_id ():
113- """Verify before_model_callback and after_model_callback share the same span.
121+ """Verify before_model_callback and after_model_callback share call_llm span.
114122
115123 This is the core regression test for issue #4851. Before the fix,
116124 before_model_callback ran outside the call_llm span, causing a span_id
@@ -125,16 +133,26 @@ async def test_before_and_after_model_callbacks_share_span_id():
125133 )
126134
127135 runner = testing_utils .TestInMemoryRunner (agent , plugins = [plugin ])
128- events = await runner .run_async_with_new_session ('test' )
136+ await runner .run_async_with_new_session ('test' )
129137
130- # Both callbacks should have captured a span ID
138+ # Both callbacks must have captured a valid span ID
131139 assert (
132140 plugin .before_model_span_id is not None
133141 ), 'before_model_callback did not capture a span ID'
134142 assert (
135143 plugin .after_model_span_id is not None
136144 ), 'after_model_callback did not capture a span ID'
137145
146+ # Both must observe the call_llm span specifically
147+ assert plugin .before_model_span_name == 'call_llm' , (
148+ f'before_model_callback saw span "{ plugin .before_model_span_name } ", '
149+ 'expected "call_llm"'
150+ )
151+ assert plugin .after_model_span_name == 'call_llm' , (
152+ f'after_model_callback saw span "{ plugin .after_model_span_name } ", '
153+ 'expected "call_llm"'
154+ )
155+
138156 # The span IDs must match — this is the core assertion for issue #4851
139157 assert plugin .before_model_span_id == plugin .after_model_span_id , (
140158 'Span ID mismatch: before_model_callback span_id='
@@ -146,10 +164,10 @@ async def test_before_and_after_model_callbacks_share_span_id():
146164
147165@pytest .mark .asyncio
148166async def test_before_and_on_error_model_callbacks_share_span_id ():
149- """Verify before_model_callback and on_model_error_callback share span.
167+ """Verify before_model_callback and on_model_error_callback share call_llm span.
150168
151169 When the model raises an error, on_model_error_callback should see the
152- same span as before_model_callback.
170+ same call_llm span as before_model_callback.
153171 """
154172 plugin = SpanCapturingPlugin ()
155173
@@ -162,16 +180,26 @@ async def test_before_and_on_error_model_callbacks_share_span_id():
162180 )
163181
164182 runner = testing_utils .TestInMemoryRunner (agent , plugins = [plugin ])
165- events = await runner .run_async_with_new_session ('test' )
183+ await runner .run_async_with_new_session ('test' )
166184
167- # Both callbacks should have captured a span ID
185+ # Both callbacks must have captured a valid span ID
168186 assert (
169187 plugin .before_model_span_id is not None
170188 ), 'before_model_callback did not capture a span ID'
171189 assert (
172190 plugin .on_model_error_span_id is not None
173191 ), 'on_model_error_callback did not capture a span ID'
174192
193+ # Both must observe the call_llm span specifically
194+ assert plugin .before_model_span_name == 'call_llm' , (
195+ f'before_model_callback saw span "{ plugin .before_model_span_name } ", '
196+ 'expected "call_llm"'
197+ )
198+ assert plugin .on_model_error_span_name == 'call_llm' , (
199+ f'on_model_error_callback saw span "{ plugin .on_model_error_span_name } ", '
200+ 'expected "call_llm"'
201+ )
202+
175203 # The span IDs must match
176204 assert plugin .before_model_span_id == plugin .on_model_error_span_id , (
177205 'Span ID mismatch: before_model_callback span_id='
@@ -184,13 +212,14 @@ async def test_before_and_on_error_model_callbacks_share_span_id():
184212
185213@pytest .mark .asyncio
186214async def test_before_model_callback_short_circuit_has_span ():
187- """Verify before_model_callback has a valid span when short-circuiting."""
215+ """Verify before_model_callback has a valid call_llm span when short-circuiting."""
188216
189217 class ShortCircuitPlugin (BasePlugin ):
190218
191219 def __init__ (self ):
192220 super ().__init__ (name = 'short_circuit_plugin' )
193221 self .span_id : Optional [int ] = None
222+ self .span_name : Optional [str ] = None
194223
195224 async def before_model_callback (
196225 self ,
@@ -202,6 +231,7 @@ async def before_model_callback(
202231 ctx = span .get_span_context ()
203232 if ctx and ctx .span_id :
204233 self .span_id = ctx .span_id
234+ self .span_name = getattr (span , 'name' , None )
205235 return LlmResponse (
206236 content = testing_utils .ModelContent (
207237 [types .Part .from_text (text = 'short-circuited' )]
@@ -225,6 +255,12 @@ async def before_model_callback(
225255 'short-circuiting the LLM call'
226256 )
227257
258+ # Must be specifically the call_llm span
259+ assert plugin .span_name == 'call_llm' , (
260+ f'before_model_callback saw span "{ plugin .span_name } ", '
261+ 'expected "call_llm"'
262+ )
263+
228264 # Verify the short-circuit response was received
229265 simplified = testing_utils .simplify_events (events )
230266 assert any ('short-circuited' in str (e ) for e in simplified )
0 commit comments