-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathstate.py
More file actions
127 lines (117 loc) · 4.1 KB
/
Copy pathstate.py
File metadata and controls
127 lines (117 loc) · 4.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations
from typing import Any, Dict
from agentex import AsyncAgentex
from agentex.types.state import State
from agentex.lib.utils.logging import make_logger
from agentex.lib.core.tracing.tracer import AsyncTracer
logger = make_logger(__name__)
class StateService:
def __init__(
self, agentex_client: AsyncAgentex, tracer: AsyncTracer
):
self._agentex_client = agentex_client
self._tracer = tracer
async def create_state(
self,
task_id: str,
agent_id: str,
state: dict[str, Any],
trace_id: str | None = None,
parent_span_id: str | None = None,
) -> State:
trace = self._tracer.trace(trace_id)
async with trace.span(
parent_id=parent_span_id,
name="create_state",
input={"task_id": task_id, "agent_id": agent_id, "state": state},
) as span:
state_model = await self._agentex_client.states.create(
task_id=task_id,
agent_id=agent_id,
state=state,
)
if span:
span.output = state_model.model_dump()
return state_model
async def get_state(
self,
state_id: str | None = None,
task_id: str | None = None,
agent_id: str | None = None,
trace_id: str | None = None,
parent_span_id: str | None = None,
) -> State | None:
trace = self._tracer.trace(trace_id) if self._tracer else None
if trace is None:
# Handle case without tracing - implement the core logic here
return await self._agentex_client.states.retrieve(state_id)
async with trace.span(
parent_id=parent_span_id,
name="get_state",
input={
"state_id": state_id,
"task_id": task_id,
"agent_id": agent_id,
},
) as span:
if state_id:
state = await self._agentex_client.states.retrieve(state_id=state_id)
elif task_id and agent_id:
states = await self._agentex_client.states.list(
task_id=task_id,
agent_id=agent_id,
)
state = states[0] if states else None
else:
raise ValueError(
"Must provide either state_id or both task_id and agent_id"
)
if span:
span.output = state.model_dump() if state else None
return state
async def update_state(
self,
state_id: str,
task_id: str,
agent_id: str,
state: Dict[str, object],
trace_id: str | None = None,
parent_span_id: str | None = None,
) -> State:
trace = self._tracer.trace(trace_id)
async with trace.span(
parent_id=parent_span_id,
name="update_state",
input={
"state_id": state_id,
"task_id": task_id,
"agent_id": agent_id,
"state": state,
},
) as span:
# Send task_id/agent_id in the body for backends predating
# scale-agentex#278, which still require them (newer ones ignore them).
state_model = await self._agentex_client.states.update(
state_id=state_id,
state=state,
extra_body={"task_id": task_id, "agent_id": agent_id},
)
if span:
span.output = state_model.model_dump()
return state_model
async def delete_state(
self,
state_id: str,
trace_id: str | None = None,
parent_span_id: str | None = None,
) -> State:
trace = self._tracer.trace(trace_id)
async with trace.span(
parent_id=parent_span_id,
name="delete_state",
input={"state_id": state_id},
) as span:
state = await self._agentex_client.states.delete(state_id)
if span:
span.output = state.model_dump()
return state