@@ -658,6 +658,62 @@ async def slow_parallel_check(
658658 assert model .first_turn_args is not None , "Model should have been called in parallel mode"
659659
660660
661+ @pytest .mark .asyncio
662+ async def test_parallel_guardrail_trip_before_tool_execution_stops_streaming_turn ():
663+ tool_was_executed = False
664+ model_started = asyncio .Event ()
665+ guardrail_tripped = asyncio .Event ()
666+
667+ @function_tool
668+ def dangerous_tool () -> str :
669+ nonlocal tool_was_executed
670+ tool_was_executed = True
671+ return "tool_executed"
672+
673+ @input_guardrail (run_in_parallel = True )
674+ async def tripwire_before_tool_execution (
675+ ctx : RunContextWrapper [Any ], agent : Agent [Any ], input : str | list [TResponseInputItem ]
676+ ) -> GuardrailFunctionOutput :
677+ await asyncio .wait_for (model_started .wait (), timeout = 1 )
678+ guardrail_tripped .set ()
679+ return GuardrailFunctionOutput (
680+ output_info = "parallel_trip_before_tool_execution" ,
681+ tripwire_triggered = True ,
682+ )
683+
684+ model = FakeModel ()
685+ original_stream_response = model .stream_response
686+
687+ async def delayed_stream_response (* args , ** kwargs ):
688+ model_started .set ()
689+ await asyncio .wait_for (guardrail_tripped .wait (), timeout = 1 )
690+ await asyncio .sleep (SHORT_DELAY )
691+ async for event in original_stream_response (* args , ** kwargs ):
692+ yield event
693+
694+ agent = Agent (
695+ name = "streaming_guardrail_hardening_agent" ,
696+ instructions = "Call the dangerous_tool immediately" ,
697+ tools = [dangerous_tool ],
698+ input_guardrails = [tripwire_before_tool_execution ],
699+ model = model ,
700+ )
701+ model .set_next_output ([get_function_tool_call ("dangerous_tool" , arguments = "{}" )])
702+ model .set_next_output ([get_text_message ("done" )])
703+
704+ with patch .object (model , "stream_response" , side_effect = delayed_stream_response ):
705+ result = Runner .run_streamed (agent , "trigger guardrail" )
706+
707+ with pytest .raises (InputGuardrailTripwireTriggered ):
708+ async for _event in result .stream_events ():
709+ pass
710+
711+ assert model_started .is_set () is True
712+ assert guardrail_tripped .is_set () is True
713+ assert tool_was_executed is False
714+ assert model .first_turn_args is not None , "Model should have been called in parallel mode"
715+
716+
661717@pytest .mark .asyncio
662718async def test_blocking_guardrail_prevents_tool_execution ():
663719 tool_was_executed = False
0 commit comments