Skip to content

Commit faad564

Browse files
feat(plugins): improve plugin creation devex with @hook and @tool decorators (#1740)
Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com>
1 parent 1a3b429 commit faad564

13 files changed

Lines changed: 1249 additions & 194 deletions

File tree

AGENTS.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,12 @@ strands-agents/
124124
│ │
125125
│ ├── hooks/ # Event hooks system
126126
│ │ ├── events.py # Hook event definitions
127-
│ │ └── registry.py # Hook registration
127+
│ │ ├── registry.py # Hook registration
128+
│ │ └── _type_inference.py # Event type inference from type hints
128129
│ │
129130
│ ├── plugins/ # Plugin system
130-
│ │ ├── plugin.py # Plugin definition
131+
│ │ ├── plugin.py # Plugin base class
132+
│ │ ├── decorator.py # @hook decorator
131133
│ │ └── registry.py # PluginRegistry for tracking plugins
132134
│ │
133135
│ ├── handlers/ # Event handlers

src/strands/experimental/steering/core/handler.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from typing import TYPE_CHECKING, Any
3939

4040
from ....hooks.events import AfterModelCallEvent, BeforeToolCallEvent
41-
from ....plugins.plugin import Plugin
41+
from ....plugins import Plugin, hook
4242
from ....types.content import Message
4343
from ....types.streaming import StopReason
4444
from ....types.tools import ToolUse
@@ -66,6 +66,7 @@ def __init__(self, context_providers: list[SteeringContextProvider] | None = Non
6666
Args:
6767
context_providers: List of context providers for context updates
6868
"""
69+
super().__init__()
6970
self.steering_context = SteeringContext()
7071
self._context_callbacks = []
7172

@@ -87,13 +88,8 @@ def init_agent(self, agent: "Agent") -> None:
8788
for callback in self._context_callbacks:
8889
agent.add_hook(lambda event, callback=callback: callback(event, self.steering_context), callback.event_type)
8990

90-
# Register tool steering guidance
91-
agent.add_hook(self._provide_tool_steering_guidance, BeforeToolCallEvent)
92-
93-
# Register model steering guidance
94-
agent.add_hook(self._provide_model_steering_guidance, AfterModelCallEvent)
95-
96-
async def _provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None:
91+
@hook
92+
async def provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None:
9793
"""Provide steering guidance for tool call."""
9894
tool_name = event.tool_use["name"]
9995
logger.debug("tool_name=<%s> | providing tool steering guidance", tool_name)
@@ -133,7 +129,8 @@ def _handle_tool_steering_action(
133129
else:
134130
raise ValueError(f"Unknown steering action type for tool call: {action}")
135131

136-
async def _provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None:
132+
@hook
133+
async def provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None:
137134
"""Provide steering guidance for model response."""
138135
logger.debug("providing model steering guidance")
139136

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""Utility for inferring event types from callback type hints."""
2+
3+
import inspect
4+
import logging
5+
import types
6+
from typing import TYPE_CHECKING, Union, cast, get_args, get_origin, get_type_hints
7+
8+
if TYPE_CHECKING:
9+
from .registry import HookCallback, TEvent
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def infer_event_types(callback: "HookCallback[TEvent]") -> "list[type[TEvent]]":
15+
"""Infer the event type(s) from a callback's type hints.
16+
17+
Supports both single types and union types (A | B or Union[A, B]).
18+
19+
Args:
20+
callback: The callback function to inspect.
21+
22+
Returns:
23+
A list of event types inferred from the callback's first parameter type hint.
24+
25+
Raises:
26+
ValueError: If the event type cannot be inferred from the callback's type hints,
27+
or if a union contains None or non-BaseHookEvent types.
28+
"""
29+
# Import here to avoid circular dependency
30+
from .registry import BaseHookEvent
31+
32+
try:
33+
hints = get_type_hints(callback)
34+
except Exception as e:
35+
logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e)
36+
raise ValueError(
37+
"failed to get type hints for callback | cannot infer event type, please provide event_type explicitly"
38+
) from e
39+
40+
# Get the first parameter's type hint
41+
sig = inspect.signature(callback)
42+
params = list(sig.parameters.values())
43+
44+
if not params:
45+
raise ValueError("callback has no parameters | cannot infer event type, please provide event_type explicitly")
46+
47+
# Skip 'self' and 'cls' parameters for methods
48+
first_param = params[0]
49+
if first_param.name in ("self", "cls") and len(params) > 1:
50+
first_param = params[1]
51+
52+
type_hint = hints.get(first_param.name)
53+
54+
if type_hint is None:
55+
raise ValueError(
56+
f"parameter=<{first_param.name}> has no type hint | "
57+
"cannot infer event type, please provide event_type explicitly"
58+
)
59+
60+
# Check if it's a Union type (Union[A, B] or A | B)
61+
origin = get_origin(type_hint)
62+
if origin is Union or origin is types.UnionType:
63+
event_types: list[type[TEvent]] = []
64+
for arg in get_args(type_hint):
65+
if arg is type(None):
66+
raise ValueError("None is not a valid event type in union")
67+
if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)):
68+
raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent")
69+
event_types.append(cast("type[TEvent]", arg))
70+
return event_types
71+
72+
# Handle single type
73+
if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent):
74+
return [cast("type[TEvent]", type_hint)]
75+
76+
raise ValueError(
77+
f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent"
78+
)

