Skip to content

Commit 73f02f4

Browse files
committed
Support per-connection agent factories
1 parent 2a6d4c5 commit 73f02f4

3 files changed

Lines changed: 489 additions & 11 deletions

File tree

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import { describe, expect, it } from "vitest";
2+
3+
import { PROTOCOL_VERSION } from "./acp.js";
4+
import { HEADER_CONNECTION_ID } from "./protocol.js";
5+
import { AcpServer } from "./server.js";
6+
import { TestAgent } from "./test-support/test-agent.js";
7+
8+
import type { Agent, AgentSideConnection } from "./acp.js";
9+
import type { AnyMessage } from "./jsonrpc.js";
10+
import type { WebSocketServerSocket } from "./ws-server.js";
11+
12+
const initializeRequest = {
13+
jsonrpc: "2.0",
14+
id: 0,
15+
method: "initialize",
16+
params: {
17+
protocolVersion: PROTOCOL_VERSION,
18+
clientCapabilities: {},
19+
},
20+
} satisfies AnyMessage;
21+
22+
describe("AcpServer prepared WebSocket upgrades", () => {
23+
it("uses the default factory when no per-upgrade override is provided", async () => {
24+
const createdBy: string[] = [];
25+
const server = new AcpServer({
26+
createAgent: recordingFactory(createdBy, "default"),
27+
});
28+
const socket = new FakeServerSocket();
29+
30+
try {
31+
server.prepareWebSocketUpgrade().accept(socket);
32+
socket.receive(JSON.stringify(initializeRequest));
33+
34+
await expect(readSentMessage(socket)).resolves.toMatchObject({
35+
jsonrpc: "2.0",
36+
id: initializeRequest.id,
37+
result: {
38+
protocolVersion: PROTOCOL_VERSION,
39+
},
40+
});
41+
expect(createdBy).toEqual(["default"]);
42+
} finally {
43+
socket.close();
44+
await server.close();
45+
}
46+
});
47+
48+
it("uses a per-upgrade factory override for that WebSocket connection", async () => {
49+
const createdBy: string[] = [];
50+
const server = new AcpServer({
51+
createAgent: recordingFactory(createdBy, "default"),
52+
});
53+
const socket = new FakeServerSocket();
54+
55+
try {
56+
server
57+
.prepareWebSocketUpgrade({
58+
createAgent: recordingFactory(createdBy, "override"),
59+
})
60+
.accept(socket);
61+
socket.receive(JSON.stringify(initializeRequest));
62+
63+
await readSentMessage(socket);
64+
expect(createdBy).toEqual(["override"]);
65+
} finally {
66+
socket.close();
67+
await server.close();
68+
}
69+
});
70+
71+
it("does not leak WebSocket factory overrides to later prepared upgrades", async () => {
72+
const createdBy: string[] = [];
73+
const server = new AcpServer({
74+
createAgent: recordingFactory(createdBy, "default"),
75+
});
76+
const overrideSocket = new FakeServerSocket();
77+
const defaultSocket = new FakeServerSocket();
78+
79+
try {
80+
server
81+
.prepareWebSocketUpgrade({
82+
createAgent: recordingFactory(createdBy, "override"),
83+
})
84+
.accept(overrideSocket);
85+
server.prepareWebSocketUpgrade().accept(defaultSocket);
86+
87+
overrideSocket.receive(JSON.stringify(initializeRequest));
88+
defaultSocket.receive(JSON.stringify({ ...initializeRequest, id: 1 }));
89+
90+
await Promise.all([
91+
readSentMessage(overrideSocket),
92+
readSentMessage(defaultSocket),
93+
]);
94+
expect(createdBy).toEqual(["override", "default"]);
95+
} finally {
96+
overrideSocket.close();
97+
defaultSocket.close();
98+
await server.close();
99+
}
100+
});
101+
102+
it("keeps concurrent WebSocket factory overrides isolated", async () => {
103+
const createdBy: string[] = [];
104+
const server = new AcpServer({
105+
createAgent: recordingFactory(createdBy, "default"),
106+
});
107+
const firstSocket = new FakeServerSocket();
108+
const secondSocket = new FakeServerSocket();
109+
110+
try {
111+
const first = server.prepareWebSocketUpgrade({
112+
createAgent: recordingFactory(createdBy, "first"),
113+
});
114+
const second = server.prepareWebSocketUpgrade({
115+
createAgent: recordingFactory(createdBy, "second"),
116+
});
117+
118+
second.accept(secondSocket);
119+
first.accept(firstSocket);
120+
secondSocket.receive(JSON.stringify({ ...initializeRequest, id: 2 }));
121+
firstSocket.receive(JSON.stringify({ ...initializeRequest, id: 1 }));
122+
123+
await Promise.all([
124+
readSentMessage(firstSocket),
125+
readSentMessage(secondSocket),
126+
]);
127+
expect(createdBy).toEqual(expect.arrayContaining(["first", "second"]));
128+
expect(createdBy).toHaveLength(2);
129+
} finally {
130+
firstSocket.close();
131+
secondSocket.close();
132+
await server.close();
133+
}
134+
});
135+
136+
it("removes rejected prepared WebSocket connections", async () => {
137+
const server = new AcpServer({
138+
createAgent: (conn) => new TestAgent(conn),
139+
});
140+
const prepared = server.prepareWebSocketUpgrade();
141+
142+
try {
143+
prepared.reject();
144+
const response = await server.handleRequest(
145+
new Request("http://127.0.0.1/acp", {
146+
method: "GET",
147+
headers: {
148+
Accept: "text/event-stream",
149+
[HEADER_CONNECTION_ID]: prepared.connectionId,
150+
},
151+
}),
152+
);
153+
154+
expect(response.status).toBe(404);
155+
} finally {
156+
await server.close();
157+
}
158+
});
159+
160+
it("keeps existing double-settle behavior for prepared WebSocket upgrades", async () => {
161+
const server = new AcpServer({
162+
createAgent: (conn) => new TestAgent(conn),
163+
});
164+
const rejected = server.prepareWebSocketUpgrade();
165+
const accepted = server.prepareWebSocketUpgrade();
166+
const socket = new FakeServerSocket();
167+
168+
try {
169+
rejected.reject();
170+
expect(() => rejected.accept(new FakeServerSocket())).toThrow(
171+
"ACP WebSocket upgrade has already been settled",
172+
);
173+
174+
accepted.accept(socket);
175+
expect(() => accepted.accept(new FakeServerSocket())).toThrow(
176+
"ACP WebSocket upgrade has already been settled",
177+
);
178+
expect(() => accepted.reject()).not.toThrow();
179+
} finally {
180+
socket.close();
181+
await server.close();
182+
}
183+
});
184+
});
185+
186+
function recordingFactory(
187+
createdBy: string[],
188+
label: string,
189+
): (conn: AgentSideConnection) => Agent {
190+
return (conn) => {
191+
createdBy.push(label);
192+
return new TestAgent(conn);
193+
};
194+
}
195+
196+
function readSentMessage(socket: FakeServerSocket): Promise<AnyMessage> {
197+
const message = socket.sent.shift();
198+
199+
if (message) {
200+
return Promise.resolve(JSON.parse(message));
201+
}
202+
203+
return new Promise((resolve) => {
204+
socket.onSend = (data) => {
205+
resolve(JSON.parse(data));
206+
};
207+
});
208+
}
209+
210+
class FakeServerSocket implements WebSocketServerSocket {
211+
readonly sent: string[] = [];
212+
readonly listeners = new Map<string, Set<(event: unknown) => void>>();
213+
onSend: ((data: string) => void) | undefined;
214+
215+
send(data: string): void {
216+
this.sent.push(data);
217+
this.onSend?.(data);
218+
this.onSend = undefined;
219+
}
220+
221+
close(_code?: number, _reason?: string): void {
222+
this.emit("close", {});
223+
}
224+
225+
addEventListener(type: string, listener: (event: unknown) => void): void {
226+
this.listeners.set(type, this.listeners.get(type) ?? new Set());
227+
this.listeners.get(type)?.add(listener);
228+
}
229+
230+
removeEventListener(type: string, listener: (event: unknown) => void): void {
231+
this.listeners.get(type)?.delete(listener);
232+
}
233+
234+
receive(data: string): void {
235+
this.emit("message", { data });
236+
}
237+
238+
private emit(type: string, event: unknown): void {
239+
for (const listener of this.listeners.get(type) ?? []) {
240+
listener(event);
241+
}
242+
}
243+
}

0 commit comments

Comments
 (0)