@@ -10,25 +10,25 @@ import (
1010)
1111
1212type clientFuncs struct {
13- WriteTextFileFunc func (context.Context , WriteTextFileRequest ) error
13+ WriteTextFileFunc func (context.Context , WriteTextFileRequest ) ( WriteTextFileResponse , error )
1414 ReadTextFileFunc func (context.Context , ReadTextFileRequest ) (ReadTextFileResponse , error )
1515 RequestPermissionFunc func (context.Context , RequestPermissionRequest ) (RequestPermissionResponse , error )
1616 SessionUpdateFunc func (context.Context , SessionNotification ) error
1717 // Terminal-related handlers
1818 CreateTerminalFunc func (context.Context , CreateTerminalRequest ) (CreateTerminalResponse , error )
19- KillTerminalCommandFunc func (context.Context , KillTerminalCommandRequest ) error
20- ReleaseTerminalFunc func (context.Context , ReleaseTerminalRequest ) error
19+ KillTerminalCommandFunc func (context.Context , KillTerminalCommandRequest ) ( KillTerminalCommandResponse , error )
20+ ReleaseTerminalFunc func (context.Context , ReleaseTerminalRequest ) ( ReleaseTerminalResponse , error )
2121 TerminalOutputFunc func (context.Context , TerminalOutputRequest ) (TerminalOutputResponse , error )
2222 WaitForTerminalExitFunc func (context.Context , WaitForTerminalExitRequest ) (WaitForTerminalExitResponse , error )
2323}
2424
2525var _ Client = (* clientFuncs )(nil )
2626
27- func (c clientFuncs ) WriteTextFile (ctx context.Context , p WriteTextFileRequest ) error {
27+ func (c clientFuncs ) WriteTextFile (ctx context.Context , p WriteTextFileRequest ) ( WriteTextFileResponse , error ) {
2828 if c .WriteTextFileFunc != nil {
2929 return c .WriteTextFileFunc (ctx , p )
3030 }
31- return nil
31+ return WriteTextFileResponse {}, nil
3232}
3333
3434func (c clientFuncs ) ReadTextFile (ctx context.Context , p ReadTextFileRequest ) (ReadTextFileResponse , error ) {
@@ -61,19 +61,19 @@ func (c *clientFuncs) CreateTerminal(ctx context.Context, params CreateTerminalR
6161}
6262
6363// KillTerminalCommand implements Client.
64- func (c * clientFuncs ) KillTerminalCommand (ctx context.Context , params KillTerminalCommandRequest ) error {
64+ func (c clientFuncs ) KillTerminalCommand (ctx context.Context , params KillTerminalCommandRequest ) ( KillTerminalCommandResponse , error ) {
6565 if c .KillTerminalCommandFunc != nil {
6666 return c .KillTerminalCommandFunc (ctx , params )
6767 }
68- return nil
68+ return KillTerminalCommandResponse {}, nil
6969}
7070
7171// ReleaseTerminal implements Client.
72- func (c * clientFuncs ) ReleaseTerminal (ctx context.Context , params ReleaseTerminalRequest ) error {
72+ func (c clientFuncs ) ReleaseTerminal (ctx context.Context , params ReleaseTerminalRequest ) ( ReleaseTerminalResponse , error ) {
7373 if c .ReleaseTerminalFunc != nil {
7474 return c .ReleaseTerminalFunc (ctx , params )
7575 }
76- return nil
76+ return ReleaseTerminalResponse {}, nil
7777}
7878
7979// TerminalOutput implements Client.
@@ -93,12 +93,14 @@ func (c *clientFuncs) WaitForTerminalExit(ctx context.Context, params WaitForTer
9393}
9494
9595type agentFuncs struct {
96- InitializeFunc func (context.Context , InitializeRequest ) (InitializeResponse , error )
97- NewSessionFunc func (context.Context , NewSessionRequest ) (NewSessionResponse , error )
98- LoadSessionFunc func (context.Context , LoadSessionRequest ) (LoadSessionResponse , error )
99- AuthenticateFunc func (context.Context , AuthenticateRequest ) error
100- PromptFunc func (context.Context , PromptRequest ) (PromptResponse , error )
101- CancelFunc func (context.Context , CancelNotification ) error
96+ InitializeFunc func (context.Context , InitializeRequest ) (InitializeResponse , error )
97+ NewSessionFunc func (context.Context , NewSessionRequest ) (NewSessionResponse , error )
98+ LoadSessionFunc func (context.Context , LoadSessionRequest ) (LoadSessionResponse , error )
99+ AuthenticateFunc func (context.Context , AuthenticateRequest ) (AuthenticateResponse , error )
100+ PromptFunc func (context.Context , PromptRequest ) (PromptResponse , error )
101+ CancelFunc func (context.Context , CancelNotification ) error
102+ SetSessionModeFunc func (ctx context.Context , params SetSessionModeRequest ) (SetSessionModeResponse , error )
103+ SetSessionModelFunc func (ctx context.Context , params SetSessionModelRequest ) (SetSessionModelResponse , error )
102104}
103105
104106var (
@@ -127,11 +129,11 @@ func (a agentFuncs) LoadSession(ctx context.Context, p LoadSessionRequest) (Load
127129 return LoadSessionResponse {}, nil
128130}
129131
130- func (a agentFuncs ) Authenticate (ctx context.Context , p AuthenticateRequest ) error {
132+ func (a agentFuncs ) Authenticate (ctx context.Context , p AuthenticateRequest ) ( AuthenticateResponse , error ) {
131133 if a .AuthenticateFunc != nil {
132134 return a .AuthenticateFunc (ctx , p )
133135 }
134- return nil
136+ return AuthenticateResponse {}, nil
135137}
136138
137139func (a agentFuncs ) Prompt (ctx context.Context , p PromptRequest ) (PromptResponse , error ) {
@@ -148,15 +150,31 @@ func (a agentFuncs) Cancel(ctx context.Context, n CancelNotification) error {
148150 return nil
149151}
150152
153+ // SetSessionMode implements Agent.
154+ func (a agentFuncs ) SetSessionMode (ctx context.Context , params SetSessionModeRequest ) (SetSessionModeResponse , error ) {
155+ if a .SetSessionModeFunc != nil {
156+ return a .SetSessionModeFunc (ctx , params )
157+ }
158+ return SetSessionModeResponse {}, nil
159+ }
160+
161+ // SetSessionModel implements Agent.
162+ func (a agentFuncs ) SetSessionModel (ctx context.Context , params SetSessionModelRequest ) (SetSessionModelResponse , error ) {
163+ if a .SetSessionModelFunc != nil {
164+ return a .SetSessionModelFunc (ctx , params )
165+ }
166+ return SetSessionModelResponse {}, nil
167+ }
168+
151169// Test bidirectional error handling similar to typescript/acp.test.ts
152170func TestConnectionHandlesErrorsBidirectional (t * testing.T ) {
153171 ctx := context .Background ()
154172 c2aR , c2aW := io .Pipe ()
155173 a2cR , a2cW := io .Pipe ()
156174
157175 c := NewClientSideConnection (& clientFuncs {
158- WriteTextFileFunc : func (context.Context , WriteTextFileRequest ) error {
159- return & RequestError {Code : - 32603 , Message : "Write failed" }
176+ WriteTextFileFunc : func (context.Context , WriteTextFileRequest ) ( WriteTextFileResponse , error ) {
177+ return WriteTextFileResponse {}, & RequestError {Code : - 32603 , Message : "Write failed" }
160178 },
161179 ReadTextFileFunc : func (context.Context , ReadTextFileRequest ) (ReadTextFileResponse , error ) {
162180 return ReadTextFileResponse {}, & RequestError {Code : - 32603 , Message : "Read failed" }
@@ -176,8 +194,8 @@ func TestConnectionHandlesErrorsBidirectional(t *testing.T) {
176194 LoadSessionFunc : func (context.Context , LoadSessionRequest ) (LoadSessionResponse , error ) {
177195 return LoadSessionResponse {}, & RequestError {Code : - 32603 , Message : "Failed to load session" }
178196 },
179- AuthenticateFunc : func (context.Context , AuthenticateRequest ) error {
180- return & RequestError {Code : - 32603 , Message : "Authentication failed" }
197+ AuthenticateFunc : func (context.Context , AuthenticateRequest ) ( AuthenticateResponse , error ) {
198+ return AuthenticateResponse {}, & RequestError {Code : - 32603 , Message : "Authentication failed" }
181199 },
182200 PromptFunc : func (context.Context , PromptRequest ) (PromptResponse , error ) {
183201 return PromptResponse {}, & RequestError {Code : - 32603 , Message : "Prompt failed" }
@@ -186,7 +204,7 @@ func TestConnectionHandlesErrorsBidirectional(t *testing.T) {
186204 }, a2cW , c2aR )
187205
188206 // Client->Agent direction: expect error
189- if err := agentConn .WriteTextFile (ctx , WriteTextFileRequest {Path : "/test.txt" , Content : "test" , SessionId : "test-session" }); err == nil {
207+ if _ , err := agentConn .WriteTextFile (ctx , WriteTextFileRequest {Path : "/test.txt" , Content : "test" , SessionId : "test-session" }); err == nil {
190208 t .Fatalf ("expected error for writeTextFile, got nil" )
191209 }
192210
@@ -205,12 +223,12 @@ func TestConnectionHandlesConcurrentRequests(t *testing.T) {
205223 requestCount := 0
206224
207225 _ = NewClientSideConnection (& clientFuncs {
208- WriteTextFileFunc : func (context.Context , WriteTextFileRequest ) error {
226+ WriteTextFileFunc : func (context.Context , WriteTextFileRequest ) ( WriteTextFileResponse , error ) {
209227 mu .Lock ()
210228 requestCount ++
211229 mu .Unlock ()
212230 time .Sleep (40 * time .Millisecond )
213- return nil
231+ return WriteTextFileResponse {}, nil
214232 },
215233 ReadTextFileFunc : func (_ context.Context , req ReadTextFileRequest ) (ReadTextFileResponse , error ) {
216234 return ReadTextFileResponse {Content : "Content of " + req .Path }, nil
@@ -230,7 +248,9 @@ func TestConnectionHandlesConcurrentRequests(t *testing.T) {
230248 LoadSessionFunc : func (context.Context , LoadSessionRequest ) (LoadSessionResponse , error ) {
231249 return LoadSessionResponse {}, nil
232250 },
233- AuthenticateFunc : func (context.Context , AuthenticateRequest ) error { return nil },
251+ AuthenticateFunc : func (context.Context , AuthenticateRequest ) (AuthenticateResponse , error ) {
252+ return AuthenticateResponse {}, nil
253+ },
234254 PromptFunc : func (context.Context , PromptRequest ) (PromptResponse , error ) {
235255 return PromptResponse {StopReason : "end_turn" }, nil
236256 },
@@ -249,7 +269,7 @@ func TestConnectionHandlesConcurrentRequests(t *testing.T) {
249269 req := p
250270 go func () {
251271 defer wg .Done ()
252- errs [idx ] = agentConn .WriteTextFile (context .Background (), req )
272+ _ , errs [idx ] = agentConn .WriteTextFile (context .Background (), req )
253273 }()
254274 }
255275 wg .Wait ()
@@ -276,9 +296,9 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) {
276296 push := func (s string ) { mu .Lock (); defer mu .Unlock (); log = append (log , s ) }
277297
278298 cs := NewClientSideConnection (& clientFuncs {
279- WriteTextFileFunc : func (_ context.Context , req WriteTextFileRequest ) error {
299+ WriteTextFileFunc : func (_ context.Context , req WriteTextFileRequest ) ( WriteTextFileResponse , error ) {
280300 push ("writeTextFile called: " + req .Path )
281- return nil
301+ return WriteTextFileResponse {}, nil
282302 },
283303 ReadTextFileFunc : func (_ context.Context , req ReadTextFileRequest ) (ReadTextFileResponse , error ) {
284304 push ("readTextFile called: " + req .Path )
@@ -306,9 +326,9 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) {
306326 push ("loadSession called: " + string (p .SessionId ))
307327 return LoadSessionResponse {}, nil
308328 },
309- AuthenticateFunc : func (_ context.Context , p AuthenticateRequest ) error {
329+ AuthenticateFunc : func (_ context.Context , p AuthenticateRequest ) ( AuthenticateResponse , error ) {
310330 push ("authenticate called: " + string (p .MethodId ))
311- return nil
331+ return AuthenticateResponse {}, nil
312332 },
313333 PromptFunc : func (_ context.Context , p PromptRequest ) (PromptResponse , error ) {
314334 push ("prompt called: " + string (p .SessionId ))
@@ -323,7 +343,7 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) {
323343 if _ , err := cs .NewSession (context .Background (), NewSessionRequest {Cwd : "/test" , McpServers : []McpServer {}}); err != nil {
324344 t .Fatalf ("newSession error: %v" , err )
325345 }
326- if err := as .WriteTextFile (context .Background (), WriteTextFileRequest {Path : "/test.txt" , Content : "test" , SessionId : "test-session" }); err != nil {
346+ if _ , err := as .WriteTextFile (context .Background (), WriteTextFileRequest {Path : "/test.txt" , Content : "test" , SessionId : "test-session" }); err != nil {
327347 t .Fatalf ("writeTextFile error: %v" , err )
328348 }
329349 if _ , err := as .ReadTextFile (context .Background (), ReadTextFileRequest {Path : "/test.txt" , SessionId : "test-session" }); err != nil {
@@ -376,7 +396,9 @@ func TestConnectionHandlesNotifications(t *testing.T) {
376396 push := func (s string ) { mu .Lock (); logs = append (logs , s ); mu .Unlock () }
377397
378398 clientSide := NewClientSideConnection (& clientFuncs {
379- WriteTextFileFunc : func (context.Context , WriteTextFileRequest ) error { return nil },
399+ WriteTextFileFunc : func (context.Context , WriteTextFileRequest ) (WriteTextFileResponse , error ) {
400+ return WriteTextFileResponse {}, nil
401+ },
380402 ReadTextFileFunc : func (context.Context , ReadTextFileRequest ) (ReadTextFileResponse , error ) {
381403 return ReadTextFileResponse {Content : "test" }, nil
382404 },
@@ -405,7 +427,9 @@ func TestConnectionHandlesNotifications(t *testing.T) {
405427 LoadSessionFunc : func (context.Context , LoadSessionRequest ) (LoadSessionResponse , error ) {
406428 return LoadSessionResponse {}, nil
407429 },
408- AuthenticateFunc : func (context.Context , AuthenticateRequest ) error { return nil },
430+ AuthenticateFunc : func (context.Context , AuthenticateRequest ) (AuthenticateResponse , error ) {
431+ return AuthenticateResponse {}, nil
432+ },
409433 PromptFunc : func (context.Context , PromptRequest ) (PromptResponse , error ) {
410434 return PromptResponse {StopReason : "end_turn" }, nil
411435 },
@@ -447,7 +471,9 @@ func TestConnectionHandlesInitialize(t *testing.T) {
447471 a2cR , a2cW := io .Pipe ()
448472
449473 agentConn := NewClientSideConnection (& clientFuncs {
450- WriteTextFileFunc : func (context.Context , WriteTextFileRequest ) error { return nil },
474+ WriteTextFileFunc : func (context.Context , WriteTextFileRequest ) (WriteTextFileResponse , error ) {
475+ return WriteTextFileResponse {}, nil
476+ },
451477 ReadTextFileFunc : func (context.Context , ReadTextFileRequest ) (ReadTextFileResponse , error ) {
452478 return ReadTextFileResponse {Content : "test" }, nil
453479 },
@@ -478,7 +504,9 @@ func TestConnectionHandlesInitialize(t *testing.T) {
478504 LoadSessionFunc : func (context.Context , LoadSessionRequest ) (LoadSessionResponse , error ) {
479505 return LoadSessionResponse {}, nil
480506 },
481- AuthenticateFunc : func (context.Context , AuthenticateRequest ) error { return nil },
507+ AuthenticateFunc : func (context.Context , AuthenticateRequest ) (AuthenticateResponse , error ) {
508+ return AuthenticateResponse {}, nil
509+ },
482510 PromptFunc : func (context.Context , PromptRequest ) (PromptResponse , error ) {
483511 return PromptResponse {StopReason : "end_turn" }, nil
484512 },
@@ -527,7 +555,9 @@ func TestPromptCancellationSendsCancelAndAllowsNewSession(t *testing.T) {
527555 LoadSessionFunc : func (context.Context , LoadSessionRequest ) (LoadSessionResponse , error ) {
528556 return LoadSessionResponse {}, nil
529557 },
530- AuthenticateFunc : func (context.Context , AuthenticateRequest ) error { return nil },
558+ AuthenticateFunc : func (context.Context , AuthenticateRequest ) (AuthenticateResponse , error ) {
559+ return AuthenticateResponse {}, nil
560+ },
531561 PromptFunc : func (ctx context.Context , p PromptRequest ) (PromptResponse , error ) {
532562 <- ctx .Done ()
533563 // mark that prompt finished due to cancellation
@@ -548,7 +578,9 @@ func TestPromptCancellationSendsCancelAndAllowsNewSession(t *testing.T) {
548578
549579 // Client side
550580 cs := NewClientSideConnection (& clientFuncs {
551- WriteTextFileFunc : func (context.Context , WriteTextFileRequest ) error { return nil },
581+ WriteTextFileFunc : func (context.Context , WriteTextFileRequest ) (WriteTextFileResponse , error ) {
582+ return WriteTextFileResponse {}, nil
583+ },
552584 ReadTextFileFunc : func (context.Context , ReadTextFileRequest ) (ReadTextFileResponse , error ) {
553585 return ReadTextFileResponse {Content : "" }, nil
554586 },
0 commit comments