src/strands/hooks/registry.py

Lines changed: 3 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,12 @@
99

1010
import inspect
1111
import logging
12-
import types
1312
from collections.abc import Awaitable, Generator
1413
from dataclasses import dataclass
15-
from typing import (
16-
TYPE_CHECKING,
17-
Any,
18-
Generic,
19-
Protocol,
20-
TypeVar,
21-
Union,
22-
cast,
23-
get_args,
24-
get_origin,
25-
get_type_hints,
26-
runtime_checkable,
27-
)
14+
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable
2815

2916
from ..interrupt import Interrupt, InterruptException
17+
from ._type_inference import infer_event_types
3018

3119
if TYPE_CHECKING:
3220
from ..agent import Agent
@@ -225,7 +213,7 @@ def multi_handler(event):
225213
resolved_event_types = self._validate_event_type_list(event_type)
226214
elif event_type is None:
227215
# Infer event type(s) from callback type hints
228-
resolved_event_types = self._infer_event_types(callback)
216+
resolved_event_types = infer_event_types(callback)
229217
else:
230218
# Single event type provided explicitly
231219
resolved_event_types = [event_type]
@@ -261,67 +249,6 @@ def _validate_event_type_list(self, event_types: list[type[TEvent]]) -> list[typ
261249
validated.append(et)
262250
return validated
263251

264-
def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent]]:
265-
"""Infer the event type(s) from a callback's type hints.
266-
267-
Supports both single types and union types (A | B or Union[A, B]).
268-
269-
Args:
270-
callback: The callback function to inspect.
271-
272-
Returns:
273-
A list of event types inferred from the callback's first parameter type hint.
274-
275-
Raises:
276-
ValueError: If the event type cannot be inferred from the callback's type hints,
277-
or if a union contains None or non-BaseHookEvent types.
278-
"""
279-
try:
280-
hints = get_type_hints(callback)
281-
except Exception as e:
282-
logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e)
283-
raise ValueError(
284-
"failed to get type hints for callback | cannot infer event type, please provide event_type explicitly"
285-
) from e
286-
287-
# Get the first parameter's type hint
288-
sig = inspect.signature(callback)
289-
params = list(sig.parameters.values())
290-
291-
if not params:
292-
raise ValueError(
293-
"callback has no parameters | cannot infer event type, please provide event_type explicitly"
294-
)
295-
296-
first_param = params[0]
297-
type_hint = hints.get(first_param.name)
298-
299-
if type_hint is None:
300-
raise ValueError(
301-
f"parameter=<{first_param.name}> has no type hint | "
302-
"cannot infer event type, please provide event_type explicitly"
303-
)
304-
305-
# Check if it's a Union type (Union[A, B] or A | B)
306-
origin = get_origin(type_hint)
307-
if origin is Union or origin is types.UnionType:
308-
event_types: list[type[TEvent]] = []
309-
for arg in get_args(type_hint):
310-
if arg is type(None):
311-
raise ValueError("None is not a valid event type in union")
312-
if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)):
313-
raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent")
314-
event_types.append(cast(type[TEvent], arg))
315-
return event_types
316-
317-
# Handle single type
318-
if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent):
319-
return [cast(type[TEvent], type_hint)]
320-
321-
raise ValueError(
322-
f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent"
323-
)
324-
325252
def add_hook(self, hook: HookProvider) -> None:
326253
"""Register all callbacks from a hook provider.
327254

src/strands/plugins/__init__.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,13 @@
11
"""Plugin system for extending agent functionality.
22
33
This module provides a composable mechanism for building objects that can
4-
extend agent behavior through a standardized initialization pattern.
5-
6-
Example Usage:
7-
```python
8-
from strands.plugins import Plugin
9-
10-
class LoggingPlugin(Plugin):
11-
name = "logging"
12-
13-
def init_agent(self, agent: Agent) -> None:
14-
agent.add_hook(self.on_model_call, BeforeModelCallEvent)
15-
16-
def on_model_call(self, event: BeforeModelCallEvent) -> None:
17-
print(f"Model called for {event.agent.name}")
18-
```
4+
extend agent behavior through automatic hook and tool registration.
195
"""
206

7+
from .decorator import hook
218
from .plugin import Plugin
229

2310
__all__ = [
2411
"Plugin",
12+
"hook",
2513
]

src/strands/plugins/decorator.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""Hook decorator for Plugin methods.
2+
3+
Marks methods as hook callbacks for automatic registration when the plugin
4+
is attached to an agent. Infers event types from type hints and supports
5+
union types for multiple events.
6+
7+
Example:
8+
```python
9+
class MyPlugin(Plugin):
10+
@hook
11+
def on_model_call(self, event: BeforeModelCallEvent):
12+
print(event)
13+
```
14+
"""
15+
16+
from collections.abc import Callable
17+
from typing import Generic, cast, overload
18+
19+
from ..hooks._type_inference import infer_event_types
20+
from ..hooks.registry import HookCallback, TEvent
21+
22+
23+
class _WrappedHookCallable(HookCallback, Generic[TEvent]):
24+
"""Wrapped version of HookCallback that includes a `_hook_event_types` attribute."""
25+
26+
_hook_event_types: list[type[TEvent]]
27+
28+
29+
# Handle @hook
30+
@overload
31+
def hook(__func: HookCallback) -> _WrappedHookCallable: ...
32+
33+
34+
# Handle @hook()
35+
@overload
36+
def hook() -> Callable[[HookCallback], _WrappedHookCallable]: ...
37+
38+
39+
def hook(
40+
func: HookCallback | None = None,
41+
) -> _WrappedHookCallable | Callable[[HookCallback], _WrappedHookCallable]:
42+
"""Mark a method as a hook callback for automatic registration.
43+
44+
Infers event type from the callback's type hint. Supports union types
45+
for multiple events. Can be used as @hook or @hook().
46+
47+
Args:
48+
func: The function to decorate.
49+
50+
Returns:
51+
The decorated function with hook metadata.
52+
53+
Raises:
54+
ValueError: If event type cannot be inferred from type hints.
55+
"""
56+
57+
def decorator(f: HookCallback[TEvent]) -> _WrappedHookCallable[TEvent]:
58+
# Infer event types from type hints
59+
event_types: list[type[TEvent]] = infer_event_types(f)
60+
61+
# Store hook metadata on the function
62+
f_wrapped = cast(_WrappedHookCallable, f)
63+
f_wrapped._hook_event_types = event_types
64+
65+
return f_wrapped
66+
67+
if func is None:
68+
return decorator
69+
return decorator(func)

0 commit comments

Comments
 (0)