diff --git a/packages/builtin-tools/src/tools/AgentTool/__tests__/filterIncompleteToolCalls.test.ts b/packages/builtin-tools/src/tools/AgentTool/__tests__/filterIncompleteToolCalls.test.ts new file mode 100644 index 000000000..5429c6ce8 --- /dev/null +++ b/packages/builtin-tools/src/tools/AgentTool/__tests__/filterIncompleteToolCalls.test.ts @@ -0,0 +1,180 @@ +import { describe, expect, test } from 'bun:test' +import type { Message } from 'src/types/message.js' +import { filterIncompleteToolCalls } from '../filterIncompleteToolCalls.js' + +describe('filterIncompleteToolCalls', () => { + test('drops assistant tool uses that do not have matching results', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [{ type: 'tool_use', id: 'missing', name: 'Read' }], + }, + }, + { + type: 'user', + uuid: 'u1', + message: { role: 'user', content: 'continue' }, + }, + ] as unknown as Message[] + + expect( + filterIncompleteToolCalls(messages).map(message => String(message.uuid)), + ).toEqual(['u1']) + }) + + test('preserves assistant text when dropping orphan tool uses', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [ + { type: 'text', text: 'I will read the file.' }, + { type: 'tool_use', id: 'missing', name: 'Read' }, + ], + }, + }, + ] as unknown as Message[] + + const filtered = filterIncompleteToolCalls(messages) + expect(filtered).toHaveLength(1) + const first = filtered[0]! + const content = first.message!.content + expect( + Array.isArray(content) ? content.map(block => block.type) : [], + ).toEqual(['text']) + }) + + test('keeps completed parallel tool calls when dropping an orphan', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [ + { type: 'tool_use', id: 'done', name: 'Read' }, + { type: 'tool_use', id: 'missing', name: 'Grep' }, + ], + }, + }, + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [{ type: 'tool_result', tool_use_id: 'done', content: 'ok' }], + }, + }, + ] as unknown as Message[] + + const filtered = filterIncompleteToolCalls(messages) + expect(filtered.map(message => String(message.uuid))).toEqual(['a1', 'u1']) + const first = filtered[0]! + const content = first.message!.content + expect( + Array.isArray(content) + ? content.map(block => + block.type === 'tool_use' ? block.id : block.type, + ) + : [], + ).toEqual(['done']) + }) + + test('keeps assistant tool uses that have matching results', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [{ type: 'tool_use', id: 'done', name: 'Read' }], + }, + }, + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [{ type: 'tool_result', tool_use_id: 'done', content: 'ok' }], + }, + }, + ] as unknown as Message[] + + expect( + filterIncompleteToolCalls(messages).map(message => String(message.uuid)), + ).toEqual(['a1', 'u1']) + }) + + test('drops orphan tool results when their tool use was removed', () => { + const messages = [ + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [ + { type: 'tool_result', tool_use_id: 'missing', content: 'late' }, + ], + }, + }, + ] as unknown as Message[] + + expect(filterIncompleteToolCalls(messages)).toEqual([]) + }) + + test('keeps user text while dropping orphan tool results', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { role: 'assistant', content: 'done' }, + }, + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [ + { type: 'text', text: 'keep this' }, + { type: 'tool_result', tool_use_id: 'missing', content: 'late' }, + ], + }, + }, + ] as unknown as Message[] + + const filtered = filterIncompleteToolCalls(messages) + expect(filtered.map(message => String(message.uuid))).toEqual(['a1', 'u1']) + const content = filtered[1]!.message!.content + expect(Array.isArray(content) ? content : []).toEqual([ + { type: 'text', text: 'keep this' }, + ]) + }) + + test('drops malformed tool blocks without ids', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [{ type: 'tool_use', name: 'Read' }], + }, + }, + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [{ type: 'tool_result', content: 'late' }], + }, + }, + ] as unknown as Message[] + + expect(filterIncompleteToolCalls(messages)).toEqual([]) + }) +}) diff --git a/packages/builtin-tools/src/tools/AgentTool/filterIncompleteToolCalls.ts b/packages/builtin-tools/src/tools/AgentTool/filterIncompleteToolCalls.ts new file mode 100644 index 000000000..7e30754ee --- /dev/null +++ b/packages/builtin-tools/src/tools/AgentTool/filterIncompleteToolCalls.ts @@ -0,0 +1,110 @@ +import type { + AssistantMessage, + Message, + UserMessage, +} from 'src/types/message.js' + +/** + * Removes invalid or orphaned tool_use/tool_result blocks while preserving + * completed tool-call pairs. This is intentionally block-level, not + * message-level, so completed parallel tool calls stay paired with results. + */ +export function filterIncompleteToolCalls(messages: Message[]): Message[] { + const toolUseIdsWithResults = new Set() + + for (const message of messages) { + if (message?.type === 'user') { + const userMessage = message as UserMessage + const content = userMessage.message.content + if (Array.isArray(content)) { + for (const block of content) { + if (block.type === 'tool_result' && block.tool_use_id) { + toolUseIdsWithResults.add(block.tool_use_id) + } + } + } + } + } + + const retainedToolUseIds = new Set() + const withoutOrphanToolUses: Message[] = [] + + for (const message of messages) { + if (message?.type === 'assistant') { + const assistantMessage = message as AssistantMessage + const content = assistantMessage.message.content + if (Array.isArray(content)) { + let changed = false + const filteredContent = content.filter(block => { + if (block.type !== 'tool_use') return true + if (!block.id) { + changed = true + return false + } + if (toolUseIdsWithResults.has(block.id)) { + retainedToolUseIds.add(block.id) + return true + } + changed = true + return false + }) + + if (!changed) { + withoutOrphanToolUses.push(message) + continue + } + if (filteredContent.length > 0) { + withoutOrphanToolUses.push({ + ...assistantMessage, + message: { + ...assistantMessage.message, + content: filteredContent, + }, + }) + } + continue + } + } + withoutOrphanToolUses.push(message) + } + + const filteredMessages: Message[] = [] + for (const message of withoutOrphanToolUses) { + if (message?.type !== 'user') { + filteredMessages.push(message) + continue + } + const userMessage = message as UserMessage + const content = userMessage.message.content + if (!Array.isArray(content)) { + filteredMessages.push(message) + continue + } + let changed = false + const filteredContent = content.filter(block => { + if (block.type !== 'tool_result') return true + if (!block.tool_use_id) { + changed = true + return false + } + if (retainedToolUseIds.has(block.tool_use_id)) return true + changed = true + return false + }) + if (!changed) { + filteredMessages.push(message) + continue + } + if (filteredContent.length > 0) { + filteredMessages.push({ + ...userMessage, + message: { + ...userMessage.message, + content: filteredContent, + }, + }) + } + } + + return filteredMessages +} diff --git a/packages/builtin-tools/src/tools/AgentTool/runAgent.ts b/packages/builtin-tools/src/tools/AgentTool/runAgent.ts index baeed9022..de55b53f8 100644 --- a/packages/builtin-tools/src/tools/AgentTool/runAgent.ts +++ b/packages/builtin-tools/src/tools/AgentTool/runAgent.ts @@ -86,8 +86,11 @@ import { import type { ContentReplacementState } from 'src/utils/toolResultStorage.js' import { createAgentId } from 'src/utils/uuid.js' import { resolveAgentTools } from './agentToolUtils.js' +import { filterIncompleteToolCalls } from './filterIncompleteToolCalls.js' import { type AgentDefinition, isBuiltInAgent } from './loadAgentsDir.js' +export { filterIncompleteToolCalls } from './filterIncompleteToolCalls.js' + /** * Initialize agent-specific MCP servers * Agents can define their own MCP servers in their frontmatter that are additive @@ -886,50 +889,6 @@ export async function* runAgent({ } } -/** - * Filters out assistant messages with incomplete tool calls (tool uses without results). - * This prevents API errors when sending messages with orphaned tool calls. - */ -export function filterIncompleteToolCalls(messages: Message[]): Message[] { - // Build a set of tool use IDs that have results - const toolUseIdsWithResults = new Set() - - for (const message of messages) { - if (message?.type === 'user') { - const userMessage = message as UserMessage - const content = userMessage.message.content - if (Array.isArray(content)) { - for (const block of content) { - if (block.type === 'tool_result' && block.tool_use_id) { - toolUseIdsWithResults.add(block.tool_use_id) - } - } - } - } - } - - // Filter out assistant messages that contain tool calls without results - return messages.filter(message => { - if (message?.type === 'assistant') { - const assistantMessage = message as AssistantMessage - const content = assistantMessage.message.content - if (Array.isArray(content)) { - // Check if this assistant message has any tool uses without results - const hasIncompleteToolCall = content.some( - block => - block.type === 'tool_use' && - block.id && - !toolUseIdsWithResults.has(block.id), - ) - // Exclude messages with incomplete tool calls - return !hasIncompleteToolCall - } - } - // Keep all non-assistant messages and assistant messages without tool calls - return true - }) -} - async function getAgentSystemPrompt( agentDefinition: AgentDefinition, toolUseContext: Pick, diff --git a/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts b/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts index e520243a5..f6c52ea80 100644 --- a/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts +++ b/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts @@ -84,22 +84,48 @@ Use this tool to discover messaging targets before sending cross-session message // UDS socket directory. The implementation scans for live sockets // and optionally includes Remote Control bridge peers. const peers: PeerInfo[] = [] + const seen = new Set() + const addPeer = (peer: PeerInfo): void => { + if (seen.has(peer.address)) return + seen.add(peer.address) + peers.push(peer) + } + + /* eslint-disable @typescript-eslint/no-require-imports */ + const udsMessaging = + require('src/utils/udsMessaging.js') as typeof import('src/utils/udsMessaging.js') + const udsClient = + require('src/utils/udsClient.js') as typeof import('src/utils/udsClient.js') + const bridgePeers = + require('src/bridge/peerSessions.js') as typeof import('src/bridge/peerSessions.js') + /* eslint-enable @typescript-eslint/no-require-imports */ - // Discovery is handled by the UDS messaging subsystem initialized in setup.ts. - // Return discovered peers from the app state. - const appState = context.getAppState() - const messagingSocketPath = (appState as Record).messagingSocketPath as string | undefined + const messagingSocketPath = udsMessaging.getUdsMessagingSocketPath() if (messagingSocketPath) { // Self entry for reference if (_input.include_self) { - peers.push({ - address: `uds:${messagingSocketPath}`, + addPeer({ + address: udsMessaging.formatUdsAddress(messagingSocketPath), name: 'self', pid: process.pid, }) } } + for (const peer of await udsClient.listPeers()) { + if (!peer.messagingSocketPath) continue + addPeer({ + address: udsMessaging.formatUdsAddress(peer.messagingSocketPath), + name: peer.name ?? peer.kind, + cwd: peer.cwd, + pid: peer.pid, + }) + } + + for (const peer of await bridgePeers.listBridgePeers()) { + addPeer(peer) + } + return { data: { peers }, } diff --git a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts index 4e9737051..cab0a03c5 100644 --- a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts +++ b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts @@ -130,6 +130,41 @@ export type SendMessageToolOutput = | RequestOutput | ResponseOutput +const UDS_INLINE_TOKEN_MARKER = '#token=' + +function stripInlineUdsToken(target: string): string { + const markerIndex = target.indexOf(UDS_INLINE_TOKEN_MARKER) + return markerIndex === -1 ? target : target.slice(0, markerIndex) +} + +function hasInlineUdsToken(to: string): boolean { + const addr = parseAddress(to) + // Empty-token markers are still inline-token attempts. Observable input + // redaction preserves "#token=" so cloned inputs remain rejected. + return ( + addr.scheme === 'uds' && addr.target.includes(UDS_INLINE_TOKEN_MARKER) + ) +} + +function recipientForDisplay(to: string): string { + const addr = parseAddress(to) + if (addr.scheme !== 'uds') return to + return `uds:${stripInlineUdsToken(addr.target)}` +} + +function redactInlineUdsTokenForRejection(to: string): string { + const addr = parseAddress(to) + if (addr.scheme !== 'uds') return to + const markerIndex = addr.target.indexOf(UDS_INLINE_TOKEN_MARKER) + if (markerIndex === -1) return to + return `uds:${addr.target.slice(0, markerIndex)}${UDS_INLINE_TOKEN_MARKER}` +} + +function redactObservableInlineUdsToken(input: { to: string }): void { + if (!hasInlineUdsToken(input.to)) return + input.to = redactInlineUdsTokenForRejection(input.to) +} + function findTeammateColor( appState: { teamContext?: { teammates: { [id: string]: { color?: string } } } @@ -541,15 +576,17 @@ export const SendMessageTool: Tool = }, backfillObservableInput(input) { - if ('type' in input) return if (typeof input.to !== 'string') return + redactObservableInlineUdsToken(input as { to: string }) + if ('type' in input) return + if (input.to === '*') { input.type = 'broadcast' if (typeof input.message === 'string') input.content = input.message } else if (typeof input.message === 'string') { input.type = 'message' - input.recipient = input.to + input.recipient = recipientForDisplay(input.to) input.content = input.message } else if (typeof input.message === 'object' && input.message !== null) { const msg = input.message as { @@ -560,7 +597,7 @@ export const SendMessageTool: Tool = feedback?: string } input.type = msg.type - input.recipient = input.to + input.recipient = recipientForDisplay(input.to) if (msg.request_id !== undefined) input.request_id = msg.request_id if (msg.approve !== undefined) input.approve = msg.approve const content = msg.reason ?? msg.feedback @@ -569,16 +606,17 @@ export const SendMessageTool: Tool = }, toAutoClassifierInput(input) { + const recipient = recipientForDisplay(input.to) if (typeof input.message === 'string') { - return `to ${input.to}: ${input.message}` + return `to ${recipient}: ${input.message}` } switch (input.message.type) { case 'shutdown_request': - return `shutdown_request to ${input.to}` + return `shutdown_request to ${recipient}` case 'shutdown_response': return `shutdown_response ${input.message.approve ? 'approve' : 'reject'} ${input.message.request_id}` case 'plan_approval_response': - return `plan_approval ${input.message.approve ? 'approve' : 'reject'} to ${input.to}` + return `plan_approval ${input.message.approve ? 'approve' : 'reject'} to ${recipient}` } }, @@ -630,6 +668,17 @@ export const SendMessageTool: Tool = errorCode: 9, } } + if ( + addr.scheme === 'uds' && + hasInlineUdsToken(input.to) + ) { + return { + result: false, + message: + 'uds addresses must not include inline auth tokens; use the ListPeers address', + errorCode: 9, + } + } if (input.to.includes('@')) { return { result: false, @@ -753,6 +802,19 @@ export const SendMessageTool: Tool = }, async call(input, context, canUseTool, assistantMessage) { + if (typeof input.message === 'string') { + const addr = parseAddress(input.to) + if (addr.scheme === 'uds' && hasInlineUdsToken(input.to)) { + return { + data: { + success: false, + message: + 'uds addresses must not include inline auth tokens; use the ListPeers address', + }, + } + } + } + if (feature('UDS_INBOX') && typeof input.message === 'string') { const addr = parseAddress(input.to) if (addr.scheme === 'bridge') { @@ -772,10 +834,10 @@ export const SendMessageTool: Tool = const { postInterClaudeMessage } = require('src/bridge/peerSessions.js') as typeof import('src/bridge/peerSessions.js') /* eslint-enable @typescript-eslint/no-require-imports */ - const result = await postInterClaudeMessage( + const result = (await postInterClaudeMessage( addr.target, input.message, - ) as { ok: boolean; error?: string } + )) as { ok: boolean; error?: string } const preview = input.summary || truncate(input.message, 50) return { data: { @@ -787,6 +849,7 @@ export const SendMessageTool: Tool = } } if (addr.scheme === 'uds') { + const recipient = recipientForDisplay(input.to) /* eslint-disable @typescript-eslint/no-require-imports */ const { sendToUdsSocket } = require('src/utils/udsClient.js') as typeof import('src/utils/udsClient.js') @@ -797,14 +860,14 @@ export const SendMessageTool: Tool = return { data: { success: true, - message: `”${preview}” → ${input.to}`, + message: `”${preview}” → ${recipient}`, }, } } catch (e) { return { data: { success: false, - message: `Failed to send to ${input.to}: ${errorMessage(e)}`, + message: `Failed to send to ${recipient}: ${errorMessage(e)}`, }, } } diff --git a/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts new file mode 100644 index 000000000..e0ce1a823 --- /dev/null +++ b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts @@ -0,0 +1,181 @@ +import { describe, expect, test } from 'bun:test' +import { SendMessageTool } from '../SendMessageTool.js' + +describe('SendMessageTool UDS recipient handling', () => { + test('redacts inline UDS tokens before classifier and observable paths', async () => { + const tokenAddress = 'uds:/tmp/peer.sock#token=secret-token' + + const observableInput = { + to: tokenAddress, + message: 'hello', + } as Record + SendMessageTool.backfillObservableInput!(observableInput) + + expect(observableInput.recipient).toBe('uds:/tmp/peer.sock') + expect(observableInput.to).toBe('uds:/tmp/peer.sock#token=') + expect(JSON.stringify(observableInput)).not.toContain('secret-token') + expect( + SendMessageTool.toAutoClassifierInput({ + to: tokenAddress, + message: 'hello', + }), + ).toBe('to uds:/tmp/peer.sock: hello') + }) + + test('keeps redacted UDS token rejection through observable backfill', async () => { + const observableInput = { + to: 'uds:/tmp/peer.sock#token=secret-token', + message: { + type: 'plan_approval_response', + request_id: 'req-1', + approve: false, + reason: 'needs tests', + }, + } as Record + + SendMessageTool.backfillObservableInput!(observableInput) + + expect(observableInput.to).toBe('uds:/tmp/peer.sock#token=') + expect(observableInput.recipient).toBe('uds:/tmp/peer.sock') + expect(observableInput.type).toBe('plan_approval_response') + expect(observableInput.request_id).toBe('req-1') + expect(observableInput.approve).toBe(false) + expect(observableInput.content).toBe('needs tests') + expect(JSON.stringify(observableInput)).not.toContain('secret-token') + + const result = await SendMessageTool.validateInput!( + observableInput as never, + {} as never, + ) + + expect(result.result).toBe(false) + if (result.result !== false) { + throw new Error('expected validation to reject redacted inline UDS token') + } + expect(result.message).toContain('inline auth tokens') + }) + + test('keeps inline-token rejection when observable input is cloned', async () => { + const observableInput = { + to: 'uds:/tmp/peer.sock#token=secret-token', + message: 'hello', + } as Record + + SendMessageTool.backfillObservableInput!(observableInput) + const clonedInput = { + to: observableInput.to, + message: observableInput.message, + summary: 'hello peer', + } + + const validation = await SendMessageTool.validateInput!( + clonedInput as never, + {} as never, + ) + const result = await SendMessageTool.call( + clonedInput as never, + {} as never, + undefined as never, + undefined as never, + ) + + expect(validation.result).toBe(false) + expect(result.data.success).toBe(false) + expect(JSON.stringify(clonedInput)).not.toContain('secret-token') + expect(JSON.stringify(result)).not.toContain('secret-token') + }) + + test('redacts UDS tokens in structured classifier text', async () => { + const to = 'uds:/tmp/peer.sock#token=secret-token' + + expect( + SendMessageTool.toAutoClassifierInput({ + to, + message: { type: 'shutdown_request' }, + }), + ).toBe('shutdown_request to uds:/tmp/peer.sock') + expect( + SendMessageTool.toAutoClassifierInput({ + to, + message: { + type: 'plan_approval_response', + request_id: 'req-1', + approve: true, + }, + }), + ).toBe('plan_approval approve to uds:/tmp/peer.sock') + expect( + SendMessageTool.toAutoClassifierInput({ + to, + message: { + type: 'plan_approval_response', + request_id: 'req-2', + approve: false, + }, + }), + ).toBe('plan_approval reject to uds:/tmp/peer.sock') + expect( + SendMessageTool.toAutoClassifierInput({ + to, + message: { + type: 'shutdown_response', + request_id: 'shutdown-1', + approve: false, + }, + }), + ).toBe('shutdown_response reject shutdown-1') + }) + + test('redacts from the first inline UDS token marker', async () => { + const tokenAddress = 'uds:/tmp/peer.sock#token=first#token=second' + + const observableInput = { + to: tokenAddress, + message: 'hello', + } as Record + SendMessageTool.backfillObservableInput!(observableInput) + + expect(observableInput.to).toBe('uds:/tmp/peer.sock#token=') + expect(observableInput.recipient).toBe('uds:/tmp/peer.sock') + expect(JSON.stringify(observableInput)).not.toContain('first') + expect(JSON.stringify(observableInput)).not.toContain('second') + expect( + SendMessageTool.toAutoClassifierInput({ + to: tokenAddress, + message: 'hello', + }), + ).toBe('to uds:/tmp/peer.sock: hello') + }) + + test('rejects inline UDS tokens during validation', async () => { + const result = await SendMessageTool.validateInput!( + { + to: 'uds:/tmp/peer.sock#token=secret-token', + message: 'hello', + }, + {} as never, + ) + + expect(result.result).toBe(false) + if (result.result !== false) { + throw new Error('expected validation to reject inline UDS token') + } + expect(result.message).toContain('inline auth tokens') + expect(JSON.stringify(result)).not.toContain('secret-token') + }) + + test('rejects inline UDS tokens during execution without leaking them', async () => { + const result = await SendMessageTool.call( + { + to: 'uds:/tmp/peer.sock#token=secret-token', + message: 'hello', + }, + {} as never, + undefined as never, + undefined as never, + ) + + expect(result.data.success).toBe(false) + expect(JSON.stringify(result)).not.toContain('secret-token') + }) +}) diff --git a/src/bridge/peerSessions.ts b/src/bridge/peerSessions.ts index c194c9b62..716c879de 100644 --- a/src/bridge/peerSessions.ts +++ b/src/bridge/peerSessions.ts @@ -6,6 +6,38 @@ import { getBridgeAccessToken } from './bridgeConfig.js' import { getReplBridgeHandle } from './replBridgeHandle.js' import { toCompatSessionId } from './sessionIdCompat.js' +export type BridgePeerSession = { + address: string + name?: string + cwd?: string + pid?: number +} + +/** + * List locally registered sessions that have published a Remote Control + * session ID. The PID registry is the local source of truth for bridge peers + * already known to this machine; SendMessage can use these bridge: + * addresses when the current process has an active bridge handle. + */ +export async function listBridgePeers(): Promise { + const { listAllLiveSessions } = await import('../utils/udsClient.js') + const sessions = await listAllLiveSessions() + const peers: BridgePeerSession[] = [] + + for (const session of sessions) { + if (session.pid === process.pid || !session.bridgeSessionId) continue + const compatId = toCompatSessionId(session.bridgeSessionId) + peers.push({ + address: `bridge:${compatId}`, + name: session.name ?? session.kind, + cwd: session.cwd, + pid: session.pid, + }) + } + + return peers +} + /** * Send a plain-text message to another Claude session via the bridge API. * diff --git a/src/cli/print.ts b/src/cli/print.ts index eb2b543f8..c4e8c4569 100644 --- a/src/cli/print.ts +++ b/src/cli/print.ts @@ -2763,13 +2763,37 @@ function runHeadlessStreaming( // when a message arrives via the UDS socket in headless mode. if (feature('UDS_INBOX')) { /* eslint-disable @typescript-eslint/no-require-imports */ - const { setOnEnqueue } = require('../utils/udsMessaging.js') + const { drainInbox, setOnEnqueue } = + require('../utils/udsMessaging.js') as typeof import('../utils/udsMessaging.js') /* eslint-enable @typescript-eslint/no-require-imports */ + + const enqueueUdsInboxMessages = (): boolean => { + const entries = drainInbox() + for (const entry of entries) { + const value = + typeof entry.message.data === 'string' + ? entry.message.data + : jsonStringify(entry.message.data) + enqueue({ + mode: 'prompt', + value, + uuid: randomUUID(), + }) + } + return entries.length > 0 + } + setOnEnqueue(() => { if (!inputClosed) { - void run() + if (enqueueUdsInboxMessages()) { + void run() + } } }) + + if (enqueueUdsInboxMessages()) { + void run() + } } // Cron scheduler: runs scheduled_tasks.json tasks in SDK/-p mode. diff --git a/src/commands/peers/peers.ts b/src/commands/peers/peers.ts index aed37d327..fcb7e17a7 100644 --- a/src/commands/peers/peers.ts +++ b/src/commands/peers/peers.ts @@ -1,6 +1,9 @@ import type { LocalCommandCall } from '../../types/command.js' import { listPeers, isPeerAlive } from '../../utils/udsClient.js' -import { getUdsMessagingSocketPath } from '../../utils/udsMessaging.js' +import { + formatUdsAddress, + getUdsMessagingSocketPath, +} from '../../utils/udsMessaging.js' export const call: LocalCommandCall = async (_args, _context) => { const mySocket = getUdsMessagingSocketPath() @@ -29,11 +32,11 @@ export const call: LocalCommandCall = async (_args, _context) => { ? ` started: ${formatAge(peer.startedAt)}` : '' - lines.push( - ` [${status}] PID ${peer.pid} (${label})${cwd}${age}`, - ) + lines.push(` [${status}] PID ${peer.pid} (${label})${cwd}${age}`) if (peer.messagingSocketPath) { - lines.push(` socket: ${peer.messagingSocketPath}`) + lines.push( + ` socket: ${formatUdsAddress(peer.messagingSocketPath)}`, + ) } if (peer.sessionId) { lines.push(` session: ${peer.sessionId}`) @@ -43,7 +46,7 @@ export const call: LocalCommandCall = async (_args, _context) => { lines.push('') lines.push( - 'To message a peer: use SendMessage with to="uds:"', + 'To message a peer: use SendMessage with the shown uds: address', ) return { type: 'text', value: lines.join('\n') } diff --git a/src/commands/poor/__tests__/poorMode.test.ts b/src/commands/poor/__tests__/poorMode.test.ts index c2a80f3cf..539c804e1 100644 --- a/src/commands/poor/__tests__/poorMode.test.ts +++ b/src/commands/poor/__tests__/poorMode.test.ts @@ -5,7 +5,8 @@ * After the fix, it reads from / writes to settings.json via * getInitialSettings() and updateSettingsForSource(). */ -import { describe, expect, test, beforeEach, mock } from 'bun:test' +import { afterAll, describe, expect, test, beforeEach, mock } from 'bun:test' +import * as settingsModule from '../../../utils/settings/settings.js' // ── Mocks must be declared before the module under test is imported ────────── @@ -13,24 +14,48 @@ let mockSettings: Record = {} let lastUpdate: { source: string; patch: Record } | null = null mock.module('src/utils/settings/settings.js', () => ({ + loadManagedFileSettings: () => ({ settings: null, errors: [] }), + getManagedFileSettingsPresence: () => ({ + hasBase: false, + hasDropIns: false, + }), + parseSettingsFile: () => ({ settings: null, errors: [] }), + getSettingsRootPathForSource: () => '', + getSettingsFilePathForSource: () => undefined, + getRelativeSettingsFilePathForSource: () => '', getInitialSettings: () => mockSettings, + getSettingsForSource: () => mockSettings, + getPolicySettingsOrigin: () => null, + getSettingsWithErrors: () => ({ settings: mockSettings, errors: [] }), + getSettingsWithSources: () => ({ effective: mockSettings, sources: [] }), + getSettings_DEPRECATED: () => mockSettings, + settingsMergeCustomizer: () => undefined, + getManagedSettingsKeysForLogging: () => [], + // Keep unrelated exports aligned with the real settings module so this + // full-surface mock cannot change later test files if Bun keeps it alive. + hasAutoModeOptIn: () => true, + hasSkipDangerousModePermissionPrompt: () => false, + getAutoModeConfig: () => undefined, + getUseAutoModeDuringPlan: () => true, + rawSettingsContainsKey: (key: string) => key in mockSettings, updateSettingsForSource: (source: string, patch: Record) => { lastUpdate = { source, patch } mockSettings = { ...mockSettings, ...patch } }, })) -// Import AFTER mocks are registered -const { isPoorModeActive, setPoorMode } = await import('../poorMode.js') - -// ── Helpers ────────────────────────────────────────────────────────────────── +afterAll(() => { + mock.restore() + mock.module('src/utils/settings/settings.js', () => settingsModule) +}) -/** Reset module-level singleton between tests by re-importing a fresh copy. */ -async function freshModule() { - // Bun caches modules; we manipulate the exported functions directly since - // the singleton `poorModeActive` is reset to null only on first import. - // Instead we test the observable behaviour through set/get pairs. -} +// Import AFTER mocks are registered. The query suffix gives this file its own +// module instance so cross-file poorMode.js mocks cannot replace the subject +// under test during Bun's shared coverage run. +const poorModeModulePath = '../poorMode.js?poorModeTest' +const { isPoorModeActive, setPoorMode } = (await import( + poorModeModulePath +)) as typeof import('../poorMode.js') // ── Tests ──────────────────────────────────────────────────────────────────── diff --git a/src/services/AgentSummary/__tests__/agentSummary.test.ts b/src/services/AgentSummary/__tests__/agentSummary.test.ts new file mode 100644 index 000000000..368671f03 --- /dev/null +++ b/src/services/AgentSummary/__tests__/agentSummary.test.ts @@ -0,0 +1,152 @@ +import { beforeEach, describe, expect, test } from 'bun:test' +import { asAgentId } from '../../../types/ids.js' +import type { Message } from '../../../types/message.js' +import type { + CacheSafeParams, + ForkedAgentResult, +} from '../../../utils/forkedAgent.js' +import { startAgentSummarization } from '../agentSummary.js' + +const transcriptMessages = [ + { type: 'user', message: { content: 'start' }, uuid: 'u1' }, + { + type: 'assistant', + message: { content: [{ type: 'text', text: 'working' }] }, + uuid: 'a1', + }, + { type: 'user', message: { content: 'continue' }, uuid: 'u2' }, +] as unknown as Message[] + +type ForkCall = { + cacheSafeParams: CacheSafeParams +} + +describe('startAgentSummarization', () => { + let scheduled: (() => void | Promise) | undefined + let handle: { stop: () => void } | undefined + let forkCalls: ForkCall[] + let updateCalls: Array<{ taskId: string; summary: string }> + let transcriptMessagesForTest: Message[] + + beforeEach(() => { + forkCalls = [] + updateCalls = [] + scheduled = undefined + handle = undefined + transcriptMessagesForTest = transcriptMessages + }) + + test('summarizes bounded transcript once and skips unchanged fingerprints', async () => { + handle = startAgentSummarization( + 'task-1', + asAgentId('a0000000000000000'), + { + forkContextMessages: [ + { type: 'user', message: { content: 'stale' }, uuid: 'old' }, + ], + model: 'claude-test', + } as unknown as CacheSafeParams, + () => undefined, + { + clearTimeout: () => undefined, + getAgentTranscript: async () => ({ + messages: transcriptMessagesForTest, + contentReplacements: [], + }), + isPoorModeActive: () => false, + logError: () => undefined, + logForDebugging: () => undefined, + runForkedAgent: async (args: ForkCall) => { + forkCalls.push(args) + return { + messages: [ + { + type: 'assistant', + message: { + content: [{ type: 'text', text: 'Reading udsClient.ts' }], + }, + }, + ], + } as unknown as ForkedAgentResult + }, + setTimeout: ((callback: TimerHandler) => { + if (typeof callback !== 'function') { + throw new Error('Expected timer callback') + } + scheduled = callback as () => void | Promise + return 1 as unknown as ReturnType + }) as unknown as typeof setTimeout, + updateAgentSummary: (taskId: string, summary: string) => { + updateCalls.push({ taskId, summary }) + }, + }, + ) + + expect(typeof scheduled).toBe('function') + await scheduled!() + + expect(forkCalls).toHaveLength(1) + expect(updateCalls).toEqual([ + { taskId: 'task-1', summary: 'Reading udsClient.ts' }, + ]) + + const forkContext = forkCalls[0].cacheSafeParams.forkContextMessages ?? [] + expect(forkContext.map(message => String(message.uuid))).toEqual([ + 'u1', + 'a1', + 'u2', + ]) + expect(forkContext.some(message => String(message.uuid) === 'old')).toBe( + false, + ) + + await scheduled!() + + expect(forkCalls).toHaveLength(1) + expect(updateCalls).toHaveLength(1) + }) + + test('skips summarization when bounded context is too small', async () => { + transcriptMessagesForTest = transcriptMessages.slice(0, 2) + + handle = startAgentSummarization( + 'task-1', + asAgentId('a0000000000000000'), + { + forkContextMessages: transcriptMessages, + model: 'claude-test', + } as unknown as CacheSafeParams, + () => undefined, + { + clearTimeout: () => undefined, + getAgentTranscript: async () => ({ + messages: transcriptMessagesForTest, + contentReplacements: [], + }), + isPoorModeActive: () => false, + logError: () => undefined, + logForDebugging: () => undefined, + runForkedAgent: async (args: ForkCall) => { + forkCalls.push(args) + return { messages: [] } as unknown as ForkedAgentResult + }, + setTimeout: ((callback: TimerHandler) => { + if (typeof callback !== 'function') { + throw new Error('Expected timer callback') + } + scheduled = callback as () => void | Promise + return 1 as unknown as ReturnType + }) as unknown as typeof setTimeout, + updateAgentSummary: (taskId: string, summary: string) => { + updateCalls.push({ taskId, summary }) + }, + }, + ) + + expect(typeof scheduled).toBe('function') + await scheduled!() + + expect(forkCalls).toEqual([]) + expect(updateCalls).toEqual([]) + }) +}) diff --git a/src/services/AgentSummary/__tests__/summaryContext.test.ts b/src/services/AgentSummary/__tests__/summaryContext.test.ts new file mode 100644 index 000000000..fe0eb3057 --- /dev/null +++ b/src/services/AgentSummary/__tests__/summaryContext.test.ts @@ -0,0 +1,261 @@ +import { describe, expect, test } from 'bun:test' +import type { Message } from '../../../types/message.js' +import { + buildSummaryContext, + estimateMessageChars, + getSummaryContextFingerprint, + MAX_SUMMARY_CONTEXT_CHARS, + selectSummaryContextMessages, +} from '../summaryContext.js' + +function makeMessage( + type: 'user' | 'assistant', + uuid: string, + content: string, +): Message { + return { + type, + uuid, + message: { + role: type, + content, + }, + } as unknown as Message +} + +describe('selectSummaryContextMessages', () => { + test('keeps a bounded recent suffix that starts with a user message', () => { + const messages = [ + makeMessage('assistant', 'a0', 'older assistant'), + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'first response'), + makeMessage('user', 'u2', 'second prompt'), + makeMessage('assistant', 'a2', 'second response'), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 3, + maxChars: 1_000, + }) + + expect(selected.map(message => String(message.uuid))).toEqual(['u2', 'a2']) + }) + + test('returns no context when the newest message exceeds the byte budget', () => { + const messages = [ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'x'.repeat(100)), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 10, + maxChars: 10, + }) + + expect(selected).toEqual([]) + }) + + test('uses serialized message size for nested content budgets', () => { + const messages = [ + makeMessage('user', 'u1', 'first prompt'), + { + ...makeMessage('assistant', 'a1', 'short'), + nested: { + payload: Array.from({ length: 50 }, (_value, index) => ({ + index, + text: 'x'.repeat(20), + })), + }, + } as unknown as Message, + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 10, + maxChars: 200, + }) + + expect(selected).toEqual([]) + }) + + test('stops at an older oversized message after keeping the recent suffix', () => { + const messages = [ + makeMessage('user', 'u1', 'x'.repeat(5_000)), + makeMessage('user', 'u2', 'small prompt'), + makeMessage('assistant', 'a2', 'small answer'), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 10, + maxChars: 1_000, + }) + + expect(selected.map(message => String(message.uuid))).toEqual(['u2', 'a2']) + }) + + test('drops leading orphan tool results after bounding', () => { + const messages = [ + makeMessage('assistant', 'a0', 'older assistant'), + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [ + { type: 'tool_result', tool_use_id: 'tool-1', content: 'ok' }, + ], + }, + } as unknown as Message, + makeMessage('assistant', 'a1', 'after orphan'), + makeMessage('user', 'u2', 'next prompt'), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 3, + maxChars: 1_000, + }) + + expect(selected.map(message => String(message.uuid))).toEqual(['u2']) + }) +}) + +describe('getSummaryContextFingerprint', () => { + test('estimates circular messages as unbounded', () => { + const circular = makeMessage('assistant', 'a1', 'cycle') as Message & { + self?: unknown + } + circular.self = circular + + expect(estimateMessageChars(circular)).toBe(Number.POSITIVE_INFINITY) + }) + + test('ignores non-json primitive fields in size estimates', () => { + const message = makeMessage('assistant', 'a1', 'metadata') as Message & { + skipUndefined?: undefined + skipFunction?: () => void + skipSymbol?: symbol + } + message.skipUndefined = undefined + message.skipFunction = () => undefined + message.skipSymbol = Symbol('ignored') + + expect(estimateMessageChars(message)).toBeGreaterThan(0) + }) + + test('returns null for an empty transcript', () => { + expect(getSummaryContextFingerprint([])).toBeNull() + }) + + test('changes when the transcript grows', () => { + const messages = [ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'first response'), + ] + + const first = getSummaryContextFingerprint(messages) + const second = getSummaryContextFingerprint([ + ...messages, + makeMessage('user', 'u2', 'next prompt'), + ]) + expect(first?.startsWith('2:a1:')).toBe(true) + expect(second?.startsWith('3:u2:')).toBe(true) + expect(first).not.toBe(second) + }) + + test('changes when message content changes under the same uuid', () => { + const first = getSummaryContextFingerprint([ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'first response'), + ]) + const second = getSummaryContextFingerprint([ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'updated response'), + ]) + + expect(first).not.toBe(second) + }) + + test('includes a truncation marker for oversized primitive values', () => { + const prefix = 'x'.repeat(MAX_SUMMARY_CONTEXT_CHARS + 100) + const first = getSummaryContextFingerprint([ + makeMessage('assistant', 'a1', `${prefix}a`), + ]) + const second = getSummaryContextFingerprint([ + makeMessage('assistant', 'a1', `${prefix}b`), + ]) + + expect(first).not.toBe(second) + }) + + test('fingerprints circular message references without recursing forever', () => { + const circular = makeMessage('assistant', 'a1', 'cycle') as Message & { + self?: unknown + } + circular.self = circular + + expect(getSummaryContextFingerprint([circular])).toContain(':a1:') + }) +}) + +describe('buildSummaryContext', () => { + test('returns bounded messages and fingerprint for summarizable context', () => { + const messages = [ + { type: 'user', uuid: 'u1', message: { content: 'start' } }, + { + type: 'assistant', + uuid: 'a1', + message: { content: [{ type: 'text', text: 'working' }] }, + }, + { type: 'user', uuid: 'u2', message: { content: 'continue' } }, + ] as unknown as Message[] + + const result = buildSummaryContext(messages, null) + + expect(result.skipReason).toBeUndefined() + expect(result.messages.map(message => String(message.uuid))).toEqual([ + 'u1', + 'a1', + 'u2', + ]) + expect(result.fingerprint).toContain('3:u2:') + }) + + test('reports unchanged contexts by fingerprint', () => { + const messages = [ + { type: 'user', uuid: 'u1', message: { content: 'start' } }, + { + type: 'assistant', + uuid: 'a1', + message: { content: [{ type: 'text', text: 'working' }] }, + }, + { type: 'user', uuid: 'u2', message: { content: 'continue' } }, + ] as unknown as Message[] + const first = buildSummaryContext(messages, null) + + const second = buildSummaryContext(messages, first.fingerprint) + + expect(second.skipReason).toBe('unchanged') + expect(second.fingerprint).toBe(first.fingerprint) + }) + + test('filters incomplete tool calls before deciding context is too small', () => { + const messages = [ + { type: 'user', uuid: 'u1', message: { content: 'start' } }, + { + type: 'assistant', + uuid: 'a1', + message: { + content: [{ type: 'tool_use', id: 'missing', name: 'Read' }], + }, + }, + { type: 'user', uuid: 'u2', message: { content: 'continue' } }, + ] as unknown as Message[] + + const result = buildSummaryContext(messages, null) + + expect(result.skipReason).toBe('too_small') + expect(result.messages.map(message => String(message.uuid))).toEqual([ + 'u1', + 'u2', + ]) + }) +}) diff --git a/src/services/AgentSummary/__tests__/summaryPrompt.test.ts b/src/services/AgentSummary/__tests__/summaryPrompt.test.ts new file mode 100644 index 000000000..9e8f03cac --- /dev/null +++ b/src/services/AgentSummary/__tests__/summaryPrompt.test.ts @@ -0,0 +1,34 @@ +import { describe, expect, test } from 'bun:test' +import { + buildSummaryPrompt, + createSummaryPromptMessage, +} from '../summaryPrompt.js' + +describe('buildSummaryPrompt', () => { + test('builds the first summary prompt without previous-summary pressure', () => { + const prompt = buildSummaryPrompt(null) + + expect(prompt).toContain('Describe your most recent action') + expect(prompt).toContain('Good: "Reading runAgent.ts"') + expect(prompt).not.toContain('Previous:') + }) + + test('asks for a new summary when a previous one exists', () => { + const prompt = buildSummaryPrompt('Reading udsMessaging.ts') + + expect(prompt).toContain('Previous: "Reading udsMessaging.ts"') + expect(prompt).toContain('say something NEW') + }) +}) + +describe('createSummaryPromptMessage', () => { + test('creates the minimal user message shape used by forked summaries', () => { + const message = createSummaryPromptMessage('Summarize progress') + + expect(message.type).toBe('user') + expect(message.message.role).toBe('user') + expect(message.message.content).toBe('Summarize progress') + expect(message.uuid).toBeString() + expect(message.timestamp).toBeString() + }) +}) diff --git a/src/services/AgentSummary/agentSummary.ts b/src/services/AgentSummary/agentSummary.ts index 50146b3c7..d212a5c72 100644 --- a/src/services/AgentSummary/agentSummary.ts +++ b/src/services/AgentSummary/agentSummary.ts @@ -13,7 +13,6 @@ import type { TaskContext } from '../../Task.js' import { isPoorModeActive } from '../../commands/poor/poorMode.js' import { updateAgentSummary } from '../../tasks/LocalAgentTask/LocalAgentTask.js' -import { filterIncompleteToolCalls } from '@claude-code-best/builtin-tools/tools/AgentTool/runAgent.js' import type { AgentId } from '../../types/ids.js' import { logForDebugging } from '../../utils/debug.js' import { @@ -21,34 +20,32 @@ import { runForkedAgent, } from '../../utils/forkedAgent.js' import { logError } from '../../utils/log.js' -import { createUserMessage } from '../../utils/messages.js' import { getAgentTranscript } from '../../utils/sessionStorage.js' +import { buildSummaryContext } from './summaryContext.js' +import { + buildSummaryPrompt, + createSummaryPromptMessage, +} from './summaryPrompt.js' const SUMMARY_INTERVAL_MS = 30_000 -function buildSummaryPrompt(previousSummary: string | null): string { - const prevLine = previousSummary - ? `\nPrevious: "${previousSummary}" — say something NEW.\n` - : '' - - return `Describe your most recent action in 3-5 words using present tense (-ing). Name the file or function, not the branch. Do not use tools. -${prevLine} -Good: "Reading runAgent.ts" -Good: "Fixing null check in validate.ts" -Good: "Running auth module tests" -Good: "Adding retry logic to fetchUser" - -Bad (past tense): "Analyzed the branch diff" -Bad (too vague): "Investigating the issue" -Bad (too long): "Reviewing full branch diff and AgentTool.tsx integration" -Bad (branch name): "Analyzed adam/background-summary branch diff"` -} +export type AgentSummaryDependencies = Partial<{ + clearTimeout: typeof clearTimeout + getAgentTranscript: typeof getAgentTranscript + isPoorModeActive: typeof isPoorModeActive + logError: typeof logError + logForDebugging: typeof logForDebugging + runForkedAgent: typeof runForkedAgent + setTimeout: typeof setTimeout + updateAgentSummary: typeof updateAgentSummary +}> export function startAgentSummarization( taskId: string, agentId: AgentId, cacheSafeParams: CacheSafeParams, setAppState: TaskContext['setAppState'], + dependencies: AgentSummaryDependencies = {}, ): { stop: () => void } { // Drop forkContextMessages from the closure — runSummary rebuilds it each // tick from getAgentTranscript(). Without this, the original fork messages @@ -58,39 +55,67 @@ export function startAgentSummarization( let timeoutId: ReturnType | null = null let stopped = false let previousSummary: string | null = null + let lastHandledTranscriptFingerprint: string | null = null + const clearTimeoutImpl = dependencies.clearTimeout ?? clearTimeout + const getAgentTranscriptImpl = + dependencies.getAgentTranscript ?? getAgentTranscript + const isPoorModeActiveImpl = + dependencies.isPoorModeActive ?? isPoorModeActive + const logErrorImpl = dependencies.logError ?? logError + const logForDebuggingImpl = + dependencies.logForDebugging ?? logForDebugging + const runForkedAgentImpl = dependencies.runForkedAgent ?? runForkedAgent + const setTimeoutImpl = dependencies.setTimeout ?? setTimeout + const updateAgentSummaryImpl = + dependencies.updateAgentSummary ?? updateAgentSummary async function runSummary(): Promise { if (stopped) return - if (isPoorModeActive()) { - logForDebugging('[AgentSummary] Skipping summary — poor mode active') + if (isPoorModeActiveImpl()) { + logForDebuggingImpl('[AgentSummary] Skipping summary — poor mode active') scheduleNext() return } - logForDebugging(`[AgentSummary] Timer fired for agent ${agentId}`) + logForDebuggingImpl(`[AgentSummary] Timer fired for agent ${agentId}`) try { // Read current messages from transcript - const transcript = await getAgentTranscript(agentId) + const transcript = await getAgentTranscriptImpl(agentId) if (!transcript || transcript.messages.length < 3) { // Not enough context yet — finally block will schedule next attempt - logForDebugging( + logForDebuggingImpl( `[AgentSummary] Skipping summary for ${taskId}: not enough messages (${transcript?.messages.length ?? 0})`, ) return } - // Filter to clean message state - const cleanMessages = filterIncompleteToolCalls(transcript.messages) + const summaryContext = buildSummaryContext( + transcript.messages, + lastHandledTranscriptFingerprint, + ) + if (summaryContext.skipReason === 'unchanged') { + logForDebuggingImpl( + `[AgentSummary] Skipping summary for ${taskId}: transcript unchanged`, + ) + return + } + + if (summaryContext.skipReason === 'too_small') { + logForDebuggingImpl( + `[AgentSummary] Skipping summary for ${taskId}: no bounded context available`, + ) + return + } // Build fork params with current messages const forkParams: CacheSafeParams = { ...baseParams, - forkContextMessages: cleanMessages, + forkContextMessages: summaryContext.messages, } - logForDebugging( - `[AgentSummary] Forking for summary, ${cleanMessages.length} messages in context`, + logForDebuggingImpl( + `[AgentSummary] Forking for summary, ${summaryContext.messages.length} messages in context`, ) // Create abort controller for this summary @@ -112,9 +137,9 @@ export function startAgentSummarization( // ContentReplacementState is cloned by default in createSubagentContext // from forkParams.toolUseContext (the subagent's LIVE state captured at // onCacheSafeParams time). No explicit override needed. - const result = await runForkedAgent({ + const result = await runForkedAgentImpl({ promptMessages: [ - createUserMessage({ content: buildSummaryPrompt(previousSummary) }), + createSummaryPromptMessage(buildSummaryPrompt(previousSummary)), ], cacheSafeParams: forkParams, canUseTool, @@ -136,21 +161,24 @@ export function startAgentSummarization( ) continue } - const contentArr = Array.isArray(msg.message!.content) ? msg.message!.content : [] + const contentArr = Array.isArray(msg.message!.content) + ? msg.message!.content + : [] const textBlock = contentArr.find(b => b.type === 'text') if (textBlock?.type === 'text' && textBlock.text.trim()) { const summaryText = textBlock.text.trim() - logForDebugging( + logForDebuggingImpl( `[AgentSummary] Summary result for ${taskId}: ${summaryText}`, ) + lastHandledTranscriptFingerprint = summaryContext.fingerprint previousSummary = summaryText - updateAgentSummary(taskId, summaryText, setAppState) + updateAgentSummaryImpl(taskId, summaryText, setAppState) break } } } catch (e) { if (!stopped && e instanceof Error) { - logError(e) + logErrorImpl(e) } } finally { summaryAbortController = null @@ -163,14 +191,14 @@ export function startAgentSummarization( function scheduleNext(): void { if (stopped) return - timeoutId = setTimeout(runSummary, SUMMARY_INTERVAL_MS) + timeoutId = setTimeoutImpl(runSummary, SUMMARY_INTERVAL_MS) } function stop(): void { - logForDebugging(`[AgentSummary] Stopping summarization for ${taskId}`) + logForDebuggingImpl(`[AgentSummary] Stopping summarization for ${taskId}`) stopped = true if (timeoutId) { - clearTimeout(timeoutId) + clearTimeoutImpl(timeoutId) timeoutId = null } if (summaryAbortController) { diff --git a/src/services/AgentSummary/summaryContext.ts b/src/services/AgentSummary/summaryContext.ts new file mode 100644 index 000000000..d4c00e1d4 --- /dev/null +++ b/src/services/AgentSummary/summaryContext.ts @@ -0,0 +1,219 @@ +import { createHash } from 'node:crypto' +import { filterIncompleteToolCalls } from '@claude-code-best/builtin-tools/tools/AgentTool/filterIncompleteToolCalls.js' +import type { Message } from '../../types/message.js' + +export const MAX_SUMMARY_CONTEXT_MESSAGES = 120 +export const MAX_SUMMARY_CONTEXT_CHARS = 200_000 + +function estimateJsonChars( + value: unknown, + limit: number, + seen = new Set(), +): number { + if (value === null) return 4 + switch (typeof value) { + case 'string': + return value.length + 2 + case 'number': + case 'boolean': + return String(value).length + case 'undefined': + case 'function': + case 'symbol': + return 0 + case 'object': { + if (seen.has(value)) return Number.POSITIVE_INFINITY + seen.add(value) + let total = 2 + if (Array.isArray(value)) { + for (let index = 0; index < value.length; index++) { + total += String(index).length + 3 + total += estimateJsonChars(value[index], limit - total, seen) + if (total > limit) return total + } + } else { + const record = value as Record + for (const key in record) { + if (!Object.hasOwn(record, key)) continue + total += key.length + 3 + total += estimateJsonChars(record[key], limit - total, seen) + if (total > limit) return total + } + } + seen.delete(value) + return total + } + } + return 0 +} + +function updateFingerprintHash( + hash: ReturnType, + value: unknown, + limit: { remaining: number }, + seen = new Set(), +): void { + if (limit.remaining <= 0) return + if (value === null || typeof value !== 'object') { + const text = String(value) + const consumed = Math.min(text.length, limit.remaining) + if (consumed <= 0) return + hash.update(typeof value) + hash.update(':') + hash.update(text.slice(0, consumed)) + if (consumed < text.length) { + hash.update(`#truncated:${text.length}:${text.slice(-64)}`) + } + limit.remaining -= consumed + return + } + if (seen.has(value)) { + hash.update('[Circular]') + return + } + seen.add(value) + if (Array.isArray(value)) { + for (let index = 0; index < value.length; index++) { + if (limit.remaining <= 0) break + const key = String(index) + hash.update(key) + limit.remaining -= key.length + updateFingerprintHash(hash, value[index], limit, seen) + } + } else { + const record = value as Record + for (const key in record) { + if (limit.remaining <= 0) break + if (!Object.hasOwn(record, key)) continue + hash.update(key) + limit.remaining -= key.length + updateFingerprintHash(hash, record[key], limit, seen) + } + } + seen.delete(value) +} + +export function estimateMessageChars( + message: Message, + limit = Number.POSITIVE_INFINITY, +): number { + const estimated = estimateJsonChars(message, limit) + if (!Number.isFinite(estimated)) { + return Number.POSITIVE_INFINITY + } + return estimated +} + +function hasToolResultBlock(message: Message): boolean { + if (message.type !== 'user') return false + const content = message.message?.content + return ( + Array.isArray(content) && + content.some(block => { + return Boolean( + block && + typeof block === 'object' && + 'type' in block && + block.type === 'tool_result', + ) + }) + ) +} + +export function getSummaryContextFingerprint( + messages: Message[], +): string | null { + const lastMessage = messages.at(-1) + if (!lastMessage) return null + const hash = createHash('sha256') + updateFingerprintHash(hash, messages, { + remaining: MAX_SUMMARY_CONTEXT_CHARS, + }) + return `${messages.length}:${lastMessage.uuid}:${hash.digest('hex').slice(0, 16)}` +} + +export function selectSummaryContextMessages( + messages: Message[], + limits: { + maxMessages?: number + maxChars?: number + } = {}, +): Message[] { + const maxMessages = limits.maxMessages ?? MAX_SUMMARY_CONTEXT_MESSAGES + const maxChars = limits.maxChars ?? MAX_SUMMARY_CONTEXT_CHARS + if (maxMessages <= 0 || maxChars <= 0) return [] + + const selected: Message[] = [] + let selectedChars = 0 + + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i] + if (!message) continue + + const messageChars = estimateMessageChars(message, maxChars - selectedChars) + if (messageChars > maxChars) { + if (selected.length === 0) return [] + break + } + + if ( + selected.length >= maxMessages || + selectedChars + messageChars > maxChars + ) { + break + } + + selected.unshift(message) + selectedChars += messageChars + } + + while (selected.length > 0) { + const first = selected[0] + if (!first) break + if (first.type !== 'user' || hasToolResultBlock(first)) { + selected.shift() + continue + } + break + } + + return selected +} + +export type SummaryContextBuildResult = { + messages: Message[] + fingerprint: string | null + skipReason?: 'too_small' | 'unchanged' +} + +export function buildSummaryContext( + messages: Message[], + previousFingerprint: string | null, +): SummaryContextBuildResult { + const cleanMessages = filterIncompleteToolCalls(messages) + const boundedMessages = filterIncompleteToolCalls( + selectSummaryContextMessages(cleanMessages), + ) + const fingerprint = getSummaryContextFingerprint(boundedMessages) + + if (fingerprint && fingerprint === previousFingerprint) { + return { + messages: boundedMessages, + fingerprint, + skipReason: 'unchanged', + } + } + + if (boundedMessages.length < 3) { + return { + messages: boundedMessages, + fingerprint, + skipReason: 'too_small', + } + } + + return { + messages: boundedMessages, + fingerprint, + } +} diff --git a/src/services/AgentSummary/summaryPrompt.ts b/src/services/AgentSummary/summaryPrompt.ts new file mode 100644 index 000000000..ce3138f2a --- /dev/null +++ b/src/services/AgentSummary/summaryPrompt.ts @@ -0,0 +1,32 @@ +import { randomUUID, type UUID } from 'node:crypto' +import type { UserMessage } from '../../types/message.js' + +export function buildSummaryPrompt(previousSummary: string | null): string { + const prevLine = previousSummary + ? `\nPrevious: "${previousSummary}" — say something NEW.\n` + : '' + + return `Describe your most recent action in 3-5 words using present tense (-ing). Name the file or function, not the branch. Do not use tools. +${prevLine} +Good: "Reading runAgent.ts" +Good: "Fixing null check in validate.ts" +Good: "Running auth module tests" +Good: "Adding retry logic to fetchUser" + +Bad (past tense): "Analyzed the branch diff" +Bad (too vague): "Investigating the issue" +Bad (too long): "Reviewing full branch diff and AgentTool.tsx integration" +Bad (branch name): "Analyzed adam/background-summary branch diff"` +} + +export function createSummaryPromptMessage(content: string): UserMessage { + return { + type: 'user', + message: { + role: 'user', + content, + }, + uuid: randomUUID() as UUID, + timestamp: new Date().toISOString(), + } +} diff --git a/src/utils/__tests__/ndjsonFramer.test.ts b/src/utils/__tests__/ndjsonFramer.test.ts new file mode 100644 index 000000000..344c1e58c --- /dev/null +++ b/src/utils/__tests__/ndjsonFramer.test.ts @@ -0,0 +1,153 @@ +import { EventEmitter } from 'node:events' +import type { Socket } from 'node:net' +import { describe, expect, test } from 'bun:test' +import { attachNdjsonFramer } from '../ndjsonFramer.js' + +type TestSocket = Socket & { + destroyed: boolean + emitData: (chunk: Buffer) => void +} + +function createTestSocket(): TestSocket { + const emitter = new EventEmitter() as TestSocket + emitter.destroyed = false + emitter.destroy = ((_error?: Error) => { + emitter.destroyed = true + emitter.emit('close') + return emitter + }) as TestSocket['destroy'] + emitter.emitData = (chunk: Buffer) => { + emitter.emit('data', chunk) + } + return emitter +} + +describe('attachNdjsonFramer', () => { + test('accepts a complete frame at the configured byte limit', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + maxFrameBytes: Buffer.byteLength('{"a":1}', 'utf8'), + onFrameError: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{"a":1}\n')) + + expect(messages).toEqual([{ a: 1 }]) + expect(errors).toEqual([]) + expect(socket.destroyed).toBe(false) + }) + + test('destroys a complete frame over the configured byte limit', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + maxFrameBytes: 8, + onFrameError: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{"long":true}\n')) + + expect(messages).toEqual([]) + expect(errors[0]?.message).toContain('NDJSON frame exceeded') + expect(socket.destroyed).toBe(true) + }) + + test('destroys oversized no-newline input before a frame can form', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + maxFrameBytes: 8, + onFrameError: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('x'.repeat(9))) + + expect(messages).toEqual([]) + expect(errors[0]?.message).toContain('NDJSON frame exceeded') + expect(socket.destroyed).toBe(true) + }) + + test('lets callers own oversized-frame shutdown when configured', () => { + const socket = createTestSocket() + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + () => undefined, + text => JSON.parse(text) as unknown, + { + maxFrameBytes: 8, + onFrameError: error => errors.push(error), + destroyOnFrameError: false, + }, + ) + + socket.emitData(Buffer.from('{"long":true}\n')) + + expect(errors[0]?.message).toContain('NDJSON frame exceeded') + expect(socket.destroyed).toBe(false) + }) + + test('reports malformed non-empty frames without changing default compatibility', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + onInvalidFrame: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{not-json\n')) + + expect(messages).toEqual([]) + expect(errors).toHaveLength(1) + expect(socket.destroyed).toBe(false) + }) + + test('destroys malformed frames when configured by the caller', () => { + const socket = createTestSocket() + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + () => undefined, + text => JSON.parse(text) as unknown, + { + destroyOnInvalidFrame: true, + onInvalidFrame: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{not-json\n')) + + expect(errors).toHaveLength(1) + expect(socket.destroyed).toBe(true) + }) +}) diff --git a/src/utils/__tests__/teammateMailbox.test.ts b/src/utils/__tests__/teammateMailbox.test.ts new file mode 100644 index 000000000..f6279dab7 --- /dev/null +++ b/src/utils/__tests__/teammateMailbox.test.ts @@ -0,0 +1,462 @@ +import { afterEach, beforeEach, describe, expect, test } from 'bun:test' +import { mkdir, readFile, rm, writeFile } from 'node:fs/promises' +import { mkdtempSync } from 'node:fs' +import { tmpdir } from 'node:os' +import { dirname, join } from 'node:path' +import type { Message } from 'src/types/message.js' +import { + compactMailboxMessages, + getLastPeerDmSummary, + getInboxPath, + markMessageAsReadByIndex, + markMessageAsReadByIdentity, + markMessagesAsRead, + markMessagesAsReadByPredicate, + MAX_MAILBOX_MESSAGE_TEXT_BYTES, + MAX_MAILBOX_FILE_BYTES, + MAX_MAILBOX_MESSAGES, + MAX_READ_MAILBOX_MESSAGES, + MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES, + readMailbox, + type TeammateMessage, + writeToMailbox, +} from 'src/utils/teammateMailbox.js' + +let tempHome = '' +let previousConfigDir: string | undefined + +function message( + text: string, + read: boolean, + timestamp = new Date(0).toISOString(), +): TeammateMessage { + return { + from: 'team-lead', + text, + timestamp, + read, + } +} + +async function seedMailbox( + agentName: string, + teamName: string, + messages: TeammateMessage[], +): Promise { + const inboxPath = getInboxPath(agentName, teamName) + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, JSON.stringify(messages, null, 2), 'utf-8') +} + +async function readRawMailbox( + agentName: string, + teamName: string, +): Promise { + const content = await readFile(getInboxPath(agentName, teamName), 'utf-8') + return JSON.parse(content) as TeammateMessage[] +} + +describe('compactMailboxMessages', () => { + test('prioritizes unread messages and keeps only recent read history', () => { + const compacted = compactMailboxMessages( + [ + message('read-1', true), + message('read-2', true), + message('unread-1', false), + message('read-3', true), + message('unread-2', false), + message('read-4', true), + message('read-5', true), + message('unread-3', false), + ], + { maxMessages: 5, maxReadMessages: 2 }, + ) + + expect(compacted.map(m => m.text)).toEqual([ + 'unread-1', + 'unread-2', + 'read-4', + 'read-5', + 'unread-3', + ]) + }) + + test('retains unread protocol messages separately from regular cap', () => { + const protocol = message( + JSON.stringify({ type: 'permission_response', request_id: 'req-1' }), + false, + ) + const compacted = compactMailboxMessages( + [ + protocol, + ...Array.from({ length: 5 }, (_value, index) => + message(`regular-${index}`, false), + ), + ], + { + maxMessages: 2, + maxReadMessages: 0, + maxUnreadProtocolMessages: 1, + }, + ) + + expect(compacted.map(m => m.text)).toEqual([ + protocol.text, + 'regular-3', + 'regular-4', + ]) + }) + + test('does not prioritize malformed JSON-like unread messages as protocol', () => { + const compacted = compactMailboxMessages( + [ + message('{not-json', false), + message('regular-1', false), + message('regular-2', false), + ], + { + maxMessages: 1, + maxReadMessages: 0, + maxUnreadProtocolMessages: 10, + }, + ) + + expect(compacted.map(m => m.text)).toEqual(['regular-2']) + }) + + test('caps unread protocol messages with an independent bound', () => { + const compacted = compactMailboxMessages( + Array.from( + { length: MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES + 1 }, + (_value, index) => + message( + JSON.stringify({ + type: 'permission_response', + request_id: `req-${index}`, + }), + false, + ), + ), + ) + + expect(compacted).toHaveLength(MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES) + expect(compacted[0]?.text).toContain('req-1') + }) + + test('keeps retained mailbox bytes under an explicit budget', () => { + const compacted = compactMailboxMessages( + Array.from({ length: 20 }, (_value, index) => + message(`msg-${index}-${'x'.repeat(200)}`, false), + ), + { + maxMessages: 20, + maxReadMessages: 0, + maxRetainedBytes: 1_000, + }, + ) + + expect( + Buffer.byteLength(JSON.stringify(compacted), 'utf8'), + ).toBeLessThanOrEqual(1_000) + expect(compacted.length).toBeLessThan(20) + expect(compacted.at(-1)?.text).toContain('msg-19') + }) + + test('returns an empty mailbox when even one message exceeds retained budget', () => { + const compacted = compactMailboxMessages([message('too-large', false)], { + maxMessages: 10, + maxReadMessages: 0, + maxRetainedBytes: 1, + }) + + expect(compacted).toEqual([]) + }) +}) + +describe('teammate mailbox retention', () => { + beforeEach(() => { + previousConfigDir = process.env.CLAUDE_CONFIG_DIR + tempHome = mkdtempSync(join(tmpdir(), 'teammate-mailbox-')) + process.env.CLAUDE_CONFIG_DIR = tempHome + }) + + afterEach(async () => { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) + tempHome = '' + }) + + test('writeToMailbox compacts oversized unread inbox files', async () => { + const existing = Array.from( + { length: MAX_MAILBOX_MESSAGES + 20 }, + (_value, index) => message(`old-${index}`, false), + ) + await seedMailbox('worker', 'alpha', existing) + + await writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'newest', + timestamp: new Date(1).toISOString(), + }, + 'alpha', + ) + + const after = await readMailbox('worker', 'alpha') + expect(after).toHaveLength(MAX_MAILBOX_MESSAGES) + expect(after[0]?.text).toBe('old-21') + expect(after.at(-1)?.text).toBe('newest') + }) + + test('markMessagesAsRead compacts read history after consumption', async () => { + const existing = Array.from( + { length: MAX_MAILBOX_MESSAGES + 20 }, + (_value, index) => message(`msg-${index}`, false), + ) + await seedMailbox('worker', 'alpha', existing) + + await markMessagesAsRead('worker', 'alpha') + + const after = await readRawMailbox('worker', 'alpha') + expect(after).toHaveLength(MAX_READ_MAILBOX_MESSAGES) + expect(after.every(m => m.read)).toBe(true) + expect(after[0]?.text).toBe( + `msg-${MAX_MAILBOX_MESSAGES + 20 - MAX_READ_MAILBOX_MESSAGES}`, + ) + }) + + test('markMessagesAsReadByPredicate leaves structured messages unread', async () => { + await seedMailbox('worker', 'alpha', [ + message('plain', false), + message(JSON.stringify({ type: 'permission_request' }), false), + ]) + + await markMessagesAsReadByPredicate( + 'worker', + m => !m.text.includes('permission_request'), + 'alpha', + ) + + const after = await readRawMailbox('worker', 'alpha') + expect(after.map(m => m.read)).toEqual([true, false]) + }) + + test('markMessageAsReadByIdentity survives compaction shifting indexes', async () => { + const permissionResponse = message( + JSON.stringify({ type: 'permission_response', request_id: 'req-1' }), + false, + ) + await seedMailbox('worker', 'alpha', [ + permissionResponse, + ...Array.from({ length: MAX_MAILBOX_MESSAGES + 20 }, (_value, index) => + message(`regular-${index}`, false), + ), + ]) + + await writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'newest', + timestamp: new Date(2).toISOString(), + }, + 'alpha', + ) + const marked = await markMessageAsReadByIdentity( + 'worker', + 'alpha', + permissionResponse, + ) + + const after = await readRawMailbox('worker', 'alpha') + expect(marked).toBe(true) + expect(after.some(m => m.text === permissionResponse.text && !m.read)).toBe( + false, + ) + }) + + test('markMessageAsReadByIndex also compacts through the compatibility path', async () => { + const existing = Array.from( + { length: MAX_MAILBOX_MESSAGES + 10 }, + (_value, index) => message(`msg-${index}`, false), + ) + await seedMailbox('worker', 'alpha', existing) + + await markMessageAsReadByIndex('worker', 'alpha', existing.length - 1) + + const after = await readRawMailbox('worker', 'alpha') + expect(after).toHaveLength(MAX_MAILBOX_MESSAGES) + expect(after.some(m => m.text === `msg-${existing.length - 1}`)).toBe(false) + expect(after.at(-1)?.text).toBe(`msg-${existing.length - 2}`) + }) + + test('writeToMailbox rejects oversized message text instead of storing it', async () => { + await expect( + writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'x'.repeat(MAX_MAILBOX_MESSAGE_TEXT_BYTES + 1), + timestamp: new Date(3).toISOString(), + }, + 'alpha', + ), + ).rejects.toThrow('Mailbox message text exceeds') + + expect(await readRawMailbox('worker', 'alpha')).toEqual([]) + }) + + test('writeToMailbox fails closed when an existing mailbox is corrupt', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, '{not-json', 'utf-8') + + await expect( + writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'new', + timestamp: new Date(4).toISOString(), + }, + 'alpha', + ), + ).rejects.toThrow() + + expect(await readFile(inboxPath, 'utf-8')).toBe('{not-json') + }) + + test('readMailbox fails closed on corrupt mailbox content', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, '{not-json', 'utf-8') + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow() + }) + + test('readMailbox rejects non-array mailbox files', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, JSON.stringify({ text: 'not an array' }), 'utf-8') + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow( + 'expected message array', + ) + }) + + test('readMailbox rejects malformed stored message shapes', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile( + inboxPath, + JSON.stringify([{ from: 'lead', text: 'missing timestamp' }]), + 'utf-8', + ) + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow( + 'Invalid mailbox message shape', + ) + }) + + test('readMailbox rejects non-object stored messages', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, JSON.stringify(['not an object']), 'utf-8') + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow( + 'expected object', + ) + }) + + test('readMailbox rejects oversized mailbox files before parsing', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, `[${' '.repeat(MAX_MAILBOX_FILE_BYTES)}]`, 'utf-8') + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow( + 'Mailbox file exceeds', + ) + }) + + test('markMessageAsReadByIdentity returns false for missing mailbox files', async () => { + await expect( + markMessageAsReadByIdentity('worker', 'alpha', message('absent', false)), + ).resolves.toBe(false) + }) + + test('markMessageAsReadByIdentity returns false when the expected message moved out', async () => { + await seedMailbox('worker', 'alpha', [message('other', false)]) + + await expect( + markMessageAsReadByIdentity('worker', 'alpha', message('missing', false)), + ).resolves.toBe(false) + + expect((await readRawMailbox('worker', 'alpha'))[0]?.read).toBe(false) + }) + + test('markMessageAsReadByIdentity returns false on corrupt mailbox content', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, '{not-json', 'utf-8') + + await expect( + markMessageAsReadByIdentity('worker', 'alpha', message('missing', false)), + ).resolves.toBe(false) + }) +}) + +describe('getLastPeerDmSummary', () => { + test('extracts the final peer direct-message summary from assistant tool use', () => { + const messages = [ + { type: 'user', message: { content: 'wake up' } }, + { + type: 'assistant', + message: { + content: [ + { + type: 'tool_use', + name: 'SendMessage', + input: { + to: 'worker-1', + message: 'please check the UDS bounds', + summary: 'Checking UDS bounds', + }, + }, + ], + }, + }, + ] as unknown as Message[] + + expect(getLastPeerDmSummary(messages)).toBe( + '[to worker-1] Checking UDS bounds', + ) + }) + + test('stops peer direct-message summary search at the wake-up boundary', () => { + const messages = [ + { + type: 'assistant', + message: { + content: [ + { + type: 'tool_use', + name: 'SendMessage', + input: { + to: 'worker-1', + message: 'old message', + }, + }, + ], + }, + }, + { type: 'user', message: { content: 'new prompt' } }, + ] as unknown as Message[] + + expect(getLastPeerDmSummary(messages)).toBeUndefined() + }) +}) diff --git a/src/utils/__tests__/udsMessaging.test.ts b/src/utils/__tests__/udsMessaging.test.ts new file mode 100644 index 000000000..392daa1ac --- /dev/null +++ b/src/utils/__tests__/udsMessaging.test.ts @@ -0,0 +1,614 @@ +import { afterEach, beforeEach, describe, expect, test } from 'bun:test' +import { + chmod, + mkdir, + mkdtemp, + readdir, + rm, + stat, + symlink, + unlink, + writeFile, +} from 'node:fs/promises' +import { createHash } from 'node:crypto' +import { createConnection, createServer } from 'node:net' +import { dirname, join } from 'node:path' +import { tmpdir } from 'node:os' +import { + drainInbox, + getDefaultUdsSocketPath, + MAX_UDS_INBOX_ENTRIES, + MAX_UDS_INBOX_BYTES, + MAX_UDS_FRAME_BYTES, + MAX_UDS_CLIENTS, + formatUdsAddress, + parseUdsTarget, + sendUdsMessage, + setOnEnqueue, + startUdsMessaging, + stopUdsMessaging, + UDS_AUTH_TIMEOUT_MS, +} from '../udsMessaging.js' + +let previousConfigDir: string | undefined +let tempConfigDir = '' + +function socketPath(label: string): string { + const suffix = `${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}-${label}` + if (process.platform === 'win32') { + return `\\\\.\\pipe\\claude-code-test-${suffix}` + } + return join(tmpdir(), 'claude-code-test', `${suffix}.sock`) +} + +function sleep(ms: number): Promise { + return new Promise(resolve => setTimeout(resolve, ms)) +} + +async function waitForEnqueues( + expected: number, + sendMessages: () => Promise, +): Promise { + let count = 0 + let resolveDone: (() => void) | undefined + const done = new Promise(resolve => { + resolveDone = resolve + }) + + setOnEnqueue(() => { + count++ + if (count >= expected) resolveDone?.() + }) + + await sendMessages() + await Promise.race([ + done, + sleep(5_000).then(() => { + throw new Error(`Timed out waiting for ${expected} UDS enqueues`) + }), + ]) + setOnEnqueue(null) +} + +beforeEach(async () => { + previousConfigDir = process.env.CLAUDE_CONFIG_DIR + tempConfigDir = await mkdtemp(join(tmpdir(), 'uds-messaging-home-')) + process.env.CLAUDE_CONFIG_DIR = tempConfigDir +}) + +afterEach(async () => { + setOnEnqueue(null) + drainInbox() + await stopUdsMessaging() + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + if (tempConfigDir) { + await rm(tempConfigDir, { recursive: true, force: true }) + tempConfigDir = '' + } +}) + +async function closeServer(server: ReturnType): Promise { + await new Promise(resolve => { + server.close(() => resolve()) + }) +} + +describe('UDS inbox retention', () => { + test('drainInbox returns each pending socket message once', async () => { + const path = socketPath('drain') + await startUdsMessaging(path, { isExplicit: true }) + expect(process.env.CLAUDE_CODE_MESSAGING_TOKEN).toBeUndefined() + + await waitForEnqueues(2, async () => { + await sendUdsMessage(path, { type: 'text', data: 'one' }) + await sendUdsMessage(path, { type: 'text', data: 'two' }) + }) + + const drained = drainInbox() + expect(drained.map(entry => entry.message.data)).toEqual(['one', 'two']) + expect(drained.every(entry => entry.status === 'processed')).toBe(true) + expect(drainInbox()).toEqual([]) + }) + + test('inbox is capped when messages arrive faster than they are drained', async () => { + const path = socketPath('cap') + await startUdsMessaging(path, { isExplicit: true }) + + await waitForEnqueues(MAX_UDS_INBOX_ENTRIES, async () => { + for (let i = 0; i < MAX_UDS_INBOX_ENTRIES; i++) { + await sendUdsMessage(path, { type: 'text', data: String(i) }) + } + }) + await expect( + sendUdsMessage(path, { type: 'text', data: 'overflow' }), + ).rejects.toThrow('inbox full') + + const drained = drainInbox() + expect(drained).toHaveLength(MAX_UDS_INBOX_ENTRIES) + expect(drained[0]?.message.data).toBe('0') + expect(drained.at(-1)?.message.data).toBe(String(MAX_UDS_INBOX_ENTRIES - 1)) + }) + + test('inbox is capped by retained bytes before entry count', async () => { + const path = socketPath('byte-cap') + await startUdsMessaging(path, { isExplicit: true }) + + const payload = 'x'.repeat(32 * 1024) + let accepted = 0 + for (;;) { + try { + await sendUdsMessage(path, { type: 'text', data: payload }) + accepted++ + if (accepted > MAX_UDS_INBOX_BYTES / payload.length + 20) { + throw new Error('byte cap was not enforced') + } + } catch (error) { + expect(error).toBeInstanceOf(Error) + expect((error as Error).message).toContain('inbox full') + break + } + } + + const drained = drainInbox() + expect(drained.length).toBe(accepted) + expect(drained.length).toBeLessThan(MAX_UDS_INBOX_ENTRIES) + }) + + test('ping replies with pong without enqueueing inbox work', async () => { + const path = socketPath('ping') + await startUdsMessaging(path, { isExplicit: true }) + + await sendUdsMessage(path, { type: 'ping' }) + expect(drainInbox()).toEqual([]) + }) + + test('udsClient helpers authenticate through the capability file', async () => { + const path = socketPath('uds-client') + await startUdsMessaging(path, { isExplicit: true }) + const { isPeerAlive, sendToUdsSocket } = await import('../udsClient.js') + + expect(await isPeerAlive(path)).toBe(true) + await waitForEnqueues(1, async () => { + await sendToUdsSocket(path, 'hello from client') + }) + + const drained = drainInbox() + expect(drained).toHaveLength(1) + expect(drained[0]?.message.data).toBe('hello from client') + expect(drained[0]?.message.meta).toBeUndefined() + }) + + test('udsClient peer probe fails closed on oversized pong frames', async () => { + const path = socketPath('uds-client-oversized-pong') + if (process.platform !== 'win32') { + await mkdir(dirname(path), { recursive: true }) + } + const receiver = createServer(socket => { + socket.on('data', () => { + socket.write('x'.repeat(MAX_UDS_FRAME_BYTES + 1)) + }) + }) + await new Promise((resolve, reject) => { + receiver.on('error', reject) + receiver.listen(path, () => resolve()) + }) + + try { + const { isPeerAlive } = await import('../udsClient.js') + expect(await isPeerAlive(path, 3_000, 'test-token')).toBe(false) + } finally { + await closeServer(receiver) + if (process.platform !== 'win32') { + await unlink(path).catch(() => undefined) + } + } + }) + + test('udsClient send fails closed when no capability token exists', async () => { + const path = socketPath('uds-client-no-token') + const { sendToUdsSocket } = await import('../udsClient.js') + + await expect(sendToUdsSocket(path, 'hello')).rejects.toThrow( + 'No auth token found', + ) + }) + + test('sendUdsMessage fails closed before connecting without an auth token', async () => { + await expect( + sendUdsMessage(socketPath('no-auth-token'), { type: 'text', data: 'x' }), + ).rejects.toThrow('without auth token') + }) + + test('drained entries never expose the UDS auth token', async () => { + const path = socketPath('strip-token') + await startUdsMessaging(path, { isExplicit: true }) + + await waitForEnqueues(1, async () => { + await sendUdsMessage(path, { + type: 'notification', + meta: { keep: 'visible' }, + }) + }) + + const drained = drainInbox() + expect(drained).toHaveLength(1) + expect(drained[0]?.message.meta).toEqual({ keep: 'visible' }) + expect(drained[0]?.message.meta).not.toHaveProperty('authToken') + }) + + test('rejects unauthenticated socket messages', async () => { + const path = socketPath('auth') + await startUdsMessaging(path, { isExplicit: true }) + + const response = await new Promise((resolve, reject) => { + let responseText = '' + const conn = createConnection(path, () => { + conn.write(`${JSON.stringify({ type: 'text', data: 'bad' })}\n`) + }) + conn.setTimeout(5_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for auth rejection')) + }) + conn.on('data', chunk => { + const text = chunk.toString('utf-8') + if (text.includes('\n')) { + responseText = text + } + }) + conn.on('close', () => resolve(responseText)) + conn.on('error', reject) + }) + + expect(JSON.parse(response).type).toBe('error') + expect(drainInbox()).toEqual([]) + }) + + test('disconnects malformed JSON clients without enqueueing inbox work', async () => { + const path = socketPath('malformed-client') + await startUdsMessaging(path, { isExplicit: true }) + + const response = await new Promise((resolve, reject) => { + let responseText = '' + const conn = createConnection(path, () => { + conn.write('{not-json\n') + }) + conn.setTimeout(5_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for malformed frame close')) + }) + conn.on('data', chunk => { + responseText += chunk.toString('utf-8') + }) + conn.on('close', () => resolve(responseText)) + conn.on('error', reject) + }) + + const parsed = JSON.parse(response) + expect(parsed.type).toBe('error') + expect(parsed.data).toBe('invalid frame') + expect(drainInbox()).toEqual([]) + }) + + test('disconnects idle unauthenticated clients', async () => { + const path = socketPath('idle-client') + await startUdsMessaging(path, { isExplicit: true }) + + const response = await new Promise((resolve, reject) => { + let responseText = '' + const conn = createConnection(path) + conn.setTimeout(UDS_AUTH_TIMEOUT_MS + 2_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for auth timeout close')) + }) + conn.on('data', chunk => { + responseText += chunk.toString('utf-8') + }) + conn.on('close', () => resolve(responseText)) + conn.on('error', reject) + }) + + const parsed = JSON.parse(response) + expect(parsed.type).toBe('error') + expect(parsed.data).toBe('authentication timeout') + expect(drainInbox()).toEqual([]) + }) + + test('destroys oversized frames before enqueueing inbox work', async () => { + const path = socketPath('oversized') + await startUdsMessaging(path, { isExplicit: true }) + + await new Promise((resolve, reject) => { + const conn = createConnection(path, () => { + conn.write('x'.repeat(MAX_UDS_FRAME_BYTES + 1)) + }) + conn.setTimeout(5_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for oversized frame close')) + }) + conn.on('close', () => resolve()) + conn.on('error', () => resolve()) + }) + + expect(drainInbox()).toEqual([]) + }) + + test('default socket path is regenerated after stop', async () => { + const firstPath = getDefaultUdsSocketPath() + await startUdsMessaging(firstPath) + await stopUdsMessaging() + + expect(getDefaultUdsSocketPath()).not.toBe(firstPath) + }) + + test('rejects oversized receiver responses before retaining them', async () => { + const path = socketPath('oversized-response') + if (process.platform !== 'win32') { + await mkdir(dirname(path), { recursive: true }) + } + const receiver = createServer(socket => { + socket.on('data', () => { + socket.write('x'.repeat(MAX_UDS_FRAME_BYTES + 1)) + }) + }) + await new Promise((resolve, reject) => { + receiver.on('error', reject) + receiver.listen(path, () => resolve()) + }) + + try { + await expect( + sendUdsMessage( + path, + { type: 'text', data: 'hello' }, + { authToken: 'test-token' }, + ), + ).rejects.toThrow('UDS response frame exceeded size limit') + } finally { + await closeServer(receiver) + if (process.platform !== 'win32') { + await unlink(path).catch(() => undefined) + } + } + }) + + test('rejects closed receiver responses without waiting for timeout', async () => { + const path = socketPath('closed-response') + if (process.platform !== 'win32') { + await mkdir(dirname(path), { recursive: true }) + } + const receiver = createServer(socket => { + socket.end() + }) + await new Promise((resolve, reject) => { + receiver.on('error', reject) + receiver.listen(path, () => resolve()) + }) + + try { + await expect( + sendUdsMessage( + path, + { type: 'text', data: 'hello' }, + { authToken: 'test-token' }, + ), + ).rejects.toThrow('before response') + } finally { + await closeServer(receiver) + if (process.platform !== 'win32') { + await unlink(path).catch(() => undefined) + } + } + }) + + test('rejects malformed receiver responses without waiting for timeout', async () => { + const path = socketPath('malformed-response') + if (process.platform !== 'win32') { + await mkdir(dirname(path), { recursive: true }) + } + const receiver = createServer(socket => { + socket.on('data', () => { + socket.write('{not-json\n') + }) + }) + await new Promise((resolve, reject) => { + receiver.on('error', reject) + receiver.listen(path, () => resolve()) + }) + + try { + await expect( + sendUdsMessage( + path, + { type: 'text', data: 'hello' }, + { authToken: 'test-token' }, + ), + ).rejects.toThrow('Invalid UDS response frame') + } finally { + await closeServer(receiver) + if (process.platform !== 'win32') { + await unlink(path).catch(() => undefined) + } + } + }) + + test('rejects inline auth token UDS targets instead of parsing them', async () => { + const path = socketPath('inline-token') + + expect(formatUdsAddress(path)).toBe(`uds:${path}`) + + const targetWithToken = `${path}#token=secret` + expect(() => parseUdsTarget(targetWithToken)).toThrow('inline auth token') + try { + parseUdsTarget(targetWithToken) + } catch (error) { + expect((error as Error).message).not.toContain('secret') + } + + const { sendToUdsSocket } = await import('../udsClient.js') + await expect(sendToUdsSocket(targetWithToken, 'hello')).rejects.toThrow( + 'inline auth token', + ) + }) + + test('fails closed and cleans temp files when capability target is occupied', async () => { + const path = socketPath('capability-target-dir') + const capabilityDir = join(tempConfigDir, 'messaging-capabilities') + const capabilityName = `${createHash('sha256').update(path).digest('hex')}.json` + await mkdir(join(capabilityDir, capabilityName), { + recursive: true, + mode: 0o700, + }) + + await expect( + startUdsMessaging(path, { isExplicit: true }), + ).rejects.toThrow() + + expect(process.env.CLAUDE_CODE_MESSAGING_SOCKET).toBeUndefined() + expect(await readdir(capabilityDir)).toEqual([capabilityName]) + }) + + if (process.platform !== 'win32') { + test('creates the listening socket with owner-only permissions', async () => { + const path = socketPath('socket-mode') + await startUdsMessaging(path, { isExplicit: true }) + + const mode = (await stat(path)).mode & 0o777 + expect(mode).toBe(0o600) + }) + + test('fails closed when the capability directory is not private', async () => { + const previousConfigDir = process.env.CLAUDE_CONFIG_DIR + const tempHome = join( + tmpdir(), + `uds-capability-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + ) + process.env.CLAUDE_CONFIG_DIR = tempHome + const capabilityDir = join(tempHome, 'messaging-capabilities') + await mkdir(capabilityDir, { recursive: true, mode: 0o755 }) + await chmod(capabilityDir, 0o755) + + try { + const path = socketPath('broad-capdir') + await expect( + startUdsMessaging(path, { isExplicit: true }), + ).rejects.toThrow('permissions are too broad') + await expect(stat(path)).rejects.toThrow() + } finally { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) + } + }) + + test('fails closed when the capability directory is a symlink', async () => { + const previousConfigDir = process.env.CLAUDE_CONFIG_DIR + const tempHome = join( + tmpdir(), + `uds-capability-link-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + ) + const target = join(tempHome, 'target') + process.env.CLAUDE_CONFIG_DIR = tempHome + await mkdir(target, { recursive: true, mode: 0o700 }) + await symlink(target, join(tempHome, 'messaging-capabilities'), 'dir') + + try { + await expect( + startUdsMessaging(socketPath('symlink-capdir'), { isExplicit: true }), + ).rejects.toThrow('not a private directory') + } finally { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) + } + }) + + test('fails closed when an explicit socket parent is not private', async () => { + const parent = join( + tmpdir(), + `uds-socket-parent-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + ) + await mkdir(parent, { recursive: true, mode: 0o755 }) + await chmod(parent, 0o755) + + try { + await expect( + startUdsMessaging(join(parent, 'messaging.sock'), { + isExplicit: true, + }), + ).rejects.toThrow('socket parent permissions are too broad') + } finally { + await rm(parent, { recursive: true, force: true }) + } + }) + + test('fails closed when an explicit socket parent is a file', async () => { + const parentFile = join( + tmpdir(), + `uds-socket-parent-file-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + ) + await writeFile(parentFile, 'not a directory', 'utf-8') + + try { + await expect( + startUdsMessaging(join(parentFile, 'messaging.sock'), { + isExplicit: true, + }), + ).rejects.toThrow('socket parent is not a directory') + } finally { + await rm(parentFile, { force: true }) + } + }) + + test('stop tolerates an already removed socket path', async () => { + const path = socketPath('already-removed') + await startUdsMessaging(path, { isExplicit: true }) + await unlink(path) + + await stopUdsMessaging() + + expect(process.env.CLAUDE_CODE_MESSAGING_SOCKET).toBeUndefined() + }) + + test('rejects clients over the configured connection cap', async () => { + const path = socketPath('client-cap') + await startUdsMessaging(path, { isExplicit: true }) + const sockets: ReturnType[] = [] + + try { + for (let i = 0; i < MAX_UDS_CLIENTS; i++) { + const socket = await new Promise>( + (resolve, reject) => { + const conn = createConnection(path, () => resolve(conn)) + conn.on('error', reject) + }, + ) + sockets.push(socket) + } + + await new Promise((resolve, reject) => { + const extra = createConnection(path) + extra.on('close', () => resolve()) + extra.on('error', reject) + extra.setTimeout(5_000, () => { + extra.destroy() + reject(new Error('Timed out waiting for client cap close')) + }) + }) + } finally { + for (const socket of sockets) { + socket.destroy() + } + } + }) + } +}) diff --git a/src/utils/__tests__/udsResponseReader.test.ts b/src/utils/__tests__/udsResponseReader.test.ts new file mode 100644 index 000000000..71203da62 --- /dev/null +++ b/src/utils/__tests__/udsResponseReader.test.ts @@ -0,0 +1,171 @@ +import { describe, expect, test } from 'bun:test' +import { EventEmitter } from 'node:events' +import type { Socket } from 'node:net' +import { attachUdsResponseReader } from '../udsResponseReader.js' + +class FakeSocket extends EventEmitter { + destroyed = false + ended = false + + destroy(): this { + this.destroyed = true + this.emit('close', true) + return this + } + + end(): this { + this.ended = true + this.emit('close', false) + return this + } + + emitData(chunk: Buffer): void { + this.emit('data', chunk) + } +} + +function asSocket(socket: FakeSocket): Socket { + return socket as unknown as Socket +} + +describe('attachUdsResponseReader', () => { + test('tracks byte limits across split multibyte response chunks', () => { + const socket = new FakeSocket() + let settled = false + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settled = true + settledError = error + }, + }) + + const multibyte = String.fromCodePoint(0x20ac) + const frame = Buffer.from( + JSON.stringify({ type: 'response', data: `ok ${multibyte}` }) + '\n', + 'utf8', + ) + const multibyteStart = frame.indexOf(Buffer.from(multibyte, 'utf8')[0]) + + socket.emitData(frame.subarray(0, multibyteStart + 1)) + expect(settled).toBe(false) + + socket.emitData(frame.subarray(multibyteStart + 1)) + expect(settled).toBe(true) + expect(settledError).toBeUndefined() + expect(socket.ended).toBe(true) + }) + + test('rejects malformed response frames immediately', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + }) + + socket.emitData(Buffer.from('{bad-json}\n')) + + expect(settledError?.message).toBe('Invalid UDS response frame') + expect(socket.destroyed).toBe(true) + }) + + test('skips blank frames before a valid response', () => { + const socket = new FakeSocket() + let settled = false + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settled = true + settledError = error + }, + }) + + socket.emitData(Buffer.from('\n \n')) + expect(settled).toBe(false) + + socket.emitData(Buffer.from(`${JSON.stringify({ type: 'response' })}\n`)) + expect(settled).toBe(true) + expect(settledError).toBeUndefined() + expect(socket.ended).toBe(true) + }) + + test('rejects receiver error frames', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + }) + + socket.emitData( + Buffer.from(`${JSON.stringify({ type: 'error', data: 'denied' })}\n`), + ) + + expect(settledError?.message).toBe('denied') + expect(socket.destroyed).toBe(true) + }) + + test('uses custom socket error formatting', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + formatSocketError: error => + new Error(`wrapped:${(error as Error).message}`), + }) + + socket.emit('error', new Error('connect failed')) + + expect(settledError?.message).toBe('wrapped:connect failed') + expect(socket.destroyed).toBe(true) + }) + + test('rejects socket end before response', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + }) + + socket.emit('end') + + expect(settledError?.message).toBe('UDS socket ended before response') + expect(socket.destroyed).toBe(true) + }) + + test('rejects clean socket close before response', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + }) + + socket.emit('close', false) + + expect(settledError?.message).toBe('UDS socket closed before response') + expect(socket.destroyed).toBe(true) + }) +}) diff --git a/src/utils/ndjsonFramer.ts b/src/utils/ndjsonFramer.ts index 968ee5217..69717fc11 100644 --- a/src/utils/ndjsonFramer.ts +++ b/src/utils/ndjsonFramer.ts @@ -7,9 +7,18 @@ */ import type { Socket } from 'net' +export type NdjsonFramerOptions = { + maxFrameBytes?: number + onFrameError?: (error: Error) => void + destroyOnFrameError?: boolean + onInvalidFrame?: (error: Error) => void + destroyOnInvalidFrame?: boolean +} + /** * Attach an NDJSON framer to a socket. Calls `onMessage` for each - * complete JSON line received. Malformed lines are silently skipped. + * complete JSON line received. Malformed lines are skipped by default; + * callers may opt into error callbacks or socket destruction. * * @param parse - Optional custom JSON parser (defaults to JSON.parse). * Useful when the caller uses a wrapped parser like jsonParse @@ -19,21 +28,73 @@ export function attachNdjsonFramer( socket: Socket, onMessage: (msg: T) => void, parse: (text: string) => T = text => JSON.parse(text) as T, + options: NdjsonFramerOptions = {}, ): void { let buffer = '' + let bufferBytes = 0 + const maxFrameBytes = options.maxFrameBytes ?? Number.POSITIVE_INFINITY + + const rejectOversizedFrame = (bytes: number): void => { + const error = new Error( + `NDJSON frame exceeded ${maxFrameBytes} bytes (${bytes})`, + ) + options.onFrameError?.(error) + if (options.destroyOnFrameError ?? true) { + socket.destroy(error) + } + } + + const rejectInvalidFrame = (error: unknown): void => { + const frameError = + error instanceof Error ? error : new Error('Invalid NDJSON frame') + options.onInvalidFrame?.(frameError) + if (options.destroyOnInvalidFrame ?? false) { + socket.destroy(frameError) + } + } + + const emitLine = (line: string): void => { + if (!line.trim()) return + try { + onMessage(parse(line)) + } catch (error) { + rejectInvalidFrame(error) + } + } socket.on('data', (chunk: Buffer) => { - buffer += chunk.toString() - const lines = buffer.split('\n') - buffer = lines.pop() ?? '' - - for (const line of lines) { - if (!line.trim()) continue - try { - onMessage(parse(line)) - } catch { - // Malformed JSON — skip + let start = 0 + for (let index = 0; index < chunk.length; index++) { + if (chunk[index] !== 0x0a) continue + + const segmentBytes = index - start + if ( + Number.isFinite(maxFrameBytes) && + bufferBytes + segmentBytes > maxFrameBytes + ) { + rejectOversizedFrame(bufferBytes + segmentBytes) + return } + + buffer += chunk.subarray(start, index).toString('utf8') + emitLine(buffer) + buffer = '' + bufferBytes = 0 + start = index + 1 + } + + const tailBytes = chunk.length - start + if ( + Number.isFinite(maxFrameBytes) && + bufferBytes + tailBytes > maxFrameBytes + ) { + rejectOversizedFrame(bufferBytes + tailBytes) + return + } + + if (tailBytes > 0) { + buffer += chunk.subarray(start).toString('utf8') + bufferBytes += tailBytes } }) } diff --git a/src/utils/swarm/inProcessRunner.ts b/src/utils/swarm/inProcessRunner.ts index 1735500b4..06fde705a 100644 --- a/src/utils/swarm/inProcessRunner.ts +++ b/src/utils/swarm/inProcessRunner.ts @@ -97,7 +97,7 @@ import { getLastPeerDmSummary, isPermissionResponse, isShutdownRequest, - markMessageAsReadByIndex, + markMessageAsReadByIdentity, readMailbox, writeToMailbox, } from '../teammateMailbox.js' @@ -405,10 +405,10 @@ function createInProcessCanUseTool( if (msg && !msg.read) { const parsed = isPermissionResponse(msg.text) if (parsed && parsed.request_id === request.id) { - await markMessageAsReadByIndex( + await markMessageAsReadByIdentity( identity.agentName, identity.teamName, - i, + msg, ) if (parsed.subtype === 'success') { processMailboxPermissionResponse({ @@ -801,10 +801,10 @@ async function waitForNextPromptOrShutdown( logForDebugging( `[inProcessRunner] ${identity.agentName} received shutdown request from ${shutdownParsed?.from} (prioritized over ${skippedUnread} unread messages)`, ) - await markMessageAsReadByIndex( + await markMessageAsReadByIdentity( identity.agentName, identity.teamName, - shutdownIndex, + msg, ) return { type: 'shutdown_request', @@ -839,10 +839,10 @@ async function waitForNextPromptOrShutdown( logForDebugging( `[inProcessRunner] ${identity.agentName} received new message from ${msg.from} (index ${selectedIndex})`, ) - await markMessageAsReadByIndex( + await markMessageAsReadByIdentity( identity.agentName, identity.teamName, - selectedIndex, + msg, ) return { type: 'new_message', diff --git a/src/utils/teammateMailbox.ts b/src/utils/teammateMailbox.ts index eb72fcc21..ad9b22f93 100644 --- a/src/utils/teammateMailbox.ts +++ b/src/utils/teammateMailbox.ts @@ -7,7 +7,8 @@ * Note: Inboxes are keyed by agent name within a team. */ -import { mkdir, readFile, writeFile } from 'fs/promises' +import { randomBytes } from 'crypto' +import { mkdir, readFile, rename, stat, unlink, writeFile } from 'fs/promises' import { join } from 'path' import { z } from 'zod/v4' import { TEAMMATE_MESSAGE_TAG } from '../constants/xml.js' @@ -40,6 +41,13 @@ const LOCK_OPTIONS = { }, } +export const MAX_MAILBOX_MESSAGES = 1_000 +export const MAX_READ_MAILBOX_MESSAGES = 200 +export const MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES = 2_000 +export const MAX_MAILBOX_MESSAGE_TEXT_BYTES = 64 * 1024 +export const MAX_MAILBOX_RETAINED_BYTES = 2 * 1024 * 1024 +export const MAX_MAILBOX_FILE_BYTES = 4 * 1024 * 1024 + export type TeammateMessage = { from: string text: string @@ -49,6 +57,223 @@ export type TeammateMessage = { summary?: string // 5-10 word summary shown as preview in the UI } +function isJsonLikeMessage(text: string): boolean { + const trimmed = text.trimStart() + return trimmed.startsWith('{') || trimmed.startsWith('[') +} + +function shouldRetainUnreadAsProtocolMessage( + message: TeammateMessage, +): boolean { + if (message.read) return false + if (isStructuredProtocolMessage(message.text)) return true + if (!isJsonLikeMessage(message.text)) return false + + try { + const parsed = jsonParse(message.text) + return Boolean( + parsed && + typeof parsed === 'object' && + 'type' in (parsed as Record), + ) + } catch { + return false + } +} + +function sameMailboxMessage(a: TeammateMessage, b: TeammateMessage): boolean { + return a.from === b.from && a.timestamp === b.timestamp && a.text === b.text +} + +function mailboxMessageStorageBytes(message: TeammateMessage): number { + return Buffer.byteLength(jsonStringify(message), 'utf8') +} + +function assertMailboxMessageSize(message: TeammateMessage): void { + const textBytes = Buffer.byteLength(message.text, 'utf8') + if (textBytes > MAX_MAILBOX_MESSAGE_TEXT_BYTES) { + throw new Error( + `Mailbox message text exceeds ${MAX_MAILBOX_MESSAGE_TEXT_BYTES} bytes`, + ) + } +} + +function toMailboxMessage(value: unknown): TeammateMessage { + if (!value || typeof value !== 'object') { + throw new Error('Invalid mailbox message: expected object') + } + const record = value as Record + if ( + typeof record.from !== 'string' || + typeof record.text !== 'string' || + typeof record.timestamp !== 'string' || + typeof record.read !== 'boolean' + ) { + throw new Error('Invalid mailbox message shape') + } + const message: TeammateMessage = { + from: record.from, + text: record.text, + timestamp: record.timestamp, + read: record.read, + ...(typeof record.color === 'string' ? { color: record.color } : {}), + ...(typeof record.summary === 'string' ? { summary: record.summary } : {}), + } + assertMailboxMessageSize(message) + return message +} + +function parseMailboxMessages(content: string): TeammateMessage[] { + const parsed = jsonParse(content) + if (!Array.isArray(parsed)) { + throw new Error('Invalid mailbox file: expected message array') + } + return parsed.map(toMailboxMessage) +} + +async function readMailboxFile(inboxPath: string): Promise { + const info = await stat(inboxPath) + if (info.size > MAX_MAILBOX_FILE_BYTES) { + throw new Error( + `Mailbox file exceeds ${MAX_MAILBOX_FILE_BYTES} bytes: ${inboxPath}`, + ) + } + return readFile(inboxPath, 'utf-8') +} + +async function readMailboxForMutation( + agentName: string, + teamName?: string, +): Promise { + const inboxPath = getInboxPath(agentName, teamName) + return parseMailboxMessages(await readMailboxFile(inboxPath)) +} + +async function writeMailboxAtomic( + inboxPath: string, + content: string, +): Promise { + const bytes = Buffer.byteLength(content, 'utf8') + if (bytes > MAX_MAILBOX_FILE_BYTES) { + throw new Error( + `Compacted mailbox still exceeds ${MAX_MAILBOX_FILE_BYTES} bytes`, + ) + } + const tempPath = `${inboxPath}.${process.pid}.${randomBytes(8).toString('hex')}.tmp` + try { + await writeFile(tempPath, content, 'utf-8') + await rename(tempPath, inboxPath) + } catch (error) { + await unlink(tempPath).catch(() => undefined) + throw error + } +} + +export function compactMailboxMessages( + messages: TeammateMessage[], + limits: { + maxMessages?: number + maxReadMessages?: number + maxUnreadProtocolMessages?: number + maxRetainedBytes?: number + } = {}, +): TeammateMessage[] { + const maxMessages = limits.maxMessages ?? MAX_MAILBOX_MESSAGES + const maxReadMessages = limits.maxReadMessages ?? MAX_READ_MAILBOX_MESSAGES + const maxUnreadProtocolMessages = + limits.maxUnreadProtocolMessages ?? MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES + const maxRetainedBytes = limits.maxRetainedBytes ?? MAX_MAILBOX_RETAINED_BYTES + + if ( + maxRetainedBytes <= 0 || + (maxMessages <= 0 && maxUnreadProtocolMessages <= 0) + ) { + return [] + } + + const keepIndexes = new Set() + let retainedBytes = 0 + let keptUnreadProtocolMessages = 0 + const tryKeep = (index: number): boolean => { + if (keepIndexes.has(index)) return true + const message = messages[index] + if (!message) return false + const bytes = mailboxMessageStorageBytes(message) + if (bytes > maxRetainedBytes || retainedBytes + bytes > maxRetainedBytes) { + return false + } + keepIndexes.add(index) + retainedBytes += bytes + return true + } + + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i] + if (!message || !shouldRetainUnreadAsProtocolMessage(message)) continue + if (keptUnreadProtocolMessages >= maxUnreadProtocolMessages) continue + if (tryKeep(i)) keptUnreadProtocolMessages++ + } + + let keptNonProtocolMessages = 0 + for (let i = messages.length - 1; i >= 0; i--) { + if (keptNonProtocolMessages >= maxMessages) break + const message = messages[i] + if ( + message && + !message.read && + !shouldRetainUnreadAsProtocolMessage(message) + ) { + if (tryKeep(i)) keptNonProtocolMessages++ + } + } + + let keptReadMessages = 0 + for (let i = messages.length - 1; i >= 0; i--) { + if (keptNonProtocolMessages >= maxMessages) break + if (keptReadMessages >= maxReadMessages) break + const message = messages[i] + if (message?.read) { + if (tryKeep(i)) { + keptReadMessages++ + keptNonProtocolMessages++ + } + } + } + + return messages.filter((_message, index) => keepIndexes.has(index)) +} + +function logUnreadMailboxEvictions( + original: TeammateMessage[], + compacted: TeammateMessage[], + context: string, +): void { + const kept = new Set(compacted) + const unreadEvicted = original.filter(message => { + return !message.read && !kept.has(message) + }) + if (unreadEvicted.length === 0) return + + const protocolEvicted = count(unreadEvicted, message => + shouldRetainUnreadAsProtocolMessage(message), + ) + logError( + new Error( + `[TeammateMailbox] Compacted ${unreadEvicted.length} unread message(s) in ${context}; protocol_or_unknown=${protocolEvicted}`, + ), + ) +} + +async function writeCompactedMailbox( + inboxPath: string, + messages: TeammateMessage[], + context: string, +): Promise { + const compacted = compactMailboxMessages(messages) + logUnreadMailboxEvictions(messages, compacted, context) + await writeMailboxAtomic(inboxPath, jsonStringify(compacted, null, 2)) +} + /** * Get the path to a teammate's inbox file * Structure: ~/.claude/teams/{team_name}/inboxes/{agent_name}.json @@ -89,8 +314,7 @@ export async function readMailbox( logForDebugging(`[TeammateMailbox] readMailbox: path=${inboxPath}`) try { - const content = await readFile(inboxPath, 'utf-8') - const messages = jsonParse(content) as TeammateMessage[] + const messages = parseMailboxMessages(await readMailboxFile(inboxPath)) logForDebugging( `[TeammateMailbox] readMailbox: read ${messages.length} message(s)`, ) @@ -103,7 +327,7 @@ export async function readMailbox( } logForDebugging(`Failed to read inbox for ${agentName}: ${error}`) logError(error) - return [] + throw error } } @@ -156,7 +380,7 @@ export async function writeToMailbox( `[TeammateMailbox] writeToMailbox: failed to create inbox file: ${error}`, ) logError(error) - return + throw error } } @@ -168,22 +392,23 @@ export async function writeToMailbox( }) // Re-read messages after acquiring lock to get the latest state - const messages = await readMailbox(recipientName, teamName) + const messages = await readMailboxForMutation(recipientName, teamName) - const newMessage: TeammateMessage = { + const newMessage = toMailboxMessage({ ...message, read: false, - } + }) messages.push(newMessage) - await writeFile(inboxPath, jsonStringify(messages, null, 2), 'utf-8') + await writeCompactedMailbox(inboxPath, messages, 'writeToMailbox') logForDebugging( `[TeammateMailbox] Wrote message to ${recipientName}'s inbox from ${message.from}`, ) } catch (error) { logForDebugging(`Failed to write to inbox for ${recipientName}: ${error}`) logError(error) + throw error } finally { if (release) { await release() @@ -222,7 +447,7 @@ export async function markMessageAsReadByIndex( logForDebugging(`[TeammateMailbox] markMessageAsReadByIndex: lock acquired`) // Re-read messages after acquiring lock to get the latest state - const messages = await readMailbox(agentName, teamName) + const messages = await readMailboxForMutation(agentName, teamName) logForDebugging( `[TeammateMailbox] markMessageAsReadByIndex: read ${messages.length} messages after lock`, ) @@ -244,7 +469,7 @@ export async function markMessageAsReadByIndex( messages[messageIndex] = { ...message, read: true } - await writeFile(inboxPath, jsonStringify(messages, null, 2), 'utf-8') + await writeCompactedMailbox(inboxPath, messages, 'markMessageAsReadByIndex') logForDebugging( `[TeammateMailbox] markMessageAsReadByIndex: marked message at index ${messageIndex} as read`, ) @@ -270,6 +495,46 @@ export async function markMessageAsReadByIndex( } } +export async function markMessageAsReadByIdentity( + agentName: string, + teamName: string | undefined, + expectedMessage: TeammateMessage, +): Promise { + const inboxPath = getInboxPath(agentName, teamName) + const lockFilePath = `${inboxPath}.lock` + + let release: (() => Promise) | undefined + try { + release = await lockfile.lock(inboxPath, { + lockfilePath: lockFilePath, + ...LOCK_OPTIONS, + }) + + const messages = await readMailboxForMutation(agentName, teamName) + const messageIndex = messages.findIndex(message => { + return !message.read && sameMailboxMessage(message, expectedMessage) + }) + if (messageIndex < 0) return false + + messages[messageIndex] = { ...messages[messageIndex]!, read: true } + await writeCompactedMailbox( + inboxPath, + messages, + 'markMessageAsReadByIdentity', + ) + return true + } catch (error) { + const code = getErrnoCode(error) + if (code === 'ENOENT') return false + logError(error) + return false + } finally { + if (release) { + await release() + } + } +} + /** * Mark all messages in a teammate's inbox as read * Uses file locking to prevent race conditions @@ -297,7 +562,7 @@ export async function markMessagesAsRead( logForDebugging(`[TeammateMailbox] markMessagesAsRead: lock acquired`) // Re-read messages after acquiring lock to get the latest state - const messages = await readMailbox(agentName, teamName) + const messages = await readMailboxForMutation(agentName, teamName) logForDebugging( `[TeammateMailbox] markMessagesAsRead: read ${messages.length} messages after lock`, ) @@ -317,7 +582,7 @@ export async function markMessagesAsRead( // messages comes from jsonParse — fresh, unshared objects safe to mutate for (const m of messages) m.read = true - await writeFile(inboxPath, jsonStringify(messages, null, 2), 'utf-8') + await writeCompactedMailbox(inboxPath, messages, 'markMessagesAsRead') logForDebugging( `[TeammateMailbox] markMessagesAsRead: WROTE ${unreadCount} message(s) as read to ${inboxPath}`, ) @@ -1114,7 +1379,7 @@ export async function markMessagesAsReadByPredicate( ...LOCK_OPTIONS, }) - const messages = await readMailbox(agentName, teamName) + const messages = await readMailboxForMutation(agentName, teamName) if (messages.length === 0) { return } @@ -1123,7 +1388,11 @@ export async function markMessagesAsReadByPredicate( !m.read && predicate(m) ? { ...m, read: true } : m, ) - await writeFile(inboxPath, jsonStringify(updatedMessages, null, 2), 'utf-8') + await writeCompactedMailbox( + inboxPath, + updatedMessages, + 'markMessagesAsReadByPredicate', + ) } catch (error) { const code = getErrnoCode(error) if (code === 'ENOENT') { @@ -1161,7 +1430,12 @@ export function getLastPeerDmSummary(messages: Message[]): string | undefined { if (!Array.isArray(content)) continue for (const block of content) { if (typeof block === 'string') continue - const b = block as unknown as { type: string; name?: string; input?: Record; [key: string]: unknown } + const b = block as unknown as { + type: string + name?: string + input?: Record + [key: string]: unknown + } if ( b.type === 'tool_use' && b.name === SEND_MESSAGE_TOOL_NAME && @@ -1177,7 +1451,7 @@ export function getLastPeerDmSummary(messages: Message[]): string | undefined { const to = b.input.to as string const summary = 'summary' in b.input && typeof b.input.summary === 'string' - ? b.input.summary as string + ? (b.input.summary as string) : (b.input.message as string).slice(0, 80) return `[to ${to}] ${summary}` } diff --git a/src/utils/udsClient.ts b/src/utils/udsClient.ts index 781f3ddd1..e33ef3fdb 100644 --- a/src/utils/udsClient.ts +++ b/src/utils/udsClient.ts @@ -16,7 +16,8 @@ import { errorMessage, isFsInaccessible } from './errors.js' import { isProcessRunning } from './genericProcessUtils.js' import { jsonParse, jsonStringify } from './slowOperations.js' import type { SessionKind } from './concurrentSessions.js' -import type { UdsMessage } from './udsMessaging.js' +import { MAX_UDS_FRAME_BYTES, type UdsMessage } from './udsMessaging.js' +import { attachUdsResponseReader, getChunkBytes } from './udsResponseReader.js' // --------------------------------------------------------------------------- // Types @@ -104,9 +105,14 @@ export async function listAllLiveSessions(): Promise { */ export async function listPeers(): Promise { const all = await listAllLiveSessions() - return all.filter( - s => s.pid !== process.pid && s.messagingSocketPath != null, - ) + return all.filter(s => s.pid !== process.pid && s.messagingSocketPath != null) +} + +async function findAuthTokenForSocketPath( + socketPath: string, +): Promise { + const { readUdsCapabilityToken } = await import('./udsMessaging.js') + return readUdsCapabilityToken(socketPath) } // --------------------------------------------------------------------------- @@ -117,10 +123,21 @@ export async function listPeers(): Promise { * Probe a UDS socket to check if a server is listening (ping/pong). * Returns true if the peer responds within the timeout. */ -export async function isPeerAlive(socketPath: string, timeoutMs = 3000): Promise { - return new Promise((resolve) => { +export async function isPeerAlive( + socketPath: string, + timeoutMs = 3000, + authToken?: string, +): Promise { + const token = authToken ?? (await findAuthTokenForSocketPath(socketPath)) + if (!token) return false + + return new Promise(resolve => { const conn = createConnection(socketPath, () => { - const ping: UdsMessage = { type: 'ping', ts: new Date().toISOString() } + const ping: UdsMessage = { + type: 'ping', + ts: new Date().toISOString(), + meta: { authToken: token }, + } conn.write(jsonStringify(ping) + '\n') }) @@ -135,7 +152,19 @@ export async function isPeerAlive(socketPath: string, timeoutMs = 3000): Promise }, timeoutMs) let buffer = '' - conn.on('data', (chunk) => { + conn.on('data', chunk => { + if ( + Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > + MAX_UDS_FRAME_BYTES + ) { + if (!resolved) { + resolved = true + clearTimeout(timer) + conn.destroy() + resolve(false) + } + return + } buffer += chunk.toString() if (buffer.includes('"pong"')) { if (!resolved) { @@ -165,6 +194,13 @@ export async function sendToUdsSocket( targetSocketPath: string, message: string | Record, ): Promise { + const { parseUdsTarget } = await import('./udsMessaging.js') + const target = parseUdsTarget(targetSocketPath) + const authToken = await findAuthTokenForSocketPath(target.socketPath) + if (!authToken) { + throw new Error(`No auth token found for peer at ${target.socketPath}`) + } + const data = typeof message === 'string' ? message : jsonStringify(message) const udsMsg: UdsMessage = { type: 'text', @@ -177,18 +213,36 @@ export async function sendToUdsSocket( udsMsg.from = getUdsMessagingSocketPath() return new Promise((resolve, reject) => { - const conn = createConnection(targetSocketPath, () => { - conn.write(jsonStringify(udsMsg) + '\n', (err) => { + let settled = false + let conn: ReturnType + const finish = (error?: Error): void => { + if (settled) return + settled = true + if (error) { + conn.destroy(error) + reject(error) + } else { conn.end() - if (err) reject(err) - else resolve() + resolve() + } + } + + conn = createConnection(target.socketPath, () => { + udsMsg.meta = { ...udsMsg.meta, authToken } + conn.write(jsonStringify(udsMsg) + '\n', err => { + if (err) finish(err) }) }) - conn.on('error', (err) => { - reject(new Error(`Failed to connect to peer at ${targetSocketPath}: ${errorMessage(err)}`)) + attachUdsResponseReader(conn, { + maxFrameBytes: MAX_UDS_FRAME_BYTES, + onSettled: finish, + formatSocketError: err => + new Error( + `Failed to connect to peer at ${target.socketPath}: ${errorMessage(err)}`, + ), }) conn.setTimeout(5000, () => { - conn.destroy(new Error('Connection timed out')) + finish(new Error('Connection timed out')) }) }) } diff --git a/src/utils/udsMessaging.ts b/src/utils/udsMessaging.ts index 1c95ab63c..b30cba137 100644 --- a/src/utils/udsMessaging.ts +++ b/src/utils/udsMessaging.ts @@ -8,14 +8,26 @@ * but can be overridden via --messaging-socket-path. */ +import { createHash, randomBytes, timingSafeEqual } from 'crypto' import { createServer, type Server, type Socket } from 'net' -import { mkdir, unlink } from 'fs/promises' +import { + chmod, + lstat, + mkdir, + open, + readFile, + rename, + unlink, +} from 'fs/promises' import { dirname, join } from 'path' import { tmpdir } from 'os' import { registerCleanup } from './cleanupRegistry.js' import { logForDebugging } from './debug.js' import { errorMessage } from './errors.js' +import { getClaudeConfigHomeDir } from './envUtils.js' import { attachNdjsonFramer } from './ndjsonFramer.js' +import { attachUdsResponseReader } from './udsResponseReader.js' +import { logError } from './log.js' import { jsonParse, jsonStringify } from './slowOperations.js' // --------------------------------------------------------------------------- @@ -27,6 +39,7 @@ export type UdsMessageType = | 'notification' | 'query' | 'response' + | 'error' | 'ping' | 'pong' @@ -60,6 +73,17 @@ let onEnqueueCb: (() => void) | null = null const clients = new Set() const inbox: UdsInboxEntry[] = [] let nextId = 1 +let defaultSocketPath: string | null = null +let authToken: string | null = null +let capabilityFilePath: string | null = null +let inboxBytes = 0 + +export const MAX_UDS_INBOX_ENTRIES = 1_000 +export const MAX_UDS_FRAME_BYTES = 64 * 1024 +export const MAX_UDS_INBOX_BYTES = 2 * 1024 * 1024 +export const MAX_UDS_CLIENTS = 128 +export const UDS_AUTH_TIMEOUT_MS = 2_000 +export const UDS_IDLE_TIMEOUT_MS = 30_000 // --------------------------------------------------------------------------- // Public API — socket path helpers @@ -74,10 +98,19 @@ let nextId = 1 * transparently, but we use the pipe format on Windows for Node.js compat. */ export function getDefaultUdsSocketPath(): string { + if (defaultSocketPath) return defaultSocketPath + const nonce = randomBytes(16).toString('hex') if (process.platform === 'win32') { - return `\\\\.\\pipe\\claude-code-${process.pid}` + defaultSocketPath = `\\\\.\\pipe\\claude-code-${process.pid}-${nonce}` + return defaultSocketPath } - return join(tmpdir(), 'claude-code-socks', `${process.pid}.sock`) + defaultSocketPath = join( + tmpdir(), + 'claude-code-socks', + `${process.pid}-${nonce}`, + 'messaging.sock', + ) + return defaultSocketPath } /** @@ -88,6 +121,153 @@ export function getUdsMessagingSocketPath(): string | undefined { return socketPath ?? undefined } +export function formatUdsAddress(socket: string): string { + return `uds:${socket}` +} + +export function parseUdsTarget(target: string): { + socketPath: string +} { + if (target.includes('#token=')) { + throw new Error( + 'UDS target must not include an inline auth token; use the ListPeers address', + ) + } + return { socketPath: target } +} + +function getCapabilityDir(): string { + return join(getClaudeConfigHomeDir(), 'messaging-capabilities') +} + +function getCapabilityPath(socket: string): string { + const digest = createHash('sha256').update(socket).digest('hex') + return join(getCapabilityDir(), `${digest}.json`) +} + +function isNotFound(error: unknown): boolean { + return ( + typeof error === 'object' && + error !== null && + (error as NodeJS.ErrnoException).code === 'ENOENT' + ) +} + +async function assertPrivateCapabilityDir(dir: string): Promise { + let stat: Awaited> + try { + stat = await lstat(dir) + } catch (error) { + if (!isNotFound(error)) throw error + await mkdir(dir, { recursive: true, mode: 0o700 }) + stat = await lstat(dir) + } + + assertPrivateDirectory(stat, dir, 'capability directory') + await chmod(dir, 0o700) +} + +function assertPrivateDirectory( + stat: Awaited>, + dir: string, + label: string, +): void { + if (!stat.isDirectory() || stat.isSymbolicLink()) { + throw new Error( + `[udsMessaging] ${label} is not a private directory: ${dir}`, + ) + } + if (process.platform !== 'win32') { + const broadMode = Number(stat.mode) & 0o077 + if (broadMode !== 0) { + throw new Error( + `[udsMessaging] ${label} permissions are too broad: ${dir}`, + ) + } + if ( + typeof process.getuid === 'function' && + Number(stat.uid) !== process.getuid() + ) { + throw new Error( + `[udsMessaging] ${label} owner does not match current user: ${dir}`, + ) + } + } +} + +async function writePrivateFileExclusive( + path: string, + content: string, +): Promise { + const handle = await open(path, 'wx', 0o600) + try { + await handle.writeFile(content, 'utf-8') + } finally { + await handle.close() + } + await chmod(path, 0o600) +} + +async function ensureSocketParent(path: string): Promise { + const dir = dirname(path) + try { + const stat = await lstat(dir) + if (!stat.isDirectory() || stat.isSymbolicLink()) { + throw new Error( + `[udsMessaging] socket parent is not a directory: ${dir}`, + ) + } + assertPrivateDirectory(stat, dir, 'socket parent') + return + } catch (error) { + if (!isNotFound(error)) throw error + } + + await mkdir(dir, { recursive: true, mode: 0o700 }) + await chmod(dir, 0o700) +} + +async function writeCapabilityFile( + socket: string, + token: string, +): Promise { + const dir = getCapabilityDir() + await assertPrivateCapabilityDir(dir) + const target = getCapabilityPath(socket) + const temp = `${target}.${process.pid}.${randomBytes(8).toString('hex')}.tmp` + try { + await writePrivateFileExclusive( + temp, + jsonStringify({ socketPath: socket, authToken: token }), + ) + await rename(temp, target) + } catch (error) { + try { + await unlink(temp) + } catch { + // Temp file may not exist if exclusive creation failed. + } + throw error + } + capabilityFilePath = target +} + +export async function readUdsCapabilityToken( + socket: string, +): Promise { + try { + const parsed = jsonParse( + await readFile(getCapabilityPath(socket), 'utf-8'), + ) as Record + if (parsed.socketPath === socket && typeof parsed.authToken === 'string') { + return parsed.authToken + } + } catch { + // Missing or unreadable capability file means the peer is not addressable. + } + return undefined +} + // --------------------------------------------------------------------------- // Inbox // --------------------------------------------------------------------------- @@ -101,16 +281,121 @@ export function setOnEnqueue(cb: (() => void) | null): void { } /** - * Drain all pending inbox messages, marking them processed. + * Drain all pending inbox messages and release retained history. */ export function drainInbox(): UdsInboxEntry[] { - const pending = inbox.filter(e => e.status === 'pending') + const pending = inbox.splice(0, inbox.length) + inboxBytes = 0 for (const entry of pending) { entry.status = 'processed' } return pending } +function getMessageBytes(message: UdsMessage): number { + return Buffer.byteLength(jsonStringify(message), 'utf8') +} + +function enqueueInboxEntry(entry: UdsInboxEntry): boolean { + const entryBytes = getMessageBytes(entry.message) + if ( + entryBytes > MAX_UDS_FRAME_BYTES || + inbox.length >= MAX_UDS_INBOX_ENTRIES || + inboxBytes + entryBytes > MAX_UDS_INBOX_BYTES + ) { + logError( + new Error( + `[udsMessaging] inbox full (${inbox.length}/${MAX_UDS_INBOX_ENTRIES}, ${inboxBytes}/${MAX_UDS_INBOX_BYTES} bytes); dropping message type=${entry.message.type}`, + ), + ) + return false + } + inbox.push(entry) + inboxBytes += entryBytes + return true +} + +function ensureAuthToken(): string { + if (!authToken) { + authToken = randomBytes(32).toString('hex') + } + return authToken +} + +function getMessageAuthToken(message: UdsMessage): string | undefined { + const token = message.meta?.authToken + return typeof token === 'string' ? token : undefined +} + +function isAuthorizedMessage(message: UdsMessage): boolean { + const provided = getMessageAuthToken(message) + if (!provided || !authToken) return false + const providedBuffer = Buffer.from(provided, 'utf8') + const expectedBuffer = Buffer.from(authToken, 'utf8') + if (providedBuffer.length !== expectedBuffer.length) return false + return timingSafeEqual(providedBuffer, expectedBuffer) +} + +function writeSocketMessage(socket: Socket, message: UdsMessage): void { + if (socket.destroyed) return + socket.write(jsonStringify(message) + '\n') +} + +function writeSocketMessageAndDestroy(socket: Socket, message: UdsMessage): void { + if (socket.destroyed) return + socket.write(jsonStringify(message) + '\n', () => { + if (!socket.destroyed) socket.destroy() + }) +} + +function writeSocketErrorAndDestroy(socket: Socket, data: string): void { + writeSocketMessageAndDestroy(socket, { + type: 'error', + data, + ts: new Date().toISOString(), + }) +} + +function unrefTimer(timer: ReturnType): void { + const maybeUnref = (timer as { unref?: () => void }).unref + if (typeof maybeUnref === 'function') { + maybeUnref.call(timer) + } +} + +async function closeServer(serverToClose: Server): Promise { + await new Promise(resolve => { + serverToClose.close(() => resolve()) + }) +} + +async function removeSocketPath(path: string): Promise { + if (process.platform === 'win32') return + try { + await unlink(path) + } catch { + // Already gone. + } +} + +function stripAuthToken(message: UdsMessage): UdsMessage { + const { authToken: _authToken, ...metaWithoutAuth } = message.meta ?? {} + return { + ...message, + meta: Object.keys(metaWithoutAuth).length > 0 ? metaWithoutAuth : undefined, + } +} + +function withRequestAuthToken(message: UdsMessage, token: string): UdsMessage { + return { + ...message, + meta: { + ...message.meta, + authToken: token, + }, + } +} + // --------------------------------------------------------------------------- // Server // --------------------------------------------------------------------------- @@ -132,7 +417,7 @@ export async function startUdsMessaging( // Ensure parent directory exists (skip on Windows — pipe paths aren't files) if (process.platform !== 'win32') { - await mkdir(dirname(path), { recursive: true }) + await ensureSocketParent(path) } // Clean up stale socket file (skip on Windows — pipe paths aren't files) @@ -144,69 +429,195 @@ export async function startUdsMessaging( } } - socketPath = path + const token = ensureAuthToken() + let startedServer: Server | null = null + let exportedSocketEnv = false + try { + await new Promise((resolve, reject) => { + const srv = createServer(socket => { + if (clients.size >= MAX_UDS_CLIENTS) { + logForDebugging( + `[udsMessaging] rejected client: ${clients.size}/${MAX_UDS_CLIENTS} clients already connected`, + ) + socket.destroy() + return + } + clients.add(socket) + logForDebugging( + `[udsMessaging] client connected (total: ${clients.size})`, + ) + let authenticated = false + let closing = false + const closeWithError = (data: string): void => { + if (closing || socket.destroyed) return + closing = true + socket.pause() + writeSocketErrorAndDestroy(socket, data) + } + const authTimer = setTimeout(() => { + if (authenticated || socket.destroyed) return + logForDebugging('[udsMessaging] closing unauthenticated idle client') + closeWithError('authentication timeout') + }, UDS_AUTH_TIMEOUT_MS) + unrefTimer(authTimer) + socket.setTimeout(UDS_IDLE_TIMEOUT_MS, () => { + logForDebugging('[udsMessaging] closing idle client') + closeWithError('idle timeout') + }) + + attachNdjsonFramer( + socket, + msg => { + if (!isAuthorizedMessage(msg)) { + logForDebugging( + `[udsMessaging] rejected unauthenticated message type=${msg.type}`, + ) + closeWithError('unauthorized') + return + } + if (!authenticated) { + authenticated = true + clearTimeout(authTimer) + } - await new Promise((resolve, reject) => { - const srv = createServer(socket => { - clients.add(socket) - logForDebugging( - `[udsMessaging] client connected (total: ${clients.size})`, - ) + // Handle ping with automatic pong + if (msg.type === 'ping') { + writeSocketMessage(socket, { + type: 'pong', + from: socketPath ?? undefined, + ts: new Date().toISOString(), + }) + return + } - attachNdjsonFramer( - socket, - msg => { - // Handle ping with automatic pong - if (msg.type === 'ping') { - const pong: UdsMessage = { - type: 'pong', - from: socketPath ?? undefined, - ts: new Date().toISOString(), + // Enqueue into inbox + const sanitizedMessage = stripAuthToken(msg) + const entry: UdsInboxEntry = { + id: `uds-${nextId++}`, + message: sanitizedMessage, + receivedAt: Date.now(), + status: 'pending', } - if (!socket.destroyed) { - socket.write(jsonStringify(pong) + '\n') + if (!enqueueInboxEntry(entry)) { + closeWithError('inbox full') + return } - return - } + logForDebugging( + `[udsMessaging] enqueued message type=${msg.type} from=${msg.from ?? 'unknown'}`, + ) + writeSocketMessage(socket, { + type: 'response', + data: 'ok', + ts: new Date().toISOString(), + meta: { id: entry.id }, + }) + onEnqueueCb?.() + }, + text => jsonParse(text) as UdsMessage, + { + maxFrameBytes: MAX_UDS_FRAME_BYTES, + onFrameError: error => { + logForDebugging(`[udsMessaging] ${error.message}`) + closeWithError(error.message) + }, + onInvalidFrame: error => { + logForDebugging( + `[udsMessaging] invalid client frame: ${errorMessage(error)}`, + ) + closeWithError('invalid frame') + }, + destroyOnFrameError: false, + }, + ) + + socket.on('close', () => { + clearTimeout(authTimer) + clients.delete(socket) + }) + + socket.on('error', err => { + clearTimeout(authTimer) + clients.delete(socket) + logForDebugging(`[udsMessaging] client error: ${errorMessage(err)}`) + }) + }) - // Enqueue into inbox - const entry: UdsInboxEntry = { - id: `uds-${nextId++}`, - message: msg, - receivedAt: Date.now(), - status: 'pending', - } - inbox.push(entry) - logForDebugging( - `[udsMessaging] enqueued message type=${msg.type} from=${msg.from ?? 'unknown'}`, - ) - onEnqueueCb?.() - }, - text => jsonParse(text) as UdsMessage, - ) + const rejectBeforeListen = (error: Error): void => { + reject(error) + } + const logRuntimeError = (error: Error): void => { + logForDebugging( + `[udsMessaging] server error on ${path}${opts?.isExplicit ? ' (explicit)' : ''}: ${errorMessage(error)}`, + ) + } - socket.on('close', () => { - clients.delete(socket) - }) + srv.once('error', rejectBeforeListen) - socket.on('error', err => { - clients.delete(socket) - logForDebugging(`[udsMessaging] client error: ${errorMessage(err)}`) + srv.listen(path, () => { + void (async () => { + try { + if (process.platform !== 'win32') { + await chmod(path, 0o600) + } + srv.off('error', rejectBeforeListen) + srv.on('error', logRuntimeError) + server = srv + startedServer = srv + resolve() + } catch (error) { + srv.off('error', rejectBeforeListen) + const closeError = + error instanceof Error ? error : new Error(errorMessage(error)) + let rejected = false + const rejectOnce = (): void => { + if (rejected) return + rejected = true + reject(closeError) + } + const fallback = setTimeout(rejectOnce, 1_000) + unrefTimer(fallback) + srv.close(() => { + clearTimeout(fallback) + rejectOnce() + }) + } + })() }) }) - srv.on('error', reject) - - srv.listen(path, () => { - server = srv - // Export so child processes can discover the socket - process.env.CLAUDE_CODE_MESSAGING_SOCKET = path - logForDebugging( - `[udsMessaging] server listening on ${path}${opts?.isExplicit ? ' (explicit)' : ''}`, - ) - resolve() - }) - }) + await writeCapabilityFile(path, token) + socketPath = path + // Export so child processes can discover the socket only after the + // capability file exists and the listener is ready. + process.env.CLAUDE_CODE_MESSAGING_SOCKET = path + exportedSocketEnv = true + logForDebugging( + `[udsMessaging] server listening on ${path}${opts?.isExplicit ? ' (explicit)' : ''}`, + ) + } catch (error) { + if (capabilityFilePath) { + try { + await unlink(capabilityFilePath) + } catch { + // Already gone. + } + capabilityFilePath = null + } + if (startedServer) { + await closeServer(startedServer) + } + if (server === startedServer) { + server = null + } + await removeSocketPath(path) + if (exportedSocketEnv) { + delete process.env.CLAUDE_CODE_MESSAGING_SOCKET + } + socketPath = null + defaultSocketPath = null + authToken = null + throw error + } // Register cleanup so the socket file is removed on exit registerCleanup(async () => { @@ -218,6 +629,7 @@ export async function startUdsMessaging( * Stop the UDS messaging server and clean up the socket file. */ export async function stopUdsMessaging(): Promise { + defaultSocketPath = null if (!server) return // Close all connected clients @@ -230,21 +642,27 @@ export async function stopUdsMessaging(): Promise { server!.close(() => resolve()) }) server = null + inbox.length = 0 + inboxBytes = 0 + onEnqueueCb = null // Remove socket file (skip on Windows — pipe paths aren't files) if (socketPath) { - if (process.platform !== 'win32') { - try { - await unlink(socketPath) - } catch { - // Already gone - } - } + await removeSocketPath(socketPath) delete process.env.CLAUDE_CODE_MESSAGING_SOCKET logForDebugging( `[udsMessaging] server stopped, socket removed: ${socketPath}`, ) socketPath = null + authToken = null + } + if (capabilityFilePath) { + try { + await unlink(capabilityFilePath) + } catch { + // Already gone + } + capabilityFilePath = null } } @@ -255,23 +673,50 @@ export async function stopUdsMessaging(): Promise { export async function sendUdsMessage( targetSocketPath: string, message: UdsMessage, + opts: { authToken?: string } = {}, ): Promise { const { createConnection } = await import('net') - message.from = message.from ?? socketPath ?? undefined - message.ts = message.ts ?? new Date().toISOString() + const token = opts.authToken ?? authToken + if (!token) { + throw new Error('Cannot send UDS message without auth token') + } + const outbound = withRequestAuthToken( + { + ...message, + from: message.from ?? socketPath ?? undefined, + ts: message.ts ?? new Date().toISOString(), + }, + token, + ) return new Promise((resolve, reject) => { - const conn = createConnection(targetSocketPath, () => { - conn.write(jsonStringify(message) + '\n', err => { + let settled = false + let conn: ReturnType + const finish = (error?: Error): void => { + if (settled) return + settled = true + if (error) { + conn.destroy(error) + reject(error) + } else { conn.end() - if (err) reject(err) - else resolve() + resolve() + } + } + + conn = createConnection(targetSocketPath, () => { + conn.write(jsonStringify(outbound) + '\n', err => { + if (err) finish(err) }) }) - conn.on('error', reject) + attachUdsResponseReader(conn, { + maxFrameBytes: MAX_UDS_FRAME_BYTES, + acceptPong: true, + onSettled: finish, + }) // Timeout so we don't hang on unreachable sockets conn.setTimeout(5000, () => { - conn.destroy(new Error('Connection timed out')) + finish(new Error('Connection timed out')) }) }) } diff --git a/src/utils/udsResponseReader.ts b/src/utils/udsResponseReader.ts new file mode 100644 index 000000000..d86328aab --- /dev/null +++ b/src/utils/udsResponseReader.ts @@ -0,0 +1,120 @@ +import type { Socket } from 'net' +import { StringDecoder } from 'node:string_decoder' +import { errorMessage } from './errors.js' +import { jsonParse } from './slowOperations.js' +import type { UdsMessage } from './udsMessaging.js' + +type UdsResponseReaderOptions = { + maxFrameBytes: number + acceptPong?: boolean + onSettled: (error?: Error) => void + formatSocketError?: (error: unknown) => Error +} + +export function getChunkBytes(chunk: string | Buffer): number { + return typeof chunk === 'string' + ? Buffer.byteLength(chunk, 'utf8') + : chunk.byteLength +} + +function parseResponseLine(line: string): UdsMessage { + try { + return jsonParse(line) as UdsMessage + } catch { + throw new Error('Invalid UDS response frame') + } +} + +export function attachUdsResponseReader( + socket: Socket, + options: UdsResponseReaderOptions, +): void { + let buffer = '' + let bufferBytes = 0 + let settled = false + const decoder = new StringDecoder('utf8') + + function cleanupListeners(): void { + socket.off('data', onData) + socket.off('error', onError) + socket.off('end', onEnd) + socket.off('close', onClose) + } + + function finish(error?: Error): void { + if (settled) return + settled = true + buffer = '' + bufferBytes = 0 + cleanupListeners() + if (error) { + socket.destroy() + } else { + socket.end() + } + options.onSettled(error) + } + + function onData(chunk: Buffer): void { + const decoded = decoder.write(chunk) + const decodedBytes = Buffer.byteLength(decoded, 'utf8') + if (bufferBytes + decodedBytes > options.maxFrameBytes) { + finish(new Error('UDS response frame exceeded size limit')) + return + } + + buffer += decoded + bufferBytes += decodedBytes + let newlineIndex = buffer.indexOf('\n') + while (newlineIndex !== -1) { + const line = buffer.slice(0, newlineIndex) + const consumed = buffer.slice(0, newlineIndex + 1) + buffer = buffer.slice(newlineIndex + 1) + bufferBytes -= Buffer.byteLength(consumed, 'utf8') + if (!line.trim()) { + newlineIndex = buffer.indexOf('\n') + continue + } + let response: UdsMessage + try { + response = parseResponseLine(line) + } catch (error) { + finish(error instanceof Error ? error : new Error(errorMessage(error))) + return + } + if ( + response.type === 'response' || + (options.acceptPong === true && response.type === 'pong') + ) { + finish() + return + } + if (response.type === 'error') { + finish(new Error(response.data ?? 'UDS receiver rejected message')) + return + } + newlineIndex = buffer.indexOf('\n') + } + } + + function onError(error: Error): void { + finish( + options.formatSocketError?.(error) ?? + (error instanceof Error ? error : new Error(errorMessage(error))), + ) + } + + function onEnd(): void { + finish(new Error('UDS socket ended before response')) + } + + function onClose(hadError: boolean): void { + if (hadError) return + finish(new Error('UDS socket closed before response')) + } + + socket.on('data', onData) + socket.on('error', onError) + socket.on('end', onEnd) + socket.on('close', onClose) +}