Skip to content

Commit c2d0c73

Browse files
committed
fix: stop streamed tool execution after known input guardrail tripwire
1 parent a5cd815 commit c2d0c73

File tree

5 files changed

+84
-1
lines changed

5 files changed

+84
-1
lines changed

src/agents/result.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ class RunResultStreaming(RunResultBase):
454454
# Store the asyncio tasks that we're waiting on
455455
run_loop_task: asyncio.Task[Any] | None = field(default=None, repr=False)
456456
_input_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
457+
_triggered_input_guardrail_result: InputGuardrailResult | None = field(default=None, repr=False)
457458
_output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
458459
_stored_exception: Exception | None = field(default=None, repr=False)
459460
_cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False)

src/agents/run_internal/guardrails.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ async def run_input_guardrails_with_queue(
8484
},
8585
),
8686
)
87+
streamed_result._triggered_input_guardrail_result = result
8788
queue.put_nowait(result)
8889
guardrail_results.append(result)
8990
break

src/agents/run_internal/run_loop.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,9 @@ async def _save_stream_items_without_count(
10141014
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
10151015
break
10161016
except Exception as e:
1017-
if current_span and not isinstance(e, ModelBehaviorError):
1017+
if current_span and not isinstance(
1018+
e, (ModelBehaviorError, InputGuardrailTripwireTriggered)
1019+
):
10181020
_error_tracing.attach_error_to_span(
10191021
current_span,
10201022
SpanError(
@@ -1100,6 +1102,24 @@ async def run_single_turn_streamed(
11001102
reasoning_item_id_policy: ReasoningItemIdPolicy | None = None,
11011103
) -> SingleStepResult:
11021104
"""Run a single streamed turn and emit events as results arrive."""
1105+
1106+
async def raise_if_input_guardrail_tripwire_known() -> None:
1107+
tripwire_result = streamed_result._triggered_input_guardrail_result
1108+
if tripwire_result is not None:
1109+
raise InputGuardrailTripwireTriggered(tripwire_result)
1110+
1111+
task = streamed_result._input_guardrails_task
1112+
if task is None or not task.done():
1113+
return
1114+
1115+
guardrail_exception = task.exception()
1116+
if guardrail_exception is not None:
1117+
raise guardrail_exception
1118+
1119+
tripwire_result = streamed_result._triggered_input_guardrail_result
1120+
if tripwire_result is not None:
1121+
raise InputGuardrailTripwireTriggered(tripwire_result)
1122+
11031123
emitted_tool_call_ids: set[str] = set()
11041124
emitted_reasoning_item_ids: set[str] = set()
11051125
emitted_tool_search_fingerprints: set[str] = set()
@@ -1433,6 +1453,7 @@ async def rewind_model_request() -> None:
14331453
run_config=run_config,
14341454
tool_use_tracker=tool_use_tracker,
14351455
event_queue=streamed_result._event_queue,
1456+
before_side_effects=raise_if_input_guardrail_tripwire_known,
14361457
)
14371458

14381459
items_to_filter = session_items_for_turn(single_step_result)

src/agents/run_internal/turn_resolution.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,7 @@ async def get_single_step_result_from_response(
16961696
run_config: RunConfig,
16971697
tool_use_tracker,
16981698
event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] | None = None,
1699+
before_side_effects: Callable[[], Awaitable[None]] | None = None,
16991700
) -> SingleStepResult:
17001701
processed_response = process_model_response(
17011702
agent=agent,
@@ -1706,6 +1707,9 @@ async def get_single_step_result_from_response(
17061707
existing_items=pre_step_items,
17071708
)
17081709

1710+
if before_side_effects is not None:
1711+
await before_side_effects()
1712+
17091713
tool_use_tracker.record_processed_response(agent, processed_response)
17101714

17111715
if event_queue is not None and processed_response.new_items:

tests/test_guardrails.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,62 @@ async def slow_parallel_check(
658658
assert model.first_turn_args is not None, "Model should have been called in parallel mode"
659659

660660

661+
@pytest.mark.asyncio
662+
async def test_parallel_guardrail_trip_before_tool_execution_stops_streaming_turn():
663+
tool_was_executed = False
664+
model_started = asyncio.Event()
665+
guardrail_tripped = asyncio.Event()
666+
667+
@function_tool
668+
def dangerous_tool() -> str:
669+
nonlocal tool_was_executed
670+
tool_was_executed = True
671+
return "tool_executed"
672+
673+
@input_guardrail(run_in_parallel=True)
674+
async def tripwire_before_tool_execution(
675+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
676+
) -> GuardrailFunctionOutput:
677+
await asyncio.wait_for(model_started.wait(), timeout=1)
678+
guardrail_tripped.set()
679+
return GuardrailFunctionOutput(
680+
output_info="parallel_trip_before_tool_execution",
681+
tripwire_triggered=True,
682+
)
683+
684+
model = FakeModel()
685+
original_stream_response = model.stream_response
686+
687+
async def delayed_stream_response(*args, **kwargs):
688+
model_started.set()
689+
await asyncio.wait_for(guardrail_tripped.wait(), timeout=1)
690+
await asyncio.sleep(SHORT_DELAY)
691+
async for event in original_stream_response(*args, **kwargs):
692+
yield event
693+
694+
agent = Agent(
695+
name="streaming_guardrail_hardening_agent",
696+
instructions="Call the dangerous_tool immediately",
697+
tools=[dangerous_tool],
698+
input_guardrails=[tripwire_before_tool_execution],
699+
model=model,
700+
)
701+
model.set_next_output([get_function_tool_call("dangerous_tool", arguments="{}")])
702+
model.set_next_output([get_text_message("done")])
703+
704+
with patch.object(model, "stream_response", side_effect=delayed_stream_response):
705+
result = Runner.run_streamed(agent, "trigger guardrail")
706+
707+
with pytest.raises(InputGuardrailTripwireTriggered):
708+
async for _event in result.stream_events():
709+
pass
710+
711+
assert model_started.is_set() is True
712+
assert guardrail_tripped.is_set() is True
713+
assert tool_was_executed is False
714+
assert model.first_turn_args is not None, "Model should have been called in parallel mode"
715+
716+
661717
@pytest.mark.asyncio
662718
async def test_blocking_guardrail_prevents_tool_execution():
663719
tool_was_executed = False

0 commit comments

Comments
 (0)