|
27 | 27 | load_thread_stream_events, |
28 | 28 | parse_last_event_id, |
29 | 29 | ) |
| 30 | +from agentseek_api.services.sse import iter_with_sse_keepalives, sse_keepalive_comment |
30 | 31 | from agentseek_api.services.thread_checkpoint_store import ( |
31 | 32 | checkpoint_to_payload, |
32 | 33 | copy_checkpoints, |
@@ -245,6 +246,59 @@ def _empty_thread_state_payload(thread: Thread) -> dict[str, object]: |
245 | 246 | } |
246 | 247 |
|
247 | 248 |
|
| 249 | +def _checkpoint_namespace(payload: dict[str, object]) -> str: |
| 250 | + checkpoint = payload.get("checkpoint") |
| 251 | + if not isinstance(checkpoint, dict): |
| 252 | + return "" |
| 253 | + checkpoint_ns = checkpoint.get("checkpoint_ns", "") |
| 254 | + return str(checkpoint_ns) if checkpoint_ns is not None else "" |
| 255 | + |
| 256 | + |
| 257 | +async def _find_thread_state_payload( |
| 258 | + *, |
| 259 | + thread_id: str, |
| 260 | + user: User, |
| 261 | + checkpoint_ns: str | None = None, |
| 262 | +) -> dict[str, object] | None: |
| 263 | + session_factory = db_manager.get_session_factory() |
| 264 | + async with session_factory() as session: |
| 265 | + thread = await session.scalar(select(Thread).where(Thread.thread_id == thread_id, Thread.user_id == user.identity)) |
| 266 | + if thread is None: |
| 267 | + raise HTTPException(status_code=404, detail="Thread not found") |
| 268 | + runs = ( |
| 269 | + await session.scalars( |
| 270 | + select(Run).where(Run.thread_id == thread_id, Run.user_id == user.identity).order_by(Run.created_at.asc()) |
| 271 | + ) |
| 272 | + ).all() |
| 273 | + visible = _visible_checkpoint_payloads( |
| 274 | + thread, |
| 275 | + runs, |
| 276 | + [checkpoint_to_payload(item) for item in await list_checkpoints(thread_id)], |
| 277 | + ) |
| 278 | + if checkpoint_ns is not None: |
| 279 | + visible = [payload for payload in visible if _checkpoint_namespace(payload) == checkpoint_ns] |
| 280 | + if not visible: |
| 281 | + return None |
| 282 | + return visible[0] |
| 283 | + |
| 284 | + |
| 285 | +async def _get_thread_state_payload( |
| 286 | + *, |
| 287 | + thread_id: str, |
| 288 | + user: User, |
| 289 | + checkpoint_ns: str | None = None, |
| 290 | +) -> dict[str, object]: |
| 291 | + payload = await _find_thread_state_payload(thread_id=thread_id, user=user, checkpoint_ns=checkpoint_ns) |
| 292 | + if payload is not None: |
| 293 | + return payload |
| 294 | + session_factory = db_manager.get_session_factory() |
| 295 | + async with session_factory() as session: |
| 296 | + thread = await session.scalar(select(Thread).where(Thread.thread_id == thread_id, Thread.user_id == user.identity)) |
| 297 | + if thread is None: |
| 298 | + raise HTTPException(status_code=404, detail="Thread not found") |
| 299 | + return _empty_thread_state_payload(thread) |
| 300 | + |
| 301 | + |
248 | 302 | @router.post("", response_model=ThreadRead) |
249 | 303 | async def create_thread(payload: ThreadCreate, user: User = Depends(get_current_user)) -> ThreadRead: |
250 | 304 | session_factory = db_manager.get_session_factory() |
@@ -437,25 +491,7 @@ async def copy_thread(thread_id: str, user: User = Depends(get_current_user)) -> |
437 | 491 |
|
438 | 492 | @router.get("/{thread_id}/state") |
439 | 493 | async def get_thread_state(thread_id: str, user: User = Depends(get_current_user)) -> dict[str, object]: |
440 | | - session_factory = db_manager.get_session_factory() |
441 | | - async with session_factory() as session: |
442 | | - thread = await session.scalar(select(Thread).where(Thread.thread_id == thread_id, Thread.user_id == user.identity)) |
443 | | - if thread is None: |
444 | | - raise HTTPException(status_code=404, detail="Thread not found") |
445 | | - runs = ( |
446 | | - await session.scalars( |
447 | | - select(Run).where(Run.thread_id == thread_id, Run.user_id == user.identity).order_by(Run.created_at.asc()) |
448 | | - ) |
449 | | - ).all() |
450 | | - checkpoints = await list_checkpoints(thread_id) |
451 | | - visible = _visible_checkpoint_payloads( |
452 | | - thread, |
453 | | - runs, |
454 | | - [checkpoint_to_payload(item) for item in checkpoints], |
455 | | - ) |
456 | | - if not visible: |
457 | | - return _empty_thread_state_payload(thread) |
458 | | - return visible[0] |
| 494 | + return await _get_thread_state_payload(thread_id=thread_id, user=user) |
459 | 495 |
|
460 | 496 |
|
461 | 497 | @router.get("/{thread_id}/history") |
@@ -586,27 +622,37 @@ async def _event_iter() -> AsyncIterator[str]: |
586 | 622 | yield f"id: {seq}\nevent: {event_name}\ndata: {json.dumps(event)}\n\n" |
587 | 623 |
|
588 | 624 | if _uses_redis_executor(): |
589 | | - async for event in _iter_persisted_thread_events( |
590 | | - thread_id=thread_id, |
591 | | - payload=payload, |
592 | | - user_id=user.identity, |
593 | | - after_seq=current_seq, |
594 | | - wait_for_future_runs=True, |
| 625 | + async for event in iter_with_sse_keepalives( |
| 626 | + _iter_persisted_thread_events( |
| 627 | + thread_id=thread_id, |
| 628 | + payload=payload, |
| 629 | + user_id=user.identity, |
| 630 | + after_seq=current_seq, |
| 631 | + wait_for_future_runs=True, |
| 632 | + ) |
595 | 633 | ): |
| 634 | + if event is None: |
| 635 | + yield sse_keepalive_comment() |
| 636 | + continue |
596 | 637 | seq = int(event.get("seq", 0)) |
597 | 638 | current_seq = max(current_seq, seq) |
598 | 639 | event_name = str(event.get("method", "event")) |
599 | 640 | yield f"id: {seq}\nevent: {event_name}\ndata: {json.dumps(event)}\n\n" |
600 | 641 | return |
601 | 642 |
|
602 | | - async for event in thread_protocol_broker.stream( |
603 | | - thread_id, |
604 | | - channels=payload.channels, |
605 | | - namespaces=payload.namespaces, |
606 | | - depth=payload.depth, |
607 | | - since=current_seq, |
608 | | - wait_for_future_runs=True, |
| 643 | + async for event in iter_with_sse_keepalives( |
| 644 | + thread_protocol_broker.stream( |
| 645 | + thread_id=thread_id, |
| 646 | + channels=payload.channels, |
| 647 | + namespaces=payload.namespaces, |
| 648 | + depth=payload.depth, |
| 649 | + since=current_seq, |
| 650 | + wait_for_future_runs=True, |
| 651 | + ) |
609 | 652 | ): |
| 653 | + if event is None: |
| 654 | + yield sse_keepalive_comment() |
| 655 | + continue |
610 | 656 | seq = int(event.get("seq", 0)) |
611 | 657 | current_seq = max(current_seq, seq) |
612 | 658 | event_name = str(event.get("method", "event")) |
@@ -768,25 +814,35 @@ async def _event_iter() -> AsyncIterator[str]: |
768 | 814 | yield f"id: {seq}\nevent: {method}\ndata: {body}\n\n" |
769 | 815 |
|
770 | 816 | if _uses_redis_executor(): |
771 | | - async for event in _iter_persisted_thread_events( |
772 | | - thread_id=thread_id, |
773 | | - payload=payload, |
774 | | - user_id=user.identity, |
775 | | - after_seq=current_seq, |
| 817 | + async for event in iter_with_sse_keepalives( |
| 818 | + _iter_persisted_thread_events( |
| 819 | + thread_id=thread_id, |
| 820 | + payload=payload, |
| 821 | + user_id=user.identity, |
| 822 | + after_seq=current_seq, |
| 823 | + ) |
776 | 824 | ): |
| 825 | + if event is None: |
| 826 | + yield sse_keepalive_comment() |
| 827 | + continue |
777 | 828 | seq = int(event.get("seq", 0)) |
778 | 829 | method = str(event.get("method", "event")) |
779 | 830 | body = json.dumps(event) |
780 | 831 | yield f"id: {seq}\nevent: {method}\ndata: {body}\n\n" |
781 | 832 | return |
782 | 833 |
|
783 | | - async for event in thread_protocol_broker.stream( |
784 | | - thread_id, |
785 | | - channels=payload.channels, |
786 | | - namespaces=payload.namespaces, |
787 | | - depth=payload.depth, |
788 | | - since=current_seq, |
| 834 | + async for event in iter_with_sse_keepalives( |
| 835 | + thread_protocol_broker.stream( |
| 836 | + thread_id, |
| 837 | + channels=payload.channels, |
| 838 | + namespaces=payload.namespaces, |
| 839 | + depth=payload.depth, |
| 840 | + since=current_seq, |
| 841 | + ) |
789 | 842 | ): |
| 843 | + if event is None: |
| 844 | + yield sse_keepalive_comment() |
| 845 | + continue |
790 | 846 | seq = int(event.get("seq", 0)) |
791 | 847 | method = str(event.get("method", "event")) |
792 | 848 | body = json.dumps(event) |
|
0 commit comments