diff --git a/packages/core/src/code_assist/experiments/flagNames.ts b/packages/core/src/code_assist/experiments/flagNames.ts index 125ff005a97..1b0c4b34afd 100644 --- a/packages/core/src/code_assist/experiments/flagNames.ts +++ b/packages/core/src/code_assist/experiments/flagNames.ts @@ -20,6 +20,7 @@ export const ExperimentFlags = { PRO_MODEL_NO_ACCESS: 45768879, GEMINI_3_1_FLASH_LITE_LAUNCHED: 45771641, DEFAULT_REQUEST_TIMEOUT: 45773134, + COMPRESSION_STRATEGY: 45768880, } as const; export type ExperimentFlagName = diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 9dbf0f8115d..8e0f35ed731 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -3073,6 +3073,17 @@ export class Config implements McpContext, AgentLoopContext { return remoteThreshold; } + async getCompressionStrategy(): Promise<'flat' | 'union-find'> { + await this.ensureExperimentsLoaded(); + const remoteStrategy = + this.experiments?.flags[ExperimentFlags.COMPRESSION_STRATEGY] + ?.stringValue; + if (remoteStrategy === 'union-find' || remoteStrategy === 'flat') { + return remoteStrategy; + } + return 'flat'; + } + async getUserCaching(): Promise { await this.ensureExperimentsLoaded(); diff --git a/packages/core/src/context/agentHistoryProvider.ts b/packages/core/src/context/agentHistoryProvider.ts index 94218088473..dcb711d1d84 100644 --- a/packages/core/src/context/agentHistoryProvider.ts +++ b/packages/core/src/context/agentHistoryProvider.ts @@ -19,6 +19,7 @@ import { truncateProportionally, normalizeFunctionResponse, } from './truncation.js'; +import { sanitizePromptValue } from '../utils/sanitizePromptInput.js'; export class AgentHistoryProvider { // TODO(joshualitt): just pass the BaseLlmClient instead of the whole Config. @@ -379,10 +380,10 @@ Distill these into a high-density Markdown block that orientates the agent on th - **Brevity:** Maximum 15 lines. No conversational preamble. ${hasPreviousSummary ? 'PREVIOUS SUMMARY AND TRUNCATED HISTORY:' : 'TRUNCATED HISTORY:'} -${JSON.stringify(messagesToTruncate)} +${sanitizePromptValue(messagesToTruncate)} ACTIVE BRIDGE (LOOKAHEAD): -${JSON.stringify(bridge)}`; +${sanitizePromptValue(bridge)}`; const summaryResponse = await this.config .getBaseLlmClient() diff --git a/packages/core/src/context/chatCompressionService.test.ts b/packages/core/src/context/chatCompressionService.test.ts index c4f26dedc08..1af44add3e2 100644 --- a/packages/core/src/context/chatCompressionService.test.ts +++ b/packages/core/src/context/chatCompressionService.test.ts @@ -29,6 +29,13 @@ vi.mock('../telemetry/loggers.js'); vi.mock('../utils/environmentContext.js'); vi.mock('../core/tokenLimits.js'); +function makeHistory(n: number): Content[] { + return Array.from({ length: n }, (_, i) => ({ + role: (i % 2 === 0 ? 'user' : 'model'), + parts: [{ text: `message ${i}` }], + })); +} + describe('findCompressSplitPoint', () => { it('should throw an error for non-positive numbers', () => { expect(() => findCompressSplitPoint([], 0)).toThrow( @@ -193,6 +200,7 @@ describe('ChatCompressionService', () => { getProjectTempDir: vi.fn().mockReturnValue(testTempDir), }, getApprovedPlanPath: vi.fn().mockReturnValue('/path/to/plan.md'), + getCompressionStrategy: vi.fn().mockReturnValue('flat'), } as unknown as Config; vi.mocked(getInitialChatHistory).mockImplementation( @@ -897,4 +905,139 @@ describe('ChatCompressionService', () => { ); }); }); + + describe('Compression strategy dispatch', () => { + it('should route to flat compression when strategy is flat', async () => { + vi.mocked( + mockConfig as unknown as { + getCompressionStrategy: () => 'flat' | 'union-find'; + }, + ).getCompressionStrategy = vi.fn().mockReturnValue('flat'); + + const history: Content[] = [ + { role: 'user', parts: [{ text: 'msg1' }] }, + { role: 'model', parts: [{ text: 'msg2' }] }, + { role: 'user', parts: [{ text: 'msg3' }] }, + { role: 'model', parts: [{ text: 'msg4' }] }, + ]; + vi.mocked(mockChat.getHistory).mockReturnValue(history); + vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600000); + + const result = await service.compress( + mockChat, + mockPromptId, + false, + mockModel, + mockConfig, + false, + ); + + expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED); + // Flat uses 2 LLM calls (generate + verify) + expect( + mockConfig.getBaseLlmClient().generateContent, + ).toHaveBeenCalledTimes(2); + }); + + it('should route to union-find compression when strategy is union-find', async () => { + vi.mocked( + mockConfig as unknown as { + getCompressionStrategy: () => 'flat' | 'union-find'; + }, + ).getCompressionStrategy = vi.fn().mockReturnValue('union-find'); + + // Mock for cluster summarization + const mockLlmClient = { + generateContent: vi.fn().mockResolvedValue({ + candidates: [ + { + content: { + parts: [{ text: 'Cluster summary' }], + }, + }, + ], + } as unknown as GenerateContentResponse), + }; + vi.mocked(mockConfig.getBaseLlmClient).mockReturnValue( + mockLlmClient as unknown as BaseLlmClient, + ); + + const history = makeHistory(35); + vi.mocked(mockChat.getHistory).mockReturnValue(history); + vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600000); + + const result = await service.compress( + mockChat, + mockPromptId, + false, + mockModel, + mockConfig, + false, + ); + + expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED); + expect(result.newHistory).not.toBeNull(); + }); + + it('should return valid ChatCompressionInfo for union-find path', async () => { + vi.mocked( + mockConfig as unknown as { + getCompressionStrategy: () => 'flat' | 'union-find'; + }, + ).getCompressionStrategy = vi.fn().mockReturnValue('union-find'); + + const mockLlmClient = { + generateContent: vi.fn().mockResolvedValue({ + candidates: [ + { + content: { + parts: [{ text: 'Summary' }], + }, + }, + ], + } as unknown as GenerateContentResponse), + }; + vi.mocked(mockConfig.getBaseLlmClient).mockReturnValue( + mockLlmClient as unknown as BaseLlmClient, + ); + + const history = makeHistory(35); + vi.mocked(mockChat.getHistory).mockReturnValue(history); + vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600000); + + const result = await service.compress( + mockChat, + mockPromptId, + false, + mockModel, + mockConfig, + false, + ); + + expect(result.info.originalTokenCount).toBe(600000); + expect(result.info.newTokenCount).toBeDefined(); + expect(typeof result.info.newTokenCount).toBe('number'); + }); + + it('should return NOOP for union-find with empty history', async () => { + vi.mocked( + mockConfig as unknown as { + getCompressionStrategy: () => 'flat' | 'union-find'; + }, + ).getCompressionStrategy = vi.fn().mockReturnValue('union-find'); + vi.mocked(mockChat.getHistory).mockReturnValue([]); + + const result = await service.compress( + mockChat, + mockPromptId, + false, + mockModel, + mockConfig, + false, + ); + + expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP); + expect(result.newHistory).toBeNull(); + }); + }); }); diff --git a/packages/core/src/context/chatCompressionService.ts b/packages/core/src/context/chatCompressionService.ts index 992ca67cf94..c2c4a18610a 100644 --- a/packages/core/src/context/chatCompressionService.ts +++ b/packages/core/src/context/chatCompressionService.ts @@ -33,6 +33,10 @@ import { PREVIEW_GEMINI_3_1_FLASH_LITE_MODEL, } from '../config/models.js'; import { PreCompressTrigger } from '../hooks/types.js'; +import { ContextWindow } from '../services/contextWindow.js'; +import { TFIDFEmbedder } from '../services/embeddingService.js'; +import { ClusterSummarizer } from '../services/clusterSummarizer.js'; +import { sanitizePromptString } from '../utils/sanitizePromptInput.js'; /** * Default threshold for compression token count as a fraction of the model's @@ -159,14 +163,12 @@ async function truncateHistoryToBudget( } else if (responseObj && typeof responseObj === 'object') { if ( 'output' in responseObj && - // eslint-disable-next-line no-restricted-syntax - typeof responseObj['output'] === 'string' + typeof responseObj['output'] === 'string' // eslint-disable-line no-restricted-syntax ) { contentStr = responseObj['output']; } else if ( 'content' in responseObj && - // eslint-disable-next-line no-restricted-syntax - typeof responseObj['content'] === 'string' + typeof responseObj['content'] === 'string' // eslint-disable-line no-restricted-syntax ) { contentStr = responseObj['content']; } else { @@ -234,6 +236,19 @@ async function truncateHistoryToBudget( return truncatedHistory; } +// graduateAt < evictAt: messages enter the cold forest early so clusters +// have time to form before the hot zone forces an eviction. +const UNION_FIND_GRADUATE_AT = 26; +const UNION_FIND_EVICT_AT = 30; +const UNION_FIND_MAX_COLD_CLUSTERS = 10; +const UNION_FIND_MERGE_THRESHOLD = 0.15; +const UNION_FIND_RETRIEVE_K = 3; +const UNION_FIND_RETRIEVE_MIN_SIM = 0.05; + +// Cap tool-response previews fed to the embedder/summarizer. Full output +// is preserved in the original Content objects for the hot zone. +const TOOL_RESPONSE_PREVIEW_GRAPHEMES = 500; + export class ChatCompressionService { async compress( chat: GeminiChat, @@ -243,6 +258,40 @@ export class ChatCompressionService { config: Config, hasFailedCompressionAttempt: boolean, abortSignal?: AbortSignal, + ): Promise<{ newHistory: Content[] | null; info: ChatCompressionInfo }> { + const strategy = await config.getCompressionStrategy(); + + if (strategy === 'union-find') { + return this.compactWithUnionFind( + chat, + promptId, + force, + model, + config, + hasFailedCompressionAttempt, + abortSignal, + ); + } + + return this.compressWithFlat( + chat, + promptId, + force, + model, + config, + hasFailedCompressionAttempt, + abortSignal, + ); + } + + private async compressWithFlat( + chat: GeminiChat, + promptId: string, + force: boolean, + model: string, + config: Config, + hasFailedCompressionAttempt: boolean, + abortSignal?: AbortSignal, ): Promise<{ newHistory: Content[] | null; info: ChatCompressionInfo }> { const curatedHistory = chat.getHistory(true); @@ -476,4 +525,207 @@ export class ChatCompressionService { }; } } + + private async compactWithUnionFind( + chat: GeminiChat, + promptId: string, + force: boolean, + model: string, + config: Config, + hasFailedCompressionAttempt: boolean, + abortSignal?: AbortSignal, + ): Promise<{ newHistory: Content[] | null; info: ChatCompressionInfo }> { + const curatedHistory = chat.getHistory(true); + + if (curatedHistory.length === 0) { + return { + newHistory: null, + info: { + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: CompressionStatus.NOOP, + }, + }; + } + + const trigger = force ? PreCompressTrigger.Manual : PreCompressTrigger.Auto; + await config.getHookSystem()?.firePreCompressEvent(trigger); + + const originalTokenCount = chat.getLastPromptTokenCount(); + + if (!force) { + const threshold = + (await config.getCompressionThreshold()) ?? + DEFAULT_COMPRESSION_TOKEN_THRESHOLD; + if (originalTokenCount < threshold * tokenLimit(model)) { + return { + newHistory: null, + info: { + originalTokenCount, + newTokenCount: originalTokenCount, + compressionStatus: CompressionStatus.NOOP, + }, + }; + } + } + + // Apply token-based truncation before graduation + const truncatedHistory = await truncateHistoryToBudget( + curatedHistory, + config, + ); + + // If summarization previously failed, fall back to truncation only + if (hasFailedCompressionAttempt && !force) { + const truncatedTokenCount = estimateTokenCountSync( + truncatedHistory.flatMap((c) => c.parts || []), + ); + if (truncatedTokenCount < originalTokenCount) { + return { + newHistory: truncatedHistory, + info: { + originalTokenCount, + newTokenCount: truncatedTokenCount, + compressionStatus: CompressionStatus.CONTENT_TRUNCATED, + }, + }; + } + return { + newHistory: null, + info: { + originalTokenCount, + newTokenCount: originalTokenCount, + compressionStatus: CompressionStatus.NOOP, + }, + }; + } + + // Build ContextWindow from truncated history + const embedder = new TFIDFEmbedder(); + const summarizer = new ClusterSummarizer( + config.getBaseLlmClient(), + modelStringToModelConfigAlias(model), + abortSignal, + ); + const contextWindow = new ContextWindow(embedder, summarizer, { + graduateAt: UNION_FIND_GRADUATE_AT, + evictAt: UNION_FIND_EVICT_AT, + maxColdClusters: UNION_FIND_MAX_COLD_CLUSTERS, + mergeThreshold: UNION_FIND_MERGE_THRESHOLD, + }); + + // Feed all truncated history into the context window. Sanitize untrusted + // tool output so it cannot inject the `` delimiter the + // renderer wraps cold summaries in. Trusted conversation text (p.text) is + // passed through unchanged to preserve formatting in the reconstructed + // history; the cold summaries themselves are sanitized at the wrap site. + for (const content of truncatedHistory) { + const text = content.parts + ?.map((p) => { + if (p.text) return p.text; + if (p.functionCall) return `[Tool call: ${p.functionCall.name}]`; + if (p.functionResponse) { + const responseStr = JSON.stringify( + p.functionResponse.response ?? '', + ); + const graphemes = Array.from(responseStr); + const preview = + graphemes.length > TOOL_RESPONSE_PREVIEW_GRAPHEMES + ? graphemes.slice(0, TOOL_RESPONSE_PREVIEW_GRAPHEMES).join('') + + '...' + : responseStr; + return `[Tool response: ${p.functionResponse.name}] ${sanitizePromptString(preview)}`; + } + if (p.fileData) { + return `[File attachment omitted: ${p.fileData.mimeType ?? 'unknown'}]`; + } + if (p.inlineData) { + return `[Inline data omitted: ${p.inlineData.mimeType ?? 'unknown'}]`; + } + return ''; + }) + .join(' ') + .trim(); + if (text) { + contextWindow.append(text); + } + } + + // Resolve dirty clusters before rendering so summaries are available + await contextWindow.resolveDirty(); + + const rendered = contextWindow.render( + null, + UNION_FIND_RETRIEVE_K, + UNION_FIND_RETRIEVE_MIN_SIM, + ); + + // Build new history: cold summaries as a single user message, then hot messages + const coldSummaries = rendered.slice( + 0, + rendered.length - contextWindow.hotCount, + ); + + const extraHistory: Content[] = []; + + if (coldSummaries.length > 0) { + extraHistory.push({ + role: 'user', + parts: [ + { + text: `\n${coldSummaries.map(sanitizePromptString).join('\n---\n')}\n`, + }, + ], + }); + extraHistory.push({ + role: 'model', + parts: [{ text: 'Got it. I have the context from previous messages.' }], + }); + } + + // Map hot messages back to their original Content objects + // Use the last N items from truncatedHistory where N = hotCount + const hotStart = Math.max( + 0, + truncatedHistory.length - contextWindow.hotCount, + ); + extraHistory.push(...truncatedHistory.slice(hotStart)); + + const fullNewHistory = await getInitialChatHistory(config, extraHistory); + + const newTokenCount = await calculateRequestTokenCount( + fullNewHistory.flatMap((c) => c.parts || []), + config.getContentGenerator(), + model, + ); + + logChatCompression( + config, + makeChatCompressionEvent({ + tokens_before: originalTokenCount, + tokens_after: newTokenCount, + }), + ); + + if (newTokenCount > originalTokenCount) { + return { + newHistory: null, + info: { + originalTokenCount, + newTokenCount, + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + }, + }; + } + + return { + newHistory: extraHistory, + info: { + originalTokenCount, + newTokenCount, + compressionStatus: CompressionStatus.COMPRESSED, + }, + }; + } } diff --git a/packages/core/src/context/toolDistillationService.ts b/packages/core/src/context/toolDistillationService.ts index 43ea12d7f16..2c28ea26cad 100644 --- a/packages/core/src/context/toolDistillationService.ts +++ b/packages/core/src/context/toolDistillationService.ts @@ -26,10 +26,11 @@ import { estimateCharsFromTokens, normalizeFunctionResponse, } from './truncation.js'; +import { sanitizePromptString } from '../utils/sanitizePromptInput.js'; -// Skip structural map generation for outputs larger than this threshold (in characters) -// as it consumes excessive tokens and may not be representative of the full content. -const MAX_DISTILLATION_SIZE = 1_000_000; +// ~16K tokens at 4 chars/token. Larger outputs overwhelm the utility +// compressor model with latency and cost for diminishing returns. +const MAX_DISTILLATION_CHARS = 64_000; export interface DistilledToolOutput { truncatedContent: PartListUnion; @@ -53,6 +54,7 @@ export class ToolOutputDistillationService { toolName: string, callId: string, content: PartListUnion, + abortSignal?: AbortSignal, ): Promise { // Explicitly bypass escape hatches that natively handle large outputs if (this.isExemptFromDistillation(toolName)) { @@ -74,6 +76,7 @@ export class ToolOutputDistillationService { content, originalContentLength, thresholdChars, + abortSignal, ); } @@ -119,6 +122,7 @@ export class ToolOutputDistillationService { content: PartListUnion, originalContentLength: number, threshold: number, + abortSignal?: AbortSignal, ): Promise { const stringifiedContent = this.stringifyContent(content); @@ -139,12 +143,13 @@ export class ToolOutputDistillationService { if ( originalContentLength > summarizationThresholdChars && - originalContentLength <= MAX_DISTILLATION_SIZE + originalContentLength <= MAX_DISTILLATION_CHARS ) { const summary = await this.generateIntentSummary( toolName, stringifiedContent, - Math.floor(MAX_DISTILLATION_SIZE), + Math.floor(MAX_DISTILLATION_CHARS), + abortSignal, ); if (summary) { @@ -254,10 +259,13 @@ export class ToolOutputDistillationService { toolName: string, stringifiedContent: string, maxPreviewLen: number, + abortSignal?: AbortSignal, ): Promise { try { - const controller = new AbortController(); - const timeoutId = setTimeout(() => controller.abort(), 15000); // 15s timeout + const timeoutSignal = AbortSignal.timeout(15000); + const summaryAbortSignal = abortSignal + ? AbortSignal.any([abortSignal, timeoutSignal]) + : timeoutSignal; const promptText = `The following output from the tool '${toolName}' is large and has been truncated. Extract the most critical factual information from this output so the main agent doesn't lose context. @@ -269,17 +277,15 @@ Focus strictly on concrete data points: Do not philosophize about the strategic intent. Keep the extraction under 10 lines and use exact quotes where helpful. Output to summarize: -${stringifiedContent.slice(0, maxPreviewLen)}...`; +${sanitizePromptString(Array.from(stringifiedContent).slice(0, maxPreviewLen).join(''))}...`; const summaryResponse = await this.geminiClient.generateContent( { model: 'agent-history-provider-summarizer' }, [{ role: 'user', parts: [{ text: promptText }] }], - controller.signal, + summaryAbortSignal, LlmRole.UTILITY_COMPRESSOR, ); - clearTimeout(timeoutId); - return summaryResponse.candidates?.[0]?.content?.parts?.[0]?.text; } catch (e) { // Fail gracefully, summarization is a progressive enhancement diff --git a/packages/core/src/services/clusterSummarizer.test.ts b/packages/core/src/services/clusterSummarizer.test.ts new file mode 100644 index 00000000000..67838678d9e --- /dev/null +++ b/packages/core/src/services/clusterSummarizer.test.ts @@ -0,0 +1,103 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi } from 'vitest'; +import { ClusterSummarizer } from './clusterSummarizer.js'; +import type { BaseLlmClient } from '../core/baseLlmClient.js'; +import type { GenerateContentResponse } from '@google/genai'; + +function mockLlmClient(responseText: string): BaseLlmClient { + return { + generateContent: vi.fn().mockResolvedValue({ + candidates: [ + { + content: { + parts: [{ text: responseText }], + }, + }, + ], + } as unknown as GenerateContentResponse), + } as unknown as BaseLlmClient; +} + +describe('ClusterSummarizer', () => { + it('should call generateContent with cluster messages', async () => { + const client = mockLlmClient('Summary of cluster'); + const summarizer = new ClusterSummarizer( + client, + 'chat-compression-default', + ); + + const result = await summarizer.summarize([ + 'User asked about file structure', + 'Model described src/ directory', + ]); + + expect(result).toBe('Summary of cluster'); + expect(client.generateContent).toHaveBeenCalledTimes(1); + + const callArgs = vi.mocked(client.generateContent).mock.calls[0][0]; + expect(callArgs.modelConfigKey).toEqual({ + model: 'chat-compression-default', + }); + // The user prompt should contain the messages + const userContent = callArgs.contents[callArgs.contents.length - 1]; + expect(userContent.parts![0].text).toContain( + 'User asked about file structure', + ); + }); + + it('should return fallback text on empty response', async () => { + const client = { + generateContent: vi.fn().mockResolvedValue({ + candidates: [ + { + content: { + parts: [{ text: '' }], + }, + }, + ], + } as unknown as GenerateContentResponse), + } as unknown as BaseLlmClient; + + const summarizer = new ClusterSummarizer( + client, + 'chat-compression-default', + ); + + const result = await summarizer.summarize(['msg1', 'msg2']); + // Should fall back to joining messages + expect(result).toContain('msg1'); + expect(result).toContain('msg2'); + }); + + it('should handle single message', async () => { + const client = mockLlmClient('Single message summary'); + const summarizer = new ClusterSummarizer( + client, + 'chat-compression-default', + ); + + const result = await summarizer.summarize(['just one message']); + expect(result).toBe('Single message summary'); + }); + + it('should handle generateContent throwing an error', async () => { + const client = { + generateContent: vi.fn().mockRejectedValue(new Error('API Error')), + } as unknown as BaseLlmClient; + + const summarizer = new ClusterSummarizer( + client, + 'chat-compression-default', + ); + + const result = await summarizer.summarize(['msg1', 'msg2']); + // Should fall back to joining messages on error + expect(result).toContain('msg1'); + expect(result).toContain('msg2'); + }); +}); diff --git a/packages/core/src/services/clusterSummarizer.ts b/packages/core/src/services/clusterSummarizer.ts new file mode 100644 index 00000000000..ac30d1dd086 --- /dev/null +++ b/packages/core/src/services/clusterSummarizer.ts @@ -0,0 +1,62 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { BaseLlmClient } from '../core/baseLlmClient.js'; +import { getResponseText } from '../utils/partUtils.js'; +import { LlmRole } from '../telemetry/types.js'; +import type { Summarizer } from './contextWindow.js'; +import { sanitizePromptString } from '../utils/sanitizePromptInput.js'; + +/** + * Cluster summarizer using BaseLlmClient for LLM-generated summaries. + * + * Single-phase summarization (no verification) since clusters are small. + */ +export class ClusterSummarizer implements Summarizer { + private _client: BaseLlmClient; + private _modelConfigKey: string; + private _abortSignal?: AbortSignal; + + constructor( + client: BaseLlmClient, + modelConfigKey: string, + abortSignal?: AbortSignal, + ) { + this._client = client; + this._modelConfigKey = modelConfigKey; + this._abortSignal = abortSignal; + } + + async summarize(messages: string[]): Promise { + const fallback = messages.join('\n---\n'); + + try { + const response = await this._client.generateContent({ + modelConfigKey: { model: this._modelConfigKey }, + contents: [ + { + role: 'user', + parts: [ + { + text: `Summarize the following conversation messages into a concise, information-dense paragraph. Preserve all specific technical details, file paths, tool results, variable names, and user constraints.\n\nMessages:\n${messages.map((m, i) => `[${i + 1}] ${sanitizePromptString(m)}`).join('\n')}`, + }, + ], + }, + ], + promptId: 'cluster-summarize', + role: LlmRole.UTILITY_COMPRESSOR, + abortSignal: this._abortSignal ?? new AbortController().signal, + }); + + const text = getResponseText(response)?.trim(); + if (!text) return fallback; + return text; + } catch (e) { + if (this._abortSignal?.aborted) throw e; + return fallback; + } + } +} diff --git a/packages/core/src/services/contextWindow.test.ts b/packages/core/src/services/contextWindow.test.ts new file mode 100644 index 00000000000..54d5926db30 --- /dev/null +++ b/packages/core/src/services/contextWindow.test.ts @@ -0,0 +1,853 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi } from 'vitest'; +import { + Forest, + ContextWindow, + cosineSimilarity, + findClosestPair, + type Embedder, + type Summarizer, +} from './contextWindow.js'; + +// -- Helpers -- + +/** Stub embedder: one-hot encoding based on first char code. */ +const stubEmbedder: Embedder = { + embed(text: string): number[] { + const vec = new Array(128).fill(0); + if (text.length > 0) { + vec[text.charCodeAt(0) % 128] = 1; + } + return vec; + }, +}; + +/** Stub summarizer: joins messages with ' | '. */ +const stubSummarizer: Summarizer = { + async summarize(messages: string[]): Promise { + return messages.join(' | '); + }, +}; + +// -- cosineSimilarity -- + +describe('cosineSimilarity', () => { + it('should return 1 for identical vectors', () => { + expect(cosineSimilarity([1, 0, 0], [1, 0, 0])).toBeCloseTo(1.0); + }); + + it('should return 0 for orthogonal vectors', () => { + expect(cosineSimilarity([1, 0], [0, 1])).toBeCloseTo(0.0); + }); + + it('should return 0 when a vector is zero', () => { + expect(cosineSimilarity([0, 0], [1, 1])).toBe(0.0); + }); + + it('should handle negative values', () => { + expect(cosineSimilarity([1, 0], [-1, 0])).toBeCloseTo(-1.0); + }); + + it('should handle mismatched dimensions without NaN', () => { + const short = [1, 0]; + const long = [1, 0, 0.5, 0.3]; + const sim = cosineSimilarity(long, short); + expect(Number.isNaN(sim)).toBe(false); + expect(sim).toBeGreaterThan(0); + // Symmetric + expect(cosineSimilarity(short, long)).toBeCloseTo(sim); + }); + + it('should return 0 for mismatched zero-overlap vectors', () => { + // short has values only in dims 0-1, long only in dims 2-3 + const a = [1, 0]; + const b = [0, 0, 1, 0]; + const sim = cosineSimilarity(a, b); + expect(Number.isNaN(sim)).toBe(false); + expect(sim).toBeCloseTo(0.0); + }); +}); + +// -- Forest -- + +describe('Forest', () => { + it('should insert a message as a singleton', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + const id = forest.insert(0, 'hello'); + expect(id).toBe(0); + expect(forest.find(0)).toBe(0); + expect(forest.size()).toBe(1); + expect(forest.clusterCount()).toBe(1); + }); + + it('should find root with path compression', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'a'); + forest.insert(1, 'b'); + forest.insert(2, 'c'); + + // Chain: 2 -> 1 -> 0 + forest.union(0, 1); + forest.union(0, 2); + + // After path compression, find(2) should return root directly + const root = forest.find(2); + expect(root).toBe(forest.find(0)); + }); + + it('should union by rank', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'a'); + forest.insert(1, 'b'); + forest.insert(2, 'c'); + + forest.union(0, 1); + const root01 = forest.find(0); + + forest.union(root01, 2); + // root01 had higher rank, so it should stay root + expect(forest.find(2)).toBe(root01); + }); + + it('should NOT call summarizer on union (synchronous, structural only)', () => { + const summarizer = { + summarize: vi.fn().mockResolvedValue('summary'), + }; + const forest = new Forest(stubEmbedder, summarizer); + forest.insert(0, 'alpha'); + forest.insert(1, 'beta'); + + const root = forest.union(0, 1); + + expect(summarizer.summarize).not.toHaveBeenCalled(); + // No summary generated yet — cluster is dirty + expect(forest.isDirty(root)).toBe(true); + }); + + it('should return number (not Promise) from union', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'a'); + forest.insert(1, 'b'); + + const result = forest.union(0, 1); + expect(typeof result).toBe('number'); + }); + + it('should mark cluster as dirty after union', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'a'); + forest.insert(1, 'b'); + + forest.union(0, 1); + const root = forest.find(0); + + expect(forest.isDirty(root)).toBe(true); + expect(forest.dirtyRoots()).toContain(root); + }); + + it('should not mark singleton as dirty', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'hello'); + + expect(forest.isDirty(0)).toBe(false); + expect(forest.dirtyRoots()).toHaveLength(0); + }); + + it('should resolve dirty clusters via resolveDirty', async () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'alpha'); + forest.insert(1, 'beta'); + forest.union(0, 1); + + await forest.resolveDirty(); + + const root = forest.find(0); + expect(forest.isDirty(root)).toBe(false); + const summary = forest.summary(root); + expect(summary).toContain('alpha'); + expect(summary).toContain('beta'); + }); + + it('should pass raw content of singletons to summarizer when resolving', async () => { + const recorder: string[][] = []; + const recSummarizer: Summarizer = { + async summarize(messages: string[]): Promise { + recorder.push([...messages]); + return messages.join('; '); + }, + }; + const forest = new Forest(stubEmbedder, recSummarizer); + forest.insert(0, 'msg0'); + forest.insert(1, 'msg1'); + forest.union(0, 1); + + await forest.resolveDirty(); + + // Both raw messages should be passed to summarizer + expect(recorder).toHaveLength(1); + expect(recorder[0]).toContain('msg0'); + expect(recorder[0]).toContain('msg1'); + }); + + it('should pass clean summary + new raw messages after second resolve', async () => { + const recorder: string[][] = []; + const recSummarizer: Summarizer = { + async summarize(messages: string[]): Promise { + recorder.push([...messages]); + return messages.join('; '); + }, + }; + const forest = new Forest(stubEmbedder, recSummarizer); + forest.insert(0, 'msg0'); + forest.insert(1, 'msg1'); + forest.union(0, 1); + + await forest.resolveDirty(); // resolve: summarize([msg0, msg1]) + + // Now insert a third and merge into the same cluster + forest.insert(2, 'msg2'); + forest.union(forest.find(0), 2); + + await forest.resolveDirty(); // resolve: summarize([cleanSummary, msg2]) + + const lastCall = recorder[recorder.length - 1]; + // First item should be the clean summary from first resolve + expect(lastCall[0]).toContain('msg0'); + expect(lastCall[0]).toContain('msg1'); + // Second item should be the new raw message + expect(lastCall[1]).toBe('msg2'); + }); + + it('should batch multiple dirty clusters in one resolveDirty call', async () => { + const summarizer = { + summarize: vi.fn().mockResolvedValue('summary'), + }; + const forest = new Forest(stubEmbedder, summarizer); + + // Create 3 separate pairs + forest.insert(0, 'a0'); + forest.insert(1, 'a1'); + forest.union(0, 1); + + forest.insert(2, 'b0'); + forest.insert(3, 'b1'); + forest.union(2, 3); + + forest.insert(4, 'c0'); + forest.insert(5, 'c1'); + forest.union(4, 5); + + expect(forest.dirtyRoots()).toHaveLength(3); + + await forest.resolveDirty(); + + expect(summarizer.summarize).toHaveBeenCalledTimes(3); + expect(forest.dirtyRoots()).toHaveLength(0); + }); + + it('should merge two clean cluster summaries on union', async () => { + const recorder: string[][] = []; + const recSummarizer: Summarizer = { + async summarize(messages: string[]): Promise { + recorder.push([...messages]); + return `summary(${messages.join('+')})`; + }, + }; + const forest = new Forest(stubEmbedder, recSummarizer); + + // Create and resolve two separate clusters + forest.insert(0, 'a0'); + forest.insert(1, 'a1'); + forest.union(0, 1); + await forest.resolveDirty(); + const summaryA = forest.summary(forest.find(0))!; + + forest.insert(2, 'b0'); + forest.insert(3, 'b1'); + forest.union(2, 3); + await forest.resolveDirty(); + const summaryB = forest.summary(forest.find(2))!; + + // Merge two clean clusters + forest.union(forest.find(0), forest.find(2)); + + await forest.resolveDirty(); + + // The last summarize call should receive both summaries + const lastCall = recorder[recorder.length - 1]; + expect(lastCall).toContain(summaryA); + expect(lastCall).toContain(summaryB); + }); + + it('should update centroid on union', () => { + const embedder: Embedder = { + embed(text: string): number[] { + return text === 'a' ? [1, 0] : [0, 1]; + }, + }; + const forest = new Forest(embedder, stubSummarizer); + forest.insert(0, 'a'); + forest.insert(1, 'b'); + forest.union(0, 1); + // Centroid should be average: [0.5, 0.5] + const root = forest.find(0); + const roots = forest.nearest([0.5, 0.5], 1); + expect(roots).toContain(root); + }); + + it('should handle centroid merging with mismatched embedding dimensions', () => { + const embedder: Embedder = { + embed(text: string): number[] { + // Simulate growing vocab: earlier messages have shorter embeddings + if (text === 'early') return [1, 0]; + return [0.5, 0.5, 0.3]; // later messages have longer embeddings + }, + }; + const forest = new Forest(embedder, stubSummarizer); + forest.insert(0, 'early'); + forest.insert(1, 'later'); + forest.union(0, 1); + + const root = forest.find(0); + const centroid = forest.getCentroid(root); + expect(centroid).toBeDefined(); + expect(centroid!.every((v) => !Number.isNaN(v))).toBe(true); + // Merged centroid should have max dimension length + expect(centroid!.length).toBe(3); + }); + + it('should return no-op for union of same cluster', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'a'); + const root = forest.union(0, 0); + expect(root).toBe(0); + expect(forest.clusterCount()).toBe(1); + expect(forest.isDirty(0)).toBe(false); + }); + + it('should compact a singleton to its content', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'hello world'); + expect(forest.compact(0)).toBe('hello world'); + }); + + it('should compact a resolved cluster to its summary', async () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'foo'); + forest.insert(1, 'bar'); + forest.union(0, 1); + + await forest.resolveDirty(); + + const root = forest.find(0); + expect(forest.compact(root)).toContain('foo'); + expect(forest.compact(root)).toContain('bar'); + }); + + it('should compact a dirty cluster to stale summary or raw content', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'foo'); + forest.insert(1, 'bar'); + forest.union(0, 1); + + // Not resolved yet — compact returns raw content of root node + const root = forest.find(0); + const compacted = forest.compact(root); + // Should be either 'foo' or 'bar' (whichever is root) + expect(['foo', 'bar']).toContain(compacted); + }); + + it('should expand a cluster to source messages', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'x'); + forest.insert(1, 'y'); + forest.union(0, 1); + const root = forest.find(0); + const expanded = forest.expand(root); + expect(expanded).toContain('x'); + expect(expanded).toContain('y'); + }); + + it('should retrieve nearest roots by cosine similarity', () => { + const embedder: Embedder = { + embed(text: string): number[] { + if (text.startsWith('cat')) return [1, 0, 0]; + if (text.startsWith('dog')) return [0.9, 0.1, 0]; + return [0, 0, 1]; + }, + }; + const forest = new Forest(embedder, stubSummarizer); + forest.insert(0, 'cat food'); + forest.insert(1, 'dog park'); + forest.insert(2, 'javascript'); + + const results = forest.nearest([1, 0, 0], 2); + expect(results).toHaveLength(2); + // 'cat food' should be closest + expect(results[0]).toBe(0); + }); + + it('should filter by min_sim in nearest', () => { + const embedder: Embedder = { + embed(text: string): number[] { + return text === 'match' ? [1, 0] : [0, 1]; + }, + }; + const forest = new Forest(embedder, stubSummarizer); + forest.insert(0, 'match'); + forest.insert(1, 'other'); + + // Query for [1, 0] with high min_sim + const results = forest.nearest([1, 0], 5, 0.9); + expect(results).toHaveLength(1); + expect(results[0]).toBe(0); + }); + + it('should return nearestRoot', () => { + const embedder: Embedder = { + embed(text: string): number[] { + return text === 'a' ? [1, 0] : [0, 1]; + }, + }; + const forest = new Forest(embedder, stubSummarizer); + forest.insert(0, 'a'); + forest.insert(1, 'b'); + + const result = forest.nearestRoot([1, 0]); + expect(result).not.toBeNull(); + expect(result![0]).toBe(0); + expect(result![1]).toBeCloseTo(1.0); + }); + + it('should return null for nearestRoot on empty forest', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + expect(forest.nearestRoot([1, 0])).toBeNull(); + }); + + it('should list members of a cluster', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'a'); + forest.insert(1, 'b'); + forest.union(0, 1); + const root = forest.find(0); + const members = forest.members(root); + expect(members).toContain(0); + expect(members).toContain(1); + }); + + it('should not drop dirty state added by union during resolveDirty', async () => { + const slowSummarizer: Summarizer = { + async summarize(messages: string[]): Promise { + // Simulate slow LLM call + await new Promise((resolve) => setTimeout(resolve, 10)); + return messages.join('; '); + }, + }; + const forest = new Forest(stubEmbedder, slowSummarizer); + forest.insert(0, 'a'); + forest.insert(1, 'b'); + forest.union(0, 1); // dirty cluster {0,1} + + // Start resolving — the await inside gives us a window + const resolvePromise = forest.resolveDirty(); + + // While resolve is in flight, add new dirty state + forest.insert(2, 'c'); + forest.insert(3, 'd'); + forest.union(2, 3); // new dirty cluster {2,3} + + await resolvePromise; + + // The new dirty cluster should NOT have been wiped + expect(forest.isDirty(forest.find(2))).toBe(true); + + // Resolve it now + await forest.resolveDirty(); + expect(forest.isDirty(forest.find(2))).toBe(false); + }); + + it('should not overwrite merged cluster dirty state when in-flight root is merged', async () => { + const slowSummarizer: Summarizer = { + async summarize(messages: string[]): Promise { + await new Promise((resolve) => setTimeout(resolve, 10)); + return messages.join('; '); + }, + }; + const forest = new Forest(stubEmbedder, slowSummarizer); + forest.insert(0, 'a'); + forest.insert(1, 'b'); + forest.union(0, 1); // dirty cluster {0,1} + const originalRoot = forest.find(0); + + // Start resolving {0,1} + const resolvePromise = forest.resolveDirty(); + + // While {0,1} is being summarized, merge it into a new cluster + forest.insert(2, 'c'); + forest.union(originalRoot, 2); // now {0,1,2} is dirty with combined inputs + + await resolvePromise; + + // The merged cluster should still be dirty — the stale summary + // from the in-flight call should NOT have resolved it + const mergedRoot = forest.find(0); + expect(forest.isDirty(mergedRoot)).toBe(true); + + // Resolve it properly now + await forest.resolveDirty(); + expect(forest.isDirty(forest.find(0))).toBe(false); + // Summary should include all three messages + const summary = forest.summary(forest.find(0))!; + expect(summary).toBeDefined(); + }); + + it('should list all roots', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'a'); + forest.insert(1, 'b'); + forest.insert(2, 'c'); + forest.union(0, 1); + const roots = forest.roots(); + expect(roots).toHaveLength(2); + }); +}); + +// -- findClosestPair -- + +describe('findClosestPair', () => { + it('should return null for fewer than 2 roots', () => { + const forest = new Forest(stubEmbedder, stubSummarizer); + forest.insert(0, 'a'); + expect(findClosestPair(forest)).toBeNull(); + }); + + it('should find the closest pair', () => { + const embedder: Embedder = { + embed(text: string): number[] { + if (text === 'a') return [1, 0, 0]; + if (text === 'b') return [0.95, 0.05, 0]; + return [0, 0, 1]; + }, + }; + const forest = new Forest(embedder, stubSummarizer); + forest.insert(0, 'a'); + forest.insert(1, 'b'); + forest.insert(2, 'c'); + + const pair = findClosestPair(forest); + expect(pair).not.toBeNull(); + // a and b are closest + expect(pair).toContain(0); + expect(pair).toContain(1); + }); +}); + +// -- ContextWindow -- + +describe('ContextWindow', () => { + it('should keep messages in hot zone when under graduateAt', () => { + const cw = new ContextWindow(stubEmbedder, stubSummarizer, { + graduateAt: 5, + evictAt: 7, + }); + cw.append('msg1'); + cw.append('msg2'); + expect(cw.hotCount).toBe(2); + expect(cw.coldClusterCount).toBe(0); + }); + + it('append should return a number (synchronous, not a Promise)', () => { + const cw = new ContextWindow(stubEmbedder, stubSummarizer); + const result = cw.append('test'); + expect(typeof result).toBe('number'); + }); + + it('should graduate oldest when hot exceeds graduateAt (overlap window)', () => { + const cw = new ContextWindow(stubEmbedder, stubSummarizer, { + graduateAt: 3, + evictAt: 5, + maxColdClusters: 10, + mergeThreshold: 0.0, // never merge by similarity + }); + + cw.append('msg0'); + cw.append('msg1'); + cw.append('msg2'); + cw.append('msg3'); // msg0 graduates but stays in hot (overlap) + + expect(cw.hotCount).toBe(4); // msg0 still in hot + expect(cw.coldClusterCount).toBe(1); // msg0 also in cold + }); + + it('should evict from hot when hot exceeds evictAt', () => { + const cw = new ContextWindow(stubEmbedder, stubSummarizer, { + graduateAt: 2, + evictAt: 4, + maxColdClusters: 10, + mergeThreshold: 0.0, + }); + + cw.append('msg0'); + cw.append('msg1'); + cw.append('msg2'); // msg0 graduates, stays in hot (3 items, overlap) + cw.append('msg3'); // msg1 graduates, stays in hot (4 items) + cw.append('msg4'); // msg2 graduates, msg0 evicted from hot (5>4) + + expect(cw.hotCount).toBe(4); // msg1, msg2, msg3, msg4 + expect(cw.coldClusterCount).toBeGreaterThanOrEqual(1); + }); + + it('should never call summarizer during append', () => { + const summarizer = { + summarize: vi.fn().mockResolvedValue('summary'), + }; + const cw = new ContextWindow(stubEmbedder, summarizer, { + graduateAt: 2, + evictAt: 4, + maxColdClusters: 10, + }); + + for (let i = 0; i < 20; i++) { + cw.append(`msg${i}`); + } + + expect(summarizer.summarize).not.toHaveBeenCalled(); + }); + + it('should never call summarizer during render', () => { + const summarizer = { + summarize: vi.fn().mockResolvedValue('summary'), + }; + const cw = new ContextWindow(stubEmbedder, summarizer, { + graduateAt: 2, + evictAt: 4, + }); + + for (let i = 0; i < 10; i++) { + cw.append(`msg${i}`); + } + + const rendered = cw.render(); + expect(summarizer.summarize).not.toHaveBeenCalled(); + expect(rendered.length).toBeGreaterThan(0); + }); + + it('should merge graduated message into nearest cluster when similar', () => { + // Use an embedder that makes all messages identical + const sameEmbedder: Embedder = { + embed(): number[] { + return [1, 0, 0]; + }, + }; + + const cw = new ContextWindow(sameEmbedder, stubSummarizer, { + graduateAt: 2, + evictAt: 4, + maxColdClusters: 10, + mergeThreshold: 0.5, // will merge since similarity is 1.0 + }); + + cw.append('a'); + cw.append('b'); + cw.append('c'); // 'a' graduates as singleton + cw.append('d'); // 'b' graduates, merges with 'a' (sim = 1.0) + + expect(cw.coldClusterCount).toBe(1); // merged into one cluster + }); + + it('should enforce hard cap on cold clusters via forced merging', () => { + // Each message gets a unique embedding so nothing merges naturally + let counter = 0; + const uniqueEmbedder: Embedder = { + embed(): number[] { + const vec = new Array(10).fill(0); + vec[counter % 10] = 1; + counter++; + return vec; + }, + }; + + const cw = new ContextWindow(uniqueEmbedder, stubSummarizer, { + graduateAt: 2, + evictAt: 4, + maxColdClusters: 3, + mergeThreshold: 2.0, // never merge naturally (sim max is 1.0) + }); + + for (let i = 0; i < 10; i++) { + cw.append(`msg${i}`); + } + + expect(cw.coldClusterCount).toBeLessThanOrEqual(3); + }); + + it('should render cold summaries + hot messages without query', () => { + const cw = new ContextWindow(stubEmbedder, stubSummarizer, { + graduateAt: 2, + evictAt: 3, + maxColdClusters: 10, + mergeThreshold: 0.0, + }); + + cw.append('old1'); + cw.append('old2'); + cw.append('hot1'); // old1 graduates + cw.append('hot2'); // old2 graduates, old1 evicted + + const rendered = cw.render(); + expect(rendered.length).toBeGreaterThanOrEqual(2); + // Hot messages should be at the end + expect(rendered[rendered.length - 1]).toBe('hot2'); + }); + + it('should render with query-based retrieval', async () => { + const embedder: Embedder = { + embed(text: string): number[] { + if (text.includes('cat')) return [1, 0, 0]; + if (text.includes('dog')) return [0.9, 0.1, 0]; + return [0, 0, 1]; + }, + }; + + const cw = new ContextWindow(embedder, stubSummarizer, { + graduateAt: 2, + evictAt: 4, + maxColdClusters: 10, + mergeThreshold: 2.0, // never merge by similarity (max sim is 1.0) + }); + + cw.append('cat info'); + cw.append('dog info'); + cw.append('javascript info'); + cw.append('hot message 1'); + cw.append('hot message 2'); + + // Resolve so cold clusters have proper summaries + await cw.resolveDirty(); + + // Query about cats should retrieve cat cluster from cold + const rendered = cw.render('cat question', 1, 0.5); + expect(rendered.some((r) => r.includes('cat'))).toBe(true); + }); + + it('should resolveDirty to batch-summarize dirty clusters', async () => { + const summarizer = { + summarize: vi.fn().mockResolvedValue('resolved summary'), + }; + const cw = new ContextWindow(stubEmbedder, summarizer, { + graduateAt: 2, + evictAt: 4, + maxColdClusters: 10, + mergeThreshold: 0.0, + }); + + for (let i = 0; i < 10; i++) { + cw.append(`msg${i}`); + } + + expect(summarizer.summarize).not.toHaveBeenCalled(); + + await cw.resolveDirty(); + + // Should have called summarize for dirty clusters + expect(summarizer.summarize).toHaveBeenCalled(); + // After resolve, no dirty clusters + expect(cw.forest.dirtyRoots()).toHaveLength(0); + }); + + it('should show graduated messages verbatim in render via overlap window', () => { + const cw = new ContextWindow(stubEmbedder, stubSummarizer, { + graduateAt: 3, + evictAt: 5, + maxColdClusters: 10, + mergeThreshold: 0.0, + }); + + cw.append('msg0'); + cw.append('msg1'); + cw.append('msg2'); + cw.append('msg3'); // msg0 graduates but stays in hot + + // msg0 should appear in render as verbatim hot zone content + const rendered = cw.render(); + expect(rendered).toContain('msg0'); + // msg0 is also in cold, but it still appears from hot + expect(cw.coldClusterCount).toBe(1); + expect(cw.hotCount).toBe(4); + }); + + it('should return correct counts', () => { + const cw = new ContextWindow(stubEmbedder, stubSummarizer, { + graduateAt: 3, + evictAt: 5, + }); + + cw.append('a'); + cw.append('b'); + expect(cw.hotCount).toBe(2); + expect(cw.coldClusterCount).toBe(0); + expect(cw.totalMessages).toBe(2); + }); + + it('render(query) should not mutate the embedder corpus', () => { + const embedder = { + embed(text: string): number[] { + if (text.includes('cat')) return [1, 0, 0]; + return [0, 0, 1]; + }, + embedQuery: vi.fn().mockReturnValue([1, 0, 0]), + }; + + const cw = new ContextWindow(embedder, stubSummarizer, { + graduateAt: 2, + evictAt: 4, + maxColdClusters: 10, + mergeThreshold: 0.0, + }); + + cw.append('cat info'); + cw.append('dog info'); + cw.append('hot1'); + + // render with query should call embedQuery, not embed + cw.render('cat question', 1, 0.0); + expect(embedder.embedQuery).toHaveBeenCalledWith('cat question'); + }); + + it('should throw if evictAt < graduateAt', () => { + expect(() => { + new ContextWindow(stubEmbedder, stubSummarizer, { + graduateAt: 5, + evictAt: 3, + }); + }).toThrow('evictAt (3) must be >= graduateAt (5)'); + }); + + it('should expose forest for direct access', () => { + const cw = new ContextWindow(stubEmbedder, stubSummarizer); + expect(cw.forest).toBeInstanceOf(Forest); + }); + + it('should expand a cold cluster to source messages', () => { + const cw = new ContextWindow(stubEmbedder, stubSummarizer, { + graduateAt: 2, + evictAt: 3, + maxColdClusters: 10, + mergeThreshold: 0.0, + }); + + cw.append('graduated'); + cw.append('h1'); + cw.append('h2'); // 'graduated' enters cold + + const roots = cw.forest.roots(); + expect(roots.length).toBe(1); + const expanded = cw.expand(roots[0]); + expect(expanded).toContain('graduated'); + }); +}); diff --git a/packages/core/src/services/contextWindow.ts b/packages/core/src/services/contextWindow.ts new file mode 100644 index 00000000000..300da3feef5 --- /dev/null +++ b/packages/core/src/services/contextWindow.ts @@ -0,0 +1,497 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Union-find context compaction with overlap window and deferred summarization. + * + * Reading order for reviewers: + * 1. cosineSimilarity() — handles mismatched vector dimensions safely + * 2. Forest class — union-find with path compression, deferred summarization + * 3. ContextWindow class — overlap window, graduation/eviction + * 4. Integration: chatCompressionService.ts compactWithUnionFind() + * + * v2 architecture: + * - append() is synchronous — no LLM calls. Graduation triggers structural + * union() only. + * - render() is synchronous — returns cached summaries + hot zone messages. + * - resolveDirty() is async fire-and-forget — batch-summarizes dirty clusters + * in background during main LLM call wait. + * - Overlap window (graduateAt/evictAt): graduated messages stay in hot zone + * for ~2 turns. By the time they evict, background resolveDirty() has + * resolved their cluster summaries. + * + * Design doc: https://github.com/kimjune01/union-find-compaction-for-gemini-cli/blob/main/transformation-design.md + */ + +// -- Interfaces -- + +export interface Embedder { + embed(text: string): number[]; + /** Embed without mutating internal state. Used for queries/retrieval. */ + embedQuery?(text: string): number[]; +} + +export interface Summarizer { + summarize(messages: string[]): Promise; +} + +// -- Data structures -- + +export interface Message { + id: number; + content: string; + embedding: number[]; + timestamp: string | null; + _parent: number | null; + _rank: number; +} + +// -- Helpers -- + +// TF-IDF vocabulary grows over time, so newer vectors are longer than older +// ones. We handle mismatched dimensions by treating missing entries as zero: +// only shared dimensions contribute to the dot product, but trailing dimensions +// still contribute to the norm (lowering similarity, as expected). +export function cosineSimilarity(a: number[], b: number[]): number { + const len = Math.min(a.length, b.length); + let dot = 0; + let normA = 0; + let normB = 0; + for (let i = 0; i < len; i++) { + dot += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + // Include trailing dimensions from the longer vector in its norm + for (let i = len; i < a.length; i++) { + normA += a[i] * a[i]; + } + for (let i = len; i < b.length; i++) { + normB += b[i] * b[i]; + } + normA = Math.sqrt(normA); + normB = Math.sqrt(normB); + if (normA === 0 || normB === 0) return 0.0; + return dot / (normA * normB); +} + +export function findClosestPair(forest: Forest): [number, number] | null { + const roots = forest.roots(); + if (roots.length < 2) return null; + + let bestSim = -1.0; + let bestPair: [number, number] = [roots[0], roots[1]]; + + for (let i = 0; i < roots.length; i++) { + const ca = forest.getCentroid(roots[i]); + if (!ca) continue; + for (let j = i + 1; j < roots.length; j++) { + const cb = forest.getCentroid(roots[j]); + if (!cb) continue; + const sim = cosineSimilarity(ca, cb); + if (sim > bestSim) { + bestSim = sim; + bestPair = [roots[i], roots[j]]; + } + } + } + + return bestPair; +} + +// -- Forest -- + +export class Forest { + private _nodes: Map = new Map(); + private _summaries: Map = new Map(); + private _children: Map = new Map(); + private _centroids: Map = new Map(); + private _dirtyInputs: Map = new Map(); + private _embedder: Embedder; + private _summarizer: Summarizer; + + constructor(embedder: Embedder, summarizer: Summarizer) { + this._embedder = embedder; + this._summarizer = summarizer; + } + + insert( + msgId: number, + content: string, + embedding?: number[], + timestamp?: string | null, + ): number { + if (embedding === undefined) { + embedding = this._embedder.embed(content); + } + const msg: Message = { + id: msgId, + content, + embedding, + timestamp: timestamp ?? null, + _parent: null, + _rank: 0, + }; + this._nodes.set(msgId, msg); + this._children.set(msgId, [msgId]); + this._centroids.set(msgId, [...embedding]); + // Singleton is NOT dirty — raw content serves as compact() output. + return msgId; + } + + find(msgId: number): number { + const node = this._nodes.get(msgId); + if (!node) throw new Error(`Node ${msgId} not found`); + if (node._parent === null) return msgId; + const root = this.find(node._parent); + node._parent = root; // path compression + return root; + } + + /** + * Synchronous structural merge. No LLM calls. + * Merges parent pointers, children, centroids. + * Collects dirty inputs for later batch summarization. + */ + union(idA: number, idB: number): number { + let rootA = this.find(idA); + let rootB = this.find(idB); + if (rootA === rootB) return rootA; + + let nodeA = this._nodes.get(rootA)!; + let nodeB = this._nodes.get(rootB)!; + + // Union by rank + if (nodeA._rank < nodeB._rank) { + [rootA, rootB] = [rootB, rootA]; + [nodeA, nodeB] = [nodeB, nodeA]; + } + nodeB._parent = rootA; + if (nodeA._rank === nodeB._rank) { + nodeA._rank += 1; + } + + // Merge children lists + const membersB = this._children.get(rootB) ?? []; + this._children.delete(rootB); + const membersA = this._children.get(rootA) ?? []; + membersA.push(...membersB); + this._children.set(rootA, membersA); + + // Update centroid (weighted average) + const ca = this._centroids.get(rootA); + const cb = this._centroids.get(rootB); + this._centroids.delete(rootB); + + if (ca && cb) { + const na = membersA.length - membersB.length; + const nb = membersB.length; + const total = na + nb; + const maxLen = Math.max(ca.length, cb.length); + const merged = new Array(maxLen); + for (let i = 0; i < maxLen; i++) { + merged[i] = ((ca[i] ?? 0) * na + (cb[i] ?? 0) * nb) / total; + } + this._centroids.set(rootA, merged); + } + + // Collect dirty inputs: what A represents + what B represents + let inputsA: string[]; + if (this._dirtyInputs.has(rootA)) { + inputsA = this._dirtyInputs.get(rootA)!; + } else if (this._summaries.has(rootA)) { + inputsA = [this._summaries.get(rootA)!]; + } else { + inputsA = [nodeA.content]; + } + + let inputsB: string[]; + if (this._dirtyInputs.has(rootB)) { + inputsB = this._dirtyInputs.get(rootB)!; + } else if (this._summaries.has(rootB)) { + inputsB = [this._summaries.get(rootB)!]; + } else { + inputsB = [this._nodes.get(rootB)!.content]; + } + + this._dirtyInputs.set(rootA, [...inputsA, ...inputsB]); + this._dirtyInputs.delete(rootB); + this._summaries.delete(rootB); + + return rootA; + } + + /** + * Batch-summarize all dirty clusters. One LLM call per dirty root. + * Called as fire-and-forget after render(), runs during main LLM call wait. + * + * Concurrency safety: union() can run between awaits (JS is single-threaded + * but yields at each await). When union() merges into a dirty root, it + * replaces _dirtyInputs with a new array containing combined content. + * We detect this via reference equality (=== check on the inputs array). + * If the array changed, we skip — the combined entry resolves next call. + */ + async resolveDirty(): Promise { + const entries = [...this._dirtyInputs.entries()]; + for (const [root, inputs] of entries) { + if (!this._dirtyInputs.has(root)) continue; + const summary = await this._summarizer.summarize(inputs); + if (this._dirtyInputs.get(root) === inputs) { + this._summaries.set(root, summary); + this._dirtyInputs.delete(root); + } + // If union() replaced the inputs (merged new content into this root) + // or merged this root away, skip — the combined dirty entry will be + // resolved in a future resolveDirty() call. + } + } + + /** Whether this cluster has unsummarized content. */ + isDirty(rootId: number): boolean { + return this._dirtyInputs.has(this.find(rootId)); + } + + /** All roots with unsummarized content. */ + dirtyRoots(): number[] { + return [...this._dirtyInputs.keys()]; + } + + compact(rootId: number): string { + const root = this.find(rootId); + const summary = this._summaries.get(root); + if (summary === undefined) { + return this._nodes.get(root)!.content; + } + return summary; + } + + expand(rootId: number): string[] { + const root = this.find(rootId); + const memberIds = this._children.get(root) ?? [root]; + return memberIds.map((mid) => this._nodes.get(mid)!.content); + } + + nearest( + queryEmbedding: number[], + k: number = 3, + minSim: number = 0.0, + ): number[] { + const scored: Array<[number, number]> = []; + for (const root of this._children.keys()) { + const centroid = this._centroids.get(root); + if (centroid) { + const sim = cosineSimilarity(queryEmbedding, centroid); + if (sim >= minSim) { + scored.push([sim, root]); + } + } + } + scored.sort((a, b) => b[0] - a[0]); + return scored.slice(0, k).map(([, root]) => root); + } + + nearestRoot(queryEmbedding: number[]): [number, number] | null { + let bestSim = -1.0; + let bestRoot: number | null = null; + for (const root of this._children.keys()) { + const centroid = this._centroids.get(root); + if (centroid) { + const sim = cosineSimilarity(queryEmbedding, centroid); + if (sim > bestSim) { + bestSim = sim; + bestRoot = root; + } + } + } + if (bestRoot === null) return null; + return [bestRoot, bestSim]; + } + + roots(): number[] { + return [...this._children.keys()]; + } + + members(rootId: number): number[] { + const root = this.find(rootId); + return [...(this._children.get(root) ?? [root])]; + } + + summary(rootId: number): string | undefined { + const root = this.find(rootId); + return this._summaries.get(root); + } + + size(): number { + return this._nodes.size; + } + + clusterCount(): number { + return this._children.size; + } + + getCentroid(rootId: number): number[] | undefined { + return this._centroids.get(rootId); + } +} + +// -- ContextWindow -- + +export interface ContextWindowOptions { + graduateAt?: number; + evictAt?: number; + maxColdClusters?: number; + mergeThreshold?: number; +} + +export class ContextWindow { + private _embedder: Embedder; + private _forest: Forest; + private _hot: Message[] = []; + private _graduateAt: number; + private _evictAt: number; + private _maxColdClusters: number; + private _mergeThreshold: number; + private _nextId = 0; + private _graduatedIndex = 0; + + constructor( + embedder: Embedder, + summarizer: Summarizer, + options: ContextWindowOptions = {}, + ) { + this._embedder = embedder; + this._forest = new Forest(embedder, summarizer); + this._graduateAt = options.graduateAt ?? 26; + this._evictAt = options.evictAt ?? 30; + this._maxColdClusters = options.maxColdClusters ?? 10; + this._mergeThreshold = options.mergeThreshold ?? 0.15; + if (this._evictAt < this._graduateAt) { + throw new Error( + `evictAt (${this._evictAt}) must be >= graduateAt (${this._graduateAt})`, + ); + } + } + + /** + * Synchronous append. No LLM calls. + * Embeds locally (TF-IDF), pushes to hot, graduates and evicts as needed. + */ + append(content: string, timestamp?: string | null): number { + const msgId = this._nextId++; + const embedding = this._embedder.embed(content); + const msg: Message = { + id: msgId, + content, + embedding, + timestamp: timestamp ?? null, + _parent: null, + _rank: 0, + }; + this._hot.push(msg); + + // Graduate: ensure ungraduated count <= graduateAt + while (this._hot.length - this._graduatedIndex > this._graduateAt) { + this._graduate(this._hot[this._graduatedIndex]); + this._graduatedIndex++; + } + + // Evict: ensure hot.length <= evictAt + while (this._hot.length > this._evictAt) { + this._hot.shift(); + this._graduatedIndex--; + } + + return msgId; + } + + /** + * Synchronous graduation. No LLM calls. + * Inserts into forest, merges with nearest cluster if similar enough, + * enforces hard cap on cluster count. + */ + private _graduate(msg: Message): void { + this._forest.insert(msg.id, msg.content, msg.embedding, msg.timestamp); + + if (this._forest.clusterCount() <= 1) return; + + // Find nearest existing e-class (excluding the singleton we just inserted) + const nearest = this._forest.nearest(msg.embedding, 2); + if (nearest.length === 0) return; + + const nearestRoot = nearest[0] === msg.id ? nearest[1] : nearest[0]; + if (nearestRoot === undefined) return; + + const centroid = this._forest.getCentroid(nearestRoot); + if (!centroid) return; + const sim = cosineSimilarity(msg.embedding, centroid); + + if (sim >= this._mergeThreshold) { + this._forest.union(msg.id, nearestRoot); + } + + // Enforce hard cap on cluster count + while (this._forest.clusterCount() > this._maxColdClusters) { + const pair = findClosestPair(this._forest); + if (!pair) break; + this._forest.union(pair[0], pair[1]); + } + } + + /** + * Synchronous render. No LLM calls. + * Returns cached cold summaries + hot zone messages. + * Overlap window ensures graduated messages still appear verbatim from hot. + */ + render( + query?: string | null, + k: number = 3, + minSim: number = 0.05, + ): string[] { + let cold: string[]; + + if (query != null && this._forest.clusterCount() > 0) { + // Use embedQuery (non-mutating) to avoid contaminating the TF-IDF corpus. + // embed() would add query terms to the vocabulary, changing future embeddings. + const embedFn = this._embedder.embedQuery ?? this._embedder.embed; + const queryEmb = embedFn.call(this._embedder, query); + const topRoots = this._forest.nearest(queryEmb, k, minSim); + cold = topRoots.map((r) => this._forest.compact(r)); + } else { + cold = this._forest.roots().map((r) => this._forest.compact(r)); + } + + const hot = this._hot.map((m) => m.content); + return [...cold, ...hot]; + } + + /** + * Async fire-and-forget. Batch-summarizes dirty clusters via LLM calls. + * Called after render(), runs during main LLM call wait. + */ + async resolveDirty(): Promise { + await this._forest.resolveDirty(); + } + + expand(rootId: number): string[] { + return this._forest.expand(rootId); + } + + get hotCount(): number { + return this._hot.length; + } + + get coldClusterCount(): number { + return this._forest.clusterCount(); + } + + get totalMessages(): number { + return this._forest.size() + this._hot.length; + } + + get forest(): Forest { + return this._forest; + } +} diff --git a/packages/core/src/services/embeddingService.test.ts b/packages/core/src/services/embeddingService.test.ts new file mode 100644 index 00000000000..3b6b8591fb8 --- /dev/null +++ b/packages/core/src/services/embeddingService.test.ts @@ -0,0 +1,123 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { TFIDFEmbedder } from './embeddingService.js'; +import { cosineSimilarity } from './contextWindow.js'; + +describe('TFIDFEmbedder', () => { + it('should produce a non-zero vector for non-empty text', () => { + const embedder = new TFIDFEmbedder(); + const vec = embedder.embed('hello world'); + expect(vec.some((v) => v !== 0)).toBe(true); + }); + + it('should produce a zero vector for empty text', () => { + const embedder = new TFIDFEmbedder(); + const vec = embedder.embed(''); + expect(vec.every((v) => v === 0)).toBe(true); + }); + + it('should produce similar embeddings for similar texts', () => { + const embedder = new TFIDFEmbedder(); + // Embed related texts to build vocabulary + embedder.embed('the cat sat on the mat'); + embedder.embed('the dog ran in the park'); + embedder.embed('javascript typescript programming'); + + const v1 = embedder.embed('the cat sat on the mat'); + const v2 = embedder.embed('the cat lay on the mat'); + const v3 = embedder.embed('javascript typescript programming'); + + const simCats = cosineSimilarity(v1, v2); + const simCatJs = cosineSimilarity(v1, v3); + expect(simCats).toBeGreaterThan(simCatJs); + }); + + it('should produce different embeddings for different topics', () => { + const embedder = new TFIDFEmbedder(); + embedder.embed('quantum physics particle accelerator'); + embedder.embed('chocolate cake recipe baking'); + + const v1 = embedder.embed('quantum physics particle accelerator'); + const v2 = embedder.embed('chocolate cake recipe baking'); + + const sim = cosineSimilarity(v1, v2); + expect(sim).toBeLessThan(0.3); + }); + + it('should handle single-word text', () => { + const embedder = new TFIDFEmbedder(); + const vec = embedder.embed('hello'); + expect(vec.some((v) => v !== 0)).toBe(true); + }); + + it('should build vocabulary incrementally', () => { + const embedder = new TFIDFEmbedder(); + const v1 = embedder.embed('alpha'); + const vocabSize1 = v1.length; + + const v2 = embedder.embed('beta gamma'); + const vocabSize2 = v2.length; + + // Vocabulary should grow + expect(vocabSize2).toBeGreaterThan(vocabSize1); + }); + + it('should apply IDF weighting to reduce common word importance', () => { + const embedder = new TFIDFEmbedder(); + // 'the' appears in many documents, 'quantum' in few + embedder.embed('the cat'); + embedder.embed('the dog'); + embedder.embed('the fish'); + embedder.embed('quantum physics'); + + const vec = embedder.embed('the quantum'); + const vocab = embedder.getVocabulary(); + const theIdx = vocab.indexOf('the'); + const quantumIdx = vocab.indexOf('quantum'); + + // 'quantum' should have higher weight because it appears in fewer docs + expect(theIdx).toBeGreaterThanOrEqual(0); + expect(quantumIdx).toBeGreaterThanOrEqual(0); + expect(Math.abs(vec[quantumIdx])).toBeGreaterThan(Math.abs(vec[theIdx])); + }); + + it('embedQuery should not mutate vocabulary or doc count', () => { + const embedder = new TFIDFEmbedder(); + embedder.embed('alpha beta'); + embedder.embed('gamma delta'); + const vocabBefore = embedder.getVocabulary().length; + + // embedQuery with new terms should not grow vocab + const qvec = embedder.embedQuery('epsilon zeta'); + const vocabAfter = embedder.getVocabulary().length; + + expect(vocabAfter).toBe(vocabBefore); + // Unknown terms should produce a zero vector (no known terms matched) + expect(qvec.every((v) => v === 0)).toBe(true); + }); + + it('embedQuery should use existing vocabulary for known terms', () => { + const embedder = new TFIDFEmbedder(); + embedder.embed('cat dog fish'); + + const qvec = embedder.embedQuery('cat'); + // Should produce non-zero vector since 'cat' is in vocab + expect(qvec.some((v) => v !== 0)).toBe(true); + // Same dimension as current vocab + expect(qvec.length).toBe(embedder.getVocabulary().length); + }); + + it('should normalize vectors', () => { + const embedder = new TFIDFEmbedder(); + const vec = embedder.embed('test normalization vector'); + const norm = Math.sqrt(vec.reduce((sum, v) => sum + v * v, 0)); + // Normalized vector should have unit length + expect(norm).toBeGreaterThan(0); + expect(norm).toBeCloseTo(1.0, 3); + }); +}); diff --git a/packages/core/src/services/embeddingService.ts b/packages/core/src/services/embeddingService.ts new file mode 100644 index 00000000000..d5283069310 --- /dev/null +++ b/packages/core/src/services/embeddingService.ts @@ -0,0 +1,126 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Embedder } from './contextWindow.js'; + +/** + * TF-IDF embedder for local, deterministic text similarity. + * + * Builds vocabulary incrementally from the conversation corpus. + * No API calls — purely local computation. + */ +export class TFIDFEmbedder implements Embedder { + private _vocab: Map = new Map(); // term -> index + private _docCount = 0; + private _termDocFreq: Map = new Map(); // term -> # docs containing it + + embed(text: string): number[] { + const tokens = this._tokenize(text); + if (tokens.length === 0) { + return new Array(Math.max(this._vocab.size, 1)).fill(0); + } + + // Update vocabulary with new terms + const uniqueTokens = new Set(tokens); + for (const token of uniqueTokens) { + if (!this._vocab.has(token)) { + this._vocab.set(token, this._vocab.size); + } + } + + // Update document frequency + this._docCount++; + for (const token of uniqueTokens) { + this._termDocFreq.set(token, (this._termDocFreq.get(token) ?? 0) + 1); + } + + // Compute TF-IDF vector + const vec = new Array(this._vocab.size).fill(0); + const termFreq = new Map(); + + for (const token of tokens) { + termFreq.set(token, (termFreq.get(token) ?? 0) + 1); + } + + for (const [term, count] of termFreq) { + const idx = this._vocab.get(term); + if (idx === undefined) continue; + + const tf = count / tokens.length; + const df = this._termDocFreq.get(term) ?? 1; + const idf = Math.log(1 + this._docCount / df); + vec[idx] = tf * idf; + } + + // L2 normalize + const norm = Math.sqrt( + vec.reduce((sum: number, v: number) => sum + v * v, 0), + ); + if (norm > 0) { + for (let i = 0; i < vec.length; i++) { + vec[i] /= norm; + } + } + + return vec; + } + + /** + * Embed without mutating vocabulary, docCount, or termDocFreq. + * Used for queries/retrieval so searching doesn't contaminate the corpus. + */ + embedQuery(text: string): number[] { + const tokens = this._tokenize(text); + if (tokens.length === 0) { + return new Array(Math.max(this._vocab.size, 1)).fill(0); + } + + const vec = new Array(this._vocab.size).fill(0); + const termFreq = new Map(); + + for (const token of tokens) { + termFreq.set(token, (termFreq.get(token) ?? 0) + 1); + } + + for (const [term, count] of termFreq) { + const idx = this._vocab.get(term); + if (idx === undefined) continue; // unknown terms ignored + + const tf = count / tokens.length; + const df = this._termDocFreq.get(term) ?? 1; + const idf = Math.log(1 + this._docCount / df); + vec[idx] = tf * idf; + } + + // L2 normalize + const norm = Math.sqrt( + vec.reduce((sum: number, v: number) => sum + v * v, 0), + ); + if (norm > 0) { + for (let i = 0; i < vec.length; i++) { + vec[i] /= norm; + } + } + + return vec; + } + + getVocabulary(): string[] { + const vocab = new Array(this._vocab.size); + for (const [term, idx] of this._vocab) { + vocab[idx] = term; + } + return vocab; + } + + private _tokenize(text: string): string[] { + return text + .toLowerCase() + .replace(/[^a-z0-9\s]/g, ' ') + .split(/\s+/) + .filter((t) => t.length > 0); + } +} diff --git a/packages/core/src/utils/sanitizePromptInput.ts b/packages/core/src/utils/sanitizePromptInput.ts new file mode 100644 index 00000000000..f5798e45cb1 --- /dev/null +++ b/packages/core/src/utils/sanitizePromptInput.ts @@ -0,0 +1,18 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +export function sanitizePromptString(value: string): string { + return value + .replace(/\\[rn]/g, ' ') + .replace(/[\r\n\u2028\u2029]+/g, ' ') + .replace(/```/g, "'''") + .replace(/[<>]/g, (char) => (char === '<' ? '<' : '>')) + .replace(/[\x00-\x1f\x7f]/g, ''); // eslint-disable-line no-control-regex +} + +export function sanitizePromptValue(value: unknown): string { + return sanitizePromptString(JSON.stringify(value) ?? ''); +}