diff --git a/packages/core/lib/inference.ts b/packages/core/lib/inference.ts index b293230d0f..681a6c25ff 100644 --- a/packages/core/lib/inference.ts +++ b/packages/core/lib/inference.ts @@ -16,7 +16,14 @@ import type { InferStagehandSchema, StagehandZodObject, } from "./v3/zodCompat.js"; -import { SupportedUnderstudyAction } from "./v3/types/private/handlers.js"; +import { + ElementRef, + ModelAction, + ModelActResponse, + modelActionSchema, + modelActResponseSchema, +} from "./v3/types/private/modelActions.js"; +import type { EncodedId } from "./v3/types/private/internal.js"; import type { Variables } from "./v3/types/public/agent.js"; // Re-export for backward compatibility @@ -30,6 +37,81 @@ function withLlmTimeout(promise: Promise, operation: string): Promise { ); } +type LegacyInferenceAction = { + elementId: EncodedId; + description: string; + method: ModelAction["method"]; + arguments: string[]; +}; + +function encodeElementRef(ref: ElementRef): EncodedId { + return `${ref.frameOrdinal}-${ref.backendNodeId}`; +} + +function toLegacyInferenceAction(action: ModelAction): LegacyInferenceAction { + switch (action.method) { + case "click": + return { + elementId: encodeElementRef(action.target), + description: action.description, + method: action.method, + arguments: action.button ? [action.button] : [], + }; + case "fill": + return { + elementId: encodeElementRef(action.target), + description: action.description, + method: action.method, + arguments: [action.value], + }; + case "type": + return { + elementId: encodeElementRef(action.target), + description: action.description, + method: action.method, + arguments: [action.text], + }; + case "press": + return { + elementId: encodeElementRef(action.target), + description: action.description, + method: action.method, + arguments: [action.key], + }; + case "scrollTo": + return { + elementId: encodeElementRef(action.target), + description: action.description, + method: action.method, + arguments: [action.position], + }; + case "selectOptionFromDropdown": + return { + elementId: encodeElementRef(action.target), + description: action.description, + method: action.method, + arguments: [action.option], + }; + case "dragAndDrop": + return { + elementId: encodeElementRef(action.target), + description: action.description, + method: action.method, + arguments: [encodeElementRef(action.destination)], + }; + case "doubleClick": + case "hover": + case "nextChunk": + case "prevChunk": + return { + elementId: encodeElementRef(action.target), + description: action.description, + method: action.method, + arguments: [], + }; + } +} + export async function extract({ instruction, domElements, @@ -164,7 +246,6 @@ export async function extract({ response_model: { name: "Metadata", schema: metadataSchema, - strict: true, }, temperature: isGPT5 ? 1 : 0.1, top_p: 1, @@ -265,39 +346,7 @@ export async function observe({ const observeSchema = z.object({ elements: z - .array( - z.object({ - elementId: z - .string() - .regex(/^\d+-\d+$/) - .describe( - "the ID string associated with the element. Never include surrounding square brackets. This field must follow the format of 'number-number'.", - ), - description: z - .string() - .describe( - "a description of the accessible element and its purpose", - ), - method: z - .enum( - // Use Object.values() for Zod v3 compatibility - z.enum() in v3 doesn't accept TypeScript enums directly - Object.values(SupportedUnderstudyAction) as unknown as readonly [ - string, - ...string[], - ], - ) - .describe( - `the candidate method/action to interact with the element. Select one of the available Understudy interaction methods.`, - ), - arguments: z.array( - z - .string() - .describe( - "the arguments to pass to the method. For example, for a click, the arguments are empty, but for a fill, the arguments are the value to fill in.", - ), - ), - }), - ) + .array(modelActionSchema) .describe("an array of accessible elements that match the instruction"), }); @@ -334,7 +383,6 @@ export async function observe({ response_model: { schema: observeSchema, name: "Observation", - strict: true, }, temperature: isGPT5 ? 1 : 0.1, top_p: 1, @@ -379,13 +427,7 @@ export async function observe({ const parsedElements = observeData.elements?.map((el) => { - const base = { - elementId: el.elementId, - description: String(el.description), - method: String(el.method), - arguments: el.arguments, - }; - return base; + return toLegacyInferenceAction(el); }) ?? []; return { @@ -415,38 +457,9 @@ export async function act({ }) { const isGPT5 = llmClient.modelName.includes("gpt-5"); // TODO: remove this as we update support for gpt-5 configuration options - const actSchema = z.object({ - elementId: z - .string() - .regex(/^\d+-\d+$/) - .describe( - "the ID string associated with the element. Never include surrounding square brackets. This field must follow the format of 'number-number'. for example, '0-76' or '16-21'", - ), - description: z - .string() - .describe("a description of the accessible element and its purpose"), - method: z - .enum( - // Use Object.values() for Zod v3 compatibility - z.enum() in v3 doesn't accept TypeScript enums directly - Object.values(SupportedUnderstudyAction) as unknown as readonly [ - string, - ...string[], - ], - ) - .describe( - "the candidate method/action to interact with the element. Select one of the available Understudy interaction methods.", - ), - arguments: z.array( - z - .string() - .describe( - "the arguments to pass to the method. For example, for a click, the arguments are empty, but for a fill, the arguments are the value to fill in.", - ), - ), - twoStep: z.boolean(), - }); + const actSchema = modelActResponseSchema; - type ActResponse = z.infer; + type ActResponse = ModelActResponse; const messages: ChatMessage[] = [ buildActSystemPrompt(userProvidedInstructions), @@ -475,7 +488,6 @@ export async function act({ response_model: { schema: actSchema, name: "act", - strict: true, }, temperature: isGPT5 ? 1 : 0.1, top_p: 1, @@ -518,12 +530,7 @@ export async function act({ }); } - const parsedElement = { - elementId: actData.elementId, - description: String(actData.description), - method: String(actData.method), - arguments: actData.arguments, - }; + const parsedElement = toLegacyInferenceAction(actData.action); return { element: parsedElement, diff --git a/packages/core/lib/v3/agent/tools/act.ts b/packages/core/lib/v3/agent/tools/act.ts index c6131640fa..f512e4befd 100644 --- a/packages/core/lib/v3/agent/tools/act.ts +++ b/packages/core/lib/v3/agent/tools/act.ts @@ -1,4 +1,4 @@ -import { tool } from "ai"; +import { NoObjectGeneratedError, tool } from "ai"; import { z } from "zod"; import type { V3 } from "../../v3.js"; import type { Action } from "../../types/public/methods.js"; @@ -66,6 +66,9 @@ export const actTool = ( if (error instanceof TimeoutError) { throw error; } + if (NoObjectGeneratedError.isInstance(error)) { + throw error; + } return { success: false, error: error?.message ?? String(error), diff --git a/packages/core/lib/v3/agent/tools/extract.ts b/packages/core/lib/v3/agent/tools/extract.ts index 4050d8c97a..16b37a162d 100644 --- a/packages/core/lib/v3/agent/tools/extract.ts +++ b/packages/core/lib/v3/agent/tools/extract.ts @@ -1,4 +1,4 @@ -import { tool } from "ai"; +import { NoObjectGeneratedError, tool } from "ai"; import { z, ZodTypeAny } from "zod"; import type { V3 } from "../../v3.js"; import type { AgentModelConfig } from "../../types/public/agent.js"; @@ -103,6 +103,9 @@ export const extractTool = ( if (error instanceof TimeoutError) { throw error; } + if (NoObjectGeneratedError.isInstance(error)) { + throw error; + } return { success: false, error: error?.message ?? String(error) }; } }, diff --git a/packages/core/lib/v3/agent/tools/fillform.ts b/packages/core/lib/v3/agent/tools/fillform.ts index 4502d5c454..2824bfd7df 100644 --- a/packages/core/lib/v3/agent/tools/fillform.ts +++ b/packages/core/lib/v3/agent/tools/fillform.ts @@ -1,4 +1,4 @@ -import { tool } from "ai"; +import { NoObjectGeneratedError, tool } from "ai"; import { z } from "zod"; import type { V3 } from "../../v3.js"; import type { Action } from "../../types/public/methods.js"; @@ -77,6 +77,9 @@ export const fillFormTool = ( if (error instanceof TimeoutError) { throw error; } + if (NoObjectGeneratedError.isInstance(error)) { + throw error; + } return { success: false, error: error?.message ?? String(error), diff --git a/packages/core/lib/v3/external_clients/aisdk.ts b/packages/core/lib/v3/external_clients/aisdk.ts index 2cbafbb0b6..7d4bc99b8f 100644 --- a/packages/core/lib/v3/external_clients/aisdk.ts +++ b/packages/core/lib/v3/external_clients/aisdk.ts @@ -134,10 +134,6 @@ export class AISdkClient extends LLMClient { topP: options.top_p, frequencyPenalty: options.frequency_penalty, presencePenalty: options.presence_penalty, - providerOptions: - options.response_model.strict === false - ? { openai: { strictJsonSchema: false } } - : undefined, }); return { diff --git a/packages/core/lib/v3/llm/LLMProvider.ts b/packages/core/lib/v3/llm/LLMProvider.ts index 918bd259e6..f1d8b9d879 100644 --- a/packages/core/lib/v3/llm/LLMProvider.ts +++ b/packages/core/lib/v3/llm/LLMProvider.ts @@ -3,7 +3,6 @@ import { ExperimentalNotConfiguredError, UnsupportedAISDKModelProviderError, UnsupportedModelError, - UnsupportedModelProviderError, } from "../types/public/sdkErrors.js"; import { LogLine } from "../types/public/logs.js"; import { @@ -12,12 +11,7 @@ import { ModelProvider, } from "../types/public/model.js"; import { AISdkClient } from "./aisdk.js"; -import { AnthropicClient } from "./AnthropicClient.js"; -import { CerebrasClient } from "./CerebrasClient.js"; -import { GoogleClient } from "./GoogleClient.js"; -import { GroqClient } from "./GroqClient.js"; import { LLMClient } from "./LLMClient.js"; -import { OpenAIClient } from "./OpenAIClient.js"; import { openai, createOpenAI } from "@ai-sdk/openai"; import { bedrock, createAmazonBedrock } from "@ai-sdk/amazon-bedrock"; import { vertex, createVertex } from "@ai-sdk/google-vertex"; @@ -100,6 +94,25 @@ const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = { "gemini-2.5-pro-preview-03-25": "google", }; +function getAISDKProviderModel( + modelName: AvailableModel, +): { provider: string; model: string } | null { + if (modelName.includes("/")) { + const firstSlashIndex = modelName.indexOf("/"); + return { + provider: modelName.substring(0, firstSlashIndex), + model: modelName.substring(firstSlashIndex + 1), + }; + } + + const provider = modelToProviderMap[modelName]; + if (!provider || provider === "aisdk") { + return null; + } + + return { provider, model: modelName }; +} + export function getAISDKLanguageModel( subProvider: string, subModelName: string, @@ -159,12 +172,10 @@ export class LLMProvider { middleware?: LanguageModelV3Middleware; }, ): LLMClient { - if (modelName.includes("/")) { - const firstSlashIndex = modelName.indexOf("/"); - const subProvider = modelName.substring(0, firstSlashIndex); - const subModelName = modelName.substring(firstSlashIndex + 1); + const aisdkTarget = getAISDKProviderModel(modelName); + if (aisdkTarget) { if ( - subProvider === "vertex" && + aisdkTarget.provider === "vertex" && !options?.disableAPI && !options?.experimental ) { @@ -173,70 +184,29 @@ export class LLMProvider { const effectiveMiddleware = options?.middleware ?? this.middleware; const languageModel = getAISDKLanguageModel( - subProvider, - subModelName, + aisdkTarget.provider, + aisdkTarget.model, clientOptions, effectiveMiddleware, ); + if (!modelName.includes("/")) { + this.logger({ + category: "llm", + message: `Deprecation warning: Model format "${modelName}" is deprecated. Please use the provider/model format (e.g., "openai/gpt-5" or "anthropic/claude-sonnet-4").`, + level: 0, + }); + } + return new AISdkClient({ model: languageModel, logger: this.logger, clientOptions, + providerName: aisdkTarget.provider, }); } - // Model name doesn't include "/" - this format is deprecated - const provider = modelToProviderMap[modelName]; - if (!provider) { - throw new UnsupportedModelError(Object.keys(modelToProviderMap)); - } - - this.logger({ - category: "llm", - message: `Deprecation warning: Model format "${modelName}" is deprecated. Please use the provider/model format (e.g., "openai/gpt-5" or "anthropic/claude-sonnet-4").`, - level: 0, - }); - - const availableModel = modelName as AvailableModel; - switch (provider) { - case "openai": - return new OpenAIClient({ - logger: this.logger, - modelName: availableModel, - clientOptions, - }); - case "anthropic": - return new AnthropicClient({ - logger: this.logger, - modelName: availableModel, - clientOptions, - }); - case "cerebras": - return new CerebrasClient({ - logger: this.logger, - modelName: availableModel, - clientOptions, - }); - case "groq": - return new GroqClient({ - logger: this.logger, - modelName: availableModel, - clientOptions, - }); - case "google": - return new GoogleClient({ - logger: this.logger, - modelName: availableModel, - clientOptions, - }); - default: - // This default case handles unknown providers that exist in modelToProviderMap - // but aren't implemented in the switch. This is an internal consistency issue. - throw new UnsupportedModelProviderError([ - ...new Set(Object.values(modelToProviderMap)), - ]); - } + throw new UnsupportedModelError(Object.keys(modelToProviderMap)); } static getModelProvider(modelName: AvailableModel): ModelProvider { diff --git a/packages/core/lib/v3/llm/OpenAIClient.ts b/packages/core/lib/v3/llm/OpenAIClient.ts deleted file mode 100644 index d3391d8ca2..0000000000 --- a/packages/core/lib/v3/llm/OpenAIClient.ts +++ /dev/null @@ -1,407 +0,0 @@ -import OpenAI, { ClientOptions } from "openai"; -import { - ChatCompletionAssistantMessageParam, - ChatCompletionContentPartImage, - ChatCompletionContentPartText, - ChatCompletionCreateParamsNonStreaming, - ChatCompletionMessageParam, - ChatCompletionSystemMessageParam, - ChatCompletionUserMessageParam, -} from "openai/resources/chat"; -import { LogLine } from "../types/public/logs.js"; -import { AvailableModel } from "../types/public/model.js"; -import { validateZodSchema } from "../../utils.js"; -import { - ChatCompletionOptions, - ChatMessage, - CreateChatCompletionOptions, - LLMClient, - LLMResponse, -} from "./LLMClient.js"; -import { - CreateChatCompletionResponseError, - StagehandError, - ZodSchemaValidationError, -} from "../types/public/sdkErrors.js"; -import { toJsonSchema } from "../zodCompat.js"; - -export class OpenAIClient extends LLMClient { - public type = "openai" as const; - private client: OpenAI; - declare public clientOptions: ClientOptions; - - constructor({ - modelName, - clientOptions, - }: { - logger: (message: LogLine) => void; - modelName: AvailableModel; - clientOptions?: ClientOptions; - }) { - super(modelName); - this.clientOptions = clientOptions; - this.client = new OpenAI(clientOptions); - this.modelName = modelName; - } - - async createChatCompletion({ - options: optionsInitial, - logger, - retries = 3, - }: CreateChatCompletionOptions): Promise { - let options: Partial = optionsInitial; - - // O1 models do not support most of the options. So we override them. - // For schema and tools, we add them as user messages. - let isToolsOverridedForO1 = false; - if (this.modelName.startsWith("o1") || this.modelName.startsWith("o3")) { - /* eslint-disable */ - // Remove unsupported options - let { - tool_choice, - top_p, - frequency_penalty, - presence_penalty, - temperature, - } = options; - ({ - tool_choice, - top_p, - frequency_penalty, - presence_penalty, - temperature, - ...options - } = options); - /* eslint-enable */ - // Remove unsupported options - options.messages = options.messages.map((message) => ({ - ...message, - role: "user", - })); - if (options.tools && options.response_model) { - throw new StagehandError( - "Cannot use both tool and response_model for o1 models", - ); - } - - if (options.tools) { - // Remove unsupported options - const { tools, ...rest } = options; - options = rest; - isToolsOverridedForO1 = true; - options.messages.push({ - role: "user", - content: `You have the following tools available to you:\n${JSON.stringify( - tools, - )} - - Respond with the following zod schema format to use a method: { - "name": "", - "arguments": - } - - Do not include any other text or formattings like \`\`\` in your response. Just the JSON object.`, - }); - } - } - if ( - options.temperature && - (this.modelName.startsWith("o1") || this.modelName.startsWith("o3")) - ) { - throw new StagehandError("Temperature is not supported for o1 models"); - } - - const { requestId, ...optionsWithoutImageAndRequestId } = options; - - logger({ - category: "openai", - message: "creating chat completion", - level: 2, - auxiliary: { - options: { - value: JSON.stringify({ - ...optionsWithoutImageAndRequestId, - requestId, - }), - type: "object", - }, - modelName: { - value: this.modelName, - type: "string", - }, - }, - }); - - if (options.image) { - const screenshotMessage: ChatMessage = { - role: "user", - content: [ - { - type: "image_url", - image_url: { - url: `data:image/jpeg;base64,${options.image.buffer.toString("base64")}`, - }, - }, - ...(options.image.description - ? [{ type: "text", text: options.image.description }] - : []), - ], - }; - - options.messages.push(screenshotMessage); - } - - let responseFormat: - | ChatCompletionCreateParamsNonStreaming["response_format"] - | undefined; - if (options.response_model) { - // For O1 models, we need to add the schema as a user message. - if (this.modelName.startsWith("o1") || this.modelName.startsWith("o3")) { - try { - const parsedSchema = JSON.stringify( - toJsonSchema(options.response_model.schema), - ); - options.messages.push({ - role: "user", - content: `Respond in this zod schema format:\n${parsedSchema}\n - - Do not include any other text, formatting or markdown in your output. Do not include \`\`\` or \`\`\`json in your response. Only the JSON object itself.`, - }); - } catch (error) { - logger({ - category: "openai", - message: "Failed to parse response model schema", - level: 0, - }); - - if (retries > 0) { - // as-casting to account for o1 models not supporting all options - return this.createChatCompletion({ - options: options as ChatCompletionOptions, - logger, - retries: retries - 1, - }); - } - - throw error; - } - } else { - responseFormat = { - type: "json_schema", - json_schema: { - name: options.response_model.name, - schema: toJsonSchema(options.response_model.schema), - }, - }; - } - } - - /* eslint-disable */ - // Remove unsupported options - const { response_model, ...openAiOptions } = { - ...optionsWithoutImageAndRequestId, - model: this.modelName, - }; - /* eslint-enable */ - - logger({ - category: "openai", - message: "creating chat completion", - level: 2, - auxiliary: { - openAiOptions: { - value: JSON.stringify(openAiOptions), - type: "object", - }, - }, - }); - - const formattedMessages: ChatCompletionMessageParam[] = - options.messages.map((message) => { - if (Array.isArray(message.content)) { - const contentParts = message.content.map((content) => { - if ("image_url" in content) { - const imageContent: ChatCompletionContentPartImage = { - image_url: { - url: content.image_url.url, - }, - type: "image_url", - }; - return imageContent; - } else { - const textContent: ChatCompletionContentPartText = { - text: content.text, - type: "text", - }; - return textContent; - } - }); - - if (message.role === "system") { - const formattedMessage: ChatCompletionSystemMessageParam = { - ...message, - role: "system", - content: contentParts.filter( - (content): content is ChatCompletionContentPartText => - content.type === "text", - ), - }; - return formattedMessage; - } else if (message.role === "user") { - const formattedMessage: ChatCompletionUserMessageParam = { - ...message, - role: "user", - content: contentParts, - }; - return formattedMessage; - } else { - const formattedMessage: ChatCompletionAssistantMessageParam = { - ...message, - role: "assistant", - content: contentParts.filter( - (content): content is ChatCompletionContentPartText => - content.type === "text", - ), - }; - return formattedMessage; - } - } - - const formattedMessage: ChatCompletionUserMessageParam = { - role: "user", - content: message.content, - }; - - return formattedMessage; - }); - - const body: ChatCompletionCreateParamsNonStreaming = { - ...openAiOptions, - model: this.modelName, - messages: formattedMessages, - response_format: responseFormat, - stream: false, - tools: options.tools?.map((tool) => ({ - function: { - name: tool.name, - description: tool.description, - parameters: tool.parameters, - }, - type: "function", - })), - }; - - const response = await this.client.chat.completions.create(body); - - // For O1 models, we need to parse the tool call response manually and add it to the response. - if (isToolsOverridedForO1) { - try { - const parsedContent = JSON.parse(response.choices[0].message.content); - - response.choices[0].message.tool_calls = [ - { - function: { - name: parsedContent["name"], - arguments: JSON.stringify(parsedContent["arguments"]), - }, - type: "function", - id: "-1", - }, - ]; - response.choices[0].message.content = null; - } catch (error) { - logger({ - category: "openai", - message: "Failed to parse tool call response", - level: 0, - auxiliary: { - error: { - value: error.message, - type: "string", - }, - content: { - value: response.choices[0].message.content, - type: "string", - }, - }, - }); - - if (retries > 0) { - // as-casting to account for o1 models not supporting all options - return this.createChatCompletion({ - options: options as ChatCompletionOptions, - logger, - retries: retries - 1, - }); - } - - throw error; - } - } - - logger({ - category: "openai", - message: "response", - level: 2, - auxiliary: { - response: { - value: JSON.stringify(response), - type: "object", - }, - requestId: { - value: requestId, - type: "string", - }, - }, - }); - - if (options.response_model) { - const extractedData = response.choices[0].message.content; - const parsedData = JSON.parse(extractedData); - - try { - validateZodSchema(options.response_model.schema, parsedData); - } catch (e) { - logger({ - category: "openai", - message: "Response failed Zod schema validation", - level: 0, - }); - if (retries > 0) { - // as-casting to account for o1 models not supporting all options - return this.createChatCompletion({ - options: options as ChatCompletionOptions, - logger, - retries: retries - 1, - }); - } - - if (e instanceof ZodSchemaValidationError) { - logger({ - category: "openai", - message: `Error during OpenAI chat completion: ${e.message}`, - level: 0, - auxiliary: { - errorDetails: { - value: `Message: ${e.message}${e.stack ? "\nStack: " + e.stack : ""}`, - type: "string", - }, - requestId: { value: requestId, type: "string" }, - }, - }); - throw new CreateChatCompletionResponseError(e.message); - } - throw e; - } - - return { - data: parsedData, - usage: response.usage, - } as T; - } - - // if the function was called with a response model, it would have returned earlier - // so we can safely cast here to T, which defaults to ChatCompletion - return response as T; - } -} diff --git a/packages/core/lib/v3/llm/aisdk.ts b/packages/core/lib/v3/llm/aisdk.ts index 86899d8182..97c34c44c0 100644 --- a/packages/core/lib/v3/llm/aisdk.ts +++ b/packages/core/lib/v3/llm/aisdk.ts @@ -63,48 +63,81 @@ function toLLMUsage(usage?: { }; } -function buildOpenAiStructuredProviderOptions(options: { +function buildStructuredProviderOptions(options: { + providerName?: string; isGPT5: boolean; isCodex: boolean; reasoningEffort?: string; - strict: boolean; }) { - const openaiOptions: Record = {}; + const providerOptions: Record> = {}; - if (options.isGPT5) { - openaiOptions.textVerbosity = options.isCodex ? "medium" : "low"; - } + switch (options.providerName) { + case "openai": { + const openaiOptions: Record = { + strictJsonSchema: true, + }; - if (options.reasoningEffort) { - openaiOptions.reasoningEffort = options.reasoningEffort; - } + if (options.isGPT5) { + openaiOptions.textVerbosity = options.isCodex ? "medium" : "low"; + } - if (!options.strict) { - openaiOptions.strictJsonSchema = false; + if (options.reasoningEffort) { + openaiOptions.reasoningEffort = options.reasoningEffort; + } + + providerOptions.openai = openaiOptions; + break; + } + case "azure": + providerOptions.azure = { strictJsonSchema: true }; + break; + case "google": + providerOptions.google = { structuredOutputs: true }; + break; + case "vertex": + providerOptions.vertex = { structuredOutputs: true }; + break; + case "anthropic": + providerOptions.anthropic = { structuredOutputMode: "auto" }; + break; + case "groq": + providerOptions.groq = { structuredOutputs: true }; + break; + case "cerebras": + providerOptions.cerebras = { strictJsonSchema: true }; + break; + case "mistral": + providerOptions.mistral = { + structuredOutputs: true, + strictJsonSchema: true, + }; + break; } - return Object.keys(openaiOptions).length > 0 - ? { openai: openaiOptions } - : undefined; + return Object.keys(providerOptions).length > 0 ? providerOptions : undefined; } export class AISdkClient extends LLMClient { public type = "aisdk" as const; private model: LanguageModelV2 | LanguageModelV3; private logger?: (message: LogLine) => void; + private providerName?: string; constructor({ model, logger, clientOptions, + providerName, }: { model: LanguageModelV2 | LanguageModelV3; logger?: (message: LogLine) => void; clientOptions?: ClientOptions; + providerName?: string; }) { super(model.modelId as AvailableModel); this.model = model; this.logger = logger; + this.providerName = providerName; if (clientOptions) { this.clientOptions = clientOptions; } @@ -255,11 +288,11 @@ You must respond in JSON format. respond WITH JSON. Do not include any other tex topP: options.top_p, frequencyPenalty: options.frequency_penalty, presencePenalty: options.presence_penalty, - providerOptions: buildOpenAiStructuredProviderOptions({ + providerOptions: buildStructuredProviderOptions({ + providerName: this.providerName, isGPT5, isCodex, reasoningEffort: resolvedReasoningEffort, - strict: options.response_model.strict ?? true, }), }); diff --git a/packages/core/lib/v3/types/private/modelActions.ts b/packages/core/lib/v3/types/private/modelActions.ts new file mode 100644 index 0000000000..8353275293 --- /dev/null +++ b/packages/core/lib/v3/types/private/modelActions.ts @@ -0,0 +1,98 @@ +import { z } from "zod"; +import { SupportedUnderstudyAction } from "./handlers.js"; + +export const elementRefSchema = z.strictObject({ + frameOrdinal: z + .number() + .int() + .nonnegative() + .describe( + "The frame ordinal from the accessibility tree element identifier.", + ), + backendNodeId: z + .number() + .int() + .positive() + .describe( + "The backend node ID from the accessibility tree element identifier.", + ), +}); + +export type ElementRef = z.infer; + +const modelActionBaseSchema = z.strictObject({ + target: elementRefSchema.describe( + "The element to act on, represented as a frame ordinal plus backend node ID.", + ), + description: z + .string() + .describe("A description of the accessible element and its purpose."), +}); + +export const modelActionSchema = z.union([ + modelActionBaseSchema.extend({ + method: z.literal(SupportedUnderstudyAction.CLICK), + button: z + .enum(["right", "middle"]) + .nullable() + .describe( + "Mouse button override for click actions. Use null for the default left click.", + ), + }), + modelActionBaseSchema.extend({ + method: z.literal(SupportedUnderstudyAction.DOUBLE_CLICK), + }), + modelActionBaseSchema.extend({ + method: z.literal(SupportedUnderstudyAction.HOVER), + }), + modelActionBaseSchema.extend({ + method: z.literal(SupportedUnderstudyAction.FILL), + value: z + .string() + .describe("The text value to fill into the target element."), + }), + modelActionBaseSchema.extend({ + method: z.literal(SupportedUnderstudyAction.TYPE), + text: z.string().describe("The text to type into the target element."), + }), + modelActionBaseSchema.extend({ + method: z.literal(SupportedUnderstudyAction.PRESS), + key: z + .string() + .describe("The keyboard key to press, for example 'Enter' or 'Tab'."), + }), + modelActionBaseSchema.extend({ + method: z.literal(SupportedUnderstudyAction.SCROLL), + position: z + .string() + .describe("The target scroll position, such as '50%' or '75%'."), + }), + modelActionBaseSchema.extend({ + method: z.literal(SupportedUnderstudyAction.NEXT_CHUNK), + }), + modelActionBaseSchema.extend({ + method: z.literal(SupportedUnderstudyAction.PREV_CHUNK), + }), + modelActionBaseSchema.extend({ + method: z.literal(SupportedUnderstudyAction.SELECT_OPTION_FROM_DROPDOWN), + option: z + .string() + .describe("The exact dropdown option text that should be selected."), + }), + modelActionBaseSchema.extend({ + method: z.literal(SupportedUnderstudyAction.DRAG_AND_DROP), + destination: elementRefSchema.describe( + "The target destination element for the drag-and-drop action.", + ), + }), +]); + +const modelActResponseSchemaInner = z.strictObject({ + action: modelActionSchema.describe("The action to perform."), + twoStep: z.boolean(), +}); + +export const modelActResponseSchema = modelActResponseSchemaInner; + +export type ModelAction = z.infer; +export type ModelActResponse = z.infer; diff --git a/packages/core/lib/v3/understudy/cdp.ts b/packages/core/lib/v3/understudy/cdp.ts index ffda2f267c..708477e2a6 100644 --- a/packages/core/lib/v3/understudy/cdp.ts +++ b/packages/core/lib/v3/understudy/cdp.ts @@ -116,13 +116,41 @@ export class CdpConnection implements CDPSessionLike { static async connect( wsUrl: string, - options?: { headers?: Record }, + options?: { headers?: Record; retryForMs?: number }, + ): Promise { + const deadlineMs = options?.retryForMs + ? Date.now() + options.retryForMs + : undefined; + let lastError: unknown; + + do { + try { + return await CdpConnection.connectOnce(wsUrl, options?.headers); + } catch (error) { + lastError = error; + if ( + !deadlineMs || + Date.now() >= deadlineMs || + !isTransientConnectError(error) + ) { + throw error; + } + await new Promise((resolve) => setTimeout(resolve, 100)); + } + } while (deadlineMs && Date.now() < deadlineMs); + + throw lastError; + } + + private static async connectOnce( + wsUrl: string, + userHeaders?: Record, ): Promise { // Include User-Agent header for server-side observability and version tracking // Merge user-provided headers, letting them override defaults const headers = { "User-Agent": `Stagehand/${STAGEHAND_VERSION}`, - ...options?.headers, + ...userHeaders, }; const ws = new WebSocket(wsUrl, { headers }); await new Promise((resolve, reject) => { @@ -510,6 +538,25 @@ export class CdpConnection implements CDPSessionLike { } } +function isTransientConnectError(error: unknown): boolean { + if (!error || typeof error !== "object") return false; + + const code = + "code" in error && typeof error.code === "string" ? error.code : ""; + const message = + "message" in error && typeof error.message === "string" + ? error.message + : ""; + + return ( + code === "ECONNREFUSED" || + code === "ECONNRESET" || + code === "EPIPE" || + code === "ETIMEDOUT" || + /ECONNREFUSED|ECONNRESET|socket hang up|EPIPE|ETIMEDOUT/i.test(message) + ); +} + export class CdpSession implements CDPSessionLike { constructor( private readonly root: CdpConnection, diff --git a/packages/core/lib/v3/understudy/context.ts b/packages/core/lib/v3/understudy/context.ts index f1c70e16f5..1415944cf6 100644 --- a/packages/core/lib/v3/understudy/context.ts +++ b/packages/core/lib/v3/understudy/context.ts @@ -166,6 +166,9 @@ export class V3Context { const connectTask = async () => { const conn = await CdpConnection.connect(wsUrl, { headers: opts?.cdpHeaders, + // Local Chrome sometimes briefly refuses the first real CDP socket + // immediately after the readiness probe succeeds under CI load. + retryForMs: opts?.env === "LOCAL" ? 3_000 : undefined, }); const ctx = new V3Context( conn, diff --git a/packages/core/tests/integration/agent-cache-self-heal.spec.ts b/packages/core/tests/integration/agent-cache-self-heal.spec.ts index 2c20c273f9..9a918f55b2 100644 --- a/packages/core/tests/integration/agent-cache-self-heal.spec.ts +++ b/packages/core/tests/integration/agent-cache-self-heal.spec.ts @@ -2,12 +2,53 @@ import { test, expect } from "@playwright/test"; import fs from "fs/promises"; import path from "path"; import { V3 } from "../../lib/v3/v3.js"; -import { v3TestConfig } from "./v3.config.js"; +import { getV3TestConfig } from "./v3.config.js"; import type { AgentReplayActStep, AgentReplayFillFormStep, CachedAgentEntry, } from "../../lib/v3/types/private/cache.js"; +import { + createScriptedAisdkTestLlmClient, + doneToolResponse, + findElementRefForText, + toolCallResponse, +} from "./testUtils.js"; + +function encodeHtml(html: string): string { + return `data:text/html,${encodeURIComponent(html)}`; +} + +function createSelfHealLlmClient() { + return createScriptedAisdkTestLlmClient({ + jsonResponses: { + act: [ + (options) => ({ + action: { + target: findElementRefForText(options, "Launch self-heal"), + description: "launch self-heal button", + method: "click", + button: null, + }, + twoStep: false, + }), + (options) => ({ + action: { + target: findElementRefForText(options, "Launch self-heal"), + description: "launch self-heal button", + method: "click", + button: null, + }, + twoStep: false, + }), + ], + }, + generateResponses: [ + toolCallResponse("act", { action: "click the button" }), + doneToolResponse("Clicked the button successfully.", true), + ], + }); +} test.describe("Agent cache self-heal (e2e)", () => { let v3: V3; @@ -18,7 +59,10 @@ test.describe("Agent cache self-heal (e2e)", () => { await fs.mkdir(testInfo.outputDir, { recursive: true }); cacheDir = await fs.mkdtemp(path.join(testInfo.outputDir, "agent-cache-")); v3 = new V3({ - ...v3TestConfig, + ...getV3TestConfig({ + experimental: true, + llmClient: createSelfHealLlmClient(), + }), cacheDir, selfHeal: true, }); @@ -30,19 +74,32 @@ test.describe("Agent cache self-heal (e2e)", () => { }); test("replays heal corrupted selectors", async () => { - test.setTimeout(120_000); + test.setTimeout(60_000); - const agent = v3.agent({ - model: "anthropic/claude-haiku-4-5-20251001", - }); + const agent = v3.agent(); const page = v3.context.pages()[0]; - const url = - "https://browserbase.github.io/stagehand-eval-sites/sites/shadow-dom/"; + const url = encodeHtml(` + + + + +
idle
+ + + `); const instruction = "click the button"; - await page.goto(url, { waitUntil: "networkidle" }); + await page.goto(url, { waitUntil: "load" }); const firstResult = await agent.execute({ instruction, maxSteps: 20 }); expect(firstResult.success).toBe(true); + await expect + .poll(async () => page.evaluate(() => document.body.textContent ?? "")) + .toContain("clicked"); const cachePath = await locateAgentCacheFile(cacheDir); const originalEntry = await readCacheEntry(cachePath); @@ -62,9 +119,13 @@ test.describe("Agent cache self-heal (e2e)", () => { ); // Second run should replay from cache, self-heal, and update the file. - await page.goto(url, { waitUntil: "networkidle" }); + await page.goto(url, { waitUntil: "load" }); const replayResult = await agent.execute({ instruction, maxSteps: 20 }); expect(replayResult.success).toBe(true); + expect(replayResult.metadata?.cacheHit).toBe(true); + await expect + .poll(async () => page.evaluate(() => document.body.textContent ?? "")) + .toContain("clicked"); const healedEntry = await readCacheEntry(cachePath); const healedActionStep = findFirstActionStep(healedEntry); diff --git a/packages/core/tests/integration/agent-callbacks.spec.ts b/packages/core/tests/integration/agent-callbacks.spec.ts index b799434339..42f09fef20 100644 --- a/packages/core/tests/integration/agent-callbacks.spec.ts +++ b/packages/core/tests/integration/agent-callbacks.spec.ts @@ -3,11 +3,13 @@ import { V3 } from "../../lib/v3/v3.js"; import { v3TestConfig } from "./v3.config.js"; import type { StepResult, ToolSet } from "ai"; import { StreamingCallbacksInNonStreamingModeError } from "../../lib/v3/types/public/sdkErrors.js"; +import { closeV3 } from "./testUtils.js"; test.describe("Stagehand agent callbacks behavior", () => { + test.describe.configure({ mode: "serial" }); let v3: V3; - test.beforeEach(async () => { + test.beforeAll(async () => { v3 = new V3({ ...v3TestConfig, experimental: true, // Required for callbacks and streaming @@ -15,8 +17,8 @@ test.describe("Stagehand agent callbacks behavior", () => { await v3.init(); }); - test.afterEach(async () => { - await v3?.close?.().catch(() => {}); + test.afterAll(async () => { + await closeV3(v3); }); test.describe("Non-streaming callbacks (stream: false)", () => { diff --git a/packages/core/tests/integration/agent-experimental-validation.spec.ts b/packages/core/tests/integration/agent-experimental-validation.spec.ts index fdf4c7864b..2d9859437b 100644 --- a/packages/core/tests/integration/agent-experimental-validation.spec.ts +++ b/packages/core/tests/integration/agent-experimental-validation.spec.ts @@ -7,6 +7,7 @@ import { ExperimentalNotConfiguredError, StagehandInvalidArgumentError, } from "../../lib/v3/types/public/sdkErrors.js"; +import { closeV3 } from "./testUtils.js"; // Define a mock custom tool for testing const mockCustomTool = tool({ @@ -28,11 +29,10 @@ test.describe("Stagehand agent experimental feature validation", () => { ...v3TestConfig, experimental: false, }); - await v3.init(); }); test.afterEach(async () => { - await v3?.close?.().catch(() => {}); + await closeV3(v3); }); test("throws StagehandInvalidArgumentError when CUA and streaming are both enabled", async () => { @@ -51,15 +51,11 @@ test.describe("Stagehand agent experimental feature validation", () => { }); test("throws StagehandInvalidArgumentError for CUA + streaming even with experimental: true", async () => { - // Close the non-experimental instance - await v3.close(); - // Create an experimental instance const v3Experimental = new V3({ ...v3TestConfig, experimental: true, }); - await v3Experimental.init(); try { v3Experimental.agent({ @@ -73,7 +69,7 @@ test.describe("Stagehand agent experimental feature validation", () => { expect((error as Error).message).toContain("streaming"); expect((error as Error).message).toContain("not supported with CUA"); } finally { - await v3Experimental.close(); + await closeV3(v3Experimental); } }); }); @@ -90,7 +86,7 @@ test.describe("Stagehand agent experimental feature validation", () => { }); test.afterEach(async () => { - await v3?.close?.().catch(() => {}); + await closeV3(v3); }); test("throws ExperimentalNotConfiguredError for MCP integrations", async () => { @@ -233,7 +229,7 @@ test.describe("Stagehand agent experimental feature validation", () => { }); test.afterEach(async () => { - await v3?.close?.().catch(() => {}); + await closeV3(v3); }); test("throws ExperimentalNotConfiguredError for CUA with integrations", async () => { @@ -317,7 +313,7 @@ test.describe("Stagehand agent experimental feature validation", () => { test("throws StagehandInvalidArgumentError for CUA unsupported features even with experimental: true", async () => { // Close the non-experimental instance - await v3.close(); + await closeV3(v3); // Create an experimental instance const v3Experimental = new V3({ @@ -342,7 +338,7 @@ test.describe("Stagehand agent experimental feature validation", () => { expect(error).toBeInstanceOf(StagehandInvalidArgumentError); expect((error as Error).message).toContain("not supported with CUA"); } finally { - await v3Experimental.close(); + await closeV3(v3Experimental); } }); }); @@ -355,11 +351,10 @@ test.describe("Stagehand agent experimental feature validation", () => { ...v3TestConfig, experimental: true, }); - await v3.init(); }); test.afterEach(async () => { - await v3?.close?.().catch(() => {}); + await closeV3(v3); }); test("allows CUA without streaming", () => { @@ -385,7 +380,6 @@ test.describe("Stagehand agent experimental feature validation", () => { ...v3TestConfig, experimental: false, }); - await v3NonExperimental.init(); try { // This should work - just creating a basic agent with no experimental features @@ -395,7 +389,7 @@ test.describe("Stagehand agent experimental feature validation", () => { }), ).not.toThrow(); } finally { - await v3NonExperimental.close(); + await closeV3(v3NonExperimental); } }); }); diff --git a/packages/core/tests/integration/click-count.spec.ts b/packages/core/tests/integration/click-count.spec.ts index f4690d127d..da34ad83a4 100644 --- a/packages/core/tests/integration/click-count.spec.ts +++ b/packages/core/tests/integration/click-count.spec.ts @@ -1,6 +1,7 @@ import { test, expect } from "@playwright/test"; import { V3 } from "../../lib/v3/v3.js"; import { v3TestConfig } from "./v3.config.js"; +import { closeV3 } from "./testUtils.js"; // Keep double-click verification event-based and deterministic. // Time-delta counters (Date.now() between mousedowns) are flaky at ms boundaries @@ -38,15 +39,16 @@ const doubleClickFixtureUrl = `data:text/html,${encodeURIComponent(``)}`; test.describe("Locator and Page click methods", () => { + test.describe.configure({ mode: "serial" }); let v3: V3; - test.beforeEach(async () => { + test.beforeAll(async () => { v3 = new V3(v3TestConfig); await v3.init(); }); - test.afterEach(async () => { - await v3?.close?.().catch(() => {}); + test.afterAll(async () => { + await closeV3(v3); }); test("locator.click() performs single click by default", async () => { diff --git a/packages/core/tests/integration/flowLogger.spec.ts b/packages/core/tests/integration/flowLogger.spec.ts index 88f45cf653..3b0533ac01 100644 --- a/packages/core/tests/integration/flowLogger.spec.ts +++ b/packages/core/tests/integration/flowLogger.spec.ts @@ -8,7 +8,7 @@ import { createScriptedAisdkTestLlmClient, closeV3, doneToolResponse, - findLastEncodedId, + findLastElementRef, toolCallResponse, } from "./testUtils.js"; import { getV3TestConfig } from "./v3.config.js"; @@ -213,10 +213,12 @@ test.describe("flow logger integration", () => { const llmClient = createScriptedAisdkTestLlmClient({ jsonResponses: { act: (options) => ({ - elementId: findLastEncodedId(options), - description: `click ${buttonText}`, - method: "click", - arguments: [], + action: { + target: findLastElementRef(options), + description: `click ${buttonText}`, + method: "click", + button: null, + }, twoStep: false, }), }, @@ -298,10 +300,10 @@ test.describe("flow logger integration", () => { Observation: (options) => ({ elements: [ { - elementId: findLastEncodedId(options), + target: findLastElementRef(options), description: observeText, method: "click", - arguments: [], + button: null, }, ], }), @@ -435,10 +437,12 @@ test.describe("flow logger integration", () => { const llmClient = createScriptedAisdkTestLlmClient({ jsonResponses: { act: (options) => ({ - elementId: findLastEncodedId(options), - description: `click ${buttonText}`, - method: "click", - arguments: [], + action: { + target: findLastElementRef(options), + description: `click ${buttonText}`, + method: "click", + button: null, + }, twoStep: false, }), }, @@ -538,10 +542,10 @@ test.describe("flow logger integration", () => { Observation: (options) => ({ elements: [ { - elementId: findLastEncodedId(options), + target: findLastElementRef(options), description: "name input", method: "fill", - arguments: ["hello"], + value: "hello", }, ], }), diff --git a/packages/core/tests/integration/page-hover.spec.ts b/packages/core/tests/integration/page-hover.spec.ts index b6098c4ec5..daf91bafc7 100644 --- a/packages/core/tests/integration/page-hover.spec.ts +++ b/packages/core/tests/integration/page-hover.spec.ts @@ -1,17 +1,19 @@ import { test, expect } from "@playwright/test"; import { V3 } from "../../lib/v3/v3.js"; import { v3TestConfig } from "./v3.config.js"; +import { closeV3 } from "./testUtils.js"; test.describe("Page.hover() - mouse hover at coordinates", () => { + test.describe.configure({ mode: "serial" }); let v3: V3; - test.beforeEach(async () => { + test.beforeAll(async () => { v3 = new V3(v3TestConfig); await v3.init(); }); - test.afterEach(async () => { - await v3?.close?.().catch(() => {}); + test.afterAll(async () => { + await closeV3(v3); }); test("hover triggers mouseover event at coordinates", async () => { diff --git a/packages/core/tests/integration/testUtils.ts b/packages/core/tests/integration/testUtils.ts index 64c21aa9ad..4f321ba4c2 100644 --- a/packages/core/tests/integration/testUtils.ts +++ b/packages/core/tests/integration/testUtils.ts @@ -147,7 +147,7 @@ function resolveJsonResponseKey( }; const properties = schema?.properties ?? {}; - if ("elementId" in properties && "twoStep" in properties) { + if ("action" in properties && "twoStep" in properties) { return "act"; } @@ -209,6 +209,38 @@ export function findLastEncodedId(options: LanguageModelV3CallOptions): string { return matches[matches.length - 1]; } +function parseEncodedId(encodedId: string): { + frameOrdinal: number; + backendNodeId: number; +} { + const match = encodedId.match(/^(\d+)-(\d+)$/); + if (!match) { + throw new Error(`Invalid encoded id: ${encodedId}`); + } + + return { + frameOrdinal: Number(match[1]), + backendNodeId: Number(match[2]), + }; +} + +export function findLastElementRef(options: LanguageModelV3CallOptions): { + frameOrdinal: number; + backendNodeId: number; +} { + return parseEncodedId(findLastEncodedId(options)); +} + +export function findElementRefForText( + options: LanguageModelV3CallOptions, + text: string, +): { + frameOrdinal: number; + backendNodeId: number; +} { + return parseEncodedId(findEncodedIdForText(options, text)); +} + export function toolCallResponse( toolName: string, input: Record, diff --git a/packages/core/tests/integration/text-selector-innermost.spec.ts b/packages/core/tests/integration/text-selector-innermost.spec.ts index a5a56514d2..db0afb2daa 100644 --- a/packages/core/tests/integration/text-selector-innermost.spec.ts +++ b/packages/core/tests/integration/text-selector-innermost.spec.ts @@ -5,14 +5,15 @@ import { v3DynamicTestConfig } from "./v3.dynamic.config.js"; import { closeV3 } from "./testUtils.js"; test.describe("Text selector innermost element matching", () => { + test.describe.configure({ mode: "serial" }); let v3: V3; - test.beforeEach(async () => { + test.beforeAll(async () => { v3 = new V3(v3DynamicTestConfig); await v3.init(); }); - test.afterEach(async () => { + test.afterAll(async () => { await closeV3(v3); }); diff --git a/packages/core/tests/integration/timeouts.spec.ts b/packages/core/tests/integration/timeouts.spec.ts index 76134fa8a7..9f1edc963d 100644 --- a/packages/core/tests/integration/timeouts.spec.ts +++ b/packages/core/tests/integration/timeouts.spec.ts @@ -124,10 +124,15 @@ function createToolTimeoutTestLlmClient( if (responseModelName === "act") { return { data: { - elementId: "1-0", - description: "click body", - method: "click", - arguments: [], + action: { + target: { + frameOrdinal: 1, + backendNodeId: 1, + }, + description: "click body", + method: "click", + button: null, + }, twoStep: false, }, usage, diff --git a/packages/core/tests/unit/aisdk-client-compat.test.ts b/packages/core/tests/unit/aisdk-client-compat.test.ts index 5ea12d9c62..de35a5bddd 100644 --- a/packages/core/tests/unit/aisdk-client-compat.test.ts +++ b/packages/core/tests/unit/aisdk-client-compat.test.ts @@ -7,6 +7,7 @@ import type { LanguageModelV3GenerateResult, LanguageModelV3Usage, } from "@ai-sdk/provider"; +import { NoObjectGeneratedError } from "ai"; import { z } from "zod"; import { AISdkClient } from "../../lib/v3/llm/aisdk.js"; @@ -167,6 +168,34 @@ describe("AISdkClient compatibility", () => { expect(result.usage).not.toHaveProperty("cachedInputTokens"); }); + it("createChatCompletion() with response_model throws NoObjectGeneratedError for invalid structured output", async () => { + const model = createScriptedModel(() => ({ + content: [ + { + type: "text", + text: JSON.stringify({ + invalidResponseShape: "missing required title field", + }), + }, + ], + })); + + const client = new AISdkClient({ model }); + + await expect( + client.createChatCompletion({ + options: { + messages: [{ role: "user", content: "Return the extraction title." }], + response_model: { + name: "Extraction", + schema: z.object({ title: z.string() }), + }, + }, + logger: vi.fn(), + }), + ).rejects.toBeInstanceOf(NoObjectGeneratedError); + }); + it("createChatCompletion() without response_model maps tool calls into legacy chat completion shape", async () => { const model = createScriptedModel(() => ({ content: [ diff --git a/packages/core/tests/unit/element-id-regression.test.ts b/packages/core/tests/unit/element-id-regression.test.ts new file mode 100644 index 0000000000..eaa83f5763 --- /dev/null +++ b/packages/core/tests/unit/element-id-regression.test.ts @@ -0,0 +1,338 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { act as runAct, observe as runObserve } from "../../lib/inference.js"; +import { ActHandler } from "../../lib/v3/handlers/actHandler.js"; +import { ObserveHandler } from "../../lib/v3/handlers/observeHandler.js"; +import type { LLMClient } from "../../lib/v3/llm/LLMClient.js"; +import type { Page } from "../../lib/v3/understudy/page.js"; +import { waitForDomNetworkQuiet } from "../../lib/v3/handlers/handlerUtils/actHandlerUtils.js"; +import { captureHybridSnapshot } from "../../lib/v3/understudy/a11y/snapshot/index.js"; + +vi.mock("../../lib/v3/handlers/handlerUtils/actHandlerUtils", () => ({ + waitForDomNetworkQuiet: vi.fn(), + performUnderstudyMethod: vi.fn(), +})); + +vi.mock("../../lib/v3/understudy/a11y/snapshot/index.js", () => ({ + captureHybridSnapshot: vi.fn(), + diffCombinedTrees: vi.fn(), +})); + +const usage = { + prompt_tokens: 1, + completion_tokens: 1, + total_tokens: 2, +}; + +function buildSchemaValidatingLlmClient( + payloads: Record, +): LLMClient { + return { + type: "aisdk", + modelName: "google/gemini-3-flash-preview", + hasVision: false, + clientOptions: {}, + createChatCompletion: async ({ + options, + }: { + options: { + response_model?: { + name: string; + schema: { + safeParseAsync: (data: unknown) => Promise<{ + success: boolean; + data?: unknown; + error?: unknown; + }>; + }; + }; + }; + }): Promise => { + const responseModel = options.response_model; + if (!responseModel) { + return { data: {}, usage } as T; + } + + const result = await responseModel.schema.safeParseAsync( + payloads[responseModel.name], + ); + if (!result.success) { + throw result.error; + } + + return { + data: result.data, + usage, + } as T; + }, + } as LLMClient; +} + +function buildActHandler(llmClient: LLMClient): ActHandler { + return new ActHandler( + llmClient, + "google/gemini-3-flash-preview", + {}, + () => llmClient, + ); +} + +function buildObserveHandler(llmClient: LLMClient): ObserveHandler { + return new ObserveHandler( + llmClient, + "google/gemini-3-flash-preview", + {}, + () => llmClient, + ); +} + +describe("typed element reference regression", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("accepts structured act element refs and re-encodes them for upstream callers", async () => { + const llmClient = buildSchemaValidatingLlmClient({ + act: { + action: { + target: { + frameOrdinal: 0, + backendNodeId: 9786, + }, + description: "gear button", + method: "click", + button: null, + }, + twoStep: false, + }, + }); + + await expect( + runAct({ + instruction: "click the gear button", + domElements: "[0-9786] button Gear", + llmClient, + logger: vi.fn(), + }), + ).resolves.toMatchObject({ + element: { + elementId: "0-9786", + description: "gear button", + method: "click", + arguments: [], + }, + twoStep: false, + }); + }); + + it("rejects act responses when twoStep is nested inside action", async () => { + const llmClient = buildSchemaValidatingLlmClient({ + act: { + action: { + target: { + frameOrdinal: 0, + backendNodeId: 9786, + }, + description: "gear button", + method: "click", + button: null, + twoStep: false, + }, + }, + }); + + await expect( + runAct({ + instruction: "click the gear button", + domElements: "[0-9786] button Gear", + llmClient, + logger: vi.fn(), + }), + ).rejects.toBeTruthy(); + }); + + it("accepts structured observe element refs and re-encodes them for upstream callers", async () => { + const llmClient = buildSchemaValidatingLlmClient({ + Observation: { + elements: [ + { + target: { + frameOrdinal: 0, + backendNodeId: 9786, + }, + description: "gear button", + method: "click", + button: null, + }, + ], + }, + }); + + await expect( + runObserve({ + instruction: "find the gear button", + domElements: "[0-9786] button Gear", + llmClient, + logger: vi.fn(), + }), + ).resolves.toMatchObject({ + elements: [ + { + elementId: "0-9786", + description: "gear button", + method: "click", + arguments: [], + }, + ], + }); + }); + + it("resolves structured act element refs against the encoded xpath map", async () => { + const llmClient = buildSchemaValidatingLlmClient({ + act: { + action: { + target: { + frameOrdinal: 0, + backendNodeId: 9786, + }, + description: "gear button", + method: "click", + button: null, + }, + twoStep: false, + }, + }); + vi.mocked(waitForDomNetworkQuiet).mockResolvedValue(undefined); + vi.mocked(captureHybridSnapshot).mockResolvedValue({ + combinedTree: "[0-9786] button Gear", + combinedXpathMap: { + "0-9786": "/html/body/button", + }, + combinedUrlMap: {}, + }); + + const { performUnderstudyMethod } = await import( + "../../lib/v3/handlers/handlerUtils/actHandlerUtils.js" + ); + const performUnderstudyMethodMock = vi.mocked(performUnderstudyMethod); + performUnderstudyMethodMock.mockResolvedValue(undefined); + + const handler = buildActHandler(llmClient); + const fakePage = { + mainFrame: vi.fn().mockReturnValue({}), + } as unknown as Page; + + await expect( + handler.act({ + instruction: "click the gear button", + page: fakePage, + }), + ).resolves.toMatchObject({ + success: true, + actions: [ + { + selector: "xpath=/html/body/button", + method: "click", + arguments: [], + }, + ], + }); + }); + + it("resolves typed drag-and-drop refs against the encoded xpath map", async () => { + const llmClient = buildSchemaValidatingLlmClient({ + act: { + action: { + target: { + frameOrdinal: 0, + backendNodeId: 9786, + }, + destination: { + frameOrdinal: 1, + backendNodeId: 4321, + }, + description: "drag the gear button into the dropzone", + method: "dragAndDrop", + }, + twoStep: false, + }, + }); + vi.mocked(waitForDomNetworkQuiet).mockResolvedValue(undefined); + vi.mocked(captureHybridSnapshot).mockResolvedValue({ + combinedTree: "[0-9786] button Gear\n[1-4321] region Dropzone", + combinedXpathMap: { + "0-9786": "/html/body/button", + "1-4321": "/html/body/div[2]", + }, + combinedUrlMap: {}, + }); + + const { performUnderstudyMethod } = await import( + "../../lib/v3/handlers/handlerUtils/actHandlerUtils.js" + ); + const performUnderstudyMethodMock = vi.mocked(performUnderstudyMethod); + performUnderstudyMethodMock.mockResolvedValue(undefined); + + const handler = buildActHandler(llmClient); + const fakePage = { + mainFrame: vi.fn().mockReturnValue({}), + } as unknown as Page; + + await expect( + handler.act({ + instruction: "drag the gear button into the dropzone", + page: fakePage, + }), + ).resolves.toMatchObject({ + success: true, + actions: [ + { + selector: "xpath=/html/body/button", + method: "dragAndDrop", + arguments: ["xpath=/html/body/div[2]"], + }, + ], + }); + }); + + it("resolves structured observe element refs against the encoded xpath map", async () => { + const llmClient = buildSchemaValidatingLlmClient({ + Observation: { + elements: [ + { + target: { + frameOrdinal: 0, + backendNodeId: 9786, + }, + description: "gear button", + method: "click", + button: null, + }, + ], + }, + }); + vi.mocked(captureHybridSnapshot).mockResolvedValue({ + combinedTree: "[0-9786] button Gear", + combinedXpathMap: { + "0-9786": "/html/body/button", + }, + combinedUrlMap: {}, + }); + + const handler = buildObserveHandler(llmClient); + const fakePage = {} as Page; + + await expect( + handler.observe({ + instruction: "find the gear button", + page: fakePage, + }), + ).resolves.toEqual([ + { + description: "gear button", + method: "click", + arguments: [], + selector: "xpath=/html/body/button", + }, + ]); + }); +}); diff --git a/packages/core/tests/unit/model-deprecation.test.ts b/packages/core/tests/unit/model-deprecation.test.ts index 32d45940bb..db3f078a0c 100644 --- a/packages/core/tests/unit/model-deprecation.test.ts +++ b/packages/core/tests/unit/model-deprecation.test.ts @@ -85,7 +85,7 @@ describe("Model format deprecation", () => { expect(message).toContain("openai/gpt-5"); }); - it("returns OpenAIClient for legacy OpenAI model names", () => { + it("returns AISdkClient for legacy OpenAI model names", () => { const logs: LogLine[] = []; const logger = (line: LogLine) => logs.push(line); const provider = new LLMProvider(logger); @@ -94,11 +94,11 @@ describe("Model format deprecation", () => { // Should return a client expect(client).toBeDefined(); - // The client should be an OpenAIClient (check constructor name) - expect(client.constructor.name).toBe("OpenAIClient"); + // Legacy bare model names should still normalize to the AI SDK path. + expect(client.constructor.name).toBe("AISdkClient"); }); - it("returns GoogleClient for legacy Google model names", () => { + it("returns AISdkClient for legacy Google model names", () => { const logs: LogLine[] = []; const logger = (line: LogLine) => logs.push(line); const provider = new LLMProvider(logger); @@ -107,8 +107,7 @@ describe("Model format deprecation", () => { // Should return a client expect(client).toBeDefined(); - // The client should be a GoogleClient - expect(client.constructor.name).toBe("GoogleClient"); + expect(client.constructor.name).toBe("AISdkClient"); }); }); diff --git a/packages/core/tests/unit/structured-output-errors.test.ts b/packages/core/tests/unit/structured-output-errors.test.ts new file mode 100644 index 0000000000..4ee6e02480 --- /dev/null +++ b/packages/core/tests/unit/structured-output-errors.test.ts @@ -0,0 +1,146 @@ +import { describe, expect, it, vi } from "vitest"; +import { NoObjectGeneratedError } from "ai"; +import { actTool } from "../../lib/v3/agent/tools/act.js"; +import { extractTool } from "../../lib/v3/agent/tools/extract.js"; +import { fillFormTool } from "../../lib/v3/agent/tools/fillform.js"; +import { V3AgentHandler } from "../../lib/v3/handlers/v3AgentHandler.js"; +import type { V3 } from "../../lib/v3/v3.js"; +import type { LLMClient } from "../../lib/v3/llm/LLMClient.js"; + +function createNoObjectGeneratedError(): NoObjectGeneratedError { + return new NoObjectGeneratedError({ + message: "Invalid structured output: missing required fields", + text: '{"invalidResponseShape":"missing required title field"}', + response: { + id: "resp_mock", + timestamp: new Date(), + modelId: "mock/stagehand-compat", + } as never, + usage: { + inputTokens: 1, + outputTokens: 1, + totalTokens: 2, + } as never, + finishReason: "stop" as never, + }); +} + +describe("structured output error propagation", () => { + it("actTool rethrows NoObjectGeneratedError from v3.act()", async () => { + const error = createNoObjectGeneratedError(); + const v3 = { + logger: vi.fn(), + act: vi.fn().mockRejectedValue(error), + recordAgentReplayStep: vi.fn(), + } as unknown as V3; + + const toolDef = actTool(v3); + + await expect( + toolDef.execute?.({ action: "click the button" }, {} as never), + ).rejects.toBe(error); + expect(v3.recordAgentReplayStep).not.toHaveBeenCalled(); + }); + + it("fillFormTool rethrows NoObjectGeneratedError from v3.observe()", async () => { + const error = createNoObjectGeneratedError(); + const v3 = { + logger: vi.fn(), + observe: vi.fn().mockRejectedValue(error), + act: vi.fn(), + recordAgentReplayStep: vi.fn(), + } as unknown as V3; + + const toolDef = fillFormTool(v3); + + await expect( + toolDef.execute?.( + { + fields: [{ action: "type hello into the first name input" }], + }, + {} as never, + ), + ).rejects.toBe(error); + expect(v3.recordAgentReplayStep).not.toHaveBeenCalled(); + }); + + it("extractTool rethrows NoObjectGeneratedError from v3.extract()", async () => { + const error = createNoObjectGeneratedError(); + const v3 = { + extract: vi.fn().mockRejectedValue(error), + } as unknown as V3; + + const toolDef = extractTool(v3); + + await expect( + toolDef.execute?.( + { + instruction: "extract the title", + schema: { + type: "object", + properties: { + title: { type: "string" }, + }, + }, + }, + {} as never, + ), + ).rejects.toBe(error); + }); + + it("V3AgentHandler.execute() preserves the failed AgentResult contract for NoObjectGeneratedError", async () => { + const error = createNoObjectGeneratedError(); + const llmClient = { + generateText: vi.fn().mockRejectedValue(error), + } as unknown as LLMClient; + + const handler = new V3AgentHandler({} as V3, vi.fn(), llmClient); + + vi.spyOn( + handler as unknown as { prepareAgent: () => Promise }, + "prepareAgent", + ).mockResolvedValue({ + options: { instruction: "describe the page" }, + maxSteps: 3, + systemPrompt: "", + allTools: {}, + messages: [{ role: "user", content: "describe the page" }], + wrappedModel: {}, + initialPageUrl: "https://example.com", + }); + + await expect(handler.execute("describe the page")).resolves.toMatchObject({ + success: false, + completed: false, + message: + "Failed to execute task: Invalid structured output: missing required fields", + }); + }); + + it("V3AgentHandler.execute() still returns a failed result for generic errors", async () => { + const llmClient = { + generateText: vi.fn().mockRejectedValue(new Error("boom")), + } as unknown as LLMClient; + + const handler = new V3AgentHandler({} as V3, vi.fn(), llmClient); + + vi.spyOn( + handler as unknown as { prepareAgent: () => Promise }, + "prepareAgent", + ).mockResolvedValue({ + options: { instruction: "describe the page" }, + maxSteps: 3, + systemPrompt: "", + allTools: {}, + messages: [{ role: "user", content: "describe the page" }], + wrappedModel: {}, + initialPageUrl: "https://example.com", + }); + + await expect(handler.execute("describe the page")).resolves.toMatchObject({ + success: false, + completed: false, + message: "Failed to execute task: boom", + }); + }); +}); diff --git a/packages/core/tests/unit/timeout-handlers.test.ts b/packages/core/tests/unit/timeout-handlers.test.ts index fad085834d..a16e08cb9d 100644 --- a/packages/core/tests/unit/timeout-handlers.test.ts +++ b/packages/core/tests/unit/timeout-handlers.test.ts @@ -204,6 +204,8 @@ describe("ActHandler two-step timeout", () => { twoStep: true, prompt_tokens: 100, completion_tokens: 50, + reasoning_tokens: 0, + cached_input_tokens: 0, inference_time_ms: 500, } as ReturnType extends Promise ? T : never); @@ -287,6 +289,8 @@ describe("ActHandler self-heal timeout", () => { twoStep: false, prompt_tokens: 100, completion_tokens: 50, + reasoning_tokens: 0, + cached_input_tokens: 0, inference_time_ms: 500, } as ReturnType extends Promise ? T : never); @@ -358,6 +362,8 @@ describe("ActHandler self-heal timeout", () => { twoStep: false, prompt_tokens: 100, completion_tokens: 50, + reasoning_tokens: 0, + cached_input_tokens: 0, inference_time_ms: 500, } as ReturnType extends Promise ? T : never); @@ -856,6 +862,8 @@ describe("No-timeout success paths", () => { { elementId: "1-0", description: "Submit button", + method: "click", + arguments: [], }, ], prompt_tokens: 150, @@ -986,6 +994,8 @@ describe("No-timeout success paths", () => { twoStep: false, prompt_tokens: 100, completion_tokens: 50, + reasoning_tokens: 0, + cached_input_tokens: 0, inference_time_ms: 500, } as ReturnType extends Promise ? T : never); @@ -1043,6 +1053,8 @@ describe("No-timeout success paths", () => { twoStep: false, prompt_tokens: 100, completion_tokens: 50, + reasoning_tokens: 0, + cached_input_tokens: 0, inference_time_ms: 500, } as ReturnType extends Promise ? T : never); diff --git a/packages/evals/lib/AISdkClientWrapped.ts b/packages/evals/lib/AISdkClientWrapped.ts index 8747a43dc5..10aa019656 100644 --- a/packages/evals/lib/AISdkClientWrapped.ts +++ b/packages/evals/lib/AISdkClientWrapped.ts @@ -2,6 +2,7 @@ import { ModelMessage, ImagePart, NoObjectGeneratedError, + Output, TextPart, ToolSet, Tool, @@ -20,7 +21,7 @@ import { } from "@browserbasehq/stagehand"; // Wrap AI SDK functions with Braintrust for tracing -const { generateObject, generateText } = wrapAISDK(ai); +const { generateText } = wrapAISDK(ai); export class AISdkClientWrapped extends LLMClient { public type = "aisdk" as const; @@ -161,7 +162,6 @@ export class AISdkClientWrapped extends LLMClient { }, ); - let objectResponse: Awaited>; const isGPT5 = this.model.modelId.includes("gpt-5"); const isCodex = this.model.modelId.includes("codex"); const isDeepSeek = this.model.modelId.includes("deepseek"); @@ -175,6 +175,7 @@ export class AISdkClientWrapped extends LLMClient { const resolvedReasoningEffort = userReasoningEffort ?? (isGPT5SubModel ? "none" : undefined); if (options.response_model) { + let objectResponse: Awaited>; if (isDeepSeek || isKimi) { const parsedSchema = JSON.stringify( toJsonSchema(options.response_model.schema), @@ -188,11 +189,18 @@ You must respond in JSON format. respond WITH JSON. Do not include any other tex } try { - objectResponse = await generateObject({ + objectResponse = await generateText({ model: this.model, messages: formattedMessages, - schema: options.response_model.schema, + output: Output.object({ + schema: options.response_model.schema, + name: options.response_model.name, + }), temperature, + maxOutputTokens: options.maxOutputTokens, + topP: options.top_p, + frequencyPenalty: options.frequency_penalty, + presencePenalty: options.presence_penalty, providerOptions: resolvedReasoningEffort ? { openai: { @@ -244,7 +252,7 @@ You must respond in JSON format. respond WITH JSON. Do not include any other tex } const result = { - data: objectResponse.object, + data: objectResponse.output, usage: { prompt_tokens: objectResponse.usage.inputTokens ?? 0, completion_tokens: objectResponse.usage.outputTokens ?? 0, @@ -261,7 +269,7 @@ You must respond in JSON format. respond WITH JSON. Do not include any other tex auxiliary: { response: { value: JSON.stringify({ - object: objectResponse.object, + output: objectResponse.output, usage: objectResponse.usage, finishReason: objectResponse.finishReason, // Omit request and response properties that might contain images