Skip to content

Commit 2fc8209

Browse files
author
Emaan Khan
committed
feat(swarm): add AgentBase protocol support
1 parent 9ecd960 commit 2fc8209

File tree

2 files changed

+1
-438
lines changed

2 files changed

+1
-438
lines changed

tests/strands/multiagent/test_swarm_agentbase.py

Lines changed: 0 additions & 201 deletions
Original file line numberDiff line numberDiff line change
@@ -156,24 +156,6 @@ def test_swarm_node_state_management_with_agentbase():
156156
node.reset_executor_state() # Should complete without error
157157

158158

159-
def test_swarm_node_state_management_with_agent():
160-
"""Test SwarmNode state management with Agent (hasattr should find attributes)."""
161-
# Create Agent with messages/state attributes
162-
agent = create_mock_agent("agent_with_state")
163-
164-
# Add actual messages and state
165-
agent.messages = [{"role": "user", "content": [{"text": "Hello"}]}]
166-
agent.state = Mock()
167-
agent.state.get = Mock(return_value={"key": "value"})
168-
169-
# Create SwarmNode
170-
node = SwarmNode(node_id="test_node", executor=agent)
171-
172-
# __post_init__ should capture state
173-
assert len(node._initial_messages) == 1
174-
assert node._initial_state.get()["key"] == "value"
175-
176-
177159
@pytest.mark.asyncio
178160
async def test_swarm_execution_with_agentbase():
179161
"""Test Swarm execution with AgentBase implementations."""
@@ -195,27 +177,6 @@ async def test_swarm_execution_with_agentbase():
195177
assert agentbase._call_count >= 1
196178

197179

198-
@pytest.mark.asyncio
199-
async def test_swarm_execution_mixed_agents():
200-
"""Test Swarm execution with mixed Agent and AgentBase nodes."""
201-
# Create mixed nodes
202-
agent = create_mock_agent("regular_agent", "Agent response")
203-
agentbase = MockAgentBase("agentbase_node", "AgentBase response")
204-
205-
# Create swarm with agent as entry point (it has tools for handoff)
206-
swarm = Swarm(nodes=[agent, agentbase], entry_point=agent)
207-
208-
# Execute swarm
209-
result = await swarm.invoke_async("Test mixed execution")
210-
211-
# Verify execution completed
212-
assert result.status == Status.COMPLETED
213-
assert len(result.results) >= 1
214-
215-
# At least the entry point should have executed
216-
assert "regular_agent" in result.results
217-
218-
219180
def test_swarm_agentbase_without_name_attribute():
220181
"""Test Swarm handles AgentBase without name attribute."""
221182

@@ -247,96 +208,6 @@ async def stream_async(self, prompt: Any = None, **kwargs: Any):
247208
assert "node_0" in swarm.nodes
248209

249210

