Skip to content

Commit 49ebb86

Browse files
committed
make ACPConversation.Send() sync to match PTYConversation behaviour
1 parent dda3717 commit 49ebb86

4 files changed

Lines changed: 72 additions & 54 deletions

File tree

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ require (
1616
github.com/spf13/viper v1.20.1
1717
github.com/stretchr/testify v1.11.1
1818
github.com/tmaxmax/go-sse v0.10.0
19+
go.uber.org/goleak v1.3.0
1920
golang.org/x/term v0.30.0
2021
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da
2122
)

x/acpio/acp_conversation.go

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ func (c *ACPConversation) Messages() []st.ConversationMessage {
7878
return slices.Clone(c.messages)
7979
}
8080

81-
// Send sends a message to the agent asynchronously.
82-
// It returns immediately after recording the user message and starts
83-
// the agent request in a background goroutine. Returns an error if
84-
// a message is already being processed.
81+
// Send sends a message to the agent synchronously.
82+
// It blocks until the agent has finished processing and returns any error
83+
// from the underlying write. Returns a validation error immediately if
84+
// the message is invalid or another message is already being processed.
8585
func (c *ACPConversation) Send(messageParts ...st.MessagePart) error {
8686
message := ""
8787
for _, part := range messageParts {
@@ -126,10 +126,7 @@ func (c *ACPConversation) Send(messageParts ...st.MessagePart) error {
126126

127127
c.logger.Debug("ACPConversation sending message", "message", message)
128128

129-
// Run the blocking write in a goroutine so HTTP returns immediately
130-
go c.executePrompt(messageParts)
131-
132-
return nil
129+
return c.executePrompt(messageParts)
133130
}
134131

135132
// Start sets up chunk handling and sends the initial prompt if provided.
@@ -139,10 +136,14 @@ func (c *ACPConversation) Start(ctx context.Context) {
139136

140137
// Send initial prompt if provided
141138
if len(c.initialPrompt) > 0 {
142-
err := c.Send(c.initialPrompt...)
143-
if err != nil {
144-
c.logger.Error("ACPConversation failed to send initial prompt", "error", err)
145-
}
139+
// Run in a goroutine because Send blocks until the prompt completes,
140+
// and Start must return immediately per the Conversation interface.
141+
go func() {
142+
err := c.Send(c.initialPrompt...)
143+
if err != nil {
144+
c.logger.Error("ACPConversation failed to send initial prompt", "error", err)
145+
}
146+
}()
146147
} else {
147148
// No initial prompt means we start in stable state
148149
c.emitter.EmitStatus(c.Status())
@@ -203,8 +204,8 @@ func (c *ACPConversation) handleChunk(chunk string) {
203204
c.emitter.EmitScreen(screen)
204205
}
205206

206-
// executePrompt runs the actual agent request in background
207-
func (c *ACPConversation) executePrompt(messageParts []st.MessagePart) {
207+
// executePrompt runs the actual agent request and returns any error.
208+
func (c *ACPConversation) executePrompt(messageParts []st.MessagePart) error {
208209
var err error
209210
for _, part := range messageParts {
210211
if c.ctx.Err() != nil {
@@ -234,7 +235,7 @@ func (c *ACPConversation) executePrompt(messageParts []st.MessagePart) {
234235
c.emitter.EmitMessages(messages)
235236
c.emitter.EmitStatus(status)
236237
c.emitter.EmitScreen(screen)
237-
return
238+
return err
238239
}
239240

240241
// Final response should already be in the last message via streaming
@@ -253,4 +254,5 @@ func (c *ACPConversation) executePrompt(messageParts []st.MessagePart) {
253254
c.emitter.EmitScreen(screen)
254255

255256
c.logger.Debug("ACPConversation message complete", "responseLen", len(response))
257+
return nil
256258
}

x/acpio/acp_conversation_test.go

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,17 @@ func Test_Status_InitiallyStable(t *testing.T) {
207207
func Test_Send_AddsUserMessage(t *testing.T) {
208208
mClock := quartz.NewMock(t)
209209
mock := newMockAgentIO()
210-
// Set up blocking to synchronize with the goroutine
210+
// Set up blocking so we can inspect state mid-flight
211211
started, done := mock.BlockWrite()
212212

213213
conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock)
214214
conv.Start(context.Background())
215215

216-
err := conv.Send(screentracker.MessagePartText{Content: "hello"})
217-
require.NoError(t, err)
216+
// Send blocks until completion, so run in a goroutine
217+
errCh := make(chan error, 1)
218+
go func() { errCh <- conv.Send(screentracker.MessagePartText{Content: "hello"}) }()
218219

219-
// Wait for the write goroutine to start
220+
// Wait for the write to start
220221
<-started
221222

222223
messages := conv.Messages()
@@ -226,8 +227,9 @@ func Test_Send_AddsUserMessage(t *testing.T) {
226227
assert.Equal(t, "hello", messages[0].Message)
227228
assert.Equal(t, screentracker.ConversationRoleAgent, messages[1].Role)
228229

229-
// Unblock the write to let the test complete cleanly
230+
// Unblock the write to let Send complete
230231
close(done)
232+
require.NoError(t, <-errCh)
231233
}
232234

233235
func Test_Send_RejectsEmptyMessage(t *testing.T) {
@@ -277,19 +279,20 @@ func Test_Send_RejectsDuplicateSend(t *testing.T) {
277279
conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock)
278280
conv.Start(context.Background())
279281

280-
// First send should succeed
281-
err := conv.Send(screentracker.MessagePartText{Content: "first"})
282-
require.NoError(t, err)
282+
// First send blocks, so run in a goroutine
283+
errCh := make(chan error, 1)
284+
go func() { errCh <- conv.Send(screentracker.MessagePartText{Content: "first"}) }()
283285

284286
// Wait for the write to start (ensuring we're in "prompting" state)
285287
<-started
286288

287289
// Second send while first is processing should fail
288-
err = conv.Send(screentracker.MessagePartText{Content: "second"})
290+
err := conv.Send(screentracker.MessagePartText{Content: "second"})
289291
assert.ErrorIs(t, err, screentracker.ErrMessageValidationChanging)
290292

291293
// Unblock the write to let the test complete cleanly
292294
close(done)
295+
require.NoError(t, <-errCh)
293296
}
294297

295298
func Test_Status_ChangesWhileProcessing(t *testing.T) {
@@ -305,9 +308,9 @@ func Test_Status_ChangesWhileProcessing(t *testing.T) {
305308
conv := acpio.NewACPConversation(ctx, mock, nil, nil, emitter, mClock)
306309
conv.Start(ctx)
307310

308-
// Send a message
309-
err := conv.Send(screentracker.MessagePartText{Content: "test"})
310-
require.NoError(t, err)
311+
// Send blocks, so run in a goroutine
312+
errCh := make(chan error, 1)
313+
go func() { errCh <- conv.Send(screentracker.MessagePartText{Content: "test"}) }()
311314

312315
// Wait for write to start
313316
<-started
@@ -318,7 +321,8 @@ func Test_Status_ChangesWhileProcessing(t *testing.T) {
318321
// Unblock the write
319322
close(done)
320323

321-
// Wait for the goroutine to complete - status should then be stable.
324+
// Wait for Send to complete - status should then be stable.
325+
require.NoError(t, <-errCh)
322326
emitter.WaitForStatus(ctx, t, screentracker.ConversationStatusStable)
323327
}
324328

@@ -334,9 +338,9 @@ func Test_Text_ReturnsStreamingContent(t *testing.T) {
334338
// Initially empty
335339
assert.Equal(t, "", conv.Text())
336340

337-
// Send a message
338-
err := conv.Send(screentracker.MessagePartText{Content: "question"})
339-
require.NoError(t, err)
341+
// Send blocks, so run in a goroutine
342+
errCh := make(chan error, 1)
343+
go func() { errCh <- conv.Send(screentracker.MessagePartText{Content: "question"}) }()
340344

341345
// Wait for write to start
342346
<-started
@@ -352,8 +356,9 @@ func Test_Text_ReturnsStreamingContent(t *testing.T) {
352356
require.Len(t, messages, 2)
353357
assert.Equal(t, "Hello world!", messages[1].Message)
354358

355-
// Unblock the write to let the test complete cleanly
359+
// Unblock the write to let Send complete
356360
close(done)
361+
require.NoError(t, <-errCh)
357362
}
358363

359364
func Test_Emitter_CalledOnChanges(t *testing.T) {
@@ -370,9 +375,9 @@ func Test_Emitter_CalledOnChanges(t *testing.T) {
370375
conv := acpio.NewACPConversation(ctx, mock, nil, nil, emitter, mClock)
371376
conv.Start(ctx)
372377

373-
// Send a message
374-
err := conv.Send(screentracker.MessagePartText{Content: "test"})
375-
require.NoError(t, err)
378+
// Send blocks, so run in a goroutine
379+
errCh := make(chan error, 1)
380+
go func() { errCh <- conv.Send(screentracker.MessagePartText{Content: "test"}) }()
376381

377382
// Wait for write to start
378383
<-started
@@ -390,6 +395,7 @@ func Test_Emitter_CalledOnChanges(t *testing.T) {
390395

391396
// Unblock the write to complete processing
392397
close(done)
398+
require.NoError(t, <-errCh)
393399

394400
// Wait for completion emit
395401
emitter.WaitForStatus(ctx, t, screentracker.ConversationStatusStable)
@@ -413,7 +419,7 @@ func Test_InitialPrompt_SentOnStart(t *testing.T) {
413419
conv := acpio.NewACPConversation(context.Background(), mock, nil, initialPrompt, nil, mClock)
414420
conv.Start(context.Background())
415421

416-
// Wait for write to start (initial prompt is being sent)
422+
// Wait for write to start (initial prompt is being sent via Start's goroutine)
417423
<-started
418424

419425
// Should have user message from initial prompt
@@ -429,14 +435,15 @@ func Test_InitialPrompt_SentOnStart(t *testing.T) {
429435
func Test_Messages_AreCopied(t *testing.T) {
430436
mClock := quartz.NewMock(t)
431437
mock := newMockAgentIO()
432-
// Set up blocking to synchronize
438+
// Set up blocking so we can inspect state mid-flight
433439
started, done := mock.BlockWrite()
434440

435441
conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock)
436442
conv.Start(context.Background())
437443

438-
err := conv.Send(screentracker.MessagePartText{Content: "test"})
439-
require.NoError(t, err)
444+
// Send blocks, so run in a goroutine
445+
errCh := make(chan error, 1)
446+
go func() { errCh <- conv.Send(screentracker.MessagePartText{Content: "test"}) }()
440447

441448
// Wait for write to start
442449
<-started
@@ -450,8 +457,9 @@ func Test_Messages_AreCopied(t *testing.T) {
450457
originalMessages := conv.Messages()
451458
assert.Equal(t, "test", originalMessages[0].Message)
452459

453-
// Unblock the write to let the test complete cleanly
460+
// Unblock the write to let Send complete
454461
close(done)
462+
require.NoError(t, <-errCh)
455463
}
456464

457465
func Test_ErrorRemovesPartialMessage(t *testing.T) {
@@ -467,9 +475,9 @@ func Test_ErrorRemovesPartialMessage(t *testing.T) {
467475
conv := acpio.NewACPConversation(ctx, mock, nil, nil, emitter, mClock)
468476
conv.Start(ctx)
469477

470-
// Send a message
471-
err := conv.Send(screentracker.MessagePartText{Content: "test"})
472-
require.NoError(t, err)
478+
// Send blocks, so run in a goroutine
479+
errCh := make(chan error, 1)
480+
go func() { errCh <- conv.Send(screentracker.MessagePartText{Content: "test"}) }()
473481

474482
// Wait for write to start
475483
<-started
@@ -494,8 +502,8 @@ func Test_ErrorRemovesPartialMessage(t *testing.T) {
494502
mock.mu.Unlock()
495503
close(done)
496504

497-
// Wait for the conversation to stabilize after the error
498-
emitter.WaitForStatus(ctx, t, screentracker.ConversationStatusStable)
505+
// Send should return the error
506+
require.ErrorIs(t, <-errCh, assert.AnError)
499507

500508
// The partial agent message should be removed on error.
501509
// Only the user message should remain.
@@ -506,28 +514,24 @@ func Test_ErrorRemovesPartialMessage(t *testing.T) {
506514
}
507515

508516
func Test_LateChunkAfterError_DoesNotCorruptUserMessage(t *testing.T) {
509-
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
510-
defer cancel()
511-
512517
mClock := quartz.NewMock(t)
513518
mock := newMockAgentIO()
514-
emitter := newMockEmitter()
515519
started, done := mock.BlockWrite()
516520

517-
conv := acpio.NewACPConversation(ctx, mock, nil, nil, emitter, mClock)
518-
conv.Start(ctx)
521+
conv := acpio.NewACPConversation(context.Background(), mock, nil, nil, nil, mClock)
522+
conv.Start(context.Background())
519523

520524
// Given: a send that fails with an error, removing the agent placeholder
521-
err := conv.Send(screentracker.MessagePartText{Content: "hello"})
522-
require.NoError(t, err)
525+
errCh := make(chan error, 1)
526+
go func() { errCh <- conv.Send(screentracker.MessagePartText{Content: "hello"}) }()
523527
<-started
524528

525529
mock.mu.Lock()
526530
mock.writeErr = assert.AnError
527531
mock.mu.Unlock()
528532
close(done)
529533

530-
emitter.WaitForStatus(ctx, t, screentracker.ConversationStatusStable)
534+
require.ErrorIs(t, <-errCh, assert.AnError)
531535

532536
messages := conv.Messages()
533537
require.Len(t, messages, 1, "agent placeholder should be removed after error")

x/acpio/main_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package acpio_test
2+
3+
import (
4+
"testing"
5+
6+
"go.uber.org/goleak"
7+
)
8+
9+
func TestMain(m *testing.M) {
10+
goleak.VerifyTestMain(m)
11+
}

0 commit comments

Comments
 (0)