forked from a2aproject/A2A
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathremote_agent_connection.py
More file actions
94 lines (85 loc) · 3.17 KB
/
Copy pathremote_agent_connection.py
File metadata and controls
94 lines (85 loc) · 3.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Callable
import uuid
from common.types import (
AgentCard,
Task,
TaskSendParams,
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
TaskStatus,
TaskState,
)
from common.client import A2AClient
TaskCallbackArg = Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
TaskUpdateCallback = Callable[[TaskCallbackArg], Task]
class RemoteAgentConnections:
"""A class to hold the connections to the remote agents."""
def __init__(self, agent_card: AgentCard):
self.agent_client = A2AClient(agent_card)
self.card = agent_card
self.conversation_name = None
self.conversation = None
self.pending_tasks = set()
def get_agent(self) -> AgentCard:
return self.card
async def send_task(
self,
request: TaskSendParams,
task_callback: TaskUpdateCallback | None,
) -> Task | None:
if self.card.capabilities.streaming:
task = None
if task_callback:
task_callback(Task(
id=request.id,
sessionId=request.sessionId,
status=TaskStatus(
state=TaskState.SUBMITTED,
message=request.message,
),
history=[request.message],
))
async for response in self.agent_client.send_task_streaming(request.model_dump()):
merge_metadata(response.result, request)
# For task status updates, we need to propagate metadata and provide
# a unique message id.
if (hasattr(response.result, 'status') and
hasattr(response.result.status, 'message') and
response.result.status.message):
merge_metadata(response.result.status.message, request.message)
m = response.result.status.message
if not m.metadata:
m.metadata = {}
if 'message_id' in m.metadata:
m.metadata['last_message_id'] = m.metadata['message_id']
m.metadata['message_id'] = str(uuid.uuid4())
if task_callback:
task = task_callback(response.result)
if hasattr(response.result, 'final') and response.result.final:
break
return task
else: # Non-streaming
response = await self.agent_client.send_task(request.model_dump())
merge_metadata(response.result, request)
# For task status updates, we need to propagate metadata and provide
# a unique message id.
if (hasattr(response.result, 'status') and
hasattr(response.result.status, 'message') and
response.result.status.message):
merge_metadata(response.result.status.message, request.message)
m = response.result.status.message
if not m.metadata:
m.metadata = {}
if 'message_id' in m.metadata:
m.metadata['last_message_id'] = m.metadata['message_id']
m.metadata['message_id'] = str(uuid.uuid4())
if task_callback:
task_callback(response.result)
return response.result
def merge_metadata(target, source):
if not hasattr(target, 'metadata') or not hasattr(source, 'metadata'):
return
if target.metadata and source.metadata:
target.metadata.update(source.metadata)
elif source.metadata:
target.metadata = dict(**source.metadata)