Skip to content

Commit f160c4a

Browse files
authored
fix: align create-time run contract with official LangGraph API (#29)
* fix: align create-time run contract with langgraph api * fix: repair create-time stream continuation contract * fix: align join and stream continuation contract * test: fix join expectation in live compat suite * fix: reject unsupported run control fields
1 parent fc3bad6 commit f160c4a

13 files changed

Lines changed: 1360 additions & 170 deletions

src/agentseek_api/api/runs.py

Lines changed: 364 additions & 17 deletions
Large diffs are not rendered by default.

src/agentseek_api/api/stateless_runs.py

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,33 @@
11
from datetime import UTC, datetime
2+
from typing import Any
23

34
from sqlalchemy import select
45

56
from fastapi import APIRouter, Depends, Response
7+
from fastapi.responses import JSONResponse, StreamingResponse
68

79
from agentseek_api.core.auth_deps import get_current_user
810
from agentseek_api.core.database import db_manager
911
from agentseek_api.core.orm import Run, Thread
10-
from agentseek_api.models.api import RunCreate, RunRead, RunsCancelRequest, ThreadCreate
12+
from agentseek_api.models.api import (
13+
RunCreateStateless,
14+
RunCreateStreamingStateless,
15+
RunRead,
16+
RunsCancelRequest,
17+
ThreadCreate,
18+
)
1119
from agentseek_api.models.auth import User
1220
from agentseek_api.services.thread_service import create_thread_for_user
13-
from agentseek_api.api.runs import create_run, create_run_stream, wait_run
21+
from agentseek_api.api.runs import (
22+
_build_create_run_stream_response,
23+
_normalize_stream_modes,
24+
_protocol_stream_location,
25+
_stream_response_headers,
26+
_validate_supported_run_controls,
27+
_wait_response_payload,
28+
create_run,
29+
wait_run,
30+
)
1431

1532
router = APIRouter(prefix="/runs", tags=["Stateless Runs"])
1633

@@ -23,27 +40,69 @@ async def _best_effort_delete_for_runs(run_ids: list[str]) -> None:
2340

2441

2542
@router.post("", response_model=RunRead)
26-
async def create_stateless_run(payload: RunCreate, user: User = Depends(get_current_user)) -> RunRead:
43+
async def create_stateless_run(payload: RunCreateStateless, user: User = Depends(get_current_user)) -> RunRead:
44+
_validate_supported_run_controls(payload, stateless=True)
2745
thread = await create_thread_for_user(payload=ThreadCreate(metadata={"stateless": True}), user=user)
2846
return await create_run(thread.thread_id, payload, user)
2947

3048

31-
@router.post("/wait", response_model=RunRead)
32-
async def create_stateless_run_wait(payload: RunCreate, user: User = Depends(get_current_user)) -> RunRead:
49+
@router.post(
50+
"/wait",
51+
response_class=JSONResponse,
52+
responses={
53+
200: {
54+
"content": {"application/json": {"schema": {}}},
55+
"headers": {
56+
"Location": {"schema": {"type": "string"}},
57+
"Content-Location": {"schema": {"type": "string"}},
58+
},
59+
}
60+
},
61+
)
62+
async def create_stateless_run_wait(payload: RunCreateStreamingStateless, user: User = Depends(get_current_user)) -> JSONResponse:
63+
_normalize_stream_modes(payload.stream_mode)
3364
created = await create_stateless_run(payload, user)
34-
if created.status in {"success", "error", "interrupted"}:
35-
return created
36-
return await wait_run(created.thread_id, created.run_id, user)
37-
38-
39-
@router.post("/stream")
40-
async def create_stateless_run_stream(payload: RunCreate, user: User = Depends(get_current_user)):
65+
final_run = created if created.status in {"success", "error", "interrupted"} else await wait_run(created.thread_id, created.run_id, user)
66+
return JSONResponse(
67+
await _wait_response_payload(final_run, user=user),
68+
headers=_stream_response_headers(
69+
location=f"/threads/{created.thread_id}/runs/{created.run_id}/join",
70+
content_location=f"/threads/{created.thread_id}/runs/{created.run_id}",
71+
),
72+
)
73+
74+
75+
@router.post(
76+
"/stream",
77+
response_class=StreamingResponse,
78+
responses={
79+
200: {
80+
"content": {"text/event-stream": {"schema": {"type": "string"}}},
81+
"headers": {
82+
"Location": {"schema": {"type": "string"}},
83+
"Content-Location": {"schema": {"type": "string"}},
84+
},
85+
}
86+
},
87+
)
88+
async def create_stateless_run_stream(payload: RunCreateStreamingStateless, user: User = Depends(get_current_user)):
89+
stream_modes = _normalize_stream_modes(payload.stream_mode)
90+
_validate_supported_run_controls(payload, stateless=True)
4191
thread = await create_thread_for_user(payload=ThreadCreate(metadata={"stateless": True}), user=user)
42-
return await create_run_stream(thread.thread_id, payload, user)
92+
created = await create_run(thread.thread_id, payload, user)
93+
return _build_create_run_stream_response(
94+
thread_id=thread.thread_id,
95+
created=created,
96+
user=user,
97+
stream_modes=stream_modes,
98+
after_seq=0,
99+
location=_protocol_stream_location(thread_id=thread.thread_id, run_id=created.run_id, stream_modes=stream_modes),
100+
content_location=f"/threads/{thread.thread_id}/runs/{created.run_id}",
101+
)
43102

44103

45104
@router.post("/batch", response_model=list[RunRead])
46-
async def create_run_batch(payload: list[RunCreate], user: User = Depends(get_current_user)) -> list[RunRead]:
105+
async def create_run_batch(payload: list[RunCreateStateless], user: User = Depends(get_current_user)) -> list[RunRead]:
47106
return [await create_stateless_run(item, user) for item in payload]
48107

49108

src/agentseek_api/models/api.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from datetime import datetime
22
from typing import Any
3+
from typing import Literal
34

45
from pydantic import BaseModel, ConfigDict, Field
56

@@ -98,13 +99,75 @@ class ThreadRead(BaseModel):
9899
status: str = "idle"
99100

100101

101-
class RunCreate(BaseModel):
102+
RunStreamMode = Literal[
103+
"values",
104+
"messages",
105+
"messages-tuple",
106+
"tasks",
107+
"checkpoints",
108+
"updates",
109+
"events",
110+
"debug",
111+
"custom",
112+
]
113+
RunInterrupt = Literal["*"] | list[str]
114+
RunDurability = Literal["sync", "async", "exit"]
115+
RunOnDisconnect = Literal["cancel", "continue"]
116+
RunOnCompletion = Literal["delete", "keep"]
117+
RunMultitaskStrategy = Literal["reject", "rollback", "interrupt", "enqueue"]
118+
RunIfNotExists = Literal["create", "reject"]
119+
120+
121+
class RunCreateStateful(BaseModel):
122+
model_config = ConfigDict(extra="allow")
123+
102124
assistant_id: str
103-
input: Any
125+
checkpoint: dict[str, Any] | None = None
126+
input: Any = None
127+
command: dict[str, Any] | None = None
104128
metadata: dict[str, Any] = Field(default_factory=dict)
105129
config: dict[str, Any] = Field(default_factory=dict)
106130
context: dict[str, Any] = Field(default_factory=dict)
107-
multitask_strategy: str = "enqueue"
131+
webhook: str | None = None
132+
interrupt_before: RunInterrupt | None = None
133+
interrupt_after: RunInterrupt | None = None
134+
stream_mode: RunStreamMode | list[RunStreamMode] | None = Field(default_factory=lambda: ["values"])
135+
stream_subgraphs: bool = False
136+
stream_resumable: bool = False
137+
feedback_keys: list[str] | None = None
138+
multitask_strategy: RunMultitaskStrategy = "enqueue"
139+
if_not_exists: RunIfNotExists = "reject"
140+
after_seconds: float | None = None
141+
checkpoint_during: bool = False
142+
durability: RunDurability = "async"
143+
144+
145+
class RunCreateStreamingStateful(RunCreateStateful):
146+
on_disconnect: RunOnDisconnect = "continue"
147+
148+
149+
class RunCreateStateless(BaseModel):
150+
model_config = ConfigDict(extra="allow")
151+
152+
assistant_id: str
153+
input: Any = None
154+
command: dict[str, Any] | None = None
155+
metadata: dict[str, Any] = Field(default_factory=dict)
156+
config: dict[str, Any] = Field(default_factory=dict)
157+
context: dict[str, Any] = Field(default_factory=dict)
158+
webhook: str | None = None
159+
stream_mode: RunStreamMode | list[RunStreamMode] | None = Field(default_factory=lambda: ["values"])
160+
feedback_keys: list[str] | None = None
161+
stream_subgraphs: bool = False
162+
stream_resumable: bool = False
163+
on_completion: RunOnCompletion = "keep"
164+
after_seconds: float | None = None
165+
checkpoint_during: bool = False
166+
durability: RunDurability = "async"
167+
168+
169+
class RunCreateStreamingStateless(RunCreateStateless):
170+
on_disconnect: RunOnDisconnect = "continue"
108171

109172

110173
class RunsCancelRequest(BaseModel):

src/agentseek_api/services/run_executor.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,9 @@ class _OpenMessage:
128128
open_blocks: dict[int, str] = field(default_factory=dict)
129129
text_contents: dict[int, str] = field(default_factory=dict)
130130

131-
def __init__(self, *, thread_id: str) -> None:
131+
def __init__(self, *, thread_id: str, run_id: str) -> None:
132132
self.thread_id = thread_id
133+
self.run_id = run_id
133134
self._open_message_ids: dict[str, _ProtocolMessageStreamState._OpenMessage] = {}
134135
self.saw_live_messages = False
135136

@@ -148,6 +149,7 @@ def _finish_blocks(
148149
self.thread_id,
149150
index=index,
150151
namespace=effective_namespace,
152+
run_id=self.run_id,
151153
)
152154
del state.open_blocks[index]
153155

@@ -166,6 +168,7 @@ def _publish_text_block(
166168
index=index,
167169
content={"type": "text", "text": ""},
168170
namespace=effective_namespace,
171+
run_id=self.run_id,
169172
)
170173
state.open_blocks[index] = "text"
171174
previous_text = state.text_contents.get(index, "")
@@ -178,6 +181,7 @@ def _publish_text_block(
178181
index=index,
179182
delta={"type": "text-delta", "text": delta_text},
180183
namespace=effective_namespace,
184+
run_id=self.run_id,
181185
)
182186
state.text_contents[index] = text
183187

@@ -197,13 +201,15 @@ def _publish_nontext_block(
197201
index=index,
198202
content=block,
199203
namespace=effective_namespace,
204+
run_id=self.run_id,
200205
)
201206
if final:
202207
publish_content_block_finish(
203208
self.thread_id,
204209
index=index,
205210
content=block,
206211
namespace=effective_namespace,
212+
run_id=self.run_id,
207213
)
208214
return
209215
state.open_blocks[index] = str(block.get("type", "block"))
@@ -214,13 +220,15 @@ def _publish_nontext_block(
214220
index=index,
215221
delta=block,
216222
namespace=effective_namespace,
223+
run_id=self.run_id,
217224
)
218225
if final:
219226
publish_content_block_finish(
220227
self.thread_id,
221228
index=index,
222229
content=block,
223230
namespace=effective_namespace,
231+
run_id=self.run_id,
224232
)
225233
del state.open_blocks[index]
226234

@@ -239,6 +247,7 @@ def publish_blocks(
239247
message_id=message_id,
240248
role=role,
241249
namespace=namespace,
250+
run_id=self.run_id,
242251
)
243252
state = self._OpenMessage(
244253
role=role,
@@ -317,7 +326,7 @@ def finish_all(self, *, namespace: list[str] | None = None) -> None:
317326
state = self._open_message_ids.pop(message_id)
318327
message_namespace = state.namespace or namespace
319328
self._finish_blocks(state, namespace=message_namespace)
320-
publish_message_complete(self.thread_id, namespace=message_namespace)
329+
publish_message_complete(self.thread_id, namespace=message_namespace, run_id=self.run_id)
321330

322331
async def afinish_blocks(
323332
self,
@@ -334,6 +343,7 @@ async def afinish_blocks(
334343
self.thread_id,
335344
index=index,
336345
namespace=effective_namespace,
346+
run_id=self.run_id,
337347
)
338348
del state.open_blocks[index]
339349

@@ -352,6 +362,7 @@ async def apublish_text_block(
352362
index=index,
353363
content={"type": "text", "text": ""},
354364
namespace=effective_namespace,
365+
run_id=self.run_id,
355366
)
356367
state.open_blocks[index] = "text"
357368
previous_text = state.text_contents.get(index, "")
@@ -364,6 +375,7 @@ async def apublish_text_block(
364375
index=index,
365376
delta={"type": "text-delta", "text": delta_text},
366377
namespace=effective_namespace,
378+
run_id=self.run_id,
367379
)
368380
state.text_contents[index] = text
369381

@@ -383,13 +395,15 @@ async def apublish_nontext_block(
383395
index=index,
384396
content=block,
385397
namespace=effective_namespace,
398+
run_id=self.run_id,
386399
)
387400
if final:
388401
await apublish_content_block_finish(
389402
self.thread_id,
390403
index=index,
391404
content=block,
392405
namespace=effective_namespace,
406+
run_id=self.run_id,
393407
)
394408
return
395409
state.open_blocks[index] = str(block.get("type", "block"))
@@ -400,13 +414,15 @@ async def apublish_nontext_block(
400414
index=index,
401415
delta=block,
402416
namespace=effective_namespace,
417+
run_id=self.run_id,
403418
)
404419
if final:
405420
await apublish_content_block_finish(
406421
self.thread_id,
407422
index=index,
408423
content=block,
409424
namespace=effective_namespace,
425+
run_id=self.run_id,
410426
)
411427
del state.open_blocks[index]
412428

@@ -425,6 +441,7 @@ async def apublish_blocks(
425441
message_id=message_id,
426442
role=role,
427443
namespace=namespace,
444+
run_id=self.run_id,
428445
)
429446
state = self._OpenMessage(
430447
role=role,
@@ -503,7 +520,7 @@ async def afinish_all(self, *, namespace: list[str] | None = None) -> None:
503520
state = self._open_message_ids.pop(message_id)
504521
message_namespace = state.namespace or namespace
505522
await self.afinish_blocks(state, namespace=message_namespace)
506-
await apublish_message_complete(self.thread_id, namespace=message_namespace)
523+
await apublish_message_complete(self.thread_id, namespace=message_namespace, run_id=self.run_id)
507524

508525

509526
def _protocol_blocks_for_message(message: BaseMessage) -> list[dict[str, Any]]:
@@ -734,7 +751,7 @@ async def execute_run(
734751
result: Any = None
735752
interrupt_chunk: Any = None
736753
interrupt_namespace: list[str] | None = None
737-
protocol_messages = _ProtocolMessageStreamState(thread_id=thread_id)
754+
protocol_messages = _ProtocolMessageStreamState(thread_id=thread_id, run_id=run_id)
738755
async for stream_event in graph.astream_events(invocation, config, version="v2"):
739756
protocol_namespace = _protocol_namespace_for_event(stream_event)
740757
for event_name, event_payload in _translate_stream_events(stream_event):
@@ -804,7 +821,12 @@ async def execute_run(
804821
if isinstance(normalized_chunk, dict):
805822
normalized_chunk.pop("__interrupt__", None)
806823
if normalized_chunk:
807-
await apublish_updates_event(thread_id, values=normalized_chunk, namespace=protocol_namespace)
824+
await apublish_updates_event(
825+
thread_id,
826+
values=normalized_chunk,
827+
namespace=protocol_namespace,
828+
run_id=run_id,
829+
)
808830
if stream_event.get("event") == "on_chain_end" and _is_root_stream_event(stream_event):
809831
data = stream_event.get("data", {})
810832
if isinstance(data, dict) and "output" in data:
@@ -818,7 +840,12 @@ async def execute_run(
818840
else:
819841
await apublish_message_transcript(thread_id, run_id=run_id, messages=messages)
820842
await protocol_messages.afinish_all()
821-
await apublish_values_event(thread_id, values=normalized_result, namespace=protocol_namespace)
843+
await apublish_values_event(
844+
thread_id,
845+
values=normalized_result,
846+
namespace=protocol_namespace,
847+
run_id=run_id,
848+
)
822849

823850
if interrupt_chunk is not None:
824851
if isinstance(result, dict):

0 commit comments

Comments
 (0)