Skip to content

Commit 2a6d4c5

Browse files
committed
fix: Align HTTP session routing with RFD
1 parent 66da6ed commit 2a6d4c5

5 files changed

Lines changed: 246 additions & 3 deletions

File tree

src/connection.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ export class ConnectionState {
9191
readonly allOutbound = new OutboundStream();
9292
readonly sessionStreams = new Map<string, OutboundStream>();
9393
readonly pendingRoutes = new Map<string, ResponseRoute>();
94+
readonly clientResponseRoutes = new Map<string, ResponseRoute>();
9495

9596
private hasStartedRouter = false;
9697
private outboundReader: ReadableStreamDefaultReader<AnyMessage> | undefined;
@@ -163,6 +164,7 @@ export class ConnectionState {
163164

164165
this.sessionStreams.clear();
165166
this.pendingRoutes.clear();
167+
this.clientResponseRoutes.clear();
166168

167169
await Promise.allSettled([
168170
this.inboundTx.close(),
@@ -231,13 +233,29 @@ export class ConnectionState {
231233
private routeOutboundRequestOrNotification(message: AnyMessage): void {
232234
const sessionId = sessionIdFromMessageParams(message);
233235
if (sessionId) {
236+
this.trackClientResponseRoute(message, { session: sessionId });
234237
this.ensureSession(sessionId).push(message);
235238
return;
236239
}
237240

241+
this.trackClientResponseRoute(message, "connection");
238242
this.connectionStream.push(message);
239243
}
240244

245+
private trackClientResponseRoute(
246+
message: AnyMessage,
247+
route: ResponseRoute,
248+
): void {
249+
if (!("id" in message) || !("method" in message)) {
250+
return;
251+
}
252+
253+
const key = messageIdKey(message.id);
254+
if (key) {
255+
this.clientResponseRoutes.set(key, route);
256+
}
257+
}
258+
241259
private pushToRoute(route: ResponseRoute, message: AnyMessage): void {
242260
if (route === "connection") {
243261
this.connectionStream.push(message);

src/http-stream.test.ts

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,48 @@ const promptRequest = {
5959
},
6060
} satisfies AnyMessage;
6161

62+
const loadSessionRequest = {
63+
jsonrpc: "2.0",
64+
id: 3,
65+
method: "session/load",
66+
params: {
67+
cwd: "/tmp",
68+
mcpServers: [],
69+
sessionId: "existing-session",
70+
},
71+
} satisfies AnyMessage;
72+
73+
const permissionRequest = {
74+
jsonrpc: "2.0",
75+
id: 99,
76+
method: "session/request_permission",
77+
params: {
78+
sessionId: "session-1",
79+
toolCall: {
80+
toolCallId: "permission-tool",
81+
title: "Permission tool",
82+
},
83+
options: [
84+
{
85+
kind: "allow_once",
86+
name: "Allow once",
87+
optionId: "allow",
88+
},
89+
],
90+
},
91+
} satisfies AnyMessage;
92+
93+
const permissionResponse = {
94+
jsonrpc: "2.0",
95+
id: 99,
96+
result: {
97+
outcome: {
98+
outcome: "selected",
99+
optionId: "allow",
100+
},
101+
},
102+
} satisfies AnyMessage;
103+
62104
describe("createHttpStream", () => {
63105
it("posts initialize with custom headers, opens connection SSE, and emits the initialize response", async () => {
64106
const controlledFetch = createControlledFetch();
@@ -137,6 +179,68 @@ describe("createHttpStream", () => {
137179
}
138180
});
139181

182+
it("opens session SSE before posting session/load for an existing session", async () => {
183+
const controlledFetch = createControlledFetch();
184+
const stream = createHttpStream("https://agent.example/acp", {
185+
fetch: controlledFetch.fetch,
186+
});
187+
const writer = stream.writable.getWriter();
188+
const reader = stream.readable.getReader();
189+
190+
try {
191+
await writer.write(initializeRequest);
192+
await readMessage(reader);
193+
await writer.write(loadSessionRequest);
194+
195+
const sessionGet = requestAt(controlledFetch.requests, 2);
196+
const loadPost = requestAt(controlledFetch.requests, 3);
197+
198+
expect(sessionGet.method).toBe("GET");
199+
expect(sessionGet.headers.get(HEADER_CONNECTION_ID)).toBe("connection-1");
200+
expect(sessionGet.headers.get(HEADER_SESSION_ID)).toBe(
201+
"existing-session",
202+
);
203+
expect(loadPost.method).toBe("POST");
204+
expect(loadPost.headers.get(HEADER_CONNECTION_ID)).toBe("connection-1");
205+
expect(loadPost.headers.get(HEADER_SESSION_ID)).toBe("existing-session");
206+
} finally {
207+
reader.releaseLock();
208+
writer.releaseLock();
209+
await stream.writable.close();
210+
}
211+
});
212+
213+
it("includes the session header on responses to session-scoped server requests", async () => {
214+
const controlledFetch = createControlledFetch();
215+
const stream = createHttpStream("https://agent.example/acp", {
216+
fetch: controlledFetch.fetch,
217+
});
218+
const writer = stream.writable.getWriter();
219+
const reader = stream.readable.getReader();
220+
221+
try {
222+
await writer.write(initializeRequest);
223+
await readMessage(reader);
224+
await controlledFetch.sendSse(0, sessionNewResponse);
225+
await readMessage(reader);
226+
await controlledFetch.sendSse(1, permissionRequest);
227+
await readMessage(reader);
228+
await writer.write(permissionResponse);
229+
230+
const responsePost = requestAt(controlledFetch.requests, 3);
231+
expect(responsePost.method).toBe("POST");
232+
expect(responsePost.headers.get(HEADER_CONNECTION_ID)).toBe(
233+
"connection-1",
234+
);
235+
expect(responsePost.headers.get(HEADER_SESSION_ID)).toBe("session-1");
236+
expect(JSON.parse(responsePost.body)).toEqual(permissionResponse);
237+
} finally {
238+
reader.releaseLock();
239+
writer.releaseLock();
240+
await stream.writable.close();
241+
}
242+
});
243+
140244
it("propagates cookies across initialize, SSE, session POST, and DELETE", async () => {
141245
const controlledFetch = createControlledFetch({
142246
initializeCookies: ["transport=alpha; Path=/"],

src/http-stream.ts

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import {
55
HEADER_SESSION_ID,
66
JSON_MIME_TYPE,
77
isInitializeRequest,
8+
messageIdKey,
89
sessionIdFromMessageParams,
910
sessionIdFromResponseResult,
1011
} from "./protocol.js";
@@ -44,6 +45,7 @@ class HttpStreamTransport {
4445
private readonly cookieJar = new ConnectionCookieJar();
4546
private readonly abortController = new AbortController();
4647
private readonly knownSessions = new Set<string>();
48+
private readonly pendingResponseSessions = new Map<string, string>();
4749

4850
private readableController:
4951
| ReadableStreamDefaultController<AnyMessage>
@@ -133,7 +135,11 @@ class HttpStreamTransport {
133135
throw new Error("ACP HTTP stream is not initialized");
134136
}
135137

136-
const sessionId = sessionIdFromMessageParams(message);
138+
const sessionId = this.sessionIdForOutboundMessage(message);
139+
if (sessionId) {
140+
this.openSessionSse(sessionId);
141+
}
142+
137143
const response = await this.fetchRequest({
138144
method: "POST",
139145
headers: {
@@ -149,6 +155,20 @@ class HttpStreamTransport {
149155
}
150156
}
151157

158+
private sessionIdForOutboundMessage(message: AnyMessage): string | undefined {
159+
const paramsSessionId = sessionIdFromMessageParams(message);
160+
if (paramsSessionId) {
161+
return paramsSessionId;
162+
}
163+
164+
if (!("id" in message) || "method" in message) {
165+
return undefined;
166+
}
167+
168+
const key = messageIdKey(message.id);
169+
return key ? this.pendingResponseSessions.get(key) : undefined;
170+
}
171+
152172
private openConnectionSse(): void {
153173
const connectionId = this.connectionId;
154174
if (!connectionId) {
@@ -207,6 +227,7 @@ class HttpStreamTransport {
207227
this.openSessionSse(sessionId);
208228
}
209229

230+
this.trackServerRequestRoute(message, headers[HEADER_SESSION_ID]);
210231
this.enqueue(message);
211232
}
212233
} catch (error) {
@@ -218,6 +239,20 @@ class HttpStreamTransport {
218239
}
219240
}
220241

242+
private trackServerRequestRoute(
243+
message: AnyMessage,
244+
streamSessionId: string | undefined,
245+
): void {
246+
if (!streamSessionId || !("method" in message) || !("id" in message)) {
247+
return;
248+
}
249+
250+
const key = messageIdKey(message.id);
251+
if (key) {
252+
this.pendingResponseSessions.set(key, streamSessionId);
253+
}
254+
}
255+
221256
private async fetchRequest(init: RequestInit): Promise<Response> {
222257
const response = await this.fetchImpl(this.serverUrl, {
223258
...init,

src/server-permission.test.ts

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,64 @@ function createPromptRequest(id: number, sessionId: string) {
4545
}
4646

4747
describe("AcpServer permission requests over HTTP", () => {
48+
it("rejects session-scoped client responses without a session header", async () => {
49+
const server = await startTestServer(
50+
(conn: AgentSideConnection) =>
51+
new TestAgent(conn, { enablePermission: true }),
52+
);
53+
54+
try {
55+
const connectionId = await initialize(server.url);
56+
const sessionId = await createSession(server.url, connectionId);
57+
const sessionSse = await openSessionSse(
58+
server.url,
59+
connectionId,
60+
sessionId,
61+
);
62+
const sessionEvents = createSseMessageIterator(sessionSse);
63+
64+
expect(
65+
await postJson(server.url, createPromptRequest(3, sessionId), {
66+
[HEADER_CONNECTION_ID]: connectionId,
67+
[HEADER_SESSION_ID]: sessionId,
68+
}),
69+
).toMatchObject({ status: 202 });
70+
71+
await readNextSseMessage(sessionEvents);
72+
const permissionRequest = await readNextSseMessage(sessionEvents);
73+
74+
const permissionResponse = {
75+
jsonrpc: "2.0",
76+
id: readMessageId(permissionRequest),
77+
result: {
78+
outcome: {
79+
outcome: "selected",
80+
optionId: "allow",
81+
},
82+
},
83+
};
84+
85+
expect(
86+
await postJson(server.url, permissionResponse, {
87+
[HEADER_CONNECTION_ID]: connectionId,
88+
}),
89+
).toMatchObject({ status: 400 });
90+
expect(
91+
await postJson(server.url, permissionResponse, {
92+
[HEADER_CONNECTION_ID]: connectionId,
93+
[HEADER_SESSION_ID]: sessionId,
94+
}),
95+
).toMatchObject({ status: 202 });
96+
97+
await readNextSseMessage(sessionEvents);
98+
await readNextSseMessage(sessionEvents);
99+
await sessionEvents.return?.();
100+
await sessionSse.body?.cancel();
101+
} finally {
102+
await server.close();
103+
}
104+
}, 10_000);
105+
48106
it("routes permission requests over session SSE and accepts client responses", async () => {
49107
const server = await startTestServer(
50108
(conn: AgentSideConnection) =>
@@ -122,7 +180,10 @@ describe("AcpServer permission requests over HTTP", () => {
122180
},
123181
},
124182
},
125-
{ [HEADER_CONNECTION_ID]: connectionId },
183+
{
184+
[HEADER_CONNECTION_ID]: connectionId,
185+
[HEADER_SESSION_ID]: sessionId,
186+
},
126187
),
127188
).toMatchObject({ status: 202 });
128189

src/server.ts

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ export class AcpServer {
255255
headers: Headers,
256256
): Promise<ForwardResult> {
257257
if (isResponseMessage(message)) {
258-
return await forwardClientResponse(connection, message);
258+
return await forwardClientResponse(connection, message, headers);
259259
}
260260

261261
return await forwardClientMethodMessage(connection, message, headers);
@@ -351,7 +351,32 @@ async function forwardClientMethodMessage(
351351
async function forwardClientResponse(
352352
connection: ConnectionState,
353353
message: AnyResponse,
354+
headers: Headers,
354355
): Promise<ForwardResult> {
356+
const key = messageIdKey(message.id);
357+
const route = key ? connection.clientResponseRoutes.get(key) : undefined;
358+
const headerSessionId = headers.get(HEADER_SESSION_ID);
359+
360+
if (route && route !== "connection" && !headerSessionId) {
361+
return {
362+
ok: false,
363+
status: 400,
364+
message: "Missing Acp-Session-Id",
365+
};
366+
}
367+
368+
if (route && route !== "connection" && headerSessionId !== route.session) {
369+
return {
370+
ok: false,
371+
status: 400,
372+
message: "Mismatched Acp-Session-Id",
373+
};
374+
}
375+
376+
if (key) {
377+
connection.clientResponseRoutes.delete(key);
378+
}
379+
355380
await writeInbound(connection, message);
356381
return { ok: true };
357382
}

0 commit comments

Comments
 (0)