diff --git a/.changeset/fix-continuation-chunk-emission.md b/.changeset/fix-continuation-chunk-emission.md new file mode 100644 index 00000000..a0dbd033 --- /dev/null +++ b/.changeset/fix-continuation-chunk-emission.md @@ -0,0 +1,5 @@ +--- +'@tanstack/ai': patch +--- + +Emit TOOL_CALL_START and TOOL_CALL_ARGS for pending tool calls during continuation re-executions diff --git a/packages/typescript/ai/src/activities/chat/index.ts b/packages/typescript/ai/src/activities/chat/index.ts index c7ec866d..24cc4152 100644 --- a/packages/typescript/ai/src/activities/chat/index.ts +++ b/packages/typescript/ai/src/activities/chat/index.ts @@ -731,6 +731,13 @@ class TextEngine< needsClientExecution: executionResult.needsClientExecution, }) + // Build args lookup so buildToolResultChunks can emit TOOL_CALL_START + + // TOOL_CALL_ARGS before TOOL_CALL_END during continuation re-executions. + const argsMap = new Map() + for (const tc of pendingToolCalls) { + argsMap.set(tc.id, tc.function.arguments) + } + if ( executionResult.needsApproval.length > 0 || executionResult.needsClientExecution.length > 0 @@ -739,6 +746,7 @@ class TextEngine< for (const chunk of this.buildToolResultChunks( executionResult.results, finishEvent, + argsMap, )) { yield chunk } @@ -765,6 +773,7 @@ class TextEngine< const toolResultChunks = this.buildToolResultChunks( executionResult.results, finishEvent, + argsMap, ) for (const chunk of toolResultChunks) { @@ -1080,12 +1089,35 @@ class TextEngine< private buildToolResultChunks( results: Array, finishEvent: RunFinishedEvent, + argsMap?: Map, ): Array { const chunks: Array = [] for (const result of results) { const content = JSON.stringify(result.result) + // Emit TOOL_CALL_START + TOOL_CALL_ARGS before TOOL_CALL_END so that + // the client can reconstruct the full tool call during continuations. + if (argsMap) { + chunks.push({ + type: 'TOOL_CALL_START', + timestamp: Date.now(), + model: finishEvent.model, + toolCallId: result.toolCallId, + toolName: result.toolName, + }) + + const args = argsMap.get(result.toolCallId) ?? '{}' + chunks.push({ + type: 'TOOL_CALL_ARGS', + timestamp: Date.now(), + model: finishEvent.model, + toolCallId: result.toolCallId, + delta: args, + args, + }) + } + chunks.push({ type: 'TOOL_CALL_END', timestamp: Date.now(), diff --git a/packages/typescript/ai/tests/chat.test.ts b/packages/typescript/ai/tests/chat.test.ts index 89374348..60d6a964 100644 --- a/packages/typescript/ai/tests/chat.test.ts +++ b/packages/typescript/ai/tests/chat.test.ts @@ -656,6 +656,232 @@ describe('chat()', () => { expect(executeSpy).not.toHaveBeenCalled() expect(calls).toHaveLength(1) }) + + it('should emit TOOL_CALL_START and TOOL_CALL_ARGS before TOOL_CALL_END for pending tool calls', async () => { + const executeSpy = vi.fn().mockReturnValue({ temp: 72 }) + + const { adapter } = createMockAdapter({ + iterations: [ + // After pending tool is executed, the engine calls the adapter for the next response + [ + ev.runStarted(), + ev.textStart(), + ev.textContent('72F in NYC'), + ev.textEnd(), + ev.runFinished('stop'), + ], + ], + }) + + const stream = chat({ + adapter, + messages: [ + { role: 'user', content: 'Weather?' }, + { + role: 'assistant', + content: 'Let me check.', + toolCalls: [ + { + id: 'call_1', + type: 'function' as const, + function: { name: 'getWeather', arguments: '{"city":"NYC"}' }, + }, + ], + }, + // No tool result message -> pending! + ], + tools: [serverTool('getWeather', executeSpy)], + }) + + const chunks = await collectChunks(stream as AsyncIterable) + + // Tool should have been executed + expect(executeSpy).toHaveBeenCalledTimes(1) + + // The continuation re-execution should emit the full chunk sequence: + // TOOL_CALL_START -> TOOL_CALL_ARGS -> TOOL_CALL_END + // Without the fix, only TOOL_CALL_END is emitted, causing the client + // to store the tool call with empty arguments {}. + const toolStartChunks = chunks.filter( + (c) => + c.type === 'TOOL_CALL_START' && (c as any).toolCallId === 'call_1', + ) + expect(toolStartChunks).toHaveLength(1) + expect((toolStartChunks[0] as any).toolName).toBe('getWeather') + + const toolArgsChunks = chunks.filter( + (c) => + c.type === 'TOOL_CALL_ARGS' && (c as any).toolCallId === 'call_1', + ) + expect(toolArgsChunks).toHaveLength(1) + expect((toolArgsChunks[0] as any).delta).toBe('{"city":"NYC"}') + expect((toolArgsChunks[0] as any).args).toBe('{"city":"NYC"}') + + const toolEndChunks = chunks.filter( + (c) => c.type === 'TOOL_CALL_END' && (c as any).toolCallId === 'call_1', + ) + expect(toolEndChunks).toHaveLength(1) + + // Verify ordering: START before ARGS before END + const startIdx = chunks.indexOf(toolStartChunks[0]!) + const argsIdx = chunks.indexOf(toolArgsChunks[0]!) + const endIdx = chunks.indexOf(toolEndChunks[0]!) + expect(startIdx).toBeLessThan(argsIdx) + expect(argsIdx).toBeLessThan(endIdx) + }) + + it('should emit TOOL_CALL_START and TOOL_CALL_ARGS for each pending tool call in a batch', async () => { + const weatherSpy = vi.fn().mockReturnValue({ temp: 72 }) + const timeSpy = vi.fn().mockReturnValue({ time: '3pm' }) + + const { adapter } = createMockAdapter({ + iterations: [ + [ + ev.runStarted(), + ev.textStart(), + ev.textContent('Done.'), + ev.textEnd(), + ev.runFinished('stop'), + ], + ], + }) + + const stream = chat({ + adapter, + messages: [ + { role: 'user', content: 'Weather and time?' }, + { + role: 'assistant', + content: '', + toolCalls: [ + { + id: 'call_weather', + type: 'function' as const, + function: { name: 'getWeather', arguments: '{"city":"NYC"}' }, + }, + { + id: 'call_time', + type: 'function' as const, + function: { name: 'getTime', arguments: '{"tz":"EST"}' }, + }, + ], + }, + // No tool results -> both pending + ], + tools: [ + serverTool('getWeather', weatherSpy), + serverTool('getTime', timeSpy), + ], + }) + + const chunks = await collectChunks(stream as AsyncIterable) + + // Both tools should have been executed + expect(weatherSpy).toHaveBeenCalledTimes(1) + expect(timeSpy).toHaveBeenCalledTimes(1) + + // Each pending tool should get the full START -> ARGS -> END sequence + for (const { id, name, args } of [ + { id: 'call_weather', name: 'getWeather', args: '{"city":"NYC"}' }, + { id: 'call_time', name: 'getTime', args: '{"tz":"EST"}' }, + ]) { + const starts = chunks.filter( + (c) => c.type === 'TOOL_CALL_START' && (c as any).toolCallId === id, + ) + expect(starts).toHaveLength(1) + expect((starts[0] as any).toolName).toBe(name) + + const argChunks = chunks.filter( + (c) => c.type === 'TOOL_CALL_ARGS' && (c as any).toolCallId === id, + ) + expect(argChunks).toHaveLength(1) + expect((argChunks[0] as any).delta).toBe(args) + + const ends = chunks.filter( + (c) => c.type === 'TOOL_CALL_END' && (c as any).toolCallId === id, + ) + expect(ends).toHaveLength(1) + + // Verify ordering + const startIdx = chunks.indexOf(starts[0]!) + const argsIdx = chunks.indexOf(argChunks[0]!) + const endIdx = chunks.indexOf(ends[0]!) + expect(startIdx).toBeLessThan(argsIdx) + expect(argsIdx).toBeLessThan(endIdx) + } + }) + + it('should emit TOOL_CALL_START and TOOL_CALL_ARGS for the server tool in a mixed pending batch', async () => { + const weatherSpy = vi.fn().mockReturnValue({ temp: 72 }) + + const { adapter } = createMockAdapter({ iterations: [] }) + + const stream = chat({ + adapter, + messages: [ + { role: 'user', content: 'Weather and notify?' }, + { + role: 'assistant', + content: '', + toolCalls: [ + { + id: 'call_server', + type: 'function' as const, + function: { name: 'getWeather', arguments: '{"city":"NYC"}' }, + }, + { + id: 'call_client', + type: 'function' as const, + function: { + name: 'showNotification', + arguments: '{"message":"done"}', + }, + }, + ], + }, + // No tool results -> both pending + ], + tools: [ + serverTool('getWeather', weatherSpy), + clientTool('showNotification'), + ], + }) + + const chunks = await collectChunks(stream as AsyncIterable) + + // Server tool should have executed + expect(weatherSpy).toHaveBeenCalledTimes(1) + + // The executed server tool should get the full START -> ARGS -> END + const starts = chunks.filter( + (c) => + c.type === 'TOOL_CALL_START' && + (c as any).toolCallId === 'call_server', + ) + expect(starts).toHaveLength(1) + expect((starts[0] as any).toolName).toBe('getWeather') + + const argChunks = chunks.filter( + (c) => + c.type === 'TOOL_CALL_ARGS' && + (c as any).toolCallId === 'call_server', + ) + expect(argChunks).toHaveLength(1) + expect((argChunks[0] as any).delta).toBe('{"city":"NYC"}') + + const ends = chunks.filter( + (c) => + c.type === 'TOOL_CALL_END' && (c as any).toolCallId === 'call_server', + ) + expect(ends).toHaveLength(1) + + // Verify ordering + const startIdx = chunks.indexOf(starts[0]!) + const argsIdx = chunks.indexOf(argChunks[0]!) + const endIdx = chunks.indexOf(ends[0]!) + expect(startIdx).toBeLessThan(argsIdx) + expect(argsIdx).toBeLessThan(endIdx) + }) }) // ========================================================================== diff --git a/testing/e2e/tests/tools-test/continuation-args.spec.ts b/testing/e2e/tests/tools-test/continuation-args.spec.ts new file mode 100644 index 00000000..4e8914cb --- /dev/null +++ b/testing/e2e/tests/tools-test/continuation-args.spec.ts @@ -0,0 +1,170 @@ +import { test, expect } from '../fixtures' +import { + selectScenario, + runTest, + waitForTestComplete, + getMetadata, + getEventLog, + getToolCalls, + getToolCallParts, +} from './helpers' + +/** + * Continuation Re-execution — Tool Call Arguments E2E Tests + * + * These tests verify that tool call arguments are correctly preserved during + * continuation re-executions. When a client tool completes and the conversation + * continues, the server re-processes message history containing pending tool + * calls. Without emitting TOOL_CALL_START + TOOL_CALL_ARGS before + * TOOL_CALL_END, tool-call parts arrive at the client with empty + * arguments {}, potentially causing infinite re-execution loops. + */ + +test.describe('Continuation Re-execution — Tool Call Arguments', () => { + test('single client tool arguments preserved after continuation', async ({ + page, + testId, + aimockPort, + }) => { + await selectScenario(page, 'client-tool-single', testId, aimockPort) + await runTest(page) + await waitForTestComplete(page) + + const metadata = await getMetadata(page) + expect(metadata.testComplete).toBe('true') + expect(parseInt(metadata.toolCallCount)).toBeGreaterThanOrEqual(1) + + const parts = await getToolCallParts(page) + expect(parts.length).toBeGreaterThanOrEqual(1) + + const notificationCall = parts.find((tc) => tc.name === 'show_notification') + expect(notificationCall).toBeDefined() + expect(notificationCall?.arguments).toEqual({ + message: 'Hello from the AI!', + type: 'info', + }) + }) + + test('sequential client tool arguments preserved across multiple continuations', async ({ + page, + testId, + aimockPort, + }) => { + await selectScenario(page, 'sequential-client-tools', testId, aimockPort) + await runTest(page) + await waitForTestComplete(page, 15000, 2) + + // Wait for execution events to propagate + await page.waitForFunction( + () => { + const el = document.querySelector('#test-metadata') + return ( + parseInt(el?.getAttribute('data-execution-complete-count') || '0') >= + 2 + ) + }, + { timeout: 10000 }, + ) + + const metadata = await getMetadata(page) + expect(parseInt(metadata.toolCallCount)).toBeGreaterThanOrEqual(2) + + const parts = await getToolCallParts(page) + const notificationCalls = parts.filter( + (tc) => tc.name === 'show_notification', + ) + expect(notificationCalls.length).toBeGreaterThanOrEqual(2) + + // Both sets of arguments must be present (order may vary) + const allArgs = notificationCalls.map((tc) => tc.arguments) + expect(allArgs).toContainEqual({ + message: 'First notification', + type: 'info', + }) + expect(allArgs).toContainEqual({ + message: 'Second notification', + type: 'warning', + }) + + // No tool call should have empty arguments + expect( + notificationCalls.every((tc) => Object.keys(tc.arguments).length > 0), + ).toBe(true) + }) + + test('parallel client tool arguments preserved in batch continuation', async ({ + page, + testId, + aimockPort, + }) => { + await selectScenario(page, 'parallel-client-tools', testId, aimockPort) + await runTest(page) + await waitForTestComplete(page, 15000, 2) + + const metadata = await getMetadata(page) + expect(parseInt(metadata.toolCallCount)).toBeGreaterThanOrEqual(2) + + const parts = await getToolCallParts(page) + expect(parts.length).toBeGreaterThanOrEqual(2) + + const notificationCall = parts.find((tc) => tc.name === 'show_notification') + const chartCall = parts.find((tc) => tc.name === 'display_chart') + + expect(notificationCall).toBeDefined() + expect(chartCall).toBeDefined() + + expect(notificationCall?.arguments).toEqual({ + message: 'Parallel 1', + type: 'info', + }) + expect(chartCall?.arguments).toEqual({ + type: 'bar', + data: [1, 2, 3], + }) + }) + + test('mixed server and client tool arguments preserved in sequence', async ({ + page, + testId, + aimockPort, + }) => { + // Server tool (fetch_data) followed by client tool (display_chart) + await selectScenario(page, 'sequence-server-client', testId, aimockPort) + await runTest(page) + await waitForTestComplete(page, 20000, 1) + + const metadata = await getMetadata(page) + expect(parseInt(metadata.toolCallCount)).toBeGreaterThanOrEqual(2) + + const parts = await getToolCallParts(page) + expect(parts.length).toBeGreaterThanOrEqual(2) + + const fetchCall = parts.find((tc) => tc.name === 'fetch_data') + const chartCall = parts.find((tc) => tc.name === 'display_chart') + + expect(fetchCall).toBeDefined() + expect(chartCall).toBeDefined() + + expect(fetchCall?.arguments).toEqual({ source: 'api' }) + expect(chartCall?.arguments).toEqual({ type: 'bar', data: [1, 2, 3] }) + }) + + // Screenshot on failure + test.afterEach(async ({ page }, testInfo) => { + if (testInfo.status !== testInfo.expectedStatus) { + await page.screenshot({ + path: `test-results/continuation-args-failure-${testInfo.title.replace(/\s+/g, '-')}.png`, + fullPage: true, + }) + + const events = await getEventLog(page) + const toolCalls = await getToolCalls(page) + const metadata = await getMetadata(page) + + console.log('Test failed. Debug info:') + console.log('Metadata:', metadata) + console.log('Events:', events) + console.log('Tool calls:', toolCalls) + } + }) +}) diff --git a/testing/e2e/tests/tools-test/helpers.ts b/testing/e2e/tests/tools-test/helpers.ts index 94a1fb27..a83853cd 100644 --- a/testing/e2e/tests/tools-test/helpers.ts +++ b/testing/e2e/tests/tools-test/helpers.ts @@ -179,3 +179,36 @@ export async function getToolCalls( } }) } + +/** + * Extract tool-call parts with parsed arguments from #messages-json-content. + */ +export async function getToolCallParts( + page: Page, +): Promise }>> { + return page.evaluate(() => { + const el = document.getElementById('messages-json-content') + if (!el) return [] + try { + const messages = JSON.parse(el.textContent || '[]') + const parts: Array<{ name: string; arguments: Record }> = + [] + for (const msg of messages) { + for (const part of msg.parts || []) { + if (part.type === 'tool-call') { + parts.push({ + name: part.name, + arguments: + typeof part.arguments === 'string' + ? JSON.parse(part.arguments) + : part.arguments, + }) + } + } + } + return parts + } catch { + return [] + } + }) +}