diff --git a/containers/api-proxy/server.js b/containers/api-proxy/server.js index 76098cd87..02fd98afe 100644 --- a/containers/api-proxy/server.js +++ b/containers/api-proxy/server.js @@ -19,12 +19,14 @@ const { generateRequestId, sanitizeForLog, logRequest } = require('./logging'); const metrics = require('./metrics'); const rateLimiter = require('./rate-limiter'); let trackTokenUsage; +let trackWebSocketTokenUsage; let closeLogStream; try { - ({ trackTokenUsage, closeLogStream } = require('./token-tracker')); + ({ trackTokenUsage, trackWebSocketTokenUsage, closeLogStream } = require('./token-tracker')); } catch (err) { if (err && err.code === 'MODULE_NOT_FOUND') { trackTokenUsage = () => {}; + trackWebSocketTokenUsage = () => {}; closeLogStream = () => {}; } else { throw err; @@ -672,6 +674,15 @@ function proxyWebSocket(req, socket, head, targetHost, injectHeaders, provider, tlsSocket.pipe(socket); socket.pipe(tlsSocket); + // Attach WebSocket token usage tracking (non-blocking, sniffs upstream frames) + trackWebSocketTokenUsage(tlsSocket, { + requestId, + provider, + path: sanitizeForLog(req.url), + startTime, + metrics, + }); + // Finalize once when either side closes; destroy the other side. socket.once('close', () => { finalize(false); tlsSocket.destroy(); }); tlsSocket.once('close', () => { finalize(false); socket.destroy(); }); diff --git a/containers/api-proxy/token-tracker.js b/containers/api-proxy/token-tracker.js index 46d73c8e3..0d71d0825 100644 --- a/containers/api-proxy/token-tracker.js +++ b/containers/api-proxy/token-tracker.js @@ -22,6 +22,7 @@ const fs = require('fs'); const path = require('path'); +const zlib = require('zlib'); const { logRequest } = require('./logging'); // Max response body to buffer for non-streaming usage extraction (5 MB). @@ -31,8 +32,36 @@ const MAX_BUFFER_SIZE = 5 * 1024 * 1024; // Token usage log file path (inside the mounted log volume) const TOKEN_LOG_DIR = process.env.AWF_TOKEN_LOG_DIR || '/var/log/api-proxy'; const TOKEN_LOG_FILE = path.join(TOKEN_LOG_DIR, 'token-usage.jsonl'); +const DIAG_LOG_FILE = path.join(TOKEN_LOG_DIR, 'token-diag.log'); +const DIAG_ENABLED = process.env.AWF_DEBUG_TOKENS === '1'; let logStream = null; +let diagStream = null; + +/** + * Write a diagnostic line to the diagnostics log file. + * Only active when AWF_DEBUG_TOKENS=1 environment variable is set. + * Data is sanitized to prevent writing raw network content to disk. + */ +function diag(msg, data) { + if (!DIAG_ENABLED) return; + try { + if (!diagStream) { + fs.mkdirSync(TOKEN_LOG_DIR, { recursive: true }); + diagStream = fs.createWriteStream(DIAG_LOG_FILE, { flags: 'a' }); + diagStream.on('error', () => { diagStream = null; }); + } + // Sanitize: only log known safe fields, omit raw response data + let safeData = data; + if (data && typeof data === 'object') { + const { raw_sample, ...rest } = data; + safeData = rest; + } + const line = `${new Date().toISOString()} ${msg}` + + (safeData ? ' ' + JSON.stringify(safeData) : '') + '\n'; + diagStream.write(line); + } catch { /* best-effort */ } +} /** * Get or create the JSONL append stream for token usage logs. @@ -80,6 +109,26 @@ function isStreamingResponse(headers) { return ct.includes('text/event-stream'); } +/** + * Check if a response is gzip or deflate compressed. + */ +function isCompressedResponse(headers) { + const ce = (headers['content-encoding'] || '').toLowerCase(); + return ce === 'gzip' || ce === 'deflate' || ce === 'br'; +} + +/** + * Create a decompression transform stream based on content-encoding. + * Returns null if the encoding is not supported. + */ +function createDecompressor(headers) { + const ce = (headers['content-encoding'] || '').toLowerCase(); + if (ce === 'gzip') return zlib.createGunzip(); + if (ce === 'deflate') return zlib.createInflate(); + if (ce === 'br') return zlib.createBrotliDecompress(); + return null; +} + /** * Extract token usage from a non-streaming JSON response body. * @@ -242,6 +291,10 @@ function normalizeUsage(usage) { * token usage. It does NOT modify the response stream — the caller still * does proxyRes.pipe(res) as before. * + * If the response is gzip/deflate compressed (common with Anthropic API), + * we decompress a copy of the data for parsing while the compressed bytes + * still flow to the client unchanged. + * * @param {http.IncomingMessage} proxyRes - Upstream response * @param {object} opts * @param {string} opts.requestId - Request ID for correlation @@ -253,6 +306,20 @@ function normalizeUsage(usage) { function trackTokenUsage(proxyRes, opts) { const { requestId, provider, path: reqPath, startTime, metrics: metricsRef } = opts; const streaming = isStreamingResponse(proxyRes.headers); + const contentType = proxyRes.headers['content-type'] || '(none)'; + const contentEncoding = proxyRes.headers['content-encoding'] || '(none)'; + const compressed = isCompressedResponse(proxyRes.headers); + + logRequest('debug', 'token_track_start', { + request_id: requestId, + provider, + path: reqPath, + streaming, + content_type: contentType, + content_encoding: contentEncoding, + status: proxyRes.statusCode, + }); + diag('HTTP_TRACK_START', { request_id: requestId, provider, path: reqPath, streaming, content_type: contentType, content_encoding: contentEncoding, status: proxyRes.statusCode }); // Accumulate response body for usage extraction const chunks = []; @@ -264,45 +331,89 @@ function trackTokenUsage(proxyRes, opts) { let streamingModel = null; let partialLine = ''; - proxyRes.on('data', (chunk) => { - totalBytes += chunk.length; + // If the response is compressed, create a decompressor. + // We feed raw chunks into it and listen on the decompressed output. + // The raw proxyRes still flows to the client unchanged via pipe(). + let decompressor = null; + if (compressed) { + decompressor = createDecompressor(proxyRes.headers); + if (decompressor) { + decompressor.on('error', (err) => { + diag('DECOMPRESS_ERROR', { request_id: requestId, error: err.message }); + }); + } + } + // The source for text parsing: decompressor output (if compressed) or raw chunks + function handleDecodedChunk(text) { if (streaming) { - // Parse SSE data lines from this chunk to extract usage events - const text = partialLine + chunk.toString('utf8'); - // Keep any incomplete line at the end for next chunk - const lastNewline = text.lastIndexOf('\n'); + const combined = partialLine + text; + const lastNewline = combined.lastIndexOf('\n'); if (lastNewline >= 0) { - const complete = text.slice(0, lastNewline); - partialLine = text.slice(lastNewline + 1); + const complete = combined.slice(0, lastNewline); + partialLine = combined.slice(lastNewline + 1); const dataLines = parseSseDataLines(complete); for (const line of dataLines) { const { usage, model } = extractUsageFromSseLine(line); if (model && !streamingModel) streamingModel = model; if (usage) { - // Merge usage fields (Anthropic sends input in message_start, output in message_delta) for (const [k, v] of Object.entries(usage)) { streamingUsage[k] = v; } } } } else { - partialLine = text; + partialLine = combined; } } else if (!overflow) { - if (totalBytes <= MAX_BUFFER_SIZE) { - chunks.push(chunk); - } else { - overflow = true; - chunks.length = 0; // free memory - } + chunks.push(Buffer.from(text, 'utf8')); } - }); + } + + if (decompressor) { + // Feed decompressed text to our parser + decompressor.on('data', (decompressedChunk) => { + handleDecodedChunk(decompressedChunk.toString('utf8')); + }); + + // Feed raw compressed bytes into the decompressor + proxyRes.on('data', (chunk) => { + totalBytes += chunk.length; + try { decompressor.write(chunk); } catch { /* ignore write errors */ } + }); + + proxyRes.on('end', () => { + try { decompressor.end(); } catch { /* ignore */ } + }); + + // Finalize on decompressor end + decompressor.on('end', () => { + finalizeTracking(); + }); + } else { + // No compression — parse raw chunks directly + proxyRes.on('data', (chunk) => { + totalBytes += chunk.length; + handleDecodedChunk(chunk.toString('utf8')); + }); - proxyRes.on('end', () => { + proxyRes.on('end', () => { + finalizeTracking(); + }); + } + + function finalizeTracking() { // Only process successful responses (2xx) - if (proxyRes.statusCode < 200 || proxyRes.statusCode >= 300) return; + if (proxyRes.statusCode < 200 || proxyRes.statusCode >= 300) { + logRequest('debug', 'token_track_skip_status', { + request_id: requestId, + provider, + status: proxyRes.statusCode, + }); + diag('HTTP_TRACK_SKIP_STATUS', { request_id: requestId, provider, status: proxyRes.statusCode }); + return; + } const duration = Date.now() - startTime; let usage = null; @@ -334,6 +445,19 @@ function trackTokenUsage(proxyRes, opts) { model = result.model; } + logRequest('debug', 'token_track_end', { + request_id: requestId, + provider, + streaming, + total_bytes: totalBytes, + overflow, + has_usage: !!usage, + usage_keys: usage ? Object.keys(usage) : [], + model, + compressed, + }); + diag('HTTP_TRACK_END', { request_id: requestId, provider, streaming, total_bytes: totalBytes, overflow, has_usage: !!usage, usage_keys: usage ? Object.keys(usage) : [], model, compressed, content_encoding: contentEncoding }); + const normalized = normalizeUsage(usage); if (!normalized) return; @@ -374,7 +498,224 @@ function trackTokenUsage(proxyRes, opts) { cache_write_tokens: normalized.cache_write_tokens, streaming, }); + } +} + +/** + * Parse WebSocket frames from a buffer (server→client direction, unmasked). + * + * Returns an object with: + * - messages: Array of decoded text frame payloads (strings) + * - consumed: Number of bytes consumed from the buffer + * + * Only handles non-fragmented text frames (FIN=1, opcode=1). + * Other frame types (binary, ping, pong, close, continuation) are consumed + * but their payloads are not returned. + * + * @param {Buffer} buf - Buffer containing WebSocket frame data + * @returns {{ messages: string[], consumed: number }} + */ +function parseWebSocketFrames(buf) { + const messages = []; + let pos = 0; + + while (pos + 2 <= buf.length) { + const firstByte = buf[pos]; + const secondByte = buf[pos + 1]; + const fin = (firstByte & 0x80) !== 0; + const opcode = firstByte & 0x0F; + const masked = (secondByte & 0x80) !== 0; + let payloadLength = secondByte & 0x7F; + let headerSize = 2; + + if (payloadLength === 126) { + if (pos + 4 > buf.length) break; + payloadLength = buf.readUInt16BE(pos + 2); + headerSize = 4; + } else if (payloadLength === 127) { + if (pos + 10 > buf.length) break; + payloadLength = Number(buf.readBigUInt64BE(pos + 2)); + headerSize = 10; + } + + if (masked) { + if (pos + headerSize + 4 > buf.length) break; + headerSize += 4; // masking key length + } + + const frameEnd = pos + headerSize + payloadLength; + if (frameEnd > buf.length) break; + + // Extract text frames (opcode 1) with FIN set + if (opcode === 1 && fin) { + const payloadStart = pos + headerSize; + if (masked) { + const maskKeyStart = payloadStart - 4; + const maskingKey = buf.slice(maskKeyStart, maskKeyStart + 4); + const maskedPayload = buf.slice(payloadStart, frameEnd); + const unmasked = Buffer.allocUnsafe(payloadLength); + for (let i = 0; i < payloadLength; i++) { + unmasked[i] = maskedPayload[i] ^ maskingKey[i % 4]; + } + messages.push(unmasked.toString('utf8')); + } else { + messages.push(buf.slice(payloadStart, frameEnd).toString('utf8')); + } + } + + pos = frameEnd; + } + + return { messages, consumed: pos }; +} + +/** + * Attach token usage tracking to a WebSocket upstream connection. + * + * Claude Code CLI uses WebSocket streaming to the Anthropic API. The + * api-proxy relays this as a raw socket pipe (tlsSocket ↔ clientSocket). + * This function adds a non-blocking 'data' listener on the upstream socket + * to parse WebSocket frames and extract token usage from JSON text messages. + * + * The upstream stream starts with an HTTP 101 response header, followed by + * WebSocket frames. This function skips the HTTP header before parsing frames. + * + * @param {import('tls').TLSSocket} upstreamSocket - Upstream TLS socket + * @param {object} opts + * @param {string} opts.requestId - Request ID for correlation + * @param {string} opts.provider - Provider name (anthropic, copilot, etc.) + * @param {string} opts.path - Request path + * @param {number} opts.startTime - Request start time (Date.now()) + * @param {object} opts.metrics - Metrics module reference + */ +function trackWebSocketTokenUsage(upstreamSocket, opts) { + const { requestId, provider, path: reqPath, startTime, metrics: metricsRef } = opts; + + logRequest('debug', 'ws_token_track_start', { + request_id: requestId, + provider, + path: reqPath, + }); + diag('WS_TRACK_START', { request_id: requestId, provider, path: reqPath }); + + let httpHeaderParsed = false; + let buffer = Buffer.alloc(0); + let totalBytes = 0; + let headerBytes = 0; + let streamingUsage = {}; + let streamingModel = null; + let finalized = false; + let frameCount = 0; + let textMessageCount = 0; + + // Max buffer to prevent unbounded memory growth (1 MB) + const MAX_WS_BUFFER = 1 * 1024 * 1024; + + upstreamSocket.on('data', (chunk) => { + totalBytes += chunk.length; + buffer = Buffer.concat([buffer, chunk]); + + // Safety: drop buffer if it grows too large (malformed frames) + if (buffer.length > MAX_WS_BUFFER) { + buffer = Buffer.alloc(0); + httpHeaderParsed = true; // skip header parsing + return; + } + + // Skip the HTTP 101 Switching Protocols response header + if (!httpHeaderParsed) { + const headerEnd = buffer.indexOf('\r\n\r\n'); + if (headerEnd === -1) return; // need more data for full header + headerBytes = headerEnd + 4; + buffer = buffer.slice(headerBytes); + httpHeaderParsed = true; + } + + // Parse any complete WebSocket frames + const { messages, consumed } = parseWebSocketFrames(buffer); + if (consumed > 0) { + buffer = buffer.slice(consumed); + } + frameCount += messages.length; + + for (const text of messages) { + textMessageCount++; + const { usage, model } = extractUsageFromSseLine(text); + if (model && !streamingModel) streamingModel = model; + if (usage) { + logRequest('debug', 'ws_token_usage_found', { + request_id: requestId, + provider, + usage_keys: Object.keys(usage), + model, + }); + for (const [k, v] of Object.entries(usage)) { + streamingUsage[k] = v; + } + } + } }); + + function doFinalize() { + if (finalized) return; + finalized = true; + + logRequest('debug', 'ws_token_track_end', { + request_id: requestId, + provider, + total_bytes: totalBytes, + frame_count: frameCount, + text_message_count: textMessageCount, + has_usage: Object.keys(streamingUsage).length > 0, + usage_keys: Object.keys(streamingUsage), + model: streamingModel, + }); + diag('WS_TRACK_END', { request_id: requestId, provider, total_bytes: totalBytes, frame_count: frameCount, text_message_count: textMessageCount, has_usage: Object.keys(streamingUsage).length > 0, usage_keys: Object.keys(streamingUsage), model: streamingModel }); + + if (Object.keys(streamingUsage).length === 0) return; + + const duration = Date.now() - startTime; + const normalized = normalizeUsage(streamingUsage); + if (!normalized) return; + + if (metricsRef) { + metricsRef.increment('input_tokens_total', { provider }, normalized.input_tokens); + metricsRef.increment('output_tokens_total', { provider }, normalized.output_tokens); + } + + const record = { + timestamp: new Date().toISOString(), + request_id: requestId, + provider, + model: streamingModel || 'unknown', + path: reqPath, + status: 101, + streaming: true, + input_tokens: normalized.input_tokens, + output_tokens: normalized.output_tokens, + cache_read_tokens: normalized.cache_read_tokens, + cache_write_tokens: normalized.cache_write_tokens, + duration_ms: duration, + response_bytes: totalBytes - headerBytes, + }; + + writeTokenUsage(record); + + logRequest('info', 'token_usage', { + request_id: requestId, + provider, + model: streamingModel || 'unknown', + input_tokens: normalized.input_tokens, + output_tokens: normalized.output_tokens, + cache_read_tokens: normalized.cache_read_tokens, + cache_write_tokens: normalized.cache_write_tokens, + streaming: true, + transport: 'websocket', + }); + } + + upstreamSocket.on('close', doFinalize); + upstreamSocket.on('end', doFinalize); } /** @@ -383,26 +724,32 @@ function trackTokenUsage(proxyRes, opts) { */ function closeLogStream() { return new Promise((resolve) => { + let pending = 0; + const check = () => { if (pending === 0) resolve(); }; if (logStream) { - logStream.end(() => { - logStream = null; - resolve(); - }); - } else { - resolve(); + pending++; + logStream.end(() => { logStream = null; pending--; check(); }); + } + if (diagStream) { + pending++; + diagStream.end(() => { diagStream = null; pending--; check(); }); } + if (pending === 0) resolve(); }); } module.exports = { trackTokenUsage, + trackWebSocketTokenUsage, closeLogStream, // Exported for testing extractUsageFromJson, extractUsageFromSseLine, parseSseDataLines, + parseWebSocketFrames, normalizeUsage, isStreamingResponse, + isCompressedResponse, writeTokenUsage, TOKEN_LOG_FILE, }; diff --git a/containers/api-proxy/token-tracker.test.js b/containers/api-proxy/token-tracker.test.js index 6238f1874..9d2effba7 100644 --- a/containers/api-proxy/token-tracker.test.js +++ b/containers/api-proxy/token-tracker.test.js @@ -6,14 +6,18 @@ const { extractUsageFromJson, extractUsageFromSseLine, parseSseDataLines, + parseWebSocketFrames, normalizeUsage, isStreamingResponse, + isCompressedResponse, trackTokenUsage, + trackWebSocketTokenUsage, } = require('./token-tracker'); const { EventEmitter } = require('events'); const os = require('os'); const path = require('path'); const fs = require('fs'); +const zlib = require('zlib'); // Redirect token log output to a temp dir to avoid /var/log permission errors let tmpLogDir; @@ -448,3 +452,512 @@ describe('trackTokenUsage', () => { }, 10); }); }); + +// ── isCompressedResponse ────────────────────────────────────────────── + +describe('isCompressedResponse', () => { + test('detects gzip encoding', () => { + expect(isCompressedResponse({ 'content-encoding': 'gzip' })).toBe(true); + }); + + test('detects deflate encoding', () => { + expect(isCompressedResponse({ 'content-encoding': 'deflate' })).toBe(true); + }); + + test('detects br (brotli) encoding', () => { + expect(isCompressedResponse({ 'content-encoding': 'br' })).toBe(true); + }); + + test('returns false for no encoding', () => { + expect(isCompressedResponse({})).toBe(false); + expect(isCompressedResponse({ 'content-encoding': '' })).toBe(false); + expect(isCompressedResponse({ 'content-encoding': 'identity' })).toBe(false); + }); +}); + +// ── trackTokenUsage with compressed responses ───────────────────────── + +describe('trackTokenUsage (compressed responses)', () => { + test('decompresses gzip SSE streaming response and extracts usage', (done) => { + const proxyRes = new EventEmitter(); + proxyRes.headers = { + 'content-type': 'text/event-stream; charset=utf-8', + 'content-encoding': 'gzip', + }; + proxyRes.statusCode = 200; + + const metricsRef = { increment: jest.fn() }; + + trackTokenUsage(proxyRes, { + requestId: 'test-gzip-sse', + provider: 'anthropic', + path: '/v1/messages?beta=true', + startTime: Date.now(), + metrics: metricsRef, + }); + + // Build Anthropic SSE data (plaintext) + const sseText = + 'event: message_start\ndata: ' + JSON.stringify({ + type: 'message_start', + message: { model: 'claude-sonnet-4-20250514', usage: { input_tokens: 1000, cache_read_input_tokens: 800 } }, + }) + '\n\n' + + 'event: content_block_delta\ndata: ' + JSON.stringify({ + type: 'content_block_delta', + delta: { type: 'text_delta', text: 'Hello' }, + }) + '\n\n' + + 'event: message_delta\ndata: ' + JSON.stringify({ + type: 'message_delta', + usage: { output_tokens: 42 }, + }) + '\n\ndata: [DONE]\n\n'; + + // Compress the SSE data with gzip + zlib.gzip(Buffer.from(sseText), (err, compressed) => { + expect(err).toBeNull(); + + // Emit compressed data (simulating Anthropic API response) + proxyRes.emit('data', compressed); + proxyRes.emit('end'); + + // Allow time for decompression pipeline + setTimeout(() => { + expect(metricsRef.increment).toHaveBeenCalledWith( + 'input_tokens_total', + { provider: 'anthropic' }, + 1000, + ); + expect(metricsRef.increment).toHaveBeenCalledWith( + 'output_tokens_total', + { provider: 'anthropic' }, + 42, + ); + done(); + }, 50); + }); + }); + + test('decompresses gzip non-streaming JSON and extracts usage', (done) => { + const proxyRes = new EventEmitter(); + proxyRes.headers = { + 'content-type': 'application/json', + 'content-encoding': 'gzip', + }; + proxyRes.statusCode = 200; + + const metricsRef = { increment: jest.fn() }; + + trackTokenUsage(proxyRes, { + requestId: 'test-gzip-json', + provider: 'anthropic', + path: '/v1/messages', + startTime: Date.now(), + metrics: metricsRef, + }); + + const body = JSON.stringify({ + model: 'claude-sonnet-4-20250514', + usage: { input_tokens: 200, output_tokens: 30 }, + }); + + zlib.gzip(Buffer.from(body), (err, compressed) => { + expect(err).toBeNull(); + proxyRes.emit('data', compressed); + proxyRes.emit('end'); + + setTimeout(() => { + expect(metricsRef.increment).toHaveBeenCalledWith( + 'input_tokens_total', + { provider: 'anthropic' }, + 200, + ); + expect(metricsRef.increment).toHaveBeenCalledWith( + 'output_tokens_total', + { provider: 'anthropic' }, + 30, + ); + done(); + }, 50); + }); + }); + + test('handles multi-chunk gzip SSE response', (done) => { + const proxyRes = new EventEmitter(); + proxyRes.headers = { + 'content-type': 'text/event-stream; charset=utf-8', + 'content-encoding': 'gzip', + }; + proxyRes.statusCode = 200; + + const metricsRef = { increment: jest.fn() }; + + trackTokenUsage(proxyRes, { + requestId: 'test-gzip-multi', + provider: 'anthropic', + path: '/v1/messages?beta=true', + startTime: Date.now(), + metrics: metricsRef, + }); + + const sseText = + 'event: message_start\ndata: ' + JSON.stringify({ + type: 'message_start', + message: { model: 'claude-sonnet-4-20250514', usage: { input_tokens: 5000 } }, + }) + '\n\n' + + 'event: message_delta\ndata: ' + JSON.stringify({ + type: 'message_delta', + usage: { output_tokens: 100 }, + }) + '\n\n'; + + zlib.gzip(Buffer.from(sseText), (err, compressed) => { + expect(err).toBeNull(); + + // Split compressed data into multiple chunks to simulate network delivery + const mid = Math.floor(compressed.length / 2); + proxyRes.emit('data', compressed.slice(0, mid)); + proxyRes.emit('data', compressed.slice(mid)); + proxyRes.emit('end'); + + setTimeout(() => { + expect(metricsRef.increment).toHaveBeenCalledWith( + 'input_tokens_total', + { provider: 'anthropic' }, + 5000, + ); + expect(metricsRef.increment).toHaveBeenCalledWith( + 'output_tokens_total', + { provider: 'anthropic' }, + 100, + ); + done(); + }, 50); + }); + }); + + test('still works with uncompressed SSE (no content-encoding)', (done) => { + // Verify existing uncompressed path still works + const proxyRes = new EventEmitter(); + proxyRes.headers = { 'content-type': 'text/event-stream' }; + proxyRes.statusCode = 200; + + const metricsRef = { increment: jest.fn() }; + + trackTokenUsage(proxyRes, { + requestId: 'test-uncompressed', + provider: 'anthropic', + path: '/v1/messages', + startTime: Date.now(), + metrics: metricsRef, + }); + + const chunk = 'event: message_start\ndata: ' + JSON.stringify({ + type: 'message_start', + message: { model: 'claude-sonnet-4-20250514', usage: { input_tokens: 300 } }, + }) + '\n\nevent: message_delta\ndata: ' + JSON.stringify({ + type: 'message_delta', + usage: { output_tokens: 20 }, + }) + '\n\n'; + + proxyRes.emit('data', Buffer.from(chunk)); + proxyRes.emit('end'); + + setTimeout(() => { + expect(metricsRef.increment).toHaveBeenCalledWith( + 'input_tokens_total', + { provider: 'anthropic' }, + 300, + ); + expect(metricsRef.increment).toHaveBeenCalledWith( + 'output_tokens_total', + { provider: 'anthropic' }, + 20, + ); + done(); + }, 10); + }); +}); + +// ── parseWebSocketFrames ────────────────────────────────────────────── + +/** + * Helper: build a WebSocket text frame (server→client, unmasked). + */ +function buildTextFrame(text) { + const payload = Buffer.from(text, 'utf8'); + const len = payload.length; + + let header; + if (len < 126) { + header = Buffer.alloc(2); + header[0] = 0x81; // FIN + text opcode + header[1] = len; + } else if (len < 65536) { + header = Buffer.alloc(4); + header[0] = 0x81; + header[1] = 126; + header.writeUInt16BE(len, 2); + } else { + header = Buffer.alloc(10); + header[0] = 0x81; + header[1] = 127; + header.writeBigUInt64BE(BigInt(len), 2); + } + + return Buffer.concat([header, payload]); +} + +describe('parseWebSocketFrames', () => { + test('parses a single small text frame', () => { + const frame = buildTextFrame('{"type":"message_start"}'); + const { messages, consumed } = parseWebSocketFrames(frame); + expect(messages).toEqual(['{"type":"message_start"}']); + expect(consumed).toBe(frame.length); + }); + + test('parses multiple text frames', () => { + const f1 = buildTextFrame('{"type":"message_start"}'); + const f2 = buildTextFrame('{"type":"message_delta"}'); + const buf = Buffer.concat([f1, f2]); + const { messages, consumed } = parseWebSocketFrames(buf); + expect(messages).toHaveLength(2); + expect(messages[0]).toBe('{"type":"message_start"}'); + expect(messages[1]).toBe('{"type":"message_delta"}'); + expect(consumed).toBe(buf.length); + }); + + test('handles partial frame (not enough data)', () => { + const frame = buildTextFrame('{"type":"test"}'); + // Give only half the frame + const partial = frame.slice(0, Math.floor(frame.length / 2)); + const { messages, consumed } = parseWebSocketFrames(partial); + expect(messages).toHaveLength(0); + expect(consumed).toBe(0); + }); + + test('handles medium payload (126-byte extended length)', () => { + const text = 'x'.repeat(200); + const frame = buildTextFrame(text); + // Verify 4-byte header was used (126 extended) + expect(frame[1] & 0x7F).toBe(126); + const { messages, consumed } = parseWebSocketFrames(frame); + expect(messages).toEqual([text]); + expect(consumed).toBe(frame.length); + }); + + test('skips binary frames (opcode 2)', () => { + const payload = Buffer.from([1, 2, 3, 4]); + const header = Buffer.alloc(2); + header[0] = 0x82; // FIN + binary opcode + header[1] = payload.length; + const binaryFrame = Buffer.concat([header, payload]); + + const textFrame = buildTextFrame('{"type":"text"}'); + const buf = Buffer.concat([binaryFrame, textFrame]); + + const { messages, consumed } = parseWebSocketFrames(buf); + expect(messages).toEqual(['{"type":"text"}']); + expect(consumed).toBe(buf.length); + }); + + test('skips ping frames (opcode 9)', () => { + const header = Buffer.alloc(2); + header[0] = 0x89; // FIN + ping opcode + header[1] = 0; // empty payload + const pingFrame = header; + + const textFrame = buildTextFrame('{"type":"data"}'); + const buf = Buffer.concat([pingFrame, textFrame]); + + const { messages, consumed } = parseWebSocketFrames(buf); + expect(messages).toEqual(['{"type":"data"}']); + expect(consumed).toBe(buf.length); + }); + + test('handles empty buffer', () => { + const { messages, consumed } = parseWebSocketFrames(Buffer.alloc(0)); + expect(messages).toHaveLength(0); + expect(consumed).toBe(0); + }); + + test('handles buffer with only 1 byte', () => { + const { messages, consumed } = parseWebSocketFrames(Buffer.alloc(1)); + expect(messages).toHaveLength(0); + expect(consumed).toBe(0); + }); + + test('unmasks masked text frames correctly', () => { + const text = '{"type":"message_start"}'; + const payload = Buffer.from(text, 'utf8'); + const maskingKey = Buffer.from([0x37, 0xfa, 0x21, 0x3d]); + + // Build masked frame: FIN + text opcode, masked bit + length, key, masked payload + const header = Buffer.alloc(2 + 4); + header[0] = 0x81; // FIN + text + header[1] = 0x80 | payload.length; // masked bit set + length + maskingKey.copy(header, 2); + + const maskedPayload = Buffer.allocUnsafe(payload.length); + for (let i = 0; i < payload.length; i++) { + maskedPayload[i] = payload[i] ^ maskingKey[i % 4]; + } + + const frame = Buffer.concat([header, maskedPayload]); + const { messages, consumed } = parseWebSocketFrames(frame); + expect(messages).toEqual([text]); + expect(consumed).toBe(frame.length); + }); +}); + +// ── trackWebSocketTokenUsage ────────────────────────────────────────── + +describe('trackWebSocketTokenUsage', () => { + test('extracts Anthropic token usage from WebSocket frames', (done) => { + const socket = new EventEmitter(); + const metricsRef = { increment: jest.fn() }; + + trackWebSocketTokenUsage(socket, { + requestId: 'ws-test-1', + provider: 'anthropic', + path: '/v1/messages', + startTime: Date.now(), + metrics: metricsRef, + }); + + // Send HTTP 101 response header + socket.emit('data', Buffer.from( + 'HTTP/1.1 101 Switching Protocols\r\n' + + 'Upgrade: websocket\r\n' + + 'Connection: Upgrade\r\n' + + '\r\n' + )); + + // Send message_start with input tokens + const msgStart = JSON.stringify({ + type: 'message_start', + message: { + model: 'claude-sonnet-4.6', + usage: { input_tokens: 1500, cache_creation_input_tokens: 0, cache_read_input_tokens: 200 }, + }, + }); + socket.emit('data', buildTextFrame(msgStart)); + + // Send message_delta with output tokens + const msgDelta = JSON.stringify({ + type: 'message_delta', + usage: { output_tokens: 350 }, + }); + socket.emit('data', buildTextFrame(msgDelta)); + + // Close socket + socket.emit('close'); + + setTimeout(() => { + expect(metricsRef.increment).toHaveBeenCalledWith( + 'input_tokens_total', { provider: 'anthropic' }, 1500 + ); + expect(metricsRef.increment).toHaveBeenCalledWith( + 'output_tokens_total', { provider: 'anthropic' }, 350 + ); + done(); + }, 10); + }); + + test('handles HTTP 101 header and frames in same chunk', (done) => { + const socket = new EventEmitter(); + const metricsRef = { increment: jest.fn() }; + + trackWebSocketTokenUsage(socket, { + requestId: 'ws-test-2', + provider: 'anthropic', + path: '/v1/messages', + startTime: Date.now(), + metrics: metricsRef, + }); + + // Send 101 header + frame in a single chunk + const header = 'HTTP/1.1 101 Switching Protocols\r\n\r\n'; + const frame = buildTextFrame(JSON.stringify({ + type: 'message_start', + message: { + model: 'claude-sonnet-4.6', + usage: { input_tokens: 500 }, + }, + })); + socket.emit('data', Buffer.concat([Buffer.from(header), frame])); + + const deltaFrame = buildTextFrame(JSON.stringify({ + type: 'message_delta', + usage: { output_tokens: 100 }, + })); + socket.emit('data', deltaFrame); + socket.emit('end'); + + setTimeout(() => { + expect(metricsRef.increment).toHaveBeenCalledWith( + 'input_tokens_total', { provider: 'anthropic' }, 500 + ); + expect(metricsRef.increment).toHaveBeenCalledWith( + 'output_tokens_total', { provider: 'anthropic' }, 100 + ); + done(); + }, 10); + }); + + test('does not log when no usage data is found', (done) => { + const socket = new EventEmitter(); + const metricsRef = { increment: jest.fn() }; + + trackWebSocketTokenUsage(socket, { + requestId: 'ws-test-3', + provider: 'anthropic', + path: '/v1/messages', + startTime: Date.now(), + metrics: metricsRef, + }); + + socket.emit('data', Buffer.from('HTTP/1.1 101 Switching Protocols\r\n\r\n')); + // Send a content_block_delta (no usage data) + socket.emit('data', buildTextFrame(JSON.stringify({ + type: 'content_block_delta', + delta: { type: 'text_delta', text: 'Hello' }, + }))); + socket.emit('close'); + + setTimeout(() => { + expect(metricsRef.increment).not.toHaveBeenCalled(); + done(); + }, 10); + }); + + test('only finalizes once (close + end)', (done) => { + const socket = new EventEmitter(); + const metricsRef = { increment: jest.fn() }; + + trackWebSocketTokenUsage(socket, { + requestId: 'ws-test-4', + provider: 'anthropic', + path: '/v1/messages', + startTime: Date.now(), + metrics: metricsRef, + }); + + socket.emit('data', Buffer.from('HTTP/1.1 101 Switching Protocols\r\n\r\n')); + socket.emit('data', buildTextFrame(JSON.stringify({ + type: 'message_start', + message: { model: 'claude-sonnet-4.6', usage: { input_tokens: 100 } }, + }))); + socket.emit('data', buildTextFrame(JSON.stringify({ + type: 'message_delta', + usage: { output_tokens: 50 }, + }))); + + // Both close and end fire + socket.emit('close'); + socket.emit('end'); + + setTimeout(() => { + // Should only be called once despite both events + expect(metricsRef.increment).toHaveBeenCalledTimes(2); + done(); + }, 10); + }); +});