diff --git a/core/src/agents/base_agent.ts b/core/src/agents/base_agent.ts index 70ee2cf8..c76c5579 100644 --- a/core/src/agents/base_agent.ts +++ b/core/src/agents/base_agent.ts @@ -217,7 +217,7 @@ export abstract class BaseAgent { * @returns An AsyncGenerator that yields the events generated by the agent. */ async *runLive( - parentContext: InvocationContext, // eslint-disable-line @typescript-eslint/no-unused-vars + parentContext: InvocationContext, ): AsyncGenerator { const span = tracer.startSpan(`invoke_agent ${this.name}`); const ctx = trace.setSpan(context.active(), span); @@ -226,10 +226,33 @@ export abstract class BaseAgent { ctx, this, async function* () { - // TODO(b/425992518): Implement live mode. + const context = this.createInvocationContext(parentContext); + + const beforeAgentCallbackEvent = + await this.handleBeforeAgentCallback(context); + if (beforeAgentCallbackEvent) { + yield beforeAgentCallbackEvent; + } + + if (context.endInvocation || parentContext.abortSignal?.aborted) { + return; + } + + for await (const event of this.runLiveImpl(context)) { + yield event; + } + + if (context.endInvocation || parentContext.abortSignal?.aborted) { + return; + } + + const afterAgentCallbackEvent = + await this.handleAfterAgentCallback(context); + if (afterAgentCallbackEvent) { + yield afterAgentCallbackEvent; + } }, ); - throw new Error('Live mode is not implemented yet.'); } finally { span.end(); } diff --git a/core/src/agents/invocation_context.ts b/core/src/agents/invocation_context.ts index c80a8116..668daafd 100644 --- a/core/src/agents/invocation_context.ts +++ b/core/src/agents/invocation_context.ts @@ -40,6 +40,7 @@ export interface InvocationContextParams { activeStreamingTools?: Record; pluginManager: PluginManager; abortSignal?: AbortSignal; + liveSessionResumptionHandle?: string; } /** @@ -191,6 +192,14 @@ export class InvocationContext { */ readonly abortSignal?: AbortSignal; + /** + * Most recent Gemini Live session resumption handle. Updated from + * `sessionResumptionUpdate` events on the active connection and replayed + * via `liveConnectConfig.sessionResumption` when reconnecting so the + * server can restore in-flight state without client-side history replay. + */ + liveSessionResumptionHandle?: string; + /** * @param params The parameters for creating an invocation context. */ @@ -210,6 +219,7 @@ export class InvocationContext { this.activeStreamingTools = params.activeStreamingTools; this.pluginManager = params.pluginManager; this.abortSignal = params.abortSignal; + this.liveSessionResumptionHandle = params.liveSessionResumptionHandle; } /** diff --git a/core/src/agents/live_request_queue.ts b/core/src/agents/live_request_queue.ts index 5e70a69f..af3882d5 100644 --- a/core/src/agents/live_request_queue.ts +++ b/core/src/agents/live_request_queue.ts @@ -55,17 +55,38 @@ export class LiveRequestQueue { /** * Retrieves a request from the queue. If the queue is empty, it will * wait until a request is available. + * + * @param abortSignal Optional signal. If it aborts while this call is + * waiting, the pending waiter is removed from the queue and the + * returned promise rejects -- so a torn-down consumer does not strand + * a waiter that would later consume (and drop) a request. * @returns A promise that resolves with the next available request. */ - async get(): Promise { + async get(abortSignal?: AbortSignal): Promise { if (this.queue.length > 0) { return this.queue.shift()!; } if (this.isClosed) { return {close: true}; } - return new Promise((resolve) => { - this.resolveFnFifoQueue.push(resolve); + if (abortSignal?.aborted) { + throw new Error('LiveRequestQueue.get() was aborted.'); + } + return new Promise((resolve, reject) => { + let resolveFn: PromiseResolveFn; + const onAbort = () => { + const index = this.resolveFnFifoQueue.indexOf(resolveFn); + if (index !== -1) { + this.resolveFnFifoQueue.splice(index, 1); + } + reject(new Error('LiveRequestQueue.get() was aborted.')); + }; + resolveFn = (req: LiveRequest) => { + abortSignal?.removeEventListener('abort', onAbort); + resolve(req); + }; + this.resolveFnFifoQueue.push(resolveFn); + abortSignal?.addEventListener('abort', onAbort, {once: true}); }); } diff --git a/core/src/agents/llm_agent.ts b/core/src/agents/llm_agent.ts index 30aba287..27c657ed 100644 --- a/core/src/agents/llm_agent.ts +++ b/core/src/agents/llm_agent.ts @@ -17,12 +17,14 @@ import { createNewEventId, Event, getFunctionCalls, + getFunctionResponses, isFinalResponse, } from '../events/event.js'; import {BaseExampleProvider} from '../examples/base_example_provider.js'; import {Example} from '../examples/example.js'; import {BaseLlm, isBaseLlm} from '../models/base_llm.js'; +import {BaseLlmConnection} from '../models/base_llm_connection.js'; import {LlmRequest} from '../models/llm_request.js'; import {LlmResponse} from '../models/llm_response.js'; import {LLMRegistry} from '../models/registry.js'; @@ -55,6 +57,7 @@ import { import {BaseContextCompactor} from '../context/base_context_compactor.js'; import {InvocationContext} from './invocation_context.js'; +import {LiveRequest, LiveRequestQueue} from './live_request_queue.js'; import {AGENT_TRANSFER_LLM_REQUEST_PROCESSOR} from './processors/agent_transfer_llm_request_processor.js'; import {BASIC_LLM_REQUEST_PROCESSOR} from './processors/basic_llm_request_processor.js'; import {CODE_EXECUTION_REQUEST_PROCESSOR} from './processors/code_execution_request_processor.js'; @@ -67,6 +70,113 @@ import {TOOL_FILTER_REQUEST_PROCESSOR} from './processors/tool_filter_request_pr import {ReadonlyContext} from './readonly_context.js'; import {StreamingMode} from './run_config.js'; +/** + * Maximum number of reconnect attempts on transient live connection failure + * when a session resumption handle is available. + */ +const MAX_LIVE_RECONNECT_ATTEMPTS = 5; + +/** + * Delay before closing the parent connection on agent transfer. Gives the + * server-side model a moment to flush any pending audio for the final turn + * before teardown. Mirrors `DEFAULT_TRANSFER_AGENT_DELAY` (1.0s) in the Python + * ADK live flow; the value is an empirical heuristic, not a guarantee. + */ +const TRANSFER_AGENT_DELAY_MS = 1000; + +/** + * Sentinel thrown from `runReceiveLoop` to break out of the receive iterator + * and signal `runLiveFlow` to reconnect using the stored resumption handle. + * Used when the server sends `goAway` or any other recoverable terminal. + */ +class LiveReconnectSignal extends Error { + constructor(readonly reason: string) { + super(`live reconnect requested: ${reason}`); + this.name = 'LiveReconnectSignal'; + } +} + +/** + * Classifies errors that should trigger a reconnect attempt instead of + * propagating. Matches the Python flow's allowlist of recoverable codes. + */ +function isRecoverableLiveError(err: unknown): boolean { + if (err instanceof LiveReconnectSignal) return true; + if (!(err instanceof Error)) return false; + const code = (err as {code?: unknown}).code; + // Standard WebSocket close codes treated as transient by the Python flow. + if (code === 1000 || code === 1006 || code === 1011 || code === 1012) { + return true; + } + const message = err.message ?? ''; + return /ConnectionClosed|connection closed|ECONNRESET|socket hang up/i.test( + message, + ); +} + +async function closeQuietly(connection: BaseLlmConnection): Promise { + try { + await connection.close(); + } catch (error) { + logger.warn('Error closing live connection:', error); + } +} + +function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +function combineAbortSignals( + ...signals: Array +): AbortSignal | undefined { + const present = signals.filter((s): s is AbortSignal => s !== undefined); + if (present.length === 0) return undefined; + if (present.length === 1) return present[0]; + const controller = new AbortController(); + for (const signal of present) { + if (signal.aborted) { + controller.abort(signal.reason); + break; + } + signal.addEventListener('abort', () => controller.abort(signal.reason), { + // Auto-remove when the combined signal aborts, so listeners are not + // left behind on a long-lived input signal (e.g. the invocation's + // abortSignal) across reconnect attempts. + signal: controller.signal, + }); + } + return controller.signal; +} + +function applyLiveRunConfig( + runConfig: InvocationContext['runConfig'], + llmRequest: LlmRequest, +): void { + if (!runConfig) return; + const liveConfig = (llmRequest.liveConnectConfig ??= {}); + if (runConfig.responseModalities) { + liveConfig.responseModalities = runConfig.responseModalities; + } + if (runConfig.speechConfig) { + liveConfig.speechConfig = runConfig.speechConfig; + } + if (runConfig.outputAudioTranscription) { + liveConfig.outputAudioTranscription = runConfig.outputAudioTranscription; + } + if (runConfig.inputAudioTranscription) { + liveConfig.inputAudioTranscription = runConfig.inputAudioTranscription; + } + if (runConfig.realtimeInputConfig) { + liveConfig.realtimeInputConfig = runConfig.realtimeInputConfig; + } + if (runConfig.proactivity) { + liveConfig.proactivity = runConfig.proactivity; + } + if (runConfig.enableAffectiveDialog) { + liveConfig.enableAffectiveDialog = runConfig.enableAffectiveDialog; + } +} + /** * Input/output schema type for agent. */ @@ -707,13 +817,455 @@ export class LlmAgent extends BaseAgent { // -------------------------------------------------------------------------- // #START LlmFlow Logic // -------------------------------------------------------------------------- - // eslint-disable-next-line require-yield + /** + * Runs the bidirectional (live) flow for this agent. + * + * Establishes a live connection to the model, drains the invocation's + * `liveRequestQueue` into the connection on a parallel task, and yields + * events derived from server messages until the queue closes, the model + * finishes, or an agent transfer occurs. + * + * If the live connection drops (network failure, server `goAway`) and a + * session resumption handle has been observed, the flow transparently + * reconnects using that handle up to {@link MAX_LIVE_RECONNECT_ATTEMPTS} + * times. Subsequent reconnects skip `sendHistory` because the server + * already holds the conversation state associated with the handle. + */ private async *runLiveFlow( - _invocationContext: InvocationContext, + invocationContext: InvocationContext, ): AsyncGenerator { - // TODO - b/425992518: remove dummy logic, implement this. - await Promise.resolve(); - throw new Error('LlmAgent.runLiveFlow not implemented'); + if (!invocationContext.liveRequestQueue) { + throw new Error('liveRequestQueue is required for LlmAgent.runLiveFlow.'); + } + + const llmRequest: LlmRequest = { + contents: [], + toolsDict: {}, + liveConnectConfig: {}, + }; + + // ========================================================================= + // Preprocess: same processors as runAsync. Yields agent-emitted events + // (e.g. instruction injection metadata events) to the caller. + // ========================================================================= + for await (const event of this.runLivePreprocess( + invocationContext, + llmRequest, + )) { + yield event; + } + + if ( + invocationContext.endInvocation || + invocationContext.abortSignal?.aborted + ) { + return; + } + + // ========================================================================= + // Apply live-only request config from the run config. + // ========================================================================= + applyLiveRunConfig(invocationContext.runConfig, llmRequest); + + const llm = this.canonicalModel; + let attempt = 1; + + // Outer reconnect loop. Re-enters on recoverable failures when a session + // resumption handle is available; the server restores state on the new + // connection so we skip history replay. + + while (true) { + if (invocationContext.abortSignal?.aborted) { + return; + } + + // Apply the latest resumption handle before each connect attempt. + const handle = invocationContext.liveSessionResumptionHandle; + if (handle) { + llmRequest.liveConnectConfig ??= {}; + llmRequest.liveConnectConfig.sessionResumption = { + handle, + transparent: true, + }; + } + + let connection: BaseLlmConnection; + try { + connection = await llm.connect(llmRequest); + } catch (err) { + if ( + isRecoverableLiveError(err) && + invocationContext.liveSessionResumptionHandle + ) { + if (attempt > MAX_LIVE_RECONNECT_ATTEMPTS) { + logger.error( + `Max live reconnection attempts reached (${attempt}).`, + err, + ); + throw err; + } + logger.info( + `Live connect failed (attempt ${attempt}); retrying with session handle.`, + err, + ); + attempt += 1; + continue; + } + throw err; + } + + // Skip history replay when resuming -- server already has the state. + if ( + llmRequest.contents.length > 0 && + !invocationContext.liveSessionResumptionHandle + ) { + await connection.sendHistory(llmRequest.contents); + } + + const sendAbort = new AbortController(); + const combinedAbort = combineAbortSignals( + invocationContext.abortSignal, + sendAbort.signal, + ); + const sendTask = this.runSendLoop( + connection, + invocationContext.liveRequestQueue, + combinedAbort, + ); + sendTask.catch((error) => { + logger.error('Error in live send loop:', error); + }); + + let reconnect = false; + try { + yield* this.runReceiveLoop( + invocationContext, + connection, + llmRequest, + sendAbort, + ); + } catch (err) { + const canReconnect = + !!invocationContext.liveSessionResumptionHandle && + (err instanceof LiveReconnectSignal || isRecoverableLiveError(err)); + if (canReconnect) { + reconnect = true; + logger.info( + 'Live connection closed; will reconnect with session handle.', + err, + ); + } else { + // Tear down before rethrowing. + sendAbort.abort(); + await closeQuietly(connection); + await sendTask.catch(() => undefined); + throw err; + } + } + + // Cancel send loop for this attempt; receive loop has exited. + sendAbort.abort(); + await closeQuietly(connection); + await sendTask.catch(() => undefined); + + if (!reconnect) { + return; + } + + if (attempt > MAX_LIVE_RECONNECT_ATTEMPTS) { + throw new Error(`Max live reconnection attempts reached (${attempt}).`); + } + attempt += 1; + } + } + + private async *runLivePreprocess( + invocationContext: InvocationContext, + llmRequest: LlmRequest, + ): AsyncGenerator { + for (const processor of this.requestProcessors) { + for await (const event of processor.runAsync( + invocationContext, + llmRequest, + )) { + if (invocationContext.abortSignal?.aborted) { + return; + } + yield event; + } + } + for (const toolUnion of this.tools) { + const toolContext = new Context({invocationContext}); + const tools = ( + await convertToolUnionToTools( + toolUnion, + new ReadonlyContext(invocationContext), + ) + ).filter( + (tool) => + !llmRequest.allowedTools || + llmRequest.allowedTools.includes(tool.name), + ); + for (const tool of tools) { + await tool.processLlmRequest({toolContext, llmRequest}); + if (invocationContext.abortSignal?.aborted) { + return; + } + } + } + } + + private async runSendLoop( + connection: BaseLlmConnection, + liveRequestQueue: LiveRequestQueue, + abortSignal?: AbortSignal, + ): Promise { + while (true) { + if (abortSignal?.aborted) { + return; + } + let liveRequest: LiveRequest; + try { + // Pass the abort signal so a parked read is released on teardown + // (reconnect, agent transfer) instead of stranding a waiter that + // would later steal a request from the next connection's send loop. + liveRequest = await liveRequestQueue.get(abortSignal); + } catch (error) { + if (abortSignal?.aborted) { + return; + } + throw error; + } + try { + await this.dispatchLiveRequest(connection, liveRequest); + } catch (error) { + logger.error('Error dispatching live request to model:', error); + throw error; + } + // Cooperative yield: avoid starving the event loop when the queue is + // backlogged so receive-loop events get a chance to interleave. + await Promise.resolve(); + if (liveRequest.close) { + return; + } + } + } + + private async dispatchLiveRequest( + connection: BaseLlmConnection, + liveRequest: LiveRequest, + ): Promise { + if (liveRequest.close) { + await connection.close(); + return; + } + if (liveRequest.activityStart) { + await connection.sendActivityStart?.(); + return; + } + if (liveRequest.activityEnd) { + await connection.sendActivityEnd?.(); + return; + } + if (liveRequest.blob) { + await connection.sendRealtime(liveRequest.blob); + return; + } + if (liveRequest.content) { + await connection.sendContent(liveRequest.content); + } + } + + private async *runReceiveLoop( + invocationContext: InvocationContext, + connection: BaseLlmConnection, + llmRequest: LlmRequest, + sendAbort: AbortController, + ): AsyncGenerator { + for await (const llmResponse of connection.receive()) { + if (invocationContext.abortSignal?.aborted) { + return; + } + + // Capture the latest server-provided resumption handle on the + // invocation context so that any subsequent reconnect attempt can + // resume server-side state instead of replaying history. + if (llmResponse.liveSessionResumptionUpdate?.newHandle) { + invocationContext.liveSessionResumptionHandle = + llmResponse.liveSessionResumptionUpdate.newHandle; + } + + // GoAway is the server's "I'm about to close; reconnect with your + // resumption handle" signal. Throw a sentinel to break the outer + // reconnect loop in runLiveFlow. + if (llmResponse.goAway) { + logger.info('Received goAway from live server; triggering reconnect.'); + throw new LiveReconnectSignal('goAway'); + } + + const author = isUserAuthoredResponse(llmResponse) ? 'user' : this.name; + + const modelResponseEvent = createEvent({ + invocationId: invocationContext.invocationId, + author, + branch: invocationContext.branch, + }); + + for await (const event of this.postprocessLive( + invocationContext, + llmRequest, + llmResponse, + modelResponseEvent, + )) { + yield event; + + // Send function responses directly through the connection rather + // than via the live request queue. The TS LiveRequestQueue rejects + // sends after close (strict semantics), and callers commonly close + // the queue at end-of-input before the model finishes ferrying tool + // results back. Python's queue tolerates post-close sends, but + // porting that semantics is out of scope here. + if (event.content && getFunctionResponses(event).length > 0) { + await connection.sendContent(event.content); + } + + // Handle agent transfer triggered by a transfer_to_agent function + // response. The active connection is closed and the destination + // sub-agent's runLive is yielded into the same generator. + const transferTo = event.actions?.transferToAgent; + if (transferTo) { + // Brief delay lets the model finish flushing pending audio for + // the in-flight turn before we tear down the connection. + await sleep(TRANSFER_AGENT_DELAY_MS); + // Stop the parent send loop before the sub-agent starts its own, + // so the two never consume the shared liveRequestQueue + // concurrently (mirrors `send_task.cancel()` in the Python flow). + sendAbort.abort(); + await connection.close(); + const subAgent = + invocationContext.agent.rootAgent.findAgent(transferTo); + if (subAgent) { + const previousAgent = invocationContext.agent; + invocationContext.agent = subAgent; + // Child agent starts its own live session; do not carry over + // the parent's resumption handle. + const previousHandle = + invocationContext.liveSessionResumptionHandle; + invocationContext.liveSessionResumptionHandle = undefined; + try { + for await (const subEvent of subAgent.runLive( + invocationContext, + )) { + yield subEvent; + } + } finally { + invocationContext.agent = previousAgent; + invocationContext.liveSessionResumptionHandle = previousHandle; + } + } + return; + } + } + } + } + + private async *postprocessLive( + invocationContext: InvocationContext, + llmRequest: LlmRequest, + llmResponse: LlmResponse, + modelResponseEvent: Event, + ): AsyncGenerator { + for (const processor of this.responseProcessors) { + for await (const event of processor.runAsync( + invocationContext, + llmResponse, + )) { + yield event; + } + } + + // Skip purely empty responses, but allow control signals to surface. + if ( + !llmResponse.content && + !llmResponse.errorCode && + !llmResponse.interrupted && + !llmResponse.turnComplete && + !llmResponse.inputTranscription && + !llmResponse.outputTranscription && + !llmResponse.usageMetadata && + !llmResponse.liveSessionResumptionUpdate + ) { + return; + } + + // The connection layer (GeminiLlmConnection.receive) emits resumption + // updates and transcriptions as standalone, single-field responses -- + // never combined with `content`. Each is therefore handled with an early + // return; if that invariant changes, co-located fields would be dropped + // here and these branches would need to merge instead. + if (llmResponse.liveSessionResumptionUpdate) { + yield createEvent({ + ...modelResponseEvent, + liveSessionResumptionUpdate: llmResponse.liveSessionResumptionUpdate, + }); + return; + } + + if (llmResponse.inputTranscription) { + yield createEvent({ + ...modelResponseEvent, + inputTranscription: llmResponse.inputTranscription, + partial: llmResponse.partial, + }); + return; + } + if (llmResponse.outputTranscription) { + yield createEvent({ + ...modelResponseEvent, + outputTranscription: llmResponse.outputTranscription, + partial: llmResponse.partial, + }); + return; + } + + const mergedEvent = createEvent({ + ...modelResponseEvent, + ...llmResponse, + }); + + const functionCalls = getFunctionCalls(mergedEvent); + if (mergedEvent.content && functionCalls.length) { + populateClientFunctionCallId(mergedEvent); + mergedEvent.longRunningToolIds = Array.from( + getLongRunningFunctionCalls(functionCalls, llmRequest.toolsDict), + ); + } + + yield mergedEvent; + + // Execute any function calls returned in this event. + if (!functionCalls.length) { + return; + } + + const functionResponseEvent = await handleFunctionCallsAsync({ + invocationContext, + functionCallEvent: mergedEvent, + toolsDict: llmRequest.toolsDict, + beforeToolCallbacks: this.canonicalBeforeToolCallbacks, + afterToolCallbacks: this.canonicalAfterToolCallbacks, + }); + if (!functionResponseEvent) { + return; + } + const authEvent = generateAuthEvent( + invocationContext, + functionResponseEvent, + ); + if (authEvent) { + yield authEvent; + } + yield functionResponseEvent; } private async *runOneStepAsync( @@ -1220,3 +1772,17 @@ export class LlmAgent extends BaseAgent { // - code_executor // - configurable agents by yaml config } + +/** + * Determines whether a live response should be authored as 'user'. + * + * Input transcription represents the user's spoken input, and any explicit + * user-role content (e.g. echoed function responses) likewise belongs to the + * user side of the transcript. + */ +function isUserAuthoredResponse(llmResponse: LlmResponse): boolean { + if (llmResponse.inputTranscription) { + return true; + } + return llmResponse.content?.role === 'user'; +} diff --git a/core/src/models/base_llm_connection.ts b/core/src/models/base_llm_connection.ts index f984ffda..8677eb1e 100644 --- a/core/src/models/base_llm_connection.ts +++ b/core/src/models/base_llm_connection.ts @@ -44,6 +44,18 @@ export interface BaseLlmConnection { */ sendRealtime(blob: Blob): Promise; + /** + * Optionally signals the start of user activity (e.g. user begins speaking) + * for models that support manual activity boundaries. + */ + sendActivityStart?(): Promise; + + /** + * Optionally signals the end of user activity (e.g. user finishes speaking) + * for models that support manual activity boundaries. + */ + sendActivityEnd?(): Promise; + /** * Receives the model response using the llm server connection. * diff --git a/core/src/models/gemini_llm_connection.ts b/core/src/models/gemini_llm_connection.ts index 3eb6da28..e54ebda9 100644 --- a/core/src/models/gemini_llm_connection.ts +++ b/core/src/models/gemini_llm_connection.ts @@ -4,105 +4,338 @@ * SPDX-License-Identifier: Apache-2.0 */ -import {Blob, Content, FunctionResponse, Session} from '@google/genai'; +import { + Blob, + Content, + FunctionResponse, + LiveServerMessage, + Session, + Transcription, +} from '@google/genai'; import {logger} from '../utils/logger.js'; +import {isGemini31FlashLive} from '../utils/model_name.js'; +import {GoogleLLMVariant} from '../utils/variant_utils.js'; import {BaseLlmConnection} from './base_llm_connection.js'; import {LlmResponse} from './llm_response.js'; -/** The Gemini model connection. */ +/** + * Internal record passed from the GenAI websocket callbacks to `receive()`. + */ +type IncomingRecord = + | {kind: 'message'; message: LiveServerMessage} + | {kind: 'error'; error: Error} + | {kind: 'close'}; + +/** + * Buffers incoming events from a callback-based websocket so they can be + * consumed as an async generator. + */ +export class IncomingMessageBuffer { + private readonly queue: IncomingRecord[] = []; + private readonly waiters: Array<(record: IncomingRecord) => void> = []; + private terminated = false; + + push(record: IncomingRecord): void { + if (this.terminated) { + return; + } + if (record.kind !== 'message') { + this.terminated = true; + } + if (this.waiters.length > 0) { + const resolve = this.waiters.shift()!; + resolve(record); + return; + } + this.queue.push(record); + } + + async pull(): Promise { + if (this.queue.length > 0) { + return this.queue.shift()!; + } + if (this.terminated) { + return {kind: 'close'}; + } + return new Promise((resolve) => { + this.waiters.push(resolve); + }); + } +} + +/** + * The Gemini live model connection. + * + * Bridges the callback-based GenAI live `Session` and the ADK + * `BaseLlmConnection` async-generator contract. + */ export class GeminiLlmConnection implements BaseLlmConnection { - constructor(private readonly geminiSession: Session) {} + private inputTranscriptionText = ''; + private outputTranscriptionText = ''; + + constructor( + private readonly geminiSession: Session, + private readonly incomingMessages: IncomingMessageBuffer, + private readonly apiBackend: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API, + private readonly modelVersion?: string, + ) {} /** * Sends the conversation history to the gemini model. * - * You call this method right after setting up the model connection. - * The model will respond if the last content is from user, otherwise it will - * wait for new user input before responding. + * Called once on a freshly opened connection when there is no session + * resumption handle. With a handle the server already holds the state, so + * the caller skips this method entirely. + * + * Audio parts are filtered out: the live API does not accept previous-turn + * audio via `sendClientContent`, and any audio has already been transcribed. + * + * `turnComplete` is set to true only when the last turn is from the user -- + * a fresh user message (or function response) that the model must respond + * to. When the last turn is from the model, the server replays the history + * as past context and waits for new user input via the live request queue. * * @param history The conversation history to send to the model. */ async sendHistory(history: Content[]): Promise { - // We ignore any audio from user during the agent transfer phase. - const contents = history.filter( - (content) => content.parts && content.parts[0]?.text, - ); - - if (contents.length > 0) { - this.geminiSession.sendClientContent({ - turns: contents, - turnComplete: contents[contents.length - 1].role === 'user', - }); - } else { + const contents = history + .map((content) => filterAudioParts(content)) + .filter((content): content is Content => content !== undefined); + + if (contents.length === 0) { logger.info('no content is sent'); + return; } + + this.geminiSession.sendClientContent({ + turns: contents, + turnComplete: contents[contents.length - 1].role === 'user', + }); } /** * Sends a user content to the gemini model. * - * The model will respond immediately upon receiving the content. - * If you send function responses, all parts in the content should be function - * responses. + * If the content contains function responses, all parts must be function + * responses; the call is dispatched as a tool response. * * @param content The content to send to the model. */ async sendContent(content: Content): Promise { - if (!content.parts) { + if (!content.parts?.length) { throw new Error('Content must have parts.'); } if (content.parts[0].functionResponse) { - // All parts have to be function responses. const functionResponses = content.parts .map((part) => part.functionResponse) .filter((fr): fr is FunctionResponse => !!fr); logger.debug('Sending LLM function response:', functionResponses); - this.geminiSession.sendToolResponse({ - functionResponses, - }); - } else { - logger.debug('Sending LLM new content', content); - this.geminiSession.sendClientContent({ - turns: [content], - turnComplete: true, - }); + this.geminiSession.sendToolResponse({functionResponses}); + return; } + logger.debug('Sending LLM new content', content); + // Gemini 3.1 Flash Live ignores sendClientContent text turns, so single + // text-part user content must go through the realtime-input path. Earlier + // models accept either path; prefer sendClientContent to avoid surprising + // semantics differences with the activity-detection / interruption logic. + if ( + isGemini31FlashLive(this.modelVersion) && + content.role === 'user' && + content.parts.length === 1 && + typeof content.parts[0].text === 'string' + ) { + this.geminiSession.sendRealtimeInput({text: content.parts[0].text}); + return; + } + this.geminiSession.sendClientContent({ + turns: [content], + turnComplete: true, + }); } /** * Sends a chunk of audio or a frame of video to the model in realtime. * + * Gemini 3.1 Flash Live splits realtime inputs into typed channels (`audio`, + * `video`); earlier models accept a generic `media` blob with any mime. + * * @param blob The blob to send to the model. */ async sendRealtime(blob: Blob): Promise { - logger.debug('Sending LLM Blob:', blob); + logger.debug('Sending LLM Blob.'); + if (isGemini31FlashLive(this.modelVersion)) { + const mime = blob.mimeType ?? ''; + if (mime.startsWith('audio/')) { + this.geminiSession.sendRealtimeInput({audio: blob}); + return; + } + if (mime.startsWith('image/')) { + this.geminiSession.sendRealtimeInput({video: blob}); + return; + } + logger.warn( + 'Blob not sent. Unknown or empty mime type for sendRealtimeInput:', + mime, + ); + return; + } this.geminiSession.sendRealtimeInput({media: blob}); } /** - * Builds a full text response. - * - * The text should not be partial and the returned LlmResponse is not be - * partial. - * - * @param text The text to be included in the response. - * @returns An LlmResponse containing the full text. + * Sends an activity start signal to the model. + */ + async sendActivityStart(): Promise { + this.geminiSession.sendRealtimeInput({activityStart: {}}); + } + + /** + * Sends an activity end signal to the model. */ - private buildFullTextResponse(text: string): LlmResponse { - return { - content: { - role: 'model', - parts: [{text}], - }, - }; + async sendActivityEnd(): Promise { + this.geminiSession.sendRealtimeInput({activityEnd: {}}); } - // TODO(b/425992518): GenAI SDK inconsistent API, missing methods. - // eslint-disable-next-line require-yield + /** + * Receives the model response using the llm server connection. + * + * Yields one or more `LlmResponse`s per server message. Terminates when the + * model signals `turnComplete`, the websocket closes, or an error occurs. + * + * @yields LlmResponse: The model response. + */ async *receive(): AsyncGenerator { - throw new Error('Not Implemented.'); + let aggregatedText = ''; + + while (true) { + const record = await this.incomingMessages.pull(); + if (record.kind === 'close') { + break; + } + if (record.kind === 'error') { + throw record.error; + } + const message = record.message; + logger.debug('Got LLM Live message'); + + if (message.usageMetadata) { + yield {usageMetadata: message.usageMetadata}; + } + + if (message.serverContent) { + const serverContent = message.serverContent; + const content = serverContent.modelTurn; + + if ( + (!content || !content.parts?.length) && + serverContent.groundingMetadata && + !serverContent.turnComplete + ) { + yield { + groundingMetadata: serverContent.groundingMetadata, + interrupted: serverContent.interrupted, + }; + } + + if (content?.parts?.length) { + const llmResponse: LlmResponse = { + content, + interrupted: serverContent.interrupted, + }; + if (!serverContent.turnComplete) { + llmResponse.groundingMetadata = serverContent.groundingMetadata; + } + const firstPart = content.parts[0]; + if (firstPart.text) { + aggregatedText += firstPart.text; + llmResponse.partial = true; + } else if (aggregatedText && !firstPart.inlineData) { + yield buildFullTextResponse(aggregatedText); + aggregatedText = ''; + } + yield llmResponse; + } + + if (serverContent.inputTranscription) { + for (const event of this.handleTranscription( + serverContent.inputTranscription, + 'input', + )) { + yield event; + } + } + if (serverContent.outputTranscription) { + for (const event of this.handleTranscription( + serverContent.outputTranscription, + 'output', + )) { + yield event; + } + } + + // Gemini API may not emit a `finished` transcription. Flush pending + // partial transcriptions on terminal control signals. + if ( + this.apiBackend === GoogleLLMVariant.GEMINI_API && + (serverContent.interrupted || + serverContent.turnComplete || + serverContent.generationComplete) + ) { + for (const event of this.flushPendingTranscriptions()) { + yield event; + } + } + + if (serverContent.turnComplete) { + if (aggregatedText) { + yield buildFullTextResponse(aggregatedText); + aggregatedText = ''; + } + yield { + turnComplete: true, + interrupted: serverContent.interrupted, + groundingMetadata: serverContent.groundingMetadata, + }; + // turnComplete is just an in-stream signal here — keep iterating + // so the same `receive()` covers all turns until the websocket + // closes (kind: 'close') or the consumer closes the connection. + } + + if (serverContent.interrupted) { + if (aggregatedText) { + yield buildFullTextResponse(aggregatedText); + aggregatedText = ''; + } else { + yield {interrupted: serverContent.interrupted}; + } + } + } + + // Gemini 3.1 Flash Live emits a toolCall message without a following + // turnComplete, so tool calls are surfaced as soon as they arrive + // rather than buffered until turn end. + if (message.toolCall?.functionCalls?.length) { + if (aggregatedText) { + yield buildFullTextResponse(aggregatedText); + aggregatedText = ''; + } + const parts = message.toolCall.functionCalls.map((functionCall) => ({ + functionCall, + })); + yield {content: {role: 'model', parts}}; + } + + if (message.sessionResumptionUpdate) { + yield {liveSessionResumptionUpdate: message.sessionResumptionUpdate}; + } + + if (message.goAway) { + logger.debug('Received GoAway message', message.goAway); + yield {goAway: message.goAway}; + } + } } /** @@ -111,4 +344,84 @@ export class GeminiLlmConnection implements BaseLlmConnection { async close(): Promise { this.geminiSession.close(); } + + private *handleTranscription( + transcription: Transcription, + direction: 'input' | 'output', + ): IterableIterator { + const isInput = direction === 'input'; + if (transcription.text) { + if (isInput) { + this.inputTranscriptionText += transcription.text; + } else { + this.outputTranscriptionText += transcription.text; + } + const partial: Transcription = { + text: transcription.text, + finished: false, + }; + yield isInput + ? {inputTranscription: partial, partial: true} + : {outputTranscription: partial, partial: true}; + } + if (transcription.finished) { + const accumulated = isInput + ? this.inputTranscriptionText + : this.outputTranscriptionText; + const finished: Transcription = {text: accumulated, finished: true}; + if (isInput) { + this.inputTranscriptionText = ''; + } else { + this.outputTranscriptionText = ''; + } + yield isInput + ? {inputTranscription: finished, partial: false} + : {outputTranscription: finished, partial: false}; + } + } + + private *flushPendingTranscriptions(): IterableIterator { + if (this.inputTranscriptionText) { + const text = this.inputTranscriptionText; + this.inputTranscriptionText = ''; + yield { + inputTranscription: {text, finished: true}, + partial: false, + }; + } + if (this.outputTranscriptionText) { + const text = this.outputTranscriptionText; + this.outputTranscriptionText = ''; + yield { + outputTranscription: {text, finished: true}, + partial: false, + }; + } + } +} + +function buildFullTextResponse(text: string): LlmResponse { + return { + content: { + role: 'model', + parts: [{text}], + }, + }; +} + +/** + * Removes inline audio parts from a content. Returns undefined if the content + * has no remaining parts after filtering. + */ +function filterAudioParts(content: Content): Content | undefined { + if (!content.parts?.length) { + return content; + } + const filteredParts = content.parts.filter( + (part) => !part.inlineData?.mimeType?.startsWith('audio/'), + ); + if (filteredParts.length === 0) { + return undefined; + } + return {...content, parts: filteredParts}; } diff --git a/core/src/models/google_llm.ts b/core/src/models/google_llm.ts index 55737629..dace405d 100644 --- a/core/src/models/google_llm.ts +++ b/core/src/models/google_llm.ts @@ -19,7 +19,10 @@ import {GoogleLLMVariant} from '../utils/variant_utils.js'; import {StreamingResponseAggregator} from '../utils/streaming_utils.js'; import {BaseLlm} from './base_llm.js'; import {BaseLlmConnection} from './base_llm_connection.js'; -import {GeminiLlmConnection} from './gemini_llm_connection.js'; +import { + GeminiLlmConnection, + IncomingMessageBuffer, +} from './gemini_llm_connection.js'; import {LlmRequest} from './llm_request.js'; import {createLlmResponse, LlmResponse} from './llm_response.js'; @@ -272,15 +275,27 @@ export class Gemini extends BaseLlm { llmRequest.liveConnectConfig.tools = llmRequest.config?.tools; + const incomingMessages = new IncomingMessageBuffer(); const liveSession = await this.liveApiClient.live.connect({ model: llmRequest.model ?? this.model, config: llmRequest.liveConnectConfig, callbacks: { - // TODO - b/425992518: GenAI SDK inconsistent API, missing methods. - onmessage: () => {}, + onmessage: (message) => + incomingMessages.push({kind: 'message', message}), + onerror: (event) => + incomingMessages.push({ + kind: 'error', + error: errorFromEvent(event), + }), + onclose: () => incomingMessages.push({kind: 'close'}), }, }); - return new GeminiLlmConnection(liveSession); + return new GeminiLlmConnection( + liveSession, + incomingMessages, + this.apiBackend, + llmRequest.model ?? this.model, + ); } private preprocessRequest(llmRequest: LlmRequest): void { @@ -312,6 +327,21 @@ function removeDisplayNameIfPresent( } } +function errorFromEvent(event: unknown): Error { + if (event instanceof Error) { + return event; + } + if ( + typeof event === 'object' && + event !== null && + 'message' in event && + typeof (event as {message: unknown}).message === 'string' + ) { + return new Error((event as {message: string}).message); + } + return new Error('Live connection error'); +} + export function geminiInitParams({ model, vertexai, diff --git a/core/src/models/llm_response.ts b/core/src/models/llm_response.ts index 4a869c7d..b952876d 100644 --- a/core/src/models/llm_response.ts +++ b/core/src/models/llm_response.ts @@ -11,6 +11,7 @@ import { GenerateContentResponse, GenerateContentResponseUsageMetadata, GroundingMetadata, + LiveServerGoAway, LiveServerSessionResumptionUpdate, Transcription, } from '@google/genai'; @@ -85,6 +86,12 @@ export interface LlmResponse { */ liveSessionResumptionUpdate?: LiveServerSessionResumptionUpdate; + /** + * Server-side signal that the live connection will be closed soon. The + * caller should reconnect using the latest session resumption handle. + */ + goAway?: LiveServerGoAway; + /** * Audio transcription of user input. */ diff --git a/core/src/runner/runner.ts b/core/src/runner/runner.ts index 88cb3ea9..c0a907fa 100644 --- a/core/src/runner/runner.ts +++ b/core/src/runner/runner.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import {Content, createPartFromText} from '@google/genai'; +import {Content, createPartFromText, Modality} from '@google/genai'; import {context, trace} from '@opentelemetry/api'; import {BaseAgent} from '../agents/base_agent.js'; @@ -12,6 +12,7 @@ import { InvocationContext, newInvocationContextId, } from '../agents/invocation_context.js'; +import {LiveRequestQueue} from '../agents/live_request_queue.js'; import {isLlmAgent} from '../agents/llm_agent.js'; import {createRunConfig, RunConfig} from '../agents/run_config.js'; import {BaseArtifactService} from '../artifacts/base_artifact_service.js'; @@ -470,7 +471,174 @@ export class Runner { } return true; } - // TODO - b/425992518: Implement runLive and related methods. + + /** + * Runs the agent in live (bidirectional) mode. + * + * Establishes an interactive session driven by `liveRequestQueue` and yields + * the resulting events. Live model audio events with `inlineData` are + * yielded but not appended to the session to avoid persisting raw audio + * blobs; events with `fileData` references and most other live events + * (transcriptions, tool calls, usage) are persisted as in `runAsync`. + * + * This feature is **experimental** and its API may change. + * + * @param params.userId The user ID of the session. + * @param params.sessionId The session ID of the session. + * @param params.liveRequestQueue The queue used to feed the live model. + * @param params.runConfig The run config for the agent. + * @yields The events generated by the agent. + */ + async *runLive(params: { + userId: string; + sessionId: string; + liveRequestQueue: LiveRequestQueue; + runConfig?: RunConfig; + abortSignal?: AbortSignal; + /** + * Optional session resumption handle observed from a prior `runLive` + * cycle on the same conversation. When set, the agent's live flow will + * open the connection with `liveConnectConfig.sessionResumption.handle` + * so the server restores its state instead of relying on client-side + * history replay. + */ + liveSessionResumptionHandle?: string; + }): AsyncGenerator { + if (!params.liveRequestQueue) { + throw new Error('liveRequestQueue is required for runLive.'); + } + + const runConfig = createRunConfig(params.runConfig); + if (!runConfig.responseModalities?.length) { + runConfig.responseModalities = [Modality.AUDIO]; + } + // For multi-agent live setups, the model's text transcription is needed + // as context for the transferred agent. + if (this.agent.subAgents?.length) { + if (runConfig.responseModalities.includes(Modality.AUDIO)) { + runConfig.outputAudioTranscription ??= {}; + } + runConfig.inputAudioTranscription ??= {}; + } + + const span = tracer.startSpan('invocation'); + const ctx = trace.setSpan(context.active(), span); + try { + yield* runAsyncGeneratorWithOtelContext( + ctx, + this, + async function* () { + const session = await this.sessionService.getSession({ + appName: this.appName, + userId: params.userId, + sessionId: params.sessionId, + }); + + if (params.abortSignal?.aborted) { + return; + } + + if (!session) { + if (!this.appName) { + throw new Error( + `Session lookup failed: appName must be provided in runner constructor`, + ); + } + throw new Error(`Session not found: ${params.sessionId}`); + } + + const invocationContext = new InvocationContext({ + artifactService: this.artifactService, + sessionService: this.sessionService, + memoryService: this.memoryService, + credentialService: this.credentialService, + invocationId: newInvocationContextId(), + agent: this.agent, + session, + runConfig, + pluginManager: this.pluginManager, + liveRequestQueue: params.liveRequestQueue, + abortSignal: params.abortSignal, + liveSessionResumptionHandle: params.liveSessionResumptionHandle, + }); + + invocationContext.agent = this.determineAgentForResumption( + session, + this.agent, + ); + + // Step 1: before-run plugin hook (early exit if it returns content). + const beforeRunCallbackResponse = + await this.pluginManager.runBeforeRunCallback({ + invocationContext, + }); + if (params.abortSignal?.aborted) { + return; + } + if (beforeRunCallbackResponse) { + const earlyExitEvent = createEvent({ + invocationId: invocationContext.invocationId, + author: 'model', + content: beforeRunCallbackResponse, + }); + await this.sessionService.appendEvent({ + session, + event: earlyExitEvent, + }); + yield earlyExitEvent; + return; + } + + // Step 2: drive the agent's runLive and propagate events. + for await (const event of invocationContext.agent.runLive( + invocationContext, + )) { + if (params.abortSignal?.aborted) { + return; + } + + if (!event.partial && shouldAppendLiveEvent(event)) { + await this.sessionService.appendEvent({session, event}); + } + + const modifiedEvent = await this.pluginManager.runOnEventCallback({ + invocationContext, + event, + }); + if (params.abortSignal?.aborted) { + return; + } + + yield modifiedEvent ?? event; + } + + // Step 3: after-run plugin hook for cleanup/metrics. + await this.pluginManager.runAfterRunCallback({invocationContext}); + }, + ); + } finally { + span.end(); + } + } +} + +/** + * Decides whether a live event should be persisted to the session. + * + * Live model audio events that carry raw inline audio bytes are deliberately + * skipped to avoid persisting large blobs. Audio referenced via `fileData` + * (e.g. saved as artifacts) and all non-audio events are persisted. + */ +function shouldAppendLiveEvent(event: Event): boolean { + const parts = event.content?.parts; + if (!parts?.length) { + return true; + } + const inlineData = parts[0].inlineData; + if (!inlineData?.mimeType?.startsWith('audio/')) { + return true; + } + return false; } /** diff --git a/core/src/utils/model_name.ts b/core/src/utils/model_name.ts index 6b1ea728..3a7250d6 100644 --- a/core/src/utils/model_name.ts +++ b/core/src/utils/model_name.ts @@ -100,3 +100,22 @@ export function isGemini2OrAbove(modelString: string): boolean { export function isGeminiModelIdCheckDisabled(): boolean { return getBooleanEnvVar('ADK_DISABLE_GEMINI_MODEL_ID_CHECK'); } + +/** + * Check if the model is a Gemini 3.1 Flash Live model. + * + * Gemini 3.1 Flash Live has different live API semantics from earlier models: + * - User text input must use `sendRealtimeInput({text})`; `sendClientContent` + * text turns are ignored. + * - Tool calls are emitted without a preceding `turnComplete`, so receivers + * must flush them eagerly. + * + * @param modelString Either a simple model name or path-based model name + * @return true if it's a Gemini 3.1 Flash Live model, false otherwise. + */ +export function isGemini31FlashLive(modelString: string | undefined): boolean { + if (!modelString) { + return false; + } + return extractModelName(modelString).startsWith('gemini-3.1-flash-live'); +} diff --git a/core/test/agents/live_request_queue_test.ts b/core/test/agents/live_request_queue_test.ts index a7ccce8d..bb45890b 100644 --- a/core/test/agents/live_request_queue_test.ts +++ b/core/test/agents/live_request_queue_test.ts @@ -125,6 +125,50 @@ describe('LiveRequestQueue', () => { }).toThrowError('Cannot send to a closed queue.'); }); + it('should reject a waiting get() when its abort signal fires', async () => { + const queue = new LiveRequestQueue(); + const controller = new AbortController(); + const getPromise = queue.get(controller.signal); + + controller.abort(); + + await expect(getPromise).rejects.toThrow('aborted'); + }); + + it('should throw immediately from get() when the signal is already aborted', async () => { + const queue = new LiveRequestQueue(); + const controller = new AbortController(); + controller.abort(); + + await expect(queue.get(controller.signal)).rejects.toThrow('aborted'); + }); + + it('should not let an aborted get() consume a later request', async () => { + const queue = new LiveRequestQueue(); + const controller = new AbortController(); + const abortedGet = queue.get(controller.signal); + + controller.abort(); + await expect(abortedGet).rejects.toThrow('aborted'); + + // The request sent after the abort must reach the next live get(), not + // the waiter that was torn down. + const request: LiveRequest = {content: createUserContent('after-abort')}; + const liveGet = queue.get(); + queue.send(request); + expect(await liveGet).toEqual(request); + }); + + it('should still resolve a get() whose abort signal never fires', async () => { + const queue = new LiveRequestQueue(); + const controller = new AbortController(); + const getPromise = queue.get(controller.signal); + + const request: LiveRequest = {content: createUserContent('req')}; + queue.send(request); + expect(await getPromise).toEqual(request); + }); + it('should drain remaining items after close, then return close signal', async () => { const queue = new LiveRequestQueue(); const request1 = {content: createUserContent('item1')}; diff --git a/core/test/models/gemini_llm_connection_test.ts b/core/test/models/gemini_llm_connection_test.ts new file mode 100644 index 00000000..0779774b --- /dev/null +++ b/core/test/models/gemini_llm_connection_test.ts @@ -0,0 +1,316 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import {Blob, LiveServerMessage, Session} from '@google/genai'; +import {describe, expect, it, vi} from 'vitest'; + +import { + GeminiLlmConnection, + IncomingMessageBuffer, +} from '../../src/models/gemini_llm_connection.js'; +import {LlmResponse} from '../../src/models/llm_response.js'; +import {GoogleLLMVariant} from '../../src/utils/variant_utils.js'; + +interface FakeSession { + sendClientContent: ReturnType; + sendRealtimeInput: ReturnType; + sendToolResponse: ReturnType; + close: ReturnType; +} + +function createFakeSession(): FakeSession { + return { + sendClientContent: vi.fn(), + sendRealtimeInput: vi.fn(), + sendToolResponse: vi.fn(), + close: vi.fn(), + }; +} + +const GEMINI_31 = 'gemini-3.1-flash-live-preview'; +const GEMINI_25 = 'gemini-2.5-flash-live-preview'; + +function createConnection( + options: {apiBackend?: GoogleLLMVariant; modelVersion?: string} = {}, +): { + connection: GeminiLlmConnection; + session: FakeSession; + buffer: IncomingMessageBuffer; +} { + const session = createFakeSession(); + const buffer = new IncomingMessageBuffer(); + const connection = new GeminiLlmConnection( + session as unknown as Session, + buffer, + options.apiBackend ?? GoogleLLMVariant.GEMINI_API, + options.modelVersion ?? GEMINI_31, + ); + return {connection, session, buffer}; +} + +describe('GeminiLlmConnection.sendRealtime', () => { + it('routes audio blobs through audio: on Gemini 3.1', async () => { + const {connection, session} = createConnection(); + const blob: Blob = {data: 'AAA=', mimeType: 'audio/pcm;rate=16000'}; + await connection.sendRealtime(blob); + expect(session.sendRealtimeInput).toHaveBeenCalledWith({audio: blob}); + }); + + it('routes image blobs through video: on Gemini 3.1', async () => { + const {connection, session} = createConnection(); + const blob: Blob = {data: 'AAA=', mimeType: 'image/jpeg'}; + await connection.sendRealtime(blob); + expect(session.sendRealtimeInput).toHaveBeenCalledWith({video: blob}); + }); + + it('drops unknown mime types on Gemini 3.1 instead of guessing', async () => { + const {connection, session} = createConnection(); + const blob: Blob = {data: 'AAA=', mimeType: 'application/octet-stream'}; + await connection.sendRealtime(blob); + expect(session.sendRealtimeInput).not.toHaveBeenCalled(); + }); + + it('routes blobs through media: on pre-3.1 models', async () => { + const {connection, session} = createConnection({modelVersion: GEMINI_25}); + const blob: Blob = {data: 'AAA=', mimeType: 'audio/pcm;rate=16000'}; + await connection.sendRealtime(blob); + expect(session.sendRealtimeInput).toHaveBeenCalledWith({media: blob}); + }); + + it('routes blobs with unknown mime via media: on pre-3.1 models', async () => { + const {connection, session} = createConnection({modelVersion: GEMINI_25}); + const blob = {data: 'AAA='} as Blob; + await connection.sendRealtime(blob); + expect(session.sendRealtimeInput).toHaveBeenCalledWith({media: blob}); + }); +}); + +describe('GeminiLlmConnection.sendContent', () => { + it('routes single user text via sendRealtimeInput on Gemini 3.1', async () => { + const {connection, session} = createConnection(); + await connection.sendContent({role: 'user', parts: [{text: 'hello'}]}); + expect(session.sendRealtimeInput).toHaveBeenCalledWith({text: 'hello'}); + expect(session.sendClientContent).not.toHaveBeenCalled(); + }); + + it('routes single user text via sendClientContent on pre-3.1 models', async () => { + const {connection, session} = createConnection({modelVersion: GEMINI_25}); + const content = {role: 'user', parts: [{text: 'hello'}]}; + await connection.sendContent(content); + expect(session.sendClientContent).toHaveBeenCalledWith({ + turns: [content], + turnComplete: true, + }); + expect(session.sendRealtimeInput).not.toHaveBeenCalled(); + }); + + it('uses sendClientContent for multi-part user content', async () => { + const {connection, session} = createConnection(); + const content = { + role: 'user', + parts: [{text: 'hello'}, {text: 'world'}], + }; + await connection.sendContent(content); + expect(session.sendClientContent).toHaveBeenCalledWith({ + turns: [content], + turnComplete: true, + }); + expect(session.sendRealtimeInput).not.toHaveBeenCalled(); + }); + + it('uses sendClientContent when content role is not user', async () => { + const {connection, session} = createConnection(); + const content = {role: 'model', parts: [{text: 'hi'}]}; + await connection.sendContent(content); + expect(session.sendClientContent).toHaveBeenCalledWith({ + turns: [content], + turnComplete: true, + }); + expect(session.sendRealtimeInput).not.toHaveBeenCalled(); + }); + + it('uses sendToolResponse for function response content', async () => { + const {connection, session} = createConnection(); + const fr = {id: 'fc1', name: 'echo', response: {ok: true}}; + await connection.sendContent({ + role: 'user', + parts: [{functionResponse: fr}], + }); + expect(session.sendToolResponse).toHaveBeenCalledWith({ + functionResponses: [fr], + }); + }); + + it('throws when content has no parts', async () => { + const {connection} = createConnection(); + await expect(connection.sendContent({role: 'user'})).rejects.toThrow( + 'Content must have parts.', + ); + }); +}); + +describe('GeminiLlmConnection.sendHistory', () => { + it('does not send when history is empty', async () => { + const {connection, session} = createConnection(); + await connection.sendHistory([]); + expect(session.sendClientContent).not.toHaveBeenCalled(); + }); + + it('seals turn when history ends with a user message', async () => { + const {connection, session} = createConnection(); + const history = [ + {role: 'model', parts: [{text: 'hi'}]}, + {role: 'user', parts: [{text: 'hello'}]}, + ]; + await connection.sendHistory(history); + expect(session.sendClientContent).toHaveBeenCalledWith({ + turns: history, + turnComplete: true, + }); + }); + + it('leaves turn open when history ends with a model message', async () => { + const {connection, session} = createConnection(); + const history = [ + {role: 'user', parts: [{text: 'hello'}]}, + {role: 'model', parts: [{text: 'hi back'}]}, + ]; + await connection.sendHistory(history); + expect(session.sendClientContent).toHaveBeenCalledWith({ + turns: history, + turnComplete: false, + }); + }); + + it('filters out audio parts before sending history', async () => { + const {connection, session} = createConnection(); + const history = [ + { + role: 'model', + parts: [ + {text: 'hello'}, + {inlineData: {data: 'AAA=', mimeType: 'audio/pcm'}}, + ], + }, + ]; + await connection.sendHistory(history); + expect(session.sendClientContent).toHaveBeenCalledWith({ + turns: [{role: 'model', parts: [{text: 'hello'}]}], + turnComplete: false, + }); + }); +}); + +describe('GeminiLlmConnection.receive', () => { + it('does not terminate after turnComplete and yields events for the next turn', async () => { + const {connection, buffer} = createConnection(); + + const turn1Audio: LiveServerMessage = { + serverContent: { + modelTurn: { + role: 'model', + parts: [{inlineData: {data: 'AAA=', mimeType: 'audio/pcm'}}], + }, + }, + } as LiveServerMessage; + const turn1Done: LiveServerMessage = { + serverContent: {turnComplete: true}, + } as LiveServerMessage; + const turn2Audio: LiveServerMessage = { + serverContent: { + modelTurn: { + role: 'model', + parts: [{inlineData: {data: 'BBB=', mimeType: 'audio/pcm'}}], + }, + }, + } as LiveServerMessage; + const turn2Done: LiveServerMessage = { + serverContent: {turnComplete: true}, + } as LiveServerMessage; + + buffer.push({kind: 'message', message: turn1Audio}); + buffer.push({kind: 'message', message: turn1Done}); + buffer.push({kind: 'message', message: turn2Audio}); + buffer.push({kind: 'message', message: turn2Done}); + buffer.push({kind: 'close'}); + + const events: LlmResponse[] = []; + for await (const event of connection.receive()) { + events.push(event); + } + + const turnCompleteCount = events.filter((e) => e.turnComplete).length; + expect(turnCompleteCount).toBe(2); + + const inlineDataChunks = events + .map((e) => e.content?.parts?.[0]?.inlineData?.data) + .filter(Boolean); + expect(inlineDataChunks).toContain('AAA='); + expect(inlineDataChunks).toContain('BBB='); + }); + + it('terminates when the buffer reports close', async () => { + const {connection, buffer} = createConnection(); + buffer.push({kind: 'close'}); + + const events: LlmResponse[] = []; + for await (const event of connection.receive()) { + events.push(event); + } + expect(events).toEqual([]); + }); + + it('throws when the buffer reports an error', async () => { + const {connection, buffer} = createConnection(); + buffer.push({kind: 'error', error: new Error('boom')}); + + const consume = async () => { + for await (const _ of connection.receive()) { + // drain + } + }; + + await expect(consume()).rejects.toThrow('boom'); + }); + + it('yields goAway events from the server', async () => { + const {connection, buffer} = createConnection(); + const goAway = {timeLeft: '1s'}; + buffer.push({ + kind: 'message', + message: {goAway} as LiveServerMessage, + }); + buffer.push({kind: 'close'}); + + const events: LlmResponse[] = []; + for await (const event of connection.receive()) { + events.push(event); + } + + const goAwayEvents = events.filter((e) => e.goAway); + expect(goAwayEvents.length).toBe(1); + expect(goAwayEvents[0].goAway).toEqual(goAway); + }); + + it('yields sessionResumptionUpdate events from the server', async () => { + const {connection, buffer} = createConnection(); + const update = {newHandle: 'handle-123', resumable: true}; + buffer.push({ + kind: 'message', + message: {sessionResumptionUpdate: update} as LiveServerMessage, + }); + buffer.push({kind: 'close'}); + + const events: LlmResponse[] = []; + for await (const event of connection.receive()) { + events.push(event); + } + + const resumeEvents = events.filter((e) => e.liveSessionResumptionUpdate); + expect(resumeEvents.length).toBe(1); + expect(resumeEvents[0].liveSessionResumptionUpdate).toEqual(update); + }); +}); diff --git a/core/test/runner/run_live_test.ts b/core/test/runner/run_live_test.ts new file mode 100644 index 00000000..925763ea --- /dev/null +++ b/core/test/runner/run_live_test.ts @@ -0,0 +1,500 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + BaseLlm, + BaseLlmConnection, + BaseTool, + Event, + InMemoryArtifactService, + InMemorySessionService, + LiveRequestQueue, + LlmAgent, + LlmRequest, + LlmResponse, + RunAsyncToolRequest, + Runner, +} from '@google/adk'; +import {Blob, Content, FunctionDeclaration, Modality} from '@google/genai'; +import {beforeEach, describe, expect, it} from 'vitest'; + +const TEST_APP_ID = 'test_app_id'; +const TEST_USER_ID = 'test_user_id'; +const TEST_SESSION_ID = 'test_session_id'; + +class RecordingConnection implements BaseLlmConnection { + readonly historyCalls: Content[][] = []; + readonly contentCalls: Content[] = []; + readonly realtimeCalls: Blob[] = []; + closed = false; + + constructor(private readonly responses: LlmResponse[]) {} + + async sendHistory(history: Content[]): Promise { + this.historyCalls.push(history); + } + async sendContent(content: Content): Promise { + this.contentCalls.push(content); + } + async sendRealtime(blob: Blob): Promise { + this.realtimeCalls.push(blob); + } + async *receive(): AsyncGenerator { + for (const response of this.responses) { + yield response; + } + } + async close(): Promise { + this.closed = true; + } +} + +class FakeLiveLlm extends BaseLlm { + connection?: RecordingConnection; + llmRequestSeen?: LlmRequest; + readonly connections: RecordingConnection[] = []; + readonly llmRequestsSeen: LlmRequest[] = []; + + constructor( + private readonly responses: LlmResponse[] | LlmResponse[][], + model = 'fake-live-llm', + ) { + super({model}); + } + + // eslint-disable-next-line require-yield + override async *generateContentAsync(): AsyncGenerator< + LlmResponse, + void, + void + > { + throw new Error('generateContentAsync not used in live tests'); + } + + override async connect(llmRequest: LlmRequest): Promise { + // Snapshot the request as the caller may mutate `liveConnectConfig` + // across reconnect attempts (e.g. setting `sessionResumption.handle`). + this.llmRequestSeen = llmRequest; + this.llmRequestsSeen.push( + JSON.parse(JSON.stringify(llmRequest)) as LlmRequest, + ); + const isSequence = + Array.isArray(this.responses) && Array.isArray(this.responses[0]); + const responses = isSequence + ? ((this.responses as LlmResponse[][])[this.connections.length] ?? []) + : (this.responses as LlmResponse[]); + this.connection = new RecordingConnection(responses); + this.connections.push(this.connection); + return this.connection; + } +} + +class EchoTool extends BaseTool { + constructor() { + super({name: 'echo', description: 'Echoes back its input.'}); + } + override _getDeclaration(): FunctionDeclaration | undefined { + return {name: this.name, description: this.description}; + } + override async runAsync(request: RunAsyncToolRequest): Promise { + return {echoed: request.args}; + } +} + +describe('Runner.runLive', () => { + let sessionService: InMemorySessionService; + let artifactService: InMemoryArtifactService; + + beforeEach(async () => { + sessionService = new InMemorySessionService(); + artifactService = new InMemoryArtifactService(); + await sessionService.createSession({ + appName: TEST_APP_ID, + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + }); + }); + + it('throws when liveRequestQueue is missing', async () => { + const llm = new FakeLiveLlm([]); + const agent = new LlmAgent({name: 'agent', model: llm}); + const runner = new Runner({ + appName: TEST_APP_ID, + agent, + sessionService, + artifactService, + }); + + await expect(async () => { + // @ts-expect-error - intentionally omit required argument + for await (const _ of runner.runLive({ + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + })) { + // no-op + } + }).rejects.toThrow('liveRequestQueue is required'); + }); + + it('throws when session does not exist', async () => { + const llm = new FakeLiveLlm([]); + const agent = new LlmAgent({name: 'agent', model: llm}); + const runner = new Runner({ + appName: TEST_APP_ID, + agent, + sessionService, + artifactService, + }); + const queue = new LiveRequestQueue(); + queue.close(); + await expect(async () => { + for await (const _ of runner.runLive({ + userId: TEST_USER_ID, + sessionId: 'missing', + liveRequestQueue: queue, + })) { + // no-op + } + }).rejects.toThrow('Session not found: missing'); + }); + + it('forwards realtime blobs to the connection and yields model events', async () => { + const audioPart: Content = { + role: 'model', + parts: [{inlineData: {data: 'AAA=', mimeType: 'audio/pcm'}}], + }; + const textPart: Content = {role: 'model', parts: [{text: 'hello'}]}; + const llm = new FakeLiveLlm([ + {content: audioPart}, + {content: textPart}, + {turnComplete: true}, + ]); + const agent = new LlmAgent({name: 'agent', model: llm}); + const runner = new Runner({ + appName: TEST_APP_ID, + agent, + sessionService, + artifactService, + }); + + const queue = new LiveRequestQueue(); + const blob: Blob = {data: 'AAA=', mimeType: 'audio/pcm'}; + queue.sendRealtime(blob); + queue.close(); + + const events: Event[] = []; + for await (const event of runner.runLive({ + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + liveRequestQueue: queue, + })) { + events.push(event); + } + + expect(llm.connection).toBeDefined(); + expect(llm.connection!.realtimeCalls).toEqual([blob]); + expect(llm.connection!.closed).toBe(true); + + expect(events.some((e) => e.content === audioPart)).toBe(true); + expect(events.some((e) => e.content === textPart)).toBe(true); + expect(events.some((e) => e.turnComplete)).toBe(true); + }); + + it('defaults responseModalities to AUDIO and applies live config', async () => { + const llm = new FakeLiveLlm([{turnComplete: true}]); + const agent = new LlmAgent({name: 'agent', model: llm}); + const runner = new Runner({ + appName: TEST_APP_ID, + agent, + sessionService, + artifactService, + }); + + const queue = new LiveRequestQueue(); + queue.close(); + for await (const _ of runner.runLive({ + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + liveRequestQueue: queue, + })) { + // drain + } + + expect(llm.llmRequestSeen?.liveConnectConfig?.responseModalities).toEqual([ + Modality.AUDIO, + ]); + }); + + it('does not persist live audio events but persists transcription events', async () => { + const audioPart: Content = { + role: 'model', + parts: [{inlineData: {data: 'AAA=', mimeType: 'audio/pcm'}}], + }; + const llm = new FakeLiveLlm([ + {content: audioPart}, + { + outputTranscription: {text: 'hello world', finished: true}, + partial: false, + }, + {turnComplete: true}, + ]); + const agent = new LlmAgent({name: 'agent', model: llm}); + const runner = new Runner({ + appName: TEST_APP_ID, + agent, + sessionService, + artifactService, + }); + + const queue = new LiveRequestQueue(); + queue.close(); + for await (const _ of runner.runLive({ + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + liveRequestQueue: queue, + })) { + // drain + } + + const session = await sessionService.getSession({ + appName: TEST_APP_ID, + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + }); + const persisted = session!.events; + const hasAudioInline = persisted.some((event) => + event.content?.parts?.some((part) => + part.inlineData?.mimeType?.startsWith('audio/'), + ), + ); + expect(hasAudioInline).toBe(false); + const hasTranscription = persisted.some( + (event) => event.outputTranscription !== undefined, + ); + expect(hasTranscription).toBe(true); + }); + + it('runs tool calls and sends function responses back to the model', async () => { + const functionCall: Content = { + role: 'model', + parts: [{functionCall: {name: 'echo', args: {value: 1}}}], + }; + const llm = new FakeLiveLlm([ + {content: functionCall}, + {turnComplete: true}, + ]); + const agent = new LlmAgent({ + name: 'agent', + model: llm, + tools: [new EchoTool()], + }); + const runner = new Runner({ + appName: TEST_APP_ID, + agent, + sessionService, + artifactService, + }); + + const queue = new LiveRequestQueue(); + queue.close(); + const events: Event[] = []; + for await (const event of runner.runLive({ + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + liveRequestQueue: queue, + })) { + events.push(event); + } + + const responseEvents = events.filter((event) => + event.content?.parts?.some((part) => part.functionResponse), + ); + expect(responseEvents.length).toBe(1); + + expect(llm.connection!.contentCalls.length).toBe(1); + const sentBack = llm.connection!.contentCalls[0]; + expect(sentBack.parts?.[0]?.functionResponse?.name).toBe('echo'); + }); + + it('captures sessionResumptionUpdate handles into invocation context', async () => { + const llm = new FakeLiveLlm([ + {liveSessionResumptionUpdate: {newHandle: 'handle-1'}}, + {liveSessionResumptionUpdate: {newHandle: 'handle-2'}}, + {turnComplete: true}, + ]); + const agent = new LlmAgent({name: 'agent', model: llm}); + const runner = new Runner({ + appName: TEST_APP_ID, + agent, + sessionService, + artifactService, + }); + + const queue = new LiveRequestQueue(); + queue.close(); + const events: Event[] = []; + for await (const event of runner.runLive({ + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + liveRequestQueue: queue, + })) { + events.push(event); + } + + const resumeEvents = events.filter((e) => e.liveSessionResumptionUpdate); + expect(resumeEvents.length).toBe(2); + expect(resumeEvents[1].liveSessionResumptionUpdate?.newHandle).toBe( + 'handle-2', + ); + }); + + it('reconnects with session handle on goAway and skips history replay', async () => { + const llm = new FakeLiveLlm([ + [ + {liveSessionResumptionUpdate: {newHandle: 'handle-1'}}, + {goAway: {timeLeft: '1s'}}, + ], + [{turnComplete: true}], + ]); + const agent = new LlmAgent({name: 'agent', model: llm}); + const runner = new Runner({ + appName: TEST_APP_ID, + agent, + sessionService, + artifactService, + }); + + // Seed a content event so contents is non-empty on the first connect. + const session = (await sessionService.getSession({ + appName: TEST_APP_ID, + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + }))!; + await sessionService.appendEvent({ + session, + event: { + invocationId: 'seed', + author: 'user', + id: 'seed-evt', + actions: { + stateDelta: {}, + artifactDelta: {}, + requestedAuthConfigs: {}, + requestedToolConfirmations: {}, + }, + longRunningToolIds: [], + timestamp: Date.now(), + content: {role: 'user', parts: [{text: 'hello'}]}, + } as Event, + }); + + const queue = new LiveRequestQueue(); + queue.close(); + for await (const _ of runner.runLive({ + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + liveRequestQueue: queue, + })) { + // drain + } + + expect(llm.connections.length).toBe(2); + // First connection received history, second skipped it. + expect(llm.connections[0].historyCalls.length).toBe(1); + expect(llm.connections[1].historyCalls.length).toBe(0); + // Second connect carried the captured resumption handle. + expect( + llm.llmRequestsSeen[1].liveConnectConfig?.sessionResumption?.handle, + ).toBe('handle-1'); + expect( + llm.llmRequestsSeen[1].liveConnectConfig?.sessionResumption?.transparent, + ).toBe(true); + // First connect had no resumption handle set. + expect( + llm.llmRequestsSeen[0].liveConnectConfig?.sessionResumption?.handle, + ).toBeUndefined(); + }); + + it('uses an externally provided session resumption handle on first connect', async () => { + const llm = new FakeLiveLlm([{turnComplete: true}]); + const agent = new LlmAgent({name: 'agent', model: llm}); + const runner = new Runner({ + appName: TEST_APP_ID, + agent, + sessionService, + artifactService, + }); + + // Seed contents so without a handle the runner would call sendHistory. + const session = (await sessionService.getSession({ + appName: TEST_APP_ID, + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + }))!; + await sessionService.appendEvent({ + session, + event: { + invocationId: 'seed', + author: 'user', + id: 'seed-evt', + actions: { + stateDelta: {}, + artifactDelta: {}, + requestedAuthConfigs: {}, + requestedToolConfirmations: {}, + }, + longRunningToolIds: [], + timestamp: Date.now(), + content: {role: 'user', parts: [{text: 'hello'}]}, + } as Event, + }); + + const queue = new LiveRequestQueue(); + queue.close(); + for await (const _ of runner.runLive({ + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + liveRequestQueue: queue, + liveSessionResumptionHandle: 'external-handle', + })) { + // drain + } + + // History was skipped because the caller supplied a handle. + expect(llm.connections[0].historyCalls.length).toBe(0); + expect( + llm.llmRequestsSeen[0].liveConnectConfig?.sessionResumption?.handle, + ).toBe('external-handle'); + }); + + it('does not reconnect when no resumption handle has been captured', async () => { + const llm = new FakeLiveLlm([ + [{goAway: {timeLeft: '1s'}}], + [{turnComplete: true}], + ]); + const agent = new LlmAgent({name: 'agent', model: llm}); + const runner = new Runner({ + appName: TEST_APP_ID, + agent, + sessionService, + artifactService, + }); + + const queue = new LiveRequestQueue(); + queue.close(); + await expect(async () => { + for await (const _ of runner.runLive({ + userId: TEST_USER_ID, + sessionId: TEST_SESSION_ID, + liveRequestQueue: queue, + })) { + // drain + } + }).rejects.toThrow(/live reconnect requested/); + + expect(llm.connections.length).toBe(1); + }); +});