Skip to content

Commit 3adc7fa

Browse files
committed
fix(python): add timeout parameter to generated RPC methods
Every generated async RPC method now accepts an optional `timeout` keyword argument that is forwarded to `JsonRpcClient.request()`. This lets callers override the default 30s timeout for long-running RPCs like `session.fleet.start` without bypassing the typed API. Fixes #539
1 parent 388f2f3 commit 3adc7fa

File tree

3 files changed

+114
-49
lines changed

3 files changed

+114
-49
lines changed

python/copilot/generated/rpc.py

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,20 @@
33
Generated from: api.schema.json
44
"""
55

6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, Optional
77

88
if TYPE_CHECKING:
99
from ..jsonrpc import JsonRpcClient
1010

1111

12+
13+
def _timeout_kwargs(timeout: Optional[float]) -> dict:
14+
"""Build keyword arguments for optional timeout forwarding."""
15+
if timeout is not None:
16+
return {"timeout": timeout}
17+
return {}
18+
19+
1220
from dataclasses import dataclass
1321
from typing import Any, Optional, List, Dict, TypeVar, Type, cast, Callable
1422
from enum import Enum
@@ -1150,25 +1158,25 @@ class ModelsApi:
11501158
def __init__(self, client: "JsonRpcClient"):
11511159
self._client = client
11521160

1153-
async def list(self) -> ModelsListResult:
1154-
return ModelsListResult.from_dict(await self._client.request("models.list", {}))
1161+
async def list(self, *, timeout: Optional[float] = None) -> ModelsListResult:
1162+
return ModelsListResult.from_dict(await self._client.request("models.list", {}, **_timeout_kwargs(timeout)))
11551163

11561164

11571165
class ToolsApi:
11581166
def __init__(self, client: "JsonRpcClient"):
11591167
self._client = client
11601168

1161-
async def list(self, params: ToolsListParams) -> ToolsListResult:
1169+
async def list(self, params: ToolsListParams, *, timeout: Optional[float] = None) -> ToolsListResult:
11621170
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
1163-
return ToolsListResult.from_dict(await self._client.request("tools.list", params_dict))
1171+
return ToolsListResult.from_dict(await self._client.request("tools.list", params_dict, **_timeout_kwargs(timeout)))
11641172

11651173

11661174
class AccountApi:
11671175
def __init__(self, client: "JsonRpcClient"):
11681176
self._client = client
11691177

1170-
async def get_quota(self) -> AccountGetQuotaResult:
1171-
return AccountGetQuotaResult.from_dict(await self._client.request("account.getQuota", {}))
1178+
async def get_quota(self, *, timeout: Optional[float] = None) -> AccountGetQuotaResult:
1179+
return AccountGetQuotaResult.from_dict(await self._client.request("account.getQuota", {}, **_timeout_kwargs(timeout)))
11721180

11731181

11741182
class ServerRpc:
@@ -1179,113 +1187,113 @@ def __init__(self, client: "JsonRpcClient"):
11791187
self.tools = ToolsApi(client)
11801188
self.account = AccountApi(client)
11811189

1182-
async def ping(self, params: PingParams) -> PingResult:
1190+
async def ping(self, params: PingParams, *, timeout: Optional[float] = None) -> PingResult:
11831191
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
1184-
return PingResult.from_dict(await self._client.request("ping", params_dict))
1192+
return PingResult.from_dict(await self._client.request("ping", params_dict, **_timeout_kwargs(timeout)))
11851193

11861194

11871195
class ModelApi:
11881196
def __init__(self, client: "JsonRpcClient", session_id: str):
11891197
self._client = client
11901198
self._session_id = session_id
11911199

1192-
async def get_current(self) -> SessionModelGetCurrentResult:
1193-
return SessionModelGetCurrentResult.from_dict(await self._client.request("session.model.getCurrent", {"sessionId": self._session_id}))
1200+
async def get_current(self, *, timeout: Optional[float] = None) -> SessionModelGetCurrentResult:
1201+
return SessionModelGetCurrentResult.from_dict(await self._client.request("session.model.getCurrent", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
11941202

1195-
async def switch_to(self, params: SessionModelSwitchToParams) -> SessionModelSwitchToResult:
1203+
async def switch_to(self, params: SessionModelSwitchToParams, *, timeout: Optional[float] = None) -> SessionModelSwitchToResult:
11961204
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
11971205
params_dict["sessionId"] = self._session_id
1198-
return SessionModelSwitchToResult.from_dict(await self._client.request("session.model.switchTo", params_dict))
1206+
return SessionModelSwitchToResult.from_dict(await self._client.request("session.model.switchTo", params_dict, **_timeout_kwargs(timeout)))
11991207

12001208

12011209
class ModeApi:
12021210
def __init__(self, client: "JsonRpcClient", session_id: str):
12031211
self._client = client
12041212
self._session_id = session_id
12051213

1206-
async def get(self) -> SessionModeGetResult:
1207-
return SessionModeGetResult.from_dict(await self._client.request("session.mode.get", {"sessionId": self._session_id}))
1214+
async def get(self, *, timeout: Optional[float] = None) -> SessionModeGetResult:
1215+
return SessionModeGetResult.from_dict(await self._client.request("session.mode.get", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12081216

1209-
async def set(self, params: SessionModeSetParams) -> SessionModeSetResult:
1217+
async def set(self, params: SessionModeSetParams, *, timeout: Optional[float] = None) -> SessionModeSetResult:
12101218
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
12111219
params_dict["sessionId"] = self._session_id
1212-
return SessionModeSetResult.from_dict(await self._client.request("session.mode.set", params_dict))
1220+
return SessionModeSetResult.from_dict(await self._client.request("session.mode.set", params_dict, **_timeout_kwargs(timeout)))
12131221

12141222

12151223
class PlanApi:
12161224
def __init__(self, client: "JsonRpcClient", session_id: str):
12171225
self._client = client
12181226
self._session_id = session_id
12191227

1220-
async def read(self) -> SessionPlanReadResult:
1221-
return SessionPlanReadResult.from_dict(await self._client.request("session.plan.read", {"sessionId": self._session_id}))
1228+
async def read(self, *, timeout: Optional[float] = None) -> SessionPlanReadResult:
1229+
return SessionPlanReadResult.from_dict(await self._client.request("session.plan.read", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12221230

1223-
async def update(self, params: SessionPlanUpdateParams) -> SessionPlanUpdateResult:
1231+
async def update(self, params: SessionPlanUpdateParams, *, timeout: Optional[float] = None) -> SessionPlanUpdateResult:
12241232
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
12251233
params_dict["sessionId"] = self._session_id
1226-
return SessionPlanUpdateResult.from_dict(await self._client.request("session.plan.update", params_dict))
1234+
return SessionPlanUpdateResult.from_dict(await self._client.request("session.plan.update", params_dict, **_timeout_kwargs(timeout)))
12271235

1228-
async def delete(self) -> SessionPlanDeleteResult:
1229-
return SessionPlanDeleteResult.from_dict(await self._client.request("session.plan.delete", {"sessionId": self._session_id}))
1236+
async def delete(self, *, timeout: Optional[float] = None) -> SessionPlanDeleteResult:
1237+
return SessionPlanDeleteResult.from_dict(await self._client.request("session.plan.delete", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12301238

12311239

12321240
class WorkspaceApi:
12331241
def __init__(self, client: "JsonRpcClient", session_id: str):
12341242
self._client = client
12351243
self._session_id = session_id
12361244

1237-
async def list_files(self) -> SessionWorkspaceListFilesResult:
1238-
return SessionWorkspaceListFilesResult.from_dict(await self._client.request("session.workspace.listFiles", {"sessionId": self._session_id}))
1245+
async def list_files(self, *, timeout: Optional[float] = None) -> SessionWorkspaceListFilesResult:
1246+
return SessionWorkspaceListFilesResult.from_dict(await self._client.request("session.workspace.listFiles", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12391247

1240-
async def read_file(self, params: SessionWorkspaceReadFileParams) -> SessionWorkspaceReadFileResult:
1248+
async def read_file(self, params: SessionWorkspaceReadFileParams, *, timeout: Optional[float] = None) -> SessionWorkspaceReadFileResult:
12411249
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
12421250
params_dict["sessionId"] = self._session_id
1243-
return SessionWorkspaceReadFileResult.from_dict(await self._client.request("session.workspace.readFile", params_dict))
1251+
return SessionWorkspaceReadFileResult.from_dict(await self._client.request("session.workspace.readFile", params_dict, **_timeout_kwargs(timeout)))
12441252

1245-
async def create_file(self, params: SessionWorkspaceCreateFileParams) -> SessionWorkspaceCreateFileResult:
1253+
async def create_file(self, params: SessionWorkspaceCreateFileParams, *, timeout: Optional[float] = None) -> SessionWorkspaceCreateFileResult:
12461254
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
12471255
params_dict["sessionId"] = self._session_id
1248-
return SessionWorkspaceCreateFileResult.from_dict(await self._client.request("session.workspace.createFile", params_dict))
1256+
return SessionWorkspaceCreateFileResult.from_dict(await self._client.request("session.workspace.createFile", params_dict, **_timeout_kwargs(timeout)))
12491257

12501258

12511259
class FleetApi:
12521260
def __init__(self, client: "JsonRpcClient", session_id: str):
12531261
self._client = client
12541262
self._session_id = session_id
12551263

1256-
async def start(self, params: SessionFleetStartParams) -> SessionFleetStartResult:
1264+
async def start(self, params: SessionFleetStartParams, *, timeout: Optional[float] = None) -> SessionFleetStartResult:
12571265
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
12581266
params_dict["sessionId"] = self._session_id
1259-
return SessionFleetStartResult.from_dict(await self._client.request("session.fleet.start", params_dict))
1267+
return SessionFleetStartResult.from_dict(await self._client.request("session.fleet.start", params_dict, **_timeout_kwargs(timeout)))
12601268

12611269

12621270
class AgentApi:
12631271
def __init__(self, client: "JsonRpcClient", session_id: str):
12641272
self._client = client
12651273
self._session_id = session_id
12661274

1267-
async def list(self) -> SessionAgentListResult:
1268-
return SessionAgentListResult.from_dict(await self._client.request("session.agent.list", {"sessionId": self._session_id}))
1275+
async def list(self, *, timeout: Optional[float] = None) -> SessionAgentListResult:
1276+
return SessionAgentListResult.from_dict(await self._client.request("session.agent.list", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12691277

1270-
async def get_current(self) -> SessionAgentGetCurrentResult:
1271-
return SessionAgentGetCurrentResult.from_dict(await self._client.request("session.agent.getCurrent", {"sessionId": self._session_id}))
1278+
async def get_current(self, *, timeout: Optional[float] = None) -> SessionAgentGetCurrentResult:
1279+
return SessionAgentGetCurrentResult.from_dict(await self._client.request("session.agent.getCurrent", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12721280

1273-
async def select(self, params: SessionAgentSelectParams) -> SessionAgentSelectResult:
1281+
async def select(self, params: SessionAgentSelectParams, *, timeout: Optional[float] = None) -> SessionAgentSelectResult:
12741282
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
12751283
params_dict["sessionId"] = self._session_id
1276-
return SessionAgentSelectResult.from_dict(await self._client.request("session.agent.select", params_dict))
1284+
return SessionAgentSelectResult.from_dict(await self._client.request("session.agent.select", params_dict, **_timeout_kwargs(timeout)))
12771285

1278-
async def deselect(self) -> SessionAgentDeselectResult:
1279-
return SessionAgentDeselectResult.from_dict(await self._client.request("session.agent.deselect", {"sessionId": self._session_id}))
1286+
async def deselect(self, *, timeout: Optional[float] = None) -> SessionAgentDeselectResult:
1287+
return SessionAgentDeselectResult.from_dict(await self._client.request("session.agent.deselect", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12801288

12811289

12821290
class CompactionApi:
12831291
def __init__(self, client: "JsonRpcClient", session_id: str):
12841292
self._client = client
12851293
self._session_id = session_id
12861294

1287-
async def compact(self) -> SessionCompactionCompactResult:
1288-
return SessionCompactionCompactResult.from_dict(await self._client.request("session.compaction.compact", {"sessionId": self._session_id}))
1295+
async def compact(self, *, timeout: Optional[float] = None) -> SessionCompactionCompactResult:
1296+
return SessionCompactionCompactResult.from_dict(await self._client.request("session.compaction.compact", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))
12891297

12901298

12911299
class SessionRpc:

python/test_rpc_timeout.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Tests for timeout parameter on generated RPC methods."""
2+
from unittest.mock import AsyncMock
3+
4+
import pytest
5+
6+
from copilot.generated.rpc import (
7+
FleetApi,
8+
Mode,
9+
ModeApi,
10+
SessionFleetStartParams,
11+
SessionModeSetParams,
12+
)
13+
14+
15+
class TestRpcTimeout:
16+
@pytest.mark.asyncio
17+
async def test_default_timeout_not_forwarded(self):
18+
client = AsyncMock()
19+
client.request = AsyncMock(return_value={"started": True})
20+
api = FleetApi(client, "sess-1")
21+
22+
await api.start(SessionFleetStartParams(prompt="go"))
23+
24+
client.request.assert_called_once()
25+
_, kwargs = client.request.call_args
26+
assert "timeout" not in kwargs
27+
28+
@pytest.mark.asyncio
29+
async def test_custom_timeout_forwarded(self):
30+
client = AsyncMock()
31+
client.request = AsyncMock(return_value={"started": True})
32+
api = FleetApi(client, "sess-1")
33+
34+
await api.start(SessionFleetStartParams(prompt="go"), timeout=600.0)
35+
36+
_, kwargs = client.request.call_args
37+
assert kwargs["timeout"] == 600.0
38+
39+
@pytest.mark.asyncio
40+
async def test_timeout_on_other_methods(self):
41+
client = AsyncMock()
42+
client.request = AsyncMock(return_value={"mode": "plan"})
43+
api = ModeApi(client, "sess-1")
44+
45+
await api.set(SessionModeSetParams(mode=Mode.PLAN), timeout=120.0)
46+
47+
_, kwargs = client.request.call_args
48+
assert kwargs["timeout"] == 120.0

scripts/codegen/python.ts

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,20 @@ AUTO-GENERATED FILE - DO NOT EDIT
169169
Generated from: api.schema.json
170170
"""
171171
172-
from typing import TYPE_CHECKING
172+
from typing import TYPE_CHECKING, Optional
173173
174174
if TYPE_CHECKING:
175175
from ..jsonrpc import JsonRpcClient
176176
177+
`);
178+
179+
lines.push(`
180+
def _timeout_kwargs(timeout: Optional[float]) -> dict:
181+
"""Build keyword arguments for optional timeout forwarding."""
182+
if timeout is not None:
183+
return {"timeout": timeout}
184+
return {}
185+
177186
`);
178187
lines.push(typesCode);
179188
lines.push(``);
@@ -255,10 +264,10 @@ function emitMethod(lines: string[], name: string, method: RpcMethod, isSession:
255264
const hasParams = isSession ? nonSessionParams.length > 0 : Object.keys(paramProps).length > 0;
256265
const paramsType = toPascalCase(method.rpcMethod) + "Params";
257266

258-
// Build signature with typed params
267+
// Build signature with typed params + optional timeout
259268
const sig = hasParams
260-
? ` async def ${methodName}(self, params: ${paramsType}) -> ${resultType}:`
261-
: ` async def ${methodName}(self) -> ${resultType}:`;
269+
? ` async def ${methodName}(self, params: ${paramsType}, *, timeout: Optional[float] = None) -> ${resultType}:`
270+
: ` async def ${methodName}(self, *, timeout: Optional[float] = None) -> ${resultType}:`;
262271

263272
lines.push(sig);
264273

@@ -267,16 +276,16 @@ function emitMethod(lines: string[], name: string, method: RpcMethod, isSession:
267276
if (hasParams) {
268277
lines.push(` params_dict = {k: v for k, v in params.to_dict().items() if v is not None}`);
269278
lines.push(` params_dict["sessionId"] = self._session_id`);
270-
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict))`);
279+
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict, **_timeout_kwargs(timeout)))`);
271280
} else {
272-
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", {"sessionId": self._session_id}))`);
281+
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))`);
273282
}
274283
} else {
275284
if (hasParams) {
276285
lines.push(` params_dict = {k: v for k, v in params.to_dict().items() if v is not None}`);
277-
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict))`);
286+
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict, **_timeout_kwargs(timeout)))`);
278287
} else {
279-
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", {}))`);
288+
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", {}, **_timeout_kwargs(timeout)))`);
280289
}
281290
}
282291
lines.push(``);

0 commit comments

Comments
 (0)