2626
2727from .._async import run_async
2828from ..agent import Agent
29+ from ..agent .base import AgentBase
2930from ..agent .state import AgentState
3031from ..hooks .events import (
3132 AfterMultiAgentInvocationEvent ,
@@ -65,16 +66,19 @@ class SwarmNode:
6566 """Represents a node (e.g. Agent) in the swarm."""
6667
6768 node_id : str
68- executor : Agent
69+ executor : AgentBase
6970 swarm : Optional ["Swarm" ] = None
7071 _initial_messages : Messages = field (default_factory = list , init = False )
7172 _initial_state : AgentState = field (default_factory = AgentState , init = False )
7273
7374 def __post_init__ (self ) -> None :
7475 """Capture initial executor state after initialization."""
7576 # Deep copy the initial messages and state to preserve them
76- self ._initial_messages = copy .deepcopy (self .executor .messages )
77- self ._initial_state = AgentState (self .executor .state .get ())
77+ if hasattr (self .executor , "messages" ):
78+ self ._initial_messages = copy .deepcopy (self .executor .messages )
79+
80+ if hasattr (self .executor , "state" ) and hasattr (self .executor .state , "get" ):
81+ self ._initial_state = AgentState (self .executor .state .get ())
7882
7983 def __hash__ (self ) -> int :
8084 """Return hash for SwarmNode based on node_id."""
@@ -99,15 +103,20 @@ def reset_executor_state(self) -> None:
99103
100104 If Swarm is resuming from an interrupt, we reset the executor state from the interrupt context.
101105 """
102- if self .swarm and self .swarm ._interrupt_state .activated :
106+ # Handle interrupt state restoration (Agent-specific)
107+ if self .swarm and self .swarm ._interrupt_state .activated and isinstance (self .executor , Agent ):
103108 context = self .swarm ._interrupt_state .context [self .node_id ]
104109 self .executor .messages = context ["messages" ]
105110 self .executor .state = AgentState (context ["state" ])
106111 self .executor ._interrupt_state = _InterruptState .from_dict (context ["interrupt_state" ])
107112 return
108113
109- self .executor .messages = copy .deepcopy (self ._initial_messages )
110- self .executor .state = AgentState (self ._initial_state .get ())
114+ # Reset to initial state (works with any AgentBase that has these attributes)
115+ if hasattr (self .executor , "messages" ):
116+ self .executor .messages = copy .deepcopy (self ._initial_messages )
117+
118+ if hasattr (self .executor , "state" ):
119+ self .executor .state = AgentState (self ._initial_state .get ())
111120
112121
113122@dataclass
@@ -232,9 +241,9 @@ class Swarm(MultiAgentBase):
232241
233242 def __init__ (
234243 self ,
235- nodes : list [Agent ],
244+ nodes : list [AgentBase ],
236245 * ,
237- entry_point : Agent | None = None ,
246+ entry_point : AgentBase | None = None ,
238247 max_handoffs : int = 20 ,
239248 max_iterations : int = 20 ,
240249 execution_timeout : float = 900.0 ,
@@ -458,19 +467,20 @@ async def _stream_with_timeout(
458467 except asyncio .TimeoutError as err :
459468 raise Exception (timeout_message ) from err
460469
461- def _setup_swarm (self , nodes : list [Agent ]) -> None :
470+ def _setup_swarm (self , nodes : list [AgentBase ]) -> None :
462471 """Initialize swarm configuration."""
463472 # Validate nodes before setup
464473 self ._validate_swarm (nodes )
465474
466475 # Validate agents have names and create SwarmNode objects
467476 for i , node in enumerate (nodes ):
468- if not node .name :
477+ # Only access name if it exists (AgentBase protocol doesn't guarantee it)
478+ node_name = getattr (node , "name" , None )
479+ if not node_name :
469480 node_id = f"node_{ i } "
470- node .name = node_id
471- logger .debug ("node_id=<%s> | agent has no name, dynamically generating one" , node_id )
472-
473- node_id = str (node .name )
481+ logger .debug ("node_id=<%s> | agent has no name, using generated id" , node_id )
482+ else :
483+ node_id = str (node_name )
474484
475485 # Ensure node IDs are unique
476486 if node_id in self .nodes :
@@ -480,7 +490,7 @@ def _setup_swarm(self, nodes: list[Agent]) -> None:
480490
481491 # Validate entry point if specified
482492 if self .entry_point is not None :
483- entry_point_node_id = str (self .entry_point . name )
493+ entry_point_node_id = str (getattr ( self .entry_point , " name" , None ) )
484494 if (
485495 entry_point_node_id not in self .nodes
486496 or self .nodes [entry_point_node_id ].executor is not self .entry_point
@@ -500,7 +510,7 @@ def _setup_swarm(self, nodes: list[Agent]) -> None:
500510 first_node = next (iter (self .nodes .keys ()))
501511 logger .debug ("entry_point=<%s> | using first node as entry point" , first_node )
502512
503- def _validate_swarm (self , nodes : list [Agent ]) -> None :
513+ def _validate_swarm (self , nodes : list [AgentBase ]) -> None :
504514 """Validate swarm structure and nodes."""
505515 # Check for duplicate object instances
506516 seen_instances = set ()
@@ -509,18 +519,31 @@ def _validate_swarm(self, nodes: list[Agent]) -> None:
509519 raise ValueError ("Duplicate node instance detected. Each node must have a unique object instance." )
510520 seen_instances .add (id (node ))
511521
512- # Check for session persistence
513- if node ._session_manager is not None :
522+ # Check for session persistence (only Agent has _session_manager attribute)
523+ if isinstance ( node , Agent ) and node ._session_manager is not None :
514524 raise ValueError ("Session persistence is not supported for Swarm agents yet." )
515525
516526 def _inject_swarm_tools (self ) -> None :
517- """Add swarm coordination tools to each agent."""
527+ """Add swarm coordination tools to each agent.
528+
529+ Note: Only Agent instances can receive swarm tools. AgentBase implementations
530+ without tool_registry will not have handoff capabilities.
531+ """
518532 # Create tool functions with proper closures
519533 swarm_tools = [
520534 self ._create_handoff_tool (),
521535 ]
522536
537+ injected_count = 0
523538 for node in self .nodes .values ():
539+ # Only Agent (not generic AgentBase) has tool_registry attribute
540+ if not isinstance (node .executor , Agent ):
541+ logger .debug (
542+ "node_id=<%s> | skipping tool injection for non-Agent node" ,
543+ node .node_id ,
544+ )
545+ continue
546+
524547 # Check for existing tools with conflicting names
525548 existing_tools = node .executor .tool_registry .registry
526549 conflicting_tools = []
@@ -536,11 +559,13 @@ def _inject_swarm_tools(self) -> None:
536559
537560 # Use the agent's tool registry to process and register the tools
538561 node .executor .tool_registry .process_tools (swarm_tools )
562+ injected_count += 1
539563
540564 logger .debug (
541- "tool_count=<%d>, node_count=<%d> | injected coordination tools into agents " ,
565+ "tool_count=<%d>, node_count=<%d>, injected_count=<%d> | injected coordination tools" ,
542566 len (swarm_tools ),
543567 len (self .nodes ),
568+ injected_count ,
544569 )
545570
546571 def _create_handoff_tool (self ) -> Callable [..., Any ]:
@@ -692,12 +717,14 @@ def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> M
692717 logger .debug ("node=<%s> | node interrupted" , node .node_id )
693718 self .state .completion_status = Status .INTERRUPTED
694719
695- self ._interrupt_state .context [node .node_id ] = {
696- "activated" : node .executor ._interrupt_state .activated ,
697- "interrupt_state" : node .executor ._interrupt_state .to_dict (),
698- "state" : node .executor .state .get (),
699- "messages" : node .executor .messages ,
700- }
720+ # Only Agent (not generic AgentBase) has _interrupt_state, state, and messages attributes
721+ if isinstance (node .executor , Agent ):
722+ self ._interrupt_state .context [node .node_id ] = {
723+ "activated" : node .executor ._interrupt_state .activated ,
724+ "interrupt_state" : node .executor ._interrupt_state .to_dict (),
725+ "state" : node .executor .state .get (),
726+ "messages" : node .executor .messages ,
727+ }
701728
702729 self ._interrupt_state .interrupts .update ({interrupt .id : interrupt for interrupt in interrupts })
703730 self ._interrupt_state .activate ()
@@ -1037,5 +1064,7 @@ def _from_dict(self, payload: dict[str, Any]) -> None:
10371064
10381065 def _initial_node (self ) -> SwarmNode :
10391066 if self .entry_point :
1040- return self .nodes [str (self .entry_point .name )]
1067+ entry_point_name = getattr (self .entry_point , "name" , None )
1068+ if entry_point_name and str (entry_point_name ) in self .nodes :
1069+ return self .nodes [str (entry_point_name )]
10411070 return next (iter (self .nodes .values ())) # First SwarmNode
0 commit comments