Skip to content

Commit 207b85b

Browse files
patnikoRasabounCopilot
authored
fix(python): add timeout parameter to generated RPC methods (#681)
* fix(python): add timeout parameter to generated RPC methods Every generated async RPC method now accepts an optional `timeout` keyword argument that is forwarded to `JsonRpcClient.request()`. This lets callers override the default 30s timeout for long-running RPCs like `session.fleet.start` without bypassing the typed API. Fixes #539 * test: cover no-params and server-scoped RPC timeout branches Add tests for PlanApi.read (session, no params) and ModelsApi.list (server, no params) to exercise all four codegen branches. * fix: move _timeout_kwargs after quicktype imports, add server+params test - Move _timeout_kwargs helper after the quicktype-generated import block to avoid duplicate Optional import and keep preamble conventional - Add ToolsApi.list tests covering the server-scoped + params branch - All four codegen branches now have test coverage * style: fix ruff format in test_rpc_timeout.py * fix: use float | None instead of Optional[float] in generated RPC Optional is not imported in the generated rpc.py file. On Python 3.11 (used in CI), annotations are eagerly evaluated, so Optional[float] would cause NameError at import time. Use the modern float | None union syntax which requires no import and matches the rest of the generated code. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: wrap long docstring line in client.py to pass ruff E501 Pre-existing ruff line-length violation in get_last_session_id docstring example that was failing CI on main. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: resolve pre-existing CI failures in docs and ruff lint - docs/guides/custom-agents.md: use typed enum constants instead of string literals for PermissionRequestResultKind (C# and Go examples) - python/copilot/client.py: wrap long docstring line to satisfy E501 Both issues pre-exist on main and block all PR CI runs. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Rasaboun <40967731+Rasaboun@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 87a54de commit 207b85b

File tree

5 files changed

+200
-51
lines changed

5 files changed

+200
-51
lines changed

docs/guides/custom-agents.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ session, _ := client.CreateSession(ctx, &copilot.SessionConfig{
121121
},
122122
},
123123
OnPermissionRequest: func(req copilot.PermissionRequest, inv copilot.PermissionInvocation) (copilot.PermissionRequestResult, error) {
124-
return copilot.PermissionRequestResult{Kind: "approved"}, nil
124+
return copilot.PermissionRequestResult{Kind: copilot.PermissionRequestResultKindApproved}, nil
125125
},
126126
})
127127
```
@@ -158,7 +158,7 @@ await using var session = await client.CreateSessionAsync(new SessionConfig
158158
},
159159
},
160160
OnPermissionRequest = (req, inv) =>
161-
Task.FromResult(new PermissionRequestResult { Kind = "approved" }),
161+
Task.FromResult(new PermissionRequestResult { Kind = PermissionRequestResultKind.Approved }),
162162
});
163163
```
164164

