Skip to content

Commit fe072f5

Browse files
committed
refactor(beta): simplify ToolPolicyMiddleware per review
Address review feedback from @Lancetnik on ag2ai#2533: - Flatten constructor API: pass blocked_tools/allowed_tools/on_blocked directly instead of requiring callers to build a ToolPolicyConfig - Drop stateful counters (_total_tool_calls, _total_blocked, _lock) and their accessor properties -- the middleware is now stateless - Add optional on_blocked callback so callers can wire their own observability (metrics, audit log) without the library retaining any per-factory state The ToolPolicyConfig dataclass is kept as an internal struct for immutability (tuple freezing in __post_init__).
1 parent d53df4a commit fe072f5

2 files changed

Lines changed: 106 additions & 45 deletions

File tree

autogen/beta/middleware/builtin/tool_policy.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44

55
"""Tool policy middleware -- blocks disallowed tool calls before execution."""
66

7-
import threading
7+
from collections.abc import Callable, Sequence
88
from dataclasses import dataclass, field
99

1010
from autogen.beta.annotations import Context
1111
from autogen.beta.events import BaseEvent, ToolCallEvent, ToolErrorEvent
1212
from autogen.beta.middleware.base import BaseMiddleware, MiddlewareFactory, ToolExecution, ToolResultType
1313
from autogen.beta.tools import ToolResult
1414

15+
OnBlockedCallback = Callable[[ToolCallEvent, str], None]
16+
1517

1618
@dataclass
1719
class ToolPolicyConfig:
@@ -71,40 +73,39 @@ def check(self, tool_name: str) -> tuple[bool, str]:
7173
class ToolPolicyMiddleware(MiddlewareFactory):
7274
"""Factory that creates per-invocation tool policy middleware instances.
7375
74-
A single ToolPolicyMiddleware shares counters across all instances it
75-
creates, so statistics accumulate over the lifetime of the factory.
76+
The middleware is stateless: no counters or shared mutable state are
77+
retained across invocations. To observe denied calls, pass an
78+
``on_blocked`` callback when constructing the middleware.
7679
7780
Example::
7881
79-
config = ToolPolicyConfig(
82+
def audit(call: ToolCallEvent, reason: str) -> None:
83+
logger.info("blocked %s: %s", call.name, reason)
84+
85+
86+
mw = ToolPolicyMiddleware(
8087
blocked_tools=["delete_all"],
8188
allowed_tools=["search", "calc"],
89+
on_blocked=audit,
8290
)
83-
mw = ToolPolicyMiddleware(config)
8491
agent = MyAgent(middleware=[mw])
8592
"""
8693

87-
def __init__(self, config: ToolPolicyConfig | None = None) -> None:
88-
self._config = config or ToolPolicyConfig()
94+
def __init__(
95+
self,
96+
blocked_tools: Sequence[str] | None = None,
97+
allowed_tools: Sequence[str] | None = None,
98+
on_blocked: OnBlockedCallback | None = None,
99+
) -> None:
100+
self._config = ToolPolicyConfig(
101+
blocked_tools=list(blocked_tools) if blocked_tools else [],
102+
allowed_tools=list(allowed_tools) if allowed_tools is not None else None,
103+
)
89104
self._policy = _ToolPolicy(self._config)
90-
self._total_tool_calls: int = 0
91-
self._total_blocked: int = 0
92-
self._lock = threading.Lock()
93-
94-
@property
95-
def total_tool_calls(self) -> int:
96-
"""Number of tool calls that passed the policy check and were forwarded."""
97-
with self._lock:
98-
return self._total_tool_calls
99-
100-
@property
101-
def total_blocked(self) -> int:
102-
"""Number of tool calls that were blocked."""
103-
with self._lock:
104-
return self._total_blocked
105+
self._on_blocked = on_blocked
105106

106107
def __call__(self, event: "BaseEvent", context: "Context") -> "BaseMiddleware":
107-
return _ToolPolicyInstance(event, context, self._policy, self)
108+
return _ToolPolicyInstance(event, context, self._policy, self._on_blocked)
108109

109110

110111
class _ToolPolicyInstance(BaseMiddleware):
@@ -115,11 +116,11 @@ def __init__(
115116
event: "BaseEvent",
116117
context: "Context",
117118
policy: _ToolPolicy,
118-
factory: ToolPolicyMiddleware,
119+
on_blocked: OnBlockedCallback | None,
119120
) -> None:
120121
super().__init__(event, context)
121122
self._policy = policy
122-
self._factory = factory
123+
self._on_blocked = on_blocked
123124

