|
| 1 | +package acpio_test |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "io" |
| 6 | + "os" |
| 7 | + "sync" |
| 8 | + "testing" |
| 9 | + |
| 10 | + acp "github.com/coder/acp-go-sdk" |
| 11 | + "github.com/stretchr/testify/assert" |
| 12 | + "github.com/stretchr/testify/require" |
| 13 | + |
| 14 | + "github.com/coder/agentapi/x/acpio" |
| 15 | +) |
| 16 | + |
| 17 | +// testAgent implements acp.Agent for testing. |
| 18 | +type testAgent struct { |
| 19 | + conn *acp.AgentSideConnection |
| 20 | + onPrompt func(ctx context.Context, conn *acp.AgentSideConnection, p acp.PromptRequest) (acp.PromptResponse, error) |
| 21 | +} |
| 22 | + |
| 23 | +var _ acp.Agent = (*testAgent)(nil) |
| 24 | + |
| 25 | +func (a *testAgent) SetAgentConnection(c *acp.AgentSideConnection) { a.conn = c } |
| 26 | + |
| 27 | +func (a *testAgent) Authenticate(context.Context, acp.AuthenticateRequest) (acp.AuthenticateResponse, error) { |
| 28 | + return acp.AuthenticateResponse{}, nil |
| 29 | +} |
| 30 | + |
| 31 | +func (a *testAgent) Initialize(context.Context, acp.InitializeRequest) (acp.InitializeResponse, error) { |
| 32 | + return acp.InitializeResponse{ |
| 33 | + ProtocolVersion: acp.ProtocolVersionNumber, |
| 34 | + AgentCapabilities: acp.AgentCapabilities{}, |
| 35 | + }, nil |
| 36 | +} |
| 37 | + |
| 38 | +func (a *testAgent) Cancel(context.Context, acp.CancelNotification) error { return nil } |
| 39 | + |
| 40 | +func (a *testAgent) NewSession(context.Context, acp.NewSessionRequest) (acp.NewSessionResponse, error) { |
| 41 | + return acp.NewSessionResponse{SessionId: "test-session"}, nil |
| 42 | +} |
| 43 | + |
| 44 | +func (a *testAgent) SetSessionMode(context.Context, acp.SetSessionModeRequest) (acp.SetSessionModeResponse, error) { |
| 45 | + return acp.SetSessionModeResponse{}, nil |
| 46 | +} |
| 47 | + |
| 48 | +func (a *testAgent) Prompt(ctx context.Context, p acp.PromptRequest) (acp.PromptResponse, error) { |
| 49 | + if a.onPrompt != nil { |
| 50 | + return a.onPrompt(ctx, a.conn, p) |
| 51 | + } |
| 52 | + return acp.PromptResponse{StopReason: acp.StopReasonEndTurn}, nil |
| 53 | +} |
| 54 | + |
| 55 | +// newTestPair creates an ACPAgentIO connected to a testAgent via pipes. |
| 56 | +func newTestPair(t *testing.T, agent *testAgent) *acpio.ACPAgentIO { |
| 57 | + t.Helper() |
| 58 | + |
| 59 | + // Two pipe pairs: client writes → agent reads, agent writes → client reads. |
| 60 | + clientToAgentR, clientToAgentW := io.Pipe() |
| 61 | + agentToClientR, agentToClientW := io.Pipe() |
| 62 | + |
| 63 | + // Client side: peerInput=clientToAgentW (writes to agent), peerOutput=agentToClientR (reads from agent) |
| 64 | + // Agent side: peerInput=agentToClientW (writes to client), peerOutput=clientToAgentR (reads from client) |
| 65 | + asc := acp.NewAgentSideConnection(agent, agentToClientW, clientToAgentR) |
| 66 | + agent.SetAgentConnection(asc) |
| 67 | + |
| 68 | + agentIO, err := acpio.NewWithPipes( |
| 69 | + context.Background(), |
| 70 | + clientToAgentW, agentToClientR, |
| 71 | + nil, |
| 72 | + func() (string, error) { return os.TempDir(), nil }, |
| 73 | + ) |
| 74 | + require.NoError(t, err) |
| 75 | + |
| 76 | + t.Cleanup(func() { |
| 77 | + _ = clientToAgentW.Close() |
| 78 | + _ = agentToClientW.Close() |
| 79 | + }) |
| 80 | + |
| 81 | + return agentIO |
| 82 | +} |
| 83 | + |
| 84 | +// chunkCollector collects chunks from SetOnChunk in a thread-safe way |
| 85 | +// and provides a method to wait for a specific number of chunks. |
| 86 | +type chunkCollector struct { |
| 87 | + mu sync.Mutex |
| 88 | + cond *sync.Cond |
| 89 | + chunks []string |
| 90 | +} |
| 91 | + |
| 92 | +func newChunkCollector() *chunkCollector { |
| 93 | + c := &chunkCollector{} |
| 94 | + c.cond = sync.NewCond(&c.mu) |
| 95 | + return c |
| 96 | +} |
| 97 | + |
| 98 | +func (c *chunkCollector) callback(chunk string) { |
| 99 | + c.mu.Lock() |
| 100 | + defer c.mu.Unlock() |
| 101 | + c.chunks = append(c.chunks, chunk) |
| 102 | + c.cond.Broadcast() |
| 103 | +} |
| 104 | + |
| 105 | +func (c *chunkCollector) waitForN(t *testing.T, n int) []string { |
| 106 | + t.Helper() |
| 107 | + c.mu.Lock() |
| 108 | + defer c.mu.Unlock() |
| 109 | + for len(c.chunks) < n { |
| 110 | + c.cond.Wait() |
| 111 | + } |
| 112 | + return append([]string(nil), c.chunks...) |
| 113 | +} |
| 114 | + |
| 115 | +func Test_ACPAgentIO_WriteAndReadScreen(t *testing.T) { |
| 116 | + collector := newChunkCollector() |
| 117 | + agent := &testAgent{ |
| 118 | + onPrompt: func(ctx context.Context, conn *acp.AgentSideConnection, p acp.PromptRequest) (acp.PromptResponse, error) { |
| 119 | + _ = conn.SessionUpdate(ctx, acp.SessionNotification{ |
| 120 | + SessionId: p.SessionId, |
| 121 | + Update: acp.UpdateAgentMessageText("Hello from agent!"), |
| 122 | + }) |
| 123 | + return acp.PromptResponse{StopReason: acp.StopReasonEndTurn}, nil |
| 124 | + }, |
| 125 | + } |
| 126 | + agentIO := newTestPair(t, agent) |
| 127 | + agentIO.SetOnChunk(collector.callback) |
| 128 | + |
| 129 | + n, err := agentIO.Write([]byte("test prompt")) |
| 130 | + require.NoError(t, err) |
| 131 | + assert.Equal(t, len("test prompt"), n) |
| 132 | + |
| 133 | + // SessionUpdate notifications are async — wait for the chunk to arrive. |
| 134 | + collector.waitForN(t, 1) |
| 135 | + assert.Equal(t, "Hello from agent!", agentIO.ReadScreen()) |
| 136 | +} |
| 137 | + |
| 138 | +func Test_ACPAgentIO_StreamingChunks(t *testing.T) { |
| 139 | + collector := newChunkCollector() |
| 140 | + agent := &testAgent{ |
| 141 | + onPrompt: func(ctx context.Context, conn *acp.AgentSideConnection, p acp.PromptRequest) (acp.PromptResponse, error) { |
| 142 | + for _, text := range []string{"Hello", " ", "world!"} { |
| 143 | + _ = conn.SessionUpdate(ctx, acp.SessionNotification{ |
| 144 | + SessionId: p.SessionId, |
| 145 | + Update: acp.UpdateAgentMessageText(text), |
| 146 | + }) |
| 147 | + } |
| 148 | + return acp.PromptResponse{StopReason: acp.StopReasonEndTurn}, nil |
| 149 | + }, |
| 150 | + } |
| 151 | + agentIO := newTestPair(t, agent) |
| 152 | + agentIO.SetOnChunk(collector.callback) |
| 153 | + |
| 154 | + _, err := agentIO.Write([]byte("test")) |
| 155 | + require.NoError(t, err) |
| 156 | + |
| 157 | + // All three chunks should arrive (order may vary due to async notification handling). |
| 158 | + chunks := collector.waitForN(t, 3) |
| 159 | + assert.Len(t, chunks, 3) |
| 160 | + assert.ElementsMatch(t, []string{"Hello", " ", "world!"}, chunks) |
| 161 | +} |
| 162 | + |
| 163 | +func Test_ACPAgentIO_StripsEscapeSequences(t *testing.T) { |
| 164 | + received := make(chan string, 1) |
| 165 | + agent := &testAgent{ |
| 166 | + onPrompt: func(ctx context.Context, conn *acp.AgentSideConnection, p acp.PromptRequest) (acp.PromptResponse, error) { |
| 167 | + defer close(received) |
| 168 | + for _, block := range p.Prompt { |
| 169 | + if block.Text != nil { |
| 170 | + received <- block.Text.Text |
| 171 | + } |
| 172 | + } |
| 173 | + return acp.PromptResponse{StopReason: acp.StopReasonEndTurn}, nil |
| 174 | + }, |
| 175 | + } |
| 176 | + agentIO := newTestPair(t, agent) |
| 177 | + |
| 178 | + // Bracketed paste sequences should be stripped |
| 179 | + _, err := agentIO.Write([]byte("\x1b[200~hello world\x1b[201~")) |
| 180 | + require.NoError(t, err) |
| 181 | + assert.Equal(t, "hello world", <-received) |
| 182 | +} |
| 183 | + |
| 184 | +func Test_ACPAgentIO_IgnoresEmptyPrompt(t *testing.T) { |
| 185 | + agent := &testAgent{ |
| 186 | + onPrompt: func(ctx context.Context, conn *acp.AgentSideConnection, p acp.PromptRequest) (acp.PromptResponse, error) { |
| 187 | + assert.Fail(t, "empty prompt should not reach the agent") |
| 188 | + return acp.PromptResponse{StopReason: acp.StopReasonEndTurn}, nil |
| 189 | + }, |
| 190 | + } |
| 191 | + agentIO := newTestPair(t, agent) |
| 192 | + |
| 193 | + // Empty after stripping should be a no-op |
| 194 | + n, err := agentIO.Write([]byte(" \t\n ")) |
| 195 | + require.NoError(t, err) |
| 196 | + assert.Equal(t, len(" \t\n "), n) |
| 197 | +} |
| 198 | + |
| 199 | +func Test_ACPAgentIO_ToolCallFormattedAsText(t *testing.T) { |
| 200 | + collector := newChunkCollector() |
| 201 | + agent := &testAgent{ |
| 202 | + onPrompt: func(ctx context.Context, conn *acp.AgentSideConnection, p acp.PromptRequest) (acp.PromptResponse, error) { |
| 203 | + _ = conn.SessionUpdate(ctx, acp.SessionNotification{ |
| 204 | + SessionId: p.SessionId, |
| 205 | + Update: acp.StartToolCall( |
| 206 | + "call_1", |
| 207 | + "Reading file", |
| 208 | + acp.WithStartKind(acp.ToolKindRead), |
| 209 | + ), |
| 210 | + }) |
| 211 | + return acp.PromptResponse{StopReason: acp.StopReasonEndTurn}, nil |
| 212 | + }, |
| 213 | + } |
| 214 | + agentIO := newTestPair(t, agent) |
| 215 | + agentIO.SetOnChunk(collector.callback) |
| 216 | + |
| 217 | + _, err := agentIO.Write([]byte("do something")) |
| 218 | + require.NoError(t, err) |
| 219 | + |
| 220 | + collector.waitForN(t, 1) |
| 221 | + assert.Contains(t, agentIO.ReadScreen(), "[Tool: read]") |
| 222 | + assert.Contains(t, agentIO.ReadScreen(), "Reading file") |
| 223 | +} |
| 224 | + |
| 225 | +func Test_ACPAgentIO_ResetsResponseBetweenWrites(t *testing.T) { |
| 226 | + collector := newChunkCollector() |
| 227 | + callCount := 0 |
| 228 | + agent := &testAgent{ |
| 229 | + onPrompt: func(ctx context.Context, conn *acp.AgentSideConnection, p acp.PromptRequest) (acp.PromptResponse, error) { |
| 230 | + callCount++ |
| 231 | + _ = conn.SessionUpdate(ctx, acp.SessionNotification{ |
| 232 | + SessionId: p.SessionId, |
| 233 | + Update: acp.UpdateAgentMessageText("response " + string(rune('0'+callCount))), |
| 234 | + }) |
| 235 | + return acp.PromptResponse{StopReason: acp.StopReasonEndTurn}, nil |
| 236 | + }, |
| 237 | + } |
| 238 | + agentIO := newTestPair(t, agent) |
| 239 | + agentIO.SetOnChunk(collector.callback) |
| 240 | + |
| 241 | + _, err := agentIO.Write([]byte("first")) |
| 242 | + require.NoError(t, err) |
| 243 | + collector.waitForN(t, 1) |
| 244 | + assert.Equal(t, "response 1", agentIO.ReadScreen()) |
| 245 | + |
| 246 | + _, err = agentIO.Write([]byte("second")) |
| 247 | + require.NoError(t, err) |
| 248 | + collector.waitForN(t, 2) |
| 249 | + // Response should be reset, not accumulated |
| 250 | + assert.Equal(t, "response 2", agentIO.ReadScreen()) |
| 251 | +} |
0 commit comments