Skip to content
133 changes: 118 additions & 15 deletions packages/vscode-ide-companion/src/ide-server.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type * as vscode from 'vscode';
import * as fs from 'node:fs/promises';
import type * as os from 'node:os';
import * as path from 'node:path';
import * as http from 'node:http';
import { IDEServer } from './ide-server.js';
import type { DiffManager } from './diff-manager.js';

Expand Down Expand Up @@ -62,26 +63,26 @@ vi.mock('./open-files-manager', () => {
return { OpenFilesManager };
});

const getPortFromMock = (
replaceMock: ReturnType<
() => vscode.ExtensionContext['environmentVariableCollection']['replace']
>,
) => {
const port = vi
.mocked(replaceMock)
.mock.calls.find((call) => call[0] === 'GEMINI_CLI_IDE_SERVER_PORT')?.[1];

if (port === undefined) {
expect.fail('Port was not set');
}
return port;
};

describe('IDEServer', () => {
let ideServer: IDEServer;
let mockContext: vscode.ExtensionContext;
let mockLog: (message: string) => void;

const getPortFromMock = (
replaceMock: ReturnType<
() => vscode.ExtensionContext['environmentVariableCollection']['replace']
>,
) => {
const port = vi
.mocked(replaceMock)
.mock.calls.find((call) => call[0] === 'GEMINI_CLI_IDE_SERVER_PORT')?.[1];

if (port === undefined) {
expect.fail('Port was not set');
}
return port;
};

beforeEach(() => {
mockLog = vi.fn();
ideServer = new IDEServer(mockLog, mocks.diffManager);
Expand Down Expand Up @@ -456,3 +457,105 @@ describe('IDEServer', () => {
});
});
});

const request = (
port: string,
options: http.RequestOptions,
body?: string,
): Promise<http.IncomingMessage> =>
new Promise((resolve, reject) => {
const req = http.request(
{
hostname: '127.0.0.1',
port,
...options,
},
(res) => {
res.resume(); // Consume response data to free up memory
resolve(res);
},
);
req.on('error', reject);
if (body) {
req.write(body);
}
req.end();
});

describe('IDEServer HTTP endpoints', () => {
let ideServer: IDEServer;
let mockContext: vscode.ExtensionContext;
let mockLog: (message: string) => void;
let port: string;

beforeEach(async () => {
mockLog = vi.fn();
ideServer = new IDEServer(mockLog, mocks.diffManager);
mockContext = {
subscriptions: [],
environmentVariableCollection: {
replace: vi.fn(),
clear: vi.fn(),
},
} as unknown as vscode.ExtensionContext;
await ideServer.start(mockContext);
const replaceMock = mockContext.environmentVariableCollection.replace;
port = getPortFromMock(replaceMock);
});

afterEach(async () => {
await ideServer.stop();
vi.restoreAllMocks();
});

it('should deny requests with an origin header', async () => {
const response = await request(
port,
{
path: '/mcp',
method: 'POST',
headers: {
Host: `localhost:${port}`,
Origin: 'https://evil.com',
'Content-Type': 'application/json',
},
},
JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }),
);
expect(response.statusCode).toBe(403);
});

it('should deny requests with an invalid host header', async () => {
const response = await request(
port,
{
path: '/mcp',
method: 'POST',
headers: {
Host: 'evil.com',
'Content-Type': 'application/json',
},
},
JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }),
);
expect(response.statusCode).toBe(403);
});

it('should allow requests with a valid host header', async () => {
const response = await request(
port,
{
path: '/mcp',
method: 'POST',
headers: {
Host: `localhost:${port}`,
'Content-Type': 'application/json',
},
},
JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }),
);
// We expect a 400 here because we are not sending a valid MCP request,
// but it's not a host error, which is what we are testing.
expect(response.statusCode).toBe(400);
});
});
54 changes: 51 additions & 3 deletions packages/vscode-ide-companion/src/ide-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ import {
import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js';
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import express, { type Request, type Response } from 'express';
import express, {
type Request,
type Response,
type NextFunction,
} from 'express';
import cors from 'cors';
import { randomUUID } from 'node:crypto';
import { type Server as HTTPServer } from 'node:http';
import * as path from 'node:path';
Expand All @@ -23,6 +28,13 @@ import type { z } from 'zod';
import type { DiffManager } from './diff-manager.js';
import { OpenFilesManager } from './open-files-manager.js';

class CORSError extends Error {
constructor(message: string) {
super(message);
this.name = 'CORSError';
}
}

const MCP_SESSION_ID_HEADER = 'mcp-session-id';
const IDE_SERVER_PORT_ENV_VAR = 'GEMINI_CLI_IDE_SERVER_PORT';
const IDE_WORKSPACE_PATH_ENV_VAR = 'GEMINI_CLI_IDE_WORKSPACE_PATH';
Expand Down Expand Up @@ -131,6 +143,34 @@ export class IDEServer {

const app = express();
app.use(express.json({ limit: '10mb' }));

app.use(
cors({
origin: (origin, callback) => {
// Only allow non-browser requests with no origin.
if (!origin) {
return callback(null, true);
}
return callback(
new CORSError('Request denied by CORS policy.'),
false,
);
},
}),
);

app.use((req, res, next) => {
const host = req.headers.host || '';
Comment thread
skeshive marked this conversation as resolved.
const allowedHosts = [
`localhost:${this.port}`,
`127.0.0.1:${this.port}`,
];
if (!allowedHosts.includes(host)) {
return res.status(403).json({ error: 'Invalid Host header' });
}
next();
});

app.use((req, res, next) => {
const authHeader = req.headers.authorization;
if (authHeader) {
Expand Down Expand Up @@ -274,7 +314,15 @@ export class IDEServer {

app.get('/mcp', handleSessionRequest);

this.server = app.listen(0, async () => {
app.use((err: Error, req: Request, res: Response, next: NextFunction) => {
if (err instanceof CORSError) {
res.status(403).json({ error: 'Request denied by CORS policy.' });
} else {
next(err);
}
});

this.server = app.listen(0, '127.0.0.1', async () => {
const address = (this.server as HTTPServer).address();
if (address && typeof address !== 'string') {
this.port = address.port;
Expand All @@ -286,7 +334,7 @@ export class IDEServer {
os.tmpdir(),
`gemini-ide-server-${process.ppid}.json`,
);
this.log(`IDE server listening on port ${this.port}`);
this.log(`IDE server listening on http://127.0.0.1:${this.port}`);

if (this.authToken) {
await writePortAndWorkspace({
Expand Down
Loading