124125
async def on_tool_execution(
125126
self,
@@ -129,10 +130,8 @@ async def on_tool_execution(
129130
) -> ToolResultType:
130131
allowed, reason = self._policy.check(event.name)
131132
if not allowed:
132-
with self._factory._lock:
133-
self._factory._total_blocked += 1
133+
if self._on_blocked is not None:
134+
self._on_blocked(event, reason)
134135
return _make_tool_error(event, reason)
135136

136-
with self._factory._lock:
137-
self._factory._total_tool_calls += 1
138137
return await call_next(event, context)

test/beta/middleware/builtins/test_tool_policy_middleware.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -162,36 +162,98 @@ def test_tool_error_content_matches_reason(self) -> None:
162162

163163

164164
# ---------------------------------------------------------------------------
165-
# TestToolPolicyStats
165+
# TestToolPolicyMiddleware
166166
# ---------------------------------------------------------------------------
167167

168168

169-
class TestToolPolicyStats:
169+
class TestToolPolicyMiddleware:
170+
def test_flat_constructor(self) -> None:
171+
# Given arguments passed directly (no explicit config object)
172+
mw = ToolPolicyMiddleware(
173+
blocked_tools=["delete_all"],
174+
allowed_tools=["search"],
175+
)
176+
177+
# Then the internal config is built from those arguments
178+
assert mw._config.blocked_tools == ("delete_all",)
179+
assert mw._config.allowed_tools == ("search",)
180+
181+
def test_default_constructor(self) -> None:
182+
# Given no arguments
183+
mw = ToolPolicyMiddleware()
184+
185+
# Then defaults are empty blocklist and None allowlist (no restriction)
186+
assert mw._config.blocked_tools == ()
187+
assert mw._config.allowed_tools is None
188+
170189
@pytest.mark.asyncio()
171-
async def test_call_and_blocked_counters(self) -> None:
172-
# Given a middleware factory with a blocklist
173-
config = ToolPolicyConfig(blocked_tools=["bad_tool"])
174-
factory = ToolPolicyMiddleware(config)
190+
async def test_on_blocked_callback_fires_on_denial(self) -> None:
191+
# Given a middleware with a blocklist and an on_blocked callback
192+
recorded: list[tuple[str, str]] = []
193+
194+
def audit(call: ToolCallEvent, reason: str) -> None:
195+
recorded.append((call.name, reason))
196+
197+
factory = ToolPolicyMiddleware(blocked_tools=["bad_tool"], on_blocked=audit)
175198

176199
ctx = _make_context()
177200
initial_event = _make_event()
178201

179-
# Prepare a real ToolResultEvent to return from call_next
202+
async def call_next(event: ToolCallEvent, context: mock.MagicMock) -> ToolResultEvent:
203+
return ToolResultEvent(parent_id=event.id, name=event.name, result=ToolResult("ok"))
204+
205+
# When a blocked call is denied
206+
blocked_call = ToolCallEvent(id="c1", name="bad_tool")
207+
instance = factory(initial_event, ctx)
208+
result = await instance.on_tool_execution(call_next, blocked_call, ctx)
209+
210+
# Then the callback was invoked with the call and reason, and a ToolErrorEvent was returned
211+
assert len(recorded) == 1
212+
assert recorded[0][0] == "bad_tool"
213+
assert "blocked" in recorded[0][1]
214+
assert isinstance(result, ToolErrorEvent)
215+
216+
@pytest.mark.asyncio()
217+
async def test_on_blocked_not_called_when_allowed(self) -> None:
218+
# Given a middleware with a callback and an allowed call
219+
recorded: list[tuple[str, str]] = []
220+
221+
def audit(call: ToolCallEvent, reason: str) -> None:
222+
recorded.append((call.name, reason))
223+
224+
factory = ToolPolicyMiddleware(allowed_tools=["good_tool"], on_blocked=audit)
225+
226+
ctx = _make_context()
227+
initial_event = _make_event()
180228
good_call = ToolCallEvent(id="c1", name="good_tool")
181229
good_result = ToolResultEvent(parent_id="c1", name="good_tool", result=ToolResult("ok"))
182230

183231
async def call_next(event: ToolCallEvent, context: mock.MagicMock) -> ToolResultEvent:
184232
return good_result
185233

186-
# When processing one allowed call and one blocked call
187-
instance_allow = factory(initial_event, ctx)
188-
await instance_allow.on_tool_execution(call_next, good_call, ctx)
234+
# When the call passes the policy
235+
instance = factory(initial_event, ctx)
236+
result = await instance.on_tool_execution(call_next, good_call, ctx)
237+
238+
# Then the callback was not invoked and the downstream result was returned
239+
assert recorded == []
240+
assert result is good_result
241+
242+
@pytest.mark.asyncio()
243+
async def test_no_callback_means_no_error(self) -> None:
244+
# Given a middleware without an on_blocked callback
245+
factory = ToolPolicyMiddleware(blocked_tools=["bad_tool"])
246+
247+
ctx = _make_context()
248+
initial_event = _make_event()
249+
250+
async def call_next(event: ToolCallEvent, context: mock.MagicMock) -> ToolResultEvent:
251+
return ToolResultEvent(parent_id=event.id, name=event.name, result=ToolResult("ok"))
189252

190-
blocked_call = ToolCallEvent(id="c2", name="bad_tool")
191-
instance_block = factory(initial_event, ctx)
192-
result = await instance_block.on_tool_execution(call_next, blocked_call, ctx)
253+
# When a blocked call is denied without a callback
254+
blocked_call = ToolCallEvent(id="c1", name="bad_tool")
255+
instance = factory(initial_event, ctx)
256+
result = await instance.on_tool_execution(call_next, blocked_call, ctx)
193257

194-
# Then counters reflect the activity
195-
assert factory.total_tool_calls == 1
196-
assert factory.total_blocked == 1
258+
# Then denial still works, just without observation
197259
assert isinstance(result, ToolErrorEvent)

0 commit comments

Comments
 (0)