|
3 | 3 | """Tests for FastAPI endpoint creation (_endpoint.py).""" |
4 | 4 |
|
5 | 5 | import json |
6 | | -from typing import Any |
| 6 | +from typing import Any, cast |
7 | 7 |
|
8 | 8 | import pytest |
9 | 9 | from ag_ui.core import MessagesSnapshotEvent, RunStartedEvent, StateSnapshotEvent |
@@ -82,7 +82,7 @@ async def test_add_endpoint_with_workflow_protocol(): |
82 | 82 | """Test adding endpoint with native Workflow support.""" |
83 | 83 |
|
84 | 84 | @executor(id="start") |
85 | | - async def start(message: Any, ctx: WorkflowContext) -> None: |
| 85 | + async def start(message: Any, ctx: WorkflowContext[Any, Any]) -> None: |
86 | 86 | await ctx.yield_output("Workflow response") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] |
87 | 87 |
|
88 | 88 | app = FastAPI() |
@@ -695,7 +695,7 @@ async def stream_fn(messages: Any, options: Any, **kwargs: Any): |
695 | 695 | path="/snapshots", |
696 | 696 | state_schema={"tenant": {"type": "string"}}, |
697 | 697 | snapshot_store=store, |
698 | | - snapshot_scope_resolver=lambda request: request.forwarded_props["tenant"], |
| 698 | + snapshot_scope_resolver=lambda request: cast("dict[str, Any]", request.forwarded_props)["tenant"], |
699 | 699 | ) |
700 | 700 | client = TestClient(app) |
701 | 701 |
|
@@ -994,14 +994,16 @@ async def test_workflow_endpoint_hydrates_emitted_snapshots_without_invoking_wor |
994 | 994 | call_count = 0 |
995 | 995 |
|
996 | 996 | @executor(id="snapshotter") |
997 | | - async def snapshotter(message: Any, ctx: WorkflowContext) -> None: |
| 997 | + async def snapshotter(message: Any, ctx: WorkflowContext[Any, Any]) -> None: |
998 | 998 | nonlocal call_count |
999 | 999 | del message |
1000 | 1000 | call_count += 1 |
1001 | 1001 | await ctx.yield_output(StateSnapshotEvent(snapshot={"active_agent": "flights"})) |
1002 | 1002 | await ctx.yield_output( |
1003 | 1003 | MessagesSnapshotEvent( |
1004 | | - messages=[{"id": "assistant-snapshot", "role": "assistant", "content": "Stored workflow reply"}] |
| 1004 | + messages=cast( |
| 1005 | + Any, [{"id": "assistant-snapshot", "role": "assistant", "content": "Stored workflow reply"}] |
| 1006 | + ) |
1005 | 1007 | ) |
1006 | 1008 | ) |
1007 | 1009 |
|
@@ -1046,7 +1048,7 @@ async def test_workflow_endpoint_hydrates_synthesized_text_and_tool_snapshot(): |
1046 | 1048 | call_count = 0 |
1047 | 1049 |
|
1048 | 1050 | @executor(id="responder") |
1049 | | - async def responder(message: Any, ctx: WorkflowContext) -> None: |
| 1051 | + async def responder(message: Any, ctx: WorkflowContext[Any, Any]) -> None: |
1050 | 1052 | nonlocal call_count |
1051 | 1053 | del message |
1052 | 1054 | call_count += 1 |
@@ -1114,7 +1116,7 @@ async def test_workflow_endpoint_hydrates_interrupted_thread_without_invoking_wo |
1114 | 1116 | call_count = 0 |
1115 | 1117 |
|
1116 | 1118 | @executor(id="requester") |
1117 | | - async def requester(message: Any, ctx: WorkflowContext) -> None: |
| 1119 | + async def requester(message: Any, ctx: WorkflowContext[Any, Any]) -> None: |
1118 | 1120 | nonlocal call_count |
1119 | 1121 | del message |
1120 | 1122 | call_count += 1 |
@@ -1167,7 +1169,7 @@ async def test_workflow_endpoint_run_error_does_not_overwrite_previous_snapshot( |
1167 | 1169 | call_count = 0 |
1168 | 1170 |
|
1169 | 1171 | @executor(id="responder") |
1170 | | - async def responder(message: Any, ctx: WorkflowContext) -> None: |
| 1172 | + async def responder(message: Any, ctx: WorkflowContext[Any, Any]) -> None: |
1171 | 1173 | nonlocal call_count |
1172 | 1174 | del message |
1173 | 1175 | call_count += 1 |
@@ -1464,7 +1466,7 @@ async def test_workflow_preserves_history_across_turns(): |
1464 | 1466 | call_count = 0 |
1465 | 1467 |
|
1466 | 1468 | @executor(id="responder") |
1467 | | - async def responder(message: Any, ctx: WorkflowContext) -> None: |
| 1469 | + async def responder(message: Any, ctx: WorkflowContext[Any, Any]) -> None: |
1468 | 1470 | nonlocal call_count |
1469 | 1471 | del message |
1470 | 1472 | call_count += 1 |
@@ -1688,7 +1690,7 @@ async def fake_run_workflow_stream(input_data: Any, workflow: Any): |
1688 | 1690 | monkeypatch.setattr(workflow_module, "run_workflow_stream", fake_run_workflow_stream) |
1689 | 1691 |
|
1690 | 1692 | @executor(id="noop") |
1691 | | - async def noop(message: Any, ctx: WorkflowContext) -> None: |
| 1693 | + async def noop(message: Any, ctx: WorkflowContext[Any, Any]) -> None: |
1692 | 1694 | del message, ctx |
1693 | 1695 |
|
1694 | 1696 | runner = AgentFrameworkWorkflow( |
@@ -1759,7 +1761,7 @@ async def test_workflow_endpoint_snapshot_save_failure_does_not_emit_run_error() |
1759 | 1761 | """A failing snapshot save after RUN_FINISHED must not emit a second terminal RUN_ERROR.""" |
1760 | 1762 |
|
1761 | 1763 | @executor(id="responder") |
1762 | | - async def responder(message: Any, ctx: WorkflowContext) -> None: |
| 1764 | + async def responder(message: Any, ctx: WorkflowContext[Any, Any]) -> None: |
1763 | 1765 | del message |
1764 | 1766 | await ctx.yield_output("Workflow reply") |
1765 | 1767 |
|
@@ -1822,7 +1824,7 @@ def test_workflow_factory_cache_is_scoped_by_snapshot_scope(): |
1822 | 1824 | """The same thread id under different Snapshot Scopes must not share a workflow instance.""" |
1823 | 1825 |
|
1824 | 1826 | @executor(id="noop") |
1825 | | - async def noop(message: Any, ctx: WorkflowContext) -> None: |
| 1827 | + async def noop(message: Any, ctx: WorkflowContext[Any, Any]) -> None: |
1826 | 1828 | del message, ctx |
1827 | 1829 |
|
1828 | 1830 | def factory(thread_id: str) -> Any: |
|
0 commit comments