Skip to content

Commit 58d3ed2

Browse files
authored
fix: preserve streamed output guardrail tripwires in the run loop (#2758)
1 parent 85929c0 commit 58d3ed2

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

src/agents/run_internal/run_loop.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
InputGuardrailTripwireTriggered,
3131
MaxTurnsExceeded,
3232
ModelBehaviorError,
33+
OutputGuardrailTripwireTriggered,
3334
RunErrorDetails,
3435
UserError,
3536
)
@@ -230,6 +231,17 @@
230231
]
231232

232233

234+
def _should_attach_generic_agent_error(exc: Exception) -> bool:
235+
return not isinstance(
236+
exc,
237+
(
238+
ModelBehaviorError,
239+
InputGuardrailTripwireTriggered,
240+
OutputGuardrailTripwireTriggered,
241+
),
242+
)
243+
244+
233245
async def _should_persist_stream_items(
234246
*,
235247
session: Session | None,
@@ -344,7 +356,12 @@ async def _run_output_guardrails_for_stream(
344356

345357
try:
346358
return cast(list[Any], await streamed_result._output_guardrails_task)
359+
except OutputGuardrailTripwireTriggered:
360+
raise
361+
except asyncio.CancelledError:
362+
raise
347363
except Exception:
364+
logger.error("Unexpected error in output guardrails", exc_info=True)
348365
return []
349366

350367

@@ -1014,7 +1031,7 @@ async def _save_stream_items_without_count(
10141031
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
10151032
break
10161033
except Exception as e:
1017-
if current_span and not isinstance(e, ModelBehaviorError):
1034+
if current_span and _should_attach_generic_agent_error(e):
10181035
_error_tracing.attach_error_to_span(
10191036
current_span,
10201037
SpanError(
@@ -1037,7 +1054,7 @@ async def _save_stream_items_without_count(
10371054
)
10381055
raise
10391056
except Exception as e:
1040-
if current_span and not isinstance(e, ModelBehaviorError):
1057+
if current_span and _should_attach_generic_agent_error(e):
10411058
_error_tracing.attach_error_to_span(
10421059
current_span,
10431060
SpanError(

tests/test_agent_runner_streamed.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,6 +1290,34 @@ def guardrail_function(
12901290
pass
12911291

12921292

1293+
@pytest.mark.asyncio
1294+
async def test_output_guardrail_tripwire_raises_from_run_loop_task_before_stream_consumption():
1295+
def guardrail_function(
1296+
context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any
1297+
) -> GuardrailFunctionOutput:
1298+
return GuardrailFunctionOutput(
1299+
output_info=None,
1300+
tripwire_triggered=True,
1301+
)
1302+
1303+
model = FakeModel(initial_output=[get_text_message("first_test")])
1304+
1305+
agent = Agent(
1306+
name="test",
1307+
output_guardrails=[OutputGuardrail(guardrail_function=guardrail_function)],
1308+
model=model,
1309+
)
1310+
1311+
result = Runner.run_streamed(agent, input="user_message")
1312+
1313+
assert result.run_loop_task is not None
1314+
with pytest.raises(OutputGuardrailTripwireTriggered):
1315+
await result.run_loop_task
1316+
1317+
assert result.final_output is None
1318+
assert result.is_complete is True
1319+
1320+
12931321
@pytest.mark.asyncio
12941322
async def test_run_input_guardrail_tripwire_triggered_causes_exception_streamed():
12951323
def guardrail_function(

0 commit comments

Comments
 (0)