python/copilot/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,8 @@ async def get_last_session_id(self) -> str | None:
977977
Example:
978978
>>> last_id = await client.get_last_session_id()
979979
>>> if last_id:
980-
... session = await client.resume_session(last_id, {"on_permission_request": PermissionHandler.approve_all})
980+
... config = {"on_permission_request": PermissionHandler.approve_all}
981+
... session = await client.resume_session(last_id, config)
981982
"""
982983
if not self._client:
983984
raise RuntimeError("Client not connected")

python/copilot/generated/rpc.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,29 +1149,36 @@ def session_compaction_compact_result_to_dict(x: SessionCompactionCompactResult)
11491149
return to_class(SessionCompactionCompactResult, x)
11501150

11511151

1152+
def _timeout_kwargs(timeout: float | None) -> dict:
1153+
"""Build keyword arguments for optional timeout forwarding."""
1154+
if timeout is not None:
1155+
return {"timeout": timeout}
1156+
return {}
1157+
1158+
11521159
class ModelsApi:
11531160
def __init__(self, client: "JsonRpcClient"):
11541161
self._client = client
11551162

1156-
async def list(self) -> ModelsListResult:
1157-
return ModelsListResult.from_dict(await self._client.request("models.list", {}))
1163+
async def list(self, *, timeout: float | None = None) -> ModelsListResult:
1164+
return ModelsListResult.from_dict(await self._client.request("models.list", {}, **_timeout_kwargs(timeout)))
11581165

11591166

11601167
class ToolsApi:
11611168
def __init__(self, client: "JsonRpcClient"):
11621169
self._client = client
11631170

1164-
async def list(self, params: ToolsListParams) -> ToolsListResult:
1171+
async def list(self, params: ToolsListParams, *, timeout: float | None = None) -> ToolsListResult:
11651172
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
1166-
return ToolsListResult.from_dict(await self._client.request("tools.list", params_dict))
1173+
return ToolsListResult.from_dict(await self._client.request("tools.list", params_dict, **_timeout_kwargs(timeout)))
11671174

11681175

11691176
class AccountApi:
11701177
def __init__(self, client: "JsonRpcClient"):
11711178
self._client = client
11721179

1173-
async def get_quota(self) -> AccountGetQuotaResult:
1174-
return AccountGetQuotaResult.from_dict(await self._client.request("account.getQuota", {}))
1180+
async def get_quota(self, *, timeout: float | None = None) -> AccountGetQuotaResult:
1181+
return AccountGetQuotaResult.from_dict(await self._client.request("account.getQuota", {}, **_timeout_kwargs(timeout)))
11751182

11761183

11771184
class ServerRpc:
@@ -1182,113 +1189,113 @@ def __init__(self, client: "JsonRpcClient"):
11821189
self.tools = ToolsApi(client)
11831190
self.account = AccountApi(client)
11841191

1185-
async def ping(self, params: PingParams) -> PingResult:
1192+
async def ping(self, params: PingParams, *, timeout: float | None = None) -> PingResult:
11861193
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
1187-
return PingResult.from_dict(await self._client.request("ping", params_dict))
1194+
return PingResult.from_dict(await self._client.request("ping", params_dict, **_timeout_kwargs(timeout)))
11881195

11891196

11901197
class ModelApi:
11911198
def __init__(self, client: "JsonRpcClient", session_id: str):
11921199
self._client = client
11931200
self._session_id = session_id
11941201

1195-
async def get_current(self) -> SessionModelGetCurrentResult:
1196-
return SessionModelGetCurrentResult.from_dict(await self._client.request("session.model.getCurrent", {"sessionId": self._session_id}))
1202+
async def get_current(self, *, timeout: float | None = None) -> SessionModelGetCurrentResult:
1203+
return SessionModelGetCurrentResult.from_dict(await self._client.request("session.model.getCurrent", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
11971204

1198-
async def switch_to(self, params: SessionModelSwitchToParams) -> SessionModelSwitchToResult:
1205+
async def switch_to(self, params: SessionModelSwitchToParams, *, timeout: float | None = None) -> SessionModelSwitchToResult:
11991206
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
12001207
params_dict["sessionId"] = self._session_id
1201-
return SessionModelSwitchToResult.from_dict(await self._client.request("session.model.switchTo", params_dict))
1208+
return SessionModelSwitchToResult.from_dict(await self._client.request("session.model.switchTo", params_dict, **_timeout_kwargs(timeout)))
12021209

12031210

12041211
class ModeApi:
12051212
def __init__(self, client: "JsonRpcClient", session_id: str):
12061213
self._client = client
12071214
self._session_id = session_id
12081215

1209-
async def get(self) -> SessionModeGetResult:
1210-
return SessionModeGetResult.from_dict(await self._client.request("session.mode.get", {"sessionId": self._session_id}))
1216+
async def get(self, *, timeout: float | None = None) -> SessionModeGetResult:
1217+
return SessionModeGetResult.from_dict(await self._client.request("session.mode.get", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12111218

1212-
async def set(self, params: SessionModeSetParams) -> SessionModeSetResult:
1219+
async def set(self, params: SessionModeSetParams, *, timeout: float | None = None) -> SessionModeSetResult:
12131220
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
12141221
params_dict["sessionId"] = self._session_id
1215-
return SessionModeSetResult.from_dict(await self._client.request("session.mode.set", params_dict))
1222+
return SessionModeSetResult.from_dict(await self._client.request("session.mode.set", params_dict, **_timeout_kwargs(timeout)))
12161223

12171224

12181225
class PlanApi:
12191226
def __init__(self, client: "JsonRpcClient", session_id: str):
12201227
self._client = client
12211228
self._session_id = session_id
12221229

1223-
async def read(self) -> SessionPlanReadResult:
1224-
return SessionPlanReadResult.from_dict(await self._client.request("session.plan.read", {"sessionId": self._session_id}))
1230+
async def read(self, *, timeout: float | None = None) -> SessionPlanReadResult:
1231+
return SessionPlanReadResult.from_dict(await self._client.request("session.plan.read", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12251232

1226-
async def update(self, params: SessionPlanUpdateParams) -> SessionPlanUpdateResult:
1233+
async def update(self, params: SessionPlanUpdateParams, *, timeout: float | None = None) -> SessionPlanUpdateResult:
12271234
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
12281235
params_dict["sessionId"] = self._session_id
1229-
return SessionPlanUpdateResult.from_dict(await self._client.request("session.plan.update", params_dict))
1236+
return SessionPlanUpdateResult.from_dict(await self._client.request("session.plan.update", params_dict, **_timeout_kwargs(timeout)))
12301237

1231-
async def delete(self) -> SessionPlanDeleteResult:
1232-
return SessionPlanDeleteResult.from_dict(await self._client.request("session.plan.delete", {"sessionId": self._session_id}))
1238+
async def delete(self, *, timeout: float | None = None) -> SessionPlanDeleteResult:
1239+
return SessionPlanDeleteResult.from_dict(await self._client.request("session.plan.delete", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12331240

12341241

12351242
class WorkspaceApi:
12361243
def __init__(self, client: "JsonRpcClient", session_id: str):
12371244
self._client = client
12381245
self._session_id = session_id
12391246

1240-
async def list_files(self) -> SessionWorkspaceListFilesResult:
1241-
return SessionWorkspaceListFilesResult.from_dict(await self._client.request("session.workspace.listFiles", {"sessionId": self._session_id}))
1247+
async def list_files(self, *, timeout: float | None = None) -> SessionWorkspaceListFilesResult:
1248+
return SessionWorkspaceListFilesResult.from_dict(await self._client.request("session.workspace.listFiles", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12421249

1243-
async def read_file(self, params: SessionWorkspaceReadFileParams) -> SessionWorkspaceReadFileResult:
1250+
async def read_file(self, params: SessionWorkspaceReadFileParams, *, timeout: float | None = None) -> SessionWorkspaceReadFileResult:
12441251
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
12451252
params_dict["sessionId"] = self._session_id
1246-
return SessionWorkspaceReadFileResult.from_dict(await self._client.request("session.workspace.readFile", params_dict))
1253+
return SessionWorkspaceReadFileResult.from_dict(await self._client.request("session.workspace.readFile", params_dict, **_timeout_kwargs(timeout)))
12471254

1248-
async def create_file(self, params: SessionWorkspaceCreateFileParams) -> SessionWorkspaceCreateFileResult:
1255+
async def create_file(self, params: SessionWorkspaceCreateFileParams, *, timeout: float | None = None) -> SessionWorkspaceCreateFileResult:
12491256
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
12501257
params_dict["sessionId"] = self._session_id
1251-
return SessionWorkspaceCreateFileResult.from_dict(await self._client.request("session.workspace.createFile", params_dict))
1258+
return SessionWorkspaceCreateFileResult.from_dict(await self._client.request("session.workspace.createFile", params_dict, **_timeout_kwargs(timeout)))
12521259

12531260

12541261
class FleetApi:
12551262
def __init__(self, client: "JsonRpcClient", session_id: str):
12561263
self._client = client
12571264
self._session_id = session_id
12581265

1259-
async def start(self, params: SessionFleetStartParams) -> SessionFleetStartResult:
1266+
async def start(self, params: SessionFleetStartParams, *, timeout: float | None = None) -> SessionFleetStartResult:
12601267
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
12611268
params_dict["sessionId"] = self._session_id
1262-
return SessionFleetStartResult.from_dict(await self._client.request("session.fleet.start", params_dict))
1269+
return SessionFleetStartResult.from_dict(await self._client.request("session.fleet.start", params_dict, **_timeout_kwargs(timeout)))
12631270

12641271

12651272
class AgentApi:
12661273
def __init__(self, client: "JsonRpcClient", session_id: str):
12671274
self._client = client
12681275
self._session_id = session_id
12691276

1270-
async def list(self) -> SessionAgentListResult:
1271-
return SessionAgentListResult.from_dict(await self._client.request("session.agent.list", {"sessionId": self._session_id}))
1277+
async def list(self, *, timeout: float | None = None) -> SessionAgentListResult:
1278+
return SessionAgentListResult.from_dict(await self._client.request("session.agent.list", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12721279

1273-
async def get_current(self) -> SessionAgentGetCurrentResult:
1274-
return SessionAgentGetCurrentResult.from_dict(await self._client.request("session.agent.getCurrent", {"sessionId": self._session_id}))
1280+
async def get_current(self, *, timeout: float | None = None) -> SessionAgentGetCurrentResult:
1281+
return SessionAgentGetCurrentResult.from_dict(await self._client.request("session.agent.getCurrent", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12751282

1276-
async def select(self, params: SessionAgentSelectParams) -> SessionAgentSelectResult:
1283+
async def select(self, params: SessionAgentSelectParams, *, timeout: float | None = None) -> SessionAgentSelectResult:
12771284
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
12781285
params_dict["sessionId"] = self._session_id
1279-
return SessionAgentSelectResult.from_dict(await self._client.request("session.agent.select", params_dict))
1286+
return SessionAgentSelectResult.from_dict(await self._client.request("session.agent.select", params_dict, **_timeout_kwargs(timeout)))
12801287

1281-
async def deselect(self) -> SessionAgentDeselectResult:
1282-
return SessionAgentDeselectResult.from_dict(await self._client.request("session.agent.deselect", {"sessionId": self._session_id}))
1288+
async def deselect(self, *, timeout: float | None = None) -> SessionAgentDeselectResult:
1289+
return SessionAgentDeselectResult.from_dict(await self._client.request("session.agent.deselect", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12831290

12841291

12851292
class CompactionApi:
12861293
def __init__(self, client: "JsonRpcClient", session_id: str):
12871294
self._client = client
12881295
self._session_id = session_id
12891296

1290-
async def compact(self) -> SessionCompactionCompactResult:
1291-
return SessionCompactionCompactResult.from_dict(await self._client.request("session.compaction.compact", {"sessionId": self._session_id}))
1297+
async def compact(self, *, timeout: float | None = None) -> SessionCompactionCompactResult:
1298+
return SessionCompactionCompactResult.from_dict(await self._client.request("session.compaction.compact", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12921299

12931300

12941301
class SessionRpc:

python/test_rpc_timeout.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""Tests for timeout parameter on generated RPC methods."""
2+
3+
from unittest.mock import AsyncMock
4+
5+
import pytest
6+
7+
from copilot.generated.rpc import (
8+
FleetApi,
9+
Mode,
10+
ModeApi,
11+
ModelsApi,
12+
PlanApi,
13+
SessionFleetStartParams,
14+
SessionModeSetParams,
15+
ToolsApi,
16+
ToolsListParams,
17+
)
18+
19+
20+
class TestRpcTimeout:
21+
"""Tests for timeout forwarding across all four codegen branches:
22+
- session-scoped with params
23+
- session-scoped without params
24+
- server-scoped with params
25+
- server-scoped without params
26+
"""
27+
28+
# ── session-scoped, with params ──────────────────────────────────
29+
30+
@pytest.mark.asyncio
31+
async def test_default_timeout_not_forwarded(self):
32+
client = AsyncMock()
33+
client.request = AsyncMock(return_value={"started": True})
34+
api = FleetApi(client, "sess-1")
35+
36+
await api.start(SessionFleetStartParams(prompt="go"))
37+
38+
client.request.assert_called_once()
39+
_, kwargs = client.request.call_args
40+
assert "timeout" not in kwargs
41+
42+
@pytest.mark.asyncio
43+
async def test_custom_timeout_forwarded(self):
44+
client = AsyncMock()
45+
client.request = AsyncMock(return_value={"started": True})
46+
api = FleetApi(client, "sess-1")
47+
48+
await api.start(SessionFleetStartParams(prompt="go"), timeout=600.0)
49+
50+
_, kwargs = client.request.call_args
51+
assert kwargs["timeout"] == 600.0
52+
53+
@pytest.mark.asyncio
54+
async def test_timeout_on_session_params_method(self):
55+
client = AsyncMock()
56+
client.request = AsyncMock(return_value={"mode": "plan"})
57+
api = ModeApi(client, "sess-1")
58+
59+
await api.set(SessionModeSetParams(mode=Mode.PLAN), timeout=120.0)
60+
61+
_, kwargs = client.request.call_args
62+
assert kwargs["timeout"] == 120.0
63+
64+
# ── session-scoped, no params ────────────────────────────────────
65+
66+
@pytest.mark.asyncio
67+
async def test_timeout_on_session_no_params_method(self):
68+
client = AsyncMock()
69+
client.request = AsyncMock(return_value={"exists": True})
70+
api = PlanApi(client, "sess-1")
71+
72+
await api.read(timeout=90.0)
73+
74+
_, kwargs = client.request.call_args
75+
assert kwargs["timeout"] == 90.0
76+
77+
@pytest.mark.asyncio
78+
async def test_default_timeout_on_session_no_params_method(self):
79+
client = AsyncMock()
80+
client.request = AsyncMock(return_value={"exists": True})
81+
api = PlanApi(client, "sess-1")
82+
83+
await api.read()
84+
85+
_, kwargs = client.request.call_args
86+
assert "timeout" not in kwargs
87+
88+
# ── server-scoped, with params ─────────────────────────────────────
89+
90+
@pytest.mark.asyncio
91+
async def test_timeout_on_server_params_method(self):
92+
client = AsyncMock()
93+
client.request = AsyncMock(return_value={"tools": []})
94+
api = ToolsApi(client)
95+
96+
await api.list(ToolsListParams(), timeout=60.0)
97+
98+
_, kwargs = client.request.call_args
99+
assert kwargs["timeout"] == 60.0
100+
101+
@pytest.mark.asyncio
102+
async def test_default_timeout_on_server_params_method(self):
103+
client = AsyncMock()
104+
client.request = AsyncMock(return_value={"tools": []})
105+
api = ToolsApi(client)
106+
107+
await api.list(ToolsListParams())
108+
109+
_, kwargs = client.request.call_args
110+
assert "timeout" not in kwargs
111+
112+
# ── server-scoped, no params ─────────────────────────────────────
113+
114+
@pytest.mark.asyncio
115+
async def test_timeout_on_server_no_params_method(self):
116+
client = AsyncMock()
117+
client.request = AsyncMock(return_value={"models": []})
118+
api = ModelsApi(client)
119+
120+
await api.list(timeout=45.0)
121+
122+
_, kwargs = client.request.call_args
123+
assert kwargs["timeout"] == 45.0
124+
125+
@pytest.mark.asyncio
126+
async def test_default_timeout_on_server_no_params_method(self):
127+
client = AsyncMock()
128+
client.request = AsyncMock(return_value={"models": []})
129+
api = ModelsApi(client)
130+
131+
await api.list()
132+
133+
_, kwargs = client.request.call_args
134+
assert "timeout" not in kwargs

0 commit comments

Comments
 (0)