diff --git a/.changeset/agent-update-instructions.md b/.changeset/agent-update-instructions.md new file mode 100644 index 000000000..69ecaebb4 --- /dev/null +++ b/.changeset/agent-update-instructions.md @@ -0,0 +1,6 @@ +--- +'@livekit/agents': patch +'@livekit/agents-plugin-openai': patch +--- + +Add `Agent.updateInstructions()` to update an agent's instructions mid-session (parity with Python). The change propagates the new instructions through the active `AgentActivity`, records an `AgentConfigUpdate` in the chat and session history, and syncs the realtime/stateless contexts. For OpenAI realtime, per-response instructions now preserve the session-level instructions instead of replacing them. diff --git a/.changeset/async-toolsets.md b/.changeset/async-toolsets.md new file mode 100644 index 000000000..9fd29c00c --- /dev/null +++ b/.changeset/async-toolsets.md @@ -0,0 +1,6 @@ +--- +'@livekit/agents': minor +--- + +Port async tool execution semantics from Python: tools can release their turn with `ctx.update()`, +`AsyncToolset` controls session/activity scope, and cancellable tools expose task-management helpers. diff --git a/.changeset/bargein-default-threshold-drop-http.md b/.changeset/bargein-default-threshold-drop-http.md new file mode 100644 index 000000000..2d937eb5c --- /dev/null +++ b/.changeset/bargein-default-threshold-drop-http.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': patch +--- + +Adaptive interruption detection now omits the threshold from `session.create` unless the user explicitly overrides it, letting the gateway apply its fetched default (surfaced via `default_threshold` on `session.created`). The HTTP transport has been dropped — detection always connects over WebSocket and always requires LiveKit credentials, and its base URL now defaults from `LIVEKIT_INFERENCE_URL` instead of `LIVEKIT_REMOTE_EOT_URL`. Inference requests also send an `X-LiveKit-Worker-Token` header when `LIVEKIT_WORKER_TOKEN` is set (hosted agents); a token supplied via the `--worker-token` CLI flag is now re-exported into the environment so forked job subprocesses inherit it and include the header. The `X-LiveKit-Agent-Id` header is now only attached once the room is connected to avoid leaking an unset local-participant SID. The interruption WebSocket is now closed deterministically on stream teardown (including error and cancel paths) instead of only on graceful completion — previously an orphaned socket leaked per session/activity and accumulated for the worker's lifetime. Mid-session threshold/duration changes via `updateOptions` now reconnect the WebSocket in place rather than closing it and letting the next send error the stream — so option changes no longer consume a failover retry (previously enough updates in a session could exhaust the retry budget and stop interruption detection). diff --git a/.changeset/cold-avocados-behave.md b/.changeset/cold-avocados-behave.md new file mode 100644 index 000000000..a2eef8a7b --- /dev/null +++ b/.changeset/cold-avocados-behave.md @@ -0,0 +1,5 @@ +--- +"@livekit/agents": patch +--- + +Add Agent.create method diff --git a/.changeset/funky-mugs-stare.md b/.changeset/funky-mugs-stare.md new file mode 100644 index 000000000..f6d283d86 --- /dev/null +++ b/.changeset/funky-mugs-stare.md @@ -0,0 +1,5 @@ +--- +"@livekit/agents": patch +--- + +Add scoped filler support to RunContext diff --git a/.changeset/gemini-provider-tools.md b/.changeset/gemini-provider-tools.md new file mode 100644 index 000000000..3b1093432 --- /dev/null +++ b/.changeset/gemini-provider-tools.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents-plugin-google': minor +--- + +Add Gemini provider tools for Google Search, Google Maps, URL context, File Search, code execution, and Vertex RAG retrieval, and serialize them from `ToolContext` for Google LLM and realtime sessions. diff --git a/.changeset/honest-swans-drum.md b/.changeset/honest-swans-drum.md new file mode 100644 index 000000000..e4782071e --- /dev/null +++ b/.changeset/honest-swans-drum.md @@ -0,0 +1,5 @@ +--- +"@livekit/agents": patch +--- + +Don't retain recorded events when recording is disabled diff --git a/.changeset/inworld-delivery-mode.md b/.changeset/inworld-delivery-mode.md new file mode 100644 index 000000000..c02a8c94e --- /dev/null +++ b/.changeset/inworld-delivery-mode.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': patch +--- + +Add Inworld `delivery_mode` to inference TTS model options. diff --git a/.changeset/list-syntax-toolcontext.md b/.changeset/list-syntax-toolcontext.md new file mode 100644 index 000000000..5e854a813 --- /dev/null +++ b/.changeset/list-syntax-toolcontext.md @@ -0,0 +1,16 @@ +--- +'@livekit/agents': minor +--- + +**BREAKING**: `Agent({ tools })` and `agent.updateTools()` now accept a flat list `(FunctionTool | ProviderTool | Toolset)[]` instead of a `Record` map, and `llm.tool({ ... })` requires a `name` field. `ToolContext` is now a Python-parity class with `functionTools` / `providerTools` / `toolsets` accessors, plus `flatten()`, `hasTool(id)`, `getFunctionTool(id)`, `updateTools()`, `copy()`, and `equals()`. To match the Python reference, registering two **different** function-tool instances under the same `name` now throws `duplicate function name: ` instead of silently overriding the earlier entry; passing the **same instance** twice is a no-op. `agent.toolCtx` returns a defensive copy so callers can no longer mutate the agent's internal state. `LLM.chat({ toolCtx })` accepts either a `ToolContext` instance or a raw `(FunctionTool | ProviderTool | Toolset)[]` array (`ToolCtxInput`) and normalizes it internally, so callers don't have to construct a `ToolContext` themselves. + +Tools also expose an `id: string` field on the base `Tool` interface (parity with Python's `Tool.id` property): for `FunctionTool` it mirrors `name`, for `ProviderTool` it is the provider tool id. `ToolContext` keys and equality now use `tool.id` consistently. + +**BREAKING**: Provider tools are now modeled to match Python's `ProviderTool`: + +- `ProviderDefinedTool` is renamed to `ProviderTool`, and `isProviderDefinedTool` is renamed to `isProviderTool`. +- `ProviderTool` is now an **abstract class** (Python parity). Plugins must subclass it (`class WebSearch extends ProviderTool { ... }`) to attach provider-specific fields and serializers; bare `new ProviderTool(...)` is rejected at compile time. +- The `tool({ id })` factory overload is removed; `tool({ ... })` only creates function tools now. Construct provider tools by instantiating a `ProviderTool` subclass. +- The `ToolType` literal for provider tools is renamed from `'provider-defined'` to `'provider'`. + +`Toolset` now carries a `TOOLSET_SYMBOL` marker and is detected via a new `isToolset()` guard (consistent with `isFunctionTool` / `isProviderTool`). Existing `instanceof Toolset` checks still work, but symbol-based detection is preferred for cross-realm safety. diff --git a/.changeset/list-unknown-tools.md b/.changeset/list-unknown-tools.md new file mode 100644 index 000000000..091964ae2 --- /dev/null +++ b/.changeset/list-unknown-tools.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': patch +--- + +List available tools in unknown-function error output. diff --git a/.changeset/llm-stream-collect.md b/.changeset/llm-stream-collect.md new file mode 100644 index 000000000..17b892718 --- /dev/null +++ b/.changeset/llm-stream-collect.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': patch +--- + +Add `LLMStream.collect()` for awaiting the full response of a chat stream as a single object (text, tool calls, usage, extra). diff --git a/.changeset/mock-tools-testing.md b/.changeset/mock-tools-testing.md new file mode 100644 index 000000000..b0f32ed46 --- /dev/null +++ b/.changeset/mock-tools-testing.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': patch +--- + +feat(testing): add `mockTools` utility to override an Agent's tool implementations within an async context, mirroring the Python `mock_tools` API diff --git a/.changeset/object-tools-compat.md b/.changeset/object-tools-compat.md new file mode 100644 index 000000000..907485171 --- /dev/null +++ b/.changeset/object-tools-compat.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': patch +--- + +Add object-map tool syntax compatibility. diff --git a/.changeset/odd-languages-speak.md b/.changeset/odd-languages-speak.md new file mode 100644 index 000000000..b70a32109 --- /dev/null +++ b/.changeset/odd-languages-speak.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents-plugin-elevenlabs': patch +--- + +Avoid persisting per-call STT language overrides on ElevenLabs STT instances. diff --git a/.changeset/openai-provider-tools.md b/.changeset/openai-provider-tools.md new file mode 100644 index 000000000..8e793a935 --- /dev/null +++ b/.changeset/openai-provider-tools.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents-plugin-openai': minor +--- + +Add OpenAI Responses provider tools for web search, file search, and code interpreter. diff --git a/.changeset/phonic-client-header.md b/.changeset/phonic-client-header.md new file mode 100644 index 000000000..e2202281c --- /dev/null +++ b/.changeset/phonic-client-header.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents-plugin-phonic': patch +--- + +Add LiveKit Agents JS client header to Phonic conversation sockets. diff --git a/.changeset/port-dob-task.md b/.changeset/port-dob-task.md new file mode 100644 index 000000000..707ebd97d --- /dev/null +++ b/.changeset/port-dob-task.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': minor +--- + +Add beta GetDOBTask with two-digit year normalization. diff --git a/.changeset/port-end-call-tool.md b/.changeset/port-end-call-tool.md new file mode 100644 index 000000000..f1646f943 --- /dev/null +++ b/.changeset/port-end-call-tool.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': minor +--- + +Add beta EndCallTool for ending calls from agent tools diff --git a/.changeset/quick-meals-breathe.md b/.changeset/quick-meals-breathe.md new file mode 100644 index 000000000..d32233b93 --- /dev/null +++ b/.changeset/quick-meals-breathe.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents': patch +--- + +Adds base `Toolset` support: a stateful container for a group of tools with `setup()` / `aclose()` lifecycle hooks. Toolsets can be passed directly into `Agent({ tools: [...] })` alongside individual function tools; their tools are flattened into the agent's `ToolContext` and the runtime drives `setup()` on activity start, `aclose()` on close, and a setup/close diff when `agent.updateTools()` adds or removes Toolsets mid-session. Per-toolset `setup()` errors are logged but do not abort the activity. The `IGNORE_ON_ENTER` flag is also respected for function tools nested inside a Toolset. Every LLM and realtime plugin tool builder iterates `ToolContext.flatten()` so toolset-contributed tools are correctly advertised. Also exports `ToolCalledEvent` / `ToolCompletedEvent` payload types. diff --git a/.changeset/scope-forward-audio-playback-started.md b/.changeset/scope-forward-audio-playback-started.md new file mode 100644 index 000000000..5c5cd277e --- /dev/null +++ b/.changeset/scope-forward-audio-playback-started.md @@ -0,0 +1,15 @@ +--- +'@livekit/agents': patch +--- + +fix(voice): scope forwardAudio's playback-started listener to its own segment + +When a speech is interrupted, the scheduling loop immediately authorizes the next +speech, so the new segment's `forwardAudio` registers its `playback_started` +listener on the shared audio output while the interrupted segment is still +emitting events during teardown. The stray event resolved the new segment's +`firstFrameFut` before its first frame was captured, which skipped resampler +creation and pushed an unresampled frame straight to the `AudioSource` +(`RtcError: sample_rate and num_channels don't match`) and corrupted playback +bookkeeping. The listener now only resolves `firstFrameFut` after the segment has +captured its own first frame. diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a527b4fe4..83e3ba826 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -35,6 +35,8 @@ jobs: run: pnpm format:check - name: Throws transformer run: pnpm throws:check + - name: Type check + run: pnpm typecheck build: name: Build runs-on: ubuntu-latest @@ -48,5 +50,13 @@ jobs: cache: pnpm - name: Install dependencies run: pnpm install --frozen-lockfile --ignore-scripts + - name: Cache turbo + uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 + with: + path: .turbo + key: turbo-${{ runner.os }}-node20-${{ hashFiles('pnpm-lock.yaml') }}-${{ github.run_id }} + restore-keys: | + turbo-${{ runner.os }}-node20-${{ hashFiles('pnpm-lock.yaml') }}- + turbo-${{ runner.os }}-node20- - name: Build run: pnpm build diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index de5e3535e..d205e932d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,6 +29,14 @@ jobs: cache: pnpm - name: Install dependencies run: pnpm install --frozen-lockfile --ignore-scripts + - name: Cache turbo + uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 + with: + path: .turbo + key: turbo-${{ runner.os }}-node20-${{ hashFiles('pnpm-lock.yaml') }}-${{ github.run_id }} + restore-keys: | + turbo-${{ runner.os }}-node20-${{ hashFiles('pnpm-lock.yaml') }}- + turbo-${{ runner.os }}-node20- - name: Build run: pnpm build - name: Check which tests to run diff --git a/agents/package.json b/agents/package.json index 3a028e195..bdcc8be98 100644 --- a/agents/package.json +++ b/agents/package.json @@ -34,6 +34,7 @@ "clean": "rm -rf dist", "clean:build": "pnpm clean && pnpm build", "lint": "eslint -f unix \"src/**/*.ts\"", + "typecheck": "tsc -p tsconfig.typecheck.json", "api:check": "api-extractor run --typescript-compiler-folder ../node_modules/typescript", "api:update": "api-extractor run --local --typescript-compiler-folder ../node_modules/typescript --verbose", "throws:check": "throws-check src/**/*.ts" diff --git a/agents/src/beta/index.ts b/agents/src/beta/index.ts index 98ac382e9..bd3f7877a 100644 --- a/agents/src/beta/index.ts +++ b/agents/src/beta/index.ts @@ -2,7 +2,11 @@ // // SPDX-License-Identifier: Apache-2.0 export { + GetDOBTask, + type GetDOBResult, + type GetDOBTaskOptions, TaskGroup, + type TimeOfBirth, type TaskCompletedEvent, type TaskGroupOptions, type TaskGroupResult, @@ -12,3 +16,10 @@ export { type WarmTransferTaskOptions, } from './workflows/index.js'; export { Instructions } from '../llm/index.js'; +export { + END_CALL_DESCRIPTION, + createEndCallTool, + type EndCallToolCalledEvent, + type EndCallToolCompletedEvent, + type EndCallToolOptions, +} from './tools/index.js'; diff --git a/agents/src/beta/tools/end_call.ts b/agents/src/beta/tools/end_call.ts new file mode 100644 index 000000000..962deda36 --- /dev/null +++ b/agents/src/beta/tools/end_call.ts @@ -0,0 +1,181 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { type EventEmitter, once } from 'node:events'; +import { setTimeout as waitFor } from 'node:timers/promises'; +import { getJobContext } from '../../job.js'; +import { + RealtimeModel, + type ToolCalledEvent, + type ToolCompletedEvent, + Toolset, + tool, +} from '../../llm/index.js'; +import { log } from '../../log.js'; +import type { AgentSession, AgentSessionCallbacks } from '../../voice/agent_session.js'; +import { AgentSessionEventTypes } from '../../voice/events.js'; +import type { UnknownUserData } from '../../voice/run_context.js'; + +/** How long to wait for the agent's goodbye reply to play out before forcing shutdown. */ +const END_CALL_REPLY_TIMEOUT = 5000; + +/** Typed wrapper around `events.once`; abort resolves to `undefined`, other errors propagate. */ +function onceEvent( + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- callbacks don't depend on UserData + session: AgentSession, + event: E, + options?: { signal?: AbortSignal }, +): Promise[0] | undefined> { + return ( + once(session as unknown as EventEmitter, event, options) as Promise< + Parameters + > + ).then( + ([payload]) => payload, + (err) => { + if (options?.signal?.aborted) return undefined; + throw err; + }, + ); +} + +export const END_CALL_DESCRIPTION = ` +Ends the current call and disconnects immediately. + +Call when: +- The user clearly indicates they are done (e.g., "that's all, bye"). + +Do not call when: +- The user asks to pause, hold, or transfer. +- Intent is unclear. + +This is the final action the agent can take. +Once called, no further interaction is possible with the user. +Don't generate any other text or response when the tool is called. +`; + +export type EndCallToolCalledEvent = ToolCalledEvent; + +export type EndCallToolCompletedEvent = ToolCompletedEvent; + +export type EndCallToolOptions = { + /** Additional description to add to the end call tool. */ + extraDescription?: string; + /** + * Whether to delete the room when the user ends the call. + * Deleting the room disconnects all remote users, including SIP callers. + */ + deleteRoom?: boolean; + /** Tool output to the LLM for generating the tool response. */ + endInstructions?: string | null; + /** Callback to call when the tool is called. */ + onToolCalled?: (event: EndCallToolCalledEvent) => Promise | void; + /** Callback to call when the tool is completed. */ + onToolCompleted?: (event: EndCallToolCompletedEvent) => Promise | void; +}; + +/** + * Allows the agent to end the call and disconnect from the room. + */ +export function createEndCallTool({ + extraDescription = '', + deleteRoom = true, + endInstructions = 'say goodbye to the user', + onToolCalled, + onToolCompleted, +}: EndCallToolOptions = {}): Toolset { + // For a realtime LLM that generates the goodbye reply itself, wait for that reply to play out + // (bounded by END_CALL_REPLY_TIMEOUT) before shutting down. `signal` is aborted when the call + // ends or the toolset is torn down, which cancels whichever of the two races is still pending. + const delayedSessionShutdown = async ( + session: AgentSession, + signal: AbortSignal, + ): Promise => { + const speech = onceEvent(session, AgentSessionEventTypes.SpeechCreated, { signal }).then( + (event) => event?.speechHandle, + ); + const timeout = waitFor(END_CALL_REPLY_TIMEOUT, 'timeout' as const, { signal }).catch( + () => undefined, + ); + + const winner = await Promise.race([speech, timeout]); + if (signal.aborted) return; // session already closed or toolset torn down + + if (winner === 'timeout') { + log().warn('tool reply timed out, shutting down session'); + session.shutdown(); + } else if (winner) { + await winner.waitForPlayout(); + session.shutdown(); + } + }; + + return Toolset.create({ + id: 'end_call', + tools: [ + tool({ + name: 'end_call', + description: `${END_CALL_DESCRIPTION}\n${extraDescription}`, + execute: async (_args, { ctx, abortSignal }) => { + log().debug('end_call tool called'); + const session = ctx.session; + const llm = session.currentAgent.getActivityOrThrow().llm; + + // Lifetime of this invocation: aborts when the session closes, and also when the tool + // call itself is aborted. All listeners/timers below are scoped to it. + const controller = new AbortController(); + const signal = abortSignal + ? AbortSignal.any([abortSignal, controller.signal]) + : controller.signal; + + void onceEvent(session, AgentSessionEventTypes.Close, { signal }) + .then((event) => { + if (!event) return; // signal aborted before close fired + controller.abort(); // stop the delayed-shutdown race + + const jobCtx = getJobContext(false); + if (!jobCtx) return; + + if (deleteRoom) { + jobCtx.addShutdownCallback(async () => { + log().info('deleting the room because the user ended the call'); + await jobCtx.deleteRoom(); + }); + } + + jobCtx.shutdown(String(event.reason)); + }) + .catch((error) => log().error({ error }, 'error during end call shutdown')); + + ctx.speechHandle.addDoneCallback(() => { + if (!(llm instanceof RealtimeModel) || !llm.capabilities.autoToolReplyGeneration) { + session.shutdown(); + return; + } + + void delayedSessionShutdown(session, signal).catch((error) => + log().error({ error }, 'error during delayed session shutdown'), + ); + }); + + if (onToolCalled) { + await onToolCalled({ ctx, arguments: {} }); + } + + const completedEvent = { + ctx, + output: + endInstructions === null + ? undefined + : ({ type: 'output', value: endInstructions } as const), + }; + if (onToolCompleted) { + await onToolCompleted(completedEvent); + } + + return endInstructions ?? undefined; + }, + }), + ], + }); +} diff --git a/agents/src/beta/tools/index.ts b/agents/src/beta/tools/index.ts new file mode 100644 index 000000000..8ef18b993 --- /dev/null +++ b/agents/src/beta/tools/index.ts @@ -0,0 +1,10 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +export { + END_CALL_DESCRIPTION, + createEndCallTool, + type EndCallToolCalledEvent, + type EndCallToolCompletedEvent, + type EndCallToolOptions, +} from './end_call.js'; diff --git a/agents/src/beta/workflows/dob.ts b/agents/src/beta/workflows/dob.ts new file mode 100644 index 000000000..f87402711 --- /dev/null +++ b/agents/src/beta/workflows/dob.ts @@ -0,0 +1,387 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { z } from 'zod'; +import type { LLMModels, STTModelString, TTSModelString } from '../../inference/index.js'; +import type { ChatContext, LLM, RealtimeModel, ToolContextEntry } from '../../llm/index.js'; +import { Instructions, ToolError, ToolFlag, tool } from '../../llm/index.js'; +import type { STT } from '../../stt/index.js'; +import type { TTS } from '../../tts/index.js'; +import type { VAD } from '../../vad.js'; +import { AgentTask } from '../../voice/agent.js'; +import type { TurnDetectionMode } from '../../voice/agent_session.js'; + +const BASE_INSTRUCTIONS = ` +You are only a single step in a broader system, responsible solely for capturing a date of birth. +{modality_specific} +{time_instructions}Call \`update_dob\` at the first opportunity whenever you form a new hypothesis about the date of birth. (before asking any questions or providing any answers.) +Don't invent dates, stick strictly to what the user said. +{confirmation_instructions} +When reading back dates, use a natural spoken format like 'January fifteenth, nineteen ninety'. +If the date is unclear or invalid, or it takes too much back-and-forth, prompt for it in parts: first the month, then the day, then the year. +Ignore unrelated input and avoid going off-topic. Do not generate markdown, greetings, or unnecessary commentary. +Avoid verbosity by not sharing example dates or formats unless prompted to do so. Do not deviate from the goal of collecting the user's birthday. +Always explicitly invoke a tool when applicable. Do not simulate tool usage, no real action is taken unless the tool is explicitly called.{extra_instructions} +`; + +const AUDIO_SPECIFIC = ` +Handle input as noisy voice transcription. Expect that users will say dates aloud with formats like: +- 'January 15th 1990' +- 'the fifteenth of January nineteen ninety' +- '01 15 1990' or 'one fifteen ninety' +- 'Jan 15 90' +- '15th January 1990' +Normalize common spoken patterns silently: +- Convert spoken numbers and ordinals to their numeric form: 'fifteenth' -> 15, 'ninety' -> 1990. +- Recognize month names in various forms: 'Jan', 'January', etc. +- Handle two-digit years appropriately: '90' likely means 1990, '05' likely means 2005. +- Filter out filler words or hesitations. +Don't mention corrections. Treat inputs as possibly imperfect but fix them silently. +`; + +const TEXT_SPECIFIC = ` +Handle input as typed text. Expect users to type their date of birth directly. +Accept common date formats like 'MM/DD/YYYY', 'January 15, 1990', or '1990-01-15'. +Handle two-digit years appropriately: '90' likely means 1990, '05' likely means 2005. +`; + +function renderTemplate( + template: string, + replacements: Record< + 'modality_specific' | 'time_instructions' | 'confirmation_instructions' | 'extra_instructions', + string + >, +): string { + return template.replace( + /\{(modality_specific|time_instructions|confirmation_instructions|extra_instructions)\}/g, + (_match, key: keyof typeof replacements) => replacements[key], + ); +} + +function createDateOnly(year: number, month: number, day: number): Date { + if (year < 1 || year > 9999) { + throw new ToolError(`Invalid date: ${year}-${month}-${day}`); + } + + const date = new Date(Date.UTC(0, month - 1, day)); + date.setUTCFullYear(year); + + if ( + date.getUTCFullYear() !== year || + date.getUTCMonth() !== month - 1 || + date.getUTCDate() !== day + ) { + throw new ToolError(`Invalid date: ${year}-${month}-${day}`); + } + + return date; +} + +function todayDateOnly(): Date { + const today = new Date(); + return createDateOnly(today.getFullYear(), today.getMonth() + 1, today.getDate()); +} + +function formatDate(date: Date): string { + return new Intl.DateTimeFormat('en-US', { + month: 'long', + day: '2-digit', + year: 'numeric', + timeZone: 'UTC', + }).format(date); +} + +function formatTime(time: TimeOfBirth): string { + const date = new Date(Date.UTC(2000, 0, 1, time.hour, time.minute)); + return new Intl.DateTimeFormat('en-US', { + hour: '2-digit', + minute: '2-digit', + hour12: true, + timeZone: 'UTC', + }).format(date); +} + +export interface TimeOfBirth { + hour: number; + minute: number; +} + +export interface GetDOBResult { + dateOfBirth: Date; + timeOfBirth: TimeOfBirth | null; +} + +export interface GetDOBTaskOptions { + extraInstructions?: string; + includeTime?: boolean; + chatCtx?: ChatContext; + turnDetection?: TurnDetectionMode | null; + tools?: readonly ToolContextEntry[]; + stt?: STT | STTModelString | null; + vad?: VAD | null; + llm?: LLM | RealtimeModel | LLMModels | null; + tts?: TTS | TTSModelString | null; + allowInterruptions?: boolean; + requireConfirmation?: boolean; + requireExplicitAsk?: boolean; +} + +export class GetDOBTask extends AgentTask { + private _includeTime: boolean; + private _requireConfirmation?: boolean; + private _requireExplicitAsk: boolean; + private _currentDob: Date | null = null; + private _currentTime: TimeOfBirth | null = null; + + constructor(options: GetDOBTaskOptions = {}) { + const { + extraInstructions = '', + includeTime = false, + chatCtx, + turnDetection, + tools, + stt, + vad, + llm, + tts, + allowInterruptions, + requireConfirmation, + requireExplicitAsk = false, + } = options; + + const timeInstructions = includeTime + ? "Also ask for and capture the time of birth if the user knows it. The time is optional - if the user doesn't know it, proceed without it.\n" + : ''; + const confirmationInstructions = + 'Call `confirm_dob` after the user confirmed the date of birth is correct.'; + const renderInstructions = (modalitySpecific: string, confirmation: string) => + renderTemplate(BASE_INSTRUCTIONS, { + modality_specific: modalitySpecific, + time_instructions: timeInstructions, + confirmation_instructions: confirmation, + extra_instructions: extraInstructions, + }); + + super({ + instructions: new Instructions({ + audio: renderInstructions( + AUDIO_SPECIFIC, + requireConfirmation !== false ? confirmationInstructions : '', + ), + text: renderInstructions( + TEXT_SPECIFIC, + requireConfirmation === true ? confirmationInstructions : '', + ), + }), + chatCtx, + turnDetection: turnDetection ?? undefined, + tools, + stt: stt ?? undefined, + vad: vad ?? undefined, + llm: llm ?? undefined, + tts: tts ?? undefined, + allowInterruptions, + }); + + this._includeTime = includeTime; + this._requireConfirmation = requireConfirmation; + this._requireExplicitAsk = requireExplicitAsk; + + const taskTools = [ + ...(tools ?? []), + this.buildUpdateDOBTool(), + this.buildDeclineDOBCaptureTool(), + ]; + if (includeTime) { + taskTools.push(this.buildUpdateTimeTool()); + } + void this.updateTools(taskTools); + } + + async onEnter(): Promise { + await this.session.generateReply({ + instructions: this._includeTime + ? 'Ask the user to provide their date of birth and, if they know it, their time of birth.' + : 'Ask the user to provide their date of birth.', + }); + } + + private buildUpdateDOBTool() { + const flags = this._requireExplicitAsk ? ToolFlag.IGNORE_ON_ENTER : ToolFlag.NONE; + + return tool({ + name: 'update_dob', + description: + "Update the date of birth provided by the user. Given a spoken month and year (e.g., 'July 2030'), return its numerical representation (7/2030).", + flags, + parameters: z.object({ + year: z.number().int().describe('The birth year (e.g., 1990)'), + month: z.number().int().min(1).max(12).describe('The birth month (1-12)'), + day: z.number().int().min(1).max(31).describe('The birth day (1-31)'), + }), + execute: async ({ year, month, day }: { year: number; month: number; day: number }, opts) => + this.updateDOB(year, month, day, opts.ctx.speechHandle.inputDetails.modality), + }); + } + + private async updateDOB( + year: number, + month: number, + day: number, + modality: 'audio' | 'text', + ): Promise { + // Match the prompt's intent for two-digit years; otherwise year 90 is valid AD 90. + if (year >= 0 && year < 100) { + const currentYear = new Date().getFullYear() % 100; + year += year <= currentYear ? 2000 : 1900; + } + + const dob = createDateOnly(year, month, day); + if (dob > todayDateOnly()) { + throw new ToolError( + `Invalid date of birth: ${formatDate(dob)} is in the future. Date of birth cannot be a future date.`, + ); + } + + this._currentDob = dob; + + if (!this.confirmationRequired(modality)) { + if (!this.done) { + this.complete(this.result()); + } + return null; + } + + const confirmTool = this.buildConfirmTool(dob); + const currentTools = this.toolCtx.tools.filter((t) => !('id' in t) || t.id !== 'confirm_dob'); + await this.updateTools([...currentTools, confirmTool]); + + let response = `The date of birth has been updated to ${formatDate(dob)}`; + if (this._currentTime) { + response += ` at ${formatTime(this._currentTime)}`; + } + + return ( + `${response}\nRepeat the date back to the user in a natural spoken format.\n` + + 'Prompt the user for confirmation, do not call `confirm_dob` directly' + ); + } + + private buildUpdateTimeTool() { + return tool({ + name: 'update_time', + description: 'Update the time of birth provided by the user.', + parameters: z.object({ + hour: z.number().int().min(0).max(23).describe('The birth hour (0-23)'), + minute: z.number().int().min(0).max(59).describe('The birth minute (0-59)'), + }), + execute: async ({ hour, minute }: { hour: number; minute: number }, opts) => + this.updateTime(hour, minute, opts.ctx.speechHandle.inputDetails.modality), + }); + } + + private async updateTime( + hour: number, + minute: number, + modality: 'audio' | 'text', + ): Promise { + this._currentTime = { hour, minute }; + + if (!this.confirmationRequired(modality) && this._currentDob !== null) { + if (!this.done) { + this.complete(this.result()); + } + return null; + } + + if (this.confirmationRequired(modality)) { + const confirmTool = this.buildConfirmTool(this._currentDob); + const currentTools = this.toolCtx.tools.filter((t) => !('id' in t) || t.id !== 'confirm_dob'); + await this.updateTools([...currentTools, confirmTool]); + } + + let response = `The time of birth has been updated to ${formatTime(this._currentTime)}`; + if (this._currentDob) { + response = `The date and time of birth has been updated to ${formatDate(this._currentDob)} at ${formatTime(this._currentTime)}`; + } + + if (this.confirmationRequired(modality)) { + response += + '\nRepeat the time back to the user in a natural spoken format.\n' + + 'Prompt the user for confirmation, do not call `confirm_dob` directly'; + } else { + response += '\nThe date of birth has not been provided yet, ask the user to provide it.'; + } + + return response; + } + + private buildConfirmTool(capturedDob: Date | null) { + const capturedTime = this._currentTime; + + return tool({ + name: 'confirm_dob', + description: 'Call after the user confirms the date of birth is correct.', + execute: async () => { + if ( + capturedDob?.getTime() !== this._currentDob?.getTime() || + capturedTime?.hour !== this._currentTime?.hour || + capturedTime?.minute !== this._currentTime?.minute + ) { + await this.session.generateReply({ + instructions: + 'The date of birth has changed since confirmation was requested, ask the user to confirm the updated date.', + }); + return; + } + + if (this._currentDob === null) { + await this.session.generateReply({ + instructions: 'No date of birth was provided yet, ask the user to provide it.', + }); + return; + } + + if (!this.done) { + this.complete(this.result()); + } + }, + }); + } + + private buildDeclineDOBCaptureTool() { + return tool({ + name: 'decline_dob_capture', + description: 'Handles the case when the user explicitly declines to provide a date of birth.', + flags: ToolFlag.IGNORE_ON_ENTER, + parameters: z.object({ + reason: z + .string() + .describe('A short explanation of why the user declined to provide the date of birth'), + }), + execute: async ({ reason }: { reason: string }) => { + if (!this.done) { + this.complete(new ToolError(`couldn't get the date of birth: ${reason}`)); + } + }, + }); + } + + private confirmationRequired(modality: 'audio' | 'text'): boolean { + if (this._requireConfirmation !== undefined) { + return this._requireConfirmation; + } + return modality === 'audio'; + } + + private result(): GetDOBResult { + if (!this._currentDob) { + throw new Error('date of birth has not been provided'); + } + + return { + dateOfBirth: this._currentDob, + timeOfBirth: this._currentTime, + }; + } +} diff --git a/agents/src/beta/workflows/index.ts b/agents/src/beta/workflows/index.ts index 0b0cfcb52..38280a36c 100644 --- a/agents/src/beta/workflows/index.ts +++ b/agents/src/beta/workflows/index.ts @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: 2026 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 +export { GetDOBTask, type GetDOBResult, type GetDOBTaskOptions, type TimeOfBirth } from './dob.js'; export { TaskGroup, type TaskCompletedEvent, diff --git a/agents/src/beta/workflows/task_group.ts b/agents/src/beta/workflows/task_group.ts index 8c96790dd..add6655fb 100644 --- a/agents/src/beta/workflows/task_group.ts +++ b/agents/src/beta/workflows/task_group.ts @@ -84,10 +84,7 @@ export class TaskGroup extends AgentTask { const outOfScopeTool = this.buildOutOfScopeTool(taskId); if (outOfScopeTool) { - await this._currentTask.updateTools({ - ...this._currentTask.toolCtx, - out_of_scope: outOfScopeTool, - }); + await this._currentTask.updateTools([...this._currentTask.toolCtx.tools, outOfScopeTool]); } try { @@ -190,6 +187,7 @@ export class TaskGroup extends AgentTask { const visitedTasks = this._visitedTasks; return tool({ + name: 'out_of_scope', description, flags: ToolFlag.IGNORE_ON_ENTER, parameters: z.object({ diff --git a/agents/src/beta/workflows/warm_transfer.ts b/agents/src/beta/workflows/warm_transfer.ts index b3d99dda0..3cf0016d3 100644 --- a/agents/src/beta/workflows/warm_transfer.ts +++ b/agents/src/beta/workflows/warm_transfer.ts @@ -12,9 +12,9 @@ import type { Instructions, LLM, RealtimeModel, - ToolContext, + ToolContextEntry, } from '../../llm/index.js'; -import { ToolError, ToolFlag, tool } from '../../llm/index.js'; +import { ToolContext, ToolError, ToolFlag, tool } from '../../llm/index.js'; import { log } from '../../log.js'; import type { STT } from '../../stt/index.js'; import type { TTS } from '../../tts/index.js'; @@ -72,7 +72,7 @@ export interface WarmTransferTaskOptions { instructions?: InstructionParts | string; chatCtx?: ChatContext; turnDetection?: TurnDetectionMode | null; - tools?: ToolContext; + tools?: readonly ToolContextEntry[]; stt?: STT | STTModelString | null; vad?: VAD | null; llm?: LLM | RealtimeModel | LLMModels | null; @@ -171,13 +171,13 @@ export class WarmTransferTask extends AgentTask { this._resolveHumanAgentFailed = resolve; }); - this._tools = { - ...this._tools, - connect_to_caller: this.buildConnectToCallerTool(), - decline_transfer: this.buildDeclineTransferTool(), - voicemail_detected: this.buildVoicemailDetectedTool(), - }; - this._chatCtx = this._chatCtx.copy({ toolCtx: this._tools }); + this._toolCtx = new ToolContext([ + ...this._toolCtx.tools, + this.buildConnectToCallerTool(), + this.buildDeclineTransferTool(), + this.buildVoicemailDetectedTool(), + ]); + this._chatCtx = this._chatCtx.copy({ toolCtx: this._toolCtx }); this._taskTurnDetection = turnDetection ?? undefined; this._allowInterruptions = allowInterruptions; @@ -268,6 +268,7 @@ export class WarmTransferTask extends AgentTask { private buildConnectToCallerTool() { return tool({ + name: 'connect_to_caller', description: 'Called when the human agent wants to connect to the caller.', flags: ToolFlag.IGNORE_ON_ENTER, execute: async () => { @@ -288,6 +289,7 @@ export class WarmTransferTask extends AgentTask { private buildDeclineTransferTool() { return tool({ + name: 'decline_transfer', description: 'Handles the case when the human agent explicitly declines to connect to the caller.', parameters: z.object({ @@ -304,6 +306,7 @@ export class WarmTransferTask extends AgentTask { private buildVoicemailDetectedTool() { return tool({ + name: 'voicemail_detected', description: 'Called when the call reaches voicemail. Use this tool AFTER you hear the voicemail greeting', flags: ToolFlag.IGNORE_ON_ENTER, @@ -418,7 +421,7 @@ export class WarmTransferTask extends AgentTask { vad: this.vad, llm: this.llm, tts: this.tts, - tools: this.toolCtx, + tools: this.toolCtx.tools, chatCtx: this._chatCtx.copy(), turnDetection: this._taskTurnDetection, allowInterruptions: this._allowInterruptions, diff --git a/agents/src/generator.test.ts b/agents/src/generator.test.ts new file mode 100644 index 000000000..a92ff9e58 --- /dev/null +++ b/agents/src/generator.test.ts @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { describe, expect, it } from 'vitest'; +import { defineAgent, isAgent } from './generator.js'; + +describe('generator', () => { + it('marks definitions created with defineAgent as agents', () => { + const agent = defineAgent({ + entry: async () => {}, + }); + + expect(isAgent(agent)).toBe(true); + }); + + it('does not treat unmarked structural objects as agents', () => { + expect(isAgent({ entry: async () => {} })).toBe(false); + }); +}); diff --git a/agents/src/generator.ts b/agents/src/generator.ts index b7e7ad93f..e9bb09075 100644 --- a/agents/src/generator.ts +++ b/agents/src/generator.ts @@ -3,24 +3,22 @@ // SPDX-License-Identifier: Apache-2.0 import type { JobContext, JobProcess } from './job.js'; +export const AGENT_DEFINITION_SYMBOL = Symbol.for('livekit.agents.AgentDefinition'); + /** @see {@link defineAgent} */ -export interface Agent> { +export interface AgentDefinition> { entry: (ctx: JobContext) => Promise; prewarm?: (proc: JobProcess) => unknown; } +export type Agent> = AgentDefinition; + /** Helper to check if an object is an agent before running it. * * @internal */ -export function isAgent(obj: unknown): obj is Agent { - return ( - typeof obj === 'object' && - obj !== null && - 'entry' in obj && - typeof (obj as Agent).entry === 'function' && - (('prewarm' in obj && typeof (obj as Agent).prewarm === 'function') || !('prewarm' in obj)) - ); +export function isAgent(obj: unknown): obj is AgentDefinition { + return typeof obj === 'object' && obj !== null && AGENT_DEFINITION_SYMBOL in obj; } /** @@ -34,7 +32,10 @@ export function isAgent(obj: unknown): obj is Agent { * ``` */ export function defineAgent>( - agent: Agent, -): Agent { + agent: AgentDefinition, +): AgentDefinition { + Object.defineProperty(agent, AGENT_DEFINITION_SYMBOL, { + value: true, + }); return agent; } diff --git a/agents/src/index.test.ts b/agents/src/index.test.ts new file mode 100644 index 000000000..40b7ade44 --- /dev/null +++ b/agents/src/index.test.ts @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { describe, expect, it } from 'vitest'; +import { + Agent, + AgentSession, + ChatContext, + ModelUsageCollector, + logMetrics, + tool, +} from './index.js'; + +describe('index exports', () => { + it('exports voice, llm, and metrics APIs directly from the package root', () => { + expect(Agent).toBeDefined(); + expect(AgentSession).toBeDefined(); + expect(ChatContext).toBeDefined(); + expect(tool).toBeDefined(); + expect(ModelUsageCollector).toBeDefined(); + expect(logMetrics).toBeDefined(); + }); +}); diff --git a/agents/src/index.ts b/agents/src/index.ts index 07e4b45da..9f1519bf3 100644 --- a/agents/src/index.ts +++ b/agents/src/index.ts @@ -14,14 +14,16 @@ export * from './audio.js'; export * as beta from './beta/index.js'; export * as cli from './cli.js'; export * from './connection_pool.js'; -export * from './generator.js'; +export { defineAgent, isAgent, type AgentDefinition } from './generator.js'; export * as inference from './inference/index.js'; export * from './inference_runner.js'; export * as ipc from './ipc/index.js'; export * from './job.js'; export * from './language.js'; +export * from './llm/index.js'; export * as llm from './llm/index.js'; export * from './log.js'; +export * from './metrics/index.js'; export * as metrics from './metrics/index.js'; export * from './plugin.js'; export * as stream from './stream/index.js'; @@ -34,6 +36,6 @@ export * from './types.js'; export * from './utils.js'; export * from './vad.js'; export * from './version.js'; +export * from './voice/index.js'; export * as voice from './voice/index.js'; -export { createTimedString, isTimedString, type TimedString } from './voice/io.js'; export * from './worker.js'; diff --git a/agents/src/inference/llm.ts b/agents/src/inference/llm.ts index 7196ba998..856630208 100644 --- a/agents/src/inference/llm.ts +++ b/agents/src/inference/llm.ts @@ -277,7 +277,7 @@ export class LLM extends llm.LLM { chat({ chatCtx, - toolCtx, + toolCtx: toolCtxInput, connOptions = DEFAULT_API_CONNECT_OPTIONS, parallelToolCalls, toolChoice, @@ -285,13 +285,14 @@ export class LLM extends llm.LLM { extraKwargs, }: { chatCtx: llm.ChatContext; - toolCtx?: llm.ToolContext; + toolCtx?: llm.ToolContextLike; connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: llm.ToolChoice; inferenceClass?: InferenceClass; extraKwargs?: Record; }): LLMStream { + const toolCtx = llm.toToolContext(toolCtxInput); let modelOptions: Record = { ...(extraKwargs || {}) }; parallelToolCalls = @@ -299,7 +300,11 @@ export class LLM extends llm.LLM { ? parallelToolCalls : this.opts.modelOptions.parallel_tool_calls; - if (toolCtx && Object.keys(toolCtx).length > 0 && parallelToolCalls !== undefined) { + if ( + toolCtx && + Object.keys(toolCtx.functionTools).length > 0 && + parallelToolCalls !== undefined + ) { modelOptions.parallel_tool_calls = parallelToolCalls; } @@ -402,6 +407,8 @@ export class LLMStream extends llm.LLMStream { this.providerFmt, )) as OpenAI.ChatCompletionMessageParam[]; + // Provider-defined tools are not supported by the inference adapter; `sortedToolEntries` + // yields only function tools (sorted by name), so they are skipped here. See AJS-112. const tools = this.toolCtx ? llm.sortedToolEntries(this.toolCtx).map(([name, func]) => { const oaiParams = { diff --git a/agents/src/inference/tts.ts b/agents/src/inference/tts.ts index 47a93bb04..78c0d2a26 100644 --- a/agents/src/inference/tts.ts +++ b/agents/src/inference/tts.ts @@ -127,6 +127,8 @@ export interface InworldOptions { speaking_rate?: number; /** Range 0-2. */ temperature?: number; + /** Controls output variation on inworld-tts-2 only. */ + delivery_mode?: 'DELIVERY_MODE_UNSPECIFIED' | 'STABLE' | 'BALANCED' | 'CREATIVE'; timestamp_type?: 'TIMESTAMP_TYPE_UNSPECIFIED' | 'WORD' | 'CHARACTER'; apply_text_normalization?: 'APPLY_TEXT_NORMALIZATION_UNSPECIFIED' | 'ON' | 'OFF'; /** @deprecated Backward-compatible alias. Use `apply_text_normalization`. */ diff --git a/agents/src/llm/async_toolset.test.ts b/agents/src/llm/async_toolset.test.ts new file mode 100644 index 000000000..dfb5b6870 --- /dev/null +++ b/agents/src/llm/async_toolset.test.ts @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { describe, expect, it } from 'vitest'; +import { AsyncToolset } from './async_toolset.js'; +import { tool } from './tool_context.js'; + +describe('AsyncToolset', () => { + const lookup = tool({ + name: 'lookup', + description: 'lookup', + execute: async () => 'ok', + }); + + it('is a scope container, not a separate async tool type', () => { + const toolset = AsyncToolset.create({ id: 'booking', tools: [lookup] }); + + expect(toolset.id).toBe('booking'); + expect(toolset.tools).toEqual([lookup]); + expect(toolset._executor).toBeDefined(); + }); +}); diff --git a/agents/src/llm/async_toolset.ts b/agents/src/llm/async_toolset.ts new file mode 100644 index 000000000..e1b0009da --- /dev/null +++ b/agents/src/llm/async_toolset.ts @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { AgentActivity } from '../voice/agent_activity.js'; +import type { AgentSession } from '../voice/agent_session.js'; +import { ToolExecutor, type ToolHandlingOptions } from '../voice/tool_executor.js'; +import { Toolset, type ToolsetCreateOptions } from './tool_context.js'; + +export interface AsyncToolsetCreateOptions extends ToolsetCreateOptions { + toolHandling?: ToolHandlingOptions; +} + +export class AsyncToolset extends Toolset { + readonly _executor = new ToolExecutor({ owningActivity: null }); + private readonly asyncToolOptionsOverride?: ToolHandlingOptions['asyncOptions']; + + private constructor({ id, tools, toolHandling }: AsyncToolsetCreateOptions) { + super({ id, tools }); + this.asyncToolOptionsOverride = toolHandling?.asyncOptions; + } + + static override create(options: AsyncToolsetCreateOptions): AsyncToolset { + return new AsyncToolset(options); + } + + _attachActivity({ + activity, + session, + }: { + activity: AgentActivity | null; + session: AgentSession; + }): void { + this._executor.setOwningActivity( + activity as unknown as Parameters[0], + ); + if (this.asyncToolOptionsOverride) { + this._executor.setToolOptions(this.asyncToolOptionsOverride); + return; + } + const activityOptions = activity?.agent?._asyncToolOptions; + if (activityOptions) { + this._executor.setToolOptions(activityOptions); + return; + } + const sessionOptions = session._asyncToolOptions; + if (sessionOptions) { + this._executor.setToolOptions(sessionOptions); + } + } + + override async aclose(): Promise { + await super.aclose(); + await this._executor.drain(); + await this._executor.aclose(); + } +} diff --git a/agents/src/llm/chat_context.test.ts b/agents/src/llm/chat_context.test.ts index 849469e55..101a9fe46 100644 --- a/agents/src/llm/chat_context.test.ts +++ b/agents/src/llm/chat_context.test.ts @@ -19,6 +19,7 @@ import { isInstructions, renderInstructions, } from './chat_context.js'; +import { ProviderTool, ToolContext, tool } from './tool_context.js'; initializeLogger({ pretty: false, level: 'error' }); @@ -1479,3 +1480,32 @@ extra`; expect((baseCtx.items[0]! as ChatMessage).content[0]).toBe(instr); }); }); + +describe('ChatContext.copy with toolCtx filter', () => { + it('drops function calls / outputs whose tool is not in the supplied ToolContext', () => { + const known = tool({ name: 'known', description: 'k', execute: async () => 'ok' }); + const ctx = new ChatContext([ + ChatMessage.create({ role: 'user', content: ['hello'] }), + FunctionCall.create({ callId: 'c1', name: 'known', args: '{}' }), + FunctionCallOutput.create({ callId: 'c1', name: 'known', output: 'done', isError: false }), + FunctionCall.create({ callId: 'c2', name: 'removed', args: '{}' }), + FunctionCallOutput.create({ callId: 'c2', name: 'removed', output: 'x', isError: false }), + ]); + + const filtered = ctx.copy({ toolCtx: new ToolContext([known]) }); + const types = filtered.items.map((i) => `${i.type}:${'name' in i ? i.name : ''}`); + expect(types).toEqual(['message:', 'function_call:known', 'function_call_output:known']); + }); + + it('keeps provider-tool calls when the ToolContext holds a matching provider tool id', () => { + class CodeRunner extends ProviderTool {} + const provider = new CodeRunner({ id: 'code_runner' }); + const ctx = new ChatContext([ + FunctionCall.create({ callId: 'p1', name: 'code_runner', args: '{}' }), + FunctionCall.create({ callId: 'p2', name: 'other', args: '{}' }), + ]); + + const filtered = ctx.copy({ toolCtx: new ToolContext([provider]) }); + expect(filtered.items.map((i) => ('name' in i ? i.name : ''))).toEqual(['code_runner']); + }); +}); diff --git a/agents/src/llm/chat_context.ts b/agents/src/llm/chat_context.ts index 743e7efb8..34cf39200 100644 --- a/agents/src/llm/chat_context.ts +++ b/agents/src/llm/chat_context.ts @@ -835,7 +835,7 @@ export class ChatContext { continue; } - if (toolCtx !== undefined && isToolCallOrOutput(item) && toolCtx[item.name] === undefined) { + if (toolCtx !== undefined && isToolCallOrOutput(item) && !toolCtx.hasTool(item.name)) { continue; } diff --git a/agents/src/llm/fallback_adapter.test.ts b/agents/src/llm/fallback_adapter.test.ts index a9747c885..7ab549ff4 100644 --- a/agents/src/llm/fallback_adapter.test.ts +++ b/agents/src/llm/fallback_adapter.test.ts @@ -9,7 +9,7 @@ import { delay } from '../utils.js'; import type { ChatContext } from './chat_context.js'; import { FallbackAdapter } from './fallback_adapter.js'; import { type ChatChunk, LLM, LLMStream } from './llm.js'; -import type { ToolChoice, ToolContext } from './tool_context.js'; +import type { ToolChoice, ToolContextLike } from './tool_context.js'; class MockLLMStream extends LLMStream { public myLLM: LLM; @@ -18,7 +18,7 @@ class MockLLMStream extends LLMStream { llm: LLM, opts: { chatCtx: ChatContext; - toolCtx?: ToolContext; + toolCtx?: ToolContextLike; connOptions: APIConnectOptions; }, private shouldFail: boolean = false, @@ -64,7 +64,7 @@ class MockLLM extends LLM { chat(opts: { chatCtx: ChatContext; - toolCtx?: ToolContext; + toolCtx?: ToolContextLike; connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: ToolChoice; diff --git a/agents/src/llm/fallback_adapter.ts b/agents/src/llm/fallback_adapter.ts index 128c2392c..fdacac36b 100644 --- a/agents/src/llm/fallback_adapter.ts +++ b/agents/src/llm/fallback_adapter.ts @@ -8,7 +8,7 @@ import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS } from '../types.js import type { ChatContext } from './chat_context.js'; import type { ChatChunk } from './llm.js'; import { LLM, LLMStream } from './llm.js'; -import type { ToolChoice, ToolContext } from './tool_context.js'; +import type { ToolChoice, ToolContextLike } from './tool_context.js'; /** * Default connection options for FallbackAdapter. @@ -113,7 +113,7 @@ export class FallbackAdapter extends LLM { chat(opts: { chatCtx: ChatContext; - toolCtx?: ToolContext; + toolCtx?: ToolContextLike; connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: ToolChoice; @@ -159,7 +159,7 @@ class FallbackLLMStream extends LLMStream { adapter: FallbackAdapter, opts: { chatCtx: ChatContext; - toolCtx?: ToolContext; + toolCtx?: ToolContextLike; connOptions: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: ToolChoice; diff --git a/agents/src/llm/index.ts b/agents/src/llm/index.ts index 8fab0e4a7..7fe9256a6 100644 --- a/agents/src/llm/index.ts +++ b/agents/src/llm/index.ts @@ -2,23 +2,46 @@ // // SPDX-License-Identifier: Apache-2.0 export { + CONFIRM_DUPLICATE_PARAM, handoff, isFunctionTool, - tool, + isProviderTool, + isTool, + isToolset, + ProviderTool, sortedToolEntries, sortedToolNames, + tool, + ToolContext, ToolError, ToolFlag, + Toolset, + normalizeToolContextInit, + toToolContext, type AgentHandoff, + type DuplicateMode, type FunctionTool, - type ProviderDefinedTool, type Tool, + type ToolCalledEvent, type ToolChoice, - type ToolContext, + type ToolCompletedEvent, + type ToolContextEntry, + type ToolContextInit, + type ToolContextLike, type ToolOptions, + type ToolsetContext, + type ToolsetCreateOptions, type ToolType, } from './tool_context.js'; +export { AsyncToolset, type AsyncToolsetCreateOptions } from './async_toolset.js'; +export type { + AsyncToolOptions, + DuplicatePromptArgs, + ReplyPromptArgs, + ToolHandlingOptions, +} from '../voice/tool_executor.js'; + export { AgentHandoffItem, AgentConfigUpdate, @@ -45,6 +68,7 @@ export { LLMStream, type ChatChunk, type ChoiceDelta, + type CollectedResponse, type CompletionUsage, type LLMCallbacks, } from './llm.js'; diff --git a/agents/src/llm/llm.test.ts b/agents/src/llm/llm.test.ts new file mode 100644 index 000000000..b921e62c2 --- /dev/null +++ b/agents/src/llm/llm.test.ts @@ -0,0 +1,149 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { beforeAll, describe, expect, it } from 'vitest'; +import { initializeLogger } from '../log.js'; +import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS } from '../types.js'; +import { delay } from '../utils.js'; +import { ChatContext, FunctionCall } from './chat_context.js'; +import { type ChatChunk, LLM, LLMStream } from './llm.js'; +import type { ToolChoice, ToolCtxInput } from './tool_context.js'; + +class MockLLMStream extends LLMStream { + constructor( + llm: LLM, + opts: { + chatCtx: ChatContext; + toolCtx?: ToolCtxInput; + connOptions: APIConnectOptions; + }, + private chunks: ChatChunk[], + ) { + super(llm, opts); + } + + protected async run(): Promise { + for (const chunk of this.chunks) { + this.queue.put(chunk); + await delay(1); + } + } +} + +class MockLLM extends LLM { + constructor(private chunks: ChatChunk[]) { + super(); + } + + label(): string { + return 'mock-llm'; + } + + chat(opts: { + chatCtx: ChatContext; + toolCtx?: ToolCtxInput; + connOptions?: APIConnectOptions; + parallelToolCalls?: boolean; + toolChoice?: ToolChoice; + extraKwargs?: Record; + }): LLMStream { + return new MockLLMStream( + this, + { + chatCtx: opts.chatCtx, + toolCtx: opts.toolCtx, + connOptions: opts.connOptions ?? DEFAULT_API_CONNECT_OPTIONS, + }, + this.chunks, + ); + } +} + +describe('LLMStream.collect', () => { + beforeAll(() => { + initializeLogger({ pretty: false }); + process.on('unhandledRejection', () => {}); + }); + + it('joins content parts and trims surrounding whitespace', async () => { + const llm = new MockLLM([ + { id: '1', delta: { role: 'assistant', content: ' Hello' } }, + { id: '1', delta: { role: 'assistant', content: ', ' } }, + { id: '1', delta: { role: 'assistant', content: 'world! ' } }, + ]); + + const response = await llm.chat({ chatCtx: new ChatContext() }).collect(); + + expect(response.text).toBe('Hello, world!'); + expect(response.toolCalls).toHaveLength(0); + expect(response.usage).toBeUndefined(); + expect(response.extra).toEqual({}); + }); + + it('accumulates tool calls across chunks', async () => { + const callA = new FunctionCall({ + callId: 'call_a', + name: 'get_weather', + args: '{"city":"SF"}', + }); + const callB = new FunctionCall({ + callId: 'call_b', + name: 'play_song', + args: '{"name":"x"}', + }); + const llm = new MockLLM([ + { id: '1', delta: { role: 'assistant', toolCalls: [callA] } }, + { id: '1', delta: { role: 'assistant', toolCalls: [callB] } }, + ]); + + const response = await llm.chat({ chatCtx: new ChatContext() }).collect(); + + expect(response.text).toBe(''); + expect(response.toolCalls).toHaveLength(2); + expect(response.toolCalls[0]!.callId).toBe('call_a'); + expect(response.toolCalls[1]!.callId).toBe('call_b'); + }); + + it('captures the latest usage and merges extra data', async () => { + const llm = new MockLLM([ + { id: '1', delta: { role: 'assistant', content: 'hi', extra: { a: 1 } } }, + { + id: '1', + delta: { role: 'assistant', content: ' there', extra: { b: 2 } }, + usage: { + completionTokens: 2, + promptTokens: 5, + promptCachedTokens: 0, + totalTokens: 7, + }, + }, + { + id: '1', + usage: { + completionTokens: 3, + promptTokens: 5, + promptCachedTokens: 0, + totalTokens: 8, + }, + }, + ]); + + const response = await llm.chat({ chatCtx: new ChatContext() }).collect(); + + expect(response.text).toBe('hi there'); + expect(response.usage?.completionTokens).toBe(3); + expect(response.usage?.totalTokens).toBe(8); + expect(response.extra).toEqual({ a: 1, b: 2 }); + }); + + it('returns empty response for an empty stream', async () => { + const llm = new MockLLM([]); + + const response = await llm.chat({ chatCtx: new ChatContext() }).collect(); + + expect(response.text).toBe(''); + expect(response.toolCalls).toHaveLength(0); + expect(response.usage).toBeUndefined(); + expect(response.extra).toEqual({}); + }); +}); diff --git a/agents/src/llm/llm.ts b/agents/src/llm/llm.ts index 553541939..ae9bde31c 100644 --- a/agents/src/llm/llm.ts +++ b/agents/src/llm/llm.ts @@ -11,7 +11,12 @@ import { recordException, traceTypes, tracer } from '../telemetry/index.js'; import { type APIConnectOptions, intervalForRetry } from '../types.js'; import { AsyncIterableQueue, delay, startSoon, toError } from '../utils.js'; import { type ChatContext, type ChatRole, type FunctionCall } from './chat_context.js'; -import type { ToolChoice, ToolContext } from './tool_context.js'; +import { + type ToolChoice, + type ToolContext, + type ToolContextLike, + toToolContext, +} from './tool_context.js'; export interface ChoiceDelta { role: ChatRole; @@ -35,6 +40,17 @@ export interface ChatChunk { usage?: CompletionUsage; } +export interface CollectedResponse { + text: string; + toolCalls: FunctionCall[]; + usage?: CompletionUsage; + /** + * Provider-specific extra data accumulated across chunks + * (e.g., xAI encrypted reasoning, Google thought signatures). + */ + extra: Record; +} + export interface LLMError { type: 'llm_error'; timestamp: number; @@ -91,7 +107,12 @@ export abstract class LLM extends (EventEmitter as new () => TypedEmitter { connOptions, }: { chatCtx: ChatContext; - toolCtx?: ToolContext; + toolCtx?: ToolContextLike; connOptions: APIConnectOptions; }, ) { this.#llm = llm; this.#chatCtx = chatCtx; - this.#toolCtx = toolCtx; + this.#toolCtx = toToolContext(toolCtx); this._connOptions = connOptions; this.monitorMetrics(); this.abortController.signal.addEventListener('abort', () => { @@ -331,6 +352,53 @@ export abstract class LLMStream implements AsyncIterableIterator { this.abortController.abort(); } + /** + * Collect the entire stream into a single response. + * + * @example + * ```ts + * const response = await myLlm.chat({ chatCtx, toolCtx }).collect(); + * + * for (const tc of response.toolCalls) { + * // execute the tool call... + * } + * ``` + */ + async collect(): Promise { + const textParts: string[] = []; + const toolCalls: FunctionCall[] = []; + let usage: CompletionUsage | undefined; + const extra: Record = {}; + + try { + for await (const chunk of this) { + if (chunk.delta) { + if (chunk.delta.content) { + textParts.push(chunk.delta.content); + } + if (chunk.delta.toolCalls) { + toolCalls.push(...chunk.delta.toolCalls); + } + if (chunk.delta.extra) { + Object.assign(extra, chunk.delta.extra); + } + } + if (chunk.usage !== undefined) { + usage = chunk.usage; + } + } + } finally { + this.close(); + } + + return { + text: textParts.join('').trim(), + toolCalls, + usage, + extra, + }; + } + [Symbol.asyncIterator](): LLMStream { return this; } diff --git a/agents/src/llm/tool_context.test.ts b/agents/src/llm/tool_context.test.ts index 183c7ff2a..817bc9895 100644 --- a/agents/src/llm/tool_context.test.ts +++ b/agents/src/llm/tool_context.test.ts @@ -5,9 +5,92 @@ import { describe, expect, it } from 'vitest'; import { z } from 'zod'; import * as z3 from 'zod/v3'; import * as z4 from 'zod/v4'; -import { type ToolOptions, tool } from './tool_context.js'; +import { voice } from '../index.js'; +import { + CONFIRM_DUPLICATE_PARAM, + ProviderTool, + type Tool, + ToolContext, + ToolFlag, + type ToolOptions, + Toolset, + type ToolsetContext, + tool, +} from './tool_context.js'; import { createToolOptions, oaiParams } from './utils.js'; +describe('tools object compatibility', () => { + it('normalizes object tools into named function tools', () => { + const getWeather = tool({ + description: 'Get the weather', + execute: async () => 'sunny', + }); + + const toolCtx = new ToolContext({ + getWeather, + }); + + const namedTool = toolCtx.getFunctionTool('getWeather'); + expect(namedTool).toBeDefined(); + expect(namedTool?.name).toBe('getWeather'); + expect(namedTool?.id).toBe('getWeather'); + expect(namedTool?.description).toBe('Get the weather'); + expect(toolCtx.tools).toEqual([namedTool]); + }); + + it('accepts object tools in agent and session construction APIs', () => { + const tools = { + lookupOrder: tool({ + description: 'Look up an order', + execute: async () => ({ ok: true }), + }), + }; + + const agent = new voice.Agent({ instructions: 'help', tools }); + const createdAgent = voice.Agent.create({ instructions: 'help', tools }); + const task = new voice.AgentTask({ instructions: 'help', tools }); + const createdTask = voice.AgentTask.create({ instructions: 'help', tools }); + const session = new voice.AgentSession({ tools, vad: null }); + + expect(agent.toolCtx.getFunctionTool('lookupOrder')?.name).toBe('lookupOrder'); + expect(createdAgent.toolCtx.getFunctionTool('lookupOrder')?.name).toBe('lookupOrder'); + expect(task.toolCtx.getFunctionTool('lookupOrder')?.name).toBe('lookupOrder'); + expect(createdTask.toolCtx.getFunctionTool('lookupOrder')?.name).toBe('lookupOrder'); + expect(Reflect.get(session, '_tools')[0]?.name).toBe('lookupOrder'); + }); + + it('keeps rejecting duplicate named function tools in array syntax', () => { + const first = tool({ + name: 'duplicate', + description: 'First tool', + execute: async () => 'first', + }); + const second = tool({ + name: 'duplicate', + description: 'Second tool', + execute: async () => 'second', + }); + + expect(() => new ToolContext([first, second])).toThrow('duplicate function name'); + }); + + it('rejects named tools and toolsets in object syntax', () => { + const namedTool = tool({ + name: 'alreadyNamed', + description: 'Already named', + execute: async () => 'ok', + }); + const toolset = new Toolset({ id: 'group', tools: [] }); + + expect(() => new ToolContext({ alreadyNamed: namedTool })).toThrow( + 'tools object entry "alreadyNamed" must be anonymous', + ); + expect(() => new ToolContext({ group: toolset })).toThrow( + 'tools object entry "group" must be an anonymous function tool', + ); + }); +}); + describe('Tool Context', () => { describe('oaiParams', () => { it('should handle basic object schema', () => { @@ -77,6 +160,7 @@ describe('Tool Context', () => { describe('tool', () => { it('should create and execute a basic core tool', async () => { const getWeather = tool({ + name: 'getWeather', description: 'Get the weather for a given location', parameters: z.object({ location: z.string(), @@ -93,8 +177,100 @@ describe('Tool Context', () => { expect(result).toBe('The weather in San Francisco is sunny, John'); }); + it('defaults to allow duplicate calls and no flags', () => { + const testFunction = tool({ + name: 'defaultDuplicatePolicy', + description: 'Test default duplicate policy', + execute: async () => 'ok', + }); + + expect(testFunction.onDuplicate).toBe('allow'); + expect(testFunction.flags).toBe(ToolFlag.NONE); + }); + + it('stores cancellable flag and duplicate policy', () => { + const testFunction = tool({ + name: 'cancellableDuplicatePolicy', + description: 'Test duplicate policy', + execute: async () => 'ok', + flags: ToolFlag.CANCELLABLE, + onDuplicate: 'replace', + }); + + expect(testFunction.onDuplicate).toBe('replace'); + expect(testFunction.flags & ToolFlag.CANCELLABLE).toBeTruthy(); + }); + + it('injects and strips confirm duplicate parameter for zod schemas', async () => { + let receivedArgs: Record | undefined; + const testFunction = tool({ + name: 'confirmDuplicateZod', + description: 'Test confirm duplicate with zod', + parameters: z.object({ + query: z.string(), + }), + onDuplicate: 'confirm', + execute: async (args) => { + receivedArgs = args; + return 'ok'; + }, + }); + + const params = oaiParams(testFunction.parameters); + expect(params.properties).toHaveProperty(CONFIRM_DUPLICATE_PARAM); + expect(params.required).toContain(CONFIRM_DUPLICATE_PARAM); + + const result = await testFunction.execute( + { query: 'flights', [CONFIRM_DUPLICATE_PARAM]: true }, + createToolOptions('confirm-zod'), + ); + + expect(result).toBe('ok'); + expect(receivedArgs).toEqual({ query: 'flights' }); + }); + + it('injects and strips confirm duplicate parameter for raw JSON schemas', async () => { + let receivedArgs: Record | undefined; + const testFunction = tool({ + name: 'confirmDuplicateRaw', + description: 'Test confirm duplicate with raw schema', + parameters: { + type: 'object', + properties: { + query: { type: 'string' }, + }, + required: ['query'], + }, + onDuplicate: 'confirm', + execute: async (args) => { + receivedArgs = args; + return 'ok'; + }, + }); + + expect(testFunction.parameters).toMatchObject({ + properties: { + [CONFIRM_DUPLICATE_PARAM]: { + type: ['boolean', 'null'], + }, + }, + }); + expect((testFunction.parameters as { required: string[] }).required).toContain( + CONFIRM_DUPLICATE_PARAM, + ); + + const result = await testFunction.execute( + { query: 'flights', [CONFIRM_DUPLICATE_PARAM]: true }, + createToolOptions('confirm-raw'), + ); + + expect(result).toBe('ok'); + expect(receivedArgs).toEqual({ query: 'flights' }); + }); + it('should properly type a callable function', async () => { const testFunction = tool({ + name: 'testFunction', description: 'Test function', parameters: z.object({ name: z.string().describe('The user name'), @@ -114,6 +290,7 @@ describe('Tool Context', () => { it('should handle async execution', async () => { const testFunction = tool({ + name: 'asyncTestFunction', description: 'Async test function', parameters: z.object({ delay: z.number().describe('Delay in milliseconds'), @@ -157,6 +334,7 @@ describe('Tool Context', () => { describe('optional parameters', () => { it('should create a tool without parameters', async () => { const simpleAction = tool({ + name: 'simpleAction', description: 'Perform a simple action', execute: async () => { return 'Action performed'; @@ -175,6 +353,7 @@ describe('Tool Context', () => { it('should support .optional() fields in tool parameters', async () => { const weatherTool = tool({ + name: 'weatherTool', description: 'Get weather information', parameters: z.object({ location: z.string().describe('The city or location').optional(), @@ -205,6 +384,7 @@ describe('Tool Context', () => { it('should handle tools with context but no parameters', async () => { const greetUser = tool({ + name: 'greetUser', description: 'Greet the current user', execute: async (_, { ctx }: ToolOptions<{ username: string }>) => { return `Hello, ${ctx.userData.username}!`; @@ -217,6 +397,7 @@ describe('Tool Context', () => { it('should create a tool that accesses tool call id without parameters', async () => { const getCallId = tool({ + name: 'getCallId', description: 'Get the current tool call ID', execute: async (_, { toolCallId }) => { return `Tool call ID: ${toolCallId}`; @@ -231,6 +412,7 @@ describe('Tool Context', () => { describe('Zod v3 and v4 compatibility', () => { it('should work with Zod v3 schemas', async () => { const v3Tool = tool({ + name: 'v3Tool', description: 'A tool using Zod v3 schema', parameters: z3.object({ name: z3.string(), @@ -250,6 +432,7 @@ describe('Tool Context', () => { it('should work with Zod v4 schemas', async () => { const v4Tool = tool({ + name: 'v4Tool', description: 'A tool using Zod v4 schema', parameters: z4.object({ name: z4.string(), @@ -269,6 +452,7 @@ describe('Tool Context', () => { it('should handle v4 schemas with optional fields', async () => { const v4Tool = tool({ + name: 'v4OptionalTool', description: 'Tool with optional field using v4', parameters: z4.object({ required: z4.string(), @@ -291,6 +475,7 @@ describe('Tool Context', () => { it('should handle v4 enum schemas', async () => { const v4Tool = tool({ + name: 'v4EnumTool', description: 'Tool with enum using v4', parameters: z4.object({ color: z4.enum(['red', 'blue', 'green']), @@ -306,6 +491,7 @@ describe('Tool Context', () => { it('should handle v4 array schemas', async () => { const v4Tool = tool({ + name: 'v4ArrayTool', description: 'Tool with array using v4', parameters: z4.object({ tags: z4.array(z4.string()), @@ -324,6 +510,7 @@ describe('Tool Context', () => { it('should handle v4 nested object schemas', async () => { const v4Tool = tool({ + name: 'v4NestedTool', description: 'Tool with nested object using v4', parameters: z4.object({ user: z4.object({ @@ -405,3 +592,304 @@ describe('Tool Context', () => { }); }); }); + +describe('tool() name requirement', () => { + it('creates an anonymous function tool when name is missing', () => { + const t = tool({ + description: 'no name', + execute: async () => 'x', + }); + + expect(t.description).toBe('no name'); + expect('name' in t).toBe(false); + expect('id' in t).toBe(false); + }); + + it('throws when name is empty', () => { + expect(() => + tool({ + name: '', + description: 'empty name', + execute: async () => 'x', + }), + ).toThrow('requires a non-empty name'); + }); + + it('stores the name on the returned function tool', () => { + const t = tool({ + name: 'doStuff', + description: 'd', + execute: async () => 'x', + }); + expect(t.name).toBe('doStuff'); + }); + + it('exposes id mirroring the function tool name', () => { + const t = tool({ + name: 'doStuff', + description: 'd', + execute: async () => 'x', + }); + expect(t.id).toBe('doStuff'); + expect(t.id).toBe(t.name); + }); +}); + +class TestProviderTool extends ProviderTool {} + +describe('ToolContext', () => { + const makeFn = (name: string) => + tool({ + name, + description: `${name} tool`, + execute: async () => name, + }); + + it('empty() returns an empty context', () => { + const ctx = ToolContext.empty(); + expect(ctx.functionTools).toEqual({}); + expect(ctx.providerTools).toEqual([]); + expect(ctx.toolsets).toEqual([]); + expect(ctx.flatten()).toEqual([]); + }); + + it('indexes function tools by name and supports lookup', () => { + const a = makeFn('a'); + const b = makeFn('b'); + const ctx = new ToolContext([a, b]); + + expect(ctx.functionTools).toEqual({ a, b }); + expect(ctx.getFunctionTool('a')).toBe(a); + expect(ctx.getFunctionTool('b')).toBe(b); + expect(ctx.getFunctionTool('missing')).toBeUndefined(); + }); + + it('throws on duplicate function names with different instances', () => { + // Matches Python's `if existing is not tool: raise ValueError(...)` — silently overriding + // a registered tool would mask a real bug at the caller (two distinct functions colliding + // on a single advertised name). + const a1 = makeFn('a'); + const a2 = makeFn('a'); + expect(() => new ToolContext([a1, a2])).toThrow('duplicate function name: a'); + }); + + it('silently skips the same function tool instance listed multiple times', () => { + // Matches Python's `return # same instance, skip` branch. Useful when a tool gets + // included both directly and via a future Toolset that re-exports it. + const a = makeFn('a'); + const ctx = new ToolContext([a, a]); + expect(ctx.getFunctionTool('a')).toBe(a); + expect(Object.keys(ctx.functionTools)).toEqual(['a']); + }); + + it('separates provider tools from function tools', () => { + const fnA = makeFn('a'); + const provider = new TestProviderTool({ id: 'code' }); + const ctx = new ToolContext([fnA, provider]); + + expect(ctx.functionTools).toEqual({ a: fnA }); + expect(ctx.providerTools).toEqual([provider]); + expect(ctx.flatten()).toEqual([fnA, provider]); + }); + + it('updateTools replaces the entire context', () => { + const a = makeFn('a'); + const b = makeFn('b'); + const ctx = new ToolContext([a]); + ctx.updateTools([b]); + expect(ctx.getFunctionTool('a')).toBeUndefined(); + expect(ctx.getFunctionTool('b')).toBe(b); + }); + + it('copy() yields an independent context with the same tools', () => { + const a = makeFn('a'); + const ctx = new ToolContext([a]); + const dup = ctx.copy(); + + expect(dup.getFunctionTool('a')).toBe(a); + dup.updateTools([]); + expect(ctx.getFunctionTool('a')).toBe(a); + expect(dup.getFunctionTool('a')).toBeUndefined(); + }); + + it('equals() compares function tool maps and provider lists by identity', () => { + const a = makeFn('a'); + const b = makeFn('b'); + const c = makeFn('c'); + + expect(new ToolContext([a, b]).equals(new ToolContext([a, b]))).toBe(true); + expect(new ToolContext([a, b]).equals(new ToolContext([a]))).toBe(false); + expect(new ToolContext([a, b]).equals(new ToolContext([a, c]))).toBe(false); + }); + + it('equals() is reflexive', () => { + const a = makeFn('a'); + const provider = new TestProviderTool({ id: 'code' }); + const ctx = new ToolContext([a, provider]); + expect(ctx.equals(ctx)).toBe(true); + }); + + it('equals() treats provider tool order as insignificant', () => { + // Matches Python's `set(id(t) for t in self._provider_tools)` comparison: two contexts + // that hold the same provider-tool identities in different order are still equal so + // realtime-session / preemptive-generation reuse fast paths are not invalidated. + const a = makeFn('a'); + const p1 = new TestProviderTool({ id: 'code' }); + const p2 = new TestProviderTool({ id: 'browser' }); + expect(new ToolContext([a, p1, p2]).equals(new ToolContext([a, p2, p1]))).toBe(true); + }); + + it('equals() supports contexts with only provider tools', () => { + const p1 = new TestProviderTool({ id: 'code' }); + const p2 = new TestProviderTool({ id: 'browser' }); + expect(new ToolContext([p1, p2]).equals(new ToolContext([p1, p2]))).toBe(true); + const p3 = new TestProviderTool({ id: 'code' }); // distinct identity, same id + expect(new ToolContext([p1]).equals(new ToolContext([p3]))).toBe(false); + }); + + it('hasTool() matches function tools by name and provider tools by id', () => { + const a = makeFn('a'); + const provider = new TestProviderTool({ id: 'code_runner' }); + const ctx = new ToolContext([a, provider]); + + expect(ctx.hasTool('a')).toBe(true); + expect(ctx.hasTool('code_runner')).toBe(true); + expect(ctx.hasTool('missing')).toBe(false); + }); + + it('flatten() returns function tools in insertion order followed by provider tools', () => { + // Matches Python's `flatten()`: list(self._fnc_tools_map.values()) + self._provider_tools. + const a = makeFn('a'); + const b = makeFn('b'); + const provider = new TestProviderTool({ id: 'code' }); + const ctx = new ToolContext([b, provider, a]); + + expect(ctx.flatten()).toEqual([b, a, provider]); + }); +}); + +describe('Toolset', () => { + const makeFn = (name: string) => + tool({ + name, + description: `${name} tool`, + execute: async () => name, + }); + + it('exposes its id and the tools it was constructed with', () => { + const a = makeFn('a'); + const b = makeFn('b'); + const ts = new Toolset({ id: 'set1', tools: [a, b] }); + + expect(ts.id).toBe('set1'); + expect(ts.tools).toEqual([a, b]); + }); + + const fakeToolsetContext = ( + updateTools: (tools: readonly Tool[]) => void = () => {}, + ): ToolsetContext => ({ updateTools }); + + it('default setup and aclose are no-ops', async () => { + const ts = new Toolset({ id: 'noop', tools: [] }); + await expect(ts.setup(fakeToolsetContext())).resolves.toBeUndefined(); + await expect(ts.aclose()).resolves.toBeUndefined(); + }); + + it('lets subclasses override lifecycle hooks', async () => { + const events: string[] = []; + class Recording extends Toolset { + override async setup(_ctx: ToolsetContext): Promise { + events.push(`setup:${this.id}`); + } + override async aclose(): Promise { + events.push(`close:${this.id}`); + } + } + + const ts = new Recording({ id: 'rec', tools: [] }); + await ts.setup(fakeToolsetContext()); + await ts.aclose(); + expect(events).toEqual(['setup:rec', 'close:rec']); + }); + + it('Toolset.create() resolves a static tools list eagerly and composes aclose', async () => { + const a = makeFn('a'); + const events: string[] = []; + const ts = Toolset.create({ + id: 'composed', + tools: [a], + aclose: async () => { + events.push('close'); + }, + }); + + expect(ts).toBeInstanceOf(Toolset); + expect(ts.id).toBe('composed'); + expect(ts.tools).toEqual([a]); // static tools available before activation + + await ts.setup(fakeToolsetContext()); + await ts.aclose(); + expect(events).toEqual(['close']); + }); + + it('Toolset.create() defaults aclose to a no-op when omitted', async () => { + const ts = Toolset.create({ id: 'bare', tools: [] }); + await expect(ts.setup(fakeToolsetContext())).resolves.toBeUndefined(); + await expect(ts.aclose()).resolves.toBeUndefined(); + }); + + it('lets setup push tools after activation via ctx.updateTools', async () => { + const a = makeFn('a'); + const b = makeFn('b'); + let push!: (tools: readonly Tool[]) => void; + const ts = Toolset.create({ + id: 'mcp', + setup: async ({ updateTools }) => { + push = updateTools; + }, + tools: [], + }); + + // Mimic the runtime: ctx.updateTools writes the toolset's current tools. + await ts.setup(fakeToolsetContext((tools) => ts._setTools(tools))); + expect(ts.tools).toEqual([]); + + // A dynamic source (e.g. an MCP server) pushes its tools after connecting. + push([a, b]); + expect(ts.tools).toEqual([a, b]); + }); + + it('is flattened into a ToolContext: function tools merged, toolset tracked', () => { + const a = makeFn('a'); + const b = makeFn('b'); + const ts = new Toolset({ id: 'set', tools: [a, b] }); + const direct = makeFn('direct'); + + const ctx = new ToolContext([direct, ts]); + + expect(Object.keys(ctx.functionTools).sort()).toEqual(['a', 'b', 'direct']); + expect(ctx.toolsets).toEqual([ts]); + }); + + it('throws when a Toolset contributes a duplicate function name', () => { + // Mirrors Python's `add_tool`: a name collision between top-level and toolset-contributed + // tools is an error, not silent overwrite. + const a1 = makeFn('a'); + const a2 = makeFn('a'); + const ts = new Toolset({ id: 'collides', tools: [a2] }); + + expect(() => new ToolContext([a1, ts])).toThrow(/duplicate function name: a/); + }); + + it('equals() compares toolsets as identity sets, not by order', () => { + // Matches Python's `{id(ts) for ts in self._tool_sets}` semantics. + const ts1 = new Toolset({ id: 'one', tools: [] }); + const ts2 = new Toolset({ id: 'two', tools: [] }); + + expect(new ToolContext([ts1, ts2]).equals(new ToolContext([ts2, ts1]))).toBe(true); + + const ts3 = new Toolset({ id: 'three', tools: [] }); + expect(new ToolContext([ts1, ts2]).equals(new ToolContext([ts1, ts3]))).toBe(false); + expect(new ToolContext([ts1]).equals(new ToolContext([ts1, ts2]))).toBe(false); + }); +}); diff --git a/agents/src/llm/tool_context.ts b/agents/src/llm/tool_context.ts index 0e66ebc0d..6b947a021 100644 --- a/agents/src/llm/tool_context.ts +++ b/agents/src/llm/tool_context.ts @@ -3,16 +3,18 @@ // SPDX-License-Identifier: Apache-2.0 import type { JSONSchema7 } from 'json-schema'; import { z } from 'zod'; +import * as z4 from 'zod/v4'; import type { Agent } from '../voice/agent.js'; import type { RunContext, UnknownUserData } from '../voice/run_context.js'; -import { isZodObjectSchema, isZodSchema } from './zod-utils.js'; +import { isZod4Schema, isZodObjectSchema, isZodSchema } from './zod-utils.js'; // heavily inspired by Vercel AI's `tool()`: // https://github.com/vercel/ai/blob/3b0983b/packages/ai/core/tool/tool.ts const TOOL_SYMBOL = Symbol('tool'); const FUNCTION_TOOL_SYMBOL = Symbol('function_tool'); -const PROVIDER_DEFINED_TOOL_SYMBOL = Symbol('provider_defined_tool'); +const PROVIDER_TOOL_SYMBOL = Symbol('provider_tool'); +const TOOLSET_SYMBOL = Symbol('toolset'); const TOOL_ERROR_SYMBOL = Symbol('tool_error'); const HANDOFF_SYMBOL = Symbol('handoff'); @@ -57,7 +59,17 @@ export type InferToolInput = T extends { _output: infer O } ? O : any; // eslint-disable-line @typescript-eslint/no-explicit-any -- Fallback type for JSON Schema objects without type inference -export type ToolType = 'function' | 'provider-defined'; +/** + * Tool argument type for a (possibly absent) parameters schema. When `parameters` is omitted the + * generic defaults to `undefined`, yielding an empty-args type; otherwise the args are inferred + * from the schema via {@link InferToolInput}. Wrapped in tuples to keep the check non-distributive. + * @internal + */ +export type ToolArgs = [Schema] extends [undefined] + ? Record + : InferToolInput; + +export type ToolType = 'function' | 'provider'; export type ToolChoice = | 'auto' @@ -70,6 +82,14 @@ export type ToolChoice = }; }; +export type DuplicateMode = 'allow' | 'reject' | 'replace' | 'confirm'; + +export const CONFIRM_DUPLICATE_PARAM = 'lk_agents_confirm_duplicate'; + +const CONFIRM_DUPLICATE_DESCRIPTION = + 'Set this to true to confirm you want to run a duplicate. ' + + 'Only do this when user confirms the duplication is needed.'; + export class ToolError extends Error { constructor(message: string) { super(message); @@ -83,6 +103,7 @@ export class ToolError extends Error { export const ToolFlag = { NONE: 0, IGNORE_ON_ENTER: 1 << 0, + CANCELLABLE: 1 << 1, } as const; export type ToolFlag = (typeof ToolFlag)[keyof typeof ToolFlag]; @@ -136,37 +157,47 @@ export type ToolExecuteFunction< export interface Tool { /** * The type of the tool. - * @internal Either user-defined core tool or provider-defined tool. + * @internal Either user-defined function tool or provider-side tool. */ type: ToolType; + /** + * Stable identifier used to key the tool inside a `ToolContext`. For function tools this + * mirrors `name`; for provider tools this is the provider tool id. + */ + id: string; + [TOOL_SYMBOL]: true; } -// TODO(AJS-112): support provider-defined tools -export interface ProviderDefinedTool extends Tool { - type: 'provider-defined'; +// TODO(AJS-112): support provider tools +export abstract class ProviderTool implements Tool { + readonly type = 'provider' as const; - /** - * The ID of the tool. - */ - id: string; + readonly id: string; - /** - * The configuration of the tool. - */ - config: Record; + readonly [TOOL_SYMBOL] = true as const; + + readonly [PROVIDER_TOOL_SYMBOL] = true as const; - [PROVIDER_DEFINED_TOOL_SYMBOL]: true; + constructor({ id }: { id: string }) { + this.id = id; + } } export interface FunctionTool< - Parameters extends JSONObject, + Parameters extends JSONObject = JSONObject, UserData = UnknownUserData, Result = unknown, > extends Tool { type: 'function'; + /** + * The name of the tool. Used to identify it inside a `ToolContext` and exposed to the LLM + * as the function name to call. Also surfaced as the inherited `Tool.id`. + */ + name: string; + /** * The description of the tool. Will be used by the language model to decide whether to use the tool. */ @@ -187,54 +218,384 @@ export interface FunctionTool< flags: number; + onDuplicate: DuplicateMode; + [FUNCTION_TOOL_SYMBOL]: true; } -// TODO(AJS-112): support provider-defined tools in the future) -export type ToolContext = { - // eslint-disable-next-line @typescript-eslint/no-explicit-any -- Generic tool registry needs to accept any parameter/result types - [name: string]: FunctionTool; +export type AnonFunctionTool< + Parameters extends JSONObject = JSONObject, + UserData = UnknownUserData, + Result = unknown, +> = Omit, 'id' | 'name'> & { + id?: never; + name?: never; }; -/** @internal */ +export interface ToolCalledEvent { + ctx: RunContext; + arguments: Record; +} + +export interface ToolCompletedEvent { + ctx: RunContext; + output?: { type: 'output'; value: unknown } | { type: 'error'; value: Error }; +} + +/** Context passed to a {@link Toolset}'s `setup` hook when it activates. */ +export interface ToolsetContext { + /** + * Replace the toolset's tools. Useful for dynamic sources + * (e.g. an MCP server) whose tools are discovered after `setup` or change at runtime. + */ + updateTools(tools: readonly ToolContextEntry[]): void; +} + +/** + * Function tools of a `ToolContext`, sorted by name for deterministic provider payloads. + * Provider tools are intentionally excluded — callers that need them iterate `flatten()`. + * @internal + */ export function sortedToolEntries( toolCtx: ToolContext, -): Array<[string, ToolContext[string]]> { - return Object.entries(toolCtx).sort(([nameA], [nameB]) => nameA.localeCompare(nameB)); + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- entries are generic function tools +): Array<[string, FunctionTool]> { + return Object.entries(toolCtx.functionTools).sort(([nameA], [nameB]) => + nameA.localeCompare(nameB), + ); } -/** @internal */ +/** Function tool names of a `ToolContext`, sorted for deterministic output. @internal */ export function sortedToolNames(toolCtx: ToolContext | undefined): string[] { if (!toolCtx) return []; - return Object.keys(toolCtx).sort((nameA, nameB) => nameA.localeCompare(nameB)); + return Object.keys(toolCtx.functionTools).sort((nameA, nameB) => nameA.localeCompare(nameB)); } -export function isSameToolContext(ctx1: ToolContext, ctx2: ToolContext): boolean { - const toolNames = new Set(Object.keys(ctx1)); - const toolNames2 = new Set(Object.keys(ctx2)); +/** + * A stateful collection of tools sharing a lifecycle. Tools registered through a `Toolset` are + * flattened into the surrounding `ToolContext`, while the `Toolset` itself is tracked so its + * `setup()` / `aclose()` hooks can be invoked by the agent runtime. + */ +export class Toolset { + readonly #id: string; + + #tools: readonly ToolContextEntry[]; - if (toolNames.size !== toolNames2.size) { - return false; + readonly [TOOLSET_SYMBOL] = true as const; + + constructor({ id, tools }: { id: string; tools: readonly ToolContextEntry[] }) { + this.#id = id; + this.#tools = [...tools]; + } + + /** + * For when your tools share something that needs setup or cleanup, like a DB pool, an open MCP + * client, or listeners on a shared bus. `setup` runs once at activation, `aclose` once at + * teardown. If the tool list itself is dynamic (e.g. an MCP server), push it from `setup` via + * {@link ToolsetContext.updateTools}. + * + * @example Static tool list with a shared backing resource + * ```ts + * function createPostgresToolset(connectionUrl: string): Toolset { + * const pool = new pg.Pool({ connectionString: connectionUrl }); + * return Toolset.create({ + * id: 'postgres', + * tools: [queryOrders, queryCustomers], + * aclose: () => pool.end(), + * }); + * } + * ``` + * + * @example Dynamic tool list bound to an external source + * ```ts + * function createMcpToolset(url: string): Toolset { + * const client = new MCPClient({ url }); + * return Toolset.create({ + * id: 'mcp_remote', + * // setup connects and wires listeners that push the server's tools whenever they change; + * // the runtime re-advertises without re-running anything. + * setup: async ({ updateTools }) => { + * const sync = async () => updateTools(await client.listTools()); + * client.on('connect', sync); + * client.on('tool_list_changed', sync); + * await client.connect(); + * }, + * tools: [], + * aclose: () => client.disconnect(), + * }); + * } + * ``` + */ + static create(options: ToolsetCreateOptions): Toolset { + return new ToolsetFactory(options); + } + + get id(): string { + return this.#id; + } + + get tools(): readonly ToolContextEntry[] { + return this.#tools; } - for (const name of toolNames) { - if (!toolNames2.has(name)) { + /** + * Replace the toolset's current tools. Backs {@link ToolsetContext.updateTools}; the runtime + * re-flattens and re-advertises after calling it. + * + * @internal + */ + _setTools(tools: readonly ToolContextEntry[]): void { + this.#tools = [...tools]; + } + + async setup(_ctx: ToolsetContext): Promise {} + + async aclose(): Promise {} +} + +/** Options accepted by `Toolset.create()` — id + tools plus optional setup/teardown hooks. */ +export interface ToolsetCreateOptions { + id: string; + /** + * One-time async initialization run when the toolset activates — e.g. connecting to a server + * and wiring listeners. Push a changed tool list via {@link ToolsetContext.updateTools}. + */ + setup?: (ctx: ToolsetContext) => Promise; + /** The toolset's initial tools. */ + tools: readonly ToolContextEntry[]; + /** Invoked when the toolset is being torn down. Release awaitable resources here. */ + aclose?: () => Promise; +} + +/** Backing implementation of `Toolset.create()`. Kept private so callers go through the factory. */ +class ToolsetFactory extends Toolset { + readonly #setupFn?: (ctx: ToolsetContext) => Promise; + + readonly #acloseFn?: () => Promise; + + constructor({ id, setup, tools, aclose }: ToolsetCreateOptions) { + super({ id, tools }); + this.#setupFn = setup; + this.#acloseFn = aclose; + } + + override async setup(ctx: ToolsetContext): Promise { + if (this.#setupFn) await this.#setupFn(ctx); + } + + override async aclose(): Promise { + if (this.#acloseFn) await this.#acloseFn(); + } +} + +/** + * Tool context or data that can be normalized into one. Used by APIs that accept an already-built + * context as well as direct tool lists or tool maps. + */ +export type ToolContextLike = + | ToolContext + | ToolContextInit; + +/** + * Initial tool data accepted by `ToolContext` constructors and update methods. + */ +export type ToolContextInit = + | readonly ToolContextEntry[] + | ToolDefinitionMap; + +/** + * Object shorthand for declaring anonymous function tools keyed by their model-visible names. + */ +export type ToolDefinitionMap = { + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- entries accept any parameter/result types + readonly [toolName: string]: AnonFunctionTool; +}; + +export function toToolContext( + input: ToolContextLike, +): ToolContext; + +export function toToolContext( + input: ToolContextLike | undefined, +): ToolContext | undefined; + +export function toToolContext( + input: ToolContextLike | undefined, +): ToolContext | undefined { + if (input === undefined) return undefined; + return input instanceof ToolContext ? input : new ToolContext(input); +} + +export function normalizeToolContextInit( + input: ToolContextInit, +): ToolContextEntry[] { + if (Array.isArray(input)) { + return [...input]; + } + + return Object.entries(input).map(([name, toolValue]) => { + if (name.length === 0) { + throw new Error('tools object keys must be non-empty'); + } + + if (!isAnonFunctionTool(toolValue)) { + throw new Error(`tools object entry "${name}" must be an anonymous function tool`); + } + + if ('name' in toolValue || 'id' in toolValue) { + throw new Error(`tools object entry "${name}" must be anonymous`); + } + + return { + ...toolValue, + id: name, + name, + }; + }); +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any -- ToolContext entries accept any function-tool parameter/result types +export type ToolContextEntry = + // eslint-disable-next-line @typescript-eslint/no-explicit-any + FunctionTool | ProviderTool | Toolset; + +export class ToolContext { + private _tools: ToolContextEntry[] = []; + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- ToolContext stores generic function tools + private _functionToolsMap: Map> = new Map(); + private _providerTools: ProviderTool[] = []; + private _toolsets: Toolset[] = []; + + constructor(tools: ToolContextInit = []) { + this.updateTools(tools); + } + + static empty(): ToolContext { + return new ToolContext([]); + } + + /** A copy of all function tools in the tool context, including those in tool sets. */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + get functionTools(): Record> { + return Object.fromEntries(this._functionToolsMap); + } + + /** A copy of all provider tools in the tool context, including those in tool sets. */ + get providerTools(): ProviderTool[] { + return [...this._providerTools]; + } + + /** A copy of all toolsets registered in the context. */ + get toolsets(): readonly Toolset[] { + return [...this._toolsets]; + } + + /** + * A copy of the raw tool list this context was constructed with. + */ + get tools(): readonly ToolContextEntry[] { + return [...this._tools]; + } + + /** Flatten the tool context to a list of tools. */ + flatten(): Tool[] { + return [...this._functionToolsMap.values(), ...this._providerTools]; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- Generic registry over any parameter/result types + getFunctionTool(id: string): FunctionTool | undefined { + return this._functionToolsMap.get(id); + } + + hasTool(id: string): boolean { + if (this._functionToolsMap.has(id)) { + return true; + } + return this._providerTools.some((tool) => tool.id === id); + } + + updateTools(tools: ToolContextInit): void { + const normalizedTools = normalizeToolContextInit(tools); + this._tools = normalizedTools; + this._functionToolsMap = new Map(); + this._providerTools = []; + this._toolsets = []; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- accepts any tool shape + const addTool = (tool: any): void => { + if (isToolset(tool)) { + for (const inner of tool.tools) { + addTool(inner); + } + this._toolsets.push(tool); + return; + } + + if (isProviderTool(tool)) { + this._providerTools.push(tool); + return; + } + + if (isFunctionTool(tool)) { + const existing = this._functionToolsMap.get(tool.id); + if (existing !== undefined) { + if (existing !== tool) { + throw new Error(`duplicate function name: ${tool.id}`); + } + return; // same instance, skip + } + this._functionToolsMap.set(tool.id, tool); + return; + } + + throw new Error(`unknown tool type: ${typeof tool}`); + }; + + for (const tool of normalizedTools) { + addTool(tool); + } + } + + copy(): ToolContext { + return new ToolContext([...this._tools]); + } + + equals(other: ToolContext): boolean { + if (this._functionToolsMap.size !== other._functionToolsMap.size) { return false; } - const tool1 = ctx1[name]; - const tool2 = ctx2[name]; + for (const [id, tool] of this._functionToolsMap) { + if (other._functionToolsMap.get(id) !== tool) { + return false; + } + } - if (!tool1 || !tool2) { + if (this._providerTools.length !== other._providerTools.length) { return false; } - if (tool1.description !== tool2.description) { + // Provider tools compare as identity sets to match Python's `set(id(t) for t in ...)` + // semantics — order is not significant. + const otherProviderIds = new Set(other._providerTools); + for (const tool of this._providerTools) { + if (!otherProviderIds.has(tool)) { + return false; + } + } + + if (this._toolsets.length !== other._toolsets.length) { return false; } - } - return true; + const otherToolsets = new Set(other._toolsets); + for (const ts of this._toolsets) { + if (!otherToolsets.has(ts)) { + return false; + } + } + return true; + } } export function isSameToolChoice(choice1: ToolChoice | null, choice2: ToolChoice | null): boolean { @@ -254,117 +615,230 @@ export function isSameToolChoice(choice1: ToolChoice | null, choice2: ToolChoice } /** - * Create a function tool with inferred parameters from the schema. + * Create a function tool. Parameters are inferred from the schema; omit `parameters` for a tool + * that takes no arguments. */ export function tool< - Schema extends ToolInputSchema, // eslint-disable-line @typescript-eslint/no-explicit-any -- Generic constraint needs to accept any JSONObject type UserData = UnknownUserData, + Schema extends ToolInputSchema | undefined = undefined, // eslint-disable-line @typescript-eslint/no-explicit-any -- Generic constraint needs to accept any JSONObject type Result = unknown, >({ + name, description, parameters, execute, flags, + onDuplicate, }: { + /** Unique name the model calls the tool by. Must be non-empty. */ + name: string; + /** Natural-language description that tells the model when to use this tool. */ description: string; - parameters: Schema; - execute: ToolExecuteFunction, UserData, Result>; + /** + * Input schema for the tool's arguments — either a Zod object schema (args + * are type-inferred) or a raw JSON Schema. Omit for a tool that takes no + * arguments. + */ + parameters?: Schema; + /** + * Called when the model invokes the tool. Receives the parsed arguments (an + * empty object when `parameters` is omitted) and a {@link RunContext} + * (`ctx`); the returned value is sent back to the model. + */ + execute: ToolExecuteFunction, UserData, Result>; + /** + * Bitmask of {@link ToolFlag}s, e.g. `ToolFlag.CANCELLABLE` to allow the call + * to be cancelled mid-flight. Defaults to `ToolFlag.NONE`. + */ flags?: number; -}): FunctionTool, UserData, Result>; + /** + * How a concurrent duplicate call of this tool is handled while one is still + * running: `'allow'` | `'reject'` | `'replace'` | `'confirm'`. Defaults to + * `'allow'`. + */ + onDuplicate?: DuplicateMode; +}): FunctionTool, UserData, Result>; /** - * Create a function tool without parameters. + * Create an anonymous (name-less) function tool. Parameters are inferred from the schema; omit + * `parameters` for a tool that takes no arguments. */ -export function tool({ +export function tool< + UserData = UnknownUserData, + Schema extends ToolInputSchema | undefined = undefined, // eslint-disable-line @typescript-eslint/no-explicit-any -- Generic constraint needs to accept any JSONObject type + Result = unknown, +>({ description, + parameters, execute, flags, + onDuplicate, }: { + /** Omitted in object syntax; the containing object key becomes the tool name. */ + name?: never; + /** Natural-language description that tells the model when to use this tool. */ description: string; - parameters?: never; - execute: ToolExecuteFunction, UserData, Result>; + /** + * Input schema for the tool's arguments — either a Zod object schema (args + * are type-inferred) or a raw JSON Schema. Omit for a tool that takes no + * arguments. + */ + parameters?: Schema; + /** + * Called when the model invokes the tool. Receives the parsed arguments (an + * empty object when `parameters` is omitted) and a {@link RunContext} + * (`ctx`); the returned value is sent back to the model. + */ + execute: ToolExecuteFunction, UserData, Result>; + /** + * Bitmask of {@link ToolFlag}s, e.g. `ToolFlag.CANCELLABLE` to allow the call + * to be cancelled mid-flight. Defaults to `ToolFlag.NONE`. + */ flags?: number; -}): FunctionTool, UserData, Result>; - -/** - * Create a provider-defined tool. - * - * @param id - The ID of the tool. - * @param config - The configuration of the tool. - */ -export function tool({ - id, - config, -}: { - id: string; - config: Record; -}): ProviderDefinedTool; + /** + * How a concurrent duplicate call of this tool is handled while one is still + * running: `'allow'` | `'reject'` | `'replace'` | `'confirm'`. Defaults to + * `'allow'`. + */ + onDuplicate?: DuplicateMode; +}): AnonFunctionTool, UserData, Result>; // eslint-disable-next-line @typescript-eslint/no-explicit-any export function tool(tool: any): any { - if (tool.execute !== undefined) { - // Default parameters to z.object({}) if not provided - const parameters = tool.parameters ?? z.object({}); + if (tool.name !== undefined && (typeof tool.name !== 'string' || tool.name.length === 0)) { + throw new Error('tool({ name, ... }) requires a non-empty name'); + } - // if parameters is a Zod schema, ensure it's an object schema - if (isZodSchema(parameters) && !isZodObjectSchema(parameters)) { - throw new Error('Tool parameters must be a Zod object schema (z.object(...))'); - } + const onDuplicate: DuplicateMode = tool.onDuplicate ?? 'allow'; - // Ensure parameters is either a Zod schema or a plain object (JSON schema) - if (!isZodSchema(parameters) && !(typeof parameters === 'object')) { - throw new Error('Tool parameters must be a Zod object schema or a raw JSON schema'); - } + // Default parameters to z.object({}) if not provided + let parameters = tool.parameters ?? z.object({}); - return { - type: 'function', - description: tool.description, - parameters, - execute: tool.execute, - flags: tool.flags ?? ToolFlag.NONE, - [TOOL_SYMBOL]: true, - [FUNCTION_TOOL_SYMBOL]: true, - }; + // if parameters is a Zod schema, ensure it's an object schema + if (isZodSchema(parameters) && !isZodObjectSchema(parameters)) { + throw new Error('Tool parameters must be a Zod object schema (z.object(...))'); } - if (tool.config !== undefined && tool.id !== undefined) { - return { - type: 'provider-defined', - id: tool.id, - config: tool.config, - [TOOL_SYMBOL]: true, - [PROVIDER_DEFINED_TOOL_SYMBOL]: true, + // Ensure parameters is either a Zod schema or a plain object (JSON schema) + if (!isZodSchema(parameters) && !(typeof parameters === 'object')) { + throw new Error('Tool parameters must be a Zod object schema or a raw JSON schema'); + } + + if (onDuplicate === 'confirm') { + parameters = injectConfirmDuplicateParameter(parameters); + } + + const execute = + onDuplicate === 'confirm' ? wrapConfirmDuplicateExecute(tool.execute) : tool.execute; + + const functionTool = { + type: 'function', + description: tool.description, + parameters, + execute, + flags: tool.flags ?? ToolFlag.NONE, + onDuplicate, + [TOOL_SYMBOL]: true, + [FUNCTION_TOOL_SYMBOL]: true, + }; + + if (tool.name === undefined) { + return functionTool; + } + + return { + ...functionTool, + id: tool.name, + name: tool.name, + }; +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +function injectConfirmDuplicateParameter(parameters: any): any { + if (isZodSchema(parameters)) { + const maybeObjectSchema = parameters as { + extend?: (shape: Record) => unknown; }; + if (typeof maybeObjectSchema.extend === 'function') { + const confirmSchema = isZod4Schema(parameters) + ? z4.boolean().nullable().describe(CONFIRM_DUPLICATE_DESCRIPTION) + : z.boolean().nullable().describe(CONFIRM_DUPLICATE_DESCRIPTION); + return maybeObjectSchema.extend({ [CONFIRM_DUPLICATE_PARAM]: confirmSchema }); + } + throw new Error('Tool parameters must be a Zod object schema (z.object(...))'); + } + + const properties = { + ...(parameters.properties ?? {}), + [CONFIRM_DUPLICATE_PARAM]: { + type: ['boolean', 'null'], + description: CONFIRM_DUPLICATE_DESCRIPTION, + }, + }; + const required = [...(parameters.required ?? [])]; + if (!required.includes(CONFIRM_DUPLICATE_PARAM)) { + required.push(CONFIRM_DUPLICATE_PARAM); } - throw new Error('Invalid tool'); + return { + ...parameters, + properties, + required, + }; +} + +function wrapConfirmDuplicateExecute< + Parameters extends JSONObject, + UserData = UnknownUserData, + Result = unknown, +>( + execute: ToolExecuteFunction, +): ToolExecuteFunction { + return async (args, opts) => { + if (args && typeof args === 'object' && !Array.isArray(args)) { + const stripped = { ...args }; + delete stripped[CONFIRM_DUPLICATE_PARAM]; + return execute(stripped, opts); + } + return execute(args, opts); + }; } // eslint-disable-next-line @typescript-eslint/no-explicit-any export function isTool(tool: any): tool is Tool { - return tool && tool[TOOL_SYMBOL] === true; + return !!tool && tool[TOOL_SYMBOL] === true; +} + +export function isFunctionTool(tool: unknown): tool is FunctionTool { + const maybeTool = tool as Partial; + return ( + isAnonFunctionTool(tool) && + typeof maybeTool.id === 'string' && + typeof maybeTool.name === 'string' + ); +} + +function isAnonFunctionTool(tool: unknown): tool is AnonFunctionTool { + const maybeTool = tool as Partial; + return isTool(tool) && maybeTool[FUNCTION_TOOL_SYMBOL] === true; } // eslint-disable-next-line @typescript-eslint/no-explicit-any -export function isFunctionTool(tool: any): tool is FunctionTool { - const isTool = tool && tool[TOOL_SYMBOL] === true; - const isFunctionTool = tool[FUNCTION_TOOL_SYMBOL] === true; - return isTool && isFunctionTool; +export function isProviderTool(tool: any): tool is ProviderTool { + return isTool(tool) && (tool as ProviderTool)[PROVIDER_TOOL_SYMBOL] === true; } // eslint-disable-next-line @typescript-eslint/no-explicit-any -export function isProviderDefinedTool(tool: any): tool is ProviderDefinedTool { - const isTool = tool && tool[TOOL_SYMBOL] === true; - const isProviderDefinedTool = tool[PROVIDER_DEFINED_TOOL_SYMBOL] === true; - return isTool && isProviderDefinedTool; +export function isToolset(value: any): value is Toolset { + return !!value && value[TOOLSET_SYMBOL] === true; } // eslint-disable-next-line @typescript-eslint/no-explicit-any export function isToolError(error: any): error is ToolError { - return error && error[TOOL_ERROR_SYMBOL] === true; + return !!error && error[TOOL_ERROR_SYMBOL] === true; } // eslint-disable-next-line @typescript-eslint/no-explicit-any export function isAgentHandoff(handoff: any): handoff is AgentHandoff { - return handoff && handoff[HANDOFF_SYMBOL] === true; + return !!handoff && handoff[HANDOFF_SYMBOL] === true; } diff --git a/agents/src/llm/tool_context.type.test.ts b/agents/src/llm/tool_context.type.test.ts index 27cf9fe55..a62d2a5a7 100644 --- a/agents/src/llm/tool_context.type.test.ts +++ b/agents/src/llm/tool_context.type.test.ts @@ -3,11 +3,23 @@ // SPDX-License-Identifier: Apache-2.0 import { describe, expect, expectTypeOf, it } from 'vitest'; import { z } from 'zod'; -import { type FunctionTool, type ProviderDefinedTool, type ToolOptions, tool } from './index.js'; +import { + type AnonFunctionTool, + type FunctionTool, + ProviderTool, + type Tool, + ToolContext, + type ToolContextInit, + type ToolContextLike, + type ToolDefinitionMap, + type ToolOptions, + tool, +} from './tool_context.js'; describe('tool type inference', () => { it('should infer argument type from zod schema', () => { const toolType = tool({ + name: 'test', description: 'test', parameters: z.object({ number: z.number() }), execute: async () => 'test' as const, @@ -16,19 +28,103 @@ describe('tool type inference', () => { expectTypeOf(toolType).toEqualTypeOf>(); }); - it('should infer provider defined tool type', () => { + it('should infer argument type for an anonymous (name-less) tool with a schema', () => { + tool({ + description: 'test', + parameters: z.object({ number: z.number() }), + execute: async (args) => { + expectTypeOf(args).toEqualTypeOf<{ number: number }>(); + return `${args.number}` as const; + }, + }); + }); + + it('should infer empty args for an anonymous tool when parameters are omitted', () => { + const toolType = tool({ + description: 'test', + execute: async (args) => { + expectTypeOf(args).toEqualTypeOf>(); + return 'done' as const; + }, + }); + + expectTypeOf(toolType).toEqualTypeOf< + AnonFunctionTool, unknown, 'done'> + >(); + expectTypeOf(toolType).toMatchTypeOf<{ readonly id?: never; readonly name?: never }>(); + }); + + it('should infer empty args for an anonymous tool when parameters are explicitly undefined', () => { const toolType = tool({ - id: 'code-interpreter', - config: { - language: 'python', + description: 'test', + parameters: undefined, + execute: async (args) => { + expectTypeOf(args).toEqualTypeOf>(); + return 'done' as const; }, }); - expectTypeOf(toolType).toEqualTypeOf(); + expectTypeOf(toolType).toEqualTypeOf< + AnonFunctionTool, unknown, 'done'> + >(); + }); + + it('should infer args from an explicit empty object schema on anonymous tools', () => { + const parameters = z.object({}); + const toolType = tool({ + description: 'test', + parameters, + execute: async (args) => { + expectTypeOf(args).toMatchTypeOf>(); + expectTypeOf(args).toEqualTypeOf>(); + return 'done' as const; + }, + }); + + expectTypeOf(toolType).toEqualTypeOf< + AnonFunctionTool, unknown, 'done'> + >(); + }); + + it('should infer args from populated object schemas on anonymous tools', () => { + const parameters = z.object({ + location: z.string(), + units: z.enum(['celsius', 'fahrenheit']).optional(), + }); + const toolType = tool({ + description: 'test', + parameters, + execute: async (args) => { + expectTypeOf(args).toEqualTypeOf>(); + expectTypeOf(args.location).toEqualTypeOf(); + expectTypeOf(args.units).toEqualTypeOf<'celsius' | 'fahrenheit' | undefined>(); + return { location: args.location, ok: true } as const; + }, + }); + + expectTypeOf(toolType).toEqualTypeOf< + AnonFunctionTool< + z.infer, + unknown, + { readonly location: string; readonly ok: true } + > + >(); + }); + + it('rejects direct instantiation of the abstract ProviderTool base', () => { + // @ts-expect-error - ProviderTool is abstract; plugins must subclass it. + new ProviderTool({ id: 'code-interpreter' }); + + class CodeInterpreter extends ProviderTool {} + const providerTool = new CodeInterpreter({ id: 'code-interpreter' }); + expectTypeOf(providerTool).toMatchTypeOf(); + expect(providerTool.id).toBe('code-interpreter'); + expect(providerTool.type).toBe('provider'); }); it('should infer run context type', () => { const toolType = tool({ + name: 'test', description: 'test', parameters: z.object({ number: z.number() }), execute: async ({ number }, { ctx }: ToolOptions<{ name: string }>) => { @@ -43,7 +139,6 @@ describe('tool type inference', () => { it('should not accept primitive zod schemas', () => { expect(() => { - // @ts-expect-error - Testing that non-object schemas are rejected tool({ name: 'test', description: 'test', @@ -55,7 +150,6 @@ describe('tool type inference', () => { it('should not accept array schemas', () => { expect(() => { - // @ts-expect-error - Testing that array schemas are rejected tool({ name: 'test', description: 'test', @@ -67,7 +161,6 @@ describe('tool type inference', () => { it('should not accept union schemas', () => { expect(() => { - // @ts-expect-error - Testing that union schemas are rejected tool({ name: 'test', description: 'test', @@ -79,10 +172,10 @@ describe('tool type inference', () => { it('should not accept non-Zod values as parameters', () => { expect(() => { - // @ts-expect-error - Testing that non-Zod values are rejected tool({ name: 'test', description: 'test', + // @ts-expect-error - Testing that non-Zod values are rejected parameters: 'invalid schema', execute: async () => 'test' as const, }); @@ -91,6 +184,7 @@ describe('tool type inference', () => { it('should infer empty object type when parameters are omitted', () => { const toolType = tool({ + name: 'simpleAction', description: 'Simple action without parameters', execute: async () => 'done' as const, }); @@ -98,8 +192,29 @@ describe('tool type inference', () => { expectTypeOf(toolType).toEqualTypeOf, unknown, 'done'>>(); }); + it('names tool context input shapes by role', () => { + const objectTools = { + lookupOrder: tool({ + description: 'Look up an order', + execute: async () => 'done' as const, + }), + }; + const namedTool = tool({ + name: 'simpleAction', + description: 'Simple action', + execute: async () => 'done' as const, + }); + + expectTypeOf(objectTools).toMatchTypeOf(); + expectTypeOf(objectTools).toMatchTypeOf(); + expectTypeOf([namedTool]).toMatchTypeOf(); + expectTypeOf(new ToolContext(objectTools)).toMatchTypeOf(); + expectTypeOf(namedTool).toMatchTypeOf(); + }); + it('should infer correct types with context but no parameters', () => { const toolType = tool({ + name: 'actionWithCtx', description: 'Action with context', execute: async (args, { ctx }: ToolOptions<{ userId: number }>) => { expectTypeOf(args).toEqualTypeOf>(); diff --git a/agents/src/llm/utils.test.ts b/agents/src/llm/utils.test.ts index 154c76348..b6b41b266 100644 --- a/agents/src/llm/utils.test.ts +++ b/agents/src/llm/utils.test.ts @@ -13,7 +13,7 @@ import { FunctionCallOutput, type ImageContent, } from './chat_context.js'; -import { tool } from './tool_context.js'; +import { ToolContext, tool } from './tool_context.js'; import { computeChatCtxDiff, executeToolCall, @@ -182,6 +182,7 @@ describe('parseFunctionArguments', () => { describe('executeToolCall', () => { it('should canonicalize repaired arguments before returning', async () => { const removeOrderItem = tool({ + name: 'removeOrderItem', description: 'remove order item', parameters: z.object({ orderId: z.array(z.string()) }), execute: async ({ orderId }) => orderId.join(','), @@ -194,7 +195,7 @@ describe('executeToolCall', () => { args: rawArgs, }); - const result = await executeToolCall(toolCall, { removeOrderItem }); + const result = await executeToolCall(toolCall, new ToolContext([removeOrderItem])); expect(result.isError).toBe(false); expect(JSON.parse(result.output)).toBe('O_WAAB70'); @@ -204,6 +205,7 @@ describe('executeToolCall', () => { it('should preserve valid argument structure during execution', async () => { const echo = tool({ + name: 'echo', description: 'echo', parameters: z.object({ arg1: z.string(), optArg2: z.string().optional() }), execute: async ({ arg1, optArg2 }) => ({ arg1, optArg2 }), @@ -216,7 +218,7 @@ describe('executeToolCall', () => { args: originalArgs, }); - const result = await executeToolCall(toolCall, { echo }); + const result = await executeToolCall(toolCall, new ToolContext([echo])); expect(result.isError).toBe(false); expect(JSON.parse(result.output)).toEqual({ arg1: 'hello', optArg2: '<|safe|>' }); diff --git a/agents/src/llm/utils.ts b/agents/src/llm/utils.ts index c9fd7682f..0db2bb802 100644 --- a/agents/src/llm/utils.ts +++ b/agents/src/llm/utils.ts @@ -14,7 +14,12 @@ import { FunctionCallOutput, type ImageContent, } from './chat_context.js'; -import type { ToolContext, ToolInputSchema, ToolOptions } from './tool_context.js'; +import { + type ToolContext, + type ToolInputSchema, + type ToolOptions, + sortedToolNames, +} from './tool_context.js'; import { isZodSchema, parseZodSchema, zodSchemaToJsonSchema } from './zod-utils.js'; export interface SerializedImage { @@ -176,7 +181,7 @@ export const oaiBuildFunctionInfo = ( toolName: string, rawArgs: string, ): FunctionCall => { - const tool = toolCtx[toolName]; + const tool = toolCtx.getFunctionTool(toolName); if (!tool) { throw new Error(`AI tool ${toolName} not found`); } @@ -263,8 +268,18 @@ export async function executeToolCall( toolCall: FunctionCall, toolCtx: ToolContext, ): Promise { - const tool = toolCtx[toolCall.name]!; - let args: Record | undefined; + const tool = toolCtx.getFunctionTool(toolCall.name); + if (!tool) { + const availableTools = sortedToolNames(toolCtx).join(', '); + return FunctionCallOutput.create({ + callId: toolCall.callId, + name: toolCall.name, + output: `Unknown function: ${toolCall.name} - available tools: ${availableTools}`, + isError: true, + }); + } + + let args: object | undefined; let params: object | undefined; // Ensure valid JSON diff --git a/agents/src/utils.test.ts b/agents/src/utils.test.ts index eff60d542..421697fb9 100644 --- a/agents/src/utils.test.ts +++ b/agents/src/utils.test.ts @@ -14,6 +14,7 @@ import { delay, isPending, resampleStream, + toStream, } from '../src/utils.js'; describe('utils', () => { @@ -36,6 +37,84 @@ describe('utils', () => { }); }); + describe('toStream', () => { + it('converts an async iterable into a ReadableStream', async () => { + async function* source() { + yield 1; + yield 2; + yield 3; + } + + const reader = toStream(source()).getReader(); + + await expect(reader.read()).resolves.toEqual({ done: false, value: 1 }); + await expect(reader.read()).resolves.toEqual({ done: false, value: 2 }); + await expect(reader.read()).resolves.toEqual({ done: false, value: 3 }); + await expect(reader.read()).resolves.toEqual({ done: true, value: undefined }); + }); + + it('propagates errors from the async iterable', async () => { + const expectedError = new Error('source failed'); + async function* source() { + yield 1; + throw expectedError; + } + + const reader = toStream(source()).getReader(); + + await expect(reader.read()).resolves.toEqual({ done: false, value: 1 }); + await expect(reader.read()).rejects.toBe(expectedError); + }); + + it('runs async iterable cleanup when the stream is canceled mid-stream', async () => { + let cleanupRan = false; + async function* source() { + try { + yield 1; + yield 2; + } finally { + cleanupRan = true; + } + } + + const reader = toStream(source()).getReader(); + + await expect(reader.read()).resolves.toEqual({ done: false, value: 1 }); + await reader.cancel('stop early'); + + expect(cleanupRan).toBe(true); + }); + + it('does not wait for a pending next value when canceled mid-stream', async () => { + let releaseNextValue: (() => void) | undefined; + let cleanupRan = false; + async function* source() { + try { + yield 1; + await new Promise((resolve) => { + releaseNextValue = resolve; + }); + yield 2; + } finally { + cleanupRan = true; + } + } + + const reader = toStream(source()).getReader(); + + await expect(reader.read()).resolves.toEqual({ done: false, value: 1 }); + const pendingRead = reader.read(); + await delay(1); + await expect( + Promise.race([reader.cancel('stop early'), delay(10).then(() => 'timeout')]), + ).resolves.not.toBe('timeout'); + releaseNextValue?.(); + await expect(pendingRead).resolves.toEqual({ done: true, value: undefined }); + + expect(cleanupRan).toBe(true); + }); + }); + describe('Task', () => { it('should execute task successfully and return result', async () => { const expectedResult = 'task completed'; @@ -160,29 +239,39 @@ describe('utils', () => { }); it('should handle task that checks abort signal manually', async () => { - const arr: number[] = []; - const task = Task.from(async (controller) => { - for (let i = 0; i < 10; i++) { - if (controller.signal.aborted) { - throw new Error('Task was aborted'); + // fake timers: with real timers the exact tick count is scheduling- + // dependent and flakes on loaded CI runners + vi.useFakeTimers(); + try { + const arr: number[] = []; + const task = Task.from(async (controller) => { + for (let i = 0; i < 10; i++) { + if (controller.signal.aborted) { + throw new Error('Task was aborted'); + } + await delay(10); + arr.push(i); } - await delay(10); - arr.push(i); - } - return 'completed'; - }); + return 'completed'; + }); - await delay(35); - task.cancel(); + await vi.advanceTimersByTimeAsync(35); + task.cancel(); - expect(arr).toEqual([0, 1, 2]); - try { - await task.result; - } catch (error: unknown) { - expect((error as Error).message).toBe('Task was aborted'); - } + expect(arr).toEqual([0, 1, 2]); + // the pending (signal-less) delay must elapse for the loop to reach + // its manual abort checkpoint + await vi.advanceTimersByTimeAsync(10); + try { + await task.result; + } catch (error: unknown) { + expect((error as Error).message).toBe('Task was aborted'); + } - expect(task.done).toBe(true); + expect(task.done).toBe(true); + } finally { + vi.useRealTimers(); + } }); it('should handle cleanup in finally block', async () => { diff --git a/agents/src/utils.ts b/agents/src/utils.ts index 2cf409211..c23238f42 100644 --- a/agents/src/utils.ts +++ b/agents/src/utils.ts @@ -14,8 +14,11 @@ import { type Throws, ThrowsPromise } from '@livekit/throws-transformer/throws'; import { AsyncLocalStorage } from 'node:async_hooks'; import { randomUUID } from 'node:crypto'; import { EventEmitter, once } from 'node:events'; -import type { ReadableStream } from 'node:stream/web'; -import { TransformStream, type TransformStreamDefaultController } from 'node:stream/web'; +import { + ReadableStream, + TransformStream, + type TransformStreamDefaultController, +} from 'node:stream/web'; import { log } from './log.js'; /** @@ -1322,15 +1325,21 @@ export async function* readStream( const abortPromise = waitForAbort(signal); while (true) { const result = await ThrowsPromise.race([reader.read(), abortPromise]); - if (!result) break; + if (!result) { + break; + } const { done, value } = result; - if (done) break; + if (done) { + break; + } yield value; } } else { while (true) { const { done, value } = await reader.read(); - if (done) break; + if (done) { + break; + } yield value; } } @@ -1343,6 +1352,39 @@ export async function* readStream( } } +export function toStream(iterable: AsyncIterable): ReadableStream { + let iterator: AsyncIterator | undefined; + let cancelled = false; + + return new ReadableStream({ + async start(controller) { + iterator = iterable[Symbol.asyncIterator](); + + try { + while (true) { + const { done, value } = await iterator.next(); + if (done || cancelled) { + break; + } + controller.enqueue(value); + } + + if (!cancelled) { + controller.close(); + } + } catch (error) { + if (!cancelled) { + controller.error(error); + } + } + }, + cancel(reason) { + cancelled = true; + void iterator?.return?.(reason).catch(() => {}); + }, + }); +} + export async function waitForAbort(signal: AbortSignal) { if (signal.aborted) { return; diff --git a/agents/src/voice/agent.test.ts b/agents/src/voice/agent.test.ts index 8dd83fee3..e644ca043 100644 --- a/agents/src/voice/agent.test.ts +++ b/agents/src/voice/agent.test.ts @@ -1,9 +1,11 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 +import type { AudioFrame } from '@livekit/rtc-node'; +import { ReadableStream } from 'node:stream/web'; import { describe, expect, it, vi } from 'vitest'; import { z } from 'zod'; -import { tool } from '../llm/index.js'; +import { ChatContext, ChatMessage, tool } from '../llm/index.js'; import { initializeLogger } from '../log.js'; import { Task } from '../utils.js'; import { Agent, AgentTask, _setActivityTaskInfo } from './agent.js'; @@ -15,6 +17,14 @@ vi.mock('ofetch', () => ({ ofetch: vi.fn() })); initializeLogger({ pretty: false, level: 'error' }); +async function collectReadableStream(stream: ReadableStream): Promise { + const chunks: T[] = []; + for await (const chunk of stream) { + chunks.push(chunk); + } + return chunks; +} + describe('Agent', () => { it('should create agent with basic instructions', () => { const instructions = 'You are a helpful assistant'; @@ -29,12 +39,14 @@ describe('Agent', () => { // Create mock tools using the tool function const mockTool1 = tool({ + name: 'getTool1', description: 'First test tool', parameters: z.object({}), execute: async () => 'tool1 result', }); const mockTool2 = tool({ + name: 'getTool2', description: 'Second test tool', parameters: z.object({ input: z.string().describe('Input parameter'), @@ -44,17 +56,14 @@ describe('Agent', () => { const agent = new Agent({ instructions, - tools: { - getTool1: mockTool1, - getTool2: mockTool2, - }, + tools: [mockTool1, mockTool2], }); expect(agent).toBeDefined(); expect(agent.instructions).toBe(instructions); // Assert tools are set correctly - const agentTools = agent.toolCtx; + const agentTools = agent.toolCtx.functionTools; expect(Object.keys(agentTools)).toHaveLength(2); expect(agentTools).toHaveProperty('getTool1'); expect(agentTools).toHaveProperty('getTool2'); @@ -64,27 +73,271 @@ describe('Agent', () => { expect(agentTools.getTool2?.description).toBe('Second test tool'); }); - it('should return a copy of tools, not the original reference', () => { + it('toolCtx returns a defensive copy that exposes the same tools', () => { const instructions = 'You are a helpful assistant'; const mockTool = tool({ + name: 'testTool', description: 'Test tool', parameters: z.object({}), execute: async () => 'result', }); - const tools = { testTool: mockTool }; - const agent = new Agent({ instructions, tools }); + const agent = new Agent({ instructions, tools: [mockTool] }); + + // Each call returns a fresh ToolContext so external mutation can't escape into the agent's + // internal state. + expect(agent.toolCtx).not.toBe(agent.toolCtx); + expect(agent.toolCtx.getFunctionTool('testTool')).toBe(mockTool); + }); + + describe('create', () => { + it('preserves constructor options and base Agent default id', () => { + const mockTool = tool({ + name: 'testTool', + description: 'Test tool', + parameters: z.object({}), + execute: async () => 'result', + }); + + const agent = Agent.create({ + instructions: 'factory instructions', + tools: [mockTool], + }); + + expect(agent).toBeInstanceOf(Agent); + expect(agent.instructions).toBe('factory instructions'); + expect(agent.id).toBe('default_agent'); + expect(agent.toolCtx.getFunctionTool('testTool')).toBe(mockTool); + }); + + it('passes AgentContext to lifecycle hooks', async () => { + const calls: string[] = []; + const chatCtx = ChatContext.empty(); + const newMessage = ChatMessage.create({ role: 'user', content: ['hello'] }); + const agent = Agent.create({ + id: 'factory_agent', + instructions: 'factory instructions', + minConsecutiveSpeechDelay: 12, + onEnter: (ctx) => { + expect(ctx.agent).toBe(agent); + expect(ctx.id).toBe(agent.id); + expect(ctx.instructions).toBe(agent.instructions); + expect(ctx.toolCtx.functionTools).toEqual(agent.toolCtx.functionTools); + expect(ctx.chatCtx.items).toEqual(agent.chatCtx.items); + expect(ctx.minConsecutiveSpeechDelay).toBe(agent.minConsecutiveSpeechDelay); + calls.push('enter'); + }, + onExit: async (ctx) => { + expect(ctx.agent).toBe(agent); + calls.push('exit'); + }, + onUserTurnCompleted: (ctx, receivedChatCtx, receivedMessage) => { + expect(ctx.agent).toBe(agent); + expect(receivedChatCtx).toBe(chatCtx); + expect(receivedMessage).toBe(newMessage); + calls.push('turn'); + }, + }); + + await agent.onEnter(); + await agent.onExit(); + await agent.onUserTurnCompleted(chatCtx, newMessage); + + expect(calls).toEqual(['enter', 'exit', 'turn']); + }); + + it('adapts stream node hooks between ReadableStream and AsyncIterable', async () => { + const audioFrame = 'audio' as unknown as AudioFrame; + const agent = Agent.create({ + instructions: 'factory instructions', + async sttNode(ctx, audio) { + async function* stream() { + expect(ctx.agent).toBe(agent); + const frames: AudioFrame[] = []; + for await (const frame of audio) { + frames.push(frame); + } + expect(frames).toEqual([audioFrame]); + yield 'transcript'; + } + + return stream(); + }, + }); + const audio = new ReadableStream({ + start(controller) { + controller.enqueue(audioFrame); + controller.close(); + }, + }); + + const result = await agent.sttNode(audio, {}); + + expect(result).not.toBeNull(); + await expect(collectReadableStream(result!)).resolves.toEqual(['transcript']); + }); + + it('supports async generator stream node hooks', async () => { + const audioFrame = 'audio' as unknown as AudioFrame; + const outputFrame = 'output-audio' as unknown as AudioFrame; + const agent = Agent.create({ + instructions: 'factory instructions', + async *sttNode(ctx, audio) { + expect(ctx.agent).toBe(agent); + const frames: AudioFrame[] = []; + for await (const frame of audio) { + frames.push(frame); + } + expect(frames).toEqual([audioFrame]); + yield 'transcript'; + }, + async *llmNode(ctx, chatCtx, toolCtx) { + expect(ctx.agent).toBe(agent); + expect(chatCtx).toBeInstanceOf(ChatContext); + expect(toolCtx.equals(agent.toolCtx)).toBe(true); + yield 'llm-output'; + }, + async *ttsNode(ctx, text) { + expect(ctx.agent).toBe(agent); + const chunks: string[] = []; + for await (const chunk of text) { + chunks.push(chunk); + } + expect(chunks).toEqual(['hello']); + yield outputFrame; + }, + async *realtimeAudioOutputNode(ctx, audio) { + expect(ctx.agent).toBe(agent); + const frames: AudioFrame[] = []; + for await (const frame of audio) { + frames.push(frame); + } + expect(frames).toEqual([audioFrame]); + yield outputFrame; + }, + }); + const audio = new ReadableStream({ + start(controller) { + controller.enqueue(audioFrame); + controller.close(); + }, + }); + const text = new ReadableStream({ + start(controller) { + controller.enqueue('hello'); + controller.close(); + }, + }); + + const sttResult = await agent.sttNode(audio, {}); + const llmResult = await agent.llmNode(ChatContext.empty(), agent.toolCtx, {}); + const ttsResult = await agent.ttsNode(text, {}); + const realtimeResult = await agent.realtimeAudioOutputNode( + new ReadableStream({ + start(controller) { + controller.enqueue(audioFrame); + controller.close(); + }, + }), + {}, + ); + + expect(sttResult).not.toBeNull(); + expect(llmResult).not.toBeNull(); + expect(ttsResult).not.toBeNull(); + expect(realtimeResult).not.toBeNull(); + await expect(collectReadableStream(sttResult!)).resolves.toEqual(['transcript']); + await expect(collectReadableStream(llmResult!)).resolves.toEqual(['llm-output']); + await expect(collectReadableStream(ttsResult!)).resolves.toEqual([outputFrame]); + await expect(collectReadableStream(realtimeResult!)).resolves.toEqual([outputFrame]); + }); + + it('supports stream node hooks that return async iterables', async () => { + function asyncIterableOf(...items: T[]): AsyncIterable { + return { + async *[Symbol.asyncIterator]() { + for (const item of items) { + yield item; + } + }, + }; + } + + const audioFrame = 'audio' as unknown as AudioFrame; + const outputFrame = 'output-audio' as unknown as AudioFrame; + const agent = Agent.create({ + instructions: 'factory instructions', + sttNode(ctx) { + expect(ctx.agent).toBe(agent); + return asyncIterableOf('transcript'); + }, + llmNode(ctx) { + expect(ctx.agent).toBe(agent); + return asyncIterableOf('llm-output'); + }, + ttsNode(ctx) { + expect(ctx.agent).toBe(agent); + return asyncIterableOf(outputFrame); + }, + realtimeAudioOutputNode(ctx) { + expect(ctx.agent).toBe(agent); + return asyncIterableOf(outputFrame); + }, + }); - const tools1 = agent.toolCtx; - const tools2 = agent.toolCtx; + const sttResult = await agent.sttNode( + new ReadableStream({ + start(controller) { + controller.enqueue(audioFrame); + controller.close(); + }, + }), + {}, + ); + const llmResult = await agent.llmNode(ChatContext.empty(), agent.toolCtx, {}); + const ttsResult = await agent.ttsNode( + new ReadableStream({ + start(controller) { + controller.enqueue('hello'); + controller.close(); + }, + }), + {}, + ); + const realtimeResult = await agent.realtimeAudioOutputNode( + new ReadableStream({ + start(controller) { + controller.enqueue(audioFrame); + controller.close(); + }, + }), + {}, + ); - // Should return different object references - expect(tools1).not.toBe(tools2); - expect(tools1).not.toBe(tools); + expect(sttResult).not.toBeNull(); + expect(llmResult).not.toBeNull(); + expect(ttsResult).not.toBeNull(); + expect(realtimeResult).not.toBeNull(); + await expect(collectReadableStream(sttResult!)).resolves.toEqual(['transcript']); + await expect(collectReadableStream(llmResult!)).resolves.toEqual(['llm-output']); + await expect(collectReadableStream(ttsResult!)).resolves.toEqual([outputFrame]); + await expect(collectReadableStream(realtimeResult!)).resolves.toEqual([outputFrame]); + }); + + it('falls back to existing defaults for missing hooks', async () => { + const audioFrame = 'audio' as unknown as AudioFrame; + const audio = new ReadableStream({ + start(controller) { + controller.enqueue(audioFrame); + controller.close(); + }, + }); + const agent = Agent.create({ instructions: 'factory instructions' }); + + const result = await agent.realtimeAudioOutputNode(audio, {}); - // Should contain the same set of tools - expect(tools1).toEqual(tools2); - expect(tools1).toEqual(tools); + expect(result).toBe(audio); + }); }); it('should require AgentTask to run inside task context', async () => { @@ -154,6 +407,94 @@ describe('Agent', () => { await expect(wrapper.result).resolves.toBe('ok'); }); + describe('AgentTask.create', () => { + it('exposes complete on hook context', async () => { + const task = AgentTask.create({ + instructions: 'factory task', + onEnter: (ctx) => { + expect(ctx.agent).toBe(task); + expect(ctx.id).toBe('default_agent'); + expect(ctx.instructions).toBe('factory task'); + ctx.complete('ok'); + }, + }); + const oldAgent = new Agent({ instructions: 'old agent' }); + const mockSession = { + currentAgent: oldAgent, + _globalRunState: undefined, + _updateActivity: async (agent: Agent) => { + if (agent === task) { + await agent.onEnter(); + } + }, + }; + const mockActivity = { + agent: oldAgent, + agentSession: mockSession, + _onEnterTask: undefined, + llm: undefined, + close: async () => {}, + }; + + const wrapper = Task.from(async () => { + const currentTask = Task.current(); + if (!currentTask) { + throw new Error('expected task context'); + } + _setActivityTaskInfo(currentTask, { inlineTask: true }); + return await agentActivityStorage.run(mockActivity as any, () => task.run()); + }); + + await expect(wrapper.result).resolves.toBe('ok'); + }); + + it('adapts stream node hooks between ReadableStream and AsyncIterable', async () => { + const audioFrame = 'audio' as unknown as AudioFrame; + const task = AgentTask.create({ + instructions: 'factory task', + async sttNode(ctx, audio) { + async function* stream() { + expect(ctx.agent).toBe(task); + const frames: AudioFrame[] = []; + for await (const frame of audio) { + frames.push(frame); + } + expect(frames).toEqual([audioFrame]); + yield 'transcript'; + } + + return stream(); + }, + }); + const audio = new ReadableStream({ + start(controller) { + controller.enqueue(audioFrame); + controller.close(); + }, + }); + + const result = await task.sttNode(audio, {}); + + expect(result).not.toBeNull(); + await expect(collectReadableStream(result!)).resolves.toEqual(['transcript']); + }); + + it('falls back to existing defaults for missing hooks', async () => { + const audioFrame = 'audio' as unknown as AudioFrame; + const audio = new ReadableStream({ + start(controller) { + controller.enqueue(audioFrame); + controller.close(); + }, + }); + const task = AgentTask.create({ instructions: 'factory task' }); + + const result = await task.realtimeAudioOutputNode(audio, {}); + + expect(result).toBe(audio); + }); + }); + it('should require AgentTask to run inside AgentActivity context', async () => { class TestTask extends AgentTask { constructor() { @@ -242,6 +583,7 @@ describe('Agent', () => { turnHandling: { endpointing: { minDelay: 999 }, interruption: {}, + preemptiveGeneration: {}, turnDetection: 'vad', }, allowInterruptions: false, @@ -258,6 +600,7 @@ describe('Agent', () => { turnHandling: { endpointing: { minDelay: 999, maxDelay: 4000 }, interruption: { enabled: true }, + preemptiveGeneration: {}, turnDetection: 'vad', }, allowInterruptions: false, @@ -275,6 +618,7 @@ describe('Agent', () => { turnHandling: { interruption: { mode: 'adaptive' }, endpointing: {}, + preemptiveGeneration: {}, turnDetection: undefined, }, }); @@ -287,6 +631,7 @@ describe('Agent', () => { turnHandling: { endpointing: { minDelay: 111, maxDelay: 222 }, interruption: { enabled: false }, + preemptiveGeneration: {}, turnDetection: 'manual', }, }); diff --git a/agents/src/voice/agent.ts b/agents/src/voice/agent.ts index c6f7b82df..e68dab727 100644 --- a/agents/src/voice/agent.ts +++ b/agents/src/voice/agent.ts @@ -20,7 +20,9 @@ import { LLM, RealtimeModel, type ToolChoice, - type ToolContext, + ToolContext, + type ToolContextInit, + normalizeToolContextInit, } from '../llm/index.js'; import { log } from '../log.js'; import type { STT, SpeechEvent } from '../stt/index.js'; @@ -33,12 +35,28 @@ import { Future, Task } from '../utils.js'; import type { VAD } from '../vad.js'; import { type AgentActivity, agentActivityStorage } from './agent_activity.js'; import type { AgentSession, TurnDetectionMode } from './agent_session.js'; +import { + type AgentCreateOptions, + type AgentTaskCreateOptions, + createAgentTaskV2, + createAgentV2, +} from './agent_v2.js'; import type { UserTurnExceededEvent } from './events.js'; import type { TimedString } from './io.js'; import type { SpeechHandle } from './speech_handle.js'; +import type { ToolHandlingOptions } from './tool_executor.js'; import type { TurnHandlingOptions } from './turn_config/turn_handling.js'; import { migrateTurnHandling } from './turn_config/utils.js'; +export type { + AgentContext, + AgentCreateOptions, + AgentHookNodeResult, + AgentHooks, + AgentTaskContext, + AgentTaskCreateOptions, +} from './agent_v2.js'; + // speechHandle identifies which SpeechHandle owns the current tool call, enabling // SpeechHandle.waitForPlayout() to distinguish self-wait (deadlock) from waiting // on a different handle scheduled inside the tool. @@ -118,12 +136,13 @@ export interface AgentOptions { id?: string; instructions: string | Instructions; chatCtx?: ChatContext; - tools?: ToolContext; + tools?: ToolContextInit; stt?: STT | STTModelString; vad?: VAD; llm?: LLM | RealtimeModel | LLMModels; tts?: TTS | TTSModelString; turnHandling?: TurnHandlingOptions; + toolHandling?: ToolHandlingOptions; minConsecutiveSpeechDelay?: number; useTtsAlignedTranscript?: boolean; /** @deprecated use turnHandling.turnDetection instead */ @@ -153,7 +172,14 @@ export class Agent { _instructions: string | Instructions; /** @internal */ - _tools?: ToolContext; + _toolCtx: ToolContext; + + /** @internal */ + _asyncToolOptions?: ToolHandlingOptions['asyncOptions']; + + static create(options: AgentCreateOptions): Agent { + return createAgentV2(Agent, options); + } constructor({ id, @@ -167,6 +193,7 @@ export class Agent { tts, allowInterruptions, turnHandling, + toolHandling, minConsecutiveSpeechDelay, useTtsAlignedTranscript, }: AgentOptions) { @@ -185,10 +212,10 @@ export class Agent { } this._instructions = instructions; - this._tools = { ...tools }; + this._toolCtx = new ToolContext(tools ?? []); this._chatCtx = chatCtx ? chatCtx.copy({ - toolCtx: this._tools, + toolCtx: this._toolCtx, }) : ChatContext.empty(); @@ -199,6 +226,7 @@ export class Agent { }); this._turnHandling = Object.keys(resolvedTurnHandling).length > 0 ? resolvedTurnHandling : undefined; + this._asyncToolOptions = toolHandling?.asyncOptions; this._vad = vad; @@ -259,7 +287,7 @@ export class Agent { } get toolCtx(): ToolContext { - return { ...this._tools }; + return this._toolCtx.copy(); } get session(): AgentSession { @@ -346,15 +374,25 @@ export class Agent { this._agentActivity.updateChatCtx(chatCtx); } - // TODO: Add when AgentConfigUpdate is ported to ChatContext. - async updateTools(tools: ToolContext): Promise { + async updateInstructions(instructions: string | Instructions): Promise { if (!this._agentActivity) { - this._tools = { ...tools }; - this._chatCtx = this._chatCtx.copy({ toolCtx: this._tools }); + this._instructions = instructions; return; } - await this._agentActivity.updateTools(tools); + await this._agentActivity.updateInstructions(instructions); + } + + // TODO(parity): Add when AgentConfigUpdate is ported to ChatContext. + async updateTools(tools: ToolContextInit): Promise { + const normalizedTools = normalizeToolContextInit(tools); + if (!this._agentActivity) { + this._toolCtx = new ToolContext(normalizedTools); + this._chatCtx = this._chatCtx.copy({ toolCtx: this._toolCtx }); + return; + } + + await this._agentActivity.updateTools(normalizedTools); } static default = { @@ -557,6 +595,12 @@ export class AgentTask extends Agent( + options: AgentTaskCreateOptions, + ): AgentTask { + return createAgentTaskV2(AgentTask, options); + } + constructor(options: AgentTaskOptions) { const { preserveFunctionCallHistory = false, ...rest } = options; super(rest); @@ -610,13 +654,24 @@ export class AgentTask extends Agent { - if (this.future.done) return; - - // If the Task finished before the AgentTask was completed, complete the AgentTask with an error. - this.#logger.error(`The Task finished before ${this.constructor.name} was completed.`); - this.complete(new Error(`The Task finished before ${this.constructor.name} was completed.`)); - }); + // A non-blocking tool (one that called ctx.update) detaches from its speech + // task, so that task completes before a later ctx.foreground/AgentTask does. + // Binding the guard below to the already-finished task would fire it + // immediately and tear the AgentTask down; the tool's still-running promise + // and the foreground hold keep it alive instead. + const ownerIsNonBlocking = + taskInfo.functionCall?.extra.__livekit_agents_tool_non_blocking === true; + if (!ownerIsNonBlocking) { + currentTask.addDoneCallback(() => { + if (this.future.done) return; + + // If the Task finished before the AgentTask was completed, complete the AgentTask with an error. + this.#logger.error(`The Task finished before ${this.constructor.name} was completed.`); + this.complete( + new Error(`The Task finished before ${this.constructor.name} was completed.`), + ); + }); + } const oldAgent = oldActivity.agent; const session = oldActivity.agentSession; diff --git a/agents/src/voice/agent_activity.test.ts b/agents/src/voice/agent_activity.test.ts index 5600418d3..2bc4e018b 100644 --- a/agents/src/voice/agent_activity.test.ts +++ b/agents/src/voice/agent_activity.test.ts @@ -16,8 +16,9 @@ */ import { Heap } from 'heap-js'; import { describe, expect, it, vi } from 'vitest'; -import { ChatContext } from '../llm/chat_context.js'; +import { AgentConfigUpdate, ChatContext } from '../llm/chat_context.js'; import { LLM, type LLMStream } from '../llm/llm.js'; +import { type Tool, ToolContext, ToolFlag, Toolset, tool } from '../llm/tool_context.js'; import { Future, Task } from '../utils.js'; import { _getActivityTaskInfo } from './agent.js'; import { AgentActivity } from './agent_activity.js'; @@ -449,15 +450,16 @@ function buildPreemptiveRunner(opts: Partial = {}) { const fakeChatCtx = new ChatContext(); + const emptyToolCtx = ToolContext.empty(); const fakeActivity = { _preemptiveGenerationCount: 0, _preemptiveGeneration: undefined, _currentSpeech: undefined as SpeechHandle | undefined, schedulingPaused: false, llm: new FakePreemptiveLLM(), - tools: {}, + tools: emptyToolCtx, toolChoice: null, - agent: { chatCtx: fakeChatCtx }, + agent: { chatCtx: fakeChatCtx, _toolCtx: emptyToolCtx }, agentSession: { sessionOptions: { turnHandling: { preemptiveGeneration: preemptiveOpts }, @@ -555,3 +557,243 @@ describe('AgentActivity - onPreemptiveGeneration guards', () => { expect(cancelPreemptiveGeneration).not.toHaveBeenCalled(); }); }); + +/** + * Regression test for the dynamic-toolset push path. + * + * When an already-activated toolset swaps its tools at runtime (e.g. an MCP server pushes a new + * tool list via `ToolsetContext.updateTools`), `setupToolsetList`'s wiring must (1) invoke + * `onToolsetToolsChanged`, which now funnels through `updateTools`, and (2) record an + * `AgentConfigUpdate` in the agent chat context + session history so a non-realtime pipeline's + * chat context reflects the new tool set on the next turn. + */ +class FakeToolsetLLM extends LLM { + label(): string { + return 'fake.toolset.LLM'; + } + chat(): LLMStream { + throw new Error('not used in these tests'); + } +} + +describe('AgentActivity - onToolsetToolsChanged (dynamic toolset push)', () => { + const makeFn = (name: string) => + tool({ name, description: `${name} tool`, execute: async () => name }); + + function buildToolsetActivity(toolset: Toolset) { + const history = new ChatContext(); + const fakeActivity = { + _toolsetsSetup: true, + realtimeSession: undefined, + llm: new FakeToolsetLLM(), + agent: { + _toolCtx: new ToolContext([toolset]), + _chatCtx: new ChatContext(), + }, + agentSession: { history }, + updateChatCtx: vi.fn(async () => {}), + logger: { info() {}, debug() {}, warn() {}, error() {} }, + }; + Object.setPrototypeOf(fakeActivity, AgentActivity.prototype); + return { fakeActivity, history }; + } + + it('fires onToolsetToolsChanged on a dynamic push and records an AgentConfigUpdate', async () => { + const toolA = makeFn('toolA'); + const toolB = makeFn('toolB'); + + // Capture the wired ctx.updateTools the framework hands the toolset during setup. + let pushTools!: (tools: readonly Tool[]) => void; + const toolset = Toolset.create({ + id: 'dynamic', + tools: [toolA], + setup: async ({ updateTools }) => { + pushTools = updateTools; + }, + }); + + const { fakeActivity, history } = buildToolsetActivity(toolset); + + const changedSpy = vi.spyOn( + AgentActivity.prototype as unknown as Record<'onToolsetToolsChanged', () => Promise>, + 'onToolsetToolsChanged', + ); + + // Activate the toolset through the real path so it captures the push channel. + const setupToolsetList = (AgentActivity.prototype as Record) + .setupToolsetList as (this: unknown, toolsets: readonly Toolset[]) => Promise; + await setupToolsetList.call(fakeActivity, [toolset]); + + expect(changedSpy).not.toHaveBeenCalled(); + + // (1) A dynamic push swaps the toolset's tools — the wiring must invoke onToolsetToolsChanged. + pushTools([toolA, toolB]); + expect(changedSpy).toHaveBeenCalledTimes(1); + await changedSpy.mock.results[0]!.value; + + // (2) An AgentConfigUpdate naming the added tool lands in the session history. + const updates = history.items.filter( + (i): i is AgentConfigUpdate => i instanceof AgentConfigUpdate, + ); + expect(updates).toHaveLength(1); + expect(updates[0]!.toolsAdded).toContain('toolB'); + expect(updates[0]!.toolsRemoved ?? []).not.toContain('toolA'); + + // The refreshed tool context advertises the new tool to the next turn, and the non-realtime + // pipeline's chat context was refreshed via updateChatCtx. + expect(Object.keys(fakeActivity.agent._toolCtx.functionTools).sort()).toEqual([ + 'toolA', + 'toolB', + ]); + expect(fakeActivity.updateChatCtx).toHaveBeenCalledTimes(1); + + changedSpy.mockRestore(); + }); +}); + +/** + * Regression test for PR #1736 review (#3378550188): session-level toolsets must have `setup()` + * run ONCE for the session's lifetime (their `aclose()` runs once at session close), while + * agent-level toolsets are set up per activity. Re-running a session toolset's `setup()` on every + * handoff would acquire resources without a matching `aclose()` (resource/listener leak). + */ +describe('AgentActivity - session toolset setup lifecycle (#3378550188)', () => { + function buildSetupActivity( + agentSession: { tools: Toolset[]; _sessionToolsetsSetup: boolean }, + agentToolset: Toolset, + ) { + const fakeActivity = { + _toolsetsSetup: false, + agentSession, + agent: { + toolCtx: new ToolContext([agentToolset]), + _toolCtx: { tools: [] as Tool[], updateTools: vi.fn() }, + }, + logger: { info() {}, debug() {}, warn() {}, error() {} }, + }; + Object.setPrototypeOf(fakeActivity, AgentActivity.prototype); + return fakeActivity; + } + + it('sets up session toolsets once across a handoff; agent toolsets per activity', async () => { + const sessionSetup = vi.fn(async () => {}); + const sessionToolset = Toolset.create({ id: 'session', tools: [], setup: sessionSetup }); + + const agentSetupA = vi.fn(async () => {}); + const agentToolsetA = Toolset.create({ id: 'agent_a', tools: [], setup: agentSetupA }); + const agentSetupB = vi.fn(async () => {}); + const agentToolsetB = Toolset.create({ id: 'agent_b', tools: [], setup: agentSetupB }); + + const agentSession = { tools: [sessionToolset], _sessionToolsetsSetup: false }; + + const setupToolsets = (AgentActivity.prototype as Record).setupToolsets as ( + this: unknown, + ) => Promise; + + // Activity #1 (agent A) sets up its own toolset + the session toolset. + await setupToolsets.call(buildSetupActivity(agentSession, agentToolsetA)); + // Activity #2 (handoff to agent B): a fresh activity with its own _toolsetsSetup=false. + await setupToolsets.call(buildSetupActivity(agentSession, agentToolsetB)); + + expect(sessionSetup).toHaveBeenCalledTimes(1); + expect(agentSetupA).toHaveBeenCalledTimes(1); + expect(agentSetupB).toHaveBeenCalledTimes(1); + expect(agentSession._sessionToolsetsSetup).toBe(true); + }); +}); + +describe('AgentActivity - preemptive generation tool snapshot (#3407098507)', () => { + it('snapshots the merged tool set so reuse is not invalidated when a cancellable tool is present', () => { + const cancellable = tool({ + name: 'bookFlight', + description: 'book a flight', + execute: async () => 'ok', + flags: ToolFlag.CANCELLABLE, + }); + const agentToolCtx = new ToolContext([cancellable]); + + const generateReply = vi.fn( + () => ({ id: 'speech_fake', _cancel: () => {} }) as unknown as SpeechHandle, + ); + + const fakeActivity = { + _preemptiveGenerationCount: 0, + _preemptiveGeneration: undefined as unknown, + _currentSpeech: undefined as SpeechHandle | undefined, + schedulingPaused: false, + newTurnsBlocked: false, + llm: new FakePreemptiveLLM(), + toolChoice: null, + // `get tools()` (real prototype getter) reads agentSession.tools + agent.toolCtx and + // injects the management tools when a cancellable tool exists. We intentionally do NOT set + // an own `tools` property so the real getter runs. + agent: { chatCtx: new ChatContext(), toolCtx: agentToolCtx, _toolCtx: agentToolCtx }, + agentSession: { + tools: [] as never[], + sessionOptions: { + turnHandling: { + preemptiveGeneration: { + enabled: true, + preemptiveTts: false, + maxSpeechDuration: 10_000, + maxRetries: 3, + }, + }, + }, + }, + logger: { info() {}, debug() {}, warn() {}, error() {} }, + generateReply, + cancelPreemptiveGeneration: vi.fn(), + }; + Object.setPrototypeOf(fakeActivity, AgentActivity.prototype); + + const onPreemptiveGeneration = (AgentActivity.prototype as unknown as Record) + .onPreemptiveGeneration as (this: unknown, info: PreemptiveGenerationInfo) => void; + + onPreemptiveGeneration.call(fakeActivity, { + newTranscript: 'hello world', + transcriptConfidence: 0.95, + startedSpeakingAt: undefined, + }); + + const snapshot = fakeActivity._preemptiveGeneration as { tools: ToolContext } | undefined; + expect(snapshot).toBeDefined(); + + const liveTools = (fakeActivity as unknown as { tools: ToolContext }).tools; + // Sanity: the live merged set is larger than agent-only (it injected the management tools), + // which is exactly the condition under which the old agent-only snapshot diverged. + expect(liveTools.tools.length).toBeGreaterThan(agentToolCtx.tools.length); + + // The reuse check (onUserTurnCompleted) does `preemptive.tools.equals(this.tools)`. + expect(snapshot!.tools.equals(liveTools)).toBe(true); + }); +}); + +describe('AgentActivity - waitForIdle close abort', () => { + it('returns promptly when the activity closes while waiting', async () => { + const closeAbort = new AbortController(); + const fakeActivity = { + closeAbort, + // Simulates a wait that only completes when its signal aborts (the spin condition). + waitForInactive: (_options: unknown, signal: AbortSignal) => + new Promise((resolve) => { + if (signal.aborted) return resolve(); + signal.addEventListener('abort', () => resolve(), { once: true }); + }), + agentSession: { _waitForIdleHoldReleased: async () => false }, + }; + Object.setPrototypeOf(fakeActivity, AgentActivity.prototype); + + const waitForIdle = ( + AgentActivity.prototype as unknown as { waitForIdle: (this: unknown) => Promise } + ).waitForIdle.bind(fakeActivity); + + const pending = waitForIdle(); + // Not idle yet — the wait is still pending. + expect(await raceTimeout(pending, 50)).toBe('timeout'); + + // close() aborts the shared signal; the wait must unblock. + closeAbort.abort(); + expect(await raceTimeout(pending, 1000)).toBe('resolved'); + }); +}); diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index 744205ecf..16184861a 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -27,6 +27,7 @@ import { instructionsEqual, renderInstructions, } from '../llm/chat_context.js'; +import { AsyncToolset, type Toolset } from '../llm/index.js'; import { type ChatItem, type FunctionCall, @@ -41,11 +42,14 @@ import { type RealtimeModelError, type RealtimeSession, type ToolChoice, - type ToolContext, + ToolContext, + type ToolContextEntry, ToolFlag, + isFunctionTool, + isToolset, } from '../llm/index.js'; import type { LLMError } from '../llm/llm.js'; -import { isSameToolChoice, isSameToolContext } from '../llm/tool_context.js'; +import { isSameToolChoice } from '../llm/tool_context.js'; import { log } from '../log.js'; import type { EOTInferenceMetrics, @@ -125,6 +129,12 @@ import { } from './generation.js'; import type { PlaybackFinishedEvent, TimedString } from './io.js'; import { type InputDetails, SpeechHandle } from './speech_handle.js'; +import { + ToolExecutor, + cancelTaskTool, + getRunningTasksTool, + hasCancellableTool, +} from './tool_executor.js'; import { type EndpointingOptions, createEndpointing } from './turn_config/endpointing.js'; import { resolveEndpointing } from './turn_config/utils.js'; import { createSilenceFrameLike, setParticipantSpanAttributes } from './utils.js'; @@ -238,6 +248,8 @@ export class AgentActivity implements RecognitionHooks { private toolChoice: ToolChoice | null = null; private _preemptiveGeneration?: PreemptiveGeneration; private _preemptiveGenerationCount = 0; + private _toolsetsSetup = false; + private readonly closeAbort = new AbortController(); private interruptionDetector?: AdaptiveInterruptionDetector; private isInterruptionDetectionEnabled: boolean; private isInterruptionByAudioActivityEnabled: boolean; @@ -302,6 +314,7 @@ export class AgentActivity implements RecognitionHooks { _onEnterTask?: Task; _onExitTask?: Task; _userTurnCompletedTask?: Task; + _toolExecutor: ToolExecutor; constructor(agent: Agent, agentSession: AgentSession) { this.agent = agent; @@ -316,6 +329,10 @@ export class AgentActivity implements RecognitionHooks { return p1 === p2 ? t1 - t2 : p2 - p1; }); this.q_updated = new Future(); + this._toolExecutor = new ToolExecutor({ + owningActivity: this, + asyncToolOptions: this.agent._asyncToolOptions ?? this.agentSession._asyncToolOptions, + }); this._resolvedTurnDetection = this._resolveTurnDetection(this.turnDetection); this.turnDetectionMode = @@ -451,6 +468,8 @@ export class AgentActivity implements RecognitionHooks { this.agent._agentActivity = this; + await this.setupToolsets(); + if (this.llm instanceof RealtimeModel) { const rtReused = reuseResources?.rtSession !== undefined; @@ -515,7 +534,13 @@ export class AgentActivity implements RecognitionHooks { } } - const initialTools = Object.keys(this.tools); + // Surface every tool the agent advertises at start — function tools by name and provider + // tools by id. + const initialToolCtx = this.tools; + const initialTools = [ + ...Object.keys(initialToolCtx.functionTools), + ...initialToolCtx.providerTools.map((t) => t.id), + ]; if (runOnEnter && (this.agent.instructions || initialTools.length > 0)) { const initialConfig = new AgentConfigUpdate({ instructions: this.agent.instructions, @@ -680,8 +705,7 @@ export class AgentActivity implements RecognitionHooks { // tools update is supported or tools are the same reusable = - reusable && - (capabilities.midSessionToolsUpdate || isSameToolContext(this.tools, newActivity.tools)); + reusable && (capabilities.midSessionToolsUpdate || this.tools.equals(newActivity.tools)); if (reusable) { // detach: remove event listeners but don't close the session @@ -750,7 +774,11 @@ export class AgentActivity implements RecognitionHooks { } get tools(): ToolContext { - return this.agent.toolCtx; + const tools: ToolContextEntry[] = [...this.agentSession.tools, ...this.agent.toolCtx.tools]; + if (hasCancellableTool(tools)) { + tools.push(cancelTaskTool, getRunningTasksTool); + } + return new ToolContext(tools); } get schedulingPaused(): boolean { @@ -830,7 +858,7 @@ export class AgentActivity implements RecognitionHooks { } get toolCtx(): ToolContext { - return this.agent.toolCtx; + return this.tools; } /** @internal */ @@ -863,13 +891,43 @@ export class AgentActivity implements RecognitionHooks { } } - async updateTools(tools: ToolContext): Promise { - const oldToolNames = new Set(Object.keys(this.tools)); - const newToolNames = new Set(Object.keys(tools)); + async updateInstructions(instructions: string | Instructions): Promise { + this.agent._instructions = instructions; + + const configUpdate = new AgentConfigUpdate({ instructions }); + this.agent._chatCtx.insert(configUpdate); + this.agentSession.history.insert(configUpdate); + + if (this.realtimeSession) { + await this.realtimeSession.updateInstructions(renderInstructions(instructions)); + } else { + updateInstructions({ + chatCtx: this.agent._chatCtx, + instructions, + addIfMissing: true, + }); + } + } + + async updateTools(tools: readonly ToolContextEntry[]): Promise { + const oldToolCtx = this.agent._toolCtx; + const oldToolNames = new Set(Object.keys(oldToolCtx.functionTools)); + const oldToolsets = oldToolCtx.toolsets; + const newToolCtx = new ToolContext(tools); + const newToolsets = newToolCtx.toolsets; + const addedToolsets = newToolsets.filter((ts) => !oldToolsets.includes(ts)); + const removedToolsets = oldToolsets.filter((ts) => !newToolsets.includes(ts)); + + // Resolve added factory toolsets before re-flattening, so their tools are included in the + // advertised set (newToolNames is computed below, after resolution). + await this.setupToolsetList(addedToolsets); + newToolCtx.updateTools(newToolCtx.tools); + const newToolNames = new Set(Object.keys(newToolCtx.functionTools)); const toolsAdded = [...newToolNames].filter((name) => !oldToolNames.has(name)); const toolsRemoved = [...oldToolNames].filter((name) => !newToolNames.has(name)); - this.agent._tools = { ...tools }; + this.agent._toolCtx = newToolCtx; + await this.closeToolsetList(removedToolsets); if (toolsAdded.length > 0 || toolsRemoved.length > 0) { const configUpdate = new AgentConfigUpdate({ @@ -881,12 +939,12 @@ export class AgentActivity implements RecognitionHooks { } if (this.realtimeSession) { - await this.realtimeSession.updateTools(tools); + await this.realtimeSession.updateTools(this.tools); } if (this.llm instanceof LLM) { // for realtime LLM, we assume the server will remove unvalid tool messages - await this.updateChatCtx(this.agent._chatCtx.copy({ toolCtx: tools })); + await this.updateChatCtx(this.agent._chatCtx.copy({ toolCtx: newToolCtx })); } } @@ -1561,7 +1619,7 @@ export class AgentActivity implements RecognitionHooks { userMessage, info, chatCtx: chatCtx.copy(), - tools: { ...this.tools }, + tools: this.tools, toolChoice: this.toolChoice, createdAt: Date.now(), }; @@ -1772,6 +1830,36 @@ export class AgentActivity implements RecognitionHooks { return this.agentSession.chatCtx; } + async waitForIdle( + options: { waitForAgent?: boolean; waitForUser?: boolean } = {}, + ): Promise { + const signal = this.closeAbort.signal; + while (!signal.aborted) { + await this.waitForInactive(options, signal); + if (signal.aborted) break; + if (!(await this.agentSession._waitForIdleHoldReleased())) { + break; + } + } + } + + private async waitForEndOfTurn(signal: AbortSignal): Promise { + if (this.audioRecognition) { + await this.waitForOrAbort( + this.audioRecognition.waitForEndOfTurnTask(), + signal, + 'error waiting for end-of-turn task', + ); + } + if (this._userTurnCompletedTask && !this._userTurnCompletedTask.done) { + await this.waitForOrAbort( + this._userTurnCompletedTask.result, + signal, + 'error waiting for user-turn-completed task', + ); + } + } + private async waitForInactive( options: { waitForAgent?: boolean; waitForUser?: boolean }, signal: AbortSignal, @@ -1787,13 +1875,7 @@ export class AgentActivity implements RecognitionHooks { } if (waitForAgent) { - if (this.audioRecognition) { - await this.waitForOrAbort( - this.audioRecognition.waitForEndOfTurnTask(), - signal, - 'error waiting for end-of-turn task', - ); - } + await this.waitForEndOfTurn(signal); if (!this._currentSpeech && this.speechQueue.size() === 0) { agentActive = false; @@ -1817,6 +1899,7 @@ export class AgentActivity implements RecognitionHooks { if (userActive) { await delay(0, { signal }); } + await this.waitForEndOfTurn(signal); } } } @@ -2023,13 +2106,14 @@ export class AgentActivity implements RecognitionHooks { const shouldFilterTools = onEnterData?.agent === this.agent && onEnterData?.session === this.agentSession; - const tools = shouldFilterTools - ? Object.fromEntries( - Object.entries(this.agent.toolCtx).filter( - ([, fnTool]) => !(fnTool.flags & ToolFlag.IGNORE_ON_ENTER), - ), + const tools: ToolContext = shouldFilterTools + ? new ToolContext( + this.tools.tools.filter((t): boolean => { + if (isToolset(t) || !isFunctionTool(t)) return true; + return !(t.flags & ToolFlag.IGNORE_ON_ENTER); + }), ) - : this.agent.toolCtx; + : this.tools; const task = this.createSpeechTask({ taskFn: (abortController: AbortController) => @@ -2274,7 +2358,7 @@ export class AgentActivity implements RecognitionHooks { if ( preemptive.info.newTranscript === userMessage?.textContent && preemptive.chatCtx.isEquivalent(chatCtx) && - isSameToolContext(preemptive.tools, this.tools) && + preemptive.tools.equals(this.tools) && isSameToolChoice(preemptive.toolChoice, this.toolChoice) ) { speechHandle = preemptive.speechHandle; @@ -2316,8 +2400,8 @@ export class AgentActivity implements RecognitionHooks { const eouMetrics: EOUMetrics = { type: 'eou_metrics', timestamp: Date.now(), - endOfUtteranceDelayMs: info.endOfUtteranceDelay, - transcriptionDelayMs: info.transcriptionDelay, + endOfUtteranceDelayMs: info.endOfUtteranceDelay ?? 0, + transcriptionDelayMs: info.transcriptionDelay ?? 0, onUserTurnCompletedDelayMs: callbackDuration, lastSpeakingTimeMs: info.stoppedSpeakingAt ?? 0, speechId: speechHandle.id, @@ -3044,12 +3128,12 @@ export class AgentActivity implements RecognitionHooks { // important: no agent output should be used after this point const { maxToolSteps } = this.agentSession.sessionOptions; - if (speechHandle.numSteps >= maxToolSteps) { + const maxStepsReached = speechHandle.numSteps >= maxToolSteps + 1; + if (maxStepsReached) { this.logger.warn( { speech_id: speechHandle.id, max_tool_steps: maxToolSteps }, - 'maximum number of function calls steps reached', + "maximum number of function calls steps reached, generating final response with toolChoice = 'none'", ); - return; } const { functionToolsExecutedEvent, shouldGenerateToolReply, newAgentTask, ignoreTaskSwitch } = @@ -3077,9 +3161,11 @@ export class AgentActivity implements RecognitionHooks { speechHandle._numSteps += 1; // Avoid setting tool_choice to "required" or a specific function when - // passing tool response back to the LLM + // passing tool response back to the LLM. const respondToolChoice = - schedulingPaused || modelSettings.toolChoice === 'none' ? 'none' : 'auto'; + maxStepsReached || schedulingPaused || modelSettings.toolChoice === 'none' + ? 'none' + : 'auto'; // Reuse the same speechHandle for the tool response. const toolResponseTask = this.createSpeechTask({ @@ -3586,7 +3672,7 @@ export class AgentActivity implements RecognitionHooks { // important: no agent ouput should be used after this point const { maxToolSteps } = this.agentSession.sessionOptions; - if (speechHandle.numSteps >= maxToolSteps) { + if (speechHandle.numSteps >= maxToolSteps + 1) { this.logger.warn( { speech_id: speechHandle.id, max_tool_steps: maxToolSteps }, 'maximum number of function calls steps reached', @@ -3889,9 +3975,10 @@ export class AgentActivity implements RecognitionHooks { return; } - // When pausing/draining, we ensure that all speech_tasks complete fully. - // This means that even if the SpeechHandle themselves have finished, - // we still wait for the entire execution (e.g function_tools) + // Wait for all speech tasks to complete fully (including function tools). + // Tool-level wedges are bounded inside `ToolExecutor.drain()` (it races the + // abort signal and warns on non-abortable tools), so this wait stays + // responsive without a blanket timer here. await this._mainTask.result; } } @@ -3992,11 +4079,14 @@ export class AgentActivity implements RecognitionHooks { } async close(): Promise { + this.closeAbort.abort(); + const unlock = await this.lock.lock(); try { this.cancelPreemptiveGeneration(); await cancelAndWait(Array.from(this.speechTasks), AgentActivity.REPLY_TASK_CANCEL_TIMEOUT); + await this._toolExecutor.drain(); if (this._currentSpeech && !this._currentSpeech.done()) { this._currentSpeech._markDone(); @@ -4006,6 +4096,7 @@ export class AgentActivity implements RecognitionHooks { this.cancelSpeechPauseTask = undefined; await this._closeSessionResources(); + await this._toolExecutor.aclose(); if (this._mainTask) { await this._mainTask.cancelAndWait(); @@ -4319,9 +4410,101 @@ export class AgentActivity implements RecognitionHooks { this.realtimeSpans?.clear(); await this.realtimeSession?.close(); await this.audioRecognition?.close(); + await this.closeToolsets(); this.realtimeSession = undefined; this.audioRecognition = undefined; } + + private async setupToolsets(): Promise { + // Guard against resume() re-entering _startSession on an activity whose toolsets are + // already initialized. + if (this._toolsetsSetup) return; + this._toolsetsSetup = true; + + const sessionToolsets = new ToolContext(this.agentSession.tools).toolsets; + const agentToolsets = this.agent.toolCtx.toolsets; + + for (const toolset of sessionToolsets) { + if (toolset instanceof AsyncToolset) { + toolset._attachActivity({ activity: null, session: this.agentSession }); + } + } + + for (const toolset of agentToolsets) { + if (toolset instanceof AsyncToolset) { + toolset._attachActivity({ activity: this, session: this.agentSession }); + } + } + + // Agent toolsets are set up (and torn down) per activity. Session toolsets are set up ONCE for + // the session's lifetime — re-running setup() on every handoff would acquire resources (DB + // pools, MCP clients, listeners) without a matching aclose(), which only runs once at session + // close. closeToolsets() likewise only closes the agent's toolsets. + const toSetup = [...agentToolsets]; + if (!this.agentSession._sessionToolsetsSetup) { + this.agentSession._sessionToolsetsSetup = true; + toSetup.push(...sessionToolsets); + } + + await this.setupToolsetList(toSetup); + // Re-flatten now that any factory toolsets have resolved their tools, so they're advertised. + this.agent._toolCtx.updateTools(this.agent._toolCtx.tools); + } + + private async closeToolsets(): Promise { + if (!this._toolsetsSetup) return; + this._toolsetsSetup = false; + await this.closeToolsetList(this.agent.toolCtx.toolsets); + } + + /** + * Refresh the agent's tool context after a dynamic toolset pushed a new tool list (via + * `ToolsetContext.updateTools`), routing through updateTools so history, chat context, and the + * realtime session stay in sync. The next LLM turn picks up the new tools automatically. + */ + private async onToolsetToolsChanged(): Promise { + if (!this._toolsetsSetup) return; + const current = this.agent._toolCtx; + if (new ToolContext(current.tools).equals(current)) return; + // Same toolset entries, so updateTools' setup/close steps are no-ops (no re-entrancy here). + await this.updateTools(current.tools); + } + + private async setupToolsetList(toolsets: readonly Toolset[]): Promise { + const outputs = await Promise.allSettled( + toolsets.map((ts) => + ts.setup({ + // A dynamic toolset pushes a changed tool list here; re-flatten and re-advertise it. + // Route through the session's current activity: a session toolset is set up once (on the + // first activity) but its pushes must re-advertise against whatever agent is active now, + // not the original (possibly closed) activity. + updateTools: (tools) => { + ts._setTools(tools); + const activity = this.agentSession._activity ?? this; + void activity + .onToolsetToolsChanged() + .catch((error) => + activity.logger.error({ error }, 'error re-advertising toolset tools'), + ); + }, + }), + ), + ); + for (const output of outputs) { + if (output.status === 'rejected') { + this.logger.error({ error: output.reason }, 'error setting up toolset'); + } + } + } + + private async closeToolsetList(toolsets: readonly Toolset[]): Promise { + const outputs = await Promise.allSettled(toolsets.map((ts) => ts.aclose())); + for (const output of outputs) { + if (output.status === 'rejected') { + this.logger.error({ error: output.reason }, 'error closing toolset'); + } + } + } } function toOaiToolChoice(toolChoice: ToolChoice | null): ToolChoice | undefined { diff --git a/agents/src/voice/agent_session.ts b/agents/src/voice/agent_session.ts index d5ac4eb2e..411c199f8 100644 --- a/agents/src/voice/agent_session.ts +++ b/agents/src/voice/agent_session.ts @@ -9,6 +9,7 @@ import { ThrowsPromise } from '@livekit/throws-transformer/throws'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import type { Context, Span } from '@opentelemetry/api'; import { context as otelContext, trace } from '@opentelemetry/api'; +import { AsyncLocalStorage } from 'node:async_hooks'; import { EventEmitter } from 'node:events'; import type { ReadableStream } from 'node:stream/web'; import type { z } from 'zod'; @@ -32,7 +33,15 @@ import { ChatMessage, type Instructions, } from '../llm/chat_context.js'; -import type { LLM, RealtimeModel, RealtimeModelError, ToolChoice } from '../llm/index.js'; +import type { + LLM, + RealtimeModel, + RealtimeModelError, + ToolChoice, + ToolContextEntry, + ToolContextInit, +} from '../llm/index.js'; +import { normalizeToolContextInit } from '../llm/index.js'; import type { LLMError } from '../llm/llm.js'; import { log } from '../log.js'; import { type ModelUsage, ModelUsageCollector, filterZeroValues } from '../metrics/model_usage.js'; @@ -46,7 +55,7 @@ import { type ResolvedSessionConnectOptions, type SessionConnectOptions, } from '../types.js'; -import { Task, asError } from '../utils.js'; +import { Event, Task, asError } from '../utils.js'; import type { VAD } from '../vad.js'; import type { Agent } from './agent.js'; import { @@ -94,6 +103,11 @@ import { import type { UnknownUserData } from './run_context.js'; import type { SpeechHandle } from './speech_handle.js'; import { RunResult } from './testing/run_result.js'; +import { + type AsyncToolOptions, + type ToolHandlingOptions, + resolveAsyncToolOptions, +} from './tool_executor.js'; import type { TextTransform } from './transcription/text_transforms.js'; import type { EndpointingOptions } from './turn_config/endpointing.js'; import type { InterruptionOptions } from './turn_config/interruption.js'; @@ -147,6 +161,8 @@ const RECORDING_ALL_OFF: ResolvedRecordingOptions = { transcript: false, }; +const idleHoldStorage = new AsyncLocalStorage(); + /** * Resolve a `record` argument into explicit per-category flags. A boolean turns * every category on or off; a partial object is merged onto all-on so omitted @@ -242,6 +258,8 @@ export type AgentSessionOptions = { tts?: TTS | TTSModelString; userData?: UserData; connOptions?: SessionConnectOptions; + tools?: ToolContextInit; + toolHandling?: ToolHandlingOptions; /** @deprecated use turnHandling.turnDetection instead */ turnDetection?: TurnDetectionMode; @@ -356,6 +374,7 @@ export class AgentSession< private _chatCtx: ChatContext; private _userData: UserData | undefined; + private _tools: readonly ToolContextEntry[] = []; private _userState: UserState = 'listening'; private _agentState: AgentState = 'initializing'; @@ -365,6 +384,8 @@ export class AgentSession< private closing = false; private closingTask: Promise | null = null; private userAwayTimer: NodeJS.Timeout | null = null; + private idleHolds = 0; + private idleReleased = new Event(); private _aecWarmupTimer: NodeJS.Timeout | null = null; @@ -417,6 +438,12 @@ export class AgentSession< /** @internal Resolved per-category recording options for this session. */ _recordingOptions: ResolvedRecordingOptions = { ...RECORDING_ALL_OFF }; + /** @internal */ + _asyncToolOptions: AsyncToolOptions = resolveAsyncToolOptions(); + + /** @internal */ + _sessionToolsetsSetup = false; + /** @internal True when any recording category is enabled. */ get _enableRecording(): boolean { return ( @@ -454,7 +481,17 @@ export class AgentSession< const { agentSessionOptions: opts, legacyVoiceOptions } = migrateLegacyOptions(options); - const { vad, stt, llm, tts, userData, connOptions, ...resolvedSessionOptions } = opts; + const { + vad, + stt, + llm, + tts, + userData, + connOptions, + tools, + toolHandling, + ...resolvedSessionOptions + } = opts; // Merge user-provided connOptions with defaults this._connOptions = { sttConnOptions: { ...DEFAULT_API_CONNECT_OPTIONS, ...connOptions?.sttConnOptions }, @@ -509,6 +546,8 @@ export class AgentSession< : configuredTurnDetection ?? new InferenceTurnDetector(); this._interruptionDetection = resolvedSessionOptions.turnHandling.interruption?.mode; this._userData = userData; + this._tools = normalizeToolContextInit(tools ?? []); + this._asyncToolOptions = resolveAsyncToolOptions(toolHandling?.asyncOptions); // configurable IO this._input = new AgentInput(this.onAudioInputChanged); @@ -522,14 +561,22 @@ export class AgentSession< this._onUserInputTranscribed = this._onUserInputTranscribed.bind(this); this.on(AgentSessionEventTypes.UserInputTranscribed, this._onUserInputTranscribed); + this.idleReleased.set(); } emit( event: K, ...args: Parameters ): boolean { - const eventData = args[0] as AgentEvent; - this._recordedEvents.push(eventData); + // Only retain events when recording is actually enabled. Otherwise this + // array grows unbounded for the entire (potentially hours-long) session, + // pinning every event's graph (SpeechHandle, OTel spans/contexts, streams) + // and leaking memory even though the events are never reported. The buffer + // is only consumed by makeSessionReport() when recording is enabled. + if (this._enableRecording) { + const eventData = args[0] as AgentEvent; + this._recordedEvents.push(eventData); + } return super.emit(event, ...args); } @@ -1259,6 +1306,58 @@ export class AgentSession< return this.agent; } + get tools(): readonly ToolContextEntry[] { + return [...this._tools]; + } + + async waitForIdle(): Promise { + while (true) { + if (this.closingTask) { + throw new Error('AgentSession is closing'); + } + const activity = this.activity; + if (!activity) { + throw new Error('AgentSession has no active AgentActivity'); + } + try { + await activity.waitForIdle(); + return activity; + } catch (error) { + if (this.activity === activity) { + throw error; + } + } + } + } + + async waitForIdleAndHold(fn: (activity: AgentActivity) => Promise | T): Promise { + const activity = await this.waitForIdle(); + this.idleHolds += 1; + this.idleReleased.clear(); + try { + return await idleHoldStorage.run(true, () => fn(activity)); + } finally { + this.idleHolds -= 1; + if (this.idleHolds === 0) { + this.idleReleased.set(); + } + } + } + + /** + * Wait until any foreground idle-hold (`waitForIdleAndHold`) is released. + * Returns `true` if it actually waited for a release — callers use that to + * re-verify idleness, since work may have resumed during the hold. + * @internal + */ + async _waitForIdleHoldReleased(): Promise { + if (this.idleHolds > 0 && !idleHoldStorage.getStore()) { + await this.idleReleased.wait(); + return true; + } + return false; + } + async close(): Promise { await this.closeImpl(CloseReason.USER_INITIATED); } @@ -1563,6 +1662,12 @@ export class AgentSession< await this.activity?.close(); this.activity = undefined; + const sessionToolsets = this._tools.filter( + (tool): tool is ToolContextEntry & { aclose: () => Promise } => + typeof (tool as { aclose?: unknown }).aclose === 'function', + ); + await Promise.allSettled(sessionToolsets.map((toolset) => toolset.aclose())); + if (this.sessionSpan) { this.sessionSpan.end(); this.sessionSpan = undefined; diff --git a/agents/src/voice/agent_task_handoff_eou.test.ts b/agents/src/voice/agent_task_handoff_eou.test.ts index 6006a6e97..e490852fe 100644 --- a/agents/src/voice/agent_task_handoff_eou.test.ts +++ b/agents/src/voice/agent_task_handoff_eou.test.ts @@ -213,15 +213,16 @@ describe('AgentTask handoff end-of-turn timing', () => { const makeStepTask = (step: number): AgentTask<{ step: number }> => { const task: AgentTask<{ step: number }> = new AgentTask<{ step: number }>({ instructions: `You are handling step ${step}. Wait for the caller to finish it.`, - tools: { - completeStep: tool({ + tools: [ + tool({ + name: 'completeStep', description: `Record that step ${step} is complete.`, execute: async () => { toolExecutedAt[step] = Date.now(); task.complete({ step }); }, }), - }, + ], }); return task; }; diff --git a/agents/src/voice/agent_v2.ts b/agents/src/voice/agent_v2.ts new file mode 100644 index 000000000..2f00402ce --- /dev/null +++ b/agents/src/voice/agent_v2.ts @@ -0,0 +1,472 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { AudioFrame } from '@livekit/rtc-node'; +import type { ReadableStream } from 'node:stream/web'; +import type { Instructions, ReadonlyChatContext } from '../llm/chat_context.js'; +import type { + ChatChunk, + ChatContext, + ChatMessage, + LLM, + RealtimeModel, + ToolContext, +} from '../llm/index.js'; +import type { STT, SpeechEvent } from '../stt/index.js'; +import type { TTS } from '../tts/index.js'; +import type { FlushSentinel } from '../types.js'; +import { readStream, toStream } from '../utils.js'; +import type { VAD } from '../vad.js'; +import type { Agent, AgentOptions, AgentTask, AgentTaskOptions, ModelSettings } from './agent.js'; +import type { AgentSession } from './agent_session.js'; +import type { TurnHandlingOptions } from './turn_config/turn_handling.js'; + +/** Context passed to hooks created with `Agent.create()`. */ +export interface AgentContext { + /** The agent instance currently running the hook. */ + agent: Agent; + /** Voice activity detector configured for the agent. */ + vad: VAD | undefined; + /** Speech-to-text model configured for the agent. */ + stt: STT | undefined; + /** LLM or realtime model configured for the agent. */ + llm: LLM | RealtimeModel | undefined; + /** Text-to-speech model configured for the agent. */ + tts: TTS | undefined; + /** Whether TTS-aligned transcripts are enabled for the agent. */ + useTtsAlignedTranscript: boolean | undefined; + /** Readonly view of the agent's current chat context. */ + chatCtx: ReadonlyChatContext; + /** Agent identifier. */ + id: string; + /** Agent instructions. */ + instructions: string | Instructions; + /** Copy of the agent tool context. */ + toolCtx: ToolContext; + /** Current session for the agent. */ + session: AgentSession; + /** Agent-level turn handling configuration. */ + turnHandling: Partial | undefined; + /** Minimum delay between consecutive speech. */ + minConsecutiveSpeechDelay: number | undefined; +} + +/** Return type for stream hooks. Returning `null` stops that pipeline node. */ +export type AgentHookNodeResult = AsyncIterable | Promise | null> | null; + +export interface AgentHooks< + UserData, + ContextT extends AgentContext = AgentContext, +> { + /** Called when the agent becomes active in a session. */ + onEnter?: (ctx: ContextT) => Promise | void; + /** Called when the agent is leaving the active session. */ + onExit?: (ctx: ContextT) => Promise | void; + /** Called after the user's turn has been committed to the chat context. */ + onUserTurnCompleted?: ( + ctx: ContextT, + chatCtx: ChatContext, + newMessage: ChatMessage, + ) => Promise | void; + /** Transforms incoming audio into speech events or transcript text for the agent. */ + sttNode?: ( + ctx: ContextT, + audio: AsyncIterable, + modelSettings: ModelSettings, + ) => AgentHookNodeResult; + /** Produces LLM chunks or text from the current chat and tool context. */ + llmNode?: ( + ctx: ContextT, + chatCtx: ChatContext, + toolCtx: ToolContext, + modelSettings: ModelSettings, + ) => AgentHookNodeResult; + /** Synthesizes agent text into audio frames for playout. */ + ttsNode?: ( + ctx: ContextT, + text: AsyncIterable, + modelSettings: ModelSettings, + ) => AgentHookNodeResult; + /** Processes realtime model audio before it is sent to the agent output. */ + realtimeAudioOutputNode?: ( + ctx: ContextT, + audio: AsyncIterable, + modelSettings: ModelSettings, + ) => AgentHookNodeResult; +} + +export interface AgentCreateOptions + extends AgentOptions, + AgentHooks {} + +/** Context passed to hooks created with `AgentTask.create()`. */ +export interface AgentTaskContext + extends AgentContext { + /** The task instance currently running the hook. */ + agent: AgentTask; + /** Complete the task with either a result or an error. */ + complete(result: ResultT | Error): void; +} + +export interface AgentTaskCreateOptions + extends AgentTaskOptions, + AgentHooks> {} + +// agent.ts passes these runtime base classes in to avoid a circular runtime import. +type AgentCtor = new (options: AgentOptions) => Agent; + +type AgentTaskCtor = new ( + options: AgentTaskOptions, +) => AgentTask; + +export function createAgentV2( + AgentBase: AgentCtor, + options: AgentCreateOptions, +): Agent { + class AgentV2 extends AgentBase { + private readonly hookAdapter: AgentHookAdapter>; + + constructor({ + onEnter, + onExit, + onUserTurnCompleted, + sttNode, + llmNode, + ttsNode, + realtimeAudioOutputNode, + ...agentOptions + }: AgentCreateOptions) { + super({ + ...agentOptions, + id: agentOptions.id ?? 'default_agent', + }); + + this.hookAdapter = new AgentHookAdapter( + { + onEnter, + onExit, + onUserTurnCompleted, + sttNode, + llmNode, + ttsNode, + realtimeAudioOutputNode, + }, + new AgentHookContext(this), + ); + } + + override async onEnter(): Promise { + return this.hookAdapter.onEnter(() => super.onEnter()); + } + + override async onExit(): Promise { + return this.hookAdapter.onExit(() => super.onExit()); + } + + override async onUserTurnCompleted( + chatCtx: ChatContext, + newMessage: ChatMessage, + ): Promise { + return this.hookAdapter.onUserTurnCompleted(chatCtx, newMessage, () => + super.onUserTurnCompleted(chatCtx, newMessage), + ); + } + + override async sttNode( + audio: ReadableStream, + modelSettings: ModelSettings, + ): Promise | null> { + return this.hookAdapter.sttNode(audio, modelSettings, () => + super.sttNode(audio, modelSettings), + ); + } + + override async llmNode( + chatCtx: ChatContext, + toolCtx: ToolContext, + modelSettings: ModelSettings, + ): Promise | null> { + return this.hookAdapter.llmNode(chatCtx, toolCtx, modelSettings, () => + super.llmNode(chatCtx, toolCtx, modelSettings), + ); + } + + override async ttsNode( + text: ReadableStream, + modelSettings: ModelSettings, + ): Promise | null> { + return this.hookAdapter.ttsNode(text, modelSettings, () => + super.ttsNode(text, modelSettings), + ); + } + + override async realtimeAudioOutputNode( + audio: ReadableStream, + modelSettings: ModelSettings, + ): Promise | null> { + return this.hookAdapter.realtimeAudioOutputNode(audio, modelSettings, () => + super.realtimeAudioOutputNode(audio, modelSettings), + ); + } + } + + return new AgentV2(options); +} + +export function createAgentTaskV2( + AgentTaskBase: AgentTaskCtor, + options: AgentTaskCreateOptions, +): AgentTask { + class AgentTaskV2 extends AgentTaskBase { + private readonly hookAdapter: AgentHookAdapter>; + + constructor({ + onEnter, + onExit, + onUserTurnCompleted, + sttNode, + llmNode, + ttsNode, + realtimeAudioOutputNode, + ...taskOptions + }: AgentTaskCreateOptions) { + super({ + ...taskOptions, + id: taskOptions.id ?? 'default_agent', + }); + + this.hookAdapter = new AgentHookAdapter( + { + onEnter, + onExit, + onUserTurnCompleted, + sttNode, + llmNode, + ttsNode, + realtimeAudioOutputNode, + }, + new AgentTaskHookContext(this), + ); + } + + override async onEnter(): Promise { + return this.hookAdapter.onEnter(() => super.onEnter()); + } + + override async onExit(): Promise { + return this.hookAdapter.onExit(() => super.onExit()); + } + + override async onUserTurnCompleted( + chatCtx: ChatContext, + newMessage: ChatMessage, + ): Promise { + return this.hookAdapter.onUserTurnCompleted(chatCtx, newMessage, () => + super.onUserTurnCompleted(chatCtx, newMessage), + ); + } + + override async sttNode( + audio: ReadableStream, + modelSettings: ModelSettings, + ): Promise | null> { + return this.hookAdapter.sttNode(audio, modelSettings, () => + super.sttNode(audio, modelSettings), + ); + } + + override async llmNode( + chatCtx: ChatContext, + toolCtx: ToolContext, + modelSettings: ModelSettings, + ): Promise | null> { + return this.hookAdapter.llmNode(chatCtx, toolCtx, modelSettings, () => + super.llmNode(chatCtx, toolCtx, modelSettings), + ); + } + + override async ttsNode( + text: ReadableStream, + modelSettings: ModelSettings, + ): Promise | null> { + return this.hookAdapter.ttsNode(text, modelSettings, () => + super.ttsNode(text, modelSettings), + ); + } + + override async realtimeAudioOutputNode( + audio: ReadableStream, + modelSettings: ModelSettings, + ): Promise | null> { + return this.hookAdapter.realtimeAudioOutputNode(audio, modelSettings, () => + super.realtimeAudioOutputNode(audio, modelSettings), + ); + } + } + + return new AgentTaskV2(options); +} + +class AgentHookAdapter> { + constructor( + private readonly hooks: AgentHooks, + private readonly context: ContextT, + ) {} + + async onEnter(fallback: () => Promise): Promise { + if (!this.hooks.onEnter) { + return fallback(); + } + + return this.hooks.onEnter(this.context); + } + + async onExit(fallback: () => Promise): Promise { + if (!this.hooks.onExit) { + return fallback(); + } + + return this.hooks.onExit(this.context); + } + + async onUserTurnCompleted( + chatCtx: ChatContext, + newMessage: ChatMessage, + fallback: () => Promise, + ): Promise { + if (!this.hooks.onUserTurnCompleted) { + return fallback(); + } + + return this.hooks.onUserTurnCompleted(this.context, chatCtx, newMessage); + } + + async sttNode( + audio: ReadableStream, + modelSettings: ModelSettings, + fallback: () => Promise | null>, + ): Promise | null> { + if (!this.hooks.sttNode) { + return fallback(); + } + + const result = await this.hooks.sttNode(this.context, readStream(audio), modelSettings); + return result === null ? null : toStream(result); + } + + async llmNode( + chatCtx: ChatContext, + toolCtx: ToolContext, + modelSettings: ModelSettings, + fallback: () => Promise | null>, + ): Promise | null> { + if (!this.hooks.llmNode) { + return fallback(); + } + + const result = await this.hooks.llmNode( + this.context, + chatCtx, + toolCtx as ToolContext, + modelSettings, + ); + return result === null ? null : toStream(result); + } + + async ttsNode( + text: ReadableStream, + modelSettings: ModelSettings, + fallback: () => Promise | null>, + ): Promise | null> { + if (!this.hooks.ttsNode) { + return fallback(); + } + + const result = await this.hooks.ttsNode(this.context, readStream(text), modelSettings); + return result === null ? null : toStream(result); + } + + async realtimeAudioOutputNode( + audio: ReadableStream, + modelSettings: ModelSettings, + fallback: () => Promise | null>, + ): Promise | null> { + if (!this.hooks.realtimeAudioOutputNode) { + return fallback(); + } + + const result = await this.hooks.realtimeAudioOutputNode( + this.context, + readStream(audio), + modelSettings, + ); + return result === null ? null : toStream(result); + } +} + +class AgentHookContext implements AgentContext { + constructor(readonly agent: Agent) {} + + get vad(): VAD | undefined { + return this.agent.vad; + } + + get stt(): STT | undefined { + return this.agent.stt; + } + + get llm(): LLM | RealtimeModel | undefined { + return this.agent.llm; + } + + get tts(): TTS | undefined { + return this.agent.tts; + } + + get useTtsAlignedTranscript(): boolean | undefined { + return this.agent.useTtsAlignedTranscript; + } + + get chatCtx(): ReadonlyChatContext { + return this.agent.chatCtx; + } + + get id(): string { + return this.agent.id; + } + + get instructions(): string | Instructions { + return this.agent.instructions; + } + + get toolCtx(): ToolContext { + return this.agent.toolCtx; + } + + get session(): AgentSession { + return this.agent.session; + } + + get turnHandling(): Partial | undefined { + return this.agent.turnHandling; + } + + get minConsecutiveSpeechDelay(): number | undefined { + return this.agent.minConsecutiveSpeechDelay; + } +} + +class AgentTaskHookContext + extends AgentHookContext + implements AgentTaskContext +{ + declare readonly agent: AgentTask; + + constructor(agent: AgentTask) { + super(agent); + } + + complete(result: ResultT | Error): void { + this.agent.complete(result); + } +} diff --git a/agents/src/voice/amd.test.ts b/agents/src/voice/amd.test.ts index 1b79ffa80..0c7bf8307 100644 --- a/agents/src/voice/amd.test.ts +++ b/agents/src/voice/amd.test.ts @@ -7,7 +7,7 @@ import type { ChatContext } from '../llm/chat_context.js'; import { FunctionCall } from '../llm/chat_context.js'; import type { ChatChunk } from '../llm/llm.js'; import { LLM, type LLMStream } from '../llm/llm.js'; -import type { ToolChoice, ToolContext } from '../llm/tool_context.js'; +import type { ToolChoice, ToolContextLike } from '../llm/tool_context.js'; import type { SpeechEvent, SpeechStream } from '../stt/stt.js'; import { STT } from '../stt/stt.js'; import type { APIConnectOptions } from '../types.js'; @@ -48,7 +48,7 @@ class StaticLLM extends LLM { connOptions: _connOptions, }: { chatCtx: ChatContext; - toolCtx?: ToolContext; + toolCtx?: ToolContextLike; connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: ToolChoice; @@ -270,7 +270,7 @@ describe('AMD', () => { } chat({}: { chatCtx: ChatContext; - toolCtx?: ToolContext; + toolCtx?: ToolContextLike; connOptions?: APIConnectOptions; }): LLMStream { return { @@ -547,7 +547,7 @@ describe('AMD', () => { label(): string { return 'postpone-llm'; } - chat({}: { chatCtx: ChatContext; toolCtx?: ToolContext }): LLMStream { + chat({}: { chatCtx: ChatContext; toolCtx?: ToolContextLike }): LLMStream { callCount += 1; const isFirst = callCount === 1; return { diff --git a/agents/src/voice/amd.ts b/agents/src/voice/amd.ts index 848d6a511..be8042d78 100644 --- a/agents/src/voice/amd.ts +++ b/agents/src/voice/amd.ts @@ -14,8 +14,7 @@ import type { LLMModels, STTModels } from '../inference/index.js'; import { ChatContext } from '../llm/chat_context.js'; import type { FunctionCall } from '../llm/chat_context.js'; import { LLM, type LLMStream } from '../llm/llm.js'; -import { isFunctionTool, tool } from '../llm/tool_context.js'; -import type { ToolContext } from '../llm/tool_context.js'; +import { ToolContext, type ToolContextEntry, isFunctionTool, tool } from '../llm/tool_context.js'; import { log } from '../log.js'; import { STT, SpeechEventType, type SpeechStream } from '../stt/stt.js'; import { traceTypes, tracer } from '../telemetry/index.js'; @@ -1181,6 +1180,7 @@ export class AMD extends (EventEmitter as new () => TypedEmitter) const isStale = (): boolean => generation !== this.detectGeneration || this.settled; const savePrediction = tool({ + name: 'save_prediction', description: 'Save the AMD prediction to the verdict.', parameters: z.object({ label: z.enum([ @@ -1213,6 +1213,7 @@ export class AMD extends (EventEmitter as new () => TypedEmitter) }); const postponeTermination = tool({ + name: 'postpone_termination', description: 'Postpone the termination of the classification task. ' + 'Use when the transcript is ambiguous and more audio is expected.', @@ -1244,10 +1245,11 @@ export class AMD extends (EventEmitter as new () => TypedEmitter) }, }); - const toolCtx: ToolContext = { save_prediction: savePrediction }; + const toolList: ToolContextEntry[] = [savePrediction]; if (this.extensionCount < MAX_EXTENSIONS) { - toolCtx.postpone_termination = postponeTermination; + toolList.push(postponeTermination); } + const toolCtx = new ToolContext(toolList); const chatCtx = new ChatContext(); chatCtx.addMessage({ role: 'system', content: this.prompt }); @@ -1282,7 +1284,7 @@ export class AMD extends (EventEmitter as new () => TypedEmitter) // Execute tool calls (save_prediction populates `savedResult`, // postpone_termination mutates the silence timer and returns). for (const tc of toolCalls) { - const fnTool = toolCtx[tc.name]; + const fnTool = toolCtx.getFunctionTool(tc.name); if (!fnTool || !isFunctionTool(fnTool)) continue; let parsedArgs: unknown = {}; try { diff --git a/agents/src/voice/audio_recognition.ts b/agents/src/voice/audio_recognition.ts index dc89a1b7e..0232d766b 100644 --- a/agents/src/voice/audio_recognition.ts +++ b/agents/src/voice/audio_recognition.ts @@ -69,9 +69,9 @@ export interface EndOfTurnInfo { /** Confidence score of the transcript (0-1). */ transcriptConfidence: number; /** Delay from speech stop to final transcription in milliseconds. */ - transcriptionDelay: number; + transcriptionDelay: number | undefined; /** Delay from speech stop to end of utterance detection in milliseconds. */ - endOfUtteranceDelay: number; + endOfUtteranceDelay: number | undefined; /** Timestamp when user started speaking (milliseconds since epoch). */ startedSpeakingAt: number | undefined; /** Timestamp when user stopped speaking (milliseconds since epoch). */ @@ -85,6 +85,46 @@ export interface EndOfTurnInfo { skipReply?: boolean; } +type EndOfTurnMetrics = { + startedSpeakingAt: number | undefined; + stoppedSpeakingAt: number | undefined; + transcriptionDelay: number | undefined; + endOfUtteranceDelay: number | undefined; +}; + +function computeEndOfTurnMetrics({ + speechStartTime, + lastSpeakingTime, + lastFinalTranscriptTime, + now, +}: { + speechStartTime: number | undefined; + lastSpeakingTime: number | undefined; + lastFinalTranscriptTime: number; + now: number; +}): EndOfTurnMetrics { + if ( + lastFinalTranscriptTime === 0 || + lastSpeakingTime === undefined || + speechStartTime === undefined || + lastSpeakingTime < speechStartTime + ) { + return { + startedSpeakingAt: undefined, + stoppedSpeakingAt: undefined, + transcriptionDelay: undefined, + endOfUtteranceDelay: undefined, + }; + } + + return { + startedSpeakingAt: speechStartTime, + stoppedSpeakingAt: lastSpeakingTime, + transcriptionDelay: Math.max(lastFinalTranscriptTime - lastSpeakingTime, 0), + endOfUtteranceDelay: Math.max(now - lastSpeakingTime, 0), + }; +} + export interface PreemptiveGenerationInfo { newTranscript: string; transcriptConfidence: number; @@ -1646,39 +1686,31 @@ export class AudioRecognition { this.finalTranscriptConfidence.length : 0; - let startedSpeakingAt: number | undefined; - let stoppedSpeakingAt: number | undefined; - let transcriptionDelay: number | undefined; - let endOfUtteranceDelay: number | undefined; - - // sometimes, we can't calculate the metrics because VAD was unreliable. - // in this case, we just ignore the calculation, it's better than providing likely wrong values - if ( - lastFinalTranscriptTime !== 0 && - lastSpeakingTime !== undefined && - speechStartTime !== undefined - ) { - startedSpeakingAt = speechStartTime; - stoppedSpeakingAt = lastSpeakingTime; - transcriptionDelay = Math.max(lastFinalTranscriptTime - lastSpeakingTime, 0); - endOfUtteranceDelay = Date.now() - lastSpeakingTime; - } + // sometimes, we can't calculate the metrics because VAD was unreliable or + // the speaking anchor is stale/out-of-order. in this case, we just ignore the + // calculation, it's better than providing likely wrong values + const metrics = computeEndOfTurnMetrics({ + speechStartTime, + lastSpeakingTime, + lastFinalTranscriptTime, + now: Date.now(), + }); const committed = await this.hooks.onEndOfTurn({ newTranscript: this.audioTranscript, transcriptConfidence: confidenceAvg, - transcriptionDelay: transcriptionDelay ?? 0, - endOfUtteranceDelay: endOfUtteranceDelay ?? 0, - startedSpeakingAt, - stoppedSpeakingAt, + transcriptionDelay: metrics.transcriptionDelay, + endOfUtteranceDelay: metrics.endOfUtteranceDelay, + startedSpeakingAt: metrics.startedSpeakingAt, + stoppedSpeakingAt: metrics.stoppedSpeakingAt, }); if (committed) { this._endUserTurnSpan({ transcript: this.audioTranscript, confidence: confidenceAvg, - transcriptionDelay: transcriptionDelay ?? 0, - endOfUtteranceDelay: endOfUtteranceDelay ?? 0, + transcriptionDelay: metrics.transcriptionDelay ?? 0, + endOfUtteranceDelay: metrics.endOfUtteranceDelay ?? 0, }); // clear the transcript if the user turn was committed diff --git a/agents/src/voice/generation.ts b/agents/src/voice/generation.ts index 9eaf98194..c047a6d46 100644 --- a/agents/src/voice/generation.ts +++ b/agents/src/voice/generation.ts @@ -17,6 +17,7 @@ import { } from '../llm/chat_context.js'; import type { ChatChunk } from '../llm/llm.js'; import { + type JSONObject, type ToolChoice, type ToolContext, ToolError, @@ -58,6 +59,8 @@ import { } from './io.js'; import { RunContext } from './run_context.js'; import type { SpeechHandle } from './speech_handle.js'; +import { getMockTool } from './testing/run_result.js'; +import { ToolExecutor, buildExecutorMap } from './tool_executor.js'; import { type TextTransform, applyTextTransforms } from './transcription/text_transforms.js'; export const DEFAULT_TTS_READ_IDLE_TIMEOUT_MS = 10_000; @@ -836,10 +839,15 @@ async function forwardAudio( const reader = ttsStream.getReader(); let resampler: AudioResampler | null = null; - // The audio output is shared across overlapping segments, so ignore a - // PLAYBACK_STARTED from another segment until we capture our own first frame. - // Resolving `firstFrameFut` early skips resampler creation and pushes an - // unresampled frame (`RtcError: sample_rate and num_channels don't match`). + // The audio output is shared across overlapping segments. When a speech is + // interrupted, the main loop immediately authorizes the next speech, so this + // forwarder can register its listener while the interrupted segment's teardown + // is still emitting PLAYBACK_STARTED on the same output. Only honor the event + // once this loop has captured its own first frame, so a stray event from + // another segment can't resolve our `firstFrameFut` prematurely. A premature + // resolution skips resampler creation (gated on `!firstFrameFut.done`) and + // pushes an unresampled frame to the AudioSource, raising + // `RtcError: sample_rate and num_channels don't match`. let hasCapturedOwnFrame = false; const onPlaybackStarted = (ev: { createdAt: number }) => { @@ -954,7 +962,6 @@ export function performToolExecutions({ output: [], firstToolStartedFuture: new Future(), }; - const toolCompleted = (out: ToolExecutionOutput) => { onToolExecutionCompleted(out); toolOutput.output.push(out); @@ -964,6 +971,19 @@ export function performToolExecutions({ const signal = controller.signal; const reader = toolCallStream.getReader(); + // Production always has an activity (and thus a shared executor). Fall back to a standalone + // executor when it's absent (edge cases / unit tests) instead of dropping every tool call, + // which would leave callers awaiting `firstToolStartedFuture` hanging forever. + const activity = session._activity; + const defaultExecutor = activity?._toolExecutor ?? new ToolExecutor({ owningActivity: null }); + + // Route AsyncToolset members to their own executor so session-scoped async + // tools survive handoff; everything else falls back to the activity executor. + const executorByName = buildExecutorMap({ + toolsets: toolCtx.toolsets, + defaultExecutor, + }); + const tasks: Task[] = []; while (!signal.aborted) { const { done, value: toolCall } = await reader.read(); @@ -983,8 +1003,10 @@ export function performToolExecutions({ // TODO(brian): assert other toolChoice values - const tool = toolCtx[toolCall.name]; + const tool = toolCtx.getFunctionTool(toolCall.name); if (!tool) { + const availableTools = sortedToolNames(toolCtx).join(', '); + const message = `Unknown function: ${toolCall.name} - available tools: ${availableTools}`; logger.warn( { function: toolCall.name, @@ -992,6 +1014,12 @@ export function performToolExecutions({ }, `unknown AI function ${toolCall.name}`, ); + toolCompleted( + createToolOutput({ + toolCall, + exception: new ToolError(message), + }), + ); continue; } @@ -1050,10 +1078,6 @@ export function performToolExecutions({ continue; } - if (!toolOutput.firstToolStartedFuture.done) { - toolOutput.firstToolStartedFuture.resolve(); - } - onToolExecutionStarted(toolCall); logger.info( @@ -1136,10 +1160,37 @@ export function performToolExecutions({ const toolExecution = functionCallStorage.run( { functionCall: toolCall, speechHandle }, async () => { - return await tool.execute(parsedArgs, { - ctx: new RunContext(session, speechHandle, toolCall), - toolCallId: toolCall.callId, + const runCtx = new RunContext(session, speechHandle, toolCall); + const mock = getMockTool(session.currentAgent, toolCall.name); + const toolToExecute = mock + ? { + ...tool, + execute: mock, + } + : tool; + + if (mock) { + logger.debug( + { + function: toolCall.name, + arguments: parsedArgs, + speech_id: speechHandle.id, + }, + 'executing mock tool', + ); + } + + const executor = executorByName.get(toolCall.name) ?? defaultExecutor; + return await executor.execute({ + tool: toolToExecute, + runCtx, + rawArguments: parsedArgs as JSONObject, abortSignal: signal, + onUserToolStarted: () => { + if (!toolOutput.firstToolStartedFuture.done) { + toolOutput.firstToolStartedFuture.resolve(); + } + }, }); }, ); diff --git a/agents/src/voice/generation_tools.test.ts b/agents/src/voice/generation_tools.test.ts index 2589fe7f4..0cd95e636 100644 --- a/agents/src/voice/generation_tools.test.ts +++ b/agents/src/voice/generation_tools.test.ts @@ -4,7 +4,7 @@ import { ReadableStream as NodeReadableStream } from 'stream/web'; import { afterEach, describe, expect, it, vi } from 'vitest'; import { z } from 'zod'; -import { FunctionCall, type ToolContext, ToolError, tool } from '../llm/index.js'; +import { FunctionCall, ToolContext, ToolError, tool } from '../llm/index.js'; import { initializeLogger } from '../log.js'; import type { Task } from '../utils.js'; import { cancelAndWait, delay } from '../utils.js'; @@ -70,6 +70,7 @@ describe('Generation + Tool Execution', () => { // Tool that takes > 5 seconds let toolAborted = false; const getWeather = tool({ + name: 'getWeather', description: 'weather', parameters: z.object({ location: z.string() }), execute: async ({ location }, { abortSignal }) => { @@ -92,7 +93,7 @@ describe('Generation + Tool Execution', () => { const [execTask, toolOutput] = performToolExecutions({ session: {} as any, speechHandle: { id: 'speech_test', _itemAdded: () => {} } as any, - toolCtx: { getWeather } as any, + toolCtx: new ToolContext([getWeather]) as any, toolCallStream, controller: replyAbortController, onToolExecutionStarted: () => {}, @@ -123,6 +124,7 @@ describe('Generation + Tool Execution', () => { const replyAbortController = new AbortController(); const echo = tool({ + name: 'echo', description: 'echo', parameters: z.object({ msg: z.string() }), execute: async ({ msg }) => `echo: ${msg}`, @@ -138,7 +140,7 @@ describe('Generation + Tool Execution', () => { const [execTask, toolOutput] = performToolExecutions({ session: {} as any, speechHandle: { id: 'speech_test2', _itemAdded: () => {} } as any, - toolCtx: { echo } as any, + toolCtx: new ToolContext([echo]) as any, toolCallStream, controller: replyAbortController, }); @@ -154,6 +156,7 @@ describe('Generation + Tool Execution', () => { const replyAbortController = new AbortController(); const removeOrderItem = tool({ + name: 'removeOrderItem', description: 'remove order item', parameters: z.object({ orderId: z.array(z.string()) }), execute: async ({ orderId }) => orderId.join(','), @@ -170,7 +173,7 @@ describe('Generation + Tool Execution', () => { const [execTask, toolOutput] = performToolExecutions({ session: {} as AgentSession, speechHandle: { id: 'speech_repair', _itemAdded: () => {} } as unknown as SpeechHandle, - toolCtx: { removeOrderItem } as unknown as ToolContext, + toolCtx: new ToolContext([removeOrderItem]) as unknown as ToolContext, toolCallStream, controller: replyAbortController, }); @@ -189,6 +192,7 @@ describe('Generation + Tool Execution', () => { let aborted = false; const longOp = tool({ + name: 'longOp', description: 'longOp', parameters: z.object({ ms: z.number() }), execute: async ({ ms }, { abortSignal }) => { @@ -210,7 +214,7 @@ describe('Generation + Tool Execution', () => { const [execTask, toolOutput] = performToolExecutions({ session: {} as any, speechHandle: { id: 'speech_abort', _itemAdded: () => {} } as any, - toolCtx: { longOp } as any, + toolCtx: new ToolContext([longOp]) as any, toolCallStream, controller: replyAbortController, }); @@ -229,6 +233,7 @@ describe('Generation + Tool Execution', () => { const replyAbortController = new AbortController(); const echo = tool({ + name: 'echo', description: 'echo', parameters: z.object({ msg: z.string() }), execute: async ({ msg }) => `echo: ${msg}`, @@ -245,7 +250,7 @@ describe('Generation + Tool Execution', () => { const [execTask, toolOutput] = performToolExecutions({ session: {} as any, speechHandle: { id: 'speech_invalid', _itemAdded: () => {} } as any, - toolCtx: { echo } as any, + toolCtx: new ToolContext([echo]) as any, toolCallStream, controller: replyAbortController, }); @@ -267,6 +272,7 @@ describe('Generation + Tool Execution', () => { const replyAbortController = new AbortController(); const echo = tool({ + name: 'echo', description: 'echo', parameters: z.object({ msg: z.string() }), execute: async ({ msg }) => `echo: ${msg}`, @@ -283,7 +289,7 @@ describe('Generation + Tool Execution', () => { const [execTask, toolOutput] = performToolExecutions({ session: {} as any, speechHandle: { id: 'speech_bad_json', _itemAdded: () => {} } as any, - toolCtx: { echo } as any, + toolCtx: new ToolContext([echo]) as any, toolCallStream, controller: replyAbortController, }); @@ -304,6 +310,7 @@ describe('Generation + Tool Execution', () => { // The tool throws a regular Error whose message contains internals (db URL, // credentials) we must NOT forward to the LLM (and from there to end users). const sensitive = tool({ + name: 'sensitive', description: 'sensitive', parameters: z.object({}), execute: async () => { @@ -321,7 +328,7 @@ describe('Generation + Tool Execution', () => { const [execTask, toolOutput] = performToolExecutions({ session: {} as any, speechHandle: { id: 'speech_generic_err', _itemAdded: () => {} } as any, - toolCtx: { sensitive } as any, + toolCtx: new ToolContext([sensitive]) as any, toolCallStream, controller: replyAbortController, }); @@ -345,6 +352,7 @@ describe('Generation + Tool Execution', () => { // Tools that intend to give the LLM a corrective hint opt in by throwing // ToolError — its message is forwarded as-is. const checked = tool({ + name: 'checked', description: 'checked', parameters: z.object({ qty: z.number() }), execute: async ({ qty }) => { @@ -365,7 +373,7 @@ describe('Generation + Tool Execution', () => { const [execTask, toolOutput] = performToolExecutions({ session: {} as any, speechHandle: { id: 'speech_tool_error', _itemAdded: () => {} } as any, - toolCtx: { checked } as any, + toolCtx: new ToolContext([checked]) as any, toolCallStream, controller: replyAbortController, }); @@ -383,11 +391,13 @@ describe('Generation + Tool Execution', () => { const replyAbortController = new AbortController(); const sum = tool({ + name: 'sum', description: 'sum', parameters: z.object({ a: z.number(), b: z.number() }), execute: async ({ a, b }) => a + b, }); const upper = tool({ + name: 'upper', description: 'upper', parameters: z.object({ s: z.string() }), execute: async ({ s }) => s.toUpperCase(), @@ -408,7 +418,7 @@ describe('Generation + Tool Execution', () => { const [execTask, toolOutput] = performToolExecutions({ session: {} as any, speechHandle: { id: 'speech_multi', _itemAdded: () => {} } as any, - toolCtx: { sum, upper } as any, + toolCtx: new ToolContext([sum, upper]) as any, toolCallStream, controller: replyAbortController, }); diff --git a/agents/src/voice/generation_tts_timeout.test.ts b/agents/src/voice/generation_tts_timeout.test.ts index cbe636ebe..5cb3803f9 100644 --- a/agents/src/voice/generation_tts_timeout.test.ts +++ b/agents/src/voice/generation_tts_timeout.test.ts @@ -120,6 +120,7 @@ describe('TTS stream idle timeout', () => { vi.useFakeTimers(); + // Stray PLAYBACK_STARTED before this segment captures anything must be ignored. audioOutput.onPlaybackStarted(Date.now()); expect(audioOut.firstFrameFut.done).toBe(false); @@ -148,11 +149,13 @@ describe('TTS stream idle timeout', () => { const controller = new AbortController(); const [task, audioOut] = performAudioForwarding(stream, audioOutput, controller); + // Stray event before the loop captures anything must not skip resampling. audioOutput.onPlaybackStarted(Date.now()); await task.result; expect(audioOut.firstFrameFut.done).toBe(true); + // Every captured frame must match the output sample rate (i.e. was resampled). expect(audioOutput.capturedFrames.length).toBeGreaterThan(0); for (const f of audioOutput.capturedFrames) { expect(f.sampleRate).toBe(24000); diff --git a/agents/src/voice/index.ts b/agents/src/voice/index.ts index d47931dc9..6d18e5aa5 100644 --- a/agents/src/voice/index.ts +++ b/agents/src/voice/index.ts @@ -1,7 +1,19 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -export { Agent, AgentTask, StopResponse, type AgentOptions, type ModelSettings } from './agent.js'; +export { + Agent, + AgentTask, + StopResponse, + type AgentContext, + type AgentCreateOptions, + type AgentHookNodeResult, + type AgentHooks, + type AgentOptions, + type AgentTaskContext, + type AgentTaskCreateOptions, + type ModelSettings, +} from './agent.js'; export * from './amd.js'; export { AgentSession, @@ -30,6 +42,8 @@ export { type PlaybackFinishedEvent, type PlaybackStartedEvent, type TimedString, + createTimedString, + isTimedString, } from './io.js'; export * from './report.js'; export * from './room_io/index.js'; diff --git a/agents/src/voice/io.ts b/agents/src/voice/io.ts index d2bed007f..38f6c4914 100644 --- a/agents/src/voice/io.ts +++ b/agents/src/voice/io.ts @@ -178,6 +178,15 @@ export abstract class AudioOutput extends EventEmitter { return this.playbackSegmentsCount - this.playbackFinishedCount; } + /** + * Monotonic count of playback segments ever captured. Lets chained outputs detect — free of + * races with concurrent finishes — whether this output accepted a segment they forwarded. + * @internal + */ + get capturedPlayoutSegments(): number { + return this.playbackSegmentsCount; + } + /** * Called when playback actually starts (first frame is sent to output). * Developers building audio sinks should call this when the first frame is captured. diff --git a/agents/src/voice/run_context.test.ts b/agents/src/voice/run_context.test.ts new file mode 100644 index 000000000..3b18644af --- /dev/null +++ b/agents/src/voice/run_context.test.ts @@ -0,0 +1,458 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { EventEmitter } from 'node:events'; +import { describe, expect, it, vi } from 'vitest'; +import { FunctionCall, type FunctionCallOutput } from '../llm/chat_context.js'; +import { Future } from '../utils.js'; +import type { AgentSession } from './agent_session.js'; +import { AgentSessionEventTypes } from './events.js'; +import { RunContext } from './run_context.js'; +import { SpeechHandle } from './speech_handle.js'; + +describe('RunContext async updates', () => { + function buildContext() { + const functionCall = FunctionCall.create({ + callId: 'call_123', + name: 'slow_lookup', + args: '{"query":"flights"}', + }); + const speechHandle = SpeechHandle.create(); + const session = { userData: { userId: 'user_1' } } as unknown as AgentSession<{ + userId: string; + }>; + const ctx = new RunContext(session, speechHandle, functionCall); + return { ctx, functionCall }; + } + + it('first update resolves dispatch and later updates enqueue deferred replies', async () => { + const { ctx, functionCall } = buildContext(); + const firstUpdate = new Future(); + const enqueued: Array<[FunctionCall, FunctionCallOutput]> = []; + + ctx._attachExecutor( + { + toolOptions: { + updateTemplate: 'Tool {functionName} update for {callId}: {message}', + }, + enqueueReply: async (_ctx, items) => { + enqueued.push(items as [FunctionCall, FunctionCallOutput]); + }, + }, + firstUpdate, + ); + + await ctx.update('started'); + + expect(await firstUpdate.await).toBe('Tool slow_lookup update for call_123: started'); + expect(functionCall.extra.__livekit_agents_tool_non_blocking).toBe(true); + expect(enqueued).toHaveLength(0); + + await ctx.update('halfway'); + + expect(enqueued).toHaveLength(1); + expect(enqueued[0]![0].callId).toBe('call_123_update_1'); + expect(enqueued[0]![0].name).toBe('slow_lookup'); + expect(enqueued[0]![1].callId).toBe('call_123_update_1'); + expect(enqueued[0]![1].output).toContain('halfway'); + }); + + it('detached contexts record updates without enqueueing replies', async () => { + const { ctx } = buildContext(); + + await ctx.update('standalone'); + + expect(ctx.updates).toHaveLength(1); + expect(ctx.updates[0]![0].callId).toBe('call_123'); + expect(ctx.updates[0]![1].output).toContain('standalone'); + }); +}); + +describe('RunContext filler', () => { + function sleep(ms: number) { + return new Promise((resolve) => setTimeout(resolve, ms)); + } + + function buildContext() { + const functionCall = FunctionCall.create({ + callId: 'call_123', + name: 'slow_lookup', + args: '{"query":"flights"}', + }); + const speechHandle = SpeechHandle.create(); + const session = new FakeFillerSession(); + const ctx = new RunContext( + session as unknown as AgentSession, + speechHandle, + functionCall, + ); + return { ctx, session }; + } + + it('speaks filler after the session stays idle for the delay', async () => { + const { ctx, session } = buildContext(); + + await ctx.filler('Still searching.', { delay: 10 }, async () => { + await sleep(30); + }); + + expect(session.sayTexts).toEqual(['Still searching.']); + }); + + it('cancels pending filler when the scope exits before the delay', async () => { + const { ctx, session } = buildContext(); + + await ctx.filler('Too late.', { delay: 50 }, async () => { + await sleep(5); + }); + + expect(session.sayTexts).toEqual([]); + }); + + it('restarts the idle dwell when update takes the floor', async () => { + const { ctx, session } = buildContext(); + + await ctx.filler('Still working.', { delay: 30 }, async () => { + await sleep(20); + await ctx.update('halfway'); + await sleep(20); + expect(session.sayTexts).toEqual([]); + await sleep(20); + }); + + expect(session.sayTexts).toEqual(['Still working.']); + }); + + it('does not let scheduler shutdown errors replace the callback result', async () => { + const functionCall = FunctionCall.create({ + callId: 'call_123', + name: 'slow_lookup', + args: '{"query":"flights"}', + }); + const speechHandle = SpeechHandle.create(); + const session = new FakeFillerSession({ + waitForIdleError: new Error('AgentSession is closing'), + }); + const ctx = new RunContext( + session as unknown as AgentSession, + speechHandle, + functionCall, + ); + + await expect(ctx.filler('Still searching.', async () => 'lookup result')).resolves.toBe( + 'lookup result', + ); + }); + + it('removes the abort listener when waitForIdle wins the race', async () => { + const { ctx, session } = buildContext(); + const activeAbortListeners = new Map>(); + const originalAddEventListener = AbortSignal.prototype.addEventListener; + const originalRemoveEventListener = AbortSignal.prototype.removeEventListener; + const addSpy = vi.spyOn(AbortSignal.prototype, 'addEventListener').mockImplementation(function ( + this: AbortSignal, + type: string, + callback: EventListenerOrEventListenerObject | null, + options?: AddEventListenerOptions | boolean, + ) { + if (type === 'abort' && callback) { + const listeners = activeAbortListeners.get(this) ?? new Set(); + listeners.add(callback); + activeAbortListeners.set(this, listeners); + } + return originalAddEventListener.call(this, type, callback, options); + }); + const removeSpy = vi + .spyOn(AbortSignal.prototype, 'removeEventListener') + .mockImplementation(function ( + this: AbortSignal, + type: string, + callback: EventListenerOrEventListenerObject | null, + options?: EventListenerOptions | boolean, + ) { + if (type === 'abort' && callback) { + const listeners = activeAbortListeners.get(this); + listeners?.delete(callback); + if (listeners?.size === 0) { + activeAbortListeners.delete(this); + } + } + return originalRemoveEventListener.call(this, type, callback, options); + }); + + try { + await ctx.filler('Listener cleanup.', { delay: 30 }, async () => { + await sleep(5); + expect([...activeAbortListeners.values()].every((listeners) => listeners.size <= 1)).toBe( + true, + ); + }); + + expect(session.listenerCount(AgentSessionEventTypes.AgentStateChanged)).toBe(0); + expect(session.listenerCount(AgentSessionEventTypes.UserStateChanged)).toBe(0); + } finally { + addSpy.mockRestore(); + removeSpy.mockRestore(); + } + }); + + it('does not create filler when maxSteps is zero', async () => { + const { ctx, session } = buildContext(); + + await ctx.filler('Disabled.', { delay: 0, interval: 1, maxSteps: 0 }, async () => { + await sleep(20); + }); + + expect(session.sayTexts).toEqual([]); + }); + + it('repeats filler on an interval until maxSteps is reached', async () => { + const { ctx, session } = buildContext(); + + await ctx.filler('Still working.', { delay: 0, interval: 5, maxSteps: 3 }, async () => { + await sleep(40); + }); + + expect(session.sayTexts).toEqual(['Still working.', 'Still working.', 'Still working.']); + }); + + it('invokes callable sources lazily and only advances step for created speeches', async () => { + const { ctx, session } = buildContext(); + const steps: number[] = []; + + await ctx.filler( + (step) => { + steps.push(step); + return steps.length === 1 ? null : `step ${step}`; + }, + { delay: 0, interval: 5, maxSteps: 1 }, + async () => { + expect(steps).toEqual([]); + await sleep(25); + }, + ); + + expect(steps).toEqual([0, 0]); + expect(session.sayTexts).toEqual(['step 0']); + }); + + it('accepts SpeechHandle sources without calling session.say', async () => { + const { ctx, session } = buildContext(); + const fillerHandle = SpeechHandle.create(); + fillerHandle._markDone(); + + await ctx.filler( + () => fillerHandle, + { delay: 0 }, + async () => { + await sleep(10); + }, + ); + + expect(session.sayTexts).toEqual([]); + }); + + it('honors an external abort signal before the dwell completes', async () => { + const { ctx, session } = buildContext(); + const abortController = new AbortController(); + + await ctx.filler('Cancelled.', { delay: 30, signal: abortController.signal }, async () => { + await sleep(10); + abortController.abort(); + await sleep(30); + }); + + expect(session.sayTexts).toEqual([]); + }); + + it('honors an already-aborted external signal', async () => { + const { ctx, session } = buildContext(); + const abortController = new AbortController(); + abortController.abort(); + + await ctx.filler( + 'Already cancelled.', + { delay: 0, signal: abortController.signal }, + async () => { + await sleep(10); + }, + ); + + expect(session.sayTexts).toEqual([]); + }); + + it('exits promptly while still waiting for the session to become idle', async () => { + const functionCall = FunctionCall.create({ + callId: 'call_123', + name: 'slow_lookup', + args: '{"query":"flights"}', + }); + const speechHandle = SpeechHandle.create(); + const session = new FakeFillerSession({ waitForIdleNeverResolves: true }); + const ctx = new RunContext( + session as unknown as AgentSession, + speechHandle, + functionCall, + ); + + await expect( + ctx.filler('Still waiting.', { delay: 0 }, async () => { + await sleep(5); + return 'done'; + }), + ).resolves.toBe('done'); + expect(session.sayTexts).toEqual([]); + }); + + it('does not let session.say shutdown errors replace the callback result', async () => { + const functionCall = FunctionCall.create({ + callId: 'call_123', + name: 'slow_lookup', + args: '{"query":"flights"}', + }); + const speechHandle = SpeechHandle.create(); + const session = new FakeFillerSession({ + sayError: new Error('AgentSession is closing, cannot use say()'), + }); + const ctx = new RunContext( + session as unknown as AgentSession, + speechHandle, + functionCall, + ); + + await expect( + ctx.filler('Still searching.', { delay: 0 }, async () => { + await sleep(10); + return 'lookup result'; + }), + ).resolves.toBe('lookup result'); + }); + + it('requires a callback scope', async () => { + const { ctx } = buildContext(); + const filler = ctx.filler as unknown as (source: string) => Promise; + + await expect(filler('x')).rejects.toThrow('RunContext.filler requires a callback scope'); + }); + + it('restarts the idle dwell when agent speech or thinking starts', async () => { + const { ctx, session } = buildContext(); + + await ctx.filler('After agent state.', { delay: 30 }, async () => { + await sleep(20); + session.emitAgentState('speaking'); + await sleep(20); + expect(session.sayTexts).toEqual([]); + session.emitAgentState('thinking'); + await sleep(20); + expect(session.sayTexts).toEqual([]); + await sleep(20); + }); + + expect(session.sayTexts).toEqual(['After agent state.']); + }); + + it('restarts the idle dwell when the user starts speaking', async () => { + const { ctx, session } = buildContext(); + + await ctx.filler('After user state.', { delay: 30 }, async () => { + await sleep(20); + session.emitUserState('speaking'); + await sleep(20); + expect(session.sayTexts).toEqual([]); + await sleep(20); + }); + + expect(session.sayTexts).toEqual(['After user state.']); + }); + + it('stops repeating filler after the owning speech handle is interrupted', async () => { + const functionCall = FunctionCall.create({ + callId: 'call_123', + name: 'slow_lookup', + args: '{"query":"flights"}', + }); + const speechHandle = SpeechHandle.create(); + const session = new FakeFillerSession(); + const ctx = new RunContext( + session as unknown as AgentSession, + speechHandle, + functionCall, + ); + + await ctx.filler('Repeat.', { delay: 0, interval: 5 }, async () => { + await sleep(15); + speechHandle.interrupt(); + const countAtInterrupt = session.sayTexts.length; + await sleep(30); + expect(session.sayTexts).toHaveLength(countAtInterrupt); + }); + }); + + it('validates filler timing options', async () => { + const { ctx } = buildContext(); + + await expect(ctx.filler('x', { delay: -1 }, async () => undefined)).rejects.toThrow( + 'delay must be non-negative', + ); + await expect(ctx.filler('x', { interval: -1 }, async () => undefined)).rejects.toThrow( + 'interval must be non-negative when set', + ); + await expect(ctx.filler('x', { maxSteps: -1 }, async () => undefined)).rejects.toThrow( + 'maxSteps must be non-negative when set', + ); + }); +}); + +class FakeFillerSession extends EventEmitter { + userData = {}; + sayTexts: string[] = []; + + constructor( + private readonly options: { + waitForIdleError?: Error; + waitForIdleNeverResolves?: boolean; + sayError?: Error; + } = {}, + ) { + super(); + } + + async waitForIdle(): Promise { + if (this.options.waitForIdleError) { + throw this.options.waitForIdleError; + } + if (this.options.waitForIdleNeverResolves) { + await new Promise(() => undefined); + } + return; + } + + say(text: string): SpeechHandle { + if (this.options.sayError) { + throw this.options.sayError; + } + this.sayTexts.push(text); + const handle = SpeechHandle.create(); + handle._markDone(); + return handle; + } + + emitAgentState(newState: 'speaking' | 'thinking'): void { + this.emit(AgentSessionEventTypes.AgentStateChanged, { + type: 'agent_state_changed', + oldState: 'idle', + newState, + createdAt: Date.now(), + }); + } + + emitUserState(newState: 'speaking'): void { + this.emit(AgentSessionEventTypes.UserStateChanged, { + type: 'user_state_changed', + oldState: 'listening', + newState, + createdAt: Date.now(), + }); + } +} diff --git a/agents/src/voice/run_context.ts b/agents/src/voice/run_context.ts index df9994ba8..a8f764cdb 100644 --- a/agents/src/voice/run_context.ts +++ b/agents/src/voice/run_context.ts @@ -1,14 +1,64 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import type { FunctionCall } from '../llm/chat_context.js'; +import { FunctionCall, FunctionCallOutput } from '../llm/chat_context.js'; +import type { Future } from '../utils.js'; +import type { AgentActivity } from './agent_activity.js'; import type { AgentSession } from './agent_session.js'; -import type { SpeechHandle } from './speech_handle.js'; +import { + AgentSessionEventTypes, + type AgentStateChangedEvent, + type UserStateChangedEvent, +} from './events.js'; +import { type SpeechHandle, isSpeechHandle } from './speech_handle.js'; export type UnknownUserData = unknown; +export type PromptTemplate = string | ((args: Args) => string); + +export interface UpdatePromptArgs { + functionName: string; + callId: string; + message: string; +} + +export interface RunContextUpdateOptions { + template?: PromptTemplate; +} + +export type FillerSource = string | ((step: number) => SpeechHandle | string | null | undefined); + +export interface RunContextFillerOptions { + /** + * Continuous-idle dwell, in milliseconds, before filler speech fires. + * Defaults to 0, which fires as soon as the session is next idle. + */ + delay?: number; + /** + * Cooldown, in milliseconds, before waiting for another idle dwell. + * When omitted, filler speech fires at most once. + */ + interval?: number; + /** Maximum number of filler speeches to create. */ + maxSteps?: number; + /** Optional external cancellation signal for the filler scheduler. */ + signal?: AbortSignal; +} + +export interface AttachedToolExecutor { + toolOptions: { + updateTemplate: PromptTemplate; + }; + enqueueReply(ctx: RunContext, items: [FunctionCall, FunctionCallOutput]): Promise; + replyTask?: Promise; +} + export class RunContext { private readonly initialStepIdx: number; + private _executor?: AttachedToolExecutor; + private _firstUpdateFuture?: Future; + private _updates: Array<[FunctionCall, FunctionCallOutput]> = []; + private _fillerSchedulers: FillerScheduler[] = []; constructor( public readonly session: AgentSession, public readonly speechHandle: SpeechHandle, @@ -31,4 +81,364 @@ export class RunContext { async waitForPlayout() { return this.speechHandle._waitForGeneration(this.initialStepIdx); } + + get updates(): readonly [FunctionCall, FunctionCallOutput][] { + return this._updates; + } + + disallowInterruptions(): void { + this.speechHandle.allowInterruptions = false; + } + + /** + * Report progress from a long-running tool — and, on the first call, turn the tool + * **non-blocking** so the conversation continues while `execute()` keeps running. + * + * Behavior depends on whether this is the first `update()` for the call: + * + * - **First call:** resolves the pending tool result with `message` and marks the + * function call non-blocking (`functionCall.extra.__livekit_agents_tool_non_blocking = true`). + * The framework treats `message` as the tool's immediate output to the LLM and returns + * control to the session, so the agent can speak/listen while the tool continues in the + * background. Whatever `execute()` ultimately returns is delivered later as a deferred reply. + * - **Subsequent calls:** each `message` is enqueued via the owning executor and delivered as a + * fresh assistant turn, gated on the session being idle (so updates never talk over the user). + * The arrival cadence is therefore conversational, not immediate. + * + * Message rendering: + * - A **string** is rendered through a template — {@link RunContextUpdateOptions.template} if + * provided, otherwise the executor's configured `updateTemplate` — with `{functionName}`, + * `{callId}`, and `{message}` substituted. The default template tells the model the task is + * still running and not to fabricate results. + * - A **non-string** value is used as-is (no templating), letting a tool emit structured output. + * + * Every update is also recorded in {@link RunContext.updates} (with a `_update_N` call-id suffix + * for the 2nd and later updates) so the full progress trail is preserved in chat history. + * + * No-op for delivery when the tool isn't attached to an async-capable executor (e.g. a plain + * blocking tool): the update is still recorded but nothing is sent. + * + * @param message - Progress text (templated) or a structured value (sent verbatim). + * @param options - Per-call overrides; see {@link RunContextUpdateOptions}. + */ + async update(message: unknown, options: RunContextUpdateOptions = {}): Promise { + for (const scheduler of this._fillerSchedulers) { + scheduler.resetDwell(); + } + + const updateStep = this._updates.length; + const renderedMessage = + typeof message === 'string' + ? renderTemplate(options.template ?? this._executor?.toolOptions.updateTemplate, { + functionName: this.functionCall.name, + callId: this.functionCall.callId, + message, + }) + : message; + const pair = this._makeUpdatePair( + renderedMessage, + updateStep > 0 ? `_update_${updateStep}` : '', + ); + this._updates.push(pair); + + if (!this._executor) { + return; + } + + if (this._firstUpdateFuture && !this._firstUpdateFuture.done) { + this._firstUpdateFuture.resolve(renderedMessage); + this.functionCall.extra.__livekit_agents_tool_non_blocking = true; + return; + } + + await this._executor.enqueueReply(this, pair); + } + + async foreground(fn: (activity: AgentActivity) => Promise | T): Promise { + await this._drainPendingReply(); + return this.session.waitForIdleAndHold(fn); + } + + /** + * Speak filler audio while a long-running tool step is in progress. + * + * The scheduler waits until the session is continuously idle for + * {@link RunContextFillerOptions.delay} milliseconds, then speaks `source` + * through {@link AgentSession.say}. When `interval` is set, it waits that + * many milliseconds before starting another idle dwell; otherwise it fires at + * most once. Agent speech, user speech, and {@link RunContext.update} reset + * any pending dwell so filler does not race real conversation turns. + * + * @example + * ```ts + * await ctx.filler('Still searching, hang on.', { delay: 5000 }, async () => { + * return await slowLookup(); + * }); + * ``` + */ + async filler(source: FillerSource, fn: () => Promise | T): Promise; + async filler( + source: FillerSource, + options: RunContextFillerOptions, + fn: () => Promise | T, + ): Promise; + async filler( + source: FillerSource, + optionsOrFn: RunContextFillerOptions | (() => Promise | T), + maybeFn?: () => Promise | T, + ): Promise { + const options = typeof optionsOrFn === 'function' ? {} : optionsOrFn; + const fn = typeof optionsOrFn === 'function' ? optionsOrFn : maybeFn; + if (!fn) { + throw new Error('RunContext.filler requires a callback scope'); + } + + const scheduler = new FillerScheduler(this.session, this.speechHandle, source, options); + this._fillerSchedulers.push(scheduler); + try { + return await fn(); + } finally { + await scheduler.close(); + this._fillerSchedulers = this._fillerSchedulers.filter((s) => s !== scheduler); + } + } + + _attachExecutor( + executor: AttachedToolExecutor, + firstUpdateFuture: Future, + ): void { + if (this._firstUpdateFuture !== undefined) { + throw new Error('Executor already attached'); + } + this._executor = executor; + this._firstUpdateFuture = firstUpdateFuture; + } + + _detachExecutor(): void { + this._executor = undefined; + this._firstUpdateFuture = undefined; + } + + async _drainPendingReply(): Promise { + if (!this._executor?.replyTask) return; + try { + await this._executor.replyTask; + } catch { + // Reply task owns its own logging/errors. + } + } + + _makeUpdatePair(message: unknown, callIdSuffix: string = ''): [FunctionCall, FunctionCallOutput] { + const fncCall = FunctionCall.create({ + callId: `${this.functionCall.callId}${callIdSuffix}`, + name: this.functionCall.name, + args: this.functionCall.args, + extra: { ...this.functionCall.extra }, + }); + return [ + fncCall, + FunctionCallOutput.create({ + name: fncCall.name, + callId: fncCall.callId, + output: stringifyToolOutput(message), + isError: false, + }), + ]; + } + + _recordUpdatePair(pair: [FunctionCall, FunctionCallOutput]): void { + this._updates.push(pair); + } +} + +function renderTemplate( + template: PromptTemplate | undefined, + args: UpdatePromptArgs, +): string { + if (!template) return args.message; + if (typeof template === 'function') return template(args); + return template + .replaceAll('{functionName}', args.functionName) + .replaceAll('{callId}', args.callId) + .replaceAll('{message}', args.message); +} + +function stringifyToolOutput(value: unknown): string { + if (typeof value === 'string') return value; + if (value === undefined || value === null) return ''; + if (typeof value === 'number' || typeof value === 'boolean') return String(value); + try { + return JSON.stringify(value); + } catch { + return String(value); + } +} + +class FillerScheduler { + private readonly abortController = new AbortController(); + private readonly delay: number; + private readonly interval?: number; + private readonly maxSteps?: number; + private readonly task: Promise; + private dwellAbortController?: AbortController; + private createdSpeeches: SpeechHandle[] = []; + + constructor( + private readonly session: AgentSession, + private readonly speechHandle: SpeechHandle, + private readonly source: FillerSource, + options: RunContextFillerOptions, + ) { + const { delay = 0, interval, maxSteps, signal } = options; + if (delay < 0) { + throw new Error('delay must be non-negative'); + } + if (interval !== undefined && interval < 0) { + throw new Error('interval must be non-negative when set'); + } + if (maxSteps !== undefined && maxSteps < 0) { + throw new Error('maxSteps must be non-negative when set'); + } + + this.delay = delay; + this.interval = interval; + this.maxSteps = maxSteps; + + if (signal?.aborted) { + this.abortController.abort(); + } else { + signal?.addEventListener('abort', () => this.abortController.abort(), { once: true }); + } + + this.task = this.run(); + void this.task.catch(() => undefined); + } + + resetDwell(): void { + this.dwellAbortController?.abort(); + } + + async close(): Promise { + this.abortController.abort(); + await this.task.catch(() => undefined); + } + + private async run(): Promise { + const onAgentStateChanged = (event: AgentStateChangedEvent): void => { + if (event.newState === 'speaking' || event.newState === 'thinking') { + this.resetDwell(); + } + }; + const onUserStateChanged = (event: UserStateChangedEvent): void => { + if (event.newState === 'speaking') { + this.resetDwell(); + } + }; + + this.session.on(AgentSessionEventTypes.AgentStateChanged, onAgentStateChanged); + this.session.on(AgentSessionEventTypes.UserStateChanged, onUserStateChanged); + + const loop = this.loop(); + try { + await this.speechHandle.waitIfNotInterrupted([loop]); + } finally { + this.abortController.abort(); + this.session.off(AgentSessionEventTypes.AgentStateChanged, onAgentStateChanged); + this.session.off(AgentSessionEventTypes.UserStateChanged, onUserStateChanged); + await loop.catch(() => undefined); + } + } + + private async loop(): Promise { + while (!this.abortController.signal.aborted) { + const idleReached = await waitUnlessAborted( + this.session.waitForIdle(), + this.abortController.signal, + ); + if (!idleReached) return; + + this.dwellAbortController = new AbortController(); + const dwellCompleted = await sleepUnlessAborted(this.delay, [ + this.abortController.signal, + this.dwellAbortController.signal, + ]); + this.dwellAbortController = undefined; + if (!dwellCompleted) { + continue; + } + + if (this.maxSteps !== undefined && this.createdSpeeches.length >= this.maxSteps) { + return; + } + + const handle = this.createSpeech(); + if (handle) { + this.createdSpeeches.push(handle); + } + + if ( + this.interval === undefined || + (this.maxSteps !== undefined && this.createdSpeeches.length >= this.maxSteps) + ) { + return; + } + + const intervalCompleted = await sleepUnlessAborted(this.interval, [ + this.abortController.signal, + ]); + if (!intervalCompleted) return; + } + } + + private createSpeech(): SpeechHandle | undefined { + const value = + typeof this.source === 'function' ? this.source(this.createdSpeeches.length) : this.source; + if (typeof value === 'string') { + return this.session.say(value); + } + if (isSpeechHandle(value)) { + return value; + } + return undefined; + } +} + +async function waitUnlessAborted(promise: Promise, signal: AbortSignal): Promise { + if (signal.aborted) return false; + let onAbort: (() => void) | undefined; + const abortPromise = new Promise((resolve) => { + onAbort = () => resolve(false); + signal.addEventListener('abort', onAbort, { once: true }); + }); + + try { + return await Promise.race([promise.then(() => true), abortPromise]); + } finally { + if (onAbort) { + signal.removeEventListener('abort', onAbort); + } + } +} + +async function sleepUnlessAborted(ms: number, signals: AbortSignal[]): Promise { + if (signals.some((signal) => signal.aborted)) return false; + return new Promise((resolve) => { + const timeout = setTimeout(() => { + cleanup(); + resolve(true); + }, ms); + const onAbort = () => { + clearTimeout(timeout); + cleanup(); + resolve(false); + }; + const cleanup = () => { + for (const signal of signals) { + signal.removeEventListener('abort', onAbort); + } + }; + for (const signal of signals) { + signal.addEventListener('abort', onAbort, { once: true }); + } + }); } diff --git a/agents/src/voice/speech_handle.ts b/agents/src/voice/speech_handle.ts index 645729fab..2c1fcb111 100644 --- a/agents/src/voice/speech_handle.ts +++ b/agents/src/voice/speech_handle.ts @@ -222,7 +222,11 @@ export class SpeechHandle { */ async waitForPlayout(): Promise { const store = functionCallStorage.getStore(); - if (store?.functionCall && store.speechHandle === this) { + if ( + store?.functionCall && + store.speechHandle === this && + store.functionCall.extra.__livekit_agents_tool_non_blocking !== true + ) { throw new SpeechHandleCircularWaitError(store.functionCall.name); } await this.doneFut.await; diff --git a/agents/src/voice/testing/fake_llm.ts b/agents/src/voice/testing/fake_llm.ts index ad3a1bf16..9cf2e4dc5 100644 --- a/agents/src/voice/testing/fake_llm.ts +++ b/agents/src/voice/testing/fake_llm.ts @@ -4,7 +4,7 @@ import type { ChatContext } from '../../llm/chat_context.js'; import { FunctionCall } from '../../llm/chat_context.js'; import { LLMStream as BaseLLMStream, LLM, type LLMStream } from '../../llm/llm.js'; -import type { ToolChoice, ToolContext } from '../../llm/tool_context.js'; +import type { ToolChoice, ToolContextLike } from '../../llm/tool_context.js'; import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS } from '../../types.js'; import { delay } from '../../utils.js'; @@ -42,7 +42,7 @@ export class FakeLLM extends LLM { connOptions = DEFAULT_API_CONNECT_OPTIONS, }: { chatCtx: ChatContext; - toolCtx?: ToolContext; + toolCtx?: ToolContextLike; connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: ToolChoice; @@ -65,7 +65,7 @@ class FakeLLMStream extends BaseLLMStream { constructor( fake: FakeLLM, - params: { chatCtx: ChatContext; toolCtx?: ToolContext; connOptions: APIConnectOptions }, + params: { chatCtx: ChatContext; toolCtx?: ToolContextLike; connOptions: APIConnectOptions }, ) { super(fake, params); this.fake = fake; diff --git a/agents/src/voice/testing/index.ts b/agents/src/voice/testing/index.ts index 23068f272..435b04c3d 100644 --- a/agents/src/voice/testing/index.ts +++ b/agents/src/voice/testing/index.ts @@ -30,6 +30,9 @@ export { MessageAssert, RunAssert, RunResult, + withMockTools, + type MockToolFn, + type MockToolsMap, } from './run_result.js'; export { diff --git a/agents/src/voice/testing/run_result.test.ts b/agents/src/voice/testing/run_result.test.ts new file mode 100644 index 000000000..3aa01aba0 --- /dev/null +++ b/agents/src/voice/testing/run_result.test.ts @@ -0,0 +1,182 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { ReadableStream } from 'node:stream/web'; +import { describe, expect, it } from 'vitest'; +import { z } from 'zod'; +import { FunctionCall } from '../../llm/chat_context.js'; +import { ToolContext, tool } from '../../llm/tool_context.js'; +import { Agent } from '../agent.js'; +import { performToolExecutions } from '../generation.js'; +import { SpeechHandle } from '../speech_handle.js'; +import { activeMockTools, withMockTools } from './run_result.js'; + +class AgentA extends Agent { + constructor() { + super({ instructions: 'a' }); + } +} + +class AgentB extends Agent { + constructor() { + super({ instructions: 'b' }); + } +} + +describe('withMockTools', () => { + it('sets the mock registry for the given agent inside the block', () => { + const mock = () => 'mocked'; + + { + using _mock = withMockTools(AgentA, { tool1: mock }); + expect(activeMockTools).toBeDefined(); + expect(activeMockTools?.get(AgentA)?.tool1).toBe(mock); + } + + expect(activeMockTools).toBeUndefined(); + }); + + it('merges mocks across nested blocks and isolates per agent', () => { + const mockA = () => 'a'; + const mockB = () => 'b'; + + { + using _mockA = withMockTools(AgentA, { toolA: mockA }); + { + using _mockB = withMockTools(AgentB, { toolB: mockB }); + expect(activeMockTools?.get(AgentA)?.toolA).toBe(mockA); + expect(activeMockTools?.get(AgentB)?.toolB).toBe(mockB); + } + + expect(activeMockTools?.get(AgentA)?.toolA).toBe(mockA); + expect(activeMockTools?.get(AgentB)).toBeUndefined(); + } + }); + + it('inner block for same agent overrides outer mocks', () => { + const outer = () => 'outer'; + const inner = () => 'inner'; + + { + using _outer = withMockTools(AgentA, { tool1: outer }); + { + using _inner = withMockTools(AgentA, { tool1: inner }); + expect(activeMockTools?.get(AgentA)?.tool1).toBe(inner); + } + expect(activeMockTools?.get(AgentA)?.tool1).toBe(outer); + } + }); + + it('exposes the mock for invocation within the block', async () => { + using _mock = withMockTools(AgentA, { tool1: async () => 42 }); + const mock = activeMockTools?.get(AgentA)?.tool1; + expect(await mock?.()).toBe(42); + }); + + it('routes performToolExecutions to the mock when set, original otherwise', async () => { + let realCalled = false; + const realTool = tool({ + name: 'greet', + description: 'real', + parameters: z.object({ name: z.string() }), + execute: async ({ name }) => { + realCalled = true; + return `real:${name}`; + }, + }); + + const toolCtx = new ToolContext([realTool]); + const speechHandle = SpeechHandle.create({ allowInterruptions: false }); + const agent = new AgentA(); + + // Minimal AgentSession stub: performToolExecutions only reads session.currentAgent. + const session = { currentAgent: agent } as never; + + const makeStream = (call: FunctionCall) => + new ReadableStream({ + start(controller) { + controller.enqueue(call); + controller.close(); + }, + }); + + // 1) With a mock registered: the mock runs, the real tool does not. + { + using _mock = withMockTools(AgentA, { greet: () => 'mocked' }); + const controller = new AbortController(); + const call = FunctionCall.create({ + callId: 'call_1', + name: 'greet', + args: JSON.stringify({ name: 'world' }), + }); + const [task, output] = performToolExecutions({ + session, + speechHandle, + toolCtx, + toolCallStream: makeStream(call), + controller, + }); + await task.result; + expect(realCalled).toBe(false); + expect(output.output[0]?.rawOutput).toBe('mocked'); + } + + // 2) Without a mock: the real tool runs. + const controller = new AbortController(); + const call = FunctionCall.create({ + callId: 'call_2', + name: 'greet', + args: JSON.stringify({ name: 'world' }), + }); + const [task, output] = performToolExecutions({ + session, + speechHandle, + toolCtx, + toolCallStream: makeStream(call), + controller, + }); + await task.result; + expect(realCalled).toBe(true); + expect(output.output[0]?.rawOutput).toBe('real:world'); + }); + + it('propagates thrown errors from mocks as tool errors', async () => { + const realTool = tool({ + name: 'failing', + description: 'real', + parameters: z.object({}), + execute: async () => 'ok', + }); + const toolCtx = new ToolContext([realTool]); + const speechHandle = SpeechHandle.create({ allowInterruptions: false }); + const session = { currentAgent: new AgentA() } as never; + + using _mock = withMockTools(AgentA, { + failing: () => { + throw new Error('test failure'); + }, + }); + const controller = new AbortController(); + const call = FunctionCall.create({ + callId: 'call_err', + name: 'failing', + args: '{}', + }); + const stream = new ReadableStream({ + start(c) { + c.enqueue(call); + c.close(); + }, + }); + const [task, output] = performToolExecutions({ + session, + speechHandle, + toolCtx, + toolCallStream: stream, + controller, + }); + await task.result; + expect(output.output[0]?.rawException?.message).toBe('test failure'); + expect(output.output[0]?.toolCallOutput?.isError).toBe(true); + }); +}); diff --git a/agents/src/voice/testing/run_result.ts b/agents/src/voice/testing/run_result.ts index 4ee0ccc56..21c3a9e11 100644 --- a/agents/src/voice/testing/run_result.ts +++ b/agents/src/voice/testing/run_result.ts @@ -28,8 +28,9 @@ import { } from './types.js'; // Type for agent constructor (used in assertions) +/** @internal */ // eslint-disable-next-line @typescript-eslint/no-explicit-any -type AgentConstructor = new (...args: any[]) => Agent; +export type AgentConstructor = new (...args: any[]) => Agent; // In JS we use a zod schema so runtime validation and TS generic inference stay aligned. type OutputSchema = z.ZodType; @@ -817,6 +818,7 @@ export class MessageAssert extends EventAssert { // Create the check_intent tool const checkIntentTool = tool({ + name: 'check_intent', description: 'Determines whether the message correctly fulfills the given intent. ' + 'Returns success=true if the message satisfies the intent, false otherwise. ' + @@ -853,7 +855,7 @@ export class MessageAssert extends EventAssert { const stream = llm.chat({ chatCtx, - toolCtx: { check_intent: checkIntentTool }, + toolCtx: [checkIntentTool], toolChoice: { type: 'function', function: { name: 'check_intent' } }, extraKwargs: { temperature: 0 }, }); @@ -947,9 +949,74 @@ export class AssertionError extends Error { } } -// TODO: mockTools() utility for mocking tool implementations in tests -// Will be implemented for test suites. -// See Python run_result.py lines 1010-1031 for reference. +/** + * A mock tool function. Can be sync or async. Receives the parsed tool arguments + * and tool options (matching the regular `execute` signature), but the mock is free + * to ignore them. Whatever the function returns becomes the tool output. Throwing + * produces a tool error event, just like a real tool execute. + */ +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type MockToolFn = (...args: any[]) => any; + +/** Map from agent constructor to a record of mocked tools by name. */ +export type MockToolsMap = Map>; + +/** @internal */ +export let activeMockTools: MockToolsMap | undefined; + +/** @internal */ +export function getMockTool(agent: Agent, toolName: string): MockToolFn | undefined { + if (!activeMockTools) return undefined; + + for (const [agentConstructor, mocks] of activeMockTools) { + if (agent.constructor === agentConstructor) { + return mocks[toolName]; + } + } + return undefined; +} + +/** + * Temporarily assign a set of mock tool callables to a specific Agent type. Returns + * a {@link Disposable} for use with a `using` declaration. While the binding is in + * scope, tool calls for the matching agent type and tool name are routed to the + * supplied mock instead of the real `execute` implementation, and are restored when + * the enclosing block exits. + * + * Mirrors the Python `mock_tools` context manager, adapted to JS via the explicit + * resource management `using` syntax (Python uses `with`). + * + * @param agent - The Agent constructor whose tools should be mocked. + * @param mocks - A record mapping tool name to a mock implementation. + * + * @example + * ```typescript + * { + * using _mock = withMockTools(DriveThruAgent, { + * orderRegularItem: () => new Error('test failure'), + * getWeather: () => 'sunny', + * }); + * + * const result = await session.run({ userInput: 'Order a burger' }); + * result.expect.containsFunctionCall({ name: 'orderRegularItem' }); + * } + * ``` + */ +export function withMockTools( + agent: AgentConstructor, + mocks: Record, +): Disposable { + const previous = activeMockTools; + const updated: MockToolsMap = new Map(previous ?? []); + updated.set(agent, mocks); + activeMockTools = updated; + + return { + [Symbol.dispose]() { + activeMockTools = previous; + }, + }; +} /** * Format events for debug output, optionally marking a selected index. diff --git a/agents/src/voice/tool_executor.test.ts b/agents/src/voice/tool_executor.test.ts new file mode 100644 index 000000000..2b9cd089c --- /dev/null +++ b/agents/src/voice/tool_executor.test.ts @@ -0,0 +1,307 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { describe, expect, it } from 'vitest'; +import { z } from 'zod'; +import { ChatContext, FunctionCall } from '../llm/chat_context.js'; +import { ToolFlag, tool } from '../llm/tool_context.js'; +import { Future } from '../utils.js'; +import type { AgentSession } from './agent_session.js'; +import { RunContext } from './run_context.js'; +import { SpeechHandle } from './speech_handle.js'; +import { ToolExecutor, getRunningTasks } from './tool_executor.js'; + +describe('ToolExecutor', () => { + function buildRunContext( + name: string = 'slow_lookup', + callId: string = `call_${name}`, + speechOptions: { allowInterruptions?: boolean } = {}, + ) { + const functionCall = FunctionCall.create({ + callId, + name, + args: '{"query":"flights"}', + }); + const history = new ChatContext(); + const agent = { + chatCtx: ChatContext.empty(), + async updateChatCtx(chatCtx: ChatContext) { + this.chatCtx = chatCtx; + }, + }; + const session = { + userData: {}, + history, + currentAgent: agent, + _globalRunState: undefined, + async waitForIdle() { + return { agent }; + }, + generateReply: () => ({ id: 'speech_reply', addDoneCallback: () => {} }), + } as unknown as AgentSession; + const speechHandle = SpeechHandle.create(speechOptions); + return { + runCtx: new RunContext(session, speechHandle, functionCall), + history, + agent, + }; + } + + it('preserves blocking tool return semantics when no update is sent', async () => { + const executor = new ToolExecutor(); + const { runCtx } = buildRunContext('blocking_lookup'); + const lookup = tool({ + name: 'blocking_lookup', + description: 'Blocking lookup', + parameters: z.object({ query: z.string() }), + execute: async ({ query }) => `result:${query}`, + }); + + const result = await executor.execute({ + tool: lookup, + runCtx, + rawArguments: { query: 'flights' }, + }); + + expect(result).toBe('result:flights'); + }); + + it('returns on first update and later enqueues final return', async () => { + const executor = new ToolExecutor(); + const { runCtx, history, agent } = buildRunContext('async_lookup'); + const releaseFinal = new Future(); + const lookup = tool({ + name: 'async_lookup', + description: 'Async lookup', + parameters: z.object({ query: z.string() }), + execute: async ({ query }, { ctx }) => { + await ctx.update(`started:${query}`); + await releaseFinal.await; + return `final:${query}`; + }, + }); + + const resultPromise = executor.execute({ + tool: lookup, + runCtx, + rawArguments: { query: 'flights' }, + }); + + await expect(resultPromise).resolves.toContain('started:flights'); + expect(runCtx.functionCall.extra.__livekit_agents_tool_non_blocking).toBe(true); + expect(history.items).toHaveLength(0); + + releaseFinal.resolve(); + await executor.waitForAll(); + + expect(history.items.some((item) => item.type === 'function_call_output')).toBe(true); + expect(agent.chatCtx.items.some((item) => item.type === 'function_call_output')).toBe(true); + }); + + it('rejects duplicate calls when onDuplicate is reject', async () => { + const executor = new ToolExecutor(); + const first = buildRunContext('dedupe_lookup'); + const second = buildRunContext('dedupe_lookup'); + const neverFinish = new Future(); + const lookup = tool({ + name: 'dedupe_lookup', + description: 'Dedupe lookup', + flags: ToolFlag.CANCELLABLE, + onDuplicate: 'reject', + execute: async (_, { ctx }) => { + await ctx.update('started'); + await neverFinish.await; + return 'done'; + }, + }); + + await executor.execute({ tool: lookup, runCtx: first.runCtx, rawArguments: {} }); + const duplicate = await executor.execute({ + tool: lookup, + runCtx: second.runCtx, + rawArguments: {}, + }); + + expect(String(duplicate)).toContain('Same tool `dedupe_lookup` is already running'); + + neverFinish.resolve(); + await executor.waitForAll(); + }); + + it('rejects concurrent duplicate calls when onDuplicate is reject', async () => { + const executor = new ToolExecutor(); + const first = buildRunContext('concurrent_dedupe', 'concurrent_dedupe_1'); + const second = buildRunContext('concurrent_dedupe', 'concurrent_dedupe_2'); + const neverFinish = new Future(); + const lookup = tool({ + name: 'concurrent_dedupe', + description: 'Concurrent dedupe lookup', + flags: ToolFlag.CANCELLABLE, + onDuplicate: 'reject', + execute: async (_, { ctx }) => { + await ctx.update('started'); + await neverFinish.await; + return 'done'; + }, + }); + + const outputs = await Promise.all([ + executor.execute({ tool: lookup, runCtx: first.runCtx, rawArguments: {} }), + executor.execute({ tool: lookup, runCtx: second.runCtx, rawArguments: {} }), + ]); + + expect(outputs.filter((output) => String(output).includes('started'))).toHaveLength(1); + expect( + outputs.filter((output) => + String(output).includes('Same tool `concurrent_dedupe` is already running'), + ), + ).toHaveLength(1); + + neverFinish.resolve(); + await executor.waitForAll(); + }); + + it('returns from cancel even when a cancellable tool ignores abortSignal', async () => { + const executor = new ToolExecutor(); + const { runCtx } = buildRunContext('non_cooperative_cancel'); + const neverFinish = new Future(); + const lookup = tool({ + name: 'non_cooperative_cancel', + description: 'Non-cooperative cancellable lookup', + flags: ToolFlag.CANCELLABLE, + execute: async (_, { ctx }) => { + await ctx.update('started'); + await neverFinish.await; + return 'done'; + }, + }); + + await executor.execute({ tool: lookup, runCtx, rawArguments: {} }); + + const cancelled = await Promise.race([ + executor.cancel(runCtx.functionCall.callId).then(() => 'cancelled'), + new Promise<'hung'>((resolve) => setTimeout(() => resolve('hung'), 25)), + ]); + + expect(cancelled).toBe('cancelled'); + expect(getRunningTasks(runCtx.session)).toHaveLength(0); + }); + + // An abortable tool that resolves as soon as its abortSignal fires — mirrors how + // the example tools (abortable sleep / fetch signal) honor cancellation. + function abortableTool(name: string, started: Future, stopped: Future) { + return tool({ + name, + description: 'Abortable cancellable lookup', + flags: ToolFlag.CANCELLABLE, + execute: async (_, { ctx, abortSignal }) => { + await ctx.update('started'); + started.resolve(); + await new Promise((resolve) => { + if (abortSignal.aborted) return resolve(); + abortSignal.addEventListener('abort', () => resolve(), { once: true }); + }); + stopped.resolve(); + return 'done'; + }, + }); + } + + it('explicit cancel aborts the tool so an abortable execute() actually stops', async () => { + const executor = new ToolExecutor(); + const { runCtx } = buildRunContext('abortable_cancel'); + const started = new Future(); + const stopped = new Future(); + + await executor.execute({ + tool: abortableTool('abortable_cancel', started, stopped), + runCtx, + rawArguments: {}, + }); + await started.await; + expect(stopped.done).toBe(false); + + await executor.cancel(runCtx.functionCall.callId); + // The tool observes the abort signal and stops on its own (no deadline needed). + await Promise.race([ + stopped.await, + new Promise((_, reject) => setTimeout(() => reject(new Error('tool never stopped')), 1000)), + ]); + expect(stopped.done).toBe(true); + }); + + it('drain (handoff) aborts cancellable tools so an abortable execute() stops', async () => { + const executor = new ToolExecutor(); + const { runCtx } = buildRunContext('abortable_drain'); + const started = new Future(); + const stopped = new Future(); + + await executor.execute({ + tool: abortableTool('abortable_drain', started, stopped), + runCtx, + rawArguments: {}, + }); + await started.await; + + await executor.drain(); + await Promise.race([ + stopped.await, + new Promise((_, reject) => setTimeout(() => reject(new Error('tool never stopped')), 1000)), + ]); + expect(stopped.done).toBe(true); + }); + + // longcw review #1: drain() must force-cancel cancellable tools at teardown even when the + // speech disallows interruptions. cancel() throws in that case (LLM-path guard); if drain used + // cancel() the throw would abort the loop and strand the remaining tools. + it('drain force-cancels a cancellable tool whose speech disallows interruptions (no throw)', async () => { + const executor = new ToolExecutor(); + const { runCtx } = buildRunContext('noninterrupt_drain', 'call_noninterrupt_drain', { + allowInterruptions: false, + }); + const started = new Future(); + const stopped = new Future(); + + await executor.execute({ + tool: abortableTool('noninterrupt_drain', started, stopped), + runCtx, + rawArguments: {}, + }); + await started.await; + expect(runCtx.speechHandle.allowInterruptions).toBe(false); + + // Must resolve (not throw) and actually abort the tool. + await expect(executor.drain()).resolves.toBeUndefined(); + await Promise.race([ + stopped.await, + new Promise((_, reject) => setTimeout(() => reject(new Error('tool never stopped')), 1000)), + ]); + expect(stopped.done).toBe(true); + expect(getRunningTasks(runCtx.session)).toHaveLength(0); + }); + + it('keeps running task visibility scoped to one session', async () => { + const executor = new ToolExecutor(); + const first = buildRunContext('session_scoped_lookup'); + const second = buildRunContext('session_scoped_lookup'); + const neverFinish = new Future(); + const lookup = tool({ + name: 'session_scoped_lookup', + description: 'Session scoped lookup', + flags: ToolFlag.CANCELLABLE, + execute: async (_, { ctx }) => { + await ctx.update('started'); + await neverFinish.await; + return 'done'; + }, + }); + + await executor.execute({ tool: lookup, runCtx: first.runCtx, rawArguments: {} }); + + expect(getRunningTasks(first.runCtx.session)).toHaveLength(1); + expect(getRunningTasks(second.runCtx.session)).toHaveLength(0); + + neverFinish.resolve(); + await executor.waitForAll(); + }); +}); diff --git a/agents/src/voice/tool_executor.ts b/agents/src/voice/tool_executor.ts new file mode 100644 index 000000000..c18602999 --- /dev/null +++ b/agents/src/voice/tool_executor.ts @@ -0,0 +1,675 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { Mutex } from '@livekit/mutex'; +import { z } from 'zod'; +import { ChatContext, FunctionCall, FunctionCallOutput } from '../llm/chat_context.js'; +import { + CONFIRM_DUPLICATE_PARAM, + type DuplicateMode, + type FunctionTool, + type JSONObject, + type ToolContextEntry, + ToolError, + ToolFlag, + Toolset, + isFunctionTool, + tool, +} from '../llm/tool_context.js'; +import { log } from '../log.js'; +import { Future } from '../utils.js'; +import type { AgentSession } from './agent_session.js'; +import type { PromptTemplate, RunContext, UpdatePromptArgs } from './run_context.js'; + +// Upper bound on how long `drain()` waits for in-flight tool promises to settle +// after the executor has signalled abort. +const DRAIN_TOOL_TIMEOUT_MS = Number(process.env.LK_DRAIN_PLAYOUT_TIMEOUT_MS) || 5_000; + +export interface DuplicatePromptArgs { + functionName: string; + functionCallsJson: string[]; + functionCallsText: string; +} + +export interface ReplyPromptArgs { + callIds: string[]; +} + +export interface AsyncToolOptions { + updateTemplate: PromptTemplate; + duplicateRejectTemplate: PromptTemplate; + duplicateConfirmTemplate: PromptTemplate; + replyAtTailTemplate: PromptTemplate; + replyMaybeCoveredTemplate: PromptTemplate; +} + +export interface ToolHandlingOptions { + asyncOptions?: Partial; +} + +export const UPDATE_TEMPLATE = + 'The tool `{functionName}` has updated, message: {message}\n' + + "The task is still running, so DON'T make up or give information not included in the message above."; + +export const DUPLICATE_REJECT = + 'Same tool `{functionName}` is already running:\n' + + '{functionCallsText}\n' + + 'If you want to cancel the existing one, call `lk_agents_cancel_task` with call_id.'; + +export const DUPLICATE_CONFIRM = + 'Same tool `{functionName}` is already running:\n' + + '{functionCallsText}\n' + + 'Re-call with confirm duplicate True to run a duplicate if needed,\n' + + 'or if you want to cancel the existing one, call `lk_agents_cancel_task` with call_id.'; + +export const REPLY_INSTRUCTIONS_AT_TAIL = + 'New results arrived from background tool calls (call_ids: {callIds}).\n' + + 'Summarize the results naturally. Do NOT repeat information you have already told the user.'; + +export const REPLY_INSTRUCTIONS_MAYBE_COVERED = + 'New results arrived from background tool calls (call_ids: {callIds}).\n' + + 'You may have already mentioned them in your most recent replies.\n' + + 'If you already told the user everything in these results, reply with an empty response (no text at all).\n' + + 'Otherwise, summarize only what you have not said yet, with a natural transition.\n' + + 'Never repeat information you have already told the user.'; + +export function resolveAsyncToolOptions(options?: Partial): AsyncToolOptions { + return { + updateTemplate: options?.updateTemplate ?? UPDATE_TEMPLATE, + duplicateRejectTemplate: options?.duplicateRejectTemplate ?? DUPLICATE_REJECT, + duplicateConfirmTemplate: options?.duplicateConfirmTemplate ?? DUPLICATE_CONFIRM, + replyAtTailTemplate: options?.replyAtTailTemplate ?? REPLY_INSTRUCTIONS_AT_TAIL, + replyMaybeCoveredTemplate: + options?.replyMaybeCoveredTemplate ?? REPLY_INSTRUCTIONS_MAYBE_COVERED, + }; +} + +export function renderTemplate>( + template: PromptTemplate, + args: Args, +): string { + if (typeof template === 'function') return template(args); + return Object.entries(args).reduce( + (result, [key, value]) => result.replaceAll(`{${key}}`, String(value)), + template, + ); +} + +type ToolExecutorAgent = { + chatCtx: ChatContext; + updateChatCtx(chatCtx: ChatContext): Promise | void; +}; + +type ToolExecutorActivity = { + agent: ToolExecutorAgent; + waitForIdle(): Promise; +}; + +type RunningTask = { + ctx: RunContext; + promise: Promise; + controller: AbortController; + firstUpdateFuture: Future; + executor: ToolExecutor; + allowCancellation: boolean; + // Guarded handle to the raw user `execute()` promise (never rejects). `drain()` + // waits on this to detect tools that keep running after being aborted. + toolPromiseRef: { promise?: Promise }; +}; + +type PendingUpdate = { + ctx: RunContext; + items: [FunctionCall, FunctionCallOutput]; + target: ToolExecutorAgent; +}; + +const runningTasks = new WeakMap, Map>(); + +export const getRunningTasksTool = tool({ + name: 'lk_agents_get_running_tasks', + description: 'Get the list of running tool calls that are cancellable.', + execute: async (_, { ctx }) => getRunningTasks(ctx.session).map((call) => call.toJSON(true)), +}); + +export const cancelTaskTool = tool({ + name: 'lk_agents_cancel_task', + description: 'Cancel a running tool call by call_id.', + parameters: z.object({ + call_id: z.string(), + }), + execute: async ({ call_id }, { ctx }) => { + const task = runningTasks.get(ctx.session)?.get(call_id); + if (!task) { + throw new ToolError(`Task ${call_id} not found`); + } + const cancelled = await task.executor.cancel(call_id); + if (!cancelled) { + throw new ToolError(`Task ${call_id} not found or already completed`); + } + return `Task ${call_id} cancelled successfully.`; + }, +}); + +export class ToolExecutor { + private runningTasks = new Map(); + private duplicateLock = new Mutex(); + private pendingUpdates: PendingUpdate[] = []; + private _replyTask?: Promise; + private _replyTaskDone = false; + toolOptions: AsyncToolOptions; + + constructor({ + owningActivity, + asyncToolOptions, + }: { + owningActivity?: ToolExecutorActivity | null; + asyncToolOptions?: Partial; + } = {}) { + this.owningActivity = owningActivity ?? null; + this.toolOptions = resolveAsyncToolOptions(asyncToolOptions); + } + + private owningActivity: ToolExecutorActivity | null; + + get replyTask(): Promise | undefined { + return this._replyTask; + } + + setOwningActivity(activity: ToolExecutorActivity | null): void { + this.owningActivity = activity; + } + + setToolOptions(options?: Partial): void { + this.toolOptions = resolveAsyncToolOptions(options); + } + + get hasRunningTasks(): boolean { + return this.runningTasks.size > 0; + } + + get hasCancellableRunningTasks(): boolean { + return [...this.runningTasks.values()].some((task) => task.allowCancellation); + } + + async execute({ + tool, + runCtx, + rawArguments, + abortSignal, + onUserToolStarted, + }: { + tool: FunctionTool; + runCtx: RunContext; + rawArguments: Parameters; + abortSignal?: AbortSignal; + onUserToolStarted?: () => void; + }): Promise { + const callId = runCtx.functionCall.callId; + const functionName = runCtx.functionCall.name; + const args = { ...rawArguments } as Parameters & Record; + const confirmDuplicate = + tool.onDuplicate === 'confirm' ? Boolean(args[CONFIRM_DUPLICATE_PARAM]) : undefined; + delete args[CONFIRM_DUPLICATE_PARAM]; + + const unlock = await this.duplicateLock.lock(); + try { + const duplicateResult = await this.checkDuplicate(functionName, { + onDuplicate: tool.onDuplicate, + confirmDuplicate, + }); + if (duplicateResult !== undefined) return duplicateResult; + + if (this.runningTasks.has(callId)) { + throw new Error(`Task already running for call_id: ${callId}`); + } + + const firstUpdateFuture = new Future(); + runCtx._attachExecutor(this, firstUpdateFuture); + + const controller = new AbortController(); + const abort = () => { + queueMicrotask(() => { + controller.abort(); + if (!firstUpdateFuture.done) { + firstUpdateFuture.reject(new Error('tool call was aborted')); + } + }); + }; + abortSignal?.addEventListener('abort', abort, { once: true }); + + // Once a tool goes non-blocking (it called ctx.update and detached from its + // owning speech), a speech interruption must NOT abort it — async tools are + // meant to survive interruptions and deliver their result later (matches + // Python, where the exe_task is independent and only cancel()/drain() stop it). + // Stop forwarding the speech abort to this tool; explicit cancel()/drain()/ + // aclose() still abort it directly via task.controller. + void firstUpdateFuture.await + .then(() => { + if (runCtx.functionCall.extra.__livekit_agents_tool_non_blocking === true) { + abortSignal?.removeEventListener('abort', abort); + } + }) + .catch(() => {}); + + const toolPromiseRef: { promise?: Promise } = {}; + const promise = this.runTool({ + tool, + runCtx, + rawArguments: args as Parameters, + firstUpdateFuture, + controller, + onUserToolStarted, + toolPromiseRef, + }).finally(() => { + this.runningTasks.delete(callId); + runningTasks.get(runCtx.session)?.delete(callId); + abortSignal?.removeEventListener('abort', abort); + runCtx._detachExecutor(); + }); + + const task: RunningTask = { + ctx: runCtx, + promise, + controller, + firstUpdateFuture, + executor: this, + allowCancellation: Boolean(tool.flags & ToolFlag.CANCELLABLE), + toolPromiseRef, + }; + this.runningTasks.set(callId, task); + let sessionTasks = runningTasks.get(runCtx.session); + if (!sessionTasks) { + sessionTasks = new Map(); + runningTasks.set(runCtx.session, sessionTasks); + } + sessionTasks.set(callId, task); + + return firstUpdateFuture.await; + } finally { + unlock(); + } + } + + async waitForAll(): Promise { + await Promise.allSettled([...this.runningTasks.values()].map((task) => task.promise)); + if (this._replyTask) await this._replyTask; + } + + async cancel(callId: string): Promise { + const task = this.runningTasks.get(callId); + if (!task) return false; + if (!task.allowCancellation) { + throw new ToolError(`Tool call ${callId} is not cancellable`); + } + if (!task.ctx.speechHandle.allowInterruptions) { + throw new ToolError( + `Tool call ${callId} is not cancellable because interruptions are disallowed`, + ); + } + this.forceCancelTask(task); + return true; + } + + /** + * Abort + detach a running task unconditionally. `cancel()` gates on cancellability and + * interruptibility (LLM/user-initiated path); teardown via `drain()` must skip those guards — + * the activity is going away, and `cancel()`'s `allowInterruptions` throw would otherwise abort + * the drain loop and strand the remaining tasks. + */ + private forceCancelTask(task: RunningTask): void { + const { callId } = task.ctx.functionCall; + task.controller.abort(); + if (!task.firstUpdateFuture.done) { + task.firstUpdateFuture.resolve(undefined); + } + + this.runningTasks.delete(callId); + runningTasks.get(task.ctx.session)?.delete(callId); + task.ctx._detachExecutor(); + void task.promise.catch(() => undefined); + // We've abandoned the wait, but the user's execute() may ignore the abort + // signal and keep running. Error if it hasn't stopped by the deadline. + this.errorIfCancelledToolKeepsRunning(task.ctx.functionCall, task.toolPromiseRef.promise); + } + + /** + * Fire-and-forget watcher: a cancelled tool whose `execute()` hasn't settled by the + * deadline is ignoring its abort signal. Surface it so the dev can make execute abortable — + * abandoning the promise alone leaves the work running invisibly. + */ + private errorIfCancelledToolKeepsRunning( + call: FunctionCall, + rawPromise: Promise | undefined, + ): void { + if (!rawPromise) return; + let settled = false; + void rawPromise.then(() => { + settled = true; + }); + const timer = setTimeout(() => { + if (settled) return; + log().error( + { tool: call.name, callId: call.callId, timeoutMs: DRAIN_TOOL_TIMEOUT_MS }, + `tool ${call.name} was cancelled but its execute() is still running after the deadline; it likely ` + + 'does not honor the abort signal. Observe the provided abortSignal in execute() so ' + + 'cancellation actually stops the work.', + ); + }, DRAIN_TOOL_TIMEOUT_MS); + timer.unref?.(); + void rawPromise.finally(() => clearTimeout(timer)); + } + + async drain(): Promise { + const tasks = [...this.runningTasks.values()]; + + // Cancellable tools: signal abort + abandon (force, bypassing the allowInterruptions guard + // that `cancel()` enforces for the LLM path — at teardown it would throw and strand the rest). + // Non-cancellable: let them run (awaited below). + for (const task of tasks) { + if (task.allowCancellation) { + this.forceCancelTask(task); + } + } + + const inflight = tasks + .filter((task) => !task.allowCancellation) + .map((task) => ({ name: task.ctx.functionCall.name, promise: task.toolPromiseRef.promise })) + .filter((t): t is { name: string; promise: Promise } => t.promise !== undefined); + if (inflight.length === 0) return; + + const settled = new Set(); + inflight.forEach((t, i) => void t.promise.then(() => settled.add(i))); + + const timer = setTimeout(() => { + const slow = inflight.filter((_, i) => !settled.has(i)).map((t) => t.name); + if (slow.length > 0) { + log().warn( + { tools: slow, timeoutMs: DRAIN_TOOL_TIMEOUT_MS }, + 'non-cancellable tool(s) still running after the drain deadline; awaiting their ' + + 'completion so the deferred result is not dropped. Keep non-cancellable tool work ' + + 'short, or mark the tool cancellable if it should stop on handoff.', + ); + } + }, DRAIN_TOOL_TIMEOUT_MS); + timer.unref?.(); + + await Promise.allSettled(inflight.map((t) => t.promise)); + clearTimeout(timer); + } + + async aclose(): Promise { + this.pendingUpdates = []; + const tasks = [...this.runningTasks.values()]; + for (const task of tasks) { + task.controller.abort(); + if (!task.firstUpdateFuture.done) { + task.firstUpdateFuture.resolve(undefined); + } + runningTasks.get(task.ctx.session)?.delete(task.ctx.functionCall.callId); + task.ctx._detachExecutor(); + void task.promise.catch(() => undefined); + } + this.runningTasks.clear(); + } + + async enqueueReply( + ctx: RunContext, + items: [FunctionCall, FunctionCallOutput], + ): Promise { + const target = this.owningActivity?.agent ?? getCurrentAgent(ctx.session); + const chatCtx = target.chatCtx.copy(); + chatCtx.insert(items); + await target.updateChatCtx(chatCtx); + ctx.session.history.insert(items); + + this.pendingUpdates.push({ ctx, items, target }); + if (this._replyTask === undefined || this._replyTaskDone) { + this._replyTaskDone = false; + this._replyTask = this.deliverReply(ctx.session).catch((error) => { + log().warn( + { error }, + 'deliverReply failed; async tool result may not trigger a follow-up reply', + ); + }); + const runState = ( + ctx.session as unknown as { + _globalRunState?: { _watchHandle?: (p: Promise) => void }; + } + )._globalRunState; + runState?._watchHandle?.(this._replyTask); + } + } + + private async runTool({ + tool, + runCtx, + rawArguments, + firstUpdateFuture, + controller, + onUserToolStarted, + toolPromiseRef, + }: { + tool: FunctionTool; + runCtx: RunContext; + rawArguments: Parameters; + firstUpdateFuture: Future; + controller: AbortController; + onUserToolStarted?: () => void; + toolPromiseRef: { promise?: Promise }; + }): Promise { + let output: unknown; + let exception: unknown; + + // Wrap so a synchronous throw inside execute() also becomes a rejection. + const toolPromise = (async () => + tool.execute(rawArguments, { + ctx: runCtx, + toolCallId: runCtx.functionCall.callId, + abortSignal: controller.signal, + }))(); + + // Guarded handle for drain() — never rejects, so abandoning it can't surface + // as an unhandled rejection. + toolPromiseRef.promise = toolPromise.then( + () => undefined, + () => undefined, + ); + onUserToolStarted?.(); + + // Await the tool to completion. Cancellation responsiveness is handled by + // `cancel()` (abandons this wait) and `drain()` (bounds the wait on the raw + // promise via `toolPromiseRef`), so we must NOT abandon a tool that finishes + // around the same time an abort fires — doing so dropped its output and left + // the function call without an output (dangling), wedging later turns. + try { + output = await toolPromise; + } catch (error) { + exception = error; + } + + if (controller.signal.aborted) { + if (!firstUpdateFuture.done) { + firstUpdateFuture.resolve(undefined); + } + return; + } + + if (!firstUpdateFuture.done) { + if (exception !== undefined) { + // Propagate non-Error throws too (e.g. `throw "boom"`), otherwise the tool would be + // reported as succeeding with an `undefined` output. + firstUpdateFuture.reject( + exception instanceof Error ? exception : new Error(String(exception)), + ); + } else { + firstUpdateFuture.resolve(output); + } + return; + } + + if (exception !== undefined) { + log().error( + { + function: runCtx.functionCall.name, + callId: runCtx.functionCall.callId, + error: exception, + }, + 'exception occurred while executing tool after a progress update', + ); + return; + } + if (output === undefined || output === null) { + return; + } + if (!this.runningTasks.has(runCtx.functionCall.callId)) { + return; + } + const pair = runCtx._makeUpdatePair(output, '_final'); + runCtx._recordUpdatePair(pair); + await this.enqueueReply(runCtx, pair); + } + + private async checkDuplicate( + functionName: string, + { + onDuplicate, + confirmDuplicate, + }: { + onDuplicate: DuplicateMode; + confirmDuplicate?: boolean; + }, + ): Promise { + const runningFunctionCalls = [...this.runningTasks.values()] + .map((task) => task.ctx.functionCall) + .filter((call) => call.name === functionName); + if (runningFunctionCalls.length === 0) return undefined; + + if (onDuplicate === 'allow') { + return undefined; + } + + if (onDuplicate === 'replace') { + const nonCancellable = runningFunctionCalls.filter( + (call) => !this.runningTasks.get(call.callId)?.allowCancellation, + ); + if (nonCancellable.length > 0) { + throw new ToolError( + `cannot replace duplicate call of \`${functionName}\`: running call is not cancellable`, + ); + } + await Promise.all(runningFunctionCalls.map((call) => this.cancel(call.callId))); + return undefined; + } + + const functionCallsJson = runningFunctionCalls.map((call) => JSON.stringify(call.toJSON(true))); + const args = { + functionName, + functionCallsJson, + functionCallsText: functionCallsJson.join('\n'), + }; + + if (onDuplicate === 'reject') { + return renderTemplate(this.toolOptions.duplicateRejectTemplate, args); + } + + if (onDuplicate === 'confirm' && !confirmDuplicate) { + return renderTemplate(this.toolOptions.duplicateConfirmTemplate, args); + } + + return undefined; + } + + private async deliverReply(session: AgentSession): Promise { + try { + if (this.owningActivity) { + await this.owningActivity.waitForIdle(); + } else if ('waitForIdle' in session && typeof session.waitForIdle === 'function') { + await session.waitForIdle(); + } + + const updates = [...this.pendingUpdates]; + this.pendingUpdates = []; + const pendingItems = updates.flatMap((update) => update.items); + if (pendingItems.length === 0) return; + + const targetAgent = this.owningActivity?.agent ?? getCurrentAgent(session); + + const itemsToInsert = updates + .filter((update) => update.target !== targetAgent) + .flatMap((update) => update.items); + + let chatCtx: ChatContext | undefined; + if (itemsToInsert.length > 0) { + chatCtx = targetAgent.chatCtx.copy(); + chatCtx.insert(itemsToInsert); + } + + const lastItem = pendingItems[pendingItems.length - 1]!; + const targetItems = targetAgent.chatCtx.items; + const atTail = + targetItems.length > 0 && targetItems[targetItems.length - 1]!.id === lastItem.id; + const callIds = pendingItems + .filter((item): item is FunctionCallOutput => item.type === 'function_call_output') + .map((item) => item.callId); + const instructions = renderTemplate( + atTail ? this.toolOptions.replyAtTailTemplate : this.toolOptions.replyMaybeCoveredTemplate, + { callIds }, + ); + + session.generateReply({ + instructions, + toolChoice: 'none', + chatCtx, + }); + } finally { + this._replyTaskDone = true; + } + } +} + +export function hasCancellableTool(tools: readonly ToolContextEntry[]): boolean { + for (const entry of tools) { + if (isFunctionTool(entry) && entry.flags & ToolFlag.CANCELLABLE) return true; + if (entry instanceof Toolset && hasCancellableTool(entry.tools)) return true; + } + return false; +} + +export function buildExecutorMap({ + toolsets, + defaultExecutor, +}: { + toolsets: readonly Toolset[]; + defaultExecutor: ToolExecutor; +}): Map { + const mapping = new Map(); + + const walk = (toolset: Toolset, current: ToolExecutor): void => { + const scopedExecutor = + '_executor' in toolset && toolset._executor instanceof ToolExecutor + ? toolset._executor + : current; + for (const child of toolset.tools) { + if (isFunctionTool(child)) { + mapping.set(child.name, scopedExecutor); + } else if (child instanceof Toolset) { + walk(child, scopedExecutor); + } + } + }; + + for (const toolset of toolsets) { + walk(toolset, defaultExecutor); + } + return mapping; +} + +export function getRunningTasks(session: AgentSession): FunctionCall[] { + return [...(runningTasks.get(session)?.values() ?? [])] + .filter((task) => task.allowCancellation) + .map((task) => FunctionCall.create({ ...task.ctx.functionCall })); +} + +function getCurrentAgent(session: AgentSession): ToolExecutorAgent { + return (session as unknown as { currentAgent: ToolExecutorAgent }).currentAgent; +} diff --git a/agents/src/voice/transcription/synchronizer.test.ts b/agents/src/voice/transcription/synchronizer.test.ts index 2b92c7be3..d3b100d3f 100644 --- a/agents/src/voice/transcription/synchronizer.test.ts +++ b/agents/src/voice/transcription/synchronizer.test.ts @@ -325,6 +325,54 @@ describe('TranscriptionSynchronizer playback-counter drift on a dropped frame', await synchronizer.close(); }); + it('does not let a synthetic drift finish consume a rotated segment awaiting its real finish', async () => { + // downstream that accepts segment A, then drops segment B at its + // pause/interrupt gate (bails before counting it) + class GateDroppingAudioOutput extends AudioOutput { + dropping = false; + constructor() { + super(8000); + } + async captureFrame(frame: AudioFrame): Promise { + if (this.dropping) return; + await super.captureFrame(frame); + } + clearBuffer(): void {} + } + + const downstream = new GateDroppingAudioOutput(); + const synchronizer = new TranscriptionSynchronizer(downstream, new MockTextOutput()); + const frame = new AudioFrame(new Int16Array(160), 8000, 1, 160); + + // segment A: text + audio accepted downstream; its real finish stays in flight + await synchronizer.textOutput.captureText('alpha'); + synchronizer.textOutput.flush(); + await synchronizer.audioOutput.captureFrame(frame); + synchronizer.audioOutput.flush(); + + // the next reply's text arrives before A's playback_finished -> A is rotated out and queued + await synchronizer.textOutput.captureText('bravo'); + await synchronizer.barrier(); + expect(synchronizer._pendingRotatedSegments).toHaveLength(1); + + // segment B: dropped at the downstream gate -> drift of 1 + downstream.dropping = true; + await synchronizer.audioOutput.captureFrame(frame); + synchronizer.audioOutput.flush(); + + // drift reconciliation must settle B (never accepted downstream) and leave + // A's queue entry for the real finish that is still coming + const playout = synchronizer.audioOutput.waitForPlayout(); + expect(synchronizer._pendingRotatedSegments).toHaveLength(1); + + // A's real finish arrives and settles A's queued segment + downstream.onPlaybackFinished({ playbackPosition: 1, interrupted: true }); + await playout; + expect(synchronizer._pendingRotatedSegments).toHaveLength(0); + + await synchronizer.close(); + }); + it('routes the synthetic finish through the synchronizer like a real playback finish', async () => { const synchronizer = new TranscriptionSynchronizer( new DroppingAudioOutput(), diff --git a/agents/src/voice/transcription/synchronizer.ts b/agents/src/voice/transcription/synchronizer.ts index efc9d0305..dc5c081f2 100644 --- a/agents/src/voice/transcription/synchronizer.ts +++ b/agents/src/voice/transcription/synchronizer.ts @@ -596,6 +596,22 @@ export class TranscriptionSynchronizer { * @internal */ _paused: boolean = false; + /** + * Segments that were rotated out by an incoming `capture_text` *before* their own + * `on_playback_finished` arrived (the "should not happen" path in `_SyncedTextOutput`). + * Each still owes one `on_playback_finished`; when it arrives it must settle that + * already-closed segment instead of the freshly-rotated current `_impl` (which now holds + * the *next* reply's text). Without this, the stale interrupted verdict tears down the new + * reply's leading segment and drops its transcript. + * + * `acceptedDownstream` records whether the next-in-chain audio output counted the segment: + * accepted segments owe a *real* finish from downstream, dropped segments owe a *synthetic* + * drift finish from `waitForPlayout()` — pairing by kind keeps a synthetic finish from + * consuming an entry whose real finish is still in flight. + * @internal + */ + _pendingRotatedSegments: { impl: SegmentSynchronizerImpl; acceptedDownstream: boolean }[] = []; + private logger = log(); constructor( @@ -701,6 +717,11 @@ export class TranscriptionSynchronizer { class SyncedAudioOutput extends AudioOutput { private pushedDuration: number = 0.0; + /** Whether the downstream output counted the segment currently being captured. */ + private segmentAccepted = false; + private segmentOpen = false; + /** Acceptance verdict of the most recently flushed segment. */ + private lastSegmentAccepted = false; constructor( public synchronizer: TranscriptionSynchronizer, @@ -709,6 +730,15 @@ class SyncedAudioOutput extends AudioOutput { super(nextInChainAudio.sampleRate, nextInChainAudio, { pause: true }); } + /** + * Whether the segment a rotation is about to queue was accepted downstream — the open + * segment's live verdict if frames are still flowing, else the last flushed segment's. + * @internal + */ + get _rotationCandidateAccepted(): boolean { + return this.segmentOpen ? this.segmentAccepted : this.lastSegmentAccepted; + } + pause(): void { super.pause(); this.synchronizer._paused = true; @@ -730,8 +760,16 @@ class SyncedAudioOutput extends AudioOutput { // capture_frame isn't completed await this.synchronizer.barrier(); + if (!this.segmentOpen) { + this.segmentOpen = true; + this.segmentAccepted = false; + } + const downstreamCapturedBefore = this.nextInChainAudio.capturedPlayoutSegments; await super.captureFrame(frame); await this.nextInChainAudio.captureFrame(frame); // passthrough audio + if (this.nextInChainAudio.capturedPlayoutSegments > downstreamCapturedBefore) { + this.segmentAccepted = true; + } // TODO(AJS-102): use frame.durationMs once available in rtc-node this.pushedDuration += frame.samplesPerChannel / frame.sampleRate; @@ -769,11 +807,23 @@ class SyncedAudioOutput extends AudioOutput { flush() { super.flush(); this.nextInChainAudio.flush(); + if (this.segmentOpen) { + this.segmentOpen = false; + this.lastSegmentAccepted = this.segmentAccepted; + } if (!this.synchronizer.outputsAttached || !this.synchronizer.enabled) { return; } + // If a previous (interrupted) speech was rotated out early by an incoming capture_text, + // this flush belongs to that already-closed segment. Applying endAudioInput to the current + // `_impl` would mark the NEXT reply's audio as ended and trigger a spurious rotation that + // drops its transcript. Skip it; the pending stale playback_finished finalizes the old one. + if (this.synchronizer._pendingRotatedSegments.length > 0) { + return; + } + if (!this.pushedDuration) { // For timed texts, audio goes directly to room without going through synchronizer. // If text was pushed but no audio, still end audio input so text can be processed. @@ -798,15 +848,46 @@ class SyncedAudioOutput extends AudioOutput { async waitForPlayout(): Promise { const drift = this.pendingPlayoutSegments - this.nextInChainAudio.pendingPlayoutSegments; for (let i = 0; i < drift; i++) { - // route the synthetic finish through our own override (not the base - // class) so the synchronizer marks the segment finished, attaches the - // synchronized transcript, and rotates — the dropped segment was - // captured through the synchronizer like any other - this.onPlaybackFinished({ playbackPosition: 0, interrupted: true }); + this.settleDriftFinish(); } return super.waitForPlayout(); } + /** + * Synthetic finish for a segment the downstream output dropped. Routed through the + * synchronizer like a real finish (mark finished, attach transcript, rotate), but paired + * only with queue entries that were *not* accepted downstream — an accepted entry's real + * finish is still in flight and must settle it instead. + */ + private settleDriftFinish(): void { + const ev: PlaybackFinishedEvent = { playbackPosition: 0, interrupted: true }; + if (!this.synchronizer.outputsAttached || !this.synchronizer.enabled) { + super.onPlaybackFinished(ev); + return; + } + + const queue = this.synchronizer._pendingRotatedSegments; + const idx = queue.findIndex((entry) => !entry.acceptedDownstream); + if (idx >= 0) { + const [entry] = queue.splice(idx, 1); + super.onPlaybackFinished({ + ...ev, + synchronizedTranscript: entry!.impl.synchronizedTranscript, + }); + this.pushedDuration = 0.0; + return; + } + + // the dropped segment is the current one + this.synchronizer._impl.markPlaybackFinished(ev.playbackPosition, ev.interrupted); + super.onPlaybackFinished({ + ...ev, + synchronizedTranscript: this.synchronizer._impl.synchronizedTranscript, + }); + this.synchronizer.rotateSegment(); + this.pushedDuration = 0.0; + } + // this is going to be automatically called by the next_in_chain onPlaybackStarted(createdAt: number): void { super.onPlaybackStarted(createdAt); @@ -822,6 +903,25 @@ class SyncedAudioOutput extends AudioOutput { return; } + // If a segment was rotated out by an incoming capture_text before its own + // playback_finished arrived, this event belongs to that already-closed segment — not the + // freshly-rotated current `_impl` (which holds the next reply). Settle the old one and + // leave the current segment untouched, so the new reply's leading text isn't dropped. + // Real finishes can only be for segments the downstream accepted; dropped entries are + // settled by `settleDriftFinish` instead. + const queue = this.synchronizer._pendingRotatedSegments; + const idx = queue.findIndex((entry) => entry.acceptedDownstream); + if (idx >= 0) { + const [entry] = queue.splice(idx, 1); + super.onPlaybackFinished({ + playbackPosition: ev.playbackPosition, + interrupted: ev.interrupted, + synchronizedTranscript: entry!.impl.synchronizedTranscript, + }); + this.pushedDuration = 0.0; + return; + } + this.synchronizer._impl.markPlaybackFinished(ev.playbackPosition, ev.interrupted); super.onPlaybackFinished({ playbackPosition: ev.playbackPosition, @@ -888,6 +988,13 @@ class SyncedTextOutput extends TextOutput { this.logger.warn( 'SegmentSynchronizerImpl text marked as ended in capture text, rotating segment', ); + // This segment's text input already ended but its on_playback_finished hasn't arrived + // yet (interrupt + fast next reply). Remember it so the still-pending playback_finished + // settles *this* segment, not the new one we're about to rotate in. + this.synchronizer._pendingRotatedSegments.push({ + impl: this.synchronizer._impl, + acceptedDownstream: this.synchronizer.audioOutput._rotationCandidateAccepted, + }); this.synchronizer.rotateSegment(); await this.synchronizer.barrier(); } diff --git a/agents/tsconfig.typecheck.json b/agents/tsconfig.typecheck.json new file mode 100644 index 000000000..afc099719 --- /dev/null +++ b/agents/tsconfig.typecheck.json @@ -0,0 +1,9 @@ +{ + "extends": "./tsconfig.json", + "compilerOptions": { + "noEmit": true, + "emitDeclarationOnly": false + }, + "include": ["src/**/*.type.test.ts"], + "exclude": [] +} diff --git a/examples/src/async_tool_agent.ts b/examples/src/async_tool_agent.ts new file mode 100644 index 000000000..a988ef4b8 --- /dev/null +++ b/examples/src/async_tool_agent.ts @@ -0,0 +1,400 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { + Agent, + AgentSession, + AgentTask, + type ChatContext, + type JobContext, + type LLMStream, + type RunContext, + ServerOptions, + ToolFlag, + cli, + defineAgent, + delay, + inference, + log, + tool, +} from '@livekit/agents'; +import type * as silero from '@livekit/agents-plugin-silero'; +import { fileURLToPath } from 'node:url'; +import { z } from 'zod'; + +type SearchResult = { + title: string; + body: string; + href?: string; +}; + +function sample(items: readonly T[], count: number): T[] { + const shuffled = [...items].sort(() => Math.random() - 0.5); + return shuffled.slice(0, count); +} + +function randomInt(min: number, max: number): number { + return Math.floor(Math.random() * (max - min + 1)) + min; +} + +function today(): string { + return new Intl.DateTimeFormat('en-US', { + weekday: 'long', + year: 'numeric', + month: '2-digit', + day: '2-digit', + }).format(new Date()); +} + +async function collectText(stream: LLMStream, signal?: AbortSignal): Promise { + let text = ''; + + const onAbort = () => stream.close(); + signal?.addEventListener('abort', onAbort, { once: true }); + try { + for await (const chunk of stream) { + if (signal?.aborted) break; + if (chunk.delta?.content) { + text += chunk.delta.content; + } + } + } finally { + signal?.removeEventListener('abort', onAbort); + stream.close(); + } + return text; +} + +function compactHistory(chatCtx: ChatContext): string { + return chatCtx.items + .slice(-8) + .map((item) => { + if (item.type === 'message') { + return `${item.role}: ${item.textContent ?? ''}`; + } + if (item.type === 'function_call_output') { + return `tool ${item.name}: ${item.output}`; + } + return item.type; + }) + .join('\n'); +} + +function createGetEmailTask(extraInstructions: string): AgentTask<{ emailAddress: string }> { + const task = AgentTask.create<{ emailAddress: string }>({ + instructions: + 'You collect the user email address for a flight booking. ' + + `${extraInstructions} As soon as you have the email, call save_email.`, + tools: [ + tool({ + name: 'save_email', + description: 'Save the user email address.', + parameters: z.object({ + emailAddress: z.string().describe('The user email address'), + }), + execute: async ({ emailAddress }) => { + task.complete({ emailAddress }); + return `Saved email address ${emailAddress}.`; + }, + }), + ], + onEnter: (ctx) => { + ctx.session.generateReply({ + instructions: + 'Ask the user for their email address in one short sentence, then call save_email.', + }); + }, + }); + return task; +} + +function createTravelAgent() { + const logger = log(); + const thinkingLLM = new inference.LLM({ + model: 'openai/gpt-5.4', + modelOptions: { reasoning_effort: 'medium' }, + }); + let userEmail: string | null = null; + + async function bookFlight( + { + origin, + destination, + date, + }: { + origin: string; + destination: string; + date: string; + }, + ctx: RunContext, + signal: AbortSignal, + ): Promise { + await ctx.update( + `Searching flights from ${origin} to ${destination} on ${date}. ` + + 'This will take a couple of minutes.', + ); + + await ctx.filler( + 'Still searching flight inventory, hang tight.', + { delay: 5_000, signal }, + () => delay(30_000, { signal }), + ); + + const airlines = sample(['United', 'Delta', 'American', 'JetBlue', 'Southwest', 'Alaska'], 3); + const prices = Object.fromEntries(airlines.map((airline) => [airline, randomInt(180, 650)])); + const cheapest = airlines.reduce((best, airline) => + prices[airline]! < prices[best]! ? airline : best, + ); + + logger.info({ airlines, prices, cheapest }, 'Found airlines and prices'); + + await ctx.update( + `Found ${airlines.length} options. Best price: $${prices[cheapest]} on ${cheapest}. ` + + 'Confirming the booking now.', + ); + + if (!userEmail) { + logger.info('Getting user email address'); + const email = await ctx.foreground(async () => { + ctx.session.say('We will need your email address to confirm the flight booking.'); + return createGetEmailTask( + 'You are capturing the email address of the user for the flight booking.', + ).run(); + }); + // The foreground hold can resolve right as we're cancelled; bail before the + // final wait so a cancelled booking doesn't push a confirmation. + if (signal.aborted) throw new Error('aborted'); + userEmail = email.emailAddress; + logger.info({ email: userEmail }, 'Captured user email address'); + } + + const confirmationFillers = [ + 'Still confirming the booking.', + "Almost there, I'm finalizing the reservation.", + ]; + await ctx.filler( + (step) => confirmationFillers[step], + { delay: 5_000, interval: 10_000, maxSteps: confirmationFillers.length, signal }, + () => delay(40_000, { signal }), + ); + + const confirmation = `FL-${randomInt(100000, 999999)}`; + return ( + `Flight booked! ${cheapest} from ${origin} to ${destination} on ${date}. ` + + `Price: $${prices[cheapest]}. Confirmation: ${confirmation}. ` + + 'The details will be sent to your email.' + ); + } + + async function tourGuide( + { + destination, + interests, + }: { + destination: string; + interests: string; + }, + ctx: RunContext, + signal: AbortSignal, + ): Promise { + await ctx.update(`Looking up the best spots in ${destination} for you.`); + + const sources = await search(destination, interests, ctx.session.history, signal); + if (sources.length === 0) { + return `Could not find information about ${destination}.`; + } + + logger.info({ count: sources.length, destination }, 'Found tour guide sources'); + return summarize(destination, interests, sources, ctx.session.history, signal); + } + + async function search( + destination: string, + interests: string, + chatCtx: ChatContext, + signal: AbortSignal, + ): Promise { + logger.info({ destination, interests }, 'Planning search queries'); + const planCtx = chatCtx.copy({ excludeFunctionCall: true, excludeInstructions: true }); + planCtx.addMessage({ + role: 'system', + content: + 'You are a travel research assistant. Output 3-4 web search queries ' + + `to find the best places to visit, eat, and explore in ${destination} ` + + `for someone interested in: ${interests}. ` + + 'Output only the queries, one per line, nothing else.', + }); + + const planResponse = await collectText(thinkingLLM.chat({ chatCtx: planCtx }), signal); + const queries = planResponse + .split('\n') + .map((query) => query.trim()) + .filter(Boolean) + .slice(0, 4); + logger.info({ queries }, 'Search queries'); + + const results: SearchResult[] = []; + for (const query of queries) { + if (signal.aborted) break; + results.push(...(await searchDuckDuckGo(query, signal))); + } + return results.slice(0, 12); + } + + async function searchDuckDuckGo(query: string, signal: AbortSignal): Promise { + const url = new URL('https://api.duckduckgo.com/'); + url.searchParams.set('q', query); + url.searchParams.set('format', 'json'); + url.searchParams.set('no_html', '1'); + url.searchParams.set('skip_disambig', '1'); + + const response = await fetch(url, { + headers: { accept: 'application/json' }, + signal, + }); + if (!response.ok) { + logger.warn({ status: response.status, query }, 'DuckDuckGo search failed'); + return []; + } + + const payload = (await response.json()) as { + AbstractText?: string; + Heading?: string; + AbstractURL?: string; + RelatedTopics?: Array<{ + Text?: string; + FirstURL?: string; + Name?: string; + Topics?: Array<{ Text?: string; FirstURL?: string }>; + }>; + }; + + const results: SearchResult[] = []; + if (payload.AbstractText) { + results.push({ + title: payload.Heading || query, + body: payload.AbstractText, + href: payload.AbstractURL, + }); + } + + for (const topic of payload.RelatedTopics ?? []) { + if (topic.Text) { + results.push({ title: topic.Name || query, body: topic.Text, href: topic.FirstURL }); + } + for (const nested of topic.Topics ?? []) { + if (nested.Text) { + results.push({ title: topic.Name || query, body: nested.Text, href: nested.FirstURL }); + } + } + } + + if (results.length === 0) { + results.push({ + title: query, + body: 'No direct instant-answer result was returned. Use this query as a research lead and provide general travel guidance.', + }); + } + + return results; + } + + async function summarize( + destination: string, + interests: string, + sources: SearchResult[], + chatCtx: ChatContext, + signal: AbortSignal, + ): Promise { + const summaryCtx = chatCtx.copy({ excludeFunctionCall: true, excludeInstructions: true }); + const sourceText = sources + .map((source) => `- ${source.title}: ${source.body}${source.href ? ` (${source.href})` : ''}`) + .join('\n\n'); + + summaryCtx.addMessage({ + role: 'system', + content: + `You are a local tour guide for ${destination}. The user is interested in: "${interests}". ` + + 'Based on the search results below, recommend specific places to visit, restaurants to try, ' + + 'and things to do. Be specific, with actual names and neighborhoods when available. ' + + 'This will be spoken aloud, so keep it conversational and brief: 3 to 5 top picks, ' + + 'no more than 200 words. No bullet points or markdown.\n\n' + + `Conversation context:\n${compactHistory(chatCtx)}\n\nSearch results:\n${sourceText}`, + }); + + return collectText(thinkingLLM.chat({ chatCtx: summaryCtx }), signal); + } + + return Agent.create({ + instructions: + 'You are a friendly travel assistant that communicates via voice. ' + + 'Avoid emojis and markdown. Speak naturally and concisely. ' + + 'You can help with two things: booking flights and recommending what to see, eat, ' + + 'and do at a destination. Use the book_flight tool when the user wants to book a ' + + 'flight. Use the tour_guide tool when the user asks about places to visit, restaurants, ' + + 'sightseeing, nightlife, or things to do somewhere. Summarize results naturally for voice. ' + + `Today is ${today()}. When the user is not asking, do not repeat messages already said. ` + + 'Do not make up flight details or ask for flight preferences. Always use the tools.', + tools: [ + tool({ + name: 'book_flight', + description: 'Called when the user wants to book a flight.', + flags: ToolFlag.CANCELLABLE, + onDuplicate: 'confirm', + parameters: z.object({ + origin: z.string().describe('Departure city or airport code'), + destination: z.string().describe('Arrival city or airport code'), + date: z.string().describe('Travel date, for example 2026-04-15'), + }), + execute: (args, options) => bookFlight(args, options.ctx, options.abortSignal), + }), + tool({ + name: 'tour_guide', + description: + 'Called when the user asks about places to visit, restaurants, local food, nightlife, or things to do somewhere.', + flags: ToolFlag.CANCELLABLE, + onDuplicate: 'confirm', + parameters: z.object({ + destination: z.string().describe('The city or area the user is visiting'), + interests: z + .string() + .describe('What the user is interested in, such as street food, museums, or nightlife'), + }), + execute: (args, options) => tourGuide(args, options.ctx, options.abortSignal), + }), + ], + onEnter: (ctx) => { + ctx.session.generateReply({ instructions: 'Greet the user and introduce yourself.' }); + }, + }); +} + +export default defineAgent({ + entry: async (ctx: JobContext) => { + const session = new AgentSession({ + vad: ctx.proc.userData.vad as silero.VAD, + stt: new inference.STT({ model: 'deepgram/nova-3' }), + llm: new inference.LLM({ model: 'google/gemini-3.1-flash-lite' }), + tts: new inference.TTS({ + model: 'cartesia/sonic-3', + voice: 'e07c00bc-4134-4eae-9ea4-1a55fb45746b', + }), + turnHandling: { + interruption: { + mode: 'adaptive', + }, + }, + }); + + await session.start({ + agent: createTravelAgent(), + room: ctx.room, + }); + }, +}); + +cli.runApp( + new ServerOptions({ + agent: fileURLToPath(import.meta.url), + }), +); diff --git a/examples/src/background_audio.ts b/examples/src/background_audio.ts index fc718ac0c..8d5884be9 100644 --- a/examples/src/background_audio.ts +++ b/examples/src/background_audio.ts @@ -35,6 +35,7 @@ export default defineAgent({ logger.info('Connected to room'); const searchWeb = llm.tool({ + name: 'searchWeb', description: 'Search the web for information based on the given query. Always use this function whenever the user requests a web search', parameters: z.object({ @@ -49,9 +50,7 @@ export default defineAgent({ const agent = new voice.Agent({ instructions: 'You are a helpful assistant', - tools: { - searchWeb, - }, + tools: [searchWeb], }); const session = new voice.AgentSession({ diff --git a/examples/src/basic_agent.ts b/examples/src/basic_agent.ts index 10c8e5028..3947b911c 100644 --- a/examples/src/basic_agent.ts +++ b/examples/src/basic_agent.ts @@ -2,15 +2,17 @@ // // SPDX-License-Identifier: Apache-2.0 import { + Agent, + AgentSession, + AgentSessionEventTypes, type JobContext, ServerOptions, cli, defineAgent, inference, - llm, log, - metrics, - voice, + logMetrics, + tool, } from '@livekit/agents'; import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; import { fileURLToPath } from 'node:url'; @@ -21,11 +23,12 @@ import { z } from 'zod'; // lazy-loads on first stream. export default defineAgent({ entry: async (ctx: JobContext) => { - const agent = new voice.Agent({ + const agent = Agent.create({ instructions: "You are a helpful assistant, you can hear the user's message and respond to it.", - tools: { - getWeather: llm.tool({ + tools: [ + tool({ + name: 'getWeather', description: 'Get the weather for a given location.', parameters: z.object({ location: z.string().describe('The location to get the weather for'), @@ -34,12 +37,12 @@ export default defineAgent({ return `The weather in ${location} is sunny.`; }, }), - }, + ], }); const logger = log(); - const session = new voice.AgentSession({ + const session = new AgentSession({ // Speech-to-text (STT) is your agent's ears, turning the user's speech into text that the LLM can understand // See all available models at https://docs.livekit.io/agents/models/stt/ stt: new inference.STT({ @@ -98,8 +101,8 @@ export default defineAgent({ }); // Log metrics as they are emitted - session.on(voice.AgentSessionEventTypes.MetricsCollected, (ev) => { - metrics.logMetrics(ev.metrics); + session.on(AgentSessionEventTypes.MetricsCollected, (ev) => { + logMetrics(ev.metrics); }); // Log usage summary when job shuts down @@ -112,8 +115,8 @@ export default defineAgent({ ); }); - session.on(voice.AgentSessionEventTypes.OverlappingSpeech, (ev) => { - logger.info({ type: ev.type, isInterruption: ev.isInterruption }, 'user overlapping speech'); + session.on(AgentSessionEventTypes.OverlappingSpeech, (ev) => { + logger.warn({ type: ev.type, isInterruption: ev.isInterruption }, 'user overlapping speech'); }); await session.start({ diff --git a/examples/src/basic_agent_task.ts b/examples/src/basic_agent_task.ts index 0549f4197..6d2023628 100644 --- a/examples/src/basic_agent_task.ts +++ b/examples/src/basic_agent_task.ts @@ -2,10 +2,13 @@ // // SPDX-License-Identifier: Apache-2.0 import { + Agent, + AgentTask, type JobContext, ServerOptions, cli, defineAgent, + handoff, inference, llm, voice, @@ -14,97 +17,101 @@ import * as openai from '@livekit/agents-plugin-openai'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; -class InfoTask extends voice.AgentTask { - constructor(private info: string) { - super({ - instructions: `Collect the user's information. around ${info}. Once you have the information, call the saveUserInfo tool to save the information to the database IMMEDIATELY. DO NOT have chitchat with the user, just collect the information and call the saveUserInfo tool.`, - tts: 'elevenlabs/eleven_turbo_v2_5', - tools: { - saveUserInfo: llm.tool({ - description: `Save the user's ${info} to database`, - parameters: z.object({ - [info]: z.string(), - }), - execute: async (args) => { - this.complete(args[info] as string); - return `Thanks, collected ${info} successfully: ${args[info]}`; - }, +function createInfoTask(info: string): AgentTask { + const task = AgentTask.create({ + instructions: `Collect the user's information. around ${info}. Once you have the information, call the saveUserInfo tool to save the information to the database IMMEDIATELY. DO NOT have chitchat with the user, just collect the information and call the saveUserInfo tool.`, + tts: 'elevenlabs/eleven_turbo_v2_5', + tools: [ + llm.tool({ + name: 'saveUserInfo', + description: `Save the user's ${info} to database`, + parameters: z.object({ + [info]: z.string(), }), - }, - }); - } + execute: async (args) => { + task.complete(args[info] as string); + return `Thanks, collected ${info} successfully: ${args[info]}`; + }, + }), + ], + onEnter: (ctx) => { + ctx.session.generateReply({ + userInput: `Ask the user for their ${info}`, + }); + }, + }); - async onEnter() { - this.session.generateReply({ - userInput: `Ask the user for their ${this.info}`, - }); - } + return task; } -class SurveyAgent extends voice.Agent { - constructor() { - super({ - instructions: - 'You orchestrate a short intro survey. Speak naturally and keep the interaction brief.', - tools: { - collectUserInfo: llm.tool({ - description: 'Call this when user want to provide some information to you', - parameters: z.object({ - key: z - .string() - .describe( - 'The key of the information to collect, e.g. "name" or "role" should be no space and underscore separated', - ), - }), - execute: async ({ key }) => { - const value = await new InfoTask(key).run(); - return `Collected ${key} successfully: ${value}`; - }, +function createWeatherAgent() { + return Agent.create({ + instructions: + 'You are a weather agent. You are responsible for providing the weather information to the user.', + tts: 'deepgram/aura-2', + tools: [ + llm.tool({ + name: 'getWeather', + description: 'Get the weather for a given location', + parameters: z.object({ + location: z.string().describe('The location to get the weather for'), }), - transferToWeatherAgent: llm.tool({ - description: 'Call this immediately after user want to know the weather', - execute: async () => { - const agent = new voice.Agent({ - instructions: - 'You are a weather agent. You are responsible for providing the weather information to the user.', - tts: 'deepgram/aura-2', - tools: { - getWeather: llm.tool({ - description: 'Get the weather for a given location', - parameters: z.object({ - location: z.string().describe('The location to get the weather for'), - }), - execute: async ({ location }) => { - return `The weather in ${location} is sunny today.`; - }, - }), - finishWeatherConversation: llm.tool({ - description: 'Call this when you want to finish the weather conversation', - execute: async () => { - return llm.handoff({ - agent: new SurveyAgent(), - returns: 'Transfer to survey agent successfully!', - }); - }, - }), - }, - }); + execute: async ({ location }) => { + return `The weather in ${location} is sunny today.`; + }, + }), + llm.tool({ + name: 'finishWeatherConversation', + description: 'Call this when you want to finish the weather conversation', + execute: async () => { + return llm.handoff({ + agent: createSurveyAgent(), + returns: 'Transfer to survey agent successfully!', + }); + }, + }), + ], + }); +} - return llm.handoff({ agent, returns: "Let's start the weather conversation!" }); - }, +function createSurveyAgent(): Agent { + return Agent.create({ + instructions: + 'You orchestrate a short intro survey. Speak naturally and keep the interaction brief.', + tools: [ + llm.tool({ + name: 'collectUserInfo', + description: 'Call this when user want to provide some information to you', + parameters: z.object({ + key: z + .string() + .describe( + 'The key of the information to collect, e.g. "name" or "role" should be no space and underscore separated', + ), }), - }, - }); - } - - async onEnter() { - const name = await new InfoTask('name').run(); - const role = await new InfoTask('role').run(); + execute: async ({ key }) => { + const value = await createInfoTask(key).run(); + return `Collected ${key} successfully: ${value}`; + }, + }), + llm.tool({ + name: 'transferToWeatherAgent', + description: 'Call this immediately after user want to know the weather', + execute: async () => { + const agent = createWeatherAgent(); + return handoff({ agent, returns: "Let's start the weather conversation!" }); + }, + }), + ], + onEnter: async (ctx) => { + const name = await createInfoTask('name').run(); + const role = await createInfoTask('role').run(); - await this.session.say( - `Great to meet you ${name}. I noted your role as ${role}. We can continue now.`, - ); - } + await ctx.session.say( + `Great to meet you ${name}. I noted your role as ${role}. We can continue now.`, + ); + }, + }); } export default defineAgent({ @@ -120,7 +127,7 @@ export default defineAgent({ await session.start({ room: ctx.room, - agent: new SurveyAgent(), + agent: createSurveyAgent(), }); }, }); diff --git a/examples/src/basic_task_group.ts b/examples/src/basic_task_group.ts index 0c24c2059..b27849ca1 100644 --- a/examples/src/basic_task_group.ts +++ b/examples/src/basic_task_group.ts @@ -23,8 +23,9 @@ class CollectNameTask extends voice.AgentTask { instructions: 'Collect the user name from the latest user message. As soon as you have it, call save_name.', tts: taskTts, - tools: { - save_name: llm.tool({ + tools: [ + llm.tool({ + name: 'save_name', description: 'Save the user name.', parameters: z.object({ name: z.string().describe('The user name'), @@ -34,7 +35,7 @@ class CollectNameTask extends voice.AgentTask { return `Saved name: ${name}`; }, }), - }, + ], }); } @@ -52,8 +53,9 @@ class CollectEmailTask extends voice.AgentTask { instructions: 'Collect the user email from the latest user message. As soon as you have it, call save_email.', tts: taskTts, - tools: { - save_email: llm.tool({ + tools: [ + llm.tool({ + name: 'save_email', description: 'Save the user email.', parameters: z.object({ email: z.string().describe('The user email'), @@ -63,7 +65,7 @@ class CollectEmailTask extends voice.AgentTask { return `Saved email: ${email}`; }, }), - }, + ], }); } @@ -80,8 +82,9 @@ class TaskGroupDemoAgent extends voice.Agent { super({ instructions: 'You are onboarding assistant. When user asks to begin onboarding, call startOnboarding exactly once.', - tools: { - startOnboarding: llm.tool({ + tools: [ + llm.tool({ + name: 'startOnboarding', description: 'Start a two-step onboarding flow (name then email).', parameters: z.object({}), execute: async () => { @@ -105,7 +108,7 @@ class TaskGroupDemoAgent extends voice.Agent { return JSON.stringify(result.taskResults); }, }), - }, + ], }); } diff --git a/examples/src/basic_tool_call_agent.ts b/examples/src/basic_tool_call_agent.ts index 8d576cf1f..06f7924d7 100644 --- a/examples/src/basic_tool_call_agent.ts +++ b/examples/src/basic_tool_call_agent.ts @@ -38,6 +38,7 @@ class GameAgent extends voice.Agent { export default defineAgent({ entry: async (ctx: JobContext) => { const getWeather = llm.tool({ + name: 'getWeather', description: ' Called when the user asks about the weather.', parameters: z.object({ location: z.string().describe('The location to get the weather for'), @@ -49,6 +50,7 @@ export default defineAgent({ }); const toggleLight = llm.tool({ + name: 'toggleLight', description: 'Called when the user asks to turn on or off the light.', parameters: z.object({ room: roomNameSchema.describe('The room to turn the light in'), @@ -64,6 +66,7 @@ export default defineAgent({ }); const getNumber = llm.tool({ + name: 'getNumber', description: 'Called when the user wants to get a number value, None if user want a random value', parameters: z.object({ @@ -81,6 +84,7 @@ export default defineAgent({ }); const checkStoredNumber = llm.tool({ + name: 'checkStoredNumber', description: 'Called when the user wants to check the stored number.', execute: async (_, { ctx }: llm.ToolOptions) => { return `The stored number is ${ctx.userData.number}.`; @@ -88,6 +92,7 @@ export default defineAgent({ }); const updateStoredNumber = llm.tool({ + name: 'updateStoredNumber', description: 'Called when the user wants to update the stored number.', parameters: z.object({ number: z.number().describe('The number to update the stored number to'), @@ -100,31 +105,33 @@ export default defineAgent({ const routerAgent = new RouterAgent({ instructions: 'You are a helpful assistant.', - tools: { + tools: [ getWeather, toggleLight, - playGame: llm.tool({ + llm.tool({ + name: 'playGame', description: 'Called when the user wants to play a game (transfer user to a game agent).', execute: async (): Promise => { return llm.handoff({ agent: gameAgent, returns: 'The game is now playing.' }); }, }), - }, + ], }); const gameAgent = new GameAgent({ instructions: 'You are a game agent. You are playing a game with the user.', - tools: { + tools: [ getNumber, checkStoredNumber, updateStoredNumber, - finishGame: llm.tool({ + llm.tool({ + name: 'finishGame', description: 'Called when the user wants to finish the game.', execute: async () => { return llm.handoff({ agent: routerAgent, returns: 'The game is now finished.' }); }, }), - }, + ], }); const session = new voice.AgentSession({ diff --git a/examples/src/basic_toolsets.ts b/examples/src/basic_toolsets.ts new file mode 100644 index 000000000..cb0cf2b6e --- /dev/null +++ b/examples/src/basic_toolsets.ts @@ -0,0 +1,180 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { + type JobContext, + type JobProcess, + ServerOptions, + cli, + defineAgent, + inference, + llm, + voice, +} from '@livekit/agents'; +import * as livekit from '@livekit/agents-plugin-livekit'; +import * as silero from '@livekit/agents-plugin-silero'; +import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; +import { fileURLToPath } from 'node:url'; +import { z } from 'zod'; + +class InfoTask extends voice.AgentTask { + private key: string; + + constructor(key: string, sharedToolset: llm.Toolset) { + super({ + instructions: `Collect the user's ${key}. Once you have it, call saveUserInfo IMMEDIATELY. No chitchat.`, + tools: [ + sharedToolset, + llm.tool({ + name: 'saveUserInfo', + description: `Save the user's ${key} to the database`, + parameters: z.object({ + [key]: z.string(), + }), + execute: async (args) => { + this.complete(args[key] as string); + return `Thanks, collected ${key} successfully: ${args[key]}`; + }, + }), + ], + }); + this.key = key; + } + + async onEnter() { + this.session.generateReply({ userInput: `Ask the user for their ${this.key}` }); + } +} + +function makeWeatherAgent(returnHome: () => voice.Agent) { + const weatherToolset = new llm.Toolset({ + id: 'weather_tools', + tools: [ + llm.tool({ + name: 'getWeather', + description: 'Get the weather for a given location', + parameters: z.object({ location: z.string() }), + execute: async ({ location }) => `The weather in ${location} is sunny today.`, + }), + ], + }); + + return new voice.Agent({ + instructions: 'You are a weather agent. Provide weather information then hand back when done.', + tools: [ + weatherToolset, + llm.tool({ + name: 'finishWeatherConversation', + description: 'Call this when you want to finish the weather conversation', + execute: async () => { + return llm.handoff({ agent: returnHome(), returns: 'Transfer back to main agent.' }); + }, + }), + ], + }); +} + +class MainAgent extends voice.Agent { + private locationToolset: llm.Toolset; + + constructor(locationToolset: llm.Toolset) { + super({ + instructions: + 'You are a helpful assistant. Use the location toolset for weather/timezone. Use transferToWeather when the user asks about weather. Use swapToolset / reapplyTools to exercise updateTools.', + tools: [ + locationToolset, + llm.tool({ + name: 'transferToWeather', + description: 'Call this when the user wants to know the weather', + execute: async () => { + return llm.handoff({ + agent: makeWeatherAgent(() => new MainAgent(locationToolset)), + returns: "Let's switch to the weather agent.", + }); + }, + }), + llm.tool({ + name: 'swapToolset', + description: 'Replace the active toolset with a brand-new toolset (tests updateTools).', + execute: async () => { + const replacement = new llm.Toolset({ + id: 'location_tools_v2', + tools: [ + llm.tool({ + name: 'getWeather', + description: 'v2 weather', + parameters: z.object({ location: z.string() }), + execute: async ({ location }) => `v2: ${location} -> sunny`, + }), + ], + }); + await this.updateTools([replacement]); + return 'Swapped toolset.'; + }, + }), + llm.tool({ + name: 'reapplyTools', + description: 'Re-apply the current tool list unchanged (idempotent updateTools).', + execute: async () => { + await this.updateTools([...this.toolCtx.tools]); + return 'Re-applied the same tool list.'; + }, + }), + ], + }); + this.locationToolset = locationToolset; + } + + async onEnter() { + const name = await new InfoTask('name', this.locationToolset).run(); + await this.session.say( + `Got it, ${name}. Ask me about weather, or say "swap" / "reapply" to exercise updateTools.`, + ); + } +} + +export default defineAgent({ + prewarm: async (proc: JobProcess) => { + proc.userData.vad = await silero.VAD.load(); + }, + entry: async (ctx: JobContext) => { + const locationToolset = new llm.Toolset({ + id: 'location_tools', + tools: [ + llm.tool({ + name: 'getWeather', + description: 'Get the weather for a given location.', + parameters: z.object({ location: z.string() }), + execute: async ({ location }) => `The weather in ${location} is sunny.`, + }), + llm.tool({ + name: 'lookupTimezone', + description: 'Look up the timezone for a city or region.', + parameters: z.object({ location: z.string() }), + execute: async ({ location }) => `${location} is in the America/Los_Angeles timezone.`, + }), + ], + }); + + const session = new voice.AgentSession({ + vad: ctx.proc.userData.vad! as silero.VAD, + stt: new inference.STT({ model: 'deepgram/nova-3', language: 'en' }), + llm: new inference.LLM({ model: 'openai/gpt-4.1-mini' }), + tts: new inference.TTS({ + model: 'cartesia/sonic-3', + voice: '9626c31c-bec5-4cca-baa8-f8ba9e84c8bc', + }), + turnDetection: new livekit.turnDetector.MultilingualModel(), + }); + + await session.start({ + agent: new MainAgent(locationToolset), + room: ctx.room, + inputOptions: { noiseCancellation: BackgroundVoiceCancellation() }, + }); + + session.say('Hello! I will ask you a quick question, then we can chat.'); + }, +}); + +cli.runApp(new ServerOptions({ agent: fileURLToPath(import.meta.url) })); diff --git a/examples/src/comprehensive_test.ts b/examples/src/comprehensive_test.ts index 9e9e25225..e72bc7741 100644 --- a/examples/src/comprehensive_test.ts +++ b/examples/src/comprehensive_test.ts @@ -74,8 +74,9 @@ class MainAgent extends voice.Agent { tts: ttsOptions['elevenlabs'](), llm: llmOptions['openai'](), turnDetection: eouOptions['multilingual'](), - tools: { - testAgent: llm.tool({ + tools: [ + llm.tool({ + name: 'testAgent', description: 'Called when user want to test an agent with STT, TTS, EOU, LLM, and optionally realtime LLM configuration', parameters: z.object({ @@ -100,7 +101,7 @@ class MainAgent extends voice.Agent { }); }, }), - }, + ], }); } @@ -157,8 +158,9 @@ class TestAgent extends voice.Agent { tts: tts, llm: realtimeModel ?? model, turnDetection: eou, - tools: { - testTool: llm.tool({ + tools: [ + llm.tool({ + name: 'testTool', description: "Testing agent's tool calling ability", parameters: z .object({ @@ -171,7 +173,8 @@ class TestAgent extends voice.Agent { }; }, }), - nextAgent: llm.tool({ + llm.tool({ + name: 'nextAgent', description: 'Called when user confirm current agent is working and want to proceed to next agent', parameters: z.object({ @@ -202,7 +205,7 @@ class TestAgent extends voice.Agent { }); }, }), - }, + ], }); this.sttChoice = sttChoice; diff --git a/examples/src/drive-thru/drivethru_agent.ts b/examples/src/drive-thru/drivethru_agent.ts index c9a534dec..838f5d104 100644 --- a/examples/src/drive-thru/drivethru_agent.ts +++ b/examples/src/drive-thru/drivethru_agent.ts @@ -47,23 +47,24 @@ export class DriveThruAgent extends voice.Agent { super({ instructions, - tools: { - orderComboMeal: DriveThruAgent.buildComboOrderTool( + tools: [ + DriveThruAgent.buildComboOrderTool( userdata.comboItems, userdata.drinkItems, userdata.sauceItems, ), - orderHappyMeal: DriveThruAgent.buildHappyOrderTool( + DriveThruAgent.buildHappyOrderTool( userdata.happyItems, userdata.drinkItems, userdata.sauceItems, ), - orderRegularItem: DriveThruAgent.buildRegularOrderTool( + DriveThruAgent.buildRegularOrderTool( userdata.regularItems, userdata.drinkItems, userdata.sauceItems, ), - removeOrderItem: llm.tool({ + llm.tool({ + name: 'removeOrderItem', description: `Removes one or more items from the user's order using their \`orderId\`s. Useful when the user asks to cancel or delete existing items (e.g., "Remove the cheeseburger"). @@ -90,7 +91,8 @@ If the \`orderId\`s are unknown, call \`listOrderItems\` first to retrieve them. return 'Removed items:\n' + removedItems.map((item) => JSON.stringify(item)).join('\n'); }, }), - listOrderItems: llm.tool({ + llm.tool({ + name: 'listOrderItems', description: `Retrieves the current list of items in the user's order, including each item's internal \`orderId\`. Helpful when: @@ -110,7 +112,7 @@ Examples: return items.map((item) => JSON.stringify(item)).join('\n'); }, }), - }, + ], }); } @@ -124,6 +126,7 @@ Examples: const availableSauceIds = [...new Set(sauceItems.map((item) => item.id))]; return llm.tool({ + name: 'orderComboMeal', description: `Call this when the user orders a **Combo Meal**, like: "Number 4b with a large Sprite" or "I'll do a medium meal." Do not call this tool unless the user clearly refers to a known combo meal by name or number. @@ -212,6 +215,7 @@ If the user says just "a large meal," assume both drink and fries are that size. const availableSauceIds = [...new Set(sauceItems.map((item) => item.id))]; return llm.tool({ + name: 'orderHappyMeal', description: `Call this when the user orders a **Happy Meal**, typically for children. These meals come with a main item, a drink, and a sauce. The user must clearly specify a valid Happy Meal option (e.g., "Can I get a Happy Meal?"). @@ -289,6 +293,7 @@ Assume Small as default only if the user says "Happy Meal" and gives no size pre const availableIds = [...new Set(allItems.map((item) => item.id))]; return llm.tool({ + name: 'orderRegularItem', description: `Call this when the user orders **a single item on its own**, not as part of a Combo Meal or Happy Meal. The customer must provide clear and specific input. For example, item variants such as flavor must **always** be explicitly stated. diff --git a/examples/src/drive-thru/test_agent.test.ts b/examples/src/drive-thru/test_agent.test.ts index 5750c2452..7eae74f8e 100644 --- a/examples/src/drive-thru/test_agent.test.ts +++ b/examples/src/drive-thru/test_agent.test.ts @@ -156,6 +156,54 @@ describe('DriveThru Agent Tests', { timeout: 180_000 }, () => { }); }); + describe('test_failure', () => { + let session: voice.AgentSession; + let llmInstance: openai.LLM; + let judgeInstance: openai.LLM; + let userdata: UserData; + + beforeAll(async () => { + userdata = await newUserData(); + llmInstance = mainLLM(); + judgeInstance = judgeLLM(); + session = new voice.AgentSession({ + llm: llmInstance, + userData: userdata, + }); + await session.start({ agent: new DriveThruAgent(userdata) }); + }, 30_000); + + afterAll(async () => { + await session?.close(); + }); + + it('should recover gracefully when a tool throws', async () => { + // Simulate a tool error via withMockTools (mirrors Python's mock_tools usage). + using _mock = voice.testing.withMockTools(DriveThruAgent, { + orderRegularItem: () => { + throw new Error('test failure'); + }, + }); + + const result = session.run({ userInput: 'Can I get a large vanilla shake?' }); + await result.wait(); + + result.expect.skipNextEventIf({ type: 'message', role: 'assistant' }); + result.expect.nextEvent().isFunctionCall({ + name: 'orderRegularItem', + args: { itemId: 'shake_vanilla', size: 'L' }, + }); + result.expect.nextEvent().isFunctionCallOutput(); + await result.expect.nextEvent().isMessage({ role: 'assistant' }).judge(judgeInstance, { + intent: + "should inform the user that something went wrong, it's ok to ask them to try again", + }); + + // leaving this commented, some LLMs may occasionally try to retry. + // result.expect.noMoreEvents(); + }); + }); + describe('test_unavailable_item', () => { let session: voice.AgentSession; let llmInstance: openai.LLM; diff --git a/examples/src/frontdesk/frontdesk_agent.ts b/examples/src/frontdesk/frontdesk_agent.ts index 8e9e50f42..b7a7c9bff 100644 --- a/examples/src/frontdesk/frontdesk_agent.ts +++ b/examples/src/frontdesk/frontdesk_agent.ts @@ -49,8 +49,9 @@ export class FrontDeskAgent extends voice.Agent { super({ instructions, - tools: { - scheduleAppointment: llm.tool({ + tools: [ + llm.tool({ + name: 'scheduleAppointment', description: 'Schedule an appointment at the given slot.', parameters: z.object({ slotId: z @@ -100,7 +101,8 @@ export class FrontDeskAgent extends voice.Agent { return `The appointment was successfully scheduled for ${formatted}.`; }, }), - listAvailableSlots: llm.tool({ + llm.tool({ + name: 'listAvailableSlots', description: `Return a plain-text list of available slots, one per line. - , , at () @@ -178,7 +180,7 @@ You must infer the appropriate range implicitly from the conversational context return lines.join('\n') || 'No slots available at the moment.'; }, }), - }, + ], }); this.tz = options.timezone; diff --git a/examples/src/gemini_realtime_agent.ts b/examples/src/gemini_realtime_agent.ts index 3b82c8ec0..f131d2526 100644 --- a/examples/src/gemini_realtime_agent.ts +++ b/examples/src/gemini_realtime_agent.ts @@ -44,6 +44,7 @@ type StoryData = { const roomNameSchema = z.enum(['bedroom', 'living room', 'kitchen', 'bathroom', 'office']); const getWeather = llm.tool({ + name: 'getWeather', description: 'Called when the user asks about the weather.', parameters: z.object({ location: z.string().describe('The location to get the weather for'), @@ -56,6 +57,7 @@ const getWeather = llm.tool({ }); const toggleLight = llm.tool({ + name: 'toggleLight', description: 'Called when the user asks to turn on or off the light.', parameters: z.object({ room: roomNameSchema.describe('The room to turn the light in'), @@ -73,28 +75,29 @@ class IntroAgent extends voice.Agent { }); } - static create() { + static createIntroAgent() { return new IntroAgent({ instructions: `You are a story teller. Your goal is to gather a few pieces of information from the user to make the story personalized and engaging. Ask the user for their name and where they are from.`, - tools: { - informationGathered: llm.tool({ + tools: [ + llm.tool({ + name: 'informationGathered', description: 'Called when the user has provided the information needed to make the story personalized and engaging.', parameters: z.object({ name: z.string().describe('The name of the user'), location: z.string().describe('The location of the user'), }), - execute: async ({ name, location }, { ctx }) => { + execute: async ({ name, location }, { ctx }: llm.ToolOptions) => { ctx.userData.name = name; ctx.userData.location = location; - const storyAgent = StoryAgent.create(name, location); + const storyAgent = StoryAgent.createStoryAgent(name, location); return llm.handoff({ agent: storyAgent, returns: "Let's start the story!" }); }, }), getWeather, toggleLight, - }, + ], }); } } @@ -104,7 +107,7 @@ class StoryAgent extends voice.Agent { this.session.generateReply(); } - static create(name: string, location: string) { + static createStoryAgent(name: string, location: string) { return new StoryAgent({ instructions: dedent` You are a storyteller. Use the user's information in order to make the story personalized. @@ -132,7 +135,7 @@ export default defineAgent({ }); await session.start({ - agent: IntroAgent.create(), + agent: IntroAgent.createIntroAgent(), room: ctx.room, }); diff --git a/examples/src/instructions_per_modality.ts b/examples/src/instructions_per_modality.ts index 793f9819c..f2ab6b910 100644 --- a/examples/src/instructions_per_modality.ts +++ b/examples/src/instructions_per_modality.ts @@ -55,8 +55,9 @@ class SchedulingAgent extends voice.Agent { super({ instructions, - tools: { - bookAppointment: llm.tool({ + tools: [ + llm.tool({ + name: 'bookAppointment', description: 'Book an appointment.', parameters: z.object({ date: z.string().describe('The date of the appointment in the format YYYY-MM-DD'), @@ -67,7 +68,7 @@ class SchedulingAgent extends voice.Agent { return `Appointment booked for ${date} at ${time}`; }, }), - }, + ], }); } diff --git a/examples/src/llm_fallback_adapter.ts b/examples/src/llm_fallback_adapter.ts index d3d407214..8db1e243d 100644 --- a/examples/src/llm_fallback_adapter.ts +++ b/examples/src/llm_fallback_adapter.ts @@ -59,8 +59,9 @@ export default defineAgent({ const agent = new voice.Agent({ instructions: 'You are a helpful assistant. Demonstrate that you are working by responding to user queries.', - tools: { - getWeather: llm.tool({ + tools: [ + llm.tool({ + name: 'getWeather', description: 'Get the weather for a given location.', parameters: z.object({ location: z.string().describe('The location to get the weather for'), @@ -69,7 +70,7 @@ export default defineAgent({ return `The weather in ${location} is sunny with a temperature of 72°F.`; }, }), - }, + ], }); const session = new voice.AgentSession({ diff --git a/examples/src/manual_shutdown.ts b/examples/src/manual_shutdown.ts index 56f770ca0..372601dd2 100644 --- a/examples/src/manual_shutdown.ts +++ b/examples/src/manual_shutdown.ts @@ -19,8 +19,9 @@ export default defineAgent({ const agent = new voice.Agent({ instructions: "You are a helpful assistant, you can hear the user's message and respond to it, end the call when the user asks you to.", - tools: { - getWeather: llm.tool({ + tools: [ + llm.tool({ + name: 'getWeather', description: 'Get the weather for a given location.', parameters: z.object({ location: z.string().describe('The location to get the weather for'), @@ -29,7 +30,8 @@ export default defineAgent({ return `The weather in ${location} is sunny.`; }, }), - endCall: llm.tool({ + llm.tool({ + name: 'endCall', description: 'End the call.', parameters: z.object({ reason: z @@ -50,7 +52,7 @@ export default defineAgent({ session.shutdown({ reason }); }, }), - }, + ], }); const session = new voice.AgentSession({ diff --git a/examples/src/multi_agent.ts b/examples/src/multi_agent.ts index 263ba6093..8137527e1 100644 --- a/examples/src/multi_agent.ts +++ b/examples/src/multi_agent.ts @@ -29,26 +29,27 @@ class IntroAgent extends voice.Agent { }); } - static create() { + static createIntroAgent() { return new IntroAgent({ instructions: `You are a story teller. Your goal is to gather a few pieces of information from the user to make the story personalized and engaging. Ask the user for their name and where they are from.`, - tools: { - informationGathered: llm.tool({ + tools: [ + llm.tool({ + name: 'informationGathered', description: 'Called when the user has provided the information needed to make the story personalized and engaging.', parameters: z.object({ name: z.string().describe('The name of the user'), location: z.string().describe('The location of the user'), }), - execute: async ({ name, location }, { ctx }) => { + execute: async ({ name, location }, { ctx }: llm.ToolOptions) => { ctx.userData.name = name; ctx.userData.location = location; - const storyAgent = StoryAgent.create(name, location); + const storyAgent = StoryAgent.createStoryAgent(name, location); return llm.handoff({ agent: storyAgent, returns: "Let's start the story!" }); }, }), - }, + ], }); } } @@ -58,7 +59,7 @@ class StoryAgent extends voice.Agent { this.session.generateReply(); } - static create(name: string, location: string) { + static createStoryAgent(name: string, location: string) { return new StoryAgent({ instructions: dedent` You are a storyteller. Use the user's information in order to make the story personalized. @@ -85,7 +86,7 @@ export default defineAgent({ }); await session.start({ - agent: IntroAgent.create(), + agent: IntroAgent.createIntroAgent(), room: ctx.room, }); diff --git a/examples/src/phonic_realtime_agent.ts b/examples/src/phonic_realtime_agent.ts index 35a038450..934485f5b 100644 --- a/examples/src/phonic_realtime_agent.ts +++ b/examples/src/phonic_realtime_agent.ts @@ -7,6 +7,7 @@ import { fileURLToPath } from 'node:url'; import { z } from 'zod'; const toggleLight = llm.tool({ + name: 'toggle_light', description: 'Toggle a light on or off. Available lights are A05, A06, A07, and A08.', parameters: z.object({ light_id: z.string().describe('The ID of the light to toggle'), @@ -23,9 +24,7 @@ export default defineAgent({ entry: async (ctx: JobContext) => { const agent = new voice.Agent({ instructions: 'You are a helpful voice AI assistant named Alex.', - tools: { - toggle_light: toggleLight, - }, + tools: [toggleLight], }); const session = new voice.AgentSession({ diff --git a/examples/src/raw_function_description.ts b/examples/src/raw_function_description.ts index 6a1d744a5..db15cecae 100644 --- a/examples/src/raw_function_description.ts +++ b/examples/src/raw_function_description.ts @@ -15,8 +15,9 @@ import { fileURLToPath } from 'node:url'; function createRawFunctionAgent() { return new voice.Agent({ instructions: 'You are a helpful assistant.', - tools: { - openGate: llm.tool({ + tools: [ + llm.tool({ + name: 'openGate', description: 'Opens a specified gate from a predefined set of access points.', parameters: { type: 'object', @@ -40,7 +41,7 @@ function createRawFunctionAgent() { return `The gate ${gateId} is now open.`; }, }), - }, + ], }); } diff --git a/examples/src/realtime_agent.ts b/examples/src/realtime_agent.ts index a6879262b..7b7b82a29 100644 --- a/examples/src/realtime_agent.ts +++ b/examples/src/realtime_agent.ts @@ -12,6 +12,7 @@ const roomNameSchema = z.enum(['bedroom', 'living room', 'kitchen', 'bathroom', export default defineAgent({ entry: async (ctx: JobContext) => { const getWeather = llm.tool({ + name: 'getWeather', description: ' Called when the user asks about the weather.', parameters: z.object({ location: z.string().describe('The location to get the weather for'), @@ -22,6 +23,7 @@ export default defineAgent({ }); const toggleLight = llm.tool({ + name: 'toggleLight', description: 'Called when the user asks to turn on or off the light.', parameters: z.object({ room: roomNameSchema.describe('The room to turn the light in'), @@ -53,10 +55,7 @@ export default defineAgent({ instructions: "You are a helpful assistant created by LiveKit, always speaking English, you can hear the user's message and respond to it.", chatCtx, - tools: { - getWeather, - toggleLight, - }, + tools: [getWeather, toggleLight], }); const session = new voice.AgentSession({ diff --git a/examples/src/realtime_with_tts.ts b/examples/src/realtime_with_tts.ts index 05df047be..d0e898630 100644 --- a/examples/src/realtime_with_tts.ts +++ b/examples/src/realtime_with_tts.ts @@ -13,6 +13,7 @@ export default defineAgent({ const logger = log(); const getWeather = llm.tool({ + name: 'getWeather', description: 'Called when the user asks about the weather.', parameters: z.object({ location: z.string().describe('The location to get the weather for'), @@ -25,9 +26,7 @@ export default defineAgent({ const agent = new voice.Agent({ instructions: 'You are a helpful assistant. Always speak in English.', - tools: { - getWeather, - }, + tools: [getWeather], }); const session = new voice.AgentSession({ diff --git a/examples/src/restaurant_agent.ts b/examples/src/restaurant_agent.ts index 081552c1a..e83877b6a 100644 --- a/examples/src/restaurant_agent.ts +++ b/examples/src/restaurant_agent.ts @@ -77,6 +77,7 @@ function summarize({ } const updateName = llm.tool({ + name: 'updateName', description: 'Called when the user provides their name. Confirm the spelling with the user before calling the function.', parameters: z.object({ @@ -89,6 +90,7 @@ const updateName = llm.tool({ }); const updatePhone = llm.tool({ + name: 'updatePhone', description: 'Called when the user provides their phone number. Confirm the spelling with the user before calling the function.', parameters: z.object({ @@ -101,6 +103,7 @@ const updatePhone = llm.tool({ }); const toGreeter = llm.tool({ + name: 'toGreeter', description: 'Called when user asks any unrelated questions or requests any other services not in your job description.', execute: async (_, { ctx }: llm.ToolOptions) => { @@ -171,35 +174,37 @@ function createGreeterAgent(menu: string) { instructions: `You are a friendly restaurant receptionist. The menu is: ${menu}\nYour jobs are to greet the caller and understand if they want to make a reservation or order takeaway. Guide them to the right agent using tools.`, llm: new inference.LLM({ model: 'openai/gpt-4.1-mini' }), tts: new inference.TTS({ model: 'cartesia/sonic-3', voice: voices.greeter }), - tools: { - toReservation: llm.tool({ + tools: [ + llm.tool({ + name: 'toReservation', description: dedent` Called when user wants to make or update a reservation. This function handles transitioning to the reservation agent who will collect the necessary details like reservation time, customer name and phone number. `, - execute: async (_, { ctx }): Promise => { + execute: async (_, { ctx }: llm.ToolOptions): Promise => { return await greeter.transferToAgent({ name: 'reservation', ctx, }); }, }), - toTakeaway: llm.tool({ + llm.tool({ + name: 'toTakeaway', description: dedent` Called when the user wants to place a takeaway order. This includes handling orders for pickup, delivery, or when the user wants to proceed to checkout with their existing order. `, - execute: async (_, { ctx }): Promise => { + execute: async (_, { ctx }: llm.ToolOptions): Promise => { return await greeter.transferToAgent({ name: 'takeaway', ctx, }); }, }), - }, + ], }); return greeter; @@ -210,11 +215,12 @@ function createReservationAgent() { name: 'reservation', instructions: `You are a reservation agent at a restaurant. Your jobs are to ask for the reservation time, then customer's name, and phone number. Then confirm the reservation details with the customer.`, tts: new inference.TTS({ model: 'cartesia/sonic-3', voice: voices.reservation }), - tools: { + tools: [ updateName, updatePhone, toGreeter, - updateReservationTime: llm.tool({ + llm.tool({ + name: 'updateReservationTime', description: dedent` Called when the user provides their reservation time. Confirm the time with the user before calling the function. @@ -222,14 +228,18 @@ function createReservationAgent() { parameters: z.object({ time: z.string().describe('The reservation time'), }), - execute: async ({ time }, { ctx }) => { + execute: async ({ time }, { ctx }: llm.ToolOptions) => { ctx.userData.reservationTime = time; return `The reservation time is updated to ${time}`; }, }), - confirmReservation: llm.tool({ + llm.tool({ + name: 'confirmReservation', description: `Called when the user confirms the reservation.`, - execute: async (_, { ctx }): Promise => { + execute: async ( + _, + { ctx }: llm.ToolOptions, + ): Promise => { const userdata = ctx.userData; if (!userdata.customer.name || !userdata.customer.phone) { return 'Please provide your name and phone number first.'; @@ -243,7 +253,7 @@ function createReservationAgent() { }); }, }), - }, + ], }); return reservation; @@ -254,21 +264,26 @@ function createTakeawayAgent(menu: string) { name: 'takeaway', instructions: `Your are a takeaway agent that takes orders from the customer. Our menu is: ${menu}\nClarify special requests and confirm the order with the customer.`, tts: new inference.TTS({ model: 'cartesia/sonic-3', voice: voices.takeaway }), - tools: { + tools: [ toGreeter, - updateOrder: llm.tool({ + llm.tool({ + name: 'updateOrder', description: `Called when the user provides their order.`, parameters: z.object({ items: z.array(z.string()).describe('The items of the full order'), }), - execute: async ({ items }, { ctx }) => { + execute: async ({ items }, { ctx }: llm.ToolOptions) => { ctx.userData.order = items; return `The order is updated to ${items}`; }, }), - toCheckout: llm.tool({ + llm.tool({ + name: 'toCheckout', description: `Called when the user confirms the order.`, - execute: async (_, { ctx }): Promise => { + execute: async ( + _, + { ctx }: llm.ToolOptions, + ): Promise => { const userdata = ctx.userData; if (!userdata.order) { return 'No takeaway order found. Please make an order first.'; @@ -279,7 +294,7 @@ function createTakeawayAgent(menu: string) { }); }, }), - }, + ], }); return takeaway; @@ -290,21 +305,23 @@ function createCheckoutAgent(menu: string) { name: 'checkout', instructions: `You are a checkout agent at a restaurant. The menu is: ${menu}\nYour are responsible for confirming the expense of the order and then collecting customer's name, phone number and credit card information, including the card number, expiry date, and CVV step by step.`, tts: new inference.TTS({ model: 'cartesia/sonic-3', voice: voices.checkout }), - tools: { + tools: [ updateName, updatePhone, toGreeter, - confirmExpense: llm.tool({ + llm.tool({ + name: 'confirmExpense', description: `Called when the user confirms the expense.`, parameters: z.object({ expense: z.number().describe('The expense of the order'), }), - execute: async ({ expense }, { ctx }) => { + execute: async ({ expense }, { ctx }: llm.ToolOptions) => { ctx.userData.expense = expense; return `The expense is confirmed to be ${expense}`; }, }), - updateCreditCard: llm.tool({ + llm.tool({ + name: 'updateCreditCard', description: dedent` Called when the user provides their credit card number, expiry date, and CVV. Confirm the spelling with the user before calling the function. @@ -314,14 +331,18 @@ function createCheckoutAgent(menu: string) { expiry: z.string().describe('The expiry date of the credit card'), cvv: z.string().describe('The CVV of the credit card'), }), - execute: async ({ number, expiry, cvv }, { ctx }) => { + execute: async ({ number, expiry, cvv }, { ctx }: llm.ToolOptions) => { ctx.userData.creditCard = { number, expiry, cvv }; return `The credit card number is updated to ${number}`; }, }), - confirmCheckout: llm.tool({ + llm.tool({ + name: 'confirmCheckout', description: `Called when the user confirms the checkout.`, - execute: async (_, { ctx }): Promise => { + execute: async ( + _, + { ctx }: llm.ToolOptions, + ): Promise => { const userdata = ctx.userData; if (!userdata.expense) { return 'Please confirm the expense first.'; @@ -340,16 +361,17 @@ function createCheckoutAgent(menu: string) { }); }, }), - toTakeaway: llm.tool({ + llm.tool({ + name: 'toTakeaway', description: `Called when the user wants to update their order.`, - execute: async (_, { ctx }): Promise => { + execute: async (_, { ctx }: llm.ToolOptions): Promise => { return await checkout.transferToAgent({ name: 'takeaway', ctx, }); }, }), - }, + ], }); return checkout; diff --git a/examples/src/survey_agent.ts b/examples/src/survey_agent.ts index 8504fa1a7..918fd9859 100644 --- a/examples/src/survey_agent.ts +++ b/examples/src/survey_agent.ts @@ -76,6 +76,7 @@ async function writeCsvRow(path: string, data: Record): Promise function disqualifyTool() { return llm.tool({ + name: 'disqualify', description: 'End the interview if the candidate refuses to cooperate, provides inappropriate answers, or is not a fit.', parameters: z.object({ @@ -101,8 +102,9 @@ export class IntroTask extends voice.AgentTask { super({ instructions: 'You are Alex, an interviewer screening a software engineer candidate. Gather the candidate name and short self-introduction.', - tools: { - saveIntro: llm.tool({ + tools: [ + llm.tool({ + name: 'saveIntro', description: 'Save candidate name and intro notes.', parameters: z.object({ name: z.string().describe('Candidate name'), @@ -114,7 +116,7 @@ export class IntroTask extends voice.AgentTask { return `Saved intro for ${name}.`; }, }), - }, + ], }); } @@ -132,9 +134,10 @@ export class EmailTask extends voice.AgentTask { super({ instructions: 'Collect a valid email address. If the candidate refuses, call disqualify immediately.', - tools: { + tools: [ disqualify, - saveEmail: llm.tool({ + llm.tool({ + name: 'saveEmail', description: 'Save candidate email address.', parameters: z.object({ email: z.string().describe('Candidate email'), @@ -144,7 +147,7 @@ export class EmailTask extends voice.AgentTask { return `Saved email: ${email}`; }, }), - }, + ], }); } @@ -161,9 +164,10 @@ export class CommuteTask extends voice.AgentTask super({ instructions: 'Collect commute flexibility. The role expects office attendance three days per week.', - tools: { + tools: [ disqualify, - saveCommute: llm.tool({ + llm.tool({ + name: 'saveCommute', description: 'Save candidate commute information.', parameters: z.object({ canCommute: z.boolean().describe('Whether the candidate can commute to office'), @@ -176,7 +180,7 @@ export class CommuteTask extends voice.AgentTask return 'Saved commute flexibility.'; }, }), - }, + ], }); } @@ -194,9 +198,10 @@ export class ExperienceTask extends voice.AgentTask { super({ instructions: 'You are a survey interviewer for a software engineer screening. Be concise, professional, and natural. Call endScreening when the process is complete.', - tools: { - endScreening: llm.tool({ + tools: [ + llm.tool({ + name: 'endScreening', description: 'End interview and hang up.', execute: async (_, { ctx }: llm.ToolOptions) => { ctx.session.shutdown(); return 'Interview concluded.'; }, }), - }, + ], }); } diff --git a/examples/src/testing/agent_task.test.ts b/examples/src/testing/agent_task.test.ts index 163d232a1..38d2a85d9 100644 --- a/examples/src/testing/agent_task.test.ts +++ b/examples/src/testing/agent_task.test.ts @@ -148,8 +148,9 @@ describe('AgentTask examples', { timeout: 120_000 }, () => { super({ instructions: 'You are collecting a name and role. Extract both from user input and call recordIntro.', - tools: { - recordIntro: llm.tool({ + tools: [ + llm.tool({ + name: 'recordIntro', description: 'Record the name and role', parameters: z.object({ name: z.string().describe('User name'), @@ -160,7 +161,7 @@ describe('AgentTask examples', { timeout: 120_000 }, () => { return 'recorded'; }, }), - }, + ], }); } @@ -222,8 +223,9 @@ describe('AgentTask examples', { timeout: 120_000 }, () => { super({ instructions: 'When asked to capture email, ALWAYS call captureEmail exactly once, then respond briefly.', - tools: { - captureEmail: llm.tool({ + tools: [ + llm.tool({ + name: 'captureEmail', description: 'Capture an email by running a nested AgentTask.', parameters: z.object({}), execute: async () => { @@ -236,7 +238,7 @@ describe('AgentTask examples', { timeout: 120_000 }, () => { } }, }), - }, + ], }); } @@ -275,8 +277,9 @@ describe('AgentTask examples', { timeout: 120_000 }, () => { instructions: 'You are Alex, an interviewer. Extract the candidate name and a short intro from the latest user input. ' + 'Use the tool recordIntro exactly once when both are available.', - tools: { - recordIntro: llm.tool({ + tools: [ + llm.tool({ + name: 'recordIntro', description: 'Record candidate name and intro summary.', parameters: z.object({ name: z.string().describe('Candidate name'), @@ -288,7 +291,7 @@ describe('AgentTask examples', { timeout: 120_000 }, () => { return 'Intro recorded.'; }, }), - }, + ], }); } @@ -305,8 +308,9 @@ describe('AgentTask examples', { timeout: 120_000 }, () => { super({ instructions: 'When the user asks to run the intro task, ALWAYS call collectIntroWithTask exactly once.', - tools: { - collectIntroWithTask: llm.tool({ + tools: [ + llm.tool({ + name: 'collectIntroWithTask', description: 'Launch the IntroTask and return the captured intro details.', parameters: z.object({}), execute: async () => { @@ -316,7 +320,7 @@ describe('AgentTask examples', { timeout: 120_000 }, () => { return JSON.stringify(result); }, }), - }, + ], }); } } diff --git a/examples/src/testing/basic_task_group.test.ts b/examples/src/testing/basic_task_group.test.ts index 07931279b..fc90cccba 100644 --- a/examples/src/testing/basic_task_group.test.ts +++ b/examples/src/testing/basic_task_group.test.ts @@ -82,8 +82,9 @@ class CollectNameTask extends voice.AgentTask { super({ instructions: 'Collect the user name from the latest user message. As soon as you have it, call save_name.', - tools: { - save_name: llm.tool({ + tools: [ + llm.tool({ + name: 'save_name', description: 'Save the user name.', parameters: z.object({ name: z.string().describe('The user name') }), execute: async ({ name }) => { @@ -91,7 +92,7 @@ class CollectNameTask extends voice.AgentTask { return `Saved name: ${name}`; }, }), - }, + ], }); this.ready = ready; } @@ -108,8 +109,9 @@ class CollectEmailTask extends voice.AgentTask { super({ instructions: 'Collect the user email from the latest user message. As soon as you have it, call save_email.', - tools: { - save_email: llm.tool({ + tools: [ + llm.tool({ + name: 'save_email', description: 'Save the user email.', parameters: z.object({ email: z.string().describe('The user email') }), execute: async ({ email }) => { @@ -117,7 +119,7 @@ class CollectEmailTask extends voice.AgentTask { return `Saved email: ${email}`; }, }), - }, + ], }); this.ready = ready; } diff --git a/examples/src/testing/run_result.test.ts b/examples/src/testing/run_result.test.ts index 583cbaffa..4169dea93 100644 --- a/examples/src/testing/run_result.test.ts +++ b/examples/src/testing/run_result.test.ts @@ -45,8 +45,9 @@ Response rules: - After ordering, confirm what was added (e.g., "I've added the burger to your order"). - When asked about sizes, always ask for clarification if not specified. - Be friendly and proactive in suggesting next steps.`, - tools: { - getWeather: llm.tool({ + tools: [ + llm.tool({ + name: 'getWeather', description: 'Get the current weather for a location', parameters: z.object({ location: z.string().describe('The city name'), @@ -59,14 +60,16 @@ Response rules: }); }, }), - getCurrentTime: llm.tool({ + llm.tool({ + name: 'getCurrentTime', description: 'Get the current time', parameters: z.object({}), execute: async () => { return '3:00 PM'; }, }), - orderItem: llm.tool({ + llm.tool({ + name: 'orderItem', description: 'Add an item to the order', parameters: z.object({ itemId: z.string().describe('The menu item ID'), @@ -84,7 +87,8 @@ Response rules: }); }, }), - getOrderStatus: llm.tool({ + llm.tool({ + name: 'getOrderStatus', description: 'Get the current order status', parameters: z.object({}), execute: async () => { @@ -98,7 +102,8 @@ Response rules: }); }, }), - getMenuItems: llm.tool({ + llm.tool({ + name: 'getMenuItems', description: 'Get available menu items and prices', parameters: z.object({ category: z @@ -128,7 +133,7 @@ Response rules: return JSON.stringify(menu); }, }), - }, + ], }); } } diff --git a/examples/src/testing/task_group.test.ts b/examples/src/testing/task_group.test.ts index 3d5afff06..188d46e59 100644 --- a/examples/src/testing/task_group.test.ts +++ b/examples/src/testing/task_group.test.ts @@ -260,8 +260,9 @@ describe('TaskGroup', { timeout: 120_000 }, () => { super({ instructions: 'Extract the user name from the latest user message. Call recordName immediately.', - tools: { - recordName: llm.tool({ + tools: [ + llm.tool({ + name: 'recordName', description: 'Record the user name', parameters: z.object({ name: z.string().describe('The user name') }), execute: async ({ name }) => { @@ -269,7 +270,7 @@ describe('TaskGroup', { timeout: 120_000 }, () => { return 'recorded'; }, }), - }, + ], }); } @@ -283,8 +284,9 @@ describe('TaskGroup', { timeout: 120_000 }, () => { super({ instructions: 'Extract an email address from the latest user message. Call recordEmail immediately.', - tools: { - recordEmail: llm.tool({ + tools: [ + llm.tool({ + name: 'recordEmail', description: 'Record the user email', parameters: z.object({ email: z.string().describe('The email address') }), execute: async ({ email }) => { @@ -292,7 +294,7 @@ describe('TaskGroup', { timeout: 120_000 }, () => { return 'recorded'; }, }), - }, + ], }); } @@ -400,8 +402,9 @@ describe('TaskGroup', { timeout: 120_000 }, () => { super({ instructions: 'Extract the user favorite color from the latest message. Call recordColor immediately.', - tools: { - recordColor: llm.tool({ + tools: [ + llm.tool({ + name: 'recordColor', description: 'Record favorite color', parameters: z.object({ color: z.string() }), execute: async ({ color }) => { @@ -409,7 +412,7 @@ describe('TaskGroup', { timeout: 120_000 }, () => { return 'recorded'; }, }), - }, + ], }); } @@ -423,8 +426,9 @@ describe('TaskGroup', { timeout: 120_000 }, () => { super({ instructions: 'Extract the user favorite food from the latest message. Call recordFood immediately.', - tools: { - recordFood: llm.tool({ + tools: [ + llm.tool({ + name: 'recordFood', description: 'Record favorite food', parameters: z.object({ food: z.string() }), execute: async ({ food }) => { @@ -432,7 +436,7 @@ describe('TaskGroup', { timeout: 120_000 }, () => { return 'recorded'; }, }), - }, + ], }); } diff --git a/examples/src/tool_call_disfluency.ts b/examples/src/tool_call_disfluency.ts index c89917d54..16f943f8e 100644 --- a/examples/src/tool_call_disfluency.ts +++ b/examples/src/tool_call_disfluency.ts @@ -32,6 +32,7 @@ export default defineAgent({ await ctx.waitForParticipant(); const getWeather = llm.tool({ + name: 'getWeather', description: ' Called when the user asks about the weather.', parameters: z.object({ location: z.string().describe('The location to get the weather for'), @@ -48,9 +49,7 @@ export default defineAgent({ const agent = new VoiceAgent({ instructions: "You are a helpful assistant, you can hear the user's message and respond to it.", - tools: { - getWeather, - }, + tools: [getWeather], }); const session = new voice.AgentSession({ diff --git a/examples/src/warm_transfer.ts b/examples/src/warm_transfer.ts index 546993724..095b87e74 100644 --- a/examples/src/warm_transfer.ts +++ b/examples/src/warm_transfer.ts @@ -23,8 +23,9 @@ class SupportAgent extends voice.Agent { constructor() { super({ instructions: INSTRUCTIONS, - tools: { - transfer_to_human: llm.tool({ + tools: [ + llm.tool({ + name: 'transfer_to_human', description: `Called when the user asks to speak to a human agent. This will put the user on hold while the supervisor is connected. Ensure that the user has confirmed that they wanted to be transferred. Do not start transfer until the user has confirmed. @@ -83,7 +84,7 @@ Examples on when the tool should be called: } }, }), - }, + ], }); } diff --git a/examples/src/xai-realtime.ts b/examples/src/xai-realtime.ts index ad383a0a6..a5130a640 100644 --- a/examples/src/xai-realtime.ts +++ b/examples/src/xai-realtime.ts @@ -10,8 +10,9 @@ export default defineAgent({ entry: async (ctx: JobContext) => { const agent = new voice.Agent({ instructions: 'You are a helpful assistant. Keep your responses short and concise.', - tools: { - getWeather: llm.tool({ + tools: [ + llm.tool({ + name: 'getWeather', description: 'Get the weather for a given location.', parameters: z.object({ location: z.string().describe('The location to get the weather for'), @@ -20,7 +21,7 @@ export default defineAgent({ return `The weather in ${location} is sunny.`; }, }), - }, + ], }); const session = new voice.AgentSession({ diff --git a/package.json b/package.json index fa531e4f0..01b808c7a 100644 --- a/package.json +++ b/package.json @@ -16,6 +16,7 @@ "format:write": "prettier --write \"**/src/**/*.{ts,tsx,md,json}\"", "lint": "turbo lint", "lint:fix": "turbo lint -- --fix", + "typecheck": "turbo typecheck", "test": "vitest run", "test:watch": "vitest", "test:examples": "vitest run examples", diff --git a/plugins/baseten/src/llm.ts b/plugins/baseten/src/llm.ts index 4039b7cac..ca50a5540 100644 --- a/plugins/baseten/src/llm.ts +++ b/plugins/baseten/src/llm.ts @@ -72,19 +72,20 @@ export class OpenAILLM extends llm.LLM { chat({ chatCtx, - toolCtx, + toolCtx: toolCtxInput, connOptions = DEFAULT_API_CONNECT_OPTIONS, parallelToolCalls, toolChoice, extraKwargs, }: { chatCtx: llm.ChatContext; - toolCtx?: llm.ToolContext; + toolCtx?: llm.ToolContextLike; connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: llm.ToolChoice; extraKwargs?: Record; }): inference.LLMStream { + const toolCtx = llm.toToolContext(toolCtxInput); const extras: Record = { ...extraKwargs }; if (this.#opts.metadata) { @@ -125,7 +126,11 @@ export class OpenAILLM extends llm.LLM { parallelToolCalls = parallelToolCalls !== undefined ? parallelToolCalls : this.#opts.parallelToolCalls; - if (toolCtx && Object.keys(toolCtx).length > 0 && parallelToolCalls !== undefined) { + if ( + toolCtx && + Object.keys(toolCtx.functionTools).length > 0 && + parallelToolCalls !== undefined + ) { extras.parallel_tool_calls = parallelToolCalls; } diff --git a/plugins/cerebras/src/llm.test.ts b/plugins/cerebras/src/llm.test.ts index 3a0a3ca8b..dda5d7b3c 100644 --- a/plugins/cerebras/src/llm.test.ts +++ b/plugins/cerebras/src/llm.test.ts @@ -88,8 +88,9 @@ class WeatherAgent extends voice.Agent { constructor() { super({ instructions: 'You are a helpful assistant.', - tools: { - get_weather: llm.tool({ + tools: [ + llm.tool({ + name: 'get_weather', description: 'Get the current weather for a location.', parameters: z.object({ location: z.string().describe('The city name'), @@ -98,7 +99,7 @@ class WeatherAgent extends voice.Agent { return `The weather in ${location} is sunny, 72°F.`; }, }), - }, + ], }); } } diff --git a/plugins/elevenlabs/src/stt.test.ts b/plugins/elevenlabs/src/stt.test.ts index c0349641a..918c321dc 100644 --- a/plugins/elevenlabs/src/stt.test.ts +++ b/plugins/elevenlabs/src/stt.test.ts @@ -280,7 +280,7 @@ describe('ElevenLabs STT', () => { expect(url.pathname).toBe('/speech-to-text/realtime'); expect(url.searchParams.get('model_id')).toBe('scribe_v2_realtime'); expect(url.searchParams.get('audio_format')).toBe('pcm_16000'); - expect(url.searchParams.get('commit_strategy')).toBe('vad'); + expect(url.searchParams.get('commit_strategy')).toBe('manual'); expect(url.searchParams.get('include_language_detection')).toBe('true'); expect(receivedMessages[0]).toMatchObject({ message_type: 'input_audio_chunk', @@ -334,7 +334,6 @@ describe('ElevenLabs STT', () => { words: [{ text: 'kept', start: 0, end: 0.2 }], }), ); - ws.send(JSON.stringify({ message_type: 'committed_transcript_with_timestamps', text: '' })); }); }); @@ -364,6 +363,7 @@ describe('ElevenLabs STT', () => { const url = new URL(`ws://127.0.0.1${requestUrl}`); expect(url.searchParams.get('language_code')).toBe('en'); + expect(url.searchParams.get('commit_strategy')).toBe('vad'); expect(url.searchParams.get('include_language_detection')).toBeNull(); expect(url.searchParams.get('include_timestamps')).toBe('true'); expect(url.searchParams.get('vad_silence_threshold_secs')).toBe('0.5'); @@ -393,14 +393,18 @@ describe('ElevenLabs STT', () => { const stream = eleven.stream(); await waitUntil(() => urls.length === 1); - eleven.updateOptions({ serverVad: null }); + eleven.updateOptions({ serverVad: { vadSilenceThresholdSecs: 0.5 } }); await waitUntil(() => urls.length === 2, 2000); + eleven.updateOptions({ serverVad: null }); + await waitUntil(() => urls.length === 3, 2000); stream.close(); const first = new URL(`ws://127.0.0.1${urls[0]}`); const second = new URL(`ws://127.0.0.1${urls[1]}`); - expect(first.searchParams.get('commit_strategy')).toBe('vad'); - expect(second.searchParams.get('commit_strategy')).toBe('manual'); + const third = new URL(`ws://127.0.0.1${urls[2]}`); + expect(first.searchParams.get('commit_strategy')).toBe('manual'); + expect(second.searchParams.get('commit_strategy')).toBe('vad'); + expect(third.searchParams.get('commit_strategy')).toBe('manual'); } finally { await closeWebSocketServer(wss); } diff --git a/plugins/elevenlabs/src/stt.ts b/plugins/elevenlabs/src/stt.ts index a205a29b7..5ff10488b 100644 --- a/plugins/elevenlabs/src/stt.ts +++ b/plugins/elevenlabs/src/stt.ts @@ -330,17 +330,16 @@ export class STT extends stt.STT { language?: string, abortSignal?: AbortSignal, ): Promise { - if (language !== undefined) { - this.#opts.languageCode = normalizeLanguage(language); - } + const languageCode = + language !== undefined ? normalizeLanguage(language) : this.#opts.languageCode; const wavBytes = createWav(mergeFrames(buffer)); const form = new FormData(); form.append('file', new Blob([new Uint8Array(wavBytes)], { type: 'audio/x-wav' }), 'audio.wav'); form.append('model_id', this.#opts.modelId); form.append('tag_audio_events', String(this.#opts.tagAudioEvents)); - if (this.#opts.languageCode) { - form.append('language_code', this.#opts.languageCode); + if (languageCode) { + form.append('language_code', languageCode); } if (this.#opts.keyterms !== undefined) { for (const keyterm of this.#opts.keyterms) { @@ -369,7 +368,7 @@ export class STT extends stt.STT { const startTime = words.length > 0 ? Math.min(...words.map((word) => word.start ?? 0)) : 0; const endTime = words.length > 0 ? Math.max(...words.map((word) => word.end ?? 0)) : 0; const normalizedLanguage = normalizeLanguage( - responseJson.language_code ?? this.#opts.languageCode ?? '', + responseJson.language_code ?? languageCode ?? '', ); return this.#transcriptionToSpeechEvent( @@ -650,7 +649,8 @@ export class SpeechStream extends stt.SpeechStream { } async #connectWs(): Promise { - const commitStrategy = this.#opts.serverVad === null ? 'manual' : 'vad'; + const serverVad = this.#opts.serverVad; + const commitStrategy = serverVad === undefined || serverVad === null ? 'manual' : 'vad'; const params = [ `model_id=${this.#opts.modelId}`, `audio_format=pcm_${this.#opts.sampleRate}`, @@ -661,30 +661,21 @@ export class SpeechStream extends stt.SpeechStream { params.push('include_language_detection=true'); } - if (this.#opts.serverVad) { + if (serverVad !== undefined && serverVad !== null) { if ( - this.#opts.serverVad.vadSilenceThresholdSecs !== undefined && - this.#opts.serverVad.vadSilenceThresholdSecs !== null + serverVad.vadSilenceThresholdSecs !== undefined && + serverVad.vadSilenceThresholdSecs !== null ) { - params.push(`vad_silence_threshold_secs=${this.#opts.serverVad.vadSilenceThresholdSecs}`); + params.push(`vad_silence_threshold_secs=${serverVad.vadSilenceThresholdSecs}`); } - if ( - this.#opts.serverVad.vadThreshold !== undefined && - this.#opts.serverVad.vadThreshold !== null - ) { - params.push(`vad_threshold=${this.#opts.serverVad.vadThreshold}`); + if (serverVad.vadThreshold !== undefined && serverVad.vadThreshold !== null) { + params.push(`vad_threshold=${serverVad.vadThreshold}`); } - if ( - this.#opts.serverVad.minSpeechDurationMs !== undefined && - this.#opts.serverVad.minSpeechDurationMs !== null - ) { - params.push(`min_speech_duration_ms=${this.#opts.serverVad.minSpeechDurationMs}`); + if (serverVad.minSpeechDurationMs !== undefined && serverVad.minSpeechDurationMs !== null) { + params.push(`min_speech_duration_ms=${serverVad.minSpeechDurationMs}`); } - if ( - this.#opts.serverVad.minSilenceDurationMs !== undefined && - this.#opts.serverVad.minSilenceDurationMs !== null - ) { - params.push(`min_silence_duration_ms=${this.#opts.serverVad.minSilenceDurationMs}`); + if (serverVad.minSilenceDurationMs !== undefined && serverVad.minSilenceDurationMs !== null) { + params.push(`min_silence_duration_ms=${serverVad.minSilenceDurationMs}`); } } @@ -804,6 +795,10 @@ export class SpeechStream extends stt.SpeechStream { type: stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives: [speechData], }); + if (this.#opts.serverVad !== undefined && this.#opts.serverVad !== null) { + this.queue.put({ type: stt.SpeechEventType.END_OF_SPEECH }); + this.#speaking = false; + } } else if (this.#speaking) { this.queue.put({ type: stt.SpeechEventType.END_OF_SPEECH }); this.#speaking = false; diff --git a/plugins/google/src/aiplatform_llm.ts b/plugins/google/src/aiplatform_llm.ts index 9217ee150..daa363086 100644 --- a/plugins/google/src/aiplatform_llm.ts +++ b/plugins/google/src/aiplatform_llm.ts @@ -165,19 +165,20 @@ export class AIPlatformLLM extends llm.LLM { chat({ chatCtx, - toolCtx, + toolCtx: toolCtxInput, connOptions = DEFAULT_API_CONNECT_OPTIONS, parallelToolCalls, toolChoice, extraKwargs, }: { chatCtx: llm.ChatContext; - toolCtx?: llm.ToolContext; + toolCtx?: llm.ToolContextLike; connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: llm.ToolChoice; extraKwargs?: Record; }): inference.LLMStream { + const toolCtx = llm.toToolContext(toolCtxInput); const extras: Record = { ...extraKwargs }; if (this.#opts.temperature !== undefined) { @@ -201,7 +202,11 @@ export class AIPlatformLLM extends llm.LLM { parallelToolCalls = parallelToolCalls !== undefined ? parallelToolCalls : this.#opts.parallelToolCalls; - if (toolCtx && Object.keys(toolCtx).length > 0 && parallelToolCalls !== undefined) { + if ( + toolCtx && + Object.keys(toolCtx.functionTools).length > 0 && + parallelToolCalls !== undefined + ) { extras.parallel_tool_calls = parallelToolCalls; } diff --git a/plugins/google/src/index.ts b/plugins/google/src/index.ts index 0afb408fc..8e786ae5b 100644 --- a/plugins/google/src/index.ts +++ b/plugins/google/src/index.ts @@ -15,6 +15,7 @@ export { export { beta }; export { LLM, LLMStream, type LLMOptions } from './llm.js'; export * from './models.js'; +export * from './tools.js'; export { realtime }; class GooglePlugin extends Plugin { diff --git a/plugins/google/src/llm.ts b/plugins/google/src/llm.ts index 4e8e52ef3..706a18046 100644 --- a/plugins/google/src/llm.ts +++ b/plugins/google/src/llm.ts @@ -13,7 +13,7 @@ import { } from '@livekit/agents'; import type { ChatModels } from './models.js'; import type { LLMTools } from './tools.js'; -import { toFunctionDeclarations } from './utils.js'; +import { toToolsConfig } from './utils.js'; interface GoogleFormatData { systemMessages: string[] | null; @@ -189,20 +189,21 @@ export class LLM extends llm.LLM { chat({ chatCtx, - toolCtx, + toolCtx: toolCtxInput, connOptions = DEFAULT_API_CONNECT_OPTIONS, toolChoice, extraKwargs, geminiTools, }: { chatCtx: llm.ChatContext; - toolCtx?: llm.ToolContext; + toolCtx?: llm.ToolContextLike; connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: llm.ToolChoice; extraKwargs?: Record; geminiTools?: LLMTools; }): LLMStream { + const toolCtx = llm.toToolContext(toolCtxInput); const extras: GenerateContentConfig = { ...extraKwargs } as GenerateContentConfig; if (this.#opts.httpOptions !== undefined && extras.httpOptions === undefined) { @@ -358,11 +359,11 @@ export class LLMStream extends llm.LLMStream { parts: turn.parts as types.Part[], })); - const functionDeclarations = this.toolCtx ? toFunctionDeclarations(this.toolCtx) : undefined; - const tools = - functionDeclarations && functionDeclarations.length > 0 - ? [{ functionDeclarations }] - : undefined; + const tools = toToolsConfig({ + toolCtx: this.toolCtx, + geminiTools: this.#geminiTools, + onlySingleType: true, + }); let systemInstruction: types.Content | undefined = undefined; if (extraData.systemMessages && extraData.systemMessages.length > 0) { diff --git a/plugins/google/src/realtime/realtime_api.ts b/plugins/google/src/realtime/realtime_api.ts index 01f37208a..d12ca3b92 100644 --- a/plugins/google/src/realtime/realtime_api.ts +++ b/plugins/google/src/realtime/realtime_api.ts @@ -33,7 +33,7 @@ import { import { Mutex } from '@livekit/mutex'; import { AudioFrame, AudioResampler, type VideoFrame } from '@livekit/rtc-node'; import { type LLMTools } from '../tools.js'; -import { toFunctionDeclarations } from '../utils.js'; +import { toToolsConfig } from '../utils.js'; import type * as api_proto from './api_proto.js'; import type { LiveAPIModels, Voice } from './api_proto.js'; @@ -70,13 +70,6 @@ export interface InputTranscription { transcript: string; } -/** - * Helper function to check if two sets are equal - */ -function setsEqual(a: Set, b: Set): boolean { - return a.size === b.size && [...a].every((x) => b.has(x)); -} - /** * Internal realtime options for Google Realtime API */ @@ -451,11 +444,10 @@ export class RealtimeModel extends llm.RealtimeModel { * supporting both text and audio modalities with function calling capabilities. */ export class RealtimeSession extends llm.RealtimeSession { - private _tools: llm.ToolContext = {}; + private _tools: llm.ToolContext = llm.ToolContext.empty(); private _chatCtx = llm.ChatContext.empty(); private options: RealtimeOptions; - private geminiDeclarations: types.FunctionDeclaration[] = []; private messageChannel = new Queue(); private inputResampler?: AudioResampler; private inputResamplerInputRate?: number; @@ -764,15 +756,12 @@ export class RealtimeSession extends llm.RealtimeSession { } async updateTools(tools: llm.ToolContext): Promise { - const newDeclarations = toFunctionDeclarations(tools); - const currentToolNames = new Set(this.geminiDeclarations.map((f) => f.name)); - const newToolNames = new Set(newDeclarations.map((f) => f.name)); - - if (!setsEqual(currentToolNames, newToolNames)) { - this.geminiDeclarations = newDeclarations; - this._tools = tools; - this.markRestartNeeded(); + if (this._tools.equals(tools)) { + return; } + + this._tools = tools; + this.markRestartNeeded(); } get chatCtx(): llm.ChatContext { @@ -780,7 +769,7 @@ export class RealtimeSession extends llm.RealtimeSession { } get tools(): llm.ToolContext { - return { ...this._tools }; + return this._tools.copy(); } get manualActivityDetection(): boolean { @@ -1452,21 +1441,11 @@ export class RealtimeSession extends llm.RealtimeSession { }, languageCode: opts.language, }, - tools: - this.geminiDeclarations.length > 0 || this.options.geminiTools - ? [ - { - functionDeclarations: - this.options.toolBehavior !== undefined - ? this.geminiDeclarations.map((d) => ({ - ...d, - behavior: this.options.toolBehavior, - })) - : this.geminiDeclarations, - ...this.options.geminiTools, - }, - ] - : undefined, + tools: toToolsConfig({ + toolCtx: this._tools, + geminiTools: this.options.geminiTools, + toolBehavior: this.options.toolBehavior, + }), inputAudioTranscription: opts.inputAudioTranscription, outputAudioTranscription: opts.outputAudioTranscription, sessionResumption: this.sessionResumptionHandle diff --git a/plugins/google/src/tools.ts b/plugins/google/src/tools.ts index 90cd9cc7c..b864e98a6 100644 --- a/plugins/google/src/tools.ts +++ b/plugins/google/src/tools.ts @@ -1,6 +1,100 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import type { Tool } from '@google/genai'; +import type * as types from '@google/genai'; +import { llm } from '@livekit/agents'; -export type LLMTools = Omit; +export type LLMTools = Omit; + +export abstract class GeminiTool extends llm.ProviderTool { + abstract toToolConfig(): types.Tool; +} + +export class GoogleSearch extends GeminiTool { + constructor(public readonly options: types.GoogleSearch = {}) { + super({ id: 'gemini_google_search' }); + } + + toToolConfig(): types.Tool { + return { googleSearch: this.options }; + } +} + +export class GoogleMaps extends GeminiTool { + constructor(public readonly options: types.GoogleMaps = {}) { + super({ id: 'gemini_google_maps' }); + } + + toToolConfig(): types.Tool { + return { googleMaps: this.options }; + } +} + +export class URLContext extends GeminiTool { + constructor() { + super({ id: 'gemini_url_context' }); + } + + toToolConfig(): types.Tool { + return { urlContext: {} }; + } +} + +export interface FileSearchOptions extends types.FileSearch { + fileSearchStoreNames: string[]; +} + +export class FileSearch extends GeminiTool { + constructor(public readonly options: FileSearchOptions) { + super({ id: 'gemini_file_search' }); + } + + toToolConfig(): types.Tool { + return { fileSearch: this.options }; + } +} + +export class ToolCodeExecution extends GeminiTool { + constructor() { + super({ id: 'gemini_code_execution' }); + } + + toToolConfig(): types.Tool { + return { codeExecution: {} }; + } +} + +export interface VertexRAGRetrievalOptions { + ragResources: string[]; + similarityTopK?: number; + vectorDistanceThreshold?: number; +} + +export class VertexRAGRetrieval extends GeminiTool { + readonly ragResources: string[]; + readonly similarityTopK: number; + readonly vectorDistanceThreshold?: number; + + constructor({ + ragResources, + similarityTopK = 3, + vectorDistanceThreshold, + }: VertexRAGRetrievalOptions) { + super({ id: 'gemini_vertex_rag_retrieval' }); + this.ragResources = ragResources; + this.similarityTopK = similarityTopK; + this.vectorDistanceThreshold = vectorDistanceThreshold; + } + + toToolConfig(): types.Tool { + return { + retrieval: { + vertexRagStore: { + ragResources: this.ragResources.map((ragCorpus) => ({ ragCorpus })), + similarityTopK: this.similarityTopK, + vectorDistanceThreshold: this.vectorDistanceThreshold, + }, + }, + }; + } +} diff --git a/plugins/google/src/utils.ts b/plugins/google/src/utils.ts index 3bb2ba0b1..c9fa471e1 100644 --- a/plugins/google/src/utils.ts +++ b/plugins/google/src/utils.ts @@ -1,9 +1,11 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 +import type * as types from '@google/genai'; import type { FunctionDeclaration, Schema } from '@google/genai'; import { llm } from '@livekit/agents'; import type { JSONSchema7 } from 'json-schema'; +import { GeminiTool, type LLMTools } from './tools.js'; /** * JSON Schema v7 @@ -139,6 +141,8 @@ function isEmptyObjectSchema(jsonSchema: JSONSchema7Definition): boolean { export function toFunctionDeclarations(toolCtx: llm.ToolContext): FunctionDeclaration[] { const functionDeclarations: FunctionDeclaration[] = []; + // Provider tools are not supported by the Gemini schema; `sortedToolEntries` yields only + // function tools (sorted by name), so they are skipped here. for (const [name, tool] of llm.sortedToolEntries(toolCtx)) { const { description, parameters } = tool; const jsonSchema = llm.toJsonSchema(parameters, false); @@ -155,3 +159,57 @@ export function toFunctionDeclarations(toolCtx: llm.ToolContext): FunctionDeclar return functionDeclarations; } + +export function toToolsConfig({ + toolCtx, + geminiTools, + toolBehavior, + onlySingleType = false, +}: { + toolCtx?: llm.ToolContext; + geminiTools?: LLMTools; + toolBehavior?: types.Behavior; + onlySingleType?: boolean; +}): types.Tool[] | undefined { + const tools: types.Tool[] = []; + const providerTools: types.Tool[] = []; + + if (toolCtx) { + const functionDeclarations = toFunctionDeclarations(toolCtx); + if (functionDeclarations.length > 0) { + tools.push({ + functionDeclarations: + toolBehavior !== undefined + ? functionDeclarations.map((declaration) => ({ + ...declaration, + behavior: toolBehavior, + })) + : functionDeclarations, + }); + } + } + + if (geminiTools !== undefined) { + providerTools.push(geminiTools); + } + + if (toolCtx) { + for (const tool of toolCtx.providerTools) { + if (tool instanceof GeminiTool) { + providerTools.push(tool.toToolConfig()); + } + } + } + + if (tools.length > 0 && providerTools.length > 0) { + throw new Error('Gemini does not support mixing function tools and provider tools'); + } + + if (onlySingleType && tools.length > 0) { + return tools; + } + + tools.push(...providerTools); + + return tools.length > 0 ? tools : undefined; +} diff --git a/plugins/mistralai/src/llm.ts b/plugins/mistralai/src/llm.ts index 97afbc6ff..d857c65f8 100644 --- a/plugins/mistralai/src/llm.ts +++ b/plugins/mistralai/src/llm.ts @@ -123,7 +123,7 @@ export class LLM extends llm.LLM { extraKwargs, }: { chatCtx: llm.ChatContext; - toolCtx?: llm.ToolContext; + toolCtx?: llm.ToolContextLike; connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: llm.ToolChoice; @@ -187,7 +187,7 @@ export class LLMStream extends llm.LLMStream { client: Mistral; opts: LLMOpts; chatCtx: llm.ChatContext; - toolCtx?: llm.ToolContext; + toolCtx?: llm.ToolContextLike; connOptions: APIConnectOptions; extraKwargs: Record; }, @@ -211,7 +211,9 @@ export class LLMStream extends llm.LLMStream { // eslint-disable-next-line @typescript-eslint/no-explicit-any const toolsList: any[] = []; - if (this.toolCtx && Object.keys(this.toolCtx).length > 0) { + // Provider tools are not supported by the Mistral schema; `sortedToolEntries` yields only + // function tools (sorted by name), so they are skipped here. + if (this.toolCtx) { for (const [name, func] of llm.sortedToolEntries(this.toolCtx)) { toolsList.push({ type: 'function' as const, diff --git a/plugins/openai/src/index.ts b/plugins/openai/src/index.ts index ccffdcb3f..6a5d9cb7c 100644 --- a/plugins/openai/src/index.ts +++ b/plugins/openai/src/index.ts @@ -5,6 +5,7 @@ import { Plugin } from '@livekit/agents'; export { LLM, LLMStream, type LLMOptions } from './llm.js'; export * from './models.js'; +export * from './tools.js'; export * as realtime from './realtime/index.js'; export * as responses from './responses/index.js'; export { STT, type STTOptions } from './stt.js'; diff --git a/plugins/openai/src/llm.ts b/plugins/openai/src/llm.ts index 50e025b94..6f83d33c2 100644 --- a/plugins/openai/src/llm.ts +++ b/plugins/openai/src/llm.ts @@ -469,19 +469,20 @@ export class LLM extends llm.LLM { chat({ chatCtx, - toolCtx, + toolCtx: toolCtxInput, connOptions = DEFAULT_API_CONNECT_OPTIONS, parallelToolCalls, toolChoice, extraKwargs, }: { chatCtx: llm.ChatContext; - toolCtx?: llm.ToolContext; + toolCtx?: llm.ToolContextLike; connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: llm.ToolChoice; extraKwargs?: Record; }): LLMStream { + const toolCtx = llm.toToolContext(toolCtxInput); const extras: Record = { ...extraKwargs }; if (this.#opts.metadata) { @@ -514,7 +515,11 @@ export class LLM extends llm.LLM { parallelToolCalls = parallelToolCalls !== undefined ? parallelToolCalls : this.#opts.parallelToolCalls; - if (toolCtx && Object.keys(toolCtx).length > 0 && parallelToolCalls !== undefined) { + if ( + toolCtx && + Object.keys(toolCtx.functionTools).length > 0 && + parallelToolCalls !== undefined + ) { extras.parallel_tool_calls = parallelToolCalls; } diff --git a/plugins/openai/src/realtime/realtime_model.test.ts b/plugins/openai/src/realtime/realtime_model.test.ts index 6080cdbd6..d52d1090a 100644 --- a/plugins/openai/src/realtime/realtime_model.test.ts +++ b/plugins/openai/src/realtime/realtime_model.test.ts @@ -13,9 +13,15 @@ import { type RealtimeSessionInternals = { generateReply: RealtimeSession['generateReply']; + updateInstructions: RealtimeSession['updateInstructions']; responseCreatedFutures: Record; sendEvent: ReturnType; textModeRecoveryRetries: number; + instructions?: string; + _options: { + isAzure?: boolean; + apiVersion?: string; + }; }; type ResponseDoneSessionInternals = { @@ -36,6 +42,7 @@ function createSessionForTest(): RealtimeSessionInternals { session.responseCreatedFutures = {}; session.sendEvent = vi.fn(); session.textModeRecoveryRetries = 0; + session._options = {}; return session; } @@ -49,6 +56,37 @@ function stubTaskRuntime(): void { } describe('RealtimeSession.generateReply', () => { + it('preserves session instructions when generating with per-response instructions', async () => { + const session = createSessionForTest(); + await session.updateInstructions('Your name is Kelly. Always respond in English.'); + + const abortController = new AbortController(); + const promise = session.generateReply('Tell the user what your name is.', { + signal: abortController.signal, + }); + abortController.abort(); + + await expect(promise).rejects.toThrow('generateReply aborted'); + expect(session.instructions).toBe('Your name is Kelly. Always respond in English.'); + expect(session.sendEvent).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'session.update', + session: expect.objectContaining({ + instructions: 'Your name is Kelly. Always respond in English.', + }), + }), + ); + expect(session.sendEvent).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'response.create', + response: expect.objectContaining({ + instructions: + 'Your name is Kelly. Always respond in English.\nTell the user what your name is.', + }), + }), + ); + }); + it('cancels an in-flight response when aborted before response.created', async () => { const session = createSessionForTest(); const abortController = new AbortController(); diff --git a/plugins/openai/src/realtime/realtime_model.ts b/plugins/openai/src/realtime/realtime_model.ts index cc28f9618..efc4cfca6 100644 --- a/plugins/openai/src/realtime/realtime_model.ts +++ b/plugins/openai/src/realtime/realtime_model.ts @@ -425,7 +425,7 @@ export function processBaseURL({ * - openai_client_event_queued: expose the raw client events sent to the OpenAI Realtime API */ export class RealtimeSession extends llm.RealtimeSession { - private _tools: llm.ToolContext = {}; + private _tools: llm.ToolContext = llm.ToolContext.empty(); private remoteChatCtx: llm.RemoteChatContext = new llm.RemoteChatContext(); private messageChannel = new Queue(); private inputResampler?: AudioResampler; @@ -552,7 +552,7 @@ export class RealtimeSession extends llm.RealtimeSession { } get tools() { - return { ...this._tools } as llm.ToolContext; + return this._tools.copy(); } async updateChatCtx(_chatCtx: llm.ChatContext): Promise { @@ -719,13 +719,12 @@ export class RealtimeSession extends llm.RealtimeSession { // TODO(brian): these logics below are noops I think; remove them later. const retainedToolNames = new Set(ev.session.tools.map((tool) => tool.name)); - const retainedTools = Object.fromEntries( - Object.entries(_tools).filter( - ([name, tool]) => llm.isFunctionTool(tool) && retainedToolNames.has(name), - ), + // Keep provider tools and Toolsets as-is; only drop function tools the server didn't accept. + const retainedEntries = _tools.tools.filter( + (entry) => !llm.isFunctionTool(entry) || retainedToolNames.has(entry.name), ); - this._tools = retainedTools as llm.ToolContext; + this._tools = new llm.ToolContext(retainedEntries); unlock(); } @@ -733,26 +732,26 @@ export class RealtimeSession extends llm.RealtimeSession { private createToolsUpdateEvent(_tools: llm.ToolContext): api_proto.SessionUpdateEvent { const oaiTools: api_proto.Tool[] = []; - for (const [name, tool] of Object.entries(_tools)) { - if (!llm.isFunctionTool(tool)) { - this.#logger.error({ name, tool }, "OpenAI Realtime API doesn't support this tool type"); - continue; - } + for (const t of _tools.flatten()) { + // TODO: support provider tools in the Realtime session-update schema. + if (!llm.isFunctionTool(t)) continue; - const { parameters: toolParameters, description } = tool; try { const parameters = llm.toJsonSchema( - toolParameters, + t.parameters, ) as unknown as api_proto.Tool['parameters']; oaiTools.push({ - name, - description, + name: t.name, + description: t.description, parameters: parameters, type: 'function', }); } catch (e) { - this.#logger.error({ name, tool }, "OpenAI Realtime API doesn't support this tool type"); + this.#logger.error( + { name: t.name, tool: t }, + "OpenAI Realtime API doesn't support this tool type", + ); continue; } } @@ -842,7 +841,18 @@ export class RealtimeSession extends llm.RealtimeSession { instructions?: string, options: { signal?: AbortSignal } = {}, ): Promise { - const handle = this.createResponse({ instructions, userInitiated: true }); + // In OpenAI realtime, the session-level instructions are completely replaced by the + // per-response instructions for this response. Prepend the session instructions so they + // are preserved (parity with the Python implementation). + let responseInstructions = instructions; + if (instructions && this.instructions) { + responseInstructions = `${this.instructions}\n${instructions}`; + } + + const handle = this.createResponse({ + instructions: responseInstructions, + userInitiated: true, + }); this.textModeRecoveryRetries = 0; const onAbort = () => { @@ -1049,7 +1059,7 @@ export class RealtimeSession extends llm.RealtimeSession { events.push(this.createSessionUpdateEvent()); // tools - if (Object.keys(this._tools).length > 0) { + if (Object.keys(this._tools.functionTools).length > 0) { events.push(this.createToolsUpdateEvent(this._tools)); } diff --git a/plugins/openai/src/responses/llm.ts b/plugins/openai/src/responses/llm.ts index 219651d83..1615352af 100644 --- a/plugins/openai/src/responses/llm.ts +++ b/plugins/openai/src/responses/llm.ts @@ -13,6 +13,7 @@ import { } from '@livekit/agents'; import OpenAI from 'openai'; import type { ChatModels } from '../models.js'; +import { toResponsesTools } from '../tool_utils.js'; import { WSLLM } from '../ws/llm.js'; export interface LLMOptions { @@ -77,25 +78,30 @@ class ResponsesHttpLLM extends llm.LLM { override chat({ chatCtx, - toolCtx, + toolCtx: toolCtxInput, connOptions = DEFAULT_API_CONNECT_OPTIONS, parallelToolCalls, toolChoice, extraKwargs, }: { chatCtx: llm.ChatContext; - toolCtx?: llm.ToolContext; + toolCtx?: llm.ToolContextLike; connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: llm.ToolChoice; extraKwargs?: Record; }): ResponsesHttpLLMStream { + const toolCtx = llm.toToolContext(toolCtxInput); const modelOptions: Record = { ...(extraKwargs || {}) }; parallelToolCalls = parallelToolCalls !== undefined ? parallelToolCalls : this.#opts.parallelToolCalls; - if (toolCtx && Object.keys(toolCtx).length > 0 && parallelToolCalls !== undefined) { + if ( + toolCtx && + Object.keys(toolCtx.functionTools).length > 0 && + parallelToolCalls !== undefined + ) { modelOptions.parallel_tool_calls = parallelToolCalls; } @@ -181,25 +187,9 @@ class ResponsesHttpLLMStream extends llm.LLMStream { 'openai.responses', )) as OpenAI.Responses.ResponseInputItem[]; + // TODO: support provider tools in the Responses schema. const tools = this.toolCtx - ? llm.sortedToolEntries(this.toolCtx).map(([name, func]) => { - const oaiParams = { - type: 'function' as const, - name: name, - description: func.description, - parameters: llm.toJsonSchema( - func.parameters, - true, - this.strictToolSchema, - ) as unknown as OpenAI.Responses.FunctionTool['parameters'], - } as OpenAI.Responses.FunctionTool; - - if (this.strictToolSchema) { - oaiParams.strict = true; - } - - return oaiParams; - }) + ? toResponsesTools(this.toolCtx, this.strictToolSchema) : undefined; const requestOptions: Record = { ...this.modelOptions }; @@ -430,7 +420,7 @@ export class LLM extends llm.LLM { extraKwargs, }: { chatCtx: llm.ChatContext; - toolCtx?: llm.ToolContext; + toolCtx?: llm.ToolContextLike; connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: llm.ToolChoice; diff --git a/plugins/openai/src/tool_utils.test.ts b/plugins/openai/src/tool_utils.test.ts new file mode 100644 index 000000000..ce1922f52 --- /dev/null +++ b/plugins/openai/src/tool_utils.test.ts @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { llm } from '@livekit/agents'; +import { describe, expect, it } from 'vitest'; +import { z } from 'zod'; +import { toResponsesTools } from './tool_utils.js'; +import { CodeInterpreter, FileSearch, WebSearch } from './tools.js'; + +describe('toResponsesTools', () => { + it('serializes function tools', () => { + const fn = llm.tool({ + name: 'lookup_weather', + description: 'Look up weather', + parameters: z.object({ city: z.string() }), + execute: async () => 'sunny', + }); + + expect(toResponsesTools(new llm.ToolContext([fn]), true)).toEqual([ + { + type: 'function', + name: 'lookup_weather', + description: 'Look up weather', + parameters: { + $schema: 'http://json-schema.org/draft-07/schema#', + type: 'object', + properties: { city: { type: 'string' } }, + required: ['city'], + additionalProperties: false, + }, + strict: true, + }, + ]); + }); + + it('serializes OpenAI provider tools', () => { + const tools = toResponsesTools( + new llm.ToolContext([ + new WebSearch({ + filters: { allowed_domains: ['docs.livekit.io'] }, + searchContextSize: 'low', + userLocation: { type: 'approximate', country: 'US' }, + }), + new FileSearch({ + vectorStoreIds: ['vs_123'], + maxNumResults: 3, + rankingOptions: { ranker: 'auto' }, + }), + new CodeInterpreter({ container: { type: 'auto', file_ids: ['file_123'] } }), + ]), + false, + ); + + expect(tools).toEqual([ + { + type: 'web_search', + search_context_size: 'low', + filters: { allowed_domains: ['docs.livekit.io'] }, + user_location: { type: 'approximate', country: 'US' }, + }, + { + type: 'file_search', + vector_store_ids: ['vs_123'], + max_num_results: 3, + ranking_options: { ranker: 'auto' }, + }, + { type: 'code_interpreter', container: { type: 'auto', file_ids: ['file_123'] } }, + ]); + }); + + it('omits the code interpreter container when unset', () => { + expect(toResponsesTools(new llm.ToolContext([new CodeInterpreter()]), false)).toEqual([ + { type: 'code_interpreter' }, + ]); + }); + + it('ignores non-OpenAI provider tools', () => { + class OtherProviderTool extends llm.ProviderTool {} + + expect( + toResponsesTools(new llm.ToolContext([new OtherProviderTool({ id: 'other' })]), false), + ).toBeUndefined(); + }); +}); diff --git a/plugins/openai/src/tool_utils.ts b/plugins/openai/src/tool_utils.ts new file mode 100644 index 000000000..a0442b496 --- /dev/null +++ b/plugins/openai/src/tool_utils.ts @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { llm } from '@livekit/agents'; +import type OpenAI from 'openai'; +import { OpenAITool } from './tools.js'; + +export function toResponsesTools( + toolCtx: llm.ToolContext, + strictToolSchema: boolean, +): OpenAI.Responses.Tool[] | undefined { + // Function tools are emitted first, sorted by name for deterministic payloads; provider + // tools follow in registration order. + const functionTools = llm.sortedToolEntries(toolCtx).map(([name, tool]) => { + const oaiParams = { + type: 'function' as const, + name, + description: tool.description, + parameters: llm.toJsonSchema( + tool.parameters, + true, + strictToolSchema, + ) as unknown as OpenAI.Responses.FunctionTool['parameters'], + } as OpenAI.Responses.FunctionTool; + + if (strictToolSchema) { + oaiParams.strict = true; + } + + return oaiParams; + }); + + const providerTools = toolCtx + .flatten() + .filter((tool) => !llm.isFunctionTool(tool)) + .map((tool) => + tool instanceof OpenAITool + ? (tool.toToolConfig() as unknown as OpenAI.Responses.Tool) + : undefined, + ) + .filter((tool): tool is OpenAI.Responses.Tool => tool !== undefined); + + const tools = [...functionTools, ...providerTools]; + + return tools.length > 0 ? tools : undefined; +} diff --git a/plugins/openai/src/tools.ts b/plugins/openai/src/tools.ts new file mode 100644 index 000000000..1ce779376 --- /dev/null +++ b/plugins/openai/src/tools.ts @@ -0,0 +1,166 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { llm } from '@livekit/agents'; +import type OpenAI from 'openai'; + +/** Base class for OpenAI Responses API provider tools. */ +export abstract class OpenAITool extends llm.ProviderTool { + /** Convert this provider tool to the OpenAI Responses API tool configuration. */ + abstract toToolConfig(): Record; +} + +/** + * High-level guidance for the amount of context window space to use for web search. + * OpenAI defaults this to `medium`. + */ +export type WebSearchContextSize = 'low' | 'medium' | 'high'; + +/** Options for the OpenAI web search tool. */ +export interface WebSearchOptions { + /** + * Filters for the search, such as allowed domains. If not provided, all domains are allowed. + */ + filters?: OpenAI.Responses.WebSearchTool['filters']; + + /** + * Amount of context window space to use for the search. Defaults to `medium`. + */ + searchContextSize?: WebSearchContextSize | null; + + /** Approximate location of the user, such as city, region, country, or timezone. */ + userLocation?: OpenAI.Responses.WebSearchTool['user_location']; +} + +/** + * Search the Internet for sources related to the prompt. + * + * @see https://platform.openai.com/docs/guides/tools-web-search + */ +export class WebSearch extends OpenAITool { + /** Filters for the search, such as allowed domains. */ + readonly filters: OpenAI.Responses.WebSearchTool['filters'] | undefined; + + /** Amount of context window space to use for the search. */ + readonly searchContextSize: WebSearchContextSize | null; + + /** Approximate location of the user. */ + readonly userLocation: OpenAI.Responses.WebSearchTool['user_location'] | undefined; + + constructor({ filters, searchContextSize = 'medium', userLocation }: WebSearchOptions = {}) { + super({ id: 'openai_web_search' }); + this.filters = filters; + this.searchContextSize = searchContextSize; + this.userLocation = userLocation; + } + + toToolConfig(): Record { + const result: Record = { + type: 'web_search', + search_context_size: this.searchContextSize, + }; + if (this.userLocation !== undefined) { + result.user_location = this.userLocation; + } + if (this.filters !== undefined) { + result.filters = this.filters; + } + return result; + } +} + +/** Options for the OpenAI file search tool. */ +export interface FileSearchOptions { + /** IDs of the vector stores to search. */ + vectorStoreIds?: string[]; + + /** Filter to apply to file search results. */ + filters?: OpenAI.Responses.FileSearchTool['filters']; + + /** Maximum number of results to return. This should be between 1 and 50 inclusive. */ + maxNumResults?: number; + + /** Ranking options for search, including ranker and score threshold. */ + rankingOptions?: OpenAI.Responses.FileSearchTool.RankingOptions; +} + +/** + * Search for relevant content from uploaded files. + * + * @see https://platform.openai.com/docs/guides/tools-file-search + */ +export class FileSearch extends OpenAITool { + /** IDs of the vector stores to search. */ + readonly vectorStoreIds: string[]; + + /** Filter to apply to file search results. */ + readonly filters: OpenAI.Responses.FileSearchTool['filters'] | undefined; + + /** Maximum number of results to return. */ + readonly maxNumResults: number | undefined; + + /** Ranking options for search. */ + readonly rankingOptions: OpenAI.Responses.FileSearchTool.RankingOptions | undefined; + + constructor({ + vectorStoreIds = [], + filters, + maxNumResults, + rankingOptions, + }: FileSearchOptions = {}) { + super({ id: 'openai_file_search' }); + this.vectorStoreIds = [...vectorStoreIds]; + this.filters = filters; + this.maxNumResults = maxNumResults; + this.rankingOptions = rankingOptions; + } + + toToolConfig(): Record { + const result: Record = { + type: 'file_search', + vector_store_ids: this.vectorStoreIds, + }; + if (this.filters !== undefined) { + result.filters = this.filters; + } + if (this.maxNumResults !== undefined) { + result.max_num_results = this.maxNumResults; + } + if (this.rankingOptions !== undefined) { + result.ranking_options = this.rankingOptions; + } + return result; + } +} + +/** Options for the OpenAI code interpreter tool. */ +export interface CodeInterpreterOptions { + /** + * Code interpreter container. Can be a container ID or an object that specifies uploaded file IDs + * to make available to the code. + */ + container?: OpenAI.Responses.Tool.CodeInterpreter['container'] | null; +} + +/** + * Run Python code to help generate a response to a prompt. + * + * @see https://platform.openai.com/docs/guides/tools-code-interpreter + */ +export class CodeInterpreter extends OpenAITool { + /** Code interpreter container ID or configuration. */ + readonly container: OpenAI.Responses.Tool.CodeInterpreter['container'] | null; + + constructor({ container = null }: CodeInterpreterOptions = {}) { + super({ id: 'openai_code_interpreter' }); + this.container = container; + } + + toToolConfig(): Record { + const result: Record = { type: 'code_interpreter' }; + if (this.container !== null) { + result.container = this.container; + } + return result; + } +} diff --git a/plugins/openai/src/ws/llm.ts b/plugins/openai/src/ws/llm.ts index 646853b7a..358a3e5b6 100644 --- a/plugins/openai/src/ws/llm.ts +++ b/plugins/openai/src/ws/llm.ts @@ -15,6 +15,7 @@ import { import type OpenAI from 'openai'; import { WebSocket } from 'ws'; import type { ChatModels } from '../models.js'; +import { toResponsesTools } from '../tool_utils.js'; import type { WsOutputItemDoneEvent, WsOutputTextDeltaEvent, @@ -253,24 +254,29 @@ export class WSLLM extends llm.LLM { chat({ chatCtx, - toolCtx, + toolCtx: toolCtxInput, connOptions = DEFAULT_API_CONNECT_OPTIONS, parallelToolCalls, toolChoice, extraKwargs, }: { chatCtx: llm.ChatContext; - toolCtx?: llm.ToolContext; + toolCtx?: llm.ToolContextLike; connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: llm.ToolChoice; extraKwargs?: Record; }): WSLLMStream { + const toolCtx = llm.toToolContext(toolCtxInput); const modelOptions: Record = { ...(extraKwargs ?? {}) }; parallelToolCalls = parallelToolCalls !== undefined ? parallelToolCalls : this.#opts.parallelToolCalls; - if (toolCtx && Object.keys(toolCtx).length > 0 && parallelToolCalls !== undefined) { + if ( + toolCtx && + Object.keys(toolCtx.functionTools).length > 0 && + parallelToolCalls !== undefined + ) { modelOptions.parallel_tool_calls = parallelToolCalls; } @@ -446,26 +452,7 @@ export class WSLLMStream extends llm.LLMStream { 'openai.responses', )) as OpenAI.Responses.ResponseInputItem[]; - const tools = this.toolCtx - ? llm.sortedToolEntries(this.toolCtx).map(([name, func]) => { - const oaiParams = { - type: 'function' as const, - name, - description: func.description, - parameters: llm.toJsonSchema( - func.parameters, - true, - this.#strictToolSchema, - ) as unknown as OpenAI.Responses.FunctionTool['parameters'], - } as OpenAI.Responses.FunctionTool; - - if (this.#strictToolSchema) { - oaiParams.strict = true; - } - - return oaiParams; - }) - : undefined; + const tools = this.toolCtx ? toResponsesTools(this.toolCtx, this.#strictToolSchema) : undefined; const requestOptions: Record = { ...this.#modelOptions }; if (!tools) { diff --git a/plugins/phonic/src/realtime/realtime_model.ts b/plugins/phonic/src/realtime/realtime_model.ts index aa2a28bf2..ada896966 100644 --- a/plugins/phonic/src/realtime/realtime_model.ts +++ b/plugins/phonic/src/realtime/realtime_model.ts @@ -239,7 +239,7 @@ interface GenerationState { * Realtime session for Phonic (https://docs.phonic.co/) */ export class RealtimeSession extends llm.RealtimeSession { - private _tools: llm.ToolContext = {}; + private _tools: llm.ToolContext = llm.ToolContext.empty(); private _chatCtx = llm.ChatContext.empty(); private options: RealtimeModelOptions; @@ -290,7 +290,7 @@ export class RealtimeSession extends llm.RealtimeSession { } get tools(): llm.ToolContext { - return { ...this._tools }; + return this._tools.copy(); } async updateInstructions(instructions: string): Promise { @@ -367,23 +367,23 @@ export class RealtimeSession extends llm.RealtimeSession { return; } - this._tools = { ...tools }; - this.toolDefinitions = Object.entries(tools) - .filter(([_, tool]) => llm.isFunctionTool(tool)) - .map(([name, tool]) => ({ - type: 'custom_websocket', + this._tools = tools.copy(); + // TODO: support provider tools in the Phonic schema. + this.toolDefinitions = tools + .flatten() + .filter(llm.isFunctionTool) + .map((t) => ({ + type: 'custom_websocket' as const, tool_schema: { - type: 'function', + type: 'function' as const, function: { - name, - description: tool.description, - parameters: llm.toJsonSchema(tool.parameters), + name: t.name, + description: t.description, + parameters: llm.toJsonSchema(t.parameters), strict: true, }, }, tool_call_output_timeout_ms: TOOL_CALL_OUTPUT_TIMEOUT_MS, - // Tool chaining and tool calls during speech are not supported at this time - // for ease of implementation within the RealtimeSession generations framework wait_for_speech_before_tool_call: true, allow_tool_chaining: false, })); @@ -405,17 +405,19 @@ export class RealtimeSession extends llm.RealtimeSession { this.options.instructions = instructions; } if (tools !== undefined) { - this._tools = { ...tools }; - this.toolDefinitions = Object.entries(tools) - .filter(([, tool]) => llm.isFunctionTool(tool)) - .map(([name, tool]) => ({ - type: 'custom_websocket', + this._tools = tools.copy(); + // TODO: support provider tools in the Phonic schema. + this.toolDefinitions = tools + .flatten() + .filter(llm.isFunctionTool) + .map((t) => ({ + type: 'custom_websocket' as const, tool_schema: { - type: 'function', + type: 'function' as const, function: { - name, - description: tool.description, - parameters: llm.toJsonSchema(tool.parameters), + name: t.name, + description: t.description, + parameters: llm.toJsonSchema(t.parameters), strict: true, }, }, @@ -571,6 +573,7 @@ export class RealtimeSession extends llm.RealtimeSession { private async connect(): Promise { this.socket = await this.client.conversations.connect({ + headers: { 'x-phonic-client': 'livekit-agents-js' }, reconnectAttempts: this.options.connOptions.maxRetry, }); diff --git a/plugins/test/src/llm.ts b/plugins/test/src/llm.ts index 8d85c75c3..1f912654a 100644 --- a/plugins/test/src/llm.ts +++ b/plugins/test/src/llm.ts @@ -5,8 +5,9 @@ import { initializeLogger, llm as llmlib } from '@livekit/agents'; import { describe, expect, it } from 'vitest'; import { z } from 'zod/v4'; -const toolCtx: llmlib.ToolContext = { - getWeather: llmlib.tool({ +const toolCtx = new llmlib.ToolContext([ + llmlib.tool({ + name: 'getWeather', description: 'Get the current weather in a given location', parameters: z.object({ location: z.string().describe('The city and state, e.g. San Francisco, CA'), @@ -14,14 +15,16 @@ const toolCtx: llmlib.ToolContext = { }), execute: async () => {}, }), - playMusic: llmlib.tool({ + llmlib.tool({ + name: 'playMusic', description: 'Play music', parameters: z.object({ name: z.string().describe('The artist and name of the song'), }), execute: async () => {}, }), - toggleLight: llmlib.tool({ + llmlib.tool({ + name: 'toggleLight', description: 'Turn on/off the lights in a room', parameters: z.object({ name: z.string().describe('The room to control'), @@ -31,7 +34,8 @@ const toolCtx: llmlib.ToolContext = { await new Promise((resolve) => setTimeout(resolve, 60_000)); }, }), - selectCurrencies: llmlib.tool({ + llmlib.tool({ + name: 'selectCurrencies', description: 'Currencies of a specific area', parameters: z.object({ currencies: z @@ -40,7 +44,8 @@ const toolCtx: llmlib.ToolContext = { }), execute: async () => {}, }), - updateUserInfo: llmlib.tool({ + llmlib.tool({ + name: 'updateUserInfo', description: 'Update user info.', parameters: z.object({ email: z.string().optional().describe("User's email address"), @@ -49,18 +54,20 @@ const toolCtx: llmlib.ToolContext = { }), execute: async () => {}, }), - simulateFailure: llmlib.tool({ + llmlib.tool({ + name: 'simulateFailure', description: 'Simulate a failure', parameters: z.object({}), execute: async () => { throw new Error('Simulated failure'); }, }), -}; +]); // Tool context for strict mode - uses nullable() instead of optional() -const toolCtxStrict: llmlib.ToolContext = { - getWeather: llmlib.tool({ +const toolCtxStrict = new llmlib.ToolContext([ + llmlib.tool({ + name: 'getWeather', description: 'Get the current weather in a given location', parameters: z.object({ location: z.string().describe('The city and state, e.g. San Francisco, CA'), @@ -68,14 +75,16 @@ const toolCtxStrict: llmlib.ToolContext = { }), execute: async () => {}, }), - playMusic: llmlib.tool({ + llmlib.tool({ + name: 'playMusic', description: 'Play music', parameters: z.object({ name: z.string().describe('The artist and name of the song'), }), execute: async () => {}, }), - toggleLight: llmlib.tool({ + llmlib.tool({ + name: 'toggleLight', description: 'Turn on/off the lights in a room', parameters: z.object({ name: z.string().describe('The room to control'), @@ -85,7 +94,8 @@ const toolCtxStrict: llmlib.ToolContext = { await new Promise((resolve) => setTimeout(resolve, 60_000)); }, }), - selectCurrencies: llmlib.tool({ + llmlib.tool({ + name: 'selectCurrencies', description: 'Currencies of a specific area', parameters: z.object({ currencies: z @@ -94,7 +104,8 @@ const toolCtxStrict: llmlib.ToolContext = { }), execute: async () => {}, }), - updateUserInfo: llmlib.tool({ + llmlib.tool({ + name: 'updateUserInfo', description: 'Update user info.', parameters: z.object({ email: z.string().nullable().describe("User's email address"), @@ -103,14 +114,15 @@ const toolCtxStrict: llmlib.ToolContext = { }), execute: async () => {}, }), - simulateFailure: llmlib.tool({ + llmlib.tool({ + name: 'simulateFailure', description: 'Simulate a failure', parameters: z.object({}), execute: async () => { throw new Error('Simulated failure'); }, }), -}; +]); export const llm = async (llm: llmlib.LLM, skipOptionalArgs: boolean) => { initializeLogger({ pretty: false }); @@ -188,6 +200,57 @@ export const llm = async (llm: llmlib.LLM, skipOptionalArgs: boolean) => { expect(JSON.parse(calls[0]!.args).address).toBeUndefined(); }); }); + + describe('toolset', async () => { + const buildToolsetContext = () => { + const weatherToolset = new llmlib.Toolset({ + id: 'weather_toolset', + tools: [ + llmlib.tool({ + name: 'getWeather', + description: 'Get the current weather in a given location', + parameters: z.object({ + location: z.string().describe('The city and state, e.g. San Francisco, CA'), + unit: z.enum(['celsius', 'fahrenheit']).describe('The temperature unit to use'), + }), + execute: async () => {}, + }), + ], + }); + + const directTool = llmlib.tool({ + name: 'playMusic', + description: 'Play music', + parameters: z.object({ + name: z.string().describe('The artist and name of the song'), + }), + execute: async () => {}, + }); + + return new llmlib.ToolContext([weatherToolset, directTool]); + }; + + it('should call a function tool that lives inside a Toolset', async () => { + const ctx = buildToolsetContext(); + const calls = await requestFncCall( + llm, + "What's the weather in San Francisco, in Celsius?", + ctx, + ); + + expect(calls.length).toStrictEqual(1); + expect(calls[0]!.name).toStrictEqual('getWeather'); + expect(JSON.parse(calls[0]!.args).unit).toStrictEqual('celsius'); + }); + + it('should expose direct tools alongside Toolset tools', async () => { + const ctx = buildToolsetContext(); + const calls = await requestFncCall(llm, 'Play the song "Bohemian Rhapsody" by Queen.', ctx); + + expect(calls.length).toStrictEqual(1); + expect(calls[0]!.name).toStrictEqual('playMusic'); + }); + }); }); }; @@ -315,7 +378,7 @@ const executeCalls = async (calls: llmlib.FunctionCall[]) => { const results: llmlib.FunctionCallOutput[] = []; for (const call of calls) { - const tool = toolCtx[call.name]; + const tool = toolCtx.getFunctionTool(call.name); if (!tool) { throw new Error(`Tool ${call.name} not found`); } diff --git a/turbo.json b/turbo.json index b0bc90527..b0ad80ff4 100644 --- a/turbo.json +++ b/turbo.json @@ -87,6 +87,7 @@ "SIP_PHONE_NUMBER", "LK_OPENAI_DEBUG", "LK_GOOGLE_DEBUG", + "LK_DRAIN_PLAYOUT_TIMEOUT_MS", "LIVEKIT_EVALS_VERBOSE", "OVHCLOUD_API_KEY", "INWORLD_API_KEY", @@ -114,6 +115,10 @@ "lint": { "outputs": [] }, + "typecheck": { + "dependsOn": ["^build"], + "outputs": [] + }, "api:check": { "cache": false, "dependsOn": ["^build"]