diff --git a/containers/api-proxy/proxy-request.js b/containers/api-proxy/proxy-request.js index 3bb210e21..8f8cd546a 100644 --- a/containers/api-proxy/proxy-request.js +++ b/containers/api-proxy/proxy-request.js @@ -119,6 +119,35 @@ function isValidRequestId(id) { return typeof id === 'string' && id.length <= 128 && /^[\w\-\.]+$/.test(id); } +function handleRequestError(err, { + res, + requestId, + provider, + req, + targetHost, + startTime, + statusCode, + clientMessage, + extraMetrics, + onHeadersSent, +}) { + const duration = Date.now() - startTime; + metrics.gaugeDec('active_requests', { provider }); + metrics.increment('requests_errors_total', { provider }); + if (extraMetrics) extraMetrics(duration); + logRequest('error', 'request_error', { + request_id: requestId, provider, method: req.method, + path: sanitizeForLog(req.url), duration_ms: duration, + error: sanitizeForLog(err.message), upstream_host: targetHost, + }); + if (res.headersSent) { + if (onHeadersSent) onHeadersSent(err); + return; + } + res.writeHead(statusCode, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: clientMessage, message: err.message })); +} + const checkRateLimit = createRateLimitChecker({ limiter, metrics, @@ -201,16 +230,16 @@ function proxyRequest(req, res, targetHost, injectHeaders, provider, basePath = req.on('error', (err) => { if (errored) return; errored = true; - const duration = Date.now() - startTime; - metrics.gaugeDec('active_requests', { provider }); - metrics.increment('requests_errors_total', { provider }); - logRequest('error', 'request_error', { - request_id: requestId, provider, method: req.method, - path: sanitizeForLog(req.url), duration_ms: duration, - error: sanitizeForLog(err.message), upstream_host: targetHost, + handleRequestError(err, { + res, + requestId, + provider, + req, + targetHost, + startTime, + statusCode: 400, + clientMessage: 'Client error', }); - if (!res.headersSent) res.writeHead(400, { 'Content-Type': 'application/json' }); - res.end(JSON.stringify({ error: 'Client error', message: err.message })); }); req.on('data', chunk => { @@ -356,16 +385,19 @@ function proxyRequest(req, res, targetHost, injectHeaders, provider, basePath = proxyRes.on('data', (chunk) => { responseBytes += chunk.length; }); proxyRes.on('error', (err) => { - const duration = Date.now() - startTime; - metrics.gaugeDec('active_requests', { provider }); - metrics.increment('requests_errors_total', { provider }); - logRequest('error', 'request_error', { - request_id: requestId, provider, method: req.method, - path: sanitizeForLog(req.url), duration_ms: duration, - error: sanitizeForLog(err.message), upstream_host: targetHost, + handleRequestError(err, { + res, + requestId, + provider, + req, + targetHost, + startTime, + statusCode: 502, + clientMessage: 'Response stream error', + onHeadersSent: () => { + if (typeof res.destroy === 'function') res.destroy(err); + }, }); - if (!res.headersSent) res.writeHead(502, { 'Content-Type': 'application/json' }); - res.end(JSON.stringify({ error: 'Response stream error', message: err.message })); }); const billingInfo = extractBillingHeaders(proxyRes.headers); @@ -419,18 +451,20 @@ function proxyRequest(req, res, targetHost, injectHeaders, provider, basePath = }); proxyReq.on('error', (err) => { - const duration = Date.now() - startTime; - metrics.gaugeDec('active_requests', { provider }); - metrics.increment('requests_errors_total', { provider }); - metrics.increment('requests_total', { provider, method: req.method, status_class: '5xx' }); - metrics.observe('request_duration_ms', duration, { provider }); - logRequest('error', 'request_error', { - request_id: requestId, provider, method: req.method, - path: sanitizeForLog(req.url), duration_ms: duration, - error: sanitizeForLog(err.message), upstream_host: targetHost, + handleRequestError(err, { + res, + requestId, + provider, + req, + targetHost, + startTime, + statusCode: 502, + clientMessage: 'Proxy error', + extraMetrics: (duration) => { + metrics.increment('requests_total', { provider, method: req.method, status_class: '5xx' }); + metrics.observe('request_duration_ms', duration, { provider }); + }, }); - if (!res.headersSent) res.writeHead(502, { 'Content-Type': 'application/json' }); - res.end(JSON.stringify({ error: 'Proxy error', message: err.message })); }); if (body.length > 0) proxyReq.write(body); diff --git a/containers/api-proxy/server.proxy.test.js b/containers/api-proxy/server.proxy.test.js index 4e267dc5b..d5b444c1c 100644 --- a/containers/api-proxy/server.proxy.test.js +++ b/containers/api-proxy/server.proxy.test.js @@ -14,11 +14,12 @@ const { resetEffectiveTokenGuardForTests, resetMaxRunsGuardForTests, resetTimeou const originalHttpsProxy = process.env.HTTPS_PROXY; let proxyRequest; let proxyWebSocket; +let healthResponse; beforeAll(() => { delete process.env.HTTPS_PROXY; jest.resetModules(); - ({ proxyRequest, proxyWebSocket } = require('./server')); + ({ proxyRequest, proxyWebSocket, healthResponse } = require('./server')); }); afterAll(() => { @@ -472,6 +473,156 @@ describe('proxyRequest X-Initiator injection', () => { }); }); +describe('proxyRequest error handling', () => { + function makeReq(headers = {}) { + const req = new EventEmitter(); + req.url = '/v1/chat/completions'; + req.method = 'POST'; + req.headers = { 'content-type': 'application/json', ...headers }; + return req; + } + + function makeRes() { + const res = { + headersSent: false, + setHeader: jest.fn(), + writeHead: jest.fn(() => { + res.headersSent = true; + }), + end: jest.fn(), + destroy: jest.fn(), + }; + return res; + } + + function getRequestErrorLog(writeSpy) { + for (const [line] of writeSpy.mock.calls) { + try { + const parsed = JSON.parse(line); + if (parsed.event === 'request_error') return parsed; + } catch { + // ignore non-JSON writes + } + } + return null; + } + + let stdoutWriteSpy; + + beforeEach(() => { + stdoutWriteSpy = jest.spyOn(process.stdout, 'write').mockImplementation(() => true); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + it('returns 400 when the client request stream errors', () => { + const before = healthResponse().metrics_summary; + const req = makeReq(); + const res = makeRes(); + + proxyRequest(req, res, 'api.openai.com', { Authorization: 'Bearer token' }, 'openai'); + req.emit('error', new Error('client stream failed\ninjected')); + + expect(res.writeHead).toHaveBeenCalledWith(400, { 'Content-Type': 'application/json' }); + expect(JSON.parse(res.end.mock.calls[0][0])).toEqual({ + error: 'Client error', + message: 'client stream failed\ninjected', + }); + const after = healthResponse().metrics_summary; + expect(after.total_errors).toBe(before.total_errors + 1); + expect(after.active_requests).toBe(before.active_requests); + const errorLog = getRequestErrorLog(stdoutWriteSpy); + expect(errorLog).toMatchObject({ + event: 'request_error', + provider: 'openai', + method: 'POST', + path: '/v1/chat/completions', + error: 'client stream failedinjected', + upstream_host: 'api.openai.com', + }); + }); + + it('destroys response when upstream response stream errors after headers are sent', () => { + const before = healthResponse().metrics_summary; + let responseHandler; + const upstreamRequest = new EventEmitter(); + upstreamRequest.end = jest.fn(); + upstreamRequest.write = jest.fn(); + upstreamRequest.destroy = jest.fn(); + + jest.spyOn(https, 'request').mockImplementation((_options, cb) => { + responseHandler = cb; + return upstreamRequest; + }); + + const req = makeReq(); + const res = makeRes(); + proxyRequest(req, res, 'api.openai.com', { Authorization: 'Bearer token' }, 'openai'); + req.emit('end'); + + const proxyRes = new EventEmitter(); + proxyRes.statusCode = 200; + proxyRes.headers = {}; + proxyRes.pipe = jest.fn(); + responseHandler(proxyRes); + proxyRes.emit('error', new Error('upstream stream failed')); + + expect(res.writeHead).toHaveBeenNthCalledWith(1, 200, { 'x-request-id': expect.any(String) }); + expect(res.writeHead).toHaveBeenCalledTimes(1); + expect(res.end).not.toHaveBeenCalled(); + expect(res.destroy).toHaveBeenCalledWith(expect.any(Error)); + const after = healthResponse().metrics_summary; + expect(after.total_errors).toBe(before.total_errors + 1); + expect(after.active_requests).toBe(before.active_requests); + const errorLog = getRequestErrorLog(stdoutWriteSpy); + expect(errorLog).toMatchObject({ + event: 'request_error', + provider: 'openai', + method: 'POST', + path: '/v1/chat/completions', + error: 'upstream stream failed', + upstream_host: 'api.openai.com', + }); + }); + + it('returns 502 when the upstream proxy request errors', () => { + const before = healthResponse().metrics_summary; + const upstreamRequest = new EventEmitter(); + upstreamRequest.end = jest.fn(); + upstreamRequest.write = jest.fn(); + upstreamRequest.destroy = jest.fn(); + + jest.spyOn(https, 'request').mockImplementation(() => upstreamRequest); + + const req = makeReq(); + const res = makeRes(); + proxyRequest(req, res, 'api.openai.com', { Authorization: 'Bearer token' }, 'openai'); + req.emit('end'); + upstreamRequest.emit('error', new Error('upstream connect failed')); + + expect(res.writeHead).toHaveBeenCalledWith(502, { 'Content-Type': 'application/json' }); + expect(JSON.parse(res.end.mock.calls[0][0])).toEqual({ + error: 'Proxy error', + message: 'upstream connect failed', + }); + const after = healthResponse().metrics_summary; + expect(after.total_errors).toBe(before.total_errors + 1); + expect(after.total_requests).toBe(before.total_requests + 1); + expect(after.active_requests).toBe(before.active_requests); + const errorLog = getRequestErrorLog(stdoutWriteSpy); + expect(errorLog).toMatchObject({ + event: 'request_error', + provider: 'openai', + method: 'POST', + path: '/v1/chat/completions', + error: 'upstream connect failed', + upstream_host: 'api.openai.com', + }); + }); +}); + describe('proxyRequest effective token guard', () => { function makeReq(headers = {}) { const req = new EventEmitter(); diff --git a/tests/integration/one-shot-tokens.test.ts b/tests/integration/one-shot-tokens.test.ts index 120b538a7..f9ae7b7a6 100644 --- a/tests/integration/one-shot-tokens.test.ts +++ b/tests/integration/one-shot-tokens.test.ts @@ -65,7 +65,7 @@ describe('One-Shot Token Protection', () => { { allowDomains: ['localhost'], logLevel: 'debug', - timeout: 240000, + timeout: 480000, buildLocal: true, // Build container locally to include one-shot-token.so env: { GITHUB_TOKEN: 'ghp_test_token_12345', @@ -80,7 +80,7 @@ describe('One-Shot Token Protection', () => { expect(result.stdout).toContain('Second read: [ghp_test_token_12345]'); // Note: printenv reads from environ array directly, not via getenv(). // The LD_PRELOAD library only intercepts getenv() calls, so no debug output appears here. - }, 240000); + }, 480000); test('should cache COPILOT_GITHUB_TOKEN and clear from environment', async () => { const testScript = ` diff --git a/tests/integration/protocol-support.test.ts b/tests/integration/protocol-support.test.ts index 779146511..0f3cfa682 100644 --- a/tests/integration/protocol-support.test.ts +++ b/tests/integration/protocol-support.test.ts @@ -29,7 +29,7 @@ describe('Protocol Support', () => { describe('HTTPS Connections', () => { test('should allow HTTPS to allowed domain', async () => { const result = await runner.runWithSudo( - 'curl -fsS https://api.github.com/zen', + 'curl -fsS https://github.com', { allowDomains: ['github.com'], logLevel: 'debug', @@ -55,7 +55,7 @@ describe('Protocol Support', () => { test('should handle HTTPS with verbose output', async () => { const result = await runner.runWithSudo( - 'curl -v https://api.github.com/zen 2>&1 | grep -E "SSL|TLS" | head -5 || true', + 'curl -v https://github.com 2>&1 | grep -E "SSL|TLS" | head -5 || true', { allowDomains: ['github.com'], logLevel: 'debug', @@ -117,7 +117,7 @@ describe('Protocol Support', () => { describe('Connection Headers', () => { test('should pass custom headers', async () => { const result = await runner.runWithSudo( - 'curl -fsS -H "Accept: application/vnd.github+json" https://api.github.com/zen', + 'curl -fsS -H "Accept: text/html" https://github.com', { allowDomains: ['github.com'], logLevel: 'debug', @@ -130,7 +130,7 @@ describe('Protocol Support', () => { test('should pass User-Agent header', async () => { const result = await runner.runWithSudo( - 'curl -fsS -A "Test-Agent/1.0" https://api.github.com/zen', + 'curl -fsS -A "Test-Agent/1.0" https://github.com', { allowDomains: ['github.com'], logLevel: 'debug', @@ -145,7 +145,7 @@ describe('Protocol Support', () => { describe('IPv4/IPv6', () => { test('should support IPv4 connections', async () => { const result = await runner.runWithSudo( - 'curl -fsS -4 https://api.github.com/zen', + 'curl -fsS -4 https://github.com', { allowDomains: ['github.com'], logLevel: 'debug', @@ -159,7 +159,7 @@ describe('Protocol Support', () => { test('should handle IPv6 (may not be available)', async () => { // IPv6 may not be available in all environments const result = await runner.runWithSudo( - 'curl -fsS -6 https://api.github.com/zen || exit 0', + 'curl -fsS -6 https://github.com || exit 0', { allowDomains: ['github.com'], logLevel: 'debug', @@ -175,7 +175,7 @@ describe('Protocol Support', () => { describe('Connection Timeouts', () => { test('should respect curl max-time option', async () => { const result = await runner.runWithSudo( - 'curl -f --max-time 5 https://api.github.com/zen', + 'curl -f --max-time 5 https://github.com', { allowDomains: ['github.com'], logLevel: 'debug',