diff --git a/.changeset/usechat-fetcher-server-functions.md b/.changeset/usechat-fetcher-server-functions.md new file mode 100644 index 000000000..b192b4c39 --- /dev/null +++ b/.changeset/usechat-fetcher-server-functions.md @@ -0,0 +1,23 @@ +--- +'@tanstack/ai-client': minor +'@tanstack/ai-react': minor +'@tanstack/ai-preact': patch +'@tanstack/ai-solid': patch +'@tanstack/ai-svelte': patch +'@tanstack/ai-vue': patch +--- + +Add a `fetcher` option to `ChatClient` and the framework chat hooks +(`useChat` / `createChat`), mirroring the `fetcher` option on the +generation hooks. Pass either `connection` or `fetcher` — the XOR is +enforced at the type level via `ChatTransport`. + +```ts +useChat({ + fetcher: ({ messages }, { signal }) => chatFn({ data: { messages }, signal }), +}) +``` + +The fetcher may return either a `Response` (parsed as SSE) or an +`AsyncIterable` (yielded directly). `stream()`, +`fetchServerSentEvents`, `fetchHttpStream`, and `rpcStream` are unchanged. diff --git a/examples/ts-react-chat/src/components/Header.tsx b/examples/ts-react-chat/src/components/Header.tsx index 4cd9fc4d8..edd44fd63 100644 --- a/examples/ts-react-chat/src/components/Header.tsx +++ b/examples/ts-react-chat/src/components/Header.tsx @@ -11,6 +11,7 @@ import { Menu, Mic, Music, + Server, Video, X, } from 'lucide-react' @@ -197,6 +198,19 @@ export default function Header() { Voice Chat (Realtime) + + setIsOpen(false)} + className="flex items-center gap-3 p-3 rounded-lg hover:bg-gray-800 transition-colors mb-2" + activeProps={{ + className: + 'flex items-center gap-3 p-3 rounded-lg bg-cyan-600 hover:bg-cyan-700 transition-colors mb-2', + }} + > + + Server Function Chat + diff --git a/examples/ts-react-chat/src/lib/server-fns.ts b/examples/ts-react-chat/src/lib/server-fns.ts index b1e5d9e59..15315b30f 100644 --- a/examples/ts-react-chat/src/lib/server-fns.ts +++ b/examples/ts-react-chat/src/lib/server-fns.ts @@ -1,6 +1,7 @@ import { createServerFn } from '@tanstack/react-start' import { z } from 'zod' import { + chat, generateAudio, generateImage, generateSpeech, @@ -10,7 +11,13 @@ import { summarize, toServerSentEventsResponse, } from '@tanstack/ai' -import { openaiImage, openaiSummarize, openaiVideo } from '@tanstack/ai-openai' +import { + openaiImage, + openaiSummarize, + openaiText, + openaiVideo, +} from '@tanstack/ai-openai' +import type { UIMessage } from '@tanstack/ai' import { InvalidModelOverrideError, UnknownProviderError, @@ -365,3 +372,23 @@ export const generateVideoStreamFn = createServerFn({ method: 'POST' }) }), ) }) + +// ============================================================================= +// Chat server function — pairs with useChat({ fetcher }) +// ============================================================================= + +export const chatFn = createServerFn({ method: 'POST' }) + .inputValidator( + (data: { messages: Array; data?: Record }) => data, + ) + .handler(({ data }) => + toServerSentEventsResponse( + chat({ + adapter: openaiText('gpt-5.2'), + messages: data.messages as any, + systemPrompts: [ + 'You are a helpful assistant. Keep replies short and friendly.', + ], + }), + ), + ) diff --git a/examples/ts-react-chat/src/routeTree.gen.ts b/examples/ts-react-chat/src/routeTree.gen.ts index f9b2ac825..60779cfa4 100644 --- a/examples/ts-react-chat/src/routeTree.gen.ts +++ b/examples/ts-react-chat/src/routeTree.gen.ts @@ -9,6 +9,7 @@ // Additionally, you should also exclude this file from your linter and/or formatter to prevent it from being checked or modified. import { Route as rootRouteImport } from './routes/__root' +import { Route as ServerFnChatRouteImport } from './routes/server-fn-chat' import { Route as RealtimeRouteImport } from './routes/realtime' import { Route as ImageGenRouteImport } from './routes/image-gen' import { Route as IndexRouteImport } from './routes/index' @@ -31,6 +32,11 @@ import { Route as ApiGenerateSpeechRouteImport } from './routes/api.generate.spe import { Route as ApiGenerateImageRouteImport } from './routes/api.generate.image' import { Route as ApiGenerateAudioRouteImport } from './routes/api.generate.audio' +const ServerFnChatRoute = ServerFnChatRouteImport.update({ + id: '/server-fn-chat', + path: '/server-fn-chat', + getParentRoute: () => rootRouteImport, +} as any) const RealtimeRoute = RealtimeRouteImport.update({ id: '/realtime', path: '/realtime', @@ -143,6 +149,7 @@ export interface FileRoutesByFullPath { '/': typeof IndexRoute '/image-gen': typeof ImageGenRoute '/realtime': typeof RealtimeRoute + '/server-fn-chat': typeof ServerFnChatRoute '/api/image-gen': typeof ApiImageGenRoute '/api/structured-output': typeof ApiStructuredOutputRoute '/api/summarize': typeof ApiSummarizeRoute @@ -166,6 +173,7 @@ export interface FileRoutesByTo { '/': typeof IndexRoute '/image-gen': typeof ImageGenRoute '/realtime': typeof RealtimeRoute + '/server-fn-chat': typeof ServerFnChatRoute '/api/image-gen': typeof ApiImageGenRoute '/api/structured-output': typeof ApiStructuredOutputRoute '/api/summarize': typeof ApiSummarizeRoute @@ -190,6 +198,7 @@ export interface FileRoutesById { '/': typeof IndexRoute '/image-gen': typeof ImageGenRoute '/realtime': typeof RealtimeRoute + '/server-fn-chat': typeof ServerFnChatRoute '/api/image-gen': typeof ApiImageGenRoute '/api/structured-output': typeof ApiStructuredOutputRoute '/api/summarize': typeof ApiSummarizeRoute @@ -215,6 +224,7 @@ export interface FileRouteTypes { | '/' | '/image-gen' | '/realtime' + | '/server-fn-chat' | '/api/image-gen' | '/api/structured-output' | '/api/summarize' @@ -238,6 +248,7 @@ export interface FileRouteTypes { | '/' | '/image-gen' | '/realtime' + | '/server-fn-chat' | '/api/image-gen' | '/api/structured-output' | '/api/summarize' @@ -261,6 +272,7 @@ export interface FileRouteTypes { | '/' | '/image-gen' | '/realtime' + | '/server-fn-chat' | '/api/image-gen' | '/api/structured-output' | '/api/summarize' @@ -285,6 +297,7 @@ export interface RootRouteChildren { IndexRoute: typeof IndexRoute ImageGenRoute: typeof ImageGenRoute RealtimeRoute: typeof RealtimeRoute + ServerFnChatRoute: typeof ServerFnChatRoute ApiImageGenRoute: typeof ApiImageGenRoute ApiStructuredOutputRoute: typeof ApiStructuredOutputRoute ApiSummarizeRoute: typeof ApiSummarizeRoute @@ -307,6 +320,13 @@ export interface RootRouteChildren { declare module '@tanstack/react-router' { interface FileRoutesByPath { + '/server-fn-chat': { + id: '/server-fn-chat' + path: '/server-fn-chat' + fullPath: '/server-fn-chat' + preLoaderRoute: typeof ServerFnChatRouteImport + parentRoute: typeof rootRouteImport + } '/realtime': { id: '/realtime' path: '/realtime' @@ -461,6 +481,7 @@ const rootRouteChildren: RootRouteChildren = { IndexRoute: IndexRoute, ImageGenRoute: ImageGenRoute, RealtimeRoute: RealtimeRoute, + ServerFnChatRoute: ServerFnChatRoute, ApiImageGenRoute: ApiImageGenRoute, ApiStructuredOutputRoute: ApiStructuredOutputRoute, ApiSummarizeRoute: ApiSummarizeRoute, diff --git a/examples/ts-react-chat/src/routes/server-fn-chat.tsx b/examples/ts-react-chat/src/routes/server-fn-chat.tsx new file mode 100644 index 000000000..4350c9e25 --- /dev/null +++ b/examples/ts-react-chat/src/routes/server-fn-chat.tsx @@ -0,0 +1,101 @@ +import { useState } from 'react' +import { createFileRoute } from '@tanstack/react-router' +import { useChat } from '@tanstack/ai-react' +import { Send, Square } from 'lucide-react' +import { chatFn } from '@/lib/server-fns' + +export const Route = createFileRoute('/server-fn-chat')({ + component: ServerFnChat, +}) + +function ServerFnChat() { + const { messages, sendMessage, isLoading, error, stop } = useChat({ + fetcher: ({ messages }, { signal }) => + chatFn({ data: { messages }, signal }), + }) + const [input, setInput] = useState('') + + const handleSubmit = (e: React.FormEvent) => { + e.preventDefault() + if (!input.trim() || isLoading) return + void sendMessage(input) + setInput('') + } + + return ( +
+
+

