Skip to content

Commit 0ec9074

Browse files
Emaan Khanemaan-c
authored andcommitted
feat(swarm): add AgentBase protocol support
1 parent 50b2c79 commit 0ec9074

File tree

5 files changed

+636
-46
lines changed

5 files changed

+636
-46
lines changed

src/strands/multiagent/swarm.py

Lines changed: 82 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from .._async import run_async
2828
from ..agent import Agent
29+
from ..agent.base import AgentBase
2930
from ..agent.state import AgentState
3031
from ..hooks.events import (
3132
AfterMultiAgentInvocationEvent,
@@ -65,7 +66,7 @@ class SwarmNode:
6566
"""Represents a node (e.g. Agent) in the swarm."""
6667

6768
node_id: str
68-
executor: Agent
69+
executor: AgentBase
6970
swarm: Optional["Swarm"] = None
7071
_initial_messages: Messages = field(default_factory=list, init=False)
7172
_initial_state: AgentState = field(default_factory=AgentState, init=False)
@@ -74,9 +75,14 @@ class SwarmNode:
7475
def __post_init__(self) -> None:
7576
"""Capture initial executor state after initialization."""
7677
# Deep copy the initial messages and state to preserve them
77-
self._initial_messages = copy.deepcopy(self.executor.messages)
78-
self._initial_state = AgentState(self.executor.state.get())
79-
self._initial_model_state = copy.deepcopy(self.executor._model_state)
78+
if hasattr(self.executor, "messages"):
79+
self._initial_messages = copy.deepcopy(self.executor.messages)
80+
81+
if hasattr(self.executor, "state") and hasattr(self.executor.state, "get"):
82+
self._initial_state = AgentState(self.executor.state.get())
83+
84+
if hasattr(self.executor, "_model_state"):
85+
self._initial_model_state = copy.deepcopy(self.executor._model_state)
8086

8187
def __hash__(self) -> int:
8288
"""Return hash for SwarmNode based on node_id."""
@@ -101,17 +107,26 @@ def reset_executor_state(self) -> None:
101107
102108
If Swarm is resuming from an interrupt, we reset the executor state from the interrupt context.
103109
"""
104-
if self.swarm and self.swarm._interrupt_state.activated:
110+
# Handle interrupt state restoration (Agent-specific)
111+
if self.swarm and self.swarm._interrupt_state.activated and isinstance(self.executor, Agent):
112+
if self.node_id not in self.swarm._interrupt_state.context:
113+
return
105114
context = self.swarm._interrupt_state.context[self.node_id]
106115
self.executor.messages = context["messages"]
107116
self.executor.state = AgentState(context["state"])
108117
self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"])
109118
self.executor._model_state = context.get("model_state", {})
110119
return
111120

112-
self.executor.messages = copy.deepcopy(self._initial_messages)
113-
self.executor.state = AgentState(self._initial_state.get())
114-
self.executor._model_state = copy.deepcopy(self._initial_model_state)
121+
# Reset to initial state (works with any AgentBase that has these attributes)
122+
if hasattr(self.executor, "messages"):
123+
self.executor.messages = copy.deepcopy(self._initial_messages)
124+
125+
if hasattr(self.executor, "state"):
126+
self.executor.state = AgentState(self._initial_state.get())
127+
128+
if hasattr(self.executor, "_model_state"):
129+
self.executor._model_state = copy.deepcopy(self._initial_model_state)
115130

116131

117132
@dataclass
@@ -236,9 +251,9 @@ class Swarm(MultiAgentBase):
236251

237252
def __init__(
238253
self,
239-
nodes: list[Agent],
254+
nodes: list[AgentBase],
240255
*,
241-
entry_point: Agent | None = None,
256+
entry_point: AgentBase | None = None,
242257
max_handoffs: int = 20,
243258
max_iterations: int = 20,
244259
execution_timeout: float = 900.0,
@@ -301,6 +316,7 @@ def __init__(
301316

302317
self._resume_from_session = False
303318

319+
self._handoff_capable_nodes: set[str] = set()
304320
self._setup_swarm(nodes)
305321
self._inject_swarm_tools()
306322
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
@@ -462,33 +478,35 @@ async def _stream_with_timeout(
462478
except asyncio.TimeoutError as err:
463479
raise Exception(timeout_message) from err
464480

465-
def _setup_swarm(self, nodes: list[Agent]) -> None:
481+
def _setup_swarm(self, nodes: list[AgentBase]) -> None:
466482
"""Initialize swarm configuration."""
467483
# Validate nodes before setup
468484
self._validate_swarm(nodes)
469485

470486
# Validate agents have names and create SwarmNode objects
471487
for i, node in enumerate(nodes):
472-
if not node.name:
488+
# Only access name if it exists (AgentBase protocol doesn't guarantee it)
489+
node_name = getattr(node, "name", None)
490+
if not node_name:
473491
node_id = f"node_{i}"
474-
node.name = node_id
475-
logger.debug("node_id=<%s> | agent has no name, dynamically generating one", node_id)
476-
477-
node_id = str(node.name)
492+
logger.debug("node_id=<%s> | agent has no name, using generated id", node_id)
493+
else:
494+
node_id = str(node_name)
478495

479496
# Ensure node IDs are unique
480497
if node_id in self.nodes:
481498
raise ValueError(f"Node ID '{node_id}' is not unique. Each agent must have a unique name.")
482499

483500
self.nodes[node_id] = SwarmNode(node_id, node, swarm=self)
484501

485-
# Validate entry point if specified
502+
# Validate entry point if specified (use identity-based lookup to handle nameless AgentBase)
486503
if self.entry_point is not None:
487-
entry_point_node_id = str(self.entry_point.name)
488-
if (
489-
entry_point_node_id not in self.nodes
490-
or self.nodes[entry_point_node_id].executor is not self.entry_point
491-
):
504+
entry_node = None
505+
for swarm_node in self.nodes.values():
506+
if swarm_node.executor is self.entry_point:
507+
entry_node = swarm_node
508+
break
509+
if entry_node is None:
492510
available_agents = [
493511
f"{node_id} ({type(node.executor).__name__})" for node_id, node in self.nodes.items()
494512
]
@@ -504,7 +522,7 @@ def _setup_swarm(self, nodes: list[Agent]) -> None:
504522
first_node = next(iter(self.nodes.keys()))
505523
logger.debug("entry_point=<%s> | using first node as entry point", first_node)
506524

507-
def _validate_swarm(self, nodes: list[Agent]) -> None:
525+
def _validate_swarm(self, nodes: list[AgentBase]) -> None:
508526
"""Validate swarm structure and nodes."""
509527
# Check for duplicate object instances
510528
seen_instances = set()
@@ -513,18 +531,31 @@ def _validate_swarm(self, nodes: list[Agent]) -> None:
513531
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
514532
seen_instances.add(id(node))
515533

516-
# Check for session persistence
517-
if node._session_manager is not None:
534+
# Check for session persistence (only Agent has _session_manager attribute)
535+
if isinstance(node, Agent) and node._session_manager is not None:
518536
raise ValueError("Session persistence is not supported for Swarm agents yet.")
519537

520538
def _inject_swarm_tools(self) -> None:
521-
"""Add swarm coordination tools to each agent."""
539+
"""Add swarm coordination tools to each agent.
540+
541+
Note: Only Agent instances can receive swarm tools. AgentBase implementations
542+
without tool_registry will not have handoff capabilities.
543+
"""
522544
# Create tool functions with proper closures
523545
swarm_tools = [
524546
self._create_handoff_tool(),
525547
]
526548

549+
injected_count = 0
527550
for node in self.nodes.values():
551+
# Only Agent (not generic AgentBase) has tool_registry attribute
552+
if not isinstance(node.executor, Agent):
553+
logger.debug(
554+
"node_id=<%s> | skipping tool injection for non-Agent node",
555+
node.node_id,
556+
)
557+
continue
558+
528559
# Check for existing tools with conflicting names
529560
existing_tools = node.executor.tool_registry.registry
530561
conflicting_tools = []
@@ -540,11 +571,14 @@ def _inject_swarm_tools(self) -> None:
540571

541572
# Use the agent's tool registry to process and register the tools
542573
node.executor.tool_registry.process_tools(swarm_tools)
574+
self._handoff_capable_nodes.add(node.node_id)
575+
injected_count += 1
543576

544577
logger.debug(
545-
"tool_count=<%d>, node_count=<%d> | injected coordination tools into agents",
578+
"tool_count=<%d>, node_count=<%d>, injected_count=<%d> | injected coordination tools",
546579
len(swarm_tools),
547580
len(self.nodes),
581+
injected_count,
548582
)
549583

550584
def _create_handoff_tool(self) -> Callable[..., Any]:
@@ -673,10 +707,13 @@ def _build_node_input(self, target_node: SwarmNode) -> str:
673707
context_text += "\n"
674708
context_text += "\n"
675709

676-
context_text += (
677-
"You have access to swarm coordination tools if you need help from other agents. "
678-
"If you don't hand off to another agent, the swarm will consider the task complete."
679-
)
710+
if target_node.node_id in self._handoff_capable_nodes:
711+
context_text += (
712+
"You have access to swarm coordination tools if you need help from other agents. "
713+
"If you don't hand off to another agent, the swarm will consider the task complete."
714+
)
715+
else:
716+
context_text += "If you complete your task, the swarm will consider the task complete."
680717

681718
return context_text
682719

@@ -696,13 +733,19 @@ def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> M
696733
logger.debug("node=<%s> | node interrupted", node.node_id)
697734
self.state.completion_status = Status.INTERRUPTED
698735

736+
# Only Agent (not generic AgentBase) has _interrupt_state, state, and messages attributes
699737
self._interrupt_state.context[node.node_id] = {
700-
"activated": node.executor._interrupt_state.activated,
701-
"interrupt_state": node.executor._interrupt_state.to_dict(),
702-
"state": node.executor.state.get(),
703-
"messages": node.executor.messages,
704-
"model_state": node.executor._model_state,
738+
"activated": isinstance(node.executor, Agent) and node.executor._interrupt_state.activated,
705739
}
740+
if isinstance(node.executor, Agent):
741+
self._interrupt_state.context[node.node_id].update(
742+
{
743+
"interrupt_state": node.executor._interrupt_state.to_dict(),
744+
"state": node.executor.state.get(),
745+
"messages": node.executor.messages,
746+
"model_state": node.executor._model_state,
747+
}
748+
)
706749

707750
self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts})
708751
self._interrupt_state.activate()
@@ -1042,5 +1085,7 @@ def _from_dict(self, payload: dict[str, Any]) -> None:
10421085

10431086
def _initial_node(self) -> SwarmNode:
10441087
if self.entry_point:
1045-
return self.nodes[str(self.entry_point.name)]
1088+
for node in self.nodes.values():
1089+
if node.executor is self.entry_point:
1090+
return node
10461091
return next(iter(self.nodes.values())) # First SwarmNode

src/strands/telemetry/tracer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -527,9 +527,7 @@ def start_event_loop_cycle_span(
527527
event_loop_cycle_id = str(invocation_state.get("event_loop_cycle_id"))
528528
parent_span = parent_span if parent_span else invocation_state.get("event_loop_parent_span")
529529

530-
attributes: dict[str, AttributeValue] = self._get_common_attributes(
531-
operation_name="execute_event_loop_cycle"
532-
)
530+
attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="execute_event_loop_cycle")
533531
attributes["event_loop.cycle_id"] = event_loop_cycle_id
534532

535533
if custom_trace_attributes:

0 commit comments

Comments
 (0)