Skip to content

Commit 30e3020

Browse files
feat: add plugins parameter to Agent (#1734)
Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com>
1 parent 029c77a commit 30e3020

2 files changed

Lines changed: 84 additions & 0 deletions

File tree

src/strands/agent/agent.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from ..interrupt import _InterruptState
4747
from ..models.bedrock import BedrockModel
4848
from ..models.model import Model
49+
from ..plugins import Plugin
50+
from ..plugins.registry import _PluginRegistry
4951
from ..session.session_manager import SessionManager
5052
from ..telemetry.metrics import EventLoopMetrics
5153
from ..telemetry.tracer import get_tracer, serialize
@@ -126,6 +128,7 @@ def __init__(
126128
name: str | None = None,
127129
description: str | None = None,
128130
state: AgentState | dict | None = None,
131+
plugins: list[Plugin] | None = None,
129132
hooks: list[HookProvider] | None = None,
130133
session_manager: SessionManager | None = None,
131134
structured_output_prompt: str | None = None,
@@ -176,6 +179,10 @@ def __init__(
176179
Defaults to None.
177180
state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict.
178181
Defaults to an empty AgentState object.
182+
plugins: List of Plugin instances to extend agent functionality.
183+
Plugins are initialized with the agent instance after construction and can register hooks,
184+
modify agent attributes, or perform other setup tasks.
185+
Defaults to None.
179186
hooks: hooks to be added to the agent hook registry
180187
Defaults to None.
181188
session_manager: Manager for handling agent sessions including conversation history and state.
@@ -265,6 +272,8 @@ def __init__(
265272

266273
self.hooks = HookRegistry()
267274

275+
self._plugin_registry = _PluginRegistry(self)
276+
268277
self._interrupt_state = _InterruptState()
269278

270279
# Initialize lock for guarding concurrent invocations
@@ -311,6 +320,11 @@ def __init__(
311320
if hooks:
312321
for hook in hooks:
313322
self.hooks.add_hook(hook)
323+
324+
if plugins:
325+
for plugin in plugins:
326+
self._plugin_registry.add_and_init(plugin)
327+
314328
self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self))
315329

316330
@property

tests/strands/agent/test_agent.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2619,3 +2619,73 @@ def untyped_callback(event):
26192619

26202620
with pytest.raises(ValueError, match="cannot infer event type"):
26212621
agent.add_hook(untyped_callback)
2622+
2623+
2624+
def test_agent_plugins_sync_initialization():
2625+
"""Test that plugins with sync init_plugin are initialized correctly."""
2626+
plugin_mock = unittest.mock.Mock()
2627+
plugin_mock.name = "test-plugin"
2628+
plugin_mock.init_plugin = unittest.mock.Mock()
2629+
2630+
agent = Agent(
2631+
model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]),
2632+
plugins=[plugin_mock],
2633+
)
2634+
2635+
plugin_mock.init_plugin.assert_called_once_with(agent)
2636+
2637+
2638+
def test_agent_plugins_async_initialization():
2639+
"""Test that plugins with async init_plugin are initialized correctly."""
2640+
plugin_mock = unittest.mock.Mock()
2641+
plugin_mock.name = "async-plugin"
2642+
plugin_mock.init_plugin = unittest.mock.AsyncMock()
2643+
2644+
agent = Agent(
2645+
model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]),
2646+
plugins=[plugin_mock],
2647+
)
2648+
2649+
plugin_mock.init_plugin.assert_called_once_with(agent)
2650+
2651+
2652+
def test_agent_plugins_multiple_in_order():
2653+
"""Test that multiple plugins are initialized in order."""
2654+
call_order = []
2655+
2656+
plugin1 = unittest.mock.Mock()
2657+
plugin1.name = "plugin1"
2658+
plugin1.init_plugin = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin1"))
2659+
2660+
plugin2 = unittest.mock.Mock()
2661+
plugin2.name = "plugin2"
2662+
plugin2.init_plugin = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin2"))
2663+
2664+
Agent(
2665+
model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]),
2666+
plugins=[plugin1, plugin2],
2667+
)
2668+
2669+
assert call_order == ["plugin1", "plugin2"]
2670+
2671+
2672+
def test_agent_plugins_can_register_hooks():
2673+
"""Test that plugins can register hooks during initialization."""
2674+
hook_called = []
2675+
2676+
class TestPlugin:
2677+
name = "hook-plugin"
2678+
2679+
def init_plugin(self, agent):
2680+
def hook_callback(event: BeforeModelCallEvent):
2681+
hook_called.append(True)
2682+
2683+
agent.add_hook(hook_callback)
2684+
2685+
agent = Agent(
2686+
model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]),
2687+
plugins=[TestPlugin()],
2688+
)
2689+
2690+
agent("test")
2691+
assert len(hook_called) == 1

0 commit comments

Comments
 (0)