@@ -162,36 +162,98 @@ def test_tool_error_content_matches_reason(self) -> None:
162162
163163
164164# ---------------------------------------------------------------------------
165- # TestToolPolicyStats
165+ # TestToolPolicyMiddleware
166166# ---------------------------------------------------------------------------
167167
168168
169- class TestToolPolicyStats :
169+ class TestToolPolicyMiddleware :
170+ def test_flat_constructor (self ) -> None :
171+ # Given arguments passed directly (no explicit config object)
172+ mw = ToolPolicyMiddleware (
173+ blocked_tools = ["delete_all" ],
174+ allowed_tools = ["search" ],
175+ )
176+
177+ # Then the internal config is built from those arguments
178+ assert mw ._config .blocked_tools == ("delete_all" ,)
179+ assert mw ._config .allowed_tools == ("search" ,)
180+
181+ def test_default_constructor (self ) -> None :
182+ # Given no arguments
183+ mw = ToolPolicyMiddleware ()
184+
185+ # Then defaults are empty blocklist and None allowlist (no restriction)
186+ assert mw ._config .blocked_tools == ()
187+ assert mw ._config .allowed_tools is None
188+
170189 @pytest .mark .asyncio ()
171- async def test_call_and_blocked_counters (self ) -> None :
172- # Given a middleware factory with a blocklist
173- config = ToolPolicyConfig (blocked_tools = ["bad_tool" ])
174- factory = ToolPolicyMiddleware (config )
190+ async def test_on_blocked_callback_fires_on_denial (self ) -> None :
191+ # Given a middleware with a blocklist and an on_blocked callback
192+ recorded : list [tuple [str , str ]] = []
193+
194+ def audit (call : ToolCallEvent , reason : str ) -> None :
195+ recorded .append ((call .name , reason ))
196+
197+ factory = ToolPolicyMiddleware (blocked_tools = ["bad_tool" ], on_blocked = audit )
175198
176199 ctx = _make_context ()
177200 initial_event = _make_event ()
178201
179- # Prepare a real ToolResultEvent to return from call_next
202+ async def call_next (event : ToolCallEvent , context : mock .MagicMock ) -> ToolResultEvent :
203+ return ToolResultEvent (parent_id = event .id , name = event .name , result = ToolResult ("ok" ))
204+
205+ # When a blocked call is denied
206+ blocked_call = ToolCallEvent (id = "c1" , name = "bad_tool" )
207+ instance = factory (initial_event , ctx )
208+ result = await instance .on_tool_execution (call_next , blocked_call , ctx )
209+
210+ # Then the callback was invoked with the call and reason, and a ToolErrorEvent was returned
211+ assert len (recorded ) == 1
212+ assert recorded [0 ][0 ] == "bad_tool"
213+ assert "blocked" in recorded [0 ][1 ]
214+ assert isinstance (result , ToolErrorEvent )
215+
216+ @pytest .mark .asyncio ()
217+ async def test_on_blocked_not_called_when_allowed (self ) -> None :
218+ # Given a middleware with a callback and an allowed call
219+ recorded : list [tuple [str , str ]] = []
220+
221+ def audit (call : ToolCallEvent , reason : str ) -> None :
222+ recorded .append ((call .name , reason ))
223+
224+ factory = ToolPolicyMiddleware (allowed_tools = ["good_tool" ], on_blocked = audit )
225+
226+ ctx = _make_context ()
227+ initial_event = _make_event ()
180228 good_call = ToolCallEvent (id = "c1" , name = "good_tool" )
181229 good_result = ToolResultEvent (parent_id = "c1" , name = "good_tool" , result = ToolResult ("ok" ))
182230
183231 async def call_next (event : ToolCallEvent , context : mock .MagicMock ) -> ToolResultEvent :
184232 return good_result
185233
186- # When processing one allowed call and one blocked call
187- instance_allow = factory (initial_event , ctx )
188- await instance_allow .on_tool_execution (call_next , good_call , ctx )
234+ # When the call passes the policy
235+ instance = factory (initial_event , ctx )
236+ result = await instance .on_tool_execution (call_next , good_call , ctx )
237+
238+ # Then the callback was not invoked and the downstream result was returned
239+ assert recorded == []
240+ assert result is good_result
241+
242+ @pytest .mark .asyncio ()
243+ async def test_no_callback_means_no_error (self ) -> None :
244+ # Given a middleware without an on_blocked callback
245+ factory = ToolPolicyMiddleware (blocked_tools = ["bad_tool" ])
246+
247+ ctx = _make_context ()
248+ initial_event = _make_event ()
249+
250+ async def call_next (event : ToolCallEvent , context : mock .MagicMock ) -> ToolResultEvent :
251+ return ToolResultEvent (parent_id = event .id , name = event .name , result = ToolResult ("ok" ))
189252
190- blocked_call = ToolCallEvent (id = "c2" , name = "bad_tool" )
191- instance_block = factory (initial_event , ctx )
192- result = await instance_block .on_tool_execution (call_next , blocked_call , ctx )
253+ # When a blocked call is denied without a callback
254+ blocked_call = ToolCallEvent (id = "c1" , name = "bad_tool" )
255+ instance = factory (initial_event , ctx )
256+ result = await instance .on_tool_execution (call_next , blocked_call , ctx )
193257
194- # Then counters reflect the activity
195- assert factory .total_tool_calls == 1
196- assert factory .total_blocked == 1
258+ # Then denial still works, just without observation
197259 assert isinstance (result , ToolErrorEvent )
0 commit comments