Skip to content

Commit 509ff7d

Browse files
authored
Add real-provider create-time contract coverage (#30)
Squash merge PR #30 after green default CI and Live Provider Streaming.
1 parent f160c4a commit 509ff7d

14 files changed

Lines changed: 1001 additions & 142 deletions

src/agentseek_api/api/runs.py

Lines changed: 178 additions & 47 deletions
Large diffs are not rendered by default.

src/agentseek_api/api/stateless_runs.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from datetime import UTC, datetime
2-
from typing import Any
32

43
from sqlalchemy import select
54

65
from fastapi import APIRouter, Depends, Response
7-
from fastapi.responses import JSONResponse, StreamingResponse
6+
from fastapi.responses import StreamingResponse
87

98
from agentseek_api.core.auth_deps import get_current_user
109
from agentseek_api.core.database import db_manager
@@ -24,9 +23,8 @@
2423
_protocol_stream_location,
2524
_stream_response_headers,
2625
_validate_supported_run_controls,
27-
_wait_response_payload,
26+
_wait_json_stream_response,
2827
create_run,
29-
wait_run,
3028
)
3129

3230
router = APIRouter(prefix="/runs", tags=["Stateless Runs"])
@@ -48,7 +46,7 @@ async def create_stateless_run(payload: RunCreateStateless, user: User = Depends
4846

4947
@router.post(
5048
"/wait",
51-
response_class=JSONResponse,
49+
response_class=StreamingResponse,
5250
responses={
5351
200: {
5452
"content": {"application/json": {"schema": {}}},
@@ -59,12 +57,12 @@ async def create_stateless_run(payload: RunCreateStateless, user: User = Depends
5957
}
6058
},
6159
)
62-
async def create_stateless_run_wait(payload: RunCreateStreamingStateless, user: User = Depends(get_current_user)) -> JSONResponse:
60+
async def create_stateless_run_wait(payload: RunCreateStreamingStateless, user: User = Depends(get_current_user)) -> StreamingResponse:
6361
_normalize_stream_modes(payload.stream_mode)
6462
created = await create_stateless_run(payload, user)
65-
final_run = created if created.status in {"success", "error", "interrupted"} else await wait_run(created.thread_id, created.run_id, user)
66-
return JSONResponse(
67-
await _wait_response_payload(final_run, user=user),
63+
return _wait_json_stream_response(
64+
run=created,
65+
user=user,
6866
headers=_stream_response_headers(
6967
location=f"/threads/{created.thread_id}/runs/{created.run_id}/join",
7068
content_location=f"/threads/{created.thread_id}/runs/{created.run_id}",

src/agentseek_api/api/threads.py

Lines changed: 99 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
load_thread_stream_events,
2828
parse_last_event_id,
2929
)
30+
from agentseek_api.services.sse import iter_with_sse_keepalives, sse_keepalive_comment
3031
from agentseek_api.services.thread_checkpoint_store import (
3132
checkpoint_to_payload,
3233
copy_checkpoints,
@@ -245,6 +246,59 @@ def _empty_thread_state_payload(thread: Thread) -> dict[str, object]:
245246
}
246247

247248

249+
def _checkpoint_namespace(payload: dict[str, object]) -> str:
250+
checkpoint = payload.get("checkpoint")
251+
if not isinstance(checkpoint, dict):
252+
return ""
253+
checkpoint_ns = checkpoint.get("checkpoint_ns", "")
254+
return str(checkpoint_ns) if checkpoint_ns is not None else ""
255+
256+
257+
async def _find_thread_state_payload(
258+
*,
259+
thread_id: str,
260+
user: User,
261+
checkpoint_ns: str | None = None,
262+
) -> dict[str, object] | None:
263+
session_factory = db_manager.get_session_factory()
264+
async with session_factory() as session:
265+
thread = await session.scalar(select(Thread).where(Thread.thread_id == thread_id, Thread.user_id == user.identity))
266+
if thread is None:
267+
raise HTTPException(status_code=404, detail="Thread not found")
268+
runs = (
269+
await session.scalars(
270+
select(Run).where(Run.thread_id == thread_id, Run.user_id == user.identity).order_by(Run.created_at.asc())
271+
)
272+
).all()
273+
visible = _visible_checkpoint_payloads(
274+
thread,
275+
runs,
276+
[checkpoint_to_payload(item) for item in await list_checkpoints(thread_id)],
277+
)
278+
if checkpoint_ns is not None:
279+
visible = [payload for payload in visible if _checkpoint_namespace(payload) == checkpoint_ns]
280+
if not visible:
281+
return None
282+
return visible[0]
283+
284+
285+
async def _get_thread_state_payload(
286+
*,
287+
thread_id: str,
288+
user: User,
289+
checkpoint_ns: str | None = None,
290+
) -> dict[str, object]:
291+
payload = await _find_thread_state_payload(thread_id=thread_id, user=user, checkpoint_ns=checkpoint_ns)
292+
if payload is not None:
293+
return payload
294+
session_factory = db_manager.get_session_factory()
295+
async with session_factory() as session:
296+
thread = await session.scalar(select(Thread).where(Thread.thread_id == thread_id, Thread.user_id == user.identity))
297+
if thread is None:
298+
raise HTTPException(status_code=404, detail="Thread not found")
299+
return _empty_thread_state_payload(thread)
300+
301+
248302
@router.post("", response_model=ThreadRead)
249303
async def create_thread(payload: ThreadCreate, user: User = Depends(get_current_user)) -> ThreadRead:
250304
session_factory = db_manager.get_session_factory()
@@ -437,25 +491,7 @@ async def copy_thread(thread_id: str, user: User = Depends(get_current_user)) ->
437491

438492
@router.get("/{thread_id}/state")
439493
async def get_thread_state(thread_id: str, user: User = Depends(get_current_user)) -> dict[str, object]:
440-
session_factory = db_manager.get_session_factory()
441-
async with session_factory() as session:
442-
thread = await session.scalar(select(Thread).where(Thread.thread_id == thread_id, Thread.user_id == user.identity))
443-
if thread is None:
444-
raise HTTPException(status_code=404, detail="Thread not found")
445-
runs = (
446-
await session.scalars(
447-
select(Run).where(Run.thread_id == thread_id, Run.user_id == user.identity).order_by(Run.created_at.asc())
448-
)
449-
).all()
450-
checkpoints = await list_checkpoints(thread_id)
451-
visible = _visible_checkpoint_payloads(
452-
thread,
453-
runs,
454-
[checkpoint_to_payload(item) for item in checkpoints],
455-
)
456-
if not visible:
457-
return _empty_thread_state_payload(thread)
458-
return visible[0]
494+
return await _get_thread_state_payload(thread_id=thread_id, user=user)
459495

460496

461497
@router.get("/{thread_id}/history")
@@ -586,27 +622,37 @@ async def _event_iter() -> AsyncIterator[str]:
586622
yield f"id: {seq}\nevent: {event_name}\ndata: {json.dumps(event)}\n\n"
587623

588624
if _uses_redis_executor():
589-
async for event in _iter_persisted_thread_events(
590-
thread_id=thread_id,
591-
payload=payload,
592-
user_id=user.identity,
593-
after_seq=current_seq,
594-
wait_for_future_runs=True,
625+
async for event in iter_with_sse_keepalives(
626+
_iter_persisted_thread_events(
627+
thread_id=thread_id,
628+
payload=payload,
629+
user_id=user.identity,
630+
after_seq=current_seq,
631+
wait_for_future_runs=True,
632+
)
595633
):
634+
if event is None:
635+
yield sse_keepalive_comment()
636+
continue
596637
seq = int(event.get("seq", 0))
597638
current_seq = max(current_seq, seq)
598639
event_name = str(event.get("method", "event"))
599640
yield f"id: {seq}\nevent: {event_name}\ndata: {json.dumps(event)}\n\n"
600641
return
601642

602-
async for event in thread_protocol_broker.stream(
603-
thread_id,
604-
channels=payload.channels,
605-
namespaces=payload.namespaces,
606-
depth=payload.depth,
607-
since=current_seq,
608-
wait_for_future_runs=True,
643+
async for event in iter_with_sse_keepalives(
644+
thread_protocol_broker.stream(
645+
thread_id=thread_id,
646+
channels=payload.channels,
647+
namespaces=payload.namespaces,
648+
depth=payload.depth,
649+
since=current_seq,
650+
wait_for_future_runs=True,
651+
)
609652
):
653+
if event is None:
654+
yield sse_keepalive_comment()
655+
continue
610656
seq = int(event.get("seq", 0))
611657
current_seq = max(current_seq, seq)
612658
event_name = str(event.get("method", "event"))
@@ -768,25 +814,35 @@ async def _event_iter() -> AsyncIterator[str]:
768814
yield f"id: {seq}\nevent: {method}\ndata: {body}\n\n"
769815

770816
if _uses_redis_executor():
771-
async for event in _iter_persisted_thread_events(
772-
thread_id=thread_id,
773-
payload=payload,
774-
user_id=user.identity,
775-
after_seq=current_seq,
817+
async for event in iter_with_sse_keepalives(
818+
_iter_persisted_thread_events(
819+
thread_id=thread_id,
820+
payload=payload,
821+
user_id=user.identity,
822+
after_seq=current_seq,
823+
)
776824
):
825+
if event is None:
826+
yield sse_keepalive_comment()
827+
continue
777828
seq = int(event.get("seq", 0))
778829
method = str(event.get("method", "event"))
779830
body = json.dumps(event)
780831
yield f"id: {seq}\nevent: {method}\ndata: {body}\n\n"
781832
return
782833

783-
async for event in thread_protocol_broker.stream(
784-
thread_id,
785-
channels=payload.channels,
786-
namespaces=payload.namespaces,
787-
depth=payload.depth,
788-
since=current_seq,
834+
async for event in iter_with_sse_keepalives(
835+
thread_protocol_broker.stream(
836+
thread_id,
837+
channels=payload.channels,
838+
namespaces=payload.namespaces,
839+
depth=payload.depth,
840+
since=current_seq,
841+
)
789842
):
843+
if event is None:
844+
yield sse_keepalive_comment()
845+
continue
790846
seq = int(event.get("seq", 0))
791847
method = str(event.get("method", "event"))
792848
body = json.dumps(event)

src/agentseek_api/models/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class ThreadRead(BaseModel):
119119

120120

121121
class RunCreateStateful(BaseModel):
122-
model_config = ConfigDict(extra="allow")
122+
model_config = ConfigDict(extra="forbid")
123123

124124
assistant_id: str
125125
checkpoint: dict[str, Any] | None = None
@@ -147,7 +147,7 @@ class RunCreateStreamingStateful(RunCreateStateful):
147147

148148

149149
class RunCreateStateless(BaseModel):
150-
model_config = ConfigDict(extra="allow")
150+
model_config = ConfigDict(extra="forbid")
151151

152152
assistant_id: str
153153
input: Any = None

src/agentseek_api/services/run_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,7 @@ async def execute_run(
860860
interrupt_id=str(item.get("id", "")),
861861
payload=item.get("value"),
862862
namespace=interrupt_namespace,
863+
run_id=run_id,
863864
)
864865

865866
output = entry.extract_output(result, payload)

src/agentseek_api/services/run_jobs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
persist_run_stream_event,
1919
persist_thread_stream_event,
2020
)
21+
from agentseek_api.services.thread_checkpoint_store import checkpoint_to_payload, get_latest_checkpoint
2122
from agentseek_api.services.thread_protocol import publish_lifecycle_event, thread_protocol_broker
2223

2324
RUN_EXECUTION_JOB_KIND = "run.execute"
2425
TERMINAL_RUN_STATUSES = {"success", "error", "interrupted"}
26+
RUN_CHECKPOINT_ID_METADATA_KEY = "__agentseek_checkpoint_id"
2527

2628

2729
@dataclass(slots=True)
@@ -173,6 +175,17 @@ async def execute_run_job(job: RunExecutionJob) -> None:
173175
await _persist_thread_snapshot(job.thread_id)
174176
await execution_session.refresh(db_run)
175177
if not _is_cancelled_run(db_run):
178+
# A missing checkpoint lookup should not turn a successful run into a failed one.
179+
try:
180+
latest_checkpoint = await get_latest_checkpoint(job.thread_id)
181+
except Exception: # noqa: BLE001
182+
latest_checkpoint = None
183+
if latest_checkpoint is not None:
184+
checkpoint_id = checkpoint_to_payload(latest_checkpoint)["checkpoint"]["checkpoint_id"]
185+
db_run.metadata_json = {
186+
**(db_run.metadata_json or {}),
187+
RUN_CHECKPOINT_ID_METADATA_KEY: checkpoint_id,
188+
}
176189
_apply_execution_result(db_run, result)
177190
except Exception as exc: # noqa: BLE001
178191
await execution_session.refresh(db_run)

src/agentseek_api/services/sse.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from collections.abc import AsyncIterator
5+
from typing import TypeVar
6+
7+
T = TypeVar("T")
8+
9+
DEFAULT_SSE_KEEPALIVE_INTERVAL_SECONDS = 15.0
10+
11+
12+
def sse_keepalive_comment() -> str:
13+
return ": keepalive\n\n"
14+
15+
16+
async def iter_with_sse_keepalives(
17+
source: AsyncIterator[T],
18+
*,
19+
interval_seconds: float | None = None,
20+
) -> AsyncIterator[T | None]:
21+
interval = DEFAULT_SSE_KEEPALIVE_INTERVAL_SECONDS if interval_seconds is None else interval_seconds
22+
iterator = source.__aiter__()
23+
pending: asyncio.Task[T] | None = None
24+
try:
25+
while True:
26+
if pending is None:
27+
pending = asyncio.create_task(anext(iterator))
28+
try:
29+
item = await asyncio.wait_for(asyncio.shield(pending), timeout=interval)
30+
except TimeoutError:
31+
yield None
32+
continue
33+
except StopAsyncIteration:
34+
return
35+
pending = None
36+
yield item
37+
finally:
38+
if pending is not None and not pending.done():
39+
pending.cancel()
40+
try:
41+
await pending
42+
except (asyncio.CancelledError, StopAsyncIteration):
43+
pass
44+
aclose = getattr(iterator, "aclose", None)
45+
if callable(aclose):
46+
await aclose()

0 commit comments

Comments
 (0)