-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathacp.py.j2
More file actions
166 lines (132 loc) · 5.89 KB
/
Copy pathacp.py.j2
File metadata and controls
166 lines (132 loc) · 5.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""ACP handler for async Pydantic AI agent.
Uses the async ACP model with Redis streaming instead of HTTP yields.
Text and reasoning tokens stream as Redis deltas; tool requests and
responses are persisted as discrete full messages.
Multi-turn memory is persisted via ``adk.state``: on each turn we load the
previous pydantic-ai ``message_history`` from state, run the agent with it,
then save the updated history back. Without this, every turn would be a
fresh stateless run and the agent would forget the prior conversation.
"""
from __future__ import annotations
import os
from typing import Any, AsyncIterator
from dotenv import load_dotenv
load_dotenv()
from project.agent import create_agent
from pydantic_ai.run import AgentRunResultEvent
from pydantic_ai.messages import ModelMessagesTypeAdapter
import agentex.lib.adk as adk
from agentex.lib.adk import (
stream_pydantic_ai_events,
create_pydantic_ai_tracing_handler,
)
from agentex.protocol.acp import SendEventParams, CancelTaskParams, CreateTaskParams
from agentex.lib.types.fastacp import AsyncACPConfig
from agentex.lib.types.tracing import SGPTracingProcessorConfig
from agentex.lib.utils.logging import make_logger
from agentex.lib.utils.model_utils import BaseModel
from agentex.lib.sdk.fastacp.fastacp import FastACP
from agentex.lib.core.tracing.tracing_processor_manager import add_tracing_processor_config
logger = make_logger(__name__)
# Register the SGP tracing exporter. Spans also reach the AgentEx backend
# via the default Agentex processor that's lazy-initialised on first span,
# so they show up in the per-task spans dropdown out of the box.
SGP_API_KEY = os.environ.get("SGP_API_KEY", "")
SGP_ACCOUNT_ID = os.environ.get("SGP_ACCOUNT_ID", "")
if SGP_API_KEY and SGP_ACCOUNT_ID:
add_tracing_processor_config(
SGPTracingProcessorConfig(
sgp_api_key=SGP_API_KEY,
sgp_account_id=SGP_ACCOUNT_ID,
sgp_base_url=os.environ.get("SGP_CLIENT_BASE_URL", ""),
)
)
acp = FastACP.create(
acp_type="async",
config=AsyncACPConfig(type="base"),
)
_agent = None
def get_agent():
"""Return the cached Pydantic AI agent, creating it on first use."""
global _agent
if _agent is None:
_agent = create_agent()
return _agent
class ConversationState(BaseModel):
"""Per-task conversation state persisted via ``adk.state``.
``history_json`` holds the pydantic-ai message history serialized by
``ModelMessagesTypeAdapter`` — pydantic-ai's official way to round-trip
``ModelMessage`` objects through JSON. We can't use a plain
``list[ModelMessage]`` field because ``ModelMessage`` is a discriminated
union of runtime types, not a stable Pydantic schema.
"""
history_json: str = "[]"
turn_number: int = 0
@acp.on_task_create
async def handle_task_create(params: CreateTaskParams):
"""Initialize per-task state on task creation."""
logger.info(f"Task created: {params.task.id}")
await adk.state.create(
task_id=params.task.id,
agent_id=params.agent.id,
state=ConversationState(),
)
@acp.on_task_event_send
async def handle_task_event_send(params: SendEventParams):
"""Handle each user message: load prior history, run the agent, save updated history."""
agent = get_agent()
task_id = params.task.id
agent_id = params.agent.id
user_message = params.event.content.content
logger.info(f"Processing message for task {task_id}")
# Echo the user's message into the task history.
await adk.messages.create(task_id=task_id, content=params.event.content)
# Load prior conversation state. Fall back to a fresh state if missing
# (e.g. the task wasn't initialised through on_task_create).
task_state = await adk.state.get_by_task_and_agent(task_id=task_id, agent_id=agent_id)
if task_state is None:
state = ConversationState()
task_state = await adk.state.create(task_id=task_id, agent_id=agent_id, state=state)
else:
state = ConversationState.model_validate(task_state.state)
state.turn_number += 1
previous_messages = ModelMessagesTypeAdapter.validate_json(state.history_json)
async with adk.tracing.span(
trace_id=task_id,
task_id=task_id,
name=f"Turn {state.turn_number}",
input={"message": user_message},
data={"__span_type__": "AGENT_WORKFLOW"},
) as turn_span:
tracing_handler = create_pydantic_ai_tracing_handler(
trace_id=task_id,
parent_span_id=turn_span.id if turn_span else None,
task_id=task_id,
)
# Wrap the pydantic-ai event stream so we can capture the final
# AgentRunResultEvent (which carries the full message list for the
# next turn) without changing the streaming-helper's signature.
captured_messages: list[Any] = []
async def tee_messages(upstream) -> AsyncIterator[Any]:
async for event in upstream:
if isinstance(event, AgentRunResultEvent):
captured_messages[:] = list(event.result.all_messages())
yield event
async with agent.run_stream_events(user_message, message_history=previous_messages) as stream:
final_output = await stream_pydantic_ai_events(
tee_messages(stream), task_id, tracing_handler=tracing_handler
)
# Save the updated message history so the next turn picks up here.
if captured_messages:
state.history_json = ModelMessagesTypeAdapter.dump_json(captured_messages).decode()
await adk.state.update(
state_id=task_state.id,
task_id=task_id,
agent_id=agent_id,
state=state,
)
if turn_span:
turn_span.output = {"final_output": final_output}
@acp.on_task_cancel
async def handle_task_canceled(params: CancelTaskParams):
logger.info(f"Task canceled: {params.task.id}")