2626
2727from .._async import run_async
2828from ..agent import Agent
29+ from ..agent .base import AgentBase
2930from ..agent .state import AgentState
3031from ..hooks .events import (
3132 AfterMultiAgentInvocationEvent ,
@@ -65,7 +66,7 @@ class SwarmNode:
6566 """Represents a node (e.g. Agent) in the swarm."""
6667
6768 node_id : str
68- executor : Agent
69+ executor : AgentBase
6970 swarm : Optional ["Swarm" ] = None
7071 _initial_messages : Messages = field (default_factory = list , init = False )
7172 _initial_state : AgentState = field (default_factory = AgentState , init = False )
@@ -74,9 +75,14 @@ class SwarmNode:
7475 def __post_init__ (self ) -> None :
7576 """Capture initial executor state after initialization."""
7677 # Deep copy the initial messages and state to preserve them
77- self ._initial_messages = copy .deepcopy (self .executor .messages )
78- self ._initial_state = AgentState (self .executor .state .get ())
79- self ._initial_model_state = copy .deepcopy (self .executor ._model_state )
78+ if hasattr (self .executor , "messages" ):
79+ self ._initial_messages = copy .deepcopy (self .executor .messages )
80+
81+ if hasattr (self .executor , "state" ) and hasattr (self .executor .state , "get" ):
82+ self ._initial_state = AgentState (self .executor .state .get ())
83+
84+ if hasattr (self .executor , "_model_state" ):
85+ self ._initial_model_state = copy .deepcopy (self .executor ._model_state )
8086
8187 def __hash__ (self ) -> int :
8288 """Return hash for SwarmNode based on node_id."""
@@ -101,17 +107,26 @@ def reset_executor_state(self) -> None:
101107
102108 If Swarm is resuming from an interrupt, we reset the executor state from the interrupt context.
103109 """
104- if self .swarm and self .swarm ._interrupt_state .activated :
110+ # Handle interrupt state restoration (Agent-specific)
111+ if self .swarm and self .swarm ._interrupt_state .activated and isinstance (self .executor , Agent ):
112+ if self .node_id not in self .swarm ._interrupt_state .context :
113+ return
105114 context = self .swarm ._interrupt_state .context [self .node_id ]
106115 self .executor .messages = context ["messages" ]
107116 self .executor .state = AgentState (context ["state" ])
108117 self .executor ._interrupt_state = _InterruptState .from_dict (context ["interrupt_state" ])
109118 self .executor ._model_state = context .get ("model_state" , {})
110119 return
111120
112- self .executor .messages = copy .deepcopy (self ._initial_messages )
113- self .executor .state = AgentState (self ._initial_state .get ())
114- self .executor ._model_state = copy .deepcopy (self ._initial_model_state )
121+ # Reset to initial state (works with any AgentBase that has these attributes)
122+ if hasattr (self .executor , "messages" ):
123+ self .executor .messages = copy .deepcopy (self ._initial_messages )
124+
125+ if hasattr (self .executor , "state" ):
126+ self .executor .state = AgentState (self ._initial_state .get ())
127+
128+ if hasattr (self .executor , "_model_state" ):
129+ self .executor ._model_state = copy .deepcopy (self ._initial_model_state )
115130
116131
117132@dataclass
@@ -236,9 +251,9 @@ class Swarm(MultiAgentBase):
236251
237252 def __init__ (
238253 self ,
239- nodes : list [Agent ],
254+ nodes : list [AgentBase ],
240255 * ,
241- entry_point : Agent | None = None ,
256+ entry_point : AgentBase | None = None ,
242257 max_handoffs : int = 20 ,
243258 max_iterations : int = 20 ,
244259 execution_timeout : float = 900.0 ,
@@ -301,6 +316,7 @@ def __init__(
301316
302317 self ._resume_from_session = False
303318
319+ self ._handoff_capable_nodes : set [str ] = set ()
304320 self ._setup_swarm (nodes )
305321 self ._inject_swarm_tools ()
306322 run_async (lambda : self .hooks .invoke_callbacks_async (MultiAgentInitializedEvent (self )))
@@ -462,33 +478,35 @@ async def _stream_with_timeout(
462478 except asyncio .TimeoutError as err :
463479 raise Exception (timeout_message ) from err
464480
465- def _setup_swarm (self , nodes : list [Agent ]) -> None :
481+ def _setup_swarm (self , nodes : list [AgentBase ]) -> None :
466482 """Initialize swarm configuration."""
467483 # Validate nodes before setup
468484 self ._validate_swarm (nodes )
469485
470486 # Validate agents have names and create SwarmNode objects
471487 for i , node in enumerate (nodes ):
472- if not node .name :
488+ # Only access name if it exists (AgentBase protocol doesn't guarantee it)
489+ node_name = getattr (node , "name" , None )
490+ if not node_name :
473491 node_id = f"node_{ i } "
474- node .name = node_id
475- logger .debug ("node_id=<%s> | agent has no name, dynamically generating one" , node_id )
476-
477- node_id = str (node .name )
492+ logger .debug ("node_id=<%s> | agent has no name, using generated id" , node_id )
493+ else :
494+ node_id = str (node_name )
478495
479496 # Ensure node IDs are unique
480497 if node_id in self .nodes :
481498 raise ValueError (f"Node ID '{ node_id } ' is not unique. Each agent must have a unique name." )
482499
483500 self .nodes [node_id ] = SwarmNode (node_id , node , swarm = self )
484501
485- # Validate entry point if specified
502+ # Validate entry point if specified (use identity-based lookup to handle nameless AgentBase)
486503 if self .entry_point is not None :
487- entry_point_node_id = str (self .entry_point .name )
488- if (
489- entry_point_node_id not in self .nodes
490- or self .nodes [entry_point_node_id ].executor is not self .entry_point
491- ):
504+ entry_node = None
505+ for swarm_node in self .nodes .values ():
506+ if swarm_node .executor is self .entry_point :
507+ entry_node = swarm_node
508+ break
509+ if entry_node is None :
492510 available_agents = [
493511 f"{ node_id } ({ type (node .executor ).__name__ } )" for node_id , node in self .nodes .items ()
494512 ]
@@ -504,7 +522,7 @@ def _setup_swarm(self, nodes: list[Agent]) -> None:
504522 first_node = next (iter (self .nodes .keys ()))
505523 logger .debug ("entry_point=<%s> | using first node as entry point" , first_node )
506524
507- def _validate_swarm (self , nodes : list [Agent ]) -> None :
525+ def _validate_swarm (self , nodes : list [AgentBase ]) -> None :
508526 """Validate swarm structure and nodes."""
509527 # Check for duplicate object instances
510528 seen_instances = set ()
@@ -513,18 +531,31 @@ def _validate_swarm(self, nodes: list[Agent]) -> None:
513531 raise ValueError ("Duplicate node instance detected. Each node must have a unique object instance." )
514532 seen_instances .add (id (node ))
515533
516- # Check for session persistence
517- if node ._session_manager is not None :
534+ # Check for session persistence (only Agent has _session_manager attribute)
535+ if isinstance ( node , Agent ) and node ._session_manager is not None :
518536 raise ValueError ("Session persistence is not supported for Swarm agents yet." )
519537
520538 def _inject_swarm_tools (self ) -> None :
521- """Add swarm coordination tools to each agent."""
539+ """Add swarm coordination tools to each agent.
540+
541+ Note: Only Agent instances can receive swarm tools. AgentBase implementations
542+ without tool_registry will not have handoff capabilities.
543+ """
522544 # Create tool functions with proper closures
523545 swarm_tools = [
524546 self ._create_handoff_tool (),
525547 ]
526548
549+ injected_count = 0
527550 for node in self .nodes .values ():
551+ # Only Agent (not generic AgentBase) has tool_registry attribute
552+ if not isinstance (node .executor , Agent ):
553+ logger .debug (
554+ "node_id=<%s> | skipping tool injection for non-Agent node" ,
555+ node .node_id ,
556+ )
557+ continue
558+
528559 # Check for existing tools with conflicting names
529560 existing_tools = node .executor .tool_registry .registry
530561 conflicting_tools = []
@@ -540,11 +571,14 @@ def _inject_swarm_tools(self) -> None:
540571
541572 # Use the agent's tool registry to process and register the tools
542573 node .executor .tool_registry .process_tools (swarm_tools )
574+ self ._handoff_capable_nodes .add (node .node_id )
575+ injected_count += 1
543576
544577 logger .debug (
545- "tool_count=<%d>, node_count=<%d> | injected coordination tools into agents " ,
578+ "tool_count=<%d>, node_count=<%d>, injected_count=<%d> | injected coordination tools" ,
546579 len (swarm_tools ),
547580 len (self .nodes ),
581+ injected_count ,
548582 )
549583
550584 def _create_handoff_tool (self ) -> Callable [..., Any ]:
@@ -673,10 +707,13 @@ def _build_node_input(self, target_node: SwarmNode) -> str:
673707 context_text += "\n "
674708 context_text += "\n "
675709
676- context_text += (
677- "You have access to swarm coordination tools if you need help from other agents. "
678- "If you don't hand off to another agent, the swarm will consider the task complete."
679- )
710+ if target_node .node_id in self ._handoff_capable_nodes :
711+ context_text += (
712+ "You have access to swarm coordination tools if you need help from other agents. "
713+ "If you don't hand off to another agent, the swarm will consider the task complete."
714+ )
715+ else :
716+ context_text += "If you complete your task, the swarm will consider the task complete."
680717
681718 return context_text
682719
@@ -696,13 +733,19 @@ def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> M
696733 logger .debug ("node=<%s> | node interrupted" , node .node_id )
697734 self .state .completion_status = Status .INTERRUPTED
698735
736+ # Only Agent (not generic AgentBase) has _interrupt_state, state, and messages attributes
699737 self ._interrupt_state .context [node .node_id ] = {
700- "activated" : node .executor ._interrupt_state .activated ,
701- "interrupt_state" : node .executor ._interrupt_state .to_dict (),
702- "state" : node .executor .state .get (),
703- "messages" : node .executor .messages ,
704- "model_state" : node .executor ._model_state ,
738+ "activated" : isinstance (node .executor , Agent ) and node .executor ._interrupt_state .activated ,
705739 }
740+ if isinstance (node .executor , Agent ):
741+ self ._interrupt_state .context [node .node_id ].update (
742+ {
743+ "interrupt_state" : node .executor ._interrupt_state .to_dict (),
744+ "state" : node .executor .state .get (),
745+ "messages" : node .executor .messages ,
746+ "model_state" : node .executor ._model_state ,
747+ }
748+ )
706749
707750 self ._interrupt_state .interrupts .update ({interrupt .id : interrupt for interrupt in interrupts })
708751 self ._interrupt_state .activate ()
@@ -1042,5 +1085,7 @@ def _from_dict(self, payload: dict[str, Any]) -> None:
10421085
10431086 def _initial_node (self ) -> SwarmNode :
10441087 if self .entry_point :
1045- return self .nodes [str (self .entry_point .name )]
1088+ for node in self .nodes .values ():
1089+ if node .executor is self .entry_point :
1090+ return node
10461091 return next (iter (self .nodes .values ())) # First SwarmNode
0 commit comments