Skip to content

Commit 66da6ed

Browse files
committed
Enforce ACP transport routing validation
1 parent 50fdedf commit 66da6ed

3 files changed

Lines changed: 100 additions & 36 deletions

File tree

src/server-session-sse.test.ts

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ function createLoadSessionRequest(id: number, sessionId: string) {
7777
};
7878
}
7979

80+
function createCancelNotification(sessionId: string) {
81+
return {
82+
jsonrpc: "2.0",
83+
method: "session/cancel",
84+
params: {
85+
sessionId,
86+
},
87+
};
88+
}
89+
8090
class LoadSessionAgent extends TestAgent {
8191
constructor(private readonly agentConnection: AgentSideConnection) {
8292
super(agentConnection);
@@ -218,6 +228,47 @@ describe("AcpServer session SSE", () => {
218228
}
219229
});
220230

231+
it("rejects session-scoped notifications without a session header", async () => {
232+
const server = await startTestServer();
233+
234+
try {
235+
const connectionId = await initialize(server.url);
236+
const sessionId = await createSession(server.url, connectionId);
237+
const response = await postJson(
238+
server.url,
239+
createCancelNotification(sessionId),
240+
{
241+
[HEADER_CONNECTION_ID]: connectionId,
242+
},
243+
);
244+
245+
expect(response.status).toBe(400);
246+
} finally {
247+
await server.close();
248+
}
249+
});
250+
251+
it("rejects session-scoped notifications with mismatched session header and params", async () => {
252+
const server = await startTestServer();
253+
254+
try {
255+
const connectionId = await initialize(server.url);
256+
const sessionId = await createSession(server.url, connectionId);
257+
const response = await postJson(
258+
server.url,
259+
createCancelNotification("other-session"),
260+
{
261+
[HEADER_CONNECTION_ID]: connectionId,
262+
[HEADER_SESSION_ID]: sessionId,
263+
},
264+
);
265+
266+
expect(response.status).toBe(400);
267+
} finally {
268+
await server.close();
269+
}
270+
});
271+
221272
it("rejects session-scoped requests without any session identifier", async () => {
222273
const server = await startTestServer();
223274

@@ -262,11 +313,6 @@ describe("AcpServer session SSE", () => {
262313
const connectionId = await initialize(server.url);
263314
const sessionId = "existing-session";
264315
const connectionSse = await openConnectionSse(server.url, connectionId);
265-
const sessionSse = await openSessionSse(
266-
server.url,
267-
connectionId,
268-
sessionId,
269-
);
270316
const accepted = await postJson(
271317
server.url,
272318
createLoadSessionRequest(3, sessionId),
@@ -275,9 +321,14 @@ describe("AcpServer session SSE", () => {
275321
[HEADER_SESSION_ID]: sessionId,
276322
},
277323
);
324+
const sessionSse = await openSessionSse(
325+
server.url,
326+
connectionId,
327+
sessionId,
328+
);
278329

279-
expect(sessionSse.status).toBe(200);
280330
expect(accepted.status).toBe(202);
331+
expect(sessionSse.status).toBe(200);
281332
expect(await readSseMessages(sessionSse, 1)).toMatchObject([
282333
{
283334
jsonrpc: "2.0",

src/server.test.ts

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ describe("AcpServer", () => {
176176
}
177177
});
178178

179-
it("opens session-scoped GETs for sessions without local streams", async () => {
179+
it("rejects session-scoped GETs for unknown sessions", async () => {
180180
const server = await startTestServer();
181181

182182
try {
@@ -187,7 +187,7 @@ describe("AcpServer", () => {
187187
globalThis.crypto.randomUUID(),
188188
);
189189

190-
expect(response.status).toBe(200);
190+
expect(response.status).toBe(404);
191191
} finally {
192192
await server.close();
193193
}
@@ -385,6 +385,21 @@ describe("AcpServer", () => {
385385
}
386386
});
387387

388+
it("rejects initialize requests on existing connections", async () => {
389+
const server = await startTestServer();
390+
391+
try {
392+
const connectionId = await initialize(server.url);
393+
const response = await postJson(server.url, initializeRequest, {
394+
[HEADER_CONNECTION_ID]: connectionId,
395+
});
396+
397+
expect(response.status).toBe(400);
398+
} finally {
399+
await server.close();
400+
}
401+
});
402+
388403
it("rejects unknown connection IDs", async () => {
389404
const server = await startTestServer();
390405

src/server.ts

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@ import {
99
methodRequiresSessionHeader,
1010
sessionIdFromParams,
1111
} from "./protocol.js";
12-
import {
13-
isJsonRpcMessage,
14-
isRequestMessage,
15-
isResponseMessage,
16-
} from "./jsonrpc.js";
12+
import { isJsonRpcMessage, isResponseMessage } from "./jsonrpc.js";
1713
import { AGENT_METHODS } from "./schema/index.js";
1814
import { serializeSseEvent, serializeSseKeepAlive } from "./sse.js";
1915
import { handleWebSocketConnection } from "./ws-server.js";
@@ -25,7 +21,12 @@ import type {
2521
ResponseRoute,
2622
} from "./connection.js";
2723
import type { Agent, AgentSideConnection } from "./acp.js";
28-
import type { AnyMessage, AnyRequest, AnyResponse } from "./jsonrpc.js";
24+
import type {
25+
AnyMessage,
26+
AnyNotification,
27+
AnyRequest,
28+
AnyResponse,
29+
} from "./jsonrpc.js";
2930

3031
/** Options for creating an ACP server transport. */
3132
export interface AcpServerOptions {
@@ -129,8 +130,12 @@ export class AcpServer {
129130

130131
const connectionId = req.headers.get(HEADER_CONNECTION_ID);
131132

132-
if (isInitializeRequest(body.value) && !connectionId) {
133-
return await this.handleInitialize(body.value);
133+
if (isInitializeRequest(body.value)) {
134+
if (!connectionId) {
135+
return await this.handleInitialize(body.value);
136+
}
137+
138+
return textResponse("Initialize not allowed on existing connection", 400);
134139
}
135140

136141
if (!connectionId) {
@@ -180,7 +185,12 @@ export class AcpServer {
180185

181186
const sessionId = req.headers.get(HEADER_SESSION_ID);
182187
if (sessionId) {
183-
return sseResponse(connection.ensureSession(sessionId).subscribe());
188+
const sessionStream = connection.sessionStreams.get(sessionId);
189+
if (!sessionStream) {
190+
return textResponse("Unknown Acp-Session-Id", 404);
191+
}
192+
193+
return sseResponse(sessionStream.subscribe());
184194
}
185195

186196
return sseResponse(connection.connectionStream.subscribe());
@@ -244,15 +254,11 @@ export class AcpServer {
244254
message: AnyMessage,
245255
headers: Headers,
246256
): Promise<ForwardResult> {
247-
if (isRequestMessage(message)) {
248-
return await forwardClientRequest(connection, message, headers);
249-
}
250-
251257
if (isResponseMessage(message)) {
252258
return await forwardClientResponse(connection, message);
253259
}
254260

255-
return await forwardClientNotification(connection, message);
261+
return await forwardClientMethodMessage(connection, message, headers);
256262
}
257263
}
258264

@@ -286,7 +292,7 @@ type RouteResult =
286292
message: string;
287293
};
288294

289-
type ClientRequestMessage = AnyRequest;
295+
type ClientMethodMessage = AnyRequest | AnyNotification;
290296

291297
async function readJson(req: Request): Promise<JsonResult> {
292298
try {
@@ -314,9 +320,9 @@ async function writeInbound(
314320
}
315321
}
316322

317-
async function forwardClientRequest(
323+
async function forwardClientMethodMessage(
318324
connection: ConnectionState,
319-
message: ClientRequestMessage,
325+
message: ClientMethodMessage,
320326
headers: Headers,
321327
): Promise<ForwardResult> {
322328
const route = determineRoute(message, headers);
@@ -329,7 +335,7 @@ async function forwardClientRequest(
329335
connection.ensureSession(route.value.session);
330336
}
331337

332-
const key = messageIdKey(message.id);
338+
const key = "id" in message ? messageIdKey(message.id) : undefined;
333339

334340
if (key) {
335341
connection.pendingRoutes.set(
@@ -350,23 +356,15 @@ async function forwardClientResponse(
350356
return { ok: true };
351357
}
352358

353-
async function forwardClientNotification(
354-
connection: ConnectionState,
355-
message: AnyMessage,
356-
): Promise<ForwardResult> {
357-
await writeInbound(connection, message);
358-
return { ok: true };
359-
}
360-
361359
function pendingResponseRoute(
362-
message: ClientRequestMessage,
360+
message: ClientMethodMessage,
363361
route: ResponseRoute,
364362
): ResponseRoute {
365363
return message.method === AGENT_METHODS.session_load ? "connection" : route;
366364
}
367365

368366
function determineRoute(
369-
message: ClientRequestMessage,
367+
message: ClientMethodMessage,
370368
headers: Headers,
371369
): RouteResult {
372370
const headerSessionId = headers.get(HEADER_SESSION_ID);

0 commit comments

Comments
 (0)