250-
def test_swarm_agentbase_entry_point():
251-
"""Test Swarm with AgentBase as entry point."""
252-
agentbase1 = MockAgentBase("agentbase1")
253-
agentbase2 = MockAgentBase("agentbase2")
254-
255-
# Set agentbase2 as entry point
256-
swarm = Swarm(nodes=[agentbase1, agentbase2], entry_point=agentbase2)
257-
258-
# Verify entry point is set
259-
assert swarm.entry_point is agentbase2
260-
261-
# Verify initial node uses entry point
262-
initial_node = swarm._initial_node()
263-
assert initial_node.node_id == "agentbase2"
264-
265-
266-
def test_swarm_agentbase_entry_point_without_name():
267-
"""Test entry point handling when AgentBase has no name attribute."""
268-
269-
class AgentBaseNoName:
270-
"""AgentBase without name attribute."""
271-
272-
async def invoke_async(self, prompt: Any = None, **kwargs: Any) -> AgentResult:
273-
return AgentResult(
274-
message={"role": "assistant", "content": [{"text": "response"}]},
275-
stop_reason="end_turn",
276-
state={},
277-
metrics=None,
278-
)
279-
280-
def __call__(self, prompt: Any = None, **kwargs: Any) -> AgentResult:
281-
return asyncio.run(self.invoke_async(prompt, **kwargs))
282-
283-
async def stream_async(self, prompt: Any = None, **kwargs: Any):
284-
result = await self.invoke_async(prompt, **kwargs)
285-
yield {"result": result}
286-
287-
agentbase = AgentBaseNoName()
288-
289-
# Should create swarm and fallback to first node
290-
swarm = Swarm(nodes=[agentbase])
291-
292-
# Should fallback to first node when entry point has no name
293-
initial_node = swarm._initial_node()
294-
assert initial_node.node_id == "node_0"
295-
296-
297-
def test_swarm_node_hasattr_pattern_for_messages():
298-
"""Test that SwarmNode uses hasattr for messages (Graph pattern)."""
299-
# Create AgentBase without messages
300-
agentbase = MockAgentBase("agentbase_node")
301-
302-
# Create SwarmNode
303-
node = SwarmNode(node_id="test_node", executor=agentbase)
304-
305-
# Should not have captured messages (none exist)
306-
assert node._initial_messages == []
307-
308-
# Add messages attribute dynamically
309-
agentbase.messages = [{"role": "user", "content": [{"text": "Test"}]}]
310-
311-
# Create new node with messages
312-
node_with_messages = SwarmNode(node_id="test_node2", executor=agentbase)
313-
314-
# Should capture messages via hasattr
315-
assert len(node_with_messages._initial_messages) == 1
316-
317-
318-
def test_swarm_node_hasattr_pattern_for_state():
319-
"""Test that SwarmNode uses hasattr for state (Graph pattern)."""
320-
# Create AgentBase without state
321-
agentbase = MockAgentBase("agentbase_node")
322-
323-
# Create SwarmNode
324-
node = SwarmNode(node_id="test_node", executor=agentbase)
325-
326-
# Should not have captured state (none exists)
327-
assert node._initial_state.get() == {}
328-
329-
# Add state attribute dynamically
330-
agentbase.state = Mock()
331-
agentbase.state.get = Mock(return_value={"test": "data"})
332-
333-
# Create new node with state
334-
node_with_state = SwarmNode(node_id="test_node2", executor=agentbase)
335-
336-
# Should capture state via hasattr
337-
assert node_with_state._initial_state.get()["test"] == "data"
338-
339-
340211
def test_swarm_interrupt_handling_with_agentbase():
341212
"""Test that interrupt handling only saves Agent-specific context."""
342213
from strands.interrupt import Interrupt
@@ -368,75 +239,3 @@ def test_swarm_interrupt_handling_with_agentbase():
368239
# Should NOT have saved AgentBase context (isinstance check should prevent it)
369240
# The interrupt is registered but no Agent-specific context is saved
370241
assert swarm._interrupt_state.activated
371-
372-
373-
@pytest.mark.asyncio
374-
async def test_swarm_agentbase_streaming():
375-
"""Test that AgentBase streaming works correctly."""
376-
# Create AgentBase
377-
agentbase = MockAgentBase("streaming_agentbase", "Streaming response")
378-
379-
# Create swarm
380-
swarm = Swarm(nodes=[agentbase])
381-
382-
# Collect events
383-
events = []
384-
async for event in swarm.stream_async("Test streaming"):
385-
events.append(event)
386-
387-
# Should have received events
388-
assert len(events) > 0
389-
390-
# Should have node start/stop events
391-
node_start_events = [e for e in events if e.get("type") == "multiagent_node_start"]
392-
node_stop_events = [e for e in events if e.get("type") == "multiagent_node_stop"]
393-
394-
assert len(node_start_events) >= 1
395-
assert len(node_stop_events) >= 1
396-
397-
# Should have final result
398-
result_events = [e for e in events if "result" in e and e.get("type") != "multiagent_node_stream"]
399-
assert len(result_events) == 1
400-
401-
402-
def test_swarm_duplicate_agentbase_instances():
403-
"""Test that Swarm rejects duplicate AgentBase instances."""
404-
agentbase = MockAgentBase("agentbase_node")
405-
406-
# Try to create swarm with same instance twice
407-
with pytest.raises(ValueError, match="Duplicate node instance detected"):
408-
Swarm(nodes=[agentbase, agentbase])
409-
410-
411-
def test_swarm_agentbase_without_session_manager():
412-
"""Test that AgentBase nodes don't trigger session manager validation."""
413-
# Create AgentBase (no _session_manager attribute)
414-
agentbase = MockAgentBase("agentbase_node")
415-
416-
# Should not raise error (validation only applies to Agent instances)
417-
swarm = Swarm(nodes=[agentbase])
418-
assert len(swarm.nodes) == 1
419-
420-
421-
def test_swarm_node_reset_with_agentbase():
422-
"""Test SwarmNode reset_executor_state with AgentBase."""
423-
# Create AgentBase
424-
agentbase = MockAgentBase("agentbase_node")
425-
426-
# Add state attributes that can be reset
427-
agentbase.messages = [{"role": "user", "content": [{"text": "Original"}]}]
428-
agentbase.state = Mock()
429-
agentbase.state.get = Mock(return_value={"original": "data"})
430-
431-
# Create node
432-
node = SwarmNode(node_id="test_node", executor=agentbase)
433-
434-
# Modify state
435-
agentbase.messages.append({"role": "assistant", "content": [{"text": "Modified"}]})
436-
437-
# Reset state
438-
node.reset_executor_state()
439-
440-
# Should have reset to initial state
441-
assert len(agentbase.messages) == 1
442-
assert agentbase.messages[0]["content"][0]["text"] == "Original"

0 commit comments

Comments
 (0)