-
Notifications
You must be signed in to change notification settings - Fork 860
Expand file tree
/
Copy pathregistry.py
More file actions
335 lines (251 loc) · 12.1 KB
/
registry.py
File metadata and controls
335 lines (251 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
"""Hook registry system for managing event callbacks in the Strands Agent SDK.
This module provides the core infrastructure for the typed hook system, enabling
composable extension of agent functionality through strongly-typed event callbacks.
The registry manages the mapping between event types and their associated callback
functions, supporting both individual callback registration and bulk registration
via hook provider objects.
"""
import inspect
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar
from ..interrupt import Interrupt, InterruptException
if TYPE_CHECKING:
from ..agent import Agent
logger = logging.getLogger(__name__)
@dataclass
class BaseHookEvent:
"""Base class for all hook events."""
@property
def should_reverse_callbacks(self) -> bool:
"""Determine if callbacks for this event should be invoked in reverse order.
Returns:
False by default. Override to return True for events that should
invoke callbacks in reverse order (e.g., cleanup/teardown events).
"""
return False
def _can_write(self, name: str) -> bool:
"""Check if the given property can be written to.
Args:
name: The name of the property to check.
Returns:
True if the property can be written to, False otherwise.
"""
return False
def __post_init__(self) -> None:
"""Disallow writes to non-approved properties."""
# This is needed as otherwise the class can't be initialized at all, so we trigger
# this after class initialization
super().__setattr__("_disallow_writes", True)
def __setattr__(self, name: str, value: Any) -> None:
"""Prevent setting attributes on hook events.
Raises:
AttributeError: Always raised to prevent setting attributes on hook events.
"""
# Allow setting attributes:
# - during init (when __dict__) doesn't exist
# - if the subclass specifically said the property is writable
if not hasattr(self, "_disallow_writes") or self._can_write(name):
return super().__setattr__(name, value)
raise AttributeError(f"Property {name} is not writable")
@dataclass
class HookEvent(BaseHookEvent):
"""Base class for single agent hook events.
Attributes:
agent: The agent instance that triggered this event.
"""
agent: "Agent"
TEvent = TypeVar("TEvent", bound=BaseHookEvent, contravariant=True)
"""Generic for adding callback handlers - contravariant to allow adding handlers which take in base classes."""
TInvokeEvent = TypeVar("TInvokeEvent", bound=BaseHookEvent)
"""Generic for invoking events - non-contravariant to enable returning events."""
class HookProvider(Protocol):
"""Protocol for objects that provide hook callbacks to an agent.
Hook providers offer a composable way to extend agent functionality by
subscribing to various events in the agent lifecycle. This protocol enables
building reusable components that can hook into agent events.
Example:
```python
class MyHookProvider(HookProvider):
def register_hooks(self, registry: HookRegistry) -> None:
registry.add_callback(StartRequestEvent, self.on_request_start)
registry.add_callback(EndRequestEvent, self.on_request_end)
agent = Agent(hooks=[MyHookProvider()])
```
"""
def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None:
"""Register callback functions for specific event types.
Args:
registry: The hook registry to register callbacks with.
**kwargs: Additional keyword arguments for future extensibility.
"""
...
class HookCallback(Protocol, Generic[TEvent]):
"""Protocol for callback functions that handle hook events.
Hook callbacks are functions that receive a single strongly-typed event
argument and perform some action in response. They should not return
values and any exceptions they raise will propagate to the caller.
Example:
```python
def my_callback(event: StartRequestEvent) -> None:
print(f"Request started for agent: {event.agent.name}")
# Or
async def my_callback(event: StartRequestEvent) -> None:
# await an async operation
```
"""
def __call__(self, event: TEvent) -> None | Awaitable[None]:
"""Handle a hook event.
Args:
event: The strongly-typed event to handle.
"""
...
class HookRegistry:
"""Registry for managing hook callbacks associated with event types.
The HookRegistry maintains a mapping of event types to callback functions
and provides methods for registering callbacks and invoking them when
events occur.
The registry handles callback ordering, including reverse ordering for
cleanup events, and provides type-safe event dispatching.
"""
def __init__(self) -> None:
"""Initialize an empty hook registry."""
self._registered_callbacks: dict[Type, list[HookCallback]] = {}
def add_callback(self, event_type: Type[TEvent], callback: HookCallback[TEvent]) -> None:
"""Register a callback function for a specific event type.
Args:
event_type: The class type of events this callback should handle.
callback: The callback function to invoke when events of this type occur.
Example:
```python
def my_handler(event: StartRequestEvent):
print("Request started")
registry.add_callback(StartRequestEvent, my_handler)
```
"""
# Related issue: https://github.com/strands-agents/sdk-python/issues/330
if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback):
raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback")
callbacks = self._registered_callbacks.setdefault(event_type, [])
callbacks.append(callback)
def add_hook(self, hook: HookProvider) -> None:
"""Register all callbacks from a hook provider.
This method allows bulk registration of callbacks by delegating to
the hook provider's register_hooks method. This is the preferred
way to register multiple related callbacks.
Args:
hook: The hook provider containing callbacks to register.
Example:
```python
class MyHooks(HookProvider):
def register_hooks(self, registry: HookRegistry):
registry.add_callback(StartRequestEvent, self.on_start)
registry.add_callback(EndRequestEvent, self.on_end)
registry.add_hook(MyHooks())
```
"""
hook.register_hooks(self)
async def invoke_callbacks_async(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]:
"""Invoke all registered callbacks for the given event.
This method finds all callbacks registered for the event's type and
invokes them in the appropriate order. For events with should_reverse_callbacks=True,
callbacks are invoked in reverse registration order. Any exceptions raised by callback
functions will propagate to the caller.
Additionally, this method aggregates interrupts raised by the user to instantiate human-in-the-loop workflows.
Args:
event: The event to dispatch to registered callbacks.
Returns:
The event dispatched to registered callbacks and any interrupts raised by the user.
Raises:
ValueError: If interrupt name is used more than once.
Example:
```python
event = StartRequestEvent(agent=my_agent)
await registry.invoke_callbacks_async(event)
```
"""
interrupts: dict[str, Interrupt] = {}
for callback in self.get_callbacks_for(event):
try:
if inspect.iscoroutinefunction(callback):
await callback(event)
else:
callback(event)
except InterruptException as exception:
interrupt = exception.interrupt
if interrupt.name in interrupts:
message = f"interrupt_name=<{interrupt.name}> | interrupt name used more than once"
logger.error(message)
raise ValueError(message) from exception
# Each callback is allowed to raise their own interrupt.
interrupts[interrupt.name] = interrupt
return event, list(interrupts.values())
def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]:
"""Invoke all registered callbacks for the given event.
This method finds all callbacks registered for the event's type and
invokes them in the appropriate order. For events with should_reverse_callbacks=True,
callbacks are invoked in reverse registration order. Any exceptions raised by callback
functions will propagate to the caller.
Additionally, this method aggregates interrupts raised by the user to instantiate human-in-the-loop workflows.
Args:
event: The event to dispatch to registered callbacks.
Returns:
The event dispatched to registered callbacks and any interrupts raised by the user.
Raises:
RuntimeError: If at least one callback is async.
ValueError: If interrupt name is used more than once.
Example:
```python
event = StartRequestEvent(agent=my_agent)
registry.invoke_callbacks(event)
```
"""
callbacks = list(self.get_callbacks_for(event))
interrupts: dict[str, Interrupt] = {}
if any(inspect.iscoroutinefunction(callback) for callback in callbacks):
raise RuntimeError(f"event=<{event}> | use invoke_callbacks_async to invoke async callback")
for callback in callbacks:
try:
callback(event)
except InterruptException as exception:
interrupt = exception.interrupt
if interrupt.name in interrupts:
message = f"interrupt_name=<{interrupt.name}> | interrupt name used more than once"
logger.error(message)
raise ValueError(message) from exception
# Each callback is allowed to raise their own interrupt.
interrupts[interrupt.name] = interrupt
return event, list(interrupts.values())
def has_callbacks(self) -> bool:
"""Check if the registry has any registered callbacks.
Returns:
True if there are any registered callbacks, False otherwise.
Example:
```python
if registry.has_callbacks():
print("Registry has callbacks registered")
```
"""
return bool(self._registered_callbacks)
def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], None, None]:
"""Get callbacks registered for the given event in the appropriate order.
This method returns callbacks in registration order for normal events,
or reverse registration order for events that have should_reverse_callbacks=True.
This enables proper cleanup ordering for teardown events.
Args:
event: The event to get callbacks for.
Yields:
Callback functions registered for this event type, in the appropriate order.
Example:
```python
event = EndRequestEvent(agent=my_agent)
for callback in registry.get_callbacks_for(event):
callback(event)
```
"""
event_type = type(event)
callbacks = self._registered_callbacks.get(event_type, [])
if event.should_reverse_callbacks:
yield from reversed(callbacks)
else:
yield from callbacks