-
Notifications
You must be signed in to change notification settings - Fork 836
Expand file tree
/
Copy path_converters.py
More file actions
179 lines (147 loc) · 6.42 KB
/
_converters.py
File metadata and controls
179 lines (147 loc) · 6.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Conversion functions between Strands and A2A types."""
from typing import cast
from uuid import uuid4
from a2a.types import Message as A2AMessage
from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent, TextPart
from ...agent.agent_result import AgentResult
from ...telemetry.metrics import EventLoopMetrics
from ...types.a2a import A2AResponse
from ...types.agent import AgentInput
from ...types.content import ContentBlock, Message
from ...types.event_loop import StopReason
# Mapping from A2A TaskState to Strands stop_reason
_STATE_TO_STOP_REASON: dict[TaskState, StopReason] = {
TaskState.completed: "end_turn",
TaskState.failed: "end_turn",
TaskState.canceled: "end_turn",
TaskState.rejected: "end_turn",
TaskState.input_required: "interrupt",
TaskState.auth_required: "interrupt",
}
def convert_input_to_message(prompt: AgentInput) -> A2AMessage:
"""Convert AgentInput to A2A Message.
Args:
prompt: Input in various formats (string, message list, or content blocks).
Returns:
A2AMessage ready to send to the remote agent.
Raises:
ValueError: If prompt format is unsupported.
"""
message_id = uuid4().hex
if isinstance(prompt, str):
return A2AMessage(
kind="message",
role=Role.user,
parts=[Part(TextPart(kind="text", text=prompt))],
message_id=message_id,
)
if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)):
# Check for interrupt responses - not supported in A2A
if "interruptResponse" in prompt[0]:
raise ValueError("InterruptResponseContent is not supported for A2AAgent")
if "role" in prompt[0]:
for msg in reversed(prompt):
if msg.get("role") == "user":
content = cast(list[ContentBlock], msg.get("content", []))
parts = convert_content_blocks_to_parts(content)
return A2AMessage(
kind="message",
role=Role.user,
parts=parts,
message_id=message_id,
)
else:
parts = convert_content_blocks_to_parts(cast(list[ContentBlock], prompt))
return A2AMessage(
kind="message",
role=Role.user,
parts=parts,
message_id=message_id,
)
raise ValueError(f"Unsupported input type: {type(prompt)}")
def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[Part]:
"""Convert Strands ContentBlocks to A2A Parts.
Args:
content_blocks: List of Strands content blocks.
Returns:
List of A2A Part objects.
"""
parts = []
for block in content_blocks:
if "text" in block:
parts.append(Part(TextPart(kind="text", text=block["text"])))
return parts
def _extract_task_state(response: A2AResponse) -> TaskState | None:
"""Extract the task state from an A2A response.
Args:
response: A2A response (either A2AMessage or tuple of task and update event).
Returns:
The TaskState if available, None otherwise.
"""
if isinstance(response, tuple) and len(response) == 2:
_task, update_event = response
if isinstance(update_event, TaskStatusUpdateEvent):
if update_event.status and hasattr(update_event.status, "state"):
return update_event.status.state
return None
def convert_response_to_agent_result(response: A2AResponse) -> AgentResult:
"""Convert A2A response to AgentResult.
Maps A2A task lifecycle states to appropriate Strands stop_reasons:
- completed → end_turn
- failed → end_turn (with error content)
- canceled → end_turn (with cancellation info)
- rejected → end_turn (with rejection info)
- input_required → interrupt (agent needs user input)
- auth_required → interrupt (agent needs authentication)
Args:
response: A2A response (either A2AMessage or tuple of task and update event).
Returns:
AgentResult with extracted content and metadata.
"""
content: list[ContentBlock] = []
task_state = _extract_task_state(response)
stop_reason: StopReason = _STATE_TO_STOP_REASON.get(task_state, "end_turn") if task_state else "end_turn"
if isinstance(response, tuple) and len(response) == 2:
task, update_event = response
# Handle artifact updates
if isinstance(update_event, TaskArtifactUpdateEvent):
if update_event.artifact and hasattr(update_event.artifact, "parts") and update_event.artifact.parts:
for part in update_event.artifact.parts:
if hasattr(part, "root") and hasattr(part.root, "text"):
content.append({"text": part.root.text})
# Handle status updates with messages
elif isinstance(update_event, TaskStatusUpdateEvent):
if (
update_event.status
and hasattr(update_event.status, "message")
and update_event.status.message
and update_event.status.message.parts
):
for part in update_event.status.message.parts:
if hasattr(part, "root") and hasattr(part.root, "text"):
content.append({"text": part.root.text})
# Use task.artifacts when no content was extracted from the event
if not content and task and hasattr(task, "artifacts") and task.artifacts is not None:
for artifact in task.artifacts:
if hasattr(artifact, "parts") and artifact.parts:
for part in artifact.parts:
if hasattr(part, "root") and hasattr(part.root, "text"):
content.append({"text": part.root.text})
elif isinstance(response, A2AMessage):
for part in response.parts:
if hasattr(part, "root") and hasattr(part.root, "text"):
content.append({"text": part.root.text})
message: Message = {
"role": "assistant",
"content": content,
}
# Build state dict with A2A metadata
state: dict[str, str] = {}
if task_state is not None:
state["a2a_task_state"] = task_state.value
return AgentResult(
stop_reason=stop_reason,
message=message,
metrics=EventLoopMetrics(),
state=state,
)