Chat via server function

+

+ + useChat({ fetcher: ({'{'}messages{'}'}, {'{'}signal{'}'}) => + chatFn({'{'} data, signal {'}'}) }) + {' '} + — the server function returns an SSE{' '} + Response; the chat client + parses it. +

+
+ +
+ {messages.length === 0 && ( +

+ Say something to start the chat. +

+ )} + {messages.map((m) => ( +
+ {m.parts.map((part, i) => + part.type === 'text' ? {part.content} : null, + )} +
+ ))} + {error && ( +
+ {error.message} +
+ )} +
+ +
+ setInput(e.target.value)} + placeholder="Message..." + disabled={isLoading} + className="flex-1 rounded-lg bg-gray-800 border border-gray-700 px-3 py-2 text-sm focus:outline-none focus:border-cyan-500" + /> + {isLoading ? ( + + ) : ( + + )} +
+
+ ) +} diff --git a/packages/typescript/ai-client/src/chat-client.ts b/packages/typescript/ai-client/src/chat-client.ts index ab9d07dff..3f33d1620 100644 --- a/packages/typescript/ai-client/src/chat-client.ts +++ b/packages/typescript/ai-client/src/chat-client.ts @@ -4,7 +4,10 @@ import { normalizeToUIMessage, } from '@tanstack/ai' import { DefaultChatClientEventEmitter } from './events' -import { normalizeConnectionAdapter } from './connection-adapters' +import { + fetcherToConnectionAdapter, + normalizeConnectionAdapter, +} from './connection-adapters' import type { AnyClientTool, ContentPart, @@ -19,6 +22,7 @@ import type { ChatClientEventEmitter } from './events' import type { ChatClientOptions, ChatClientState, + ChatFetcher, ConnectionStatus, MessagePart, MultimodalContent, @@ -26,6 +30,26 @@ import type { UIMessage, } from './types' +function resolveTransport(transport: { + connection?: ConnectionAdapter + fetcher?: ChatFetcher +}): ConnectionAdapter { + const hasConnection = transport.connection !== undefined + const hasFetcher = transport.fetcher !== undefined + if (hasConnection && hasFetcher) { + throw new Error( + 'ChatClient: pass either `connection` or `fetcher`, not both.', + ) + } + if (hasConnection) { + return transport.connection! + } + if (hasFetcher) { + return fetcherToConnectionAdapter(transport.fetcher!) + } + throw new Error('ChatClient: either `connection` or `fetcher` is required.') +} + export class ChatClient { private processor: StreamProcessor private connection: SubscribeConnectionAdapter @@ -82,7 +106,7 @@ export class ChatClient { constructor(options: ChatClientOptions) { this.uniqueId = options.id || this.generateUniqueId('chat') this.body = options.body || {} - this.connection = normalizeConnectionAdapter(options.connection) + this.connection = normalizeConnectionAdapter(resolveTransport(options)) this.events = new DefaultChatClientEventEmitter(this.uniqueId) // Build client tools map @@ -969,6 +993,7 @@ export class ChatClient { */ updateOptions(options: { connection?: ConnectionAdapter + fetcher?: ChatFetcher body?: Record tools?: ReadonlyArray onResponse?: (response?: Response) => void | Promise @@ -984,7 +1009,7 @@ export class ChatClient { context: { toolCallId?: string }, ) => void }): void { - if (options.connection !== undefined) { + if (options.connection !== undefined || options.fetcher !== undefined) { const wasSubscribed = this.isSubscribed if (this.isLoading) { @@ -999,7 +1024,12 @@ export class ChatClient { this.resetSessionGenerating() this.setIsSubscribed(false) this.setConnectionStatus('disconnected') - this.connection = normalizeConnectionAdapter(options.connection) + this.connection = normalizeConnectionAdapter( + resolveTransport({ + connection: options.connection, + fetcher: options.fetcher, + }), + ) if (wasSubscribed) { this.subscribe() diff --git a/packages/typescript/ai-client/src/connection-adapters.ts b/packages/typescript/ai-client/src/connection-adapters.ts index 91d63a146..b31d9796f 100644 --- a/packages/typescript/ai-client/src/connection-adapters.ts +++ b/packages/typescript/ai-client/src/connection-adapters.ts @@ -1,4 +1,26 @@ -import type { ModelMessage, StreamChunk, UIMessage } from '@tanstack/ai' +import { EventType } from '@tanstack/ai' +import type { + ModelMessage, + RunErrorEvent, + RunFinishedEvent, + StreamChunk, + UIMessage, +} from '@tanstack/ai' +import type { ChatFetcher } from './types' + +/** + * Thrown when an SSE/HTTP stream ends with a non-empty unterminated buffer. + * Indicates the connection was cut mid-line (server crash, dropped TCP, proxy + * timeout) so the partial content cannot be safely parsed. + */ +export class StreamTruncatedError extends Error { + constructor() { + super( + 'Stream ended with unterminated trailing data — connection was likely cut short.', + ) + this.name = 'StreamTruncatedError' + } +} /** * Merge custom headers into request headers @@ -53,15 +75,81 @@ async function* readStreamLines( } } - // Process any remaining data in the buffer + // A non-empty trailing buffer means the connection was cut mid-line. + // Surface this as an error so the chat client transitions to 'error' + // state instead of silently presenting a partial stream as success. if (buffer.trim()) { - yield buffer + throw new StreamTruncatedError() } } finally { reader.releaseLock() } } +/** + * Yield StreamChunks parsed from an SSE Response body. + * + * Accepts either `data: {...}` lines or bare JSON lines. Skips comments + * starting with `:` (proxies and CDNs inject these as keepalives) and the + * `event:` / `id:` / `retry:` SSE control fields. A `[DONE]` sentinel is + * treated as a terminal event: a synthesized RUN_FINISHED is yielded using + * the most recent upstream `threadId` / `runId`, ensuring the consumer sees + * a clean terminal event with real correlation ids. + * + * A JSON parse failure throws — the consumer surfaces it as an error. + */ +async function* responseToSSEChunks( + response: Response, + abortSignal?: AbortSignal, +): AsyncGenerator { + if (!response.ok) { + throw new Error( + `HTTP error! status: ${response.status} ${response.statusText}`, + ) + } + const reader = response.body?.getReader() + if (!reader) { + throw new Error('Response body is not readable') + } + let lastThreadId: string | undefined + let lastRunId: string | undefined + let lastModel: string | undefined + for await (const line of readStreamLines(reader, abortSignal)) { + if ( + line.startsWith(':') || + line.startsWith('event:') || + line.startsWith('id:') || + line.startsWith('retry:') + ) { + continue + } + const data = line.startsWith('data: ') ? line.slice(6) : line + if (data === '[DONE]') { + const synthetic: RunFinishedEvent = { + type: EventType.RUN_FINISHED, + threadId: lastThreadId ?? '', + runId: lastRunId ?? '', + model: lastModel ?? '', + timestamp: Date.now(), + finishReason: 'stop', + } + yield synthetic + return + } + const chunk = JSON.parse(data) as StreamChunk + if ('threadId' in chunk && typeof chunk.threadId === 'string') { + lastThreadId = chunk.threadId + } + if ('runId' in chunk && typeof chunk.runId === 'string') { + lastRunId = chunk.runId + } + if ('model' in chunk && typeof chunk.model === 'string') { + lastModel = chunk.model + } + yield chunk + } +} + export interface ConnectConnectionAdapter { /** * Connect and return an async iterable of StreamChunks. @@ -175,9 +263,17 @@ export function normalizeConnectionAdapter( }, async send(messages, data, abortSignal) { let hasTerminalEvent = false + let upstreamThreadId: string | undefined + let upstreamRunId: string | undefined try { const stream = connection.connect(messages, data, abortSignal) for await (const chunk of stream) { + if ('threadId' in chunk && typeof chunk.threadId === 'string') { + upstreamThreadId = chunk.threadId + } + if ('runId' in chunk && typeof chunk.runId === 'string') { + upstreamRunId = chunk.runId + } if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') { hasTerminalEvent = true } @@ -187,28 +283,26 @@ export function normalizeConnectionAdapter( // If the connect stream ended cleanly without a terminal event, // synthesize RUN_FINISHED so request-scoped consumers can complete. if (!abortSignal?.aborted && !hasTerminalEvent) { - push({ - type: 'RUN_FINISHED', - runId: `run-${Date.now()}`, + const synthetic: RunFinishedEvent = { + type: EventType.RUN_FINISHED, + threadId: upstreamThreadId ?? `thread-${Date.now()}`, + runId: upstreamRunId ?? `run-${Date.now()}`, model: 'connect-wrapper', timestamp: Date.now(), finishReason: 'stop', - } as unknown as StreamChunk) + } + push(synthetic) } } catch (err) { if (!abortSignal?.aborted && !hasTerminalEvent) { - push({ - type: 'RUN_ERROR', + const message = + err instanceof Error ? err.message : 'Unknown error in connect()' + const synthetic: RunErrorEvent = { + type: EventType.RUN_ERROR, timestamp: Date.now(), - message: - err instanceof Error ? err.message : 'Unknown error in connect()', - error: { - message: - err instanceof Error - ? err.message - : 'Unknown error in connect()', - }, - } as unknown as StreamChunk) + message, + } + push(synthetic) } throw err } @@ -296,37 +390,7 @@ export function fetchServerSentEvents( signal: abortSignal || resolvedOptions.signal, }) - if (!response.ok) { - throw new Error( - `HTTP error! status: ${response.status} ${response.statusText}`, - ) - } - - // Parse Server-Sent Events format - const reader = response.body?.getReader() - if (!reader) { - throw new Error('Response body is not readable') - } - - for await (const line of readStreamLines(reader, abortSignal)) { - // Handle Server-Sent Events format - const data = line.startsWith('data: ') ? line.slice(6) : line - - if (data === '[DONE]') { - console.warn( - '[@tanstack/ai-client] Received [DONE] sentinel. This is deprecated — upgrade your @tanstack/ai server package. RUN_FINISHED is the stream terminator.', - ) - continue - } - - try { - const parsed: StreamChunk = JSON.parse(data) - yield parsed - } catch (parseError) { - // Skip non-JSON lines or malformed chunks - console.warn('Failed to parse SSE chunk:', data) - } - } + yield* responseToSSEChunks(response, abortSignal) }, } } @@ -413,12 +477,7 @@ export function fetchHttpStream( } for await (const line of readStreamLines(reader, abortSignal)) { - try { - const parsed: StreamChunk = JSON.parse(line) - yield parsed - } catch (parseError) { - console.warn('Failed to parse HTTP stream chunk:', line) - } + yield JSON.parse(line) as StreamChunk } }, } @@ -442,17 +501,82 @@ export function stream( streamFactory: ( messages: Array | Array, data?: Record, + abortSignal?: AbortSignal, ) => AsyncIterable, ): ConnectConnectionAdapter { return { - async *connect(messages, data) { + async *connect(messages, data, abortSignal) { // Pass messages as-is (UIMessages with parts preserved) // Server-side chat() handles conversion to ModelMessages - yield* streamFactory(messages, data) + yield* streamFactory(messages, data, abortSignal) }, } } +/** + * Wrap a `ChatFetcher` as a `ConnectConnectionAdapter` so the chat client can + * consume it through the same `subscribe`/`send` plumbing used for SSE / + * HTTP-stream / RPC connections. May return either a `Response` (parsed as + * SSE) or an `AsyncIterable` (yielded directly). + * + * @internal + */ +export function fetcherToConnectionAdapter( + fetcher: ChatFetcher, +): ConnectConnectionAdapter { + return { + async *connect(messages, data, abortSignal) { + if (!abortSignal) { + throw new Error( + 'fetcherToConnectionAdapter requires an AbortSignal — the chat client always supplies one.', + ) + } + const uiMessages = messages as Array + const result = await fetcher( + { messages: uiMessages, data }, + { signal: abortSignal }, + ) + if (result instanceof Response) { + yield* responseToSSEChunks(result, abortSignal) + } else { + yield* abortableIterable(result, abortSignal) + } + }, + } +} + +/** + * Wrap an AsyncIterable so iteration aborts when `signal` fires. Without + * this, a fetcher that returns a generator ignoring its signal would leave + * the for-await loop hanging until the iterable naturally ends. + */ +async function* abortableIterable( + iterable: AsyncIterable, + signal: AbortSignal, +): AsyncGenerator { + if (signal.aborted) return + const iterator = iterable[Symbol.asyncIterator]() + const abortPromise = new Promise<{ done: true; value: undefined }>( + (resolve) => { + signal.addEventListener( + 'abort', + () => resolve({ done: true, value: undefined }), + { once: true }, + ) + }, + ) + try { + // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition + while (true) { + const result = await Promise.race([iterator.next(), abortPromise]) + if (result.done) return + yield result.value + } + } finally { + await iterator.return?.() + } +} + /** * Create an RPC stream connection adapter (for RPC-based streaming like Cap'n Web RPC) * @@ -473,13 +597,14 @@ export function rpcStream( rpcCall: ( messages: Array | Array, data?: Record, + abortSignal?: AbortSignal, ) => AsyncIterable, ): ConnectConnectionAdapter { return { - async *connect(messages, data) { + async *connect(messages, data, abortSignal) { // Pass messages as-is (UIMessages with parts preserved) // Server-side chat() handles conversion to ModelMessages - yield* rpcCall(messages, data) + yield* rpcCall(messages, data, abortSignal) }, } } diff --git a/packages/typescript/ai-client/src/index.ts b/packages/typescript/ai-client/src/index.ts index 0f86ae891..84f05fc94 100644 --- a/packages/typescript/ai-client/src/index.ts +++ b/packages/typescript/ai-client/src/index.ts @@ -16,7 +16,10 @@ export type { InferChatMessages, ChatClientState, ConnectionStatus, - // Multimodal content input type + ChatFetcher, + ChatFetcherInput, + ChatFetcherOptions, + ChatTransport, MultimodalContent, } from './types' // Generation client types @@ -57,6 +60,7 @@ export { fetchHttpStream, stream, rpcStream, + StreamTruncatedError, type ConnectConnectionAdapter, type ConnectionAdapter, type FetchConnectionOptions, diff --git a/packages/typescript/ai-client/src/types.ts b/packages/typescript/ai-client/src/types.ts index b705ebbdf..6dbf6f079 100644 --- a/packages/typescript/ai-client/src/types.ts +++ b/packages/typescript/ai-client/src/types.ts @@ -13,6 +13,47 @@ import type { } from '@tanstack/ai' import type { ConnectionAdapter } from './connection-adapters' +/** + * `messages` is the full UIMessage history (not a delta). `data` is the + * merged body — `ChatClientOptions.body` plus any per-call data passed to + * `sendMessage(...)`. + */ +export interface ChatFetcherInput { + messages: Array + data?: Record +} + +export interface ChatFetcherOptions { + /** Fires when `stop()` is called or the request is superseded. */ + signal: AbortSignal +} + +/** + * Direct async function that performs a chat request. Mirrors + * `GenerationFetcher`. Returns either a `Response` (SSE body parsed by the + * chat client) or an `AsyncIterable` (yielded directly). + * + * @example + * ```ts + * useChat({ + * fetcher: ({ messages }, { signal }) => + * chatFn({ data: { messages }, signal }), + * }) + * ``` + */ +export type ChatFetcher = ( + input: ChatFetcherInput, + options: ChatFetcherOptions, +) => Promise> + +/** + * Discriminated union enforcing that exactly one of `connection` or + * `fetcher` is provided. Mirrors `GenerationTransport`. + */ +export type ChatTransport = + | { connection: ConnectionAdapter; fetcher?: never } + | { fetcher: ChatFetcher; connection?: never } + /** * Tool call states - track the lifecycle of a tool call */ @@ -183,16 +224,13 @@ export interface UIMessage = any> { createdAt?: Date } -export interface ChatClientOptions< +/** + * Options for `ChatClient`. Exactly one of `connection` or `fetcher` must be + * provided — the type-level XOR is enforced via `ChatTransport`. + */ +export type ChatClientOptions< TTools extends ReadonlyArray = any, -> { - /** - * Connection adapter for streaming. - * Supports mutually exclusive modes: request-response via `connect()`, or - * subscribe/send mode via `subscribe()` + `send()`. - */ - connection: ConnectionAdapter - +> = { /** * Initial messages to populate the chat */ @@ -299,7 +337,7 @@ export interface ChatClientOptions< */ chunkStrategy?: ChunkStrategy } -} +} & ChatTransport export interface ChatRequestBody { messages: Array diff --git a/packages/typescript/ai-client/tests/chat-client.test.ts b/packages/typescript/ai-client/tests/chat-client.test.ts index ec997c868..8b062c65e 100644 --- a/packages/typescript/ai-client/tests/chat-client.test.ts +++ b/packages/typescript/ai-client/tests/chat-client.test.ts @@ -88,9 +88,9 @@ describe('ChatClient', () => { expect(client1MessageId).not.toBe(client2MessageId) }) - it('should throw if connection is not provided', () => { + it('should throw if neither connection nor fetcher is provided', () => { expect(() => new ChatClient({} as any)).toThrow( - 'Connection adapter is required', + 'either `connection` or `fetcher` is required', ) }) }) diff --git a/packages/typescript/ai-client/tests/chat-fetcher.test.ts b/packages/typescript/ai-client/tests/chat-fetcher.test.ts new file mode 100644 index 000000000..59b73e4ad --- /dev/null +++ b/packages/typescript/ai-client/tests/chat-fetcher.test.ts @@ -0,0 +1,324 @@ +import { describe, expect, it, vi } from 'vitest' +import { ChatClient } from '../src/chat-client' +import { createTextChunks } from './test-utils' +import type { StreamChunk } from '@tanstack/ai' +import type { ChatFetcher, UIMessage } from '../src/types' + +/** + * Tests for the `fetcher` transport on ChatClient — the chat-side mirror of + * `GenerationFetcher` (used by useGenerateSpeech / useSummarize / etc.). + */ +describe('ChatClient — fetcher transport', () => { + it('runs an in-process AsyncIterable fetcher and streams text', async () => { + const chunks = createTextChunks('Hello world', 'msg-1') + const fetcher = vi.fn(async function* () { + for (const chunk of chunks) { + yield chunk + } + } as unknown as ChatFetcher) + + let finalMessages: Array = [] + const client = new ChatClient({ + fetcher, + onMessagesChange: (m) => { + finalMessages = m + }, + }) + + await client.sendMessage('Hi') + + expect(fetcher).toHaveBeenCalledTimes(1) + expect(finalMessages).toHaveLength(2) // user + assistant + const assistant = finalMessages[1]! + expect(assistant.role).toBe('assistant') + const textPart = assistant.parts.find((p) => p.type === 'text') + expect(textPart && 'content' in textPart && textPart.content).toBe( + 'Hello world', + ) + }) + + it('parses an SSE Response returned by the fetcher (server-fn style)', async () => { + const sseBody = + [ + `data: ${JSON.stringify({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'm1', + model: 'test', + timestamp: Date.now(), + delta: 'Hi', + content: 'Hi', + })}`, + `data: ${JSON.stringify({ + type: 'RUN_FINISHED', + runId: 'r1', + threadId: 't1', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + })}`, + '', + ].join('\n') + '\n' + + const fetcher = vi.fn(async () => { + return new Response(sseBody, { + status: 200, + headers: { 'content-type': 'text/event-stream' }, + }) + }) + + let finalMessages: Array = [] + const client = new ChatClient({ + fetcher, + onMessagesChange: (m) => { + finalMessages = m + }, + }) + + await client.sendMessage('hi') + + expect(fetcher).toHaveBeenCalledTimes(1) + const assistant = finalMessages[1]! + expect(assistant.role).toBe('assistant') + const textPart = assistant.parts.find((p) => p.type === 'text') + expect(textPart && 'content' in textPart && textPart.content).toBe('Hi') + }) + + it('passes the AbortSignal to the fetcher; stop() aborts it', async () => { + let observedSignal: AbortSignal | undefined + let resolveFetcher: (() => void) | undefined + const fetcherStarted = new Promise((res) => { + resolveFetcher = res + }) + + const fetcher: ChatFetcher = async (_input, { signal }) => { + observedSignal = signal + resolveFetcher?.() + // Hang until aborted + return new Promise((_resolve, reject) => { + signal.addEventListener('abort', () => + reject(new DOMException('aborted', 'AbortError')), + ) + }) + } + + const client = new ChatClient({ fetcher }) + const sendPromise = client.sendMessage('hi') + await fetcherStarted + expect(observedSignal).toBeDefined() + expect(observedSignal!.aborted).toBe(false) + + client.stop() + await sendPromise + expect(observedSignal!.aborted).toBe(true) + }) + + it('surfaces a fetcher error as a ChatClient error', async () => { + const fetcher: ChatFetcher = async () => { + throw new Error('fetcher exploded') + } + let observedError: Error | undefined + const client = new ChatClient({ + fetcher, + onError: (err) => { + observedError = err + }, + }) + + await client.sendMessage('hi') + + expect(observedError).toBeDefined() + expect(observedError!.message).toBe('fetcher exploded') + expect(client.getStatus()).toBe('error') + }) + + it('surfaces a malformed-SSE Response as a ChatClient error', async () => { + // A fetcher that returns a Response whose body has a malformed JSON line. + // The new behavior is to throw SyntaxError from the SSE parser; the + // chat client should surface that as an error rather than silently + // dropping the bad chunk. + const sseBody = 'data: { not valid json\n\n' + const fetcher: ChatFetcher = async () => { + return new Response(sseBody, { + status: 200, + headers: { 'content-type': 'text/event-stream' }, + }) + } + + let observedError: Error | undefined + const client = new ChatClient({ + fetcher, + onError: (err) => { + observedError = err + }, + }) + + await client.sendMessage('hi') + + expect(observedError).toBeDefined() + expect(observedError!.name).toBe('SyntaxError') + expect(client.getStatus()).toBe('error') + }) + + it('passes UIMessages and merged body to the fetcher', async () => { + const fetcher = vi.fn(async function* () { + yield { + type: 'RUN_FINISHED', + runId: 'r1', + threadId: 't1', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + } as StreamChunk + } as unknown as ChatFetcher) + + const client = new ChatClient({ + fetcher, + body: { provider: 'openai' }, + }) + + await client.sendMessage('hello there') + + expect(fetcher).toHaveBeenCalledTimes(1) + const [input] = fetcher.mock.calls[0]! + expect(input.messages).toHaveLength(1) + expect(input.messages[0]!.role).toBe('user') + expect(input.messages[0]!.parts[0]).toMatchObject({ + type: 'text', + content: 'hello there', + }) + expect(input.data).toMatchObject({ + provider: 'openai', + conversationId: expect.any(String), + }) + }) + + it('throws when both connection and fetcher are passed', () => { + // The XOR is enforced at the type level via `ChatTransport`; the runtime + // check is defense-in-depth for callers using `as any` / dynamic options. + const both: any = { + connection: { connect: async function* () {} }, + fetcher: async () => new Response(''), + } + expect(() => new ChatClient(both)).toThrow( + 'pass either `connection` or `fetcher`', + ) + }) + + it('throws when neither connection nor fetcher is passed', () => { + expect(() => new ChatClient({} as any)).toThrow( + 'either `connection` or `fetcher` is required', + ) + }) + + it('surfaces a non-OK Response as a ChatClient error', async () => { + const fetcher: ChatFetcher = async () => + new Response('Internal Server Error', { + status: 500, + statusText: 'Internal Server Error', + }) + + let observedError: Error | undefined + const client = new ChatClient({ + fetcher, + onError: (err) => { + observedError = err + }, + }) + + await client.sendMessage('hi') + + expect(observedError).toBeDefined() + expect(observedError!.message).toMatch(/HTTP error.*500/) + expect(client.getStatus()).toBe('error') + }) + + it('surfaces an AsyncIterable that throws after yielding chunks', async () => { + const fetcher = vi.fn(async function* () { + yield { + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'm1', + model: 'test', + timestamp: Date.now(), + delta: 'partial', + content: 'partial', + } as StreamChunk + throw new Error('mid-stream boom') + } as unknown as ChatFetcher) + + let observedError: Error | undefined + const client = new ChatClient({ + fetcher, + onError: (err) => { + observedError = err + }, + }) + + await client.sendMessage('hi') + + expect(observedError).toBeDefined() + expect(observedError!.message).toBe('mid-stream boom') + expect(client.getStatus()).toBe('error') + }) + + it('completes cleanly when AsyncIterable ends without RUN_FINISHED', async () => { + const fetcher = vi.fn(async function* () { + yield { + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'm1', + model: 'test', + timestamp: Date.now(), + delta: 'Hello', + content: 'Hello', + } as StreamChunk + } as unknown as ChatFetcher) + + let finalMessages: Array = [] + const client = new ChatClient({ + fetcher, + onMessagesChange: (m) => { + finalMessages = m + }, + }) + + await client.sendMessage('hi') + + expect(client.getStatus()).toBe('ready') + expect(finalMessages).toHaveLength(2) + const assistant = finalMessages[1]! + const textPart = assistant.parts.find((p) => p.type === 'text') + expect(textPart && 'content' in textPart && textPart.content).toBe('Hello') + }) + + it('stops consuming chunks from an AsyncIterable that ignores its signal', async () => { + const observedChunks: Array = [] + const fetcher: ChatFetcher = async () => { + return (async function* () { + await new Promise((r) => setTimeout(r, 5)) + for (let i = 0; i < 10; i++) { + await new Promise((r) => setTimeout(r, 20)) + yield { + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'm1', + model: 'test', + timestamp: Date.now(), + delta: String(i), + content: String(i), + } as StreamChunk + } + })() + } + + const client = new ChatClient({ + fetcher, + onChunk: (c) => observedChunks.push(c), + }) + const sendPromise = client.sendMessage('hi') + await new Promise((r) => setTimeout(r, 30)) + const beforeStop = observedChunks.length + client.stop() + await sendPromise + await new Promise((r) => setTimeout(r, 100)) + + expect(observedChunks.length).toBe(beforeStop) + }) +}) diff --git a/packages/typescript/ai-client/tests/connection-adapters.test.ts b/packages/typescript/ai-client/tests/connection-adapters.test.ts index 60c36763a..736c9b4f8 100644 --- a/packages/typescript/ai-client/tests/connection-adapters.test.ts +++ b/packages/typescript/ai-client/tests/connection-adapters.test.ts @@ -104,9 +104,7 @@ describe('connection-adapters', () => { expect(chunks).toHaveLength(1) }) - it('should skip [DONE] markers and warn about deprecation', async () => { - const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) - + it('should synthesize RUN_FINISHED on [DONE] and stop reading', async () => { const mockReader = { read: vi .fn() @@ -136,19 +134,11 @@ describe('connection-adapters', () => { chunks.push(chunk) } - expect(chunks).toHaveLength(0) - expect(warnSpy).toHaveBeenCalledWith( - expect.stringContaining('[DONE] sentinel'), - ) - - warnSpy.mockRestore() + expect(chunks).toHaveLength(1) + expect(chunks[0]!.type).toBe('RUN_FINISHED') }) - it('should handle malformed JSON gracefully', async () => { - const consoleWarnSpy = vi - .spyOn(console, 'warn') - .mockImplementation(() => {}) - + it('should throw a SyntaxError on malformed JSON', async () => { const mockReader = { read: vi .fn() @@ -170,17 +160,16 @@ describe('connection-adapters', () => { fetchMock.mockResolvedValue(mockResponse as any) const adapter = fetchServerSentEvents('/api/chat') - const chunks: Array = [] - - for await (const chunk of adapter.connect([ - { role: 'user', content: 'Hello' }, - ])) { - chunks.push(chunk) - } - expect(chunks).toHaveLength(0) - expect(consoleWarnSpy).toHaveBeenCalled() - consoleWarnSpy.mockRestore() + await expect( + (async () => { + for await (const _ of adapter.connect([ + { role: 'user', content: 'Hello' }, + ])) { + // Consume + } + })(), + ).rejects.toThrow(SyntaxError) }) it('should handle HTTP errors', async () => { @@ -486,7 +475,7 @@ describe('connection-adapters', () => { .mockResolvedValueOnce({ done: false, value: new TextEncoder().encode( - 'data: {"type":"RUN_FINISHED","runId":"run-1","finishReason":"stop","timestamp":300}\n\ndata: [DONE]\n\n', + 'data: {"type":"RUN_FINISHED","runId":"run-1","finishReason":"stop","timestamp":300}\n\n', ), }) .mockResolvedValueOnce({ done: true, value: undefined }), @@ -550,11 +539,7 @@ describe('connection-adapters', () => { expect(chunks).toHaveLength(1) }) - it('should handle malformed JSON gracefully', async () => { - const consoleWarnSpy = vi - .spyOn(console, 'warn') - .mockImplementation(() => {}) - + it('should throw a SyntaxError on malformed JSON', async () => { const mockReader = { read: vi .fn() @@ -576,17 +561,16 @@ describe('connection-adapters', () => { fetchMock.mockResolvedValue(mockResponse as any) const adapter = fetchHttpStream('/api/chat') - const chunks: Array = [] - - for await (const chunk of adapter.connect([ - { role: 'user', content: 'Hello' }, - ])) { - chunks.push(chunk) - } - expect(chunks).toHaveLength(0) - expect(consoleWarnSpy).toHaveBeenCalled() - consoleWarnSpy.mockRestore() + await expect( + (async () => { + for await (const _ of adapter.connect([ + { role: 'user', content: 'Hello' }, + ])) { + // Consume + } + })(), + ).rejects.toThrow(SyntaxError) }) it('should handle HTTP errors', async () => { @@ -836,6 +820,7 @@ describe('connection-adapters', () => { expect(streamFactory).toHaveBeenCalledWith( expect.arrayContaining([expect.objectContaining({ role: 'user' })]), data, + undefined, ) }) }) @@ -1023,6 +1008,7 @@ describe('connection-adapters', () => { expect(rpcCall).toHaveBeenCalledWith( expect.arrayContaining([expect.objectContaining({ role: 'user' })]), data, + undefined, ) }) }) diff --git a/packages/typescript/ai-preact/src/use-chat.ts b/packages/typescript/ai-preact/src/use-chat.ts index dacd2be88..8661a0de0 100644 --- a/packages/typescript/ai-preact/src/use-chat.ts +++ b/packages/typescript/ai-preact/src/use-chat.ts @@ -55,8 +55,12 @@ export function useChat = any>( isFirstMountRef.current = false + const transport = optionsRef.current.connection + ? { connection: optionsRef.current.connection } + : { fetcher: optionsRef.current.fetcher! } + return new ChatClient({ - connection: optionsRef.current.connection, + ...transport, id: clientId, initialMessages: messagesToUse, body: optionsRef.current.body, diff --git a/packages/typescript/ai-react/src/index.ts b/packages/typescript/ai-react/src/index.ts index 5ce8c9911..36d547a7d 100644 --- a/packages/typescript/ai-react/src/index.ts +++ b/packages/typescript/ai-react/src/index.ts @@ -56,7 +56,11 @@ export { fetchServerSentEvents, fetchHttpStream, stream, + rpcStream, createChatClientOptions, + type ChatFetcher, + type ChatFetcherInput, + type ChatFetcherOptions, type ConnectionAdapter, type FetchConnectionOptions, type InferChatMessages, diff --git a/packages/typescript/ai-react/src/types.ts b/packages/typescript/ai-react/src/types.ts index ec8f49a92..4335607f1 100644 --- a/packages/typescript/ai-react/src/types.ts +++ b/packages/typescript/ai-react/src/types.ts @@ -12,20 +12,13 @@ import type { export type { ChatRequestBody, MultimodalContent, UIMessage } /** - * Options for the useChat hook. + * Options for the useChat hook. Pass either `connection` or `fetcher` — + * the XOR is enforced at the type level via `ChatTransport`. * - * This extends ChatClientOptions but omits the state change callbacks that are - * managed internally by React state: - * - `onMessagesChange` - Managed by React state (exposed as `messages`) - * - `onLoadingChange` - Managed by React state (exposed as `isLoading`) - * - `onErrorChange` - Managed by React state (exposed as `error`) - * - `onStatusChange` - Managed by React state (exposed as `status`) - * - * All other callbacks (onResponse, onChunk, onFinish, onError) are - * passed through to the underlying ChatClient and can be used for side effects. - * - * Note: Connection and body changes will recreate the ChatClient instance. - * To update these options, remount the component or use a key prop. + * State-change callbacks (`onMessagesChange`, `onLoadingChange`, + * `onErrorChange`, `onStatusChange`, etc.) are owned by the hook and + * exposed as React state. Side-effect callbacks (`onResponse`, `onChunk`, + * `onFinish`, `onError`) are passed through. */ export type UseChatOptions = any> = Omit< diff --git a/packages/typescript/ai-react/src/use-chat.ts b/packages/typescript/ai-react/src/use-chat.ts index c95589874..79a4758a2 100644 --- a/packages/typescript/ai-react/src/use-chat.ts +++ b/packages/typescript/ai-react/src/use-chat.ts @@ -53,8 +53,12 @@ export function useChat = any>( isFirstMountRef.current = false + const transport = optionsRef.current.connection + ? { connection: optionsRef.current.connection } + : { fetcher: optionsRef.current.fetcher! } + return new ChatClient({ - connection: optionsRef.current.connection, + ...transport, id: clientId, initialMessages: messagesToUse, body: optionsRef.current.body, diff --git a/packages/typescript/ai-react/tests/use-chat-fetcher.test.ts b/packages/typescript/ai-react/tests/use-chat-fetcher.test.ts new file mode 100644 index 000000000..22a73abff --- /dev/null +++ b/packages/typescript/ai-react/tests/use-chat-fetcher.test.ts @@ -0,0 +1,111 @@ +import { renderHook, waitFor } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import type { ChatFetcher } from '@tanstack/ai-client' +import type { StreamChunk } from '@tanstack/ai' +import { useChat } from '../src/use-chat' +import { createTextChunks } from './test-utils' + +describe('useChat — fetcher transport', () => { + it('streams text into messages via an AsyncIterable fetcher', async () => { + const chunks = createTextChunks('Hello world', 'msg-1') + const fetcher: ChatFetcher = async function* () { + for (const chunk of chunks) { + yield chunk + } + } as unknown as ChatFetcher + + const { result } = renderHook(() => useChat({ fetcher })) + + await result.current.sendMessage('hi') + + await waitFor(() => { + expect(result.current.messages).toHaveLength(2) + }) + const assistant = result.current.messages[1]! + const textPart = assistant.parts.find((p) => p.type === 'text') + expect(textPart && 'content' in textPart && textPart.content).toBe( + 'Hello world', + ) + }) + + it('parses an SSE Response returned by the fetcher', async () => { + const sseBody = + [ + `data: ${JSON.stringify({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'm1', + model: 'test', + timestamp: Date.now(), + delta: 'Hi', + content: 'Hi', + })}`, + `data: ${JSON.stringify({ + type: 'RUN_FINISHED', + runId: 'r1', + threadId: 't1', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + })}`, + '', + ].join('\n') + '\n' + + const fetcher: ChatFetcher = async () => + new Response(sseBody, { + status: 200, + headers: { 'content-type': 'text/event-stream' }, + }) + + const { result } = renderHook(() => useChat({ fetcher })) + + await result.current.sendMessage('hi') + + await waitFor(() => { + expect(result.current.messages).toHaveLength(2) + }) + const assistant = result.current.messages[1]! + const textPart = assistant.parts.find((p) => p.type === 'text') + expect(textPart && 'content' in textPart && textPart.content).toBe('Hi') + }) + + it('surfaces fetcher errors as the hook error state', async () => { + const fetcher: ChatFetcher = async () => { + throw new Error('boom') + } + + const { result } = renderHook(() => useChat({ fetcher })) + + await result.current.sendMessage('hi') + + await waitFor(() => { + expect(result.current.error).toBeDefined() + }) + expect(result.current.error!.message).toBe('boom') + expect(result.current.status).toBe('error') + }) + + it('passes the merged body and full message history to the fetcher', async () => { + const fetcher = vi.fn(async function* () { + yield { + type: 'RUN_FINISHED', + runId: 'r1', + threadId: 't1', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + } as StreamChunk + } as unknown as ChatFetcher) + + const { result } = renderHook(() => + useChat({ fetcher, body: { provider: 'openai' } }), + ) + + await result.current.sendMessage('hello') + + expect(fetcher).toHaveBeenCalledTimes(1) + const [input] = fetcher.mock.calls[0]! + expect(input.messages).toHaveLength(1) + expect(input.messages[0]!.role).toBe('user') + expect(input.data).toMatchObject({ provider: 'openai' }) + }) +}) diff --git a/packages/typescript/ai-solid/src/use-chat.ts b/packages/typescript/ai-solid/src/use-chat.ts index 0aff8603a..198b37e96 100644 --- a/packages/typescript/ai-solid/src/use-chat.ts +++ b/packages/typescript/ai-solid/src/use-chat.ts @@ -40,8 +40,11 @@ export function useChat = any>( // in-place mutations propagate. When the user clears a callback (sets it to // undefined), `?.` no-ops. const client = createMemo(() => { + const transport = options.connection + ? { connection: options.connection } + : { fetcher: options.fetcher! } return new ChatClient({ - connection: options.connection, + ...transport, id: clientId, initialMessages: options.initialMessages, body: options.body, diff --git a/packages/typescript/ai-svelte/src/create-chat.svelte.ts b/packages/typescript/ai-svelte/src/create-chat.svelte.ts index 3a9eeb232..fbac75d6f 100644 --- a/packages/typescript/ai-svelte/src/create-chat.svelte.ts +++ b/packages/typescript/ai-svelte/src/create-chat.svelte.ts @@ -60,8 +60,12 @@ export function createChat = any>( // by reference. Callbacks are therefore frozen to whatever the caller passed // at creation — to swap them dynamically, mutate the options object // in-place or call `client.updateOptions(...)` imperatively. + const transport = options.connection + ? { connection: options.connection } + : { fetcher: options.fetcher! } + const client = new ChatClient({ - connection: options.connection, + ...transport, id: clientId, initialMessages: options.initialMessages, body: options.body, diff --git a/packages/typescript/ai-vue/src/use-chat.ts b/packages/typescript/ai-vue/src/use-chat.ts index 3f10b4dcb..a587fd248 100644 --- a/packages/typescript/ai-vue/src/use-chat.ts +++ b/packages/typescript/ai-vue/src/use-chat.ts @@ -32,8 +32,12 @@ export function useChat = any>( // in-place mutations propagate. When the user clears a callback (sets it to // undefined), `?.` no-ops — unlike `client.updateOptions`, which silently // skips undefined and leaves the old callback installed. + const transport = options.connection + ? { connection: options.connection } + : { fetcher: options.fetcher! } + const client = new ChatClient({ - connection: options.connection, + ...transport, id: clientId, initialMessages: options.initialMessages, body: options.body, diff --git a/testing/e2e/src/routes/$provider/$feature.tsx b/testing/e2e/src/routes/$provider/$feature.tsx index de76993a2..fdb448630 100644 --- a/testing/e2e/src/routes/$provider/$feature.tsx +++ b/testing/e2e/src/routes/$provider/$feature.tsx @@ -64,7 +64,7 @@ function FeaturePage() { ) } - return + return } function MediaFeature({ @@ -125,9 +125,11 @@ function MediaFeature({ function ChatFeature({ provider, feature, + mode, }: { provider: Provider feature: Feature + mode?: Mode }) { const needsApproval = feature === 'tool-approval' const showImageInput = @@ -137,9 +139,32 @@ function ChatFeature({ const { testId, aimockPort } = Route.useSearch() + const transport = + mode === 'fetcher' + ? { + fetcher: async ( + input: { messages: unknown; data?: unknown }, + options: { signal: AbortSignal }, + ) => + fetch('/api/chat', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + messages: input.messages, + data: input.data, + provider, + feature, + testId, + aimockPort, + }), + signal: options.signal, + }), + } + : { connection: fetchServerSentEvents('/api/chat') } + const { messages, sendMessage, isLoading, addToolApprovalResponse, stop } = useChat({ - connection: fetchServerSentEvents('/api/chat'), + ...transport, tools, body: { provider, feature, testId, aimockPort }, }) diff --git a/testing/e2e/tests/chat.spec.ts b/testing/e2e/tests/chat.spec.ts index cdec5109f..e5afba466 100644 --- a/testing/e2e/tests/chat.spec.ts +++ b/testing/e2e/tests/chat.spec.ts @@ -22,5 +22,21 @@ for (const provider of providersFor('chat')) { const response = await getLastAssistantMessage(page) expect(response).toContain('Fender Stratocaster') }) + + test('fetcher mode — streams an SSE Response through useChat({ fetcher })', async ({ + page, + testId, + aimockPort, + }) => { + await page.goto( + featureUrl(provider, 'chat', testId, aimockPort, 'fetcher'), + ) + + await sendMessage(page, '[chat] recommend a guitar') + await waitForResponse(page) + + const response = await getLastAssistantMessage(page) + expect(response).toContain('Fender Stratocaster') + }) }) }