Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions apps/web/src/lib/ai-gateway/api-request-log-errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,25 @@ import type Anthropic from '@anthropic-ai/sdk';
import { createParser } from 'eventsource-parser';
import type { GatewayRequest } from '@/lib/ai-gateway/providers/openrouter/types';

export const toolCallArgumentErrorSchema = z.object({
tool_call_id: z.string(),
tool_name: z.string(),
kind: z.enum(['unparseable_json', 'schema_mismatch']),
details: z.string().optional(),
});
export const toolCallArgumentErrorSchema = z.discriminatedUnion('kind', [
z.object({
tool_call_id: z.string(),
tool_name: z.string(),
kind: z.literal('unparseable_json'),
details: z.string(),
}),
z.object({
tool_call_id: z.string(),
tool_name: z.string(),
kind: z.literal('schema_mismatch'),
details: z.unknown(),
}),
z.object({
tool_call_id: z.string(),
tool_name: z.string(),
kind: z.literal('unknown_tool'),
}),
]);

export const apiRequestLogErrorSchema = z.object({
invalid_tool_call_arguments: z.array(toolCallArgumentErrorSchema),
Expand All @@ -19,6 +32,21 @@ export type ApiRequestLogError = z.infer<typeof apiRequestLogErrorSchema>;

type ToolCallError = z.infer<typeof toolCallArgumentErrorSchema>;

function checkKnownTool(
knownToolNames: Set<string>,
toolCallId: string,
toolName: string,
errors: ToolCallError[]
): boolean {
if (knownToolNames.has(toolName)) return true;
errors.push({
tool_call_id: toolCallId,
tool_name: toolName,
kind: 'unknown_tool',
});
return false;
}

function validateAgainstSchema(
parsedArgs: unknown,
parameters: unknown,
Expand All @@ -40,7 +68,7 @@ function validateAgainstSchema(
tool_call_id: toolCallId,
tool_name: toolName,
kind: 'schema_mismatch',
details: result.error.message,
details: z.treeifyError(result.error),
});
}
}
Expand Down Expand Up @@ -86,8 +114,10 @@ function detectChatCompletionSseErrors(
tools: OpenAI.Chat.ChatCompletionTool[] | null | undefined
): ToolCallError[] {
const toolSchemaByName = new Map<string, unknown>();
const knownToolNames = new Set<string>();
for (const tool of tools ?? []) {
if (tool.type === 'function') {
knownToolNames.add(tool.function.name);
toolSchemaByName.set(tool.function.name, tool.function.parameters);
}
}
Expand All @@ -109,6 +139,7 @@ function detectChatCompletionSseErrors(
const errors: ToolCallError[] = [];
for (const [, acc] of byIndex) {
if (!acc.name) continue;
if (!checkKnownTool(knownToolNames, acc.id, acc.name, errors)) continue;
const result = parseArgsString(acc.arguments, acc.id, acc.name, errors);
if (result.ok) {
validateAgainstSchema(
Expand All @@ -128,8 +159,10 @@ function detectResponsesSseErrors(
tools: OpenAI.Responses.ResponseCreateParams['tools']
): ToolCallError[] {
const toolSchemaByName = new Map<string, unknown>();
const knownToolNames = new Set<string>();
for (const tool of tools ?? []) {
if (tool.type === 'function') {
knownToolNames.add(tool.name);
toolSchemaByName.set(tool.name, tool.parameters);
}
}
Expand All @@ -143,6 +176,7 @@ function detectResponsesSseErrors(
const callId: string = event.item.call_id;
const name: string = event.item.name;
const argsStr: string = event.item.arguments;
if (!checkKnownTool(knownToolNames, callId, name, errors)) continue;
const result = parseArgsString(argsStr, callId, name, errors);
if (result.ok) {
validateAgainstSchema(result.parsed, toolSchemaByName.get(name), callId, name, errors);
Expand All @@ -156,7 +190,9 @@ function detectMessagesSseErrors(
tools: Anthropic.MessageCreateParams['tools']
): ToolCallError[] {
const toolSchemaByName = new Map<string, unknown>();
const knownToolNames = new Set<string>();
for (const tool of tools ?? []) {
knownToolNames.add(tool.name);
// Anthropic.Tool has input_schema; server tools (BashTool, TextEditorTool, etc.) do not
if ('input_schema' in tool) {
toolSchemaByName.set(tool.name, tool.input_schema);
Expand All @@ -179,6 +215,7 @@ function detectMessagesSseErrors(

const errors: ToolCallError[] = [];
for (const [, acc] of byIndex) {
if (!checkKnownTool(knownToolNames, acc.id, acc.name, errors)) continue;
const result = parseArgsString(acc.arguments, acc.id, acc.name, errors);
if (result.ok) {
// acc.arguments is accumulated JSON — validate against tool schema
Expand Down
Loading