forked from strands-agents/sdk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase.py
More file actions
119 lines (93 loc) · 4.22 KB
/
base.py
File metadata and controls
119 lines (93 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Multi-Agent Base Class.
Provides minimal foundation for multi-agent patterns (Swarm, Graph).
"""
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Union
from ..agent import AgentResult
from ..types.content import ContentBlock
from ..types.event_loop import Metrics, Usage
class Status(Enum):
"""Execution status for both graphs and nodes."""
PENDING = "pending"
EXECUTING = "executing"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class NodeResult:
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results.
The status field represents the semantic outcome of the node's work:
- COMPLETED: The node's task was successfully accomplished
- FAILED: The node's task failed or produced an error
"""
# Core result data - single AgentResult, nested MultiAgentResult, or Exception
result: Union[AgentResult, "MultiAgentResult", Exception]
# Execution metadata
execution_time: int = 0
status: Status = Status.PENDING
# Accumulated metrics from this node and all children
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
execution_count: int = 0
def get_agent_results(self) -> list[AgentResult]:
"""Get all AgentResult objects from this node, flattened if nested."""
if isinstance(self.result, Exception):
return [] # No agent results for exceptions
elif isinstance(self.result, AgentResult):
return [self.result]
else:
# Flatten nested results from MultiAgentResult
flattened = []
for nested_node_result in self.result.results.values():
flattened.extend(nested_node_result.get_agent_results())
return flattened
@dataclass
class MultiAgentResult:
"""Result from multi-agent execution with accumulated metrics.
The status field represents the outcome of the MultiAgentBase execution:
- COMPLETED: The execution was successfully accomplished
- FAILED: The execution failed or produced an error
"""
status: Status = Status.PENDING
results: dict[str, NodeResult] = field(default_factory=lambda: {})
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
execution_count: int = 0
execution_time: int = 0
class MultiAgentBase(ABC):
"""Base class for multi-agent helpers.
This class integrates with existing Strands Agent instances and provides
multi-agent orchestration capabilities.
"""
@abstractmethod
async def invoke_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> MultiAgentResult:
"""Invoke asynchronously.
Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Additional keyword arguments passed to underlying agents.
"""
raise NotImplementedError("invoke_async not implemented")
def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> MultiAgentResult:
"""Invoke synchronously.
Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Additional keyword arguments passed to underlying agents.
"""
if invocation_state is None:
invocation_state = {}
def execute() -> MultiAgentResult:
return asyncio.run(self.invoke_async(task, invocation_state, **kwargs))
with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()