Skip to content

Commit 296b485

Browse files
author
Emaan Khan
committed
feat(swarm): add AgentBase protocol support
1 parent 194c69b commit 296b485

File tree

3 files changed

+540
-27
lines changed

3 files changed

+540
-27
lines changed

src/strands/multiagent/swarm.py

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from .._async import run_async
2828
from ..agent import Agent
29+
from ..agent.base import AgentBase
2930
from ..agent.state import AgentState
3031
from ..hooks.events import (
3132
AfterMultiAgentInvocationEvent,
@@ -65,16 +66,19 @@ class SwarmNode:
6566
"""Represents a node (e.g. Agent) in the swarm."""
6667

6768
node_id: str
68-
executor: Agent
69+
executor: AgentBase
6970
swarm: Optional["Swarm"] = None
7071
_initial_messages: Messages = field(default_factory=list, init=False)
7172
_initial_state: AgentState = field(default_factory=AgentState, init=False)
7273

7374
def __post_init__(self) -> None:
7475
"""Capture initial executor state after initialization."""
7576
# Deep copy the initial messages and state to preserve them
76-
self._initial_messages = copy.deepcopy(self.executor.messages)
77-
self._initial_state = AgentState(self.executor.state.get())
77+
if hasattr(self.executor, "messages"):
78+
self._initial_messages = copy.deepcopy(self.executor.messages)
79+
80+
if hasattr(self.executor, "state") and hasattr(self.executor.state, "get"):
81+
self._initial_state = AgentState(self.executor.state.get())
7882

7983
def __hash__(self) -> int:
8084
"""Return hash for SwarmNode based on node_id."""
@@ -99,15 +103,20 @@ def reset_executor_state(self) -> None:
99103
100104
If Swarm is resuming from an interrupt, we reset the executor state from the interrupt context.
101105
"""
102-
if self.swarm and self.swarm._interrupt_state.activated:
106+
# Handle interrupt state restoration (Agent-specific)
107+
if self.swarm and self.swarm._interrupt_state.activated and isinstance(self.executor, Agent):
103108
context = self.swarm._interrupt_state.context[self.node_id]
104109
self.executor.messages = context["messages"]
105110
self.executor.state = AgentState(context["state"])
106111
self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"])
107112
return
108113

109-
self.executor.messages = copy.deepcopy(self._initial_messages)
110-
self.executor.state = AgentState(self._initial_state.get())
114+
# Reset to initial state (works with any AgentBase that has these attributes)
115+
if hasattr(self.executor, "messages"):
116+
self.executor.messages = copy.deepcopy(self._initial_messages)
117+
118+
if hasattr(self.executor, "state"):
119+
self.executor.state = AgentState(self._initial_state.get())
111120

112121

113122
@dataclass
@@ -232,9 +241,9 @@ class Swarm(MultiAgentBase):
232241

233242
def __init__(
234243
self,
235-
nodes: list[Agent],
244+
nodes: list[AgentBase],
236245
*,
237-
entry_point: Agent | None = None,
246+
entry_point: AgentBase | None = None,
238247
max_handoffs: int = 20,
239248
max_iterations: int = 20,
240249
execution_timeout: float = 900.0,
@@ -458,19 +467,20 @@ async def _stream_with_timeout(
458467
except asyncio.TimeoutError as err:
459468
raise Exception(timeout_message) from err
460469

461-
def _setup_swarm(self, nodes: list[Agent]) -> None:
470+
def _setup_swarm(self, nodes: list[AgentBase]) -> None:
462471
"""Initialize swarm configuration."""
463472
# Validate nodes before setup
464473
self._validate_swarm(nodes)
465474

466475
# Validate agents have names and create SwarmNode objects
467476
for i, node in enumerate(nodes):
468-
if not node.name:
477+
# Only access name if it exists (AgentBase protocol doesn't guarantee it)
478+
node_name = getattr(node, "name", None)
479+
if not node_name:
469480
node_id = f"node_{i}"
470-
node.name = node_id
471-
logger.debug("node_id=<%s> | agent has no name, dynamically generating one", node_id)
472-
473-
node_id = str(node.name)
481+
logger.debug("node_id=<%s> | agent has no name, using generated id", node_id)
482+
else:
483+
node_id = str(node_name)
474484

475485
# Ensure node IDs are unique
476486
if node_id in self.nodes:
@@ -480,7 +490,7 @@ def _setup_swarm(self, nodes: list[Agent]) -> None:
480490

481491
# Validate entry point if specified
482492
if self.entry_point is not None:
483-
entry_point_node_id = str(self.entry_point.name)
493+
entry_point_node_id = str(getattr(self.entry_point, "name", None))
484494
if (
485495
entry_point_node_id not in self.nodes
486496
or self.nodes[entry_point_node_id].executor is not self.entry_point
@@ -500,7 +510,7 @@ def _setup_swarm(self, nodes: list[Agent]) -> None:
500510
first_node = next(iter(self.nodes.keys()))
501511
logger.debug("entry_point=<%s> | using first node as entry point", first_node)
502512

503-
def _validate_swarm(self, nodes: list[Agent]) -> None:
513+
def _validate_swarm(self, nodes: list[AgentBase]) -> None:
504514
"""Validate swarm structure and nodes."""
505515
# Check for duplicate object instances
506516
seen_instances = set()
@@ -509,18 +519,31 @@ def _validate_swarm(self, nodes: list[Agent]) -> None:
509519
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
510520
seen_instances.add(id(node))
511521

512-
# Check for session persistence
513-
if node._session_manager is not None:
522+
# Check for session persistence (only Agent has _session_manager attribute)
523+
if isinstance(node, Agent) and node._session_manager is not None:
514524
raise ValueError("Session persistence is not supported for Swarm agents yet.")
515525

516526
def _inject_swarm_tools(self) -> None:
517-
"""Add swarm coordination tools to each agent."""
527+
"""Add swarm coordination tools to each agent.
528+
529+
Note: Only Agent instances can receive swarm tools. AgentBase implementations
530+
without tool_registry will not have handoff capabilities.
531+
"""
518532
# Create tool functions with proper closures
519533
swarm_tools = [
520534
self._create_handoff_tool(),
521535
]
522536

537+
injected_count = 0
523538
for node in self.nodes.values():
539+
# Only Agent (not generic AgentBase) has tool_registry attribute
540+
if not isinstance(node.executor, Agent):
541+
logger.debug(
542+
"node_id=<%s> | skipping tool injection for non-Agent node",
543+
node.node_id,
544+
)
545+
continue
546+
524547
# Check for existing tools with conflicting names
525548
existing_tools = node.executor.tool_registry.registry
526549
conflicting_tools = []
@@ -536,11 +559,13 @@ def _inject_swarm_tools(self) -> None:
536559

537560
# Use the agent's tool registry to process and register the tools
538561
node.executor.tool_registry.process_tools(swarm_tools)
562+
injected_count += 1
539563

540564
logger.debug(
541-
"tool_count=<%d>, node_count=<%d> | injected coordination tools into agents",
565+
"tool_count=<%d>, node_count=<%d>, injected_count=<%d> | injected coordination tools",
542566
len(swarm_tools),
543567
len(self.nodes),
568+
injected_count,
544569
)
545570

546571
def _create_handoff_tool(self) -> Callable[..., Any]:
@@ -692,12 +717,14 @@ def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> M
692717
logger.debug("node=<%s> | node interrupted", node.node_id)
693718
self.state.completion_status = Status.INTERRUPTED
694719

695-
self._interrupt_state.context[node.node_id] = {
696-
"activated": node.executor._interrupt_state.activated,
697-
"interrupt_state": node.executor._interrupt_state.to_dict(),
698-
"state": node.executor.state.get(),
699-
"messages": node.executor.messages,
700-
}
720+
# Only Agent (not generic AgentBase) has _interrupt_state, state, and messages attributes
721+
if isinstance(node.executor, Agent):
722+
self._interrupt_state.context[node.node_id] = {
723+
"activated": node.executor._interrupt_state.activated,
724+
"interrupt_state": node.executor._interrupt_state.to_dict(),
725+
"state": node.executor.state.get(),
726+
"messages": node.executor.messages,
727+
}
701728

702729
self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts})
703730
self._interrupt_state.activate()
@@ -1037,5 +1064,7 @@ def _from_dict(self, payload: dict[str, Any]) -> None:
10371064

10381065
def _initial_node(self) -> SwarmNode:
10391066
if self.entry_point:
1040-
return self.nodes[str(self.entry_point.name)]
1067+
entry_point_name = getattr(self.entry_point, "name", None)
1068+
if entry_point_name and str(entry_point_name) in self.nodes:
1069+
return self.nodes[str(entry_point_name)]
10411070
return next(iter(self.nodes.values())) # First SwarmNode

0 commit comments

Comments
 (0)