-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathworkflow.py.j2
More file actions
146 lines (123 loc) · 6.21 KB
/
Copy pathworkflow.py.j2
File metadata and controls
146 lines (123 loc) · 6.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Temporal workflow for {{ agent_name }}.
The workflow holds task state durably across crashes. Its signal handler
delegates the actual agent run to ``temporal_agent.run(...)`` — which
internally schedules model and tool activities, each independently
durable. The ``event_stream_handler`` registered on ``temporal_agent``
pushes streaming deltas to Redis while the model activity runs.
Multi-turn memory is kept on the workflow instance itself
(``self._message_history``). Temporal's workflow state is already durable
and replay-safe, so unlike the async-base template we don't need an
external ``adk.state`` round-trip — the message list survives crashes
because Temporal replays the activity results that produced it.
"""
from __future__ import annotations
import os
import json
from typing import TYPE_CHECKING
from temporalio import workflow
from project.agent import TaskDeps, temporal_agent
from agentex.lib import adk
from agentex.protocol.acp import SendEventParams, CreateTaskParams
from agentex.lib.types.tracing import SGPTracingProcessorConfig
from agentex.lib.utils.logging import make_logger
from agentex.types.text_content import TextContent
from agentex.lib.environment_variables import EnvironmentVariables
from agentex.lib.core.temporal.types.workflow import SignalName
from agentex.lib.core.temporal.workflows.workflow import BaseWorkflow
from agentex.lib.core.tracing.tracing_processor_manager import add_tracing_processor_config
if TYPE_CHECKING:
from pydantic_ai.messages import ModelMessage
# Register the SGP tracing exporter. Spans also reach the AgentEx backend
# via the default Agentex processor that's lazy-initialised on first span.
SGP_API_KEY = os.environ.get("SGP_API_KEY", "")
SGP_ACCOUNT_ID = os.environ.get("SGP_ACCOUNT_ID", "")
if SGP_API_KEY and SGP_ACCOUNT_ID:
add_tracing_processor_config(
SGPTracingProcessorConfig(
sgp_api_key=SGP_API_KEY,
sgp_account_id=SGP_ACCOUNT_ID,
sgp_base_url=os.environ.get("SGP_CLIENT_BASE_URL", ""),
)
)
environment_variables = EnvironmentVariables.refresh()
if environment_variables.WORKFLOW_NAME is None:
raise ValueError("Environment variable WORKFLOW_NAME is not set")
if environment_variables.AGENT_NAME is None:
raise ValueError("Environment variable AGENT_NAME is not set")
logger = make_logger(__name__)
@workflow.defn(name=environment_variables.WORKFLOW_NAME)
class {{ workflow_class }}(BaseWorkflow):
"""Long-running Temporal workflow that delegates each turn to a Pydantic AI TemporalAgent.
The ``__pydantic_ai_agents__`` attribute is the marker the
``PydanticAIPlugin`` looks for at worker startup: it pulls
``temporal_agent.temporal_activities`` off this list and registers
every model/tool activity on the worker automatically — so we don't
have to enumerate activities by hand in ``run_worker.py``.
"""
__pydantic_ai_agents__ = [temporal_agent]
def __init__(self):
super().__init__(display_name=environment_variables.AGENT_NAME)
self._complete_task = False
self._turn_number = 0
# Conversation history accumulated across turns. Each entry is a
# pydantic-ai ``ModelMessage``. Temporal replays the activity that
# produced these messages, so the list is rebuilt deterministically
# if the workflow ever recovers from a crash.
self._message_history: list["ModelMessage"] = []
@workflow.signal(name=SignalName.RECEIVE_EVENT)
async def on_task_event_send(self, params: SendEventParams) -> None:
"""Handle a new user message: echo it, then run the agent durably."""
logger.info(f"Received task event: {params.task.id}")
self._turn_number += 1
# Echo the user's message so it shows up in the UI as a chat bubble.
await adk.messages.create(task_id=params.task.id, content=params.event.content)
async with adk.tracing.span(
trace_id=params.task.id,
task_id=params.task.id,
name=f"Turn {self._turn_number}",
input={"message": params.event.content.content},
) as span:
# temporal_agent.run() is the magic line. Internally it schedules
# a model activity (LLM HTTP call) and, for each tool the model
# invokes, a separate tool activity. Each is independently
# durable and retried. While the model activity runs, the
# event_stream_handler on temporal_agent pushes deltas to Redis
# so the UI sees tokens stream live.
#
# Passing ``message_history`` makes the run remember prior turns;
# without it the agent would respond to each user message as if
# it had never seen the conversation before.
result = await temporal_agent.run(
params.event.content.content,
message_history=self._message_history,
deps=TaskDeps(
task_id=params.task.id,
parent_span_id=span.id if span else None,
),
)
# Persist the new full history (user + assistant + any tool
# rounds) so the next turn picks up from here.
self._message_history = list(result.all_messages())
if span:
span.output = {"final_output": result.output}
@workflow.run
async def on_task_create(self, params: CreateTaskParams) -> str:
"""Workflow entry point — keep the conversation alive for incoming signals."""
logger.info(f"Task created: {params.task.id}")
await adk.messages.create(
task_id=params.task.id,
content=TextContent(
author="agent",
content=(
f"Task initialized with params:\n{json.dumps(params.params, indent=2)}\n"
f"Send me a message and I'll respond using a Pydantic AI agent backed by Temporal."
),
),
)
await workflow.wait_condition(lambda: self._complete_task, timeout=None)
return "Task completed"
@workflow.signal
async def complete_task_signal(self) -> None:
"""Graceful workflow shutdown signal."""
logger.info("Received complete_task signal")
self._complete_task = True