Skip to content

Commit 784c708

Browse files
committed
feat: update method signatures to return response types for consistency
Change-Id: I702ad8086d4da8253e8caca8d11ec8346e690bff Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent 2c5041b commit 784c708

14 files changed

Lines changed: 581 additions & 164 deletions

File tree

go/acp_test.go

Lines changed: 69 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,25 @@ import (
1010
)
1111

1212
type clientFuncs struct {
13-
WriteTextFileFunc func(context.Context, WriteTextFileRequest) error
13+
WriteTextFileFunc func(context.Context, WriteTextFileRequest) (WriteTextFileResponse, error)
1414
ReadTextFileFunc func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error)
1515
RequestPermissionFunc func(context.Context, RequestPermissionRequest) (RequestPermissionResponse, error)
1616
SessionUpdateFunc func(context.Context, SessionNotification) error
1717
// Terminal-related handlers
1818
CreateTerminalFunc func(context.Context, CreateTerminalRequest) (CreateTerminalResponse, error)
19-
KillTerminalCommandFunc func(context.Context, KillTerminalCommandRequest) error
20-
ReleaseTerminalFunc func(context.Context, ReleaseTerminalRequest) error
19+
KillTerminalCommandFunc func(context.Context, KillTerminalCommandRequest) (KillTerminalCommandResponse, error)
20+
ReleaseTerminalFunc func(context.Context, ReleaseTerminalRequest) (ReleaseTerminalResponse, error)
2121
TerminalOutputFunc func(context.Context, TerminalOutputRequest) (TerminalOutputResponse, error)
2222
WaitForTerminalExitFunc func(context.Context, WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error)
2323
}
2424

2525
var _ Client = (*clientFuncs)(nil)
2626

27-
func (c clientFuncs) WriteTextFile(ctx context.Context, p WriteTextFileRequest) error {
27+
func (c clientFuncs) WriteTextFile(ctx context.Context, p WriteTextFileRequest) (WriteTextFileResponse, error) {
2828
if c.WriteTextFileFunc != nil {
2929
return c.WriteTextFileFunc(ctx, p)
3030
}
31-
return nil
31+
return WriteTextFileResponse{}, nil
3232
}
3333

3434
func (c clientFuncs) ReadTextFile(ctx context.Context, p ReadTextFileRequest) (ReadTextFileResponse, error) {
@@ -61,19 +61,19 @@ func (c *clientFuncs) CreateTerminal(ctx context.Context, params CreateTerminalR
6161
}
6262

6363
// KillTerminalCommand implements Client.
64-
func (c *clientFuncs) KillTerminalCommand(ctx context.Context, params KillTerminalCommandRequest) error {
64+
func (c clientFuncs) KillTerminalCommand(ctx context.Context, params KillTerminalCommandRequest) (KillTerminalCommandResponse, error) {
6565
if c.KillTerminalCommandFunc != nil {
6666
return c.KillTerminalCommandFunc(ctx, params)
6767
}
68-
return nil
68+
return KillTerminalCommandResponse{}, nil
6969
}
7070

7171
// ReleaseTerminal implements Client.
72-
func (c *clientFuncs) ReleaseTerminal(ctx context.Context, params ReleaseTerminalRequest) error {
72+
func (c clientFuncs) ReleaseTerminal(ctx context.Context, params ReleaseTerminalRequest) (ReleaseTerminalResponse, error) {
7373
if c.ReleaseTerminalFunc != nil {
7474
return c.ReleaseTerminalFunc(ctx, params)
7575
}
76-
return nil
76+
return ReleaseTerminalResponse{}, nil
7777
}
7878

7979
// TerminalOutput implements Client.
@@ -93,12 +93,14 @@ func (c *clientFuncs) WaitForTerminalExit(ctx context.Context, params WaitForTer
9393
}
9494

9595
type agentFuncs struct {
96-
InitializeFunc func(context.Context, InitializeRequest) (InitializeResponse, error)
97-
NewSessionFunc func(context.Context, NewSessionRequest) (NewSessionResponse, error)
98-
LoadSessionFunc func(context.Context, LoadSessionRequest) (LoadSessionResponse, error)
99-
AuthenticateFunc func(context.Context, AuthenticateRequest) error
100-
PromptFunc func(context.Context, PromptRequest) (PromptResponse, error)
101-
CancelFunc func(context.Context, CancelNotification) error
96+
InitializeFunc func(context.Context, InitializeRequest) (InitializeResponse, error)
97+
NewSessionFunc func(context.Context, NewSessionRequest) (NewSessionResponse, error)
98+
LoadSessionFunc func(context.Context, LoadSessionRequest) (LoadSessionResponse, error)
99+
AuthenticateFunc func(context.Context, AuthenticateRequest) (AuthenticateResponse, error)
100+
PromptFunc func(context.Context, PromptRequest) (PromptResponse, error)
101+
CancelFunc func(context.Context, CancelNotification) error
102+
SetSessionModeFunc func(ctx context.Context, params SetSessionModeRequest) (SetSessionModeResponse, error)
103+
SetSessionModelFunc func(ctx context.Context, params SetSessionModelRequest) (SetSessionModelResponse, error)
102104
}
103105

104106
var (
@@ -127,11 +129,11 @@ func (a agentFuncs) LoadSession(ctx context.Context, p LoadSessionRequest) (Load
127129
return LoadSessionResponse{}, nil
128130
}
129131

130-
func (a agentFuncs) Authenticate(ctx context.Context, p AuthenticateRequest) error {
132+
func (a agentFuncs) Authenticate(ctx context.Context, p AuthenticateRequest) (AuthenticateResponse, error) {
131133
if a.AuthenticateFunc != nil {
132134
return a.AuthenticateFunc(ctx, p)
133135
}
134-
return nil
136+
return AuthenticateResponse{}, nil
135137
}
136138

137139
func (a agentFuncs) Prompt(ctx context.Context, p PromptRequest) (PromptResponse, error) {
@@ -148,15 +150,31 @@ func (a agentFuncs) Cancel(ctx context.Context, n CancelNotification) error {
148150
return nil
149151
}
150152

153+
// SetSessionMode implements Agent.
154+
func (a agentFuncs) SetSessionMode(ctx context.Context, params SetSessionModeRequest) (SetSessionModeResponse, error) {
155+
if a.SetSessionModeFunc != nil {
156+
return a.SetSessionModeFunc(ctx, params)
157+
}
158+
return SetSessionModeResponse{}, nil
159+
}
160+
161+
// SetSessionModel implements Agent.
162+
func (a agentFuncs) SetSessionModel(ctx context.Context, params SetSessionModelRequest) (SetSessionModelResponse, error) {
163+
if a.SetSessionModelFunc != nil {
164+
return a.SetSessionModelFunc(ctx, params)
165+
}
166+
return SetSessionModelResponse{}, nil
167+
}
168+
151169
// Test bidirectional error handling similar to typescript/acp.test.ts
152170
func TestConnectionHandlesErrorsBidirectional(t *testing.T) {
153171
ctx := context.Background()
154172
c2aR, c2aW := io.Pipe()
155173
a2cR, a2cW := io.Pipe()
156174

157175
c := NewClientSideConnection(&clientFuncs{
158-
WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error {
159-
return &RequestError{Code: -32603, Message: "Write failed"}
176+
WriteTextFileFunc: func(context.Context, WriteTextFileRequest) (WriteTextFileResponse, error) {
177+
return WriteTextFileResponse{}, &RequestError{Code: -32603, Message: "Write failed"}
160178
},
161179
ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) {
162180
return ReadTextFileResponse{}, &RequestError{Code: -32603, Message: "Read failed"}
@@ -176,8 +194,8 @@ func TestConnectionHandlesErrorsBidirectional(t *testing.T) {
176194
LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) {
177195
return LoadSessionResponse{}, &RequestError{Code: -32603, Message: "Failed to load session"}
178196
},
179-
AuthenticateFunc: func(context.Context, AuthenticateRequest) error {
180-
return &RequestError{Code: -32603, Message: "Authentication failed"}
197+
AuthenticateFunc: func(context.Context, AuthenticateRequest) (AuthenticateResponse, error) {
198+
return AuthenticateResponse{}, &RequestError{Code: -32603, Message: "Authentication failed"}
181199
},
182200
PromptFunc: func(context.Context, PromptRequest) (PromptResponse, error) {
183201
return PromptResponse{}, &RequestError{Code: -32603, Message: "Prompt failed"}
@@ -186,7 +204,7 @@ func TestConnectionHandlesErrorsBidirectional(t *testing.T) {
186204
}, a2cW, c2aR)
187205

188206
// Client->Agent direction: expect error
189-
if err := agentConn.WriteTextFile(ctx, WriteTextFileRequest{Path: "/test.txt", Content: "test", SessionId: "test-session"}); err == nil {
207+
if _, err := agentConn.WriteTextFile(ctx, WriteTextFileRequest{Path: "/test.txt", Content: "test", SessionId: "test-session"}); err == nil {
190208
t.Fatalf("expected error for writeTextFile, got nil")
191209
}
192210

@@ -205,12 +223,12 @@ func TestConnectionHandlesConcurrentRequests(t *testing.T) {
205223
requestCount := 0
206224

207225
_ = NewClientSideConnection(&clientFuncs{
208-
WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error {
226+
WriteTextFileFunc: func(context.Context, WriteTextFileRequest) (WriteTextFileResponse, error) {
209227
mu.Lock()
210228
requestCount++
211229
mu.Unlock()
212230
time.Sleep(40 * time.Millisecond)
213-
return nil
231+
return WriteTextFileResponse{}, nil
214232
},
215233
ReadTextFileFunc: func(_ context.Context, req ReadTextFileRequest) (ReadTextFileResponse, error) {
216234
return ReadTextFileResponse{Content: "Content of " + req.Path}, nil
@@ -230,7 +248,9 @@ func TestConnectionHandlesConcurrentRequests(t *testing.T) {
230248
LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) {
231249
return LoadSessionResponse{}, nil
232250
},
233-
AuthenticateFunc: func(context.Context, AuthenticateRequest) error { return nil },
251+
AuthenticateFunc: func(context.Context, AuthenticateRequest) (AuthenticateResponse, error) {
252+
return AuthenticateResponse{}, nil
253+
},
234254
PromptFunc: func(context.Context, PromptRequest) (PromptResponse, error) {
235255
return PromptResponse{StopReason: "end_turn"}, nil
236256
},
@@ -249,7 +269,7 @@ func TestConnectionHandlesConcurrentRequests(t *testing.T) {
249269
req := p
250270
go func() {
251271
defer wg.Done()
252-
errs[idx] = agentConn.WriteTextFile(context.Background(), req)
272+
_, errs[idx] = agentConn.WriteTextFile(context.Background(), req)
253273
}()
254274
}
255275
wg.Wait()
@@ -276,9 +296,9 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) {
276296
push := func(s string) { mu.Lock(); defer mu.Unlock(); log = append(log, s) }
277297

278298
cs := NewClientSideConnection(&clientFuncs{
279-
WriteTextFileFunc: func(_ context.Context, req WriteTextFileRequest) error {
299+
WriteTextFileFunc: func(_ context.Context, req WriteTextFileRequest) (WriteTextFileResponse, error) {
280300
push("writeTextFile called: " + req.Path)
281-
return nil
301+
return WriteTextFileResponse{}, nil
282302
},
283303
ReadTextFileFunc: func(_ context.Context, req ReadTextFileRequest) (ReadTextFileResponse, error) {
284304
push("readTextFile called: " + req.Path)
@@ -306,9 +326,9 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) {
306326
push("loadSession called: " + string(p.SessionId))
307327
return LoadSessionResponse{}, nil
308328
},
309-
AuthenticateFunc: func(_ context.Context, p AuthenticateRequest) error {
329+
AuthenticateFunc: func(_ context.Context, p AuthenticateRequest) (AuthenticateResponse, error) {
310330
push("authenticate called: " + string(p.MethodId))
311-
return nil
331+
return AuthenticateResponse{}, nil
312332
},
313333
PromptFunc: func(_ context.Context, p PromptRequest) (PromptResponse, error) {
314334
push("prompt called: " + string(p.SessionId))
@@ -323,7 +343,7 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) {
323343
if _, err := cs.NewSession(context.Background(), NewSessionRequest{Cwd: "/test", McpServers: []McpServer{}}); err != nil {
324344
t.Fatalf("newSession error: %v", err)
325345
}
326-
if err := as.WriteTextFile(context.Background(), WriteTextFileRequest{Path: "/test.txt", Content: "test", SessionId: "test-session"}); err != nil {
346+
if _, err := as.WriteTextFile(context.Background(), WriteTextFileRequest{Path: "/test.txt", Content: "test", SessionId: "test-session"}); err != nil {
327347
t.Fatalf("writeTextFile error: %v", err)
328348
}
329349
if _, err := as.ReadTextFile(context.Background(), ReadTextFileRequest{Path: "/test.txt", SessionId: "test-session"}); err != nil {
@@ -376,7 +396,9 @@ func TestConnectionHandlesNotifications(t *testing.T) {
376396
push := func(s string) { mu.Lock(); logs = append(logs, s); mu.Unlock() }
377397

378398
clientSide := NewClientSideConnection(&clientFuncs{
379-
WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error { return nil },
399+
WriteTextFileFunc: func(context.Context, WriteTextFileRequest) (WriteTextFileResponse, error) {
400+
return WriteTextFileResponse{}, nil
401+
},
380402
ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) {
381403
return ReadTextFileResponse{Content: "test"}, nil
382404
},
@@ -405,7 +427,9 @@ func TestConnectionHandlesNotifications(t *testing.T) {
405427
LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) {
406428
return LoadSessionResponse{}, nil
407429
},
408-
AuthenticateFunc: func(context.Context, AuthenticateRequest) error { return nil },
430+
AuthenticateFunc: func(context.Context, AuthenticateRequest) (AuthenticateResponse, error) {
431+
return AuthenticateResponse{}, nil
432+
},
409433
PromptFunc: func(context.Context, PromptRequest) (PromptResponse, error) {
410434
return PromptResponse{StopReason: "end_turn"}, nil
411435
},
@@ -447,7 +471,9 @@ func TestConnectionHandlesInitialize(t *testing.T) {
447471
a2cR, a2cW := io.Pipe()
448472

449473
agentConn := NewClientSideConnection(&clientFuncs{
450-
WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error { return nil },
474+
WriteTextFileFunc: func(context.Context, WriteTextFileRequest) (WriteTextFileResponse, error) {
475+
return WriteTextFileResponse{}, nil
476+
},
451477
ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) {
452478
return ReadTextFileResponse{Content: "test"}, nil
453479
},
@@ -478,7 +504,9 @@ func TestConnectionHandlesInitialize(t *testing.T) {
478504
LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) {
479505
return LoadSessionResponse{}, nil
480506
},
481-
AuthenticateFunc: func(context.Context, AuthenticateRequest) error { return nil },
507+
AuthenticateFunc: func(context.Context, AuthenticateRequest) (AuthenticateResponse, error) {
508+
return AuthenticateResponse{}, nil
509+
},
482510
PromptFunc: func(context.Context, PromptRequest) (PromptResponse, error) {
483511
return PromptResponse{StopReason: "end_turn"}, nil
484512
},
@@ -527,7 +555,9 @@ func TestPromptCancellationSendsCancelAndAllowsNewSession(t *testing.T) {
527555
LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) {
528556
return LoadSessionResponse{}, nil
529557
},
530-
AuthenticateFunc: func(context.Context, AuthenticateRequest) error { return nil },
558+
AuthenticateFunc: func(context.Context, AuthenticateRequest) (AuthenticateResponse, error) {
559+
return AuthenticateResponse{}, nil
560+
},
531561
PromptFunc: func(ctx context.Context, p PromptRequest) (PromptResponse, error) {
532562
<-ctx.Done()
533563
// mark that prompt finished due to cancellation
@@ -548,7 +578,9 @@ func TestPromptCancellationSendsCancelAndAllowsNewSession(t *testing.T) {
548578

549579
// Client side
550580
cs := NewClientSideConnection(&clientFuncs{
551-
WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error { return nil },
581+
WriteTextFileFunc: func(context.Context, WriteTextFileRequest) (WriteTextFileResponse, error) {
582+
return WriteTextFileResponse{}, nil
583+
},
552584
ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) {
553585
return ReadTextFileResponse{Content: ""}, nil
554586
},

go/agent_gen.go

Lines changed: 38 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)