Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 63 additions & 29 deletions containers/api-proxy/proxy-request.js
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,35 @@ function isValidRequestId(id) {
return typeof id === 'string' && id.length <= 128 && /^[\w\-\.]+$/.test(id);
}

function handleRequestError(err, {
res,
requestId,
provider,
req,
targetHost,
startTime,
statusCode,
clientMessage,
extraMetrics,
onHeadersSent,
}) {
const duration = Date.now() - startTime;
metrics.gaugeDec('active_requests', { provider });
metrics.increment('requests_errors_total', { provider });
if (extraMetrics) extraMetrics(duration);
logRequest('error', 'request_error', {
request_id: requestId, provider, method: req.method,
path: sanitizeForLog(req.url), duration_ms: duration,
error: sanitizeForLog(err.message), upstream_host: targetHost,
Comment on lines +134 to +141
});
if (res.headersSent) {
if (onHeadersSent) onHeadersSent(err);
return;
}
res.writeHead(statusCode, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: clientMessage, message: err.message }));
}

const checkRateLimit = createRateLimitChecker({
limiter,
metrics,
Expand Down Expand Up @@ -201,16 +230,16 @@ function proxyRequest(req, res, targetHost, injectHeaders, provider, basePath =
req.on('error', (err) => {
if (errored) return;
errored = true;
const duration = Date.now() - startTime;
metrics.gaugeDec('active_requests', { provider });
metrics.increment('requests_errors_total', { provider });
logRequest('error', 'request_error', {
request_id: requestId, provider, method: req.method,
path: sanitizeForLog(req.url), duration_ms: duration,
error: sanitizeForLog(err.message), upstream_host: targetHost,
handleRequestError(err, {
res,
requestId,
provider,
req,
targetHost,
startTime,
statusCode: 400,
clientMessage: 'Client error',
});
if (!res.headersSent) res.writeHead(400, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'Client error', message: err.message }));
});

req.on('data', chunk => {
Expand Down Expand Up @@ -356,16 +385,19 @@ function proxyRequest(req, res, targetHost, injectHeaders, provider, basePath =
proxyRes.on('data', (chunk) => { responseBytes += chunk.length; });

proxyRes.on('error', (err) => {
const duration = Date.now() - startTime;
metrics.gaugeDec('active_requests', { provider });
metrics.increment('requests_errors_total', { provider });
logRequest('error', 'request_error', {
request_id: requestId, provider, method: req.method,
path: sanitizeForLog(req.url), duration_ms: duration,
error: sanitizeForLog(err.message), upstream_host: targetHost,
handleRequestError(err, {
res,
requestId,
provider,
req,
targetHost,
startTime,
statusCode: 502,
clientMessage: 'Response stream error',
onHeadersSent: () => {
if (typeof res.destroy === 'function') res.destroy(err);
},
});
Comment on lines 387 to 400
if (!res.headersSent) res.writeHead(502, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'Response stream error', message: err.message }));
});

const billingInfo = extractBillingHeaders(proxyRes.headers);
Expand Down Expand Up @@ -419,18 +451,20 @@ function proxyRequest(req, res, targetHost, injectHeaders, provider, basePath =
});

proxyReq.on('error', (err) => {
const duration = Date.now() - startTime;
metrics.gaugeDec('active_requests', { provider });
metrics.increment('requests_errors_total', { provider });
metrics.increment('requests_total', { provider, method: req.method, status_class: '5xx' });
metrics.observe('request_duration_ms', duration, { provider });
logRequest('error', 'request_error', {
request_id: requestId, provider, method: req.method,
path: sanitizeForLog(req.url), duration_ms: duration,
error: sanitizeForLog(err.message), upstream_host: targetHost,
handleRequestError(err, {
res,
requestId,
provider,
req,
targetHost,
startTime,
statusCode: 502,
clientMessage: 'Proxy error',
extraMetrics: (duration) => {
metrics.increment('requests_total', { provider, method: req.method, status_class: '5xx' });
metrics.observe('request_duration_ms', duration, { provider });
},
});
if (!res.headersSent) res.writeHead(502, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: 'Proxy error', message: err.message }));
});

if (body.length > 0) proxyReq.write(body);
Expand Down
153 changes: 152 additions & 1 deletion containers/api-proxy/server.proxy.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ const { resetEffectiveTokenGuardForTests, resetMaxRunsGuardForTests, resetTimeou
const originalHttpsProxy = process.env.HTTPS_PROXY;
let proxyRequest;
let proxyWebSocket;
let healthResponse;

beforeAll(() => {
delete process.env.HTTPS_PROXY;
jest.resetModules();
({ proxyRequest, proxyWebSocket } = require('./server'));
({ proxyRequest, proxyWebSocket, healthResponse } = require('./server'));
});

afterAll(() => {
Expand Down Expand Up @@ -472,6 +473,156 @@ describe('proxyRequest X-Initiator injection', () => {
});
});

describe('proxyRequest error handling', () => {
function makeReq(headers = {}) {
const req = new EventEmitter();
req.url = '/v1/chat/completions';
req.method = 'POST';
req.headers = { 'content-type': 'application/json', ...headers };
return req;
}

function makeRes() {
const res = {
headersSent: false,
setHeader: jest.fn(),
writeHead: jest.fn(() => {
res.headersSent = true;
}),
end: jest.fn(),
Comment on lines +485 to +492
destroy: jest.fn(),
};
return res;
}

function getRequestErrorLog(writeSpy) {
for (const [line] of writeSpy.mock.calls) {
try {
const parsed = JSON.parse(line);
if (parsed.event === 'request_error') return parsed;
} catch {
// ignore non-JSON writes
}
}
return null;
}

let stdoutWriteSpy;

beforeEach(() => {
stdoutWriteSpy = jest.spyOn(process.stdout, 'write').mockImplementation(() => true);
});

afterEach(() => {
jest.restoreAllMocks();
});

it('returns 400 when the client request stream errors', () => {
const before = healthResponse().metrics_summary;
const req = makeReq();
const res = makeRes();

proxyRequest(req, res, 'api.openai.com', { Authorization: 'Bearer token' }, 'openai');
req.emit('error', new Error('client stream failed\ninjected'));

expect(res.writeHead).toHaveBeenCalledWith(400, { 'Content-Type': 'application/json' });
expect(JSON.parse(res.end.mock.calls[0][0])).toEqual({
error: 'Client error',
message: 'client stream failed\ninjected',
});
const after = healthResponse().metrics_summary;
expect(after.total_errors).toBe(before.total_errors + 1);
expect(after.active_requests).toBe(before.active_requests);
const errorLog = getRequestErrorLog(stdoutWriteSpy);
expect(errorLog).toMatchObject({
event: 'request_error',
provider: 'openai',
method: 'POST',
path: '/v1/chat/completions',
error: 'client stream failedinjected',
upstream_host: 'api.openai.com',
});
});

it('destroys response when upstream response stream errors after headers are sent', () => {
const before = healthResponse().metrics_summary;
let responseHandler;
const upstreamRequest = new EventEmitter();
upstreamRequest.end = jest.fn();
upstreamRequest.write = jest.fn();
upstreamRequest.destroy = jest.fn();

jest.spyOn(https, 'request').mockImplementation((_options, cb) => {
responseHandler = cb;
return upstreamRequest;
});

const req = makeReq();
const res = makeRes();
proxyRequest(req, res, 'api.openai.com', { Authorization: 'Bearer token' }, 'openai');
req.emit('end');

const proxyRes = new EventEmitter();
proxyRes.statusCode = 200;
proxyRes.headers = {};
proxyRes.pipe = jest.fn();
responseHandler(proxyRes);
proxyRes.emit('error', new Error('upstream stream failed'));

expect(res.writeHead).toHaveBeenNthCalledWith(1, 200, { 'x-request-id': expect.any(String) });
expect(res.writeHead).toHaveBeenCalledTimes(1);
expect(res.end).not.toHaveBeenCalled();
expect(res.destroy).toHaveBeenCalledWith(expect.any(Error));
const after = healthResponse().metrics_summary;
expect(after.total_errors).toBe(before.total_errors + 1);
expect(after.active_requests).toBe(before.active_requests);
const errorLog = getRequestErrorLog(stdoutWriteSpy);
expect(errorLog).toMatchObject({
event: 'request_error',
provider: 'openai',
method: 'POST',
path: '/v1/chat/completions',
error: 'upstream stream failed',
upstream_host: 'api.openai.com',
});
});

it('returns 502 when the upstream proxy request errors', () => {
const before = healthResponse().metrics_summary;
const upstreamRequest = new EventEmitter();
upstreamRequest.end = jest.fn();
upstreamRequest.write = jest.fn();
upstreamRequest.destroy = jest.fn();

jest.spyOn(https, 'request').mockImplementation(() => upstreamRequest);

const req = makeReq();
const res = makeRes();
proxyRequest(req, res, 'api.openai.com', { Authorization: 'Bearer token' }, 'openai');
req.emit('end');
upstreamRequest.emit('error', new Error('upstream connect failed'));

expect(res.writeHead).toHaveBeenCalledWith(502, { 'Content-Type': 'application/json' });
expect(JSON.parse(res.end.mock.calls[0][0])).toEqual({
error: 'Proxy error',
message: 'upstream connect failed',
});
const after = healthResponse().metrics_summary;
expect(after.total_errors).toBe(before.total_errors + 1);
expect(after.total_requests).toBe(before.total_requests + 1);
expect(after.active_requests).toBe(before.active_requests);
const errorLog = getRequestErrorLog(stdoutWriteSpy);
expect(errorLog).toMatchObject({
event: 'request_error',
provider: 'openai',
method: 'POST',
path: '/v1/chat/completions',
error: 'upstream connect failed',
upstream_host: 'api.openai.com',
});
});
});

describe('proxyRequest effective token guard', () => {
function makeReq(headers = {}) {
const req = new EventEmitter();
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/one-shot-tokens.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ describe('One-Shot Token Protection', () => {
{
allowDomains: ['localhost'],
logLevel: 'debug',
timeout: 240000,
timeout: 480000,
buildLocal: true, // Build container locally to include one-shot-token.so
env: {
GITHUB_TOKEN: 'ghp_test_token_12345',
Expand All @@ -80,7 +80,7 @@ describe('One-Shot Token Protection', () => {
expect(result.stdout).toContain('Second read: [ghp_test_token_12345]');
// Note: printenv reads from environ array directly, not via getenv().
// The LD_PRELOAD library only intercepts getenv() calls, so no debug output appears here.
}, 240000);
}, 480000);

test('should cache COPILOT_GITHUB_TOKEN and clear from environment', async () => {
const testScript = `
Expand Down
14 changes: 7 additions & 7 deletions tests/integration/protocol-support.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ describe('Protocol Support', () => {
describe('HTTPS Connections', () => {
test('should allow HTTPS to allowed domain', async () => {
const result = await runner.runWithSudo(
'curl -fsS https://api.github.com/zen',
'curl -fsS https://github.com',
{
allowDomains: ['github.com'],
logLevel: 'debug',
Expand All @@ -55,7 +55,7 @@ describe('Protocol Support', () => {

test('should handle HTTPS with verbose output', async () => {
const result = await runner.runWithSudo(
'curl -v https://api.github.com/zen 2>&1 | grep -E "SSL|TLS" | head -5 || true',
'curl -v https://github.com 2>&1 | grep -E "SSL|TLS" | head -5 || true',
{
allowDomains: ['github.com'],
logLevel: 'debug',
Expand Down Expand Up @@ -117,7 +117,7 @@ describe('Protocol Support', () => {
describe('Connection Headers', () => {
test('should pass custom headers', async () => {
const result = await runner.runWithSudo(
'curl -fsS -H "Accept: application/vnd.github+json" https://api.github.com/zen',
'curl -fsS -H "Accept: text/html" https://github.com',
{
allowDomains: ['github.com'],
logLevel: 'debug',
Expand All @@ -130,7 +130,7 @@ describe('Protocol Support', () => {

test('should pass User-Agent header', async () => {
const result = await runner.runWithSudo(
'curl -fsS -A "Test-Agent/1.0" https://api.github.com/zen',
'curl -fsS -A "Test-Agent/1.0" https://github.com',
{
allowDomains: ['github.com'],
logLevel: 'debug',
Expand All @@ -145,7 +145,7 @@ describe('Protocol Support', () => {
describe('IPv4/IPv6', () => {
test('should support IPv4 connections', async () => {
const result = await runner.runWithSudo(
'curl -fsS -4 https://api.github.com/zen',
'curl -fsS -4 https://github.com',
{
allowDomains: ['github.com'],
logLevel: 'debug',
Expand All @@ -159,7 +159,7 @@ describe('Protocol Support', () => {
test('should handle IPv6 (may not be available)', async () => {
// IPv6 may not be available in all environments
const result = await runner.runWithSudo(
'curl -fsS -6 https://api.github.com/zen || exit 0',
'curl -fsS -6 https://github.com || exit 0',
{
allowDomains: ['github.com'],
logLevel: 'debug',
Expand All @@ -175,7 +175,7 @@ describe('Protocol Support', () => {
describe('Connection Timeouts', () => {
test('should respect curl max-time option', async () => {
const result = await runner.runWithSudo(
'curl -f --max-time 5 https://api.github.com/zen',
'curl -f --max-time 5 https://github.com',
{
allowDomains: ['github.com'],
logLevel: 'debug',
Expand Down
Loading