Skip to content

Commit c61ea69

Browse files
jmoseleyCopilot
andcommitted
Teach codegen shell exec hooks
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 74b8784 commit c61ea69

File tree

5 files changed

+60
-28
lines changed

5 files changed

+60
-28
lines changed

dotnet/src/Generated/Rpc.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1977,4 +1977,4 @@ public async Task<SessionShellKillResult> KillAsync(string processId, SessionShe
19771977
[JsonSerializable(typeof(Tool))]
19781978
[JsonSerializable(typeof(ToolsListRequest))]
19791979
[JsonSerializable(typeof(ToolsListResult))]
1980-
internal partial class RpcJsonContext : JsonSerializerContext;
1980+
internal partial class RpcJsonContext : JsonSerializerContext;

python/copilot/generated/rpc.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2836,22 +2836,15 @@ async def handle_pending_permission_request(self, params: SessionPermissionsHand
28362836

28372837

28382838
class ShellApi:
2839-
def __init__(
2840-
self,
2841-
client: "JsonRpcClient",
2842-
session_id: str,
2843-
on_exec: Callable[[str], None] | None = None,
2844-
):
2839+
def __init__(self, client: "JsonRpcClient", session_id: str, on_exec: Callable[[str], None] | None = None):
28452840
self._client = client
28462841
self._session_id = session_id
28472842
self._on_exec = on_exec
28482843

28492844
async def exec(self, params: SessionShellExecParams, *, timeout: float | None = None) -> SessionShellExecResult:
28502845
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
28512846
params_dict["sessionId"] = self._session_id
2852-
result = SessionShellExecResult.from_dict(
2853-
await self._client.request("session.shell.exec", params_dict, **_timeout_kwargs(timeout))
2854-
)
2847+
result = SessionShellExecResult.from_dict(await self._client.request("session.shell.exec", params_dict, **_timeout_kwargs(timeout)))
28552848
if self._on_exec is not None:
28562849
self._on_exec(result.process_id)
28572850
return result
@@ -2864,12 +2857,7 @@ async def kill(self, params: SessionShellKillParams, *, timeout: float | None =
28642857

28652858
class SessionRpc:
28662859
"""Typed session-scoped RPC methods."""
2867-
def __init__(
2868-
self,
2869-
client: "JsonRpcClient",
2870-
session_id: str,
2871-
on_shell_exec: Callable[[str], None] | None = None,
2872-
):
2860+
def __init__(self, client: "JsonRpcClient", session_id: str, on_shell_exec: Callable[[str], None] | None = None):
28732861
self._client = client
28742862
self._session_id = session_id
28752863
self.model = ModelApi(client, session_id)
@@ -2893,3 +2881,4 @@ async def log(self, params: SessionLogParams, *, timeout: float | None = None) -
28932881
params_dict = {k: v for k, v in params.to_dict().items() if v is not None}
28942882
params_dict["sessionId"] = self._session_id
28952883
return SessionLogResult.from_dict(await self._client.request("session.log", params_dict, **_timeout_kwargs(timeout)))
2884+

scripts/codegen/csharp.ts

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -828,9 +828,11 @@ function emitSessionRpcClasses(node: Record<string, unknown>, classes: string[])
828828
const groups = Object.entries(node).filter(([, v]) => typeof v === "object" && v !== null && !isRpcMethod(v));
829829
const topLevelMethods = Object.entries(node).filter(([, v]) => isRpcMethod(v));
830830

831-
const srLines = [`/// <summary>Provides typed session-scoped RPC methods.</summary>`, `public class SessionRpc`, `{`, ` private readonly JsonRpc _rpc;`, ` private readonly string _sessionId;`, ""];
832-
srLines.push(` internal SessionRpc(JsonRpc rpc, string sessionId)`, ` {`, ` _rpc = rpc;`, ` _sessionId = sessionId;`);
833-
for (const [groupName] of groups) srLines.push(` ${toPascalCase(groupName)} = new ${toPascalCase(groupName)}Api(rpc, sessionId);`);
831+
const srLines = [`/// <summary>Provides typed session-scoped RPC methods.</summary>`, `public class SessionRpc`, `{`, ` private readonly JsonRpc _rpc;`, ` private readonly string _sessionId;`, ` private readonly Action<string>? _onShellExec;`, ""];
832+
srLines.push(` internal SessionRpc(JsonRpc rpc, string sessionId, Action<string>? onShellExec = null)`, ` {`, ` _rpc = rpc;`, ` _sessionId = sessionId;`, ` _onShellExec = onShellExec;`);
833+
for (const [groupName] of groups) srLines.push(
834+
` ${toPascalCase(groupName)} = new ${toPascalCase(groupName)}Api(rpc, sessionId${groupName === "shell" ? ", _onShellExec" : ""});`
835+
);
834836
srLines.push(` }`);
835837
for (const [groupName] of groups) srLines.push("", ` /// <summary>${toPascalCase(groupName)} APIs.</summary>`, ` public ${toPascalCase(groupName)}Api ${toPascalCase(groupName)} { get; }`);
836838

@@ -896,15 +898,22 @@ function emitSessionMethod(key: string, method: RpcMethod, lines: string[], clas
896898

897899
lines.push(`${indent}public async Task<${resultClassName}> ${methodName}Async(${sigParams.join(", ")})`);
898900
lines.push(`${indent}{`, `${indent} var request = new ${requestClassName} { ${bodyAssignments.join(", ")} };`);
899-
lines.push(`${indent} return await CopilotClient.InvokeRpcAsync<${resultClassName}>(_rpc, "${method.rpcMethod}", [request], cancellationToken);`, `${indent}}`);
901+
if (method.rpcMethod === "session.shell.exec") {
902+
lines.push(`${indent} var result = await CopilotClient.InvokeRpcAsync<${resultClassName}>(_rpc, "${method.rpcMethod}", [request], cancellationToken);`);
903+
lines.push(`${indent} _onExec?.Invoke(result.ProcessId);`);
904+
lines.push(`${indent} return result;`, `${indent}}`);
905+
} else {
906+
lines.push(`${indent} return await CopilotClient.InvokeRpcAsync<${resultClassName}>(_rpc, "${method.rpcMethod}", [request], cancellationToken);`, `${indent}}`);
907+
}
900908
}
901909

902910
function emitSessionApiClass(className: string, node: Record<string, unknown>, classes: string[]): string {
903911
const displayName = className.replace(/Api$/, "");
904912
const groupExperimental = isNodeFullyExperimental(node);
905913
const experimentalAttr = groupExperimental ? `[Experimental(Diagnostics.Experimental)]\n` : "";
906-
const lines = [`/// <summary>Provides session-scoped ${displayName} APIs.</summary>`, `${experimentalAttr}public class ${className}`, `{`, ` private readonly JsonRpc _rpc;`, ` private readonly string _sessionId;`, ""];
907-
lines.push(` internal ${className}(JsonRpc rpc, string sessionId)`, ` {`, ` _rpc = rpc;`, ` _sessionId = sessionId;`, ` }`);
914+
const ctorSuffix = className === "ShellApi" ? ", Action<string>? onExec = null" : "";
915+
const lines = [`/// <summary>Provides session-scoped ${displayName} APIs.</summary>`, `${experimentalAttr}public class ${className}`, `{`, ` private readonly JsonRpc _rpc;`, ` private readonly string _sessionId;`, ...(className === "ShellApi" ? [` private readonly Action<string>? _onExec;`] : []), ""];
916+
lines.push(` internal ${className}(JsonRpc rpc, string sessionId${ctorSuffix})`, ` {`, ` _rpc = rpc;`, ` _sessionId = sessionId;`, ...(className === "ShellApi" ? [` _onExec = onExec;`] : []), ` }`);
908917

909918
for (const [key, value] of Object.entries(node)) {
910919
if (!isRpcMethod(value)) continue;

scripts/codegen/go.ts

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ function emitRpcWrapper(lines: string[], node: Record<string, unknown>, isSessio
315315
if (isSession) {
316316
lines.push(`\tclient *jsonrpc2.Client`);
317317
lines.push(`\tsessionID string`);
318+
if (groupName === "shell") {
319+
lines.push(`\tonExec func(string)`);
320+
}
318321
} else {
319322
lines.push(`\tclient *jsonrpc2.Client`);
320323
}
@@ -355,14 +358,22 @@ function emitRpcWrapper(lines: string[], node: Record<string, unknown>, isSessio
355358
const padKey = (name: string) => (name + ":").padEnd(maxKeyLen + 1); // +1 for min trailing space
356359

357360
// Constructor
358-
const ctorParams = isSession ? "client *jsonrpc2.Client, sessionID string" : "client *jsonrpc2.Client";
361+
const ctorParams = isSession ? "client *jsonrpc2.Client, sessionID string, onShellExec ...func(string)" : "client *jsonrpc2.Client";
359362
const ctorFields = isSession ? "client: client, sessionID: sessionID," : "client: client,";
360363
lines.push(`func New${wrapperName}(${ctorParams}) *${wrapperName} {`);
364+
if (isSession) {
365+
lines.push(`\tvar shellExecHandler func(string)`);
366+
lines.push(`\tif len(onShellExec) > 0 {`);
367+
lines.push(`\t\tshellExecHandler = onShellExec[0]`);
368+
lines.push(`\t}`);
369+
}
361370
lines.push(`\treturn &${wrapperName}{${ctorFields}`);
362371
for (const [groupName] of groups) {
363372
const prefix = isSession ? "" : "Server";
364373
const apiInit = isSession
365-
? `&${toPascalCase(groupName)}${apiSuffix}{client: client, sessionID: sessionID}`
374+
? groupName === "shell"
375+
? `&${toPascalCase(groupName)}${apiSuffix}{client: client, sessionID: sessionID, onExec: shellExecHandler}`
376+
: `&${toPascalCase(groupName)}${apiSuffix}{client: client, sessionID: sessionID}`
366377
: `&${prefix}${toPascalCase(groupName)}${apiSuffix}{client: client}`;
367378
lines.push(`\t\t${padKey(toPascalCase(groupName))}${apiInit},`);
368379
}
@@ -421,6 +432,11 @@ function emitMethod(lines: string[], receiver: string, name: string, method: Rpc
421432
lines.push(`\tif err := json.Unmarshal(raw, &result); err != nil {`);
422433
lines.push(`\t\treturn nil, err`);
423434
lines.push(`\t}`);
435+
if (method.rpcMethod === "session.shell.exec") {
436+
lines.push(`\tif a.onExec != nil {`);
437+
lines.push(`\t\ta.onExec(result.ProcessID)`);
438+
lines.push(`\t}`);
439+
}
424440
lines.push(`\treturn &result, nil`);
425441
lines.push(`}`);
426442
lines.push(``);

scripts/codegen/python.ts

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,16 @@ function emitRpcWrapper(lines: string[], node: Record<string, unknown>, isSessio
319319
lines.push(`# Experimental: this API group is experimental and may change or be removed.`);
320320
}
321321
lines.push(`class ${apiName}:`);
322-
lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str):`);
322+
if (groupName === "shell") {
323+
lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str, on_exec: Callable[[str], None] | None = None):`);
324+
} else {
325+
lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str):`);
326+
}
323327
lines.push(` self._client = client`);
324328
lines.push(` self._session_id = session_id`);
329+
if (groupName === "shell") {
330+
lines.push(` self._on_exec = on_exec`);
331+
}
325332
} else {
326333
if (groupExperimental) {
327334
lines.push(`# Experimental: this API group is experimental and may change or be removed.`);
@@ -342,11 +349,15 @@ function emitRpcWrapper(lines: string[], node: Record<string, unknown>, isSessio
342349
if (isSession) {
343350
lines.push(`class ${wrapperName}:`);
344351
lines.push(` """Typed session-scoped RPC methods."""`);
345-
lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str):`);
352+
lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str, on_shell_exec: Callable[[str], None] | None = None):`);
346353
lines.push(` self._client = client`);
347354
lines.push(` self._session_id = session_id`);
348355
for (const [groupName] of groups) {
349-
lines.push(` self.${toSnakeCase(groupName)} = ${toPascalCase(groupName)}Api(client, session_id)`);
356+
if (groupName === "shell") {
357+
lines.push(` self.${toSnakeCase(groupName)} = ${toPascalCase(groupName)}Api(client, session_id, on_shell_exec)`);
358+
} else {
359+
lines.push(` self.${toSnakeCase(groupName)} = ${toPascalCase(groupName)}Api(client, session_id)`);
360+
}
350361
}
351362
} else {
352363
lines.push(`class ${wrapperName}:`);
@@ -392,7 +403,14 @@ function emitMethod(lines: string[], name: string, method: RpcMethod, isSession:
392403
if (hasParams) {
393404
lines.push(` params_dict = {k: v for k, v in params.to_dict().items() if v is not None}`);
394405
lines.push(` params_dict["sessionId"] = self._session_id`);
395-
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict, **_timeout_kwargs(timeout)))`);
406+
if (method.rpcMethod === "session.shell.exec") {
407+
lines.push(` result = ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict, **_timeout_kwargs(timeout)))`);
408+
lines.push(` if self._on_exec is not None:`);
409+
lines.push(` self._on_exec(result.process_id)`);
410+
lines.push(` return result`);
411+
} else {
412+
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", params_dict, **_timeout_kwargs(timeout)))`);
413+
}
396414
} else {
397415
lines.push(` return ${resultType}.from_dict(await self._client.request("${method.rpcMethod}", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)))`);
398416
}

0 commit comments

Comments
 (0)