Skip to content

Commit 6936cce

Browse files
Refactor setModel to be thin wrappers around session.Rpc.Model.SwitchTo()
Co-authored-by: SteveSandersonMS <1101362+SteveSandersonMS@users.noreply.github.com>
1 parent a0367b8 commit 6936cce

8 files changed

Lines changed: 12 additions & 55 deletions

File tree

dotnet/src/Session.cs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,7 @@ await InvokeRpcAsync<object>(
553553
/// </example>
554554
public async Task SetModelAsync(string model, CancellationToken cancellationToken = default)
555555
{
556-
await InvokeRpcAsync<object>(
557-
"session.setModel", [new SetModelRequest { SessionId = SessionId, Model = model }], cancellationToken);
556+
await Rpc.Model.SwitchToAsync(model, cancellationToken);
558557
}
559558

560559
/// <summary>
@@ -655,12 +654,6 @@ internal record SessionDestroyRequest
655654
public string SessionId { get; init; } = string.Empty;
656655
}
657656

658-
internal record SetModelRequest
659-
{
660-
public string SessionId { get; init; } = string.Empty;
661-
public string Model { get; init; } = string.Empty;
662-
}
663-
664657
[JsonSourceGenerationOptions(
665658
JsonSerializerDefaults.Web,
666659
AllowOutOfOrderMetadataProperties = true,
@@ -673,7 +666,6 @@ internal record SetModelRequest
673666
[JsonSerializable(typeof(SendMessageResponse))]
674667
[JsonSerializable(typeof(SessionAbortRequest))]
675668
[JsonSerializable(typeof(SessionDestroyRequest))]
676-
[JsonSerializable(typeof(SetModelRequest))]
677669
[JsonSerializable(typeof(UserMessageDataAttachmentsItem))]
678670
[JsonSerializable(typeof(PreToolUseHookInput))]
679671
[JsonSerializable(typeof(PreToolUseHookOutput))]

go/client_test.go

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -448,26 +448,6 @@ func TestResumeSessionRequest_ClientName(t *testing.T) {
448448
})
449449
}
450450

451-
func TestSetModelRequest(t *testing.T) {
452-
t.Run("includes sessionId and model in JSON", func(t *testing.T) {
453-
req := sessionSetModelRequest{SessionID: "s1", Model: "gpt-4.1"}
454-
data, err := json.Marshal(req)
455-
if err != nil {
456-
t.Fatalf("Failed to marshal: %v", err)
457-
}
458-
var m map[string]any
459-
if err := json.Unmarshal(data, &m); err != nil {
460-
t.Fatalf("Failed to unmarshal: %v", err)
461-
}
462-
if m["sessionId"] != "s1" {
463-
t.Errorf("Expected sessionId 's1', got %v", m["sessionId"])
464-
}
465-
if m["model"] != "gpt-4.1" {
466-
t.Errorf("Expected model 'gpt-4.1', got %v", m["model"])
467-
}
468-
})
469-
}
470-
471451
func TestClient_CreateSession_RequiresPermissionHandler(t *testing.T) {
472452
t.Run("returns error when config is nil", func(t *testing.T) {
473453
client := NewClient(nil)

go/session.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -586,10 +586,7 @@ func (s *Session) Abort(ctx context.Context) error {
586586
// log.Printf("Failed to set model: %v", err)
587587
// }
588588
func (s *Session) SetModel(ctx context.Context, model string) error {
589-
_, err := s.client.Request("session.setModel", sessionSetModelRequest{
590-
SessionID: s.SessionID,
591-
Model: model,
592-
})
589+
_, err := s.RPC.Model.SwitchTo(ctx, &rpc.SessionModelSwitchToParams{ModelID: model})
593590
if err != nil {
594591
return fmt.Errorf("failed to set model: %w", err)
595592
}

go/types.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -828,12 +828,6 @@ type sessionAbortRequest struct {
828828
SessionID string `json:"sessionId"`
829829
}
830830

831-
// sessionSetModelRequest is the request for session.setModel
832-
type sessionSetModelRequest struct {
833-
SessionID string `json:"sessionId"`
834-
Model string `json:"model"`
835-
}
836-
837831
type sessionSendRequest struct {
838832
SessionID string `json:"sessionId"`
839833
Prompt string `json:"prompt"`

nodejs/src/session.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -562,9 +562,6 @@ export class CopilotSession {
562562
* ```
563563
*/
564564
async setModel(model: string): Promise<void> {
565-
await this.connection.sendRequest("session.setModel", {
566-
sessionId: this.sessionId,
567-
model,
568-
});
565+
await this.rpc.model.switchTo({ modelId: model });
569566
}
570567
}

nodejs/test/client.test.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ describe("CopilotClient", () => {
8080
);
8181
});
8282

83-
it("sends session.setModel RPC with correct params", async () => {
83+
it("sends session.model.switchTo RPC with correct params", async () => {
8484
const client = new CopilotClient();
8585
await client.start();
8686
onTestFinished(() => client.forceStop());
@@ -90,16 +90,16 @@ describe("CopilotClient", () => {
9090
// Mock sendRequest to capture the call without hitting the runtime
9191
const spy = vi.spyOn((client as any).connection!, "sendRequest")
9292
.mockImplementation(async (method: string, params: any) => {
93-
if (method === "session.setModel") return {};
93+
if (method === "session.model.switchTo") return {};
9494
// Fall through for other methods (shouldn't be called)
9595
throw new Error(`Unexpected method: ${method}`);
9696
});
9797

9898
await session.setModel("gpt-4.1");
9999

100100
expect(spy).toHaveBeenCalledWith(
101-
"session.setModel",
102-
{ sessionId: session.sessionId, model: "gpt-4.1" }
101+
"session.model.switchTo",
102+
{ sessionId: session.sessionId, modelId: "gpt-4.1" }
103103
);
104104

105105
spy.mockRestore();

python/copilot/session.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Callable
1212
from typing import Any, cast
1313

14-
from .generated.rpc import SessionRpc
14+
from .generated.rpc import SessionModelSwitchToParams, SessionRpc
1515
from .generated.session_events import SessionEvent, SessionEventType, session_event_from_dict
1616
from .types import (
1717
MessageOptions,
@@ -537,7 +537,4 @@ async def set_model(self, model: str) -> None:
537537
Example:
538538
>>> await session.set_model("gpt-4.1")
539539
"""
540-
await self._client.request(
541-
"session.setModel",
542-
{"sessionId": self.session_id, "model": model},
543-
)
540+
await self.rpc.model.switch_to(SessionModelSwitchToParams(model_id=model))

python/test_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,13 +242,13 @@ async def test_set_model_sends_correct_rpc(self):
242242

243243
async def mock_request(method, params):
244244
captured[method] = params
245-
if method == "session.setModel":
245+
if method == "session.model.switchTo":
246246
return {}
247247
return await original_request(method, params)
248248

249249
client._client.request = mock_request
250250
await session.set_model("gpt-4.1")
251-
assert captured["session.setModel"]["sessionId"] == session.session_id
252-
assert captured["session.setModel"]["model"] == "gpt-4.1"
251+
assert captured["session.model.switchTo"]["sessionId"] == session.session_id
252+
assert captured["session.model.switchTo"]["modelId"] == "gpt-4.1"
253253
finally:
254254
await client.force_stop()

0 commit comments

Comments
 (0)