diff --git a/README.md b/README.md index e5b390991..eccaa79d9 100644 --- a/README.md +++ b/README.md @@ -184,11 +184,12 @@ The server exposes several endpoints to interact with the Copilot API. It provid These endpoints mimic the OpenAI API structure. -| Endpoint | Method | Description | -| --------------------------- | ------ | --------------------------------------------------------- | -| `POST /v1/chat/completions` | `POST` | Creates a model response for the given chat conversation. | -| `GET /v1/models` | `GET` | Lists the currently available models. | -| `POST /v1/embeddings` | `POST` | Creates an embedding vector representing the input text. | +| Endpoint | Method | Description | +| --------------------------- | ------ | ---------------------------------------------------------------- | +| `POST /v1/responses` | `POST` | Most advanced interface for generating model responses. | +| `POST /v1/chat/completions` | `POST` | Creates a model response for the given chat conversation. | +| `GET /v1/models` | `GET` | Lists the currently available models. | +| `POST /v1/embeddings` | `POST` | Creates an embedding vector representing the input text. | ### Anthropic Compatible Endpoints @@ -304,7 +305,16 @@ Here is an example `.claude/settings.json` file: "ANTHROPIC_BASE_URL": "http://localhost:4141", "ANTHROPIC_AUTH_TOKEN": "dummy", "ANTHROPIC_MODEL": "gpt-4.1", - "ANTHROPIC_SMALL_FAST_MODEL": "gpt-4.1" + "ANTHROPIC_DEFAULT_SONNET_MODEL": "gpt-4.1", + "ANTHROPIC_SMALL_FAST_MODEL": "gpt-4.1", + "ANTHROPIC_DEFAULT_HAIKU_MODEL": "gpt-4.1", + "DISABLE_NON_ESSENTIAL_MODEL_CALLS": "1", + "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1" + }, + "permissions": { + "deny": [ + "WebSearch" + ] } } ``` diff --git a/src/lib/debug-logger.ts b/src/lib/debug-logger.ts new file mode 100644 index 000000000..e51c723ae --- /dev/null +++ b/src/lib/debug-logger.ts @@ -0,0 +1,179 @@ +import { existsSync, mkdirSync } from "node:fs" +import { writeFile } from "node:fs/promises" +import { join } from "node:path" + +import type { GeminiRequest } from "~/routes/generate-content/types" +import type { + ChatCompletionsPayload, + ChatCompletionResponse, +} from "~/services/copilot/create-chat-completions" + +interface DebugLogData { + timestamp: string + requestId: string + originalGeminiPayload: GeminiRequest + translatedOpenAIPayload: ChatCompletionsPayload | null + error?: string + processingTime?: number +} + +export class DebugLogger { + private static instance: DebugLogger | undefined + private logDir: string + + private constructor() { + this.logDir = process.env.DEBUG_LOG_DIR || join(process.cwd(), "debug-logs") + this.ensureLogDir() + } + + static getInstance(): DebugLogger { + if (!DebugLogger.instance) { + DebugLogger.instance = new DebugLogger() + } + return DebugLogger.instance + } + + private ensureLogDir(): void { + if (!existsSync(this.logDir)) { + mkdirSync(this.logDir, { recursive: true }) + } + } + + private generateLogFileName(requestId: string): string { + const timestamp = new Date().toISOString().replaceAll(/[:.]/g, "-") + return join(this.logDir, `debug-gemini-${timestamp}-${requestId}.log`) + } + + async logRequest(data: { + requestId: string + geminiPayload: GeminiRequest + openAIPayload?: ChatCompletionsPayload | null + error?: string + processingTime?: number + }): Promise { + const logData: DebugLogData = { + timestamp: new Date().toISOString(), + requestId: data.requestId, + originalGeminiPayload: data.geminiPayload, + translatedOpenAIPayload: data.openAIPayload ?? null, + error: data.error, + processingTime: data.processingTime, + } + + const logPath = this.generateLogFileName(data.requestId) + + try { + await writeFile(logPath, JSON.stringify(logData, null, 2), "utf8") + console.log(`[DEBUG] Logged request data to: ${logPath}`) + } catch (writeError) { + console.error(`[DEBUG] Failed to write log file ${logPath}:`, writeError) + } + } + + // For backward compatibility during development + static async logGeminiRequest( + geminiPayload: GeminiRequest, + openAIPayload?: ChatCompletionsPayload, + error?: string, + ): Promise { + const logger = DebugLogger.getInstance() + const requestId = Math.random().toString(36).slice(2, 8) + await logger.logRequest({ requestId, geminiPayload, openAIPayload, error }) + } + + // Log GitHub Copilot API Response + static async logCopilotResponse( + response: ChatCompletionResponse, + context?: string, + ): Promise { + const logger = DebugLogger.getInstance() + const requestId = Math.random().toString(36).slice(2, 8) + const timestamp = new Date().toISOString().replaceAll(/[:.]/g, "-") + const logPath = join( + logger.logDir, + `debug-copilot-response-${timestamp}-${requestId}.log`, + ) + + const logData = { + timestamp: new Date().toISOString(), + context: context || "GitHub Copilot API Response", + response, + } + + try { + await writeFile(logPath, JSON.stringify(logData, null, 2), "utf8") + console.log(`[DEBUG] Logged Copilot response to: ${logPath}`) + } catch (writeError) { + console.error( + `[DEBUG] Failed to write Copilot response log file ${logPath}:`, + writeError, + ) + } + } + + // Log any object for debugging purposes + static async logDebugData( + data: unknown, + context: string, + filePrefix = "debug-data", + ): Promise { + const logger = DebugLogger.getInstance() + const requestId = Math.random().toString(36).slice(2, 8) + const timestamp = new Date().toISOString().replaceAll(/[:.]/g, "-") + const logPath = join( + logger.logDir, + `${filePrefix}-${timestamp}-${requestId}.log`, + ) + + const logData = { + timestamp: new Date().toISOString(), + context, + data, + } + + try { + await writeFile(logPath, JSON.stringify(logData, null, 2), "utf8") + console.log(`[DEBUG] Logged ${context} to: ${logPath}`) + } catch (writeError) { + console.error( + `[DEBUG] Failed to write debug log file ${logPath}:`, + writeError, + ) + } + } + + // Log original and translated response comparison + static async logResponseComparison( + originalResponse: unknown, + translatedResponse: unknown, + options: { context: string; filePrefix?: string } = { + context: "Response Comparison", + }, + ): Promise { + const { context, filePrefix = "debug-comparison" } = options + const logger = DebugLogger.getInstance() + const requestId = Math.random().toString(36).slice(2, 8) + const timestamp = new Date().toISOString().replaceAll(/[:.]/g, "-") + const logPath = join( + logger.logDir, + `${filePrefix}-${timestamp}-${requestId}.log`, + ) + + const logData = { + timestamp: new Date().toISOString(), + context, + originalResponse, + translatedResponse, + } + + try { + await writeFile(logPath, JSON.stringify(logData, null, 2), "utf8") + console.log(`[DEBUG] Logged ${context} comparison to: ${logPath}`) + } catch (writeError) { + console.error( + `[DEBUG] Failed to write comparison log file ${logPath}:`, + writeError, + ) + } + } +} diff --git a/src/lib/tokenizer.ts b/src/lib/tokenizer.ts index 73cd499f9..8c3eda736 100644 --- a/src/lib/tokenizer.ts +++ b/src/lib/tokenizer.ts @@ -1,37 +1,345 @@ -import { countTokens } from "gpt-tokenizer/model/gpt-4o" +import type { + ChatCompletionsPayload, + ContentPart, + Message, + Tool, + ToolCall, +} from "~/services/copilot/create-chat-completions" +import type { Model } from "~/services/copilot/get-models" -import type { Message } from "~/services/copilot/create-chat-completions" +// Encoder type mapping +const ENCODING_MAP = { + o200k_base: () => import("gpt-tokenizer/encoding/o200k_base"), + cl100k_base: () => import("gpt-tokenizer/encoding/cl100k_base"), + p50k_base: () => import("gpt-tokenizer/encoding/p50k_base"), + p50k_edit: () => import("gpt-tokenizer/encoding/p50k_edit"), + r50k_base: () => import("gpt-tokenizer/encoding/r50k_base"), +} as const -export const getTokenCount = (messages: Array) => { - const simplifiedMessages = messages.map((message) => { - let content = "" - if (typeof message.content === "string") { - content = message.content - } else if (Array.isArray(message.content)) { - content = message.content - .filter((part) => part.type === "text") - .map((part) => (part as { text: string }).text) - .join("") +type SupportedEncoding = keyof typeof ENCODING_MAP + +// Define encoder interface +interface Encoder { + encode: (text: string) => Array +} + +// Cache loaded encoders to avoid repeated imports +const encodingCache = new Map() + +/** + * Calculate tokens for tool calls + */ +const calculateToolCallsTokens = ( + toolCalls: Array, + encoder: Encoder, + constants: ReturnType, +): number => { + let tokens = 0 + for (const toolCall of toolCalls) { + tokens += constants.funcInit + tokens += encoder.encode(JSON.stringify(toolCall)).length + } + tokens += constants.funcEnd + return tokens +} + +/** + * Calculate tokens for content parts + */ +const calculateContentPartsTokens = ( + contentParts: Array, + encoder: Encoder, +): number => { + let tokens = 0 + for (const part of contentParts) { + if (part.type === "image_url") { + tokens += encoder.encode(part.image_url.url).length + 85 + } else if (part.text) { + tokens += encoder.encode(part.text).length + } + } + return tokens +} + +/** + * Calculate tokens for a single message + */ +const calculateMessageTokens = ( + message: Message, + encoder: Encoder, + constants: ReturnType, +): number => { + const tokensPerMessage = 3 + const tokensPerName = 1 + let tokens = tokensPerMessage + for (const [key, value] of Object.entries(message)) { + if (typeof value === "string") { + tokens += encoder.encode(value).length + } + if (key === "name") { + tokens += tokensPerName + } + if (key === "tool_calls") { + tokens += calculateToolCallsTokens( + value as Array, + encoder, + constants, + ) + } + if (key === "content" && Array.isArray(value)) { + tokens += calculateContentPartsTokens( + value as Array, + encoder, + ) + } + } + return tokens +} + +/** + * Calculate tokens using custom algorithm + */ +const calculateTokens = ( + messages: Array, + encoder: Encoder, + constants: ReturnType, +): number => { + if (messages.length === 0) { + return 0 + } + let numTokens = 0 + for (const message of messages) { + numTokens += calculateMessageTokens(message, encoder, constants) + } + // every reply is primed with <|start|>assistant<|message|> + numTokens += 3 + return numTokens +} + +/** + * Get the corresponding encoder module based on encoding type + */ +const getEncodeChatFunction = async (encoding: string): Promise => { + if (encodingCache.has(encoding)) { + const cached = encodingCache.get(encoding) + if (cached) { + return cached } - return { ...message, content } - }) + } + + const supportedEncoding = encoding as SupportedEncoding + if (!(supportedEncoding in ENCODING_MAP)) { + const fallbackModule = (await ENCODING_MAP.o200k_base()) as Encoder + encodingCache.set(encoding, fallbackModule) + return fallbackModule + } + + const encodingModule = (await ENCODING_MAP[supportedEncoding]()) as Encoder + encodingCache.set(encoding, encodingModule) + return encodingModule +} + +/** + * Get tokenizer type from model information + */ +export const getTokenizerFromModel = (model: Model): string => { + return model.capabilities.tokenizer || "o200k_base" +} + +/** + * Get model-specific constants for token calculation + */ +const getModelConstants = (model: Model) => { + return model.id === "gpt-3.5-turbo" || model.id === "gpt-4" ? + { + funcInit: 10, + propInit: 3, + propKey: 3, + enumInit: -3, + enumItem: 3, + funcEnd: 12, + } + : { + funcInit: 7, + propInit: 3, + propKey: 3, + enumInit: -3, + enumItem: 3, + funcEnd: 12, + } +} - let inputMessages = simplifiedMessages.filter((message) => { - return message.role !== "tool" - }) - let outputMessages: typeof simplifiedMessages = [] +/** + * Calculate tokens for a single parameter + */ +const calculateParameterTokens = ( + key: string, + prop: unknown, + context: { + encoder: Encoder + constants: ReturnType + }, +): number => { + const { encoder, constants } = context + let tokens = constants.propKey - const lastMessage = simplifiedMessages.at(-1) + // Early return if prop is not an object + if (typeof prop !== "object" || prop === null) { + return tokens + } - if (lastMessage?.role === "assistant") { - inputMessages = simplifiedMessages.slice(0, -1) - outputMessages = [lastMessage] + // Type assertion for parameter properties + const param = prop as { + type?: string + description?: string + enum?: Array + [key: string]: unknown } - // @ts-expect-error TS can't infer from arr.filter() - const inputTokens = countTokens(inputMessages) - // @ts-expect-error TS can't infer from arr.filter() - const outputTokens = countTokens(outputMessages) + const paramName = key + const paramType = param.type || "string" + let paramDesc = param.description || "" + + // Handle enum values + if (param.enum && Array.isArray(param.enum)) { + tokens += constants.enumInit + for (const item of param.enum) { + tokens += constants.enumItem + tokens += encoder.encode(String(item)).length + } + } + + // Clean up description + if (paramDesc.endsWith(".")) { + paramDesc = paramDesc.slice(0, -1) + } + + // Encode the main parameter line + const line = `${paramName}:${paramType}:${paramDesc}` + tokens += encoder.encode(line).length + + // Handle additional properties (excluding standard ones) + const excludedKeys = new Set(["type", "description", "enum"]) + for (const propertyName of Object.keys(param)) { + if (!excludedKeys.has(propertyName)) { + const propertyValue = param[propertyName] + const propertyText = + typeof propertyValue === "string" ? propertyValue : ( + JSON.stringify(propertyValue) + ) + tokens += encoder.encode(`${propertyName}:${propertyText}`).length + } + } + + return tokens +} + +/** + * Calculate tokens for function parameters + */ +const calculateParametersTokens = ( + parameters: unknown, + encoder: Encoder, + constants: ReturnType, +): number => { + if (!parameters || typeof parameters !== "object") { + return 0 + } + + const params = parameters as Record + let tokens = 0 + + for (const [key, value] of Object.entries(params)) { + if (key === "properties") { + const properties = value as Record + if (Object.keys(properties).length > 0) { + tokens += constants.propInit + for (const propKey of Object.keys(properties)) { + tokens += calculateParameterTokens(propKey, properties[propKey], { + encoder, + constants, + }) + } + } + } else { + const paramText = + typeof value === "string" ? value : JSON.stringify(value) + tokens += encoder.encode(`${key}:${paramText}`).length + } + } + + return tokens +} + +/** + * Calculate tokens for a single tool + */ +const calculateToolTokens = ( + tool: Tool, + encoder: Encoder, + constants: ReturnType, +): number => { + let tokens = constants.funcInit + const func = tool.function + const fName = func.name + let fDesc = func.description || "" + if (fDesc.endsWith(".")) { + fDesc = fDesc.slice(0, -1) + } + const line = fName + ":" + fDesc + tokens += encoder.encode(line).length + if ( + typeof func.parameters === "object" // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition + && func.parameters !== null + ) { + tokens += calculateParametersTokens(func.parameters, encoder, constants) + } + return tokens +} + +/** + * Calculate token count for tools based on model + */ +export const numTokensForTools = ( + tools: Array, + encoder: Encoder, + constants: ReturnType, +): number => { + let funcTokenCount = 0 + for (const tool of tools) { + funcTokenCount += calculateToolTokens(tool, encoder, constants) + } + funcTokenCount += constants.funcEnd + return funcTokenCount +} + +/** + * Calculate the token count of messages, supporting multiple GPT encoders + */ +export const getTokenCount = async ( + payload: ChatCompletionsPayload, + model: Model, +): Promise<{ input: number; output: number }> => { + // Get tokenizer string + const tokenizer = getTokenizerFromModel(model) + + // Get corresponding encoder module + const encoder = await getEncodeChatFunction(tokenizer) + + const simplifiedMessages = payload.messages + const inputMessages = simplifiedMessages.filter( + (msg) => msg.role !== "assistant", + ) + const outputMessages = simplifiedMessages.filter( + (msg) => msg.role === "assistant", + ) + + const constants = getModelConstants(model) + let inputTokens = calculateTokens(inputMessages, encoder, constants) + if (payload.tools && payload.tools.length > 0) { + inputTokens += numTokensForTools(payload.tools, encoder, constants) + } + const outputTokens = calculateTokens(outputMessages, encoder, constants) return { input: inputTokens, diff --git a/src/lib/tool-call-utils.ts b/src/lib/tool-call-utils.ts new file mode 100644 index 000000000..c17117cf9 --- /dev/null +++ b/src/lib/tool-call-utils.ts @@ -0,0 +1,287 @@ +import type { + GeminiTool, + GeminiRequest, + GeminiContent, + GeminiPart, +} from "~/routes/generate-content/types" +import type { Tool } from "~/services/copilot/create-chat-completions" + +// Tool declaration generation - moved from translation.ts +export function translateGeminiToolsToOpenAI( + geminiTools?: Array, +): Array | undefined { + if (!geminiTools || geminiTools.length === 0) return undefined + + const tools: Array = [] + for (const tool of geminiTools) { + // Handle standard function declarations + if (tool.functionDeclarations) { + for (const func of tool.functionDeclarations) { + // Validate that function name exists and is not empty + if ( + !func.name + || typeof func.name !== "string" + || func.name.trim() === "" + ) { + continue + } + + // Ensure parameters is always a valid object + const validParameters = func.parametersJsonSchema + || func.parameters || { type: "object", properties: {} } + + tools.push({ + type: "function", + function: { + name: func.name, + description: func.description, + parameters: validParameters, + }, + }) + } + } + + // Handle googleSearch tool (special case) + if (tool.googleSearch !== undefined) { + tools.push({ + type: "function", + function: { + name: "google_web_search", + description: + "Performs a web search using Google Search (via the Gemini API) and returns the results. This tool is useful for finding information on the internet based on a query.", + parameters: { + type: "object", + properties: { + query: { + type: "string", + description: "The search query to find information on the web.", + }, + }, + required: ["query"], + }, + }, + }) + } + + // Handle urlContext tool (special case for web_fetch) + // Note: GitHub Copilot API doesn't support web_fetch functionality + // Skip this tool to avoid "Failed to create chat completions" errors + if (tool.urlContext !== undefined) { + continue + } + } + + return tools.length > 0 ? tools : undefined +} + +// Tool configuration translation - moved from translation.ts +export function translateGeminiToolConfigToOpenAI( + toolConfig?: GeminiRequest["toolConfig"], +): "auto" | "required" | "none" | undefined { + if (!toolConfig) return undefined + + const mode = toolConfig.functionCallingConfig.mode + switch (mode) { + case "AUTO": { + return "auto" + } + case "ANY": { + return "required" + } + case "NONE": { + return "none" + } + default: { + return undefined + } + } +} + +// Utility function to generate unique tool call IDs - moved from translation.ts +// Generate IDs within 40 character limit (API constraint) +export function generateToolCallId(_functionName: string): string { + const timestamp = Date.now().toString(36) // Base36 for shorter encoding + const random = Math.random().toString(36).slice(2, 8) // 6 chars random + return `call_${timestamp}_${random}` // Format: call_{timestamp}_{random} +} + +// Helper function to try parsing and creating a function call - moved from translation.ts +// NOTE: Used internally by ToolCallAccumulator.handleToolCallWithName() and handleToolCallAccumulation() +// knip may report this as unused, but it's called within this module's class methods +export function tryCreateFunctionCall( + name: string, + argumentsStr: string, +): GeminiPart | null { + try { + const args = JSON.parse(argumentsStr) as Record + return { + functionCall: { + name, + args, + }, + } + } catch { + return null + } +} + +// Tool synthesis from contents - moved from translation.ts +export function synthesizeToolsFromContents( + contents: Array< + | GeminiContent + | Array<{ + functionResponse: { id?: string; name: string; response: unknown } + }> + >, +): Array | undefined { + const names = new Set() + for (const item of contents) { + if (Array.isArray(item)) continue + for (const part of item.parts) { + if ("functionCall" in part && part.functionCall.name) { + names.add(part.functionCall.name) + } + } + } + if (names.size === 0) return undefined + return Array.from(names).map((name) => ({ + type: "function", + function: { name, parameters: { type: "object", properties: {} } }, + })) +} + +/** + * Tool call state manager for incremental parameter accumulation in streaming responses + */ +export class ToolCallAccumulator { + private accumulator = new Map< + number, + { + name: string + arguments: string + id?: string + } + >() + + /** + * Handle tool call with function name (start of new tool call) + */ + handleToolCallWithName(toolCall: { + index: number + id?: string + function: { + name: string + arguments?: string + } + }): GeminiPart | null { + const accumulatedArgs = toolCall.function.arguments || "" + + this.accumulator.set(toolCall.index, { + name: toolCall.function.name, + arguments: accumulatedArgs, + id: toolCall.id, + }) + + // If we already have arguments, try to process immediately (for non-streaming models like Gemini) + if (accumulatedArgs) { + const functionCall = tryCreateFunctionCall( + toolCall.function.name, + accumulatedArgs, + ) + if (functionCall) { + // Clear the accumulator for this index since we've successfully processed it + this.accumulator.delete(toolCall.index) + return functionCall + } + } + + return null + } + + /** + * Handle tool call parameter accumulation (append argument fragments) + */ + handleToolCallAccumulation(toolCall: { + index: number + function?: { + arguments?: string + } + }): GeminiPart | null { + const existingAccumulated = this.accumulator.get(toolCall.index) + + if (existingAccumulated && toolCall.function?.arguments) { + existingAccumulated.arguments += toolCall.function.arguments + + const functionCall = tryCreateFunctionCall( + existingAccumulated.name, + existingAccumulated.arguments, + ) + if (functionCall) { + // Clear the accumulator for this index since we've successfully processed it + this.accumulator.delete(toolCall.index) + return functionCall + } + } + + return null + } + + /** + * Clear all accumulated state (for stream end or error reset) + */ + clear(): void { + this.accumulator.clear() + } +} + +/** + * Process tool calls array and generate Gemini format parts + * Supports both complete parameters and fragmented parameters modes + */ +export function processToolCalls( + toolCalls: Array<{ + index: number + id?: string + type?: "function" + function?: { + name?: string + arguments?: string + } + }>, + accumulator: ToolCallAccumulator, +): Array { + const parts: Array = [] + + for (const toolCall of toolCalls) { + // Debug: Log streaming tool call arguments to verify what GitHub Copilot returns + if (process.env.DEBUG_GEMINI_REQUESTS === "true") { + console.log( + `[DEBUG STREAM] Tool call - name: ${toolCall.function?.name}, arguments: "${toolCall.function?.arguments}", type: ${typeof toolCall.function?.arguments}, truthy: ${Boolean(toolCall.function?.arguments)}`, + ) + } + + // If this chunk has a function name, it's the start of a new tool call + if (toolCall.function?.name && toolCall.function.name.trim() !== "") { + const functionCall = accumulator.handleToolCallWithName({ + index: toolCall.index, + id: toolCall.id, + function: { + name: toolCall.function.name, + arguments: toolCall.function.arguments, + }, + }) + if (functionCall) { + parts.push(functionCall) + } + continue + } + + // If we have existing accumulated data and this chunk has arguments, append them + const functionCall = accumulator.handleToolCallAccumulation(toolCall) + if (functionCall) { + parts.push(functionCall) + } + } + + return parts +} diff --git a/src/routes/chat-completions/handler.ts b/src/routes/chat-completions/handler.ts index 6e49029b8..04a5ae9ed 100644 --- a/src/routes/chat-completions/handler.ts +++ b/src/routes/chat-completions/handler.ts @@ -20,15 +20,26 @@ export async function handleCompletion(c: Context) { let payload = await c.req.json() consola.debug("Request payload:", JSON.stringify(payload).slice(-400)) - consola.info("Current token count:", getTokenCount(payload.messages)) + // Find the selected model + const selectedModel = state.models?.data.find( + (model) => model.id === payload.model, + ) + + // Calculate and display token count + try { + if (selectedModel) { + const tokenCount = await getTokenCount(payload, selectedModel) + consola.info("Current token count:", tokenCount) + } else { + consola.warn("No model selected, skipping token count calculation") + } + } catch (error) { + consola.warn("Failed to calculate token count:", error) + } if (state.manualApprove) await awaitApproval() if (isNullish(payload.max_tokens)) { - const selectedModel = state.models?.data.find( - (model) => model.id === payload.model, - ) - payload = { ...payload, max_tokens: selectedModel?.capabilities.limits.max_output_tokens, diff --git a/src/routes/generate-content/handler.ts b/src/routes/generate-content/handler.ts new file mode 100644 index 000000000..26c3c3a43 --- /dev/null +++ b/src/routes/generate-content/handler.ts @@ -0,0 +1,322 @@ +import type { Context } from "hono" +import type { SSEStreamingApi } from "hono/streaming" + +import { streamSSE } from "hono/streaming" + +import { awaitApproval } from "~/lib/approval" +import { DebugLogger } from "~/lib/debug-logger" +import { checkRateLimit } from "~/lib/rate-limit" +import { state } from "~/lib/state" +import { getTokenCount } from "~/lib/tokenizer" +import { + createChatCompletions, + type ChatCompletionResponse, + type ChatCompletionChunk, +} from "~/services/copilot/create-chat-completions" + +// Helper function to extract model from URL path +function extractModelFromUrl(url: string): string { + const match = url.match(/\/v1beta\/models\/([^:]+):/) + if (!match) { + throw new Error("Model name is required in URL path") + } + return match[1] +} + +import { ToolCallAccumulator } from "~/lib/tool-call-utils" + +import { + translateGeminiToOpenAI, + translateOpenAIToGemini, + translateGeminiCountTokensToOpenAI, + translateTokenCountToGemini, + translateOpenAIChunkToGemini, +} from "./translation" +import { + type GeminiRequest, + type GeminiCountTokensRequest, + type GeminiStreamResponse, + type GeminiResponse, +} from "./types" + +// Unified generation handler following Claude's two-branch pattern +export async function handleGeminiGeneration( + c: Context, + stream: boolean = false, +) { + const model = extractModelFromUrl(c.req.url) + + if (!model) { + throw new Error("Model name is required in URL path") + } + + await checkRateLimit(state) + + const geminiPayload = await c.req.json() + const openAIPayload = translateGeminiToOpenAI(geminiPayload, model, stream) + + // Log request for debugging (async, non-blocking) - only if debug logging is enabled + if (process.env.DEBUG_GEMINI_REQUESTS === "true") { + DebugLogger.logGeminiRequest(geminiPayload, openAIPayload).catch( + (error: unknown) => { + console.error("[DEBUG] Failed to log request:", error) + }, + ) + } + + if (state.manualApprove) { + await awaitApproval() + } + + const response = await createChatCompletions(openAIPayload) + + if (isNonStreaming(response)) { + const geminiResponse = translateOpenAIToGemini(response) + + if (stream) { + return handleNonStreamingToStreaming(c, geminiResponse) + } + return c.json(geminiResponse) + } + + if (!stream) { + throw new Error("Unexpected streaming response for non-streaming endpoint") + } + + return handleStreamingResponse(c, response) +} + +// Helper function to handle non-streaming response conversion +function handleNonStreamingToStreaming( + c: Context, + geminiResponse: GeminiResponse, +) { + return streamSSE(c, async (stream) => { + try { + const firstPart = geminiResponse.candidates[0]?.content?.parts?.[0] + // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition + const hasTextContent = firstPart && "text" in firstPart + + // eslint-disable-next-line unicorn/prefer-ternary + if (hasTextContent) { + await sendTextInChunks(stream, firstPart.text, geminiResponse) + } else { + await sendFallbackResponse(stream, geminiResponse) + } + + // Add a small delay to ensure all data is flushed + await new Promise((resolve) => setTimeout(resolve, 50)) + } catch (error) { + console.error("[GEMINI_STREAM] Error in non-streaming conversion", error) + } finally { + try { + await stream.close() + } catch (closeError) { + console.error( + "[GEMINI_STREAM] Error closing non-streaming conversion stream", + closeError, + ) + } + } + }) +} + +// Helper function to send text in chunks with configuration object +async function sendTextInChunks( + stream: SSEStreamingApi, + text: string, + geminiResponse: GeminiResponse, +) { + const chunkSize = Math.max(1, Math.min(50, text.length)) + let lastWritePromise: Promise = Promise.resolve() + + for (let i = 0; i < text.length; i += chunkSize) { + const chunk = text.slice(i, i + chunkSize) + const isLast = i + chunkSize >= text.length + const streamResponse: GeminiStreamResponse = { + candidates: [ + { + content: { + parts: [{ text: chunk }], + role: "model", + }, + finishReason: + isLast ? geminiResponse.candidates[0]?.finishReason : undefined, + index: 0, + }, + ], + ...(isLast && geminiResponse.usageMetadata ? + { usageMetadata: geminiResponse.usageMetadata } + : {}), + } + + // Wait for previous write to complete before writing new chunk + await lastWritePromise + lastWritePromise = stream.writeSSE({ + data: JSON.stringify(streamResponse), + }) + } + + // Wait for final write to complete + await lastWritePromise +} + +// Helper function to send fallback response +async function sendFallbackResponse( + stream: SSEStreamingApi, + geminiResponse: GeminiResponse, +) { + const streamResponse: GeminiStreamResponse = { + candidates: geminiResponse.candidates, + usageMetadata: geminiResponse.usageMetadata, + } + + await stream.writeSSE({ data: JSON.stringify(streamResponse) }) +} + +// Simplified Gemini streaming state (inspired by Claude AnthropicStreamState) +interface GeminiStreamState { + jsonAccumulator: string + parseMode: "direct" | "accumulated" +} + +// Minimal state machine for JSON parsing only +class GeminiStreamParser { + private state: GeminiStreamState = { + jsonAccumulator: "", + parseMode: "direct", + } + + parseChunk(rawData: string): unknown { + if (this.state.parseMode === "direct") { + try { + return JSON.parse(rawData) + } catch { + // Switch to accumulated mode on first failure + this.state.parseMode = "accumulated" + this.state.jsonAccumulator = rawData + return null + } + } else { + // Accumulated mode - keep building until valid JSON + this.state.jsonAccumulator += rawData + try { + const result = JSON.parse(this.state.jsonAccumulator) as unknown + // Success - reset for next chunk + this.resetAccumulator() + return result + } catch { + // Continue accumulating + return null + } + } + } + + private resetAccumulator(): void { + this.state.jsonAccumulator = "" + this.state.parseMode = "direct" + } +} + +// Helper function to handle streaming response processing +function handleStreamingResponse( + c: Context, + response: AsyncIterable<{ data?: string }>, +) { + return streamSSE(c, async (stream) => { + // Create a parser instance for this stream (each request gets its own parser) + const streamParser = new GeminiStreamParser() + // Create a tool call accumulator for this stream + const toolCallAccumulator = new ToolCallAccumulator() + let lastWritePromise: Promise = Promise.resolve() + + try { + for await (const rawEvent of response) { + if (rawEvent.data === "[DONE]") { + break + } + + // Inline processing without extra wrapper + if (!rawEvent.data) { + continue + } + + try { + const chunk = streamParser.parseChunk(rawEvent.data) + if (!chunk) { + continue + } + + const geminiChunk = translateOpenAIChunkToGemini( + chunk as ChatCompletionChunk, + toolCallAccumulator, + ) + if (geminiChunk) { + // Wait for previous write to complete before writing new chunk + await lastWritePromise + lastWritePromise = stream.writeSSE({ + data: JSON.stringify(geminiChunk), + }) + } + } catch (parseError) { + console.error("[GEMINI_STREAM] Error parsing chunk", parseError) + continue + } + } + + // Wait for all writes to complete before closing + await lastWritePromise + + // Add a small delay to ensure all data is flushed + await new Promise((resolve) => setTimeout(resolve, 50)) + } catch (error) { + console.error("[GEMINI_STREAM] Error in streaming processing", error) + // Ensure we don't leave the stream hanging + } finally { + // Always close the stream, but with proper cleanup + try { + await stream.close() + } catch (closeError) { + console.error("[GEMINI_STREAM] Error closing stream", closeError) + } + } + }) +} + +// Create convenience wrapper for streaming generation +export function handleGeminiStreamGeneration(c: Context) { + return handleGeminiGeneration(c, true) +} + +// Token counting endpoint +export async function handleGeminiCountTokens(c: Context) { + const model = extractModelFromUrl(c.req.url) + + if (!model) { + throw new Error("Model name is required in URL path") + } + + const geminiPayload = await c.req.json() + + const openAIPayload = translateGeminiCountTokensToOpenAI(geminiPayload, model) + + // Find the full Model object from state + const selectedModel = state.models?.data.find((m) => m.id === model) + + if (!selectedModel) { + // Fallback: return minimal token count if model not found + const geminiResponse = translateTokenCountToGemini(10) + return c.json(geminiResponse) + } + + const tokenCounts = await getTokenCount(openAIPayload, selectedModel) + + const totalTokens = tokenCounts.input + tokenCounts.output + const geminiResponse = translateTokenCountToGemini(totalTokens) + + return c.json(geminiResponse) +} + +const isNonStreaming = ( + response: Awaited>, +): response is ChatCompletionResponse => "choices" in response diff --git a/src/routes/generate-content/route.ts b/src/routes/generate-content/route.ts new file mode 100644 index 000000000..64cb48a45 --- /dev/null +++ b/src/routes/generate-content/route.ts @@ -0,0 +1,67 @@ +import { Hono } from "hono" + +import { forwardError } from "~/lib/error" + +import { + handleGeminiGeneration, + handleGeminiStreamGeneration, + handleGeminiCountTokens, +} from "./handler" + +function isStreamGenerate(url: string): boolean { + return url.includes(":streamGenerateContent") +} +function isCountTokens(url: string): boolean { + return url.includes(":countTokens") +} +function isGenerate(url: string): boolean { + return ( + url.includes(":generateContent") && !url.includes(":streamGenerateContent") + ) +} + +const router = new Hono() + +// Streaming generation endpoint +// POST /v1beta/models/{model}:streamGenerateContent +router.post("/v1beta/models/*", async (c, next) => { + const url = c.req.url + if (isStreamGenerate(url)) { + try { + return await handleGeminiStreamGeneration(c) + } catch (error) { + return await forwardError(c, error) + } + } + await next() +}) + +// Token counting endpoint +// POST /v1beta/models/{model}:countTokens +router.post("/v1beta/models/*", async (c, next) => { + const url = c.req.url + if (isCountTokens(url)) { + try { + return await handleGeminiCountTokens(c) + } catch (error) { + return await forwardError(c, error) + } + } + await next() +}) + +// Standard generation endpoint +// POST /v1beta/models/{model}:generateContent +router.post("/v1beta/models/*", async (c, next) => { + const url = c.req.url + if (isGenerate(url)) { + try { + return await handleGeminiGeneration(c) + } catch (error) { + return await forwardError(c, error) + } + } + await next() +}) + +export { router as geminiRouter } diff --git a/src/routes/generate-content/translation.ts b/src/routes/generate-content/translation.ts new file mode 100644 index 000000000..8f4d70937 --- /dev/null +++ b/src/routes/generate-content/translation.ts @@ -0,0 +1,724 @@ +import { DebugLogger } from "~/lib/debug-logger" +import { + translateGeminiToolsToOpenAI, + translateGeminiToolConfigToOpenAI, + generateToolCallId, + synthesizeToolsFromContents, + ToolCallAccumulator, + processToolCalls as processToolCallsWithAccumulator, +} from "~/lib/tool-call-utils" +import { + type ChatCompletionResponse, + type ChatCompletionChunk, + type ChatCompletionsPayload, + type ContentPart, + type Message, + type Tool, + type ToolCall, +} from "~/services/copilot/create-chat-completions" + +import { + type GeminiRequest, + type GeminiResponse, + type GeminiContent, + type GeminiPart, + type GeminiTextPart, + type GeminiFunctionCallPart, + type GeminiFunctionResponsePart, + type GeminiTool, + type GeminiCandidate, + type GeminiCountTokensRequest, + type GeminiCountTokensResponse, + type GeminiUsageMetadata, +} from "./types" +import { mapOpenAIFinishReasonToGemini } from "./utils" + +// Model mapping for Gemini models - only map unsupported variants to supported ones +function mapGeminiModelToCopilot(geminiModel: string): string { + const modelMap: Record = { + "gemini-2.5-flash": "gemini-2.0-flash-001", // Map to supported Gemini model + "gemini-2.0-flash": "gemini-2.0-flash-001", // Map to full model name + "gemini-2.5-flash-lite": "gemini-2.0-flash-001", // Map to full model name + } + + return modelMap[geminiModel] || geminiModel // Return original if supported +} + +function selectTools( + geminiTools?: Array, + contents?: Array< + | GeminiContent + | Array<{ + functionResponse: { id?: string; name: string; response: unknown } + }> + >, +): Array | undefined { + return ( + translateGeminiToolsToOpenAI(geminiTools) + || (contents ? synthesizeToolsFromContents(contents) : undefined) + ) +} + +// Request translation: Gemini -> OpenAI + +export function translateGeminiToOpenAI( + payload: GeminiRequest, + model: string, + stream: boolean, +): ChatCompletionsPayload { + const tools = selectTools(payload.tools, payload.contents) + const result = { + model: mapGeminiModelToCopilot(model), + messages: translateGeminiContentsToOpenAI( + payload.contents, + payload.systemInstruction, + ), + max_tokens: (payload.generationConfig?.maxOutputTokens as number) || 4096, + stop: payload.generationConfig?.stopSequences as Array | undefined, + stream, + temperature: payload.generationConfig?.temperature as number | undefined, + top_p: payload.generationConfig?.topP as number | undefined, + tools, + tool_choice: + tools ? translateGeminiToolConfigToOpenAI(payload.toolConfig) : undefined, + } + + return result +} + +// Helper function to process function response arrays +function processFunctionResponseArray( + responseArray: Array<{ + functionResponse: { name: string; response: unknown } + }>, + pendingToolCalls: Map, + messages: Array, +): void { + for (const responseItem of responseArray) { + if ("functionResponse" in responseItem) { + const functionName = responseItem.functionResponse.name + // Find tool call ID by searching through the map + let matchedToolCallId: string | undefined + for (const [ + toolCallId, + mappedFunctionName, + ] of pendingToolCalls.entries()) { + if (mappedFunctionName === functionName) { + matchedToolCallId = toolCallId + break + } + } + if (matchedToolCallId) { + messages.push({ + role: "tool", + tool_call_id: matchedToolCallId, + content: JSON.stringify(responseItem.functionResponse.response), + }) + pendingToolCalls.delete(matchedToolCallId) + } + } + } +} + +// Helper function to check if tool calls have corresponding tool responses +function hasCorrespondingToolResponses( + messages: Array, + toolCalls: Array, +): boolean { + const toolCallIds = new Set(toolCalls.map((call) => call.id)) + + // Look for tool messages that respond to these tool calls + for (const message of messages) { + if (message.role === "tool" && message.tool_call_id) { + toolCallIds.delete(message.tool_call_id) + } + } + + // If any tool call ID remains, it means there's no corresponding response + return toolCallIds.size === 0 +} + +// Helper function to process function responses in content +function processFunctionResponses( + functionResponses: Array, + pendingToolCalls: Map, + messages: Array, +): void { + for (const funcResponse of functionResponses) { + const functionName = funcResponse.functionResponse.name + // Find tool call ID by searching through the map + let matchedToolCallId: string | undefined + for (const [toolCallId, mappedFunctionName] of pendingToolCalls.entries()) { + if (mappedFunctionName === functionName) { + matchedToolCallId = toolCallId + break + } + } + if (matchedToolCallId) { + messages.push({ + role: "tool", + tool_call_id: matchedToolCallId, + content: JSON.stringify(funcResponse.functionResponse.response), + }) + pendingToolCalls.delete(matchedToolCallId) + } + } +} + +// Helper function to process function calls and create assistant message +function processFunctionCalls(options: { + functionCalls: Array + content: GeminiContent + pendingToolCalls: Map + messages: Array +}): void { + const { functionCalls, content, pendingToolCalls, messages } = options + + const textContent = extractTextFromGeminiContent(content) + const toolCalls = functionCalls.map((call) => { + const toolCallId = generateToolCallId(call.functionCall.name) + // Remember this tool call for later matching with responses + // Use tool_call_id as key to avoid duplicate function name overwrites + pendingToolCalls.set(toolCallId, call.functionCall.name) + + return { + id: toolCallId, + type: "function" as const, + function: { + name: call.functionCall.name, + arguments: JSON.stringify(call.functionCall.args), + }, + } + }) + + messages.push({ + role: "assistant", + content: textContent || null, + tool_calls: toolCalls, + }) +} + +// Helper function to check if a tool response is duplicate +function isDuplicateToolResponse( + message: Message, + seenToolCallIds: Set, +): boolean { + return ( + message.role === "tool" + && message.tool_call_id !== undefined + && seenToolCallIds.has(message.tool_call_id) + ) +} + +// Helper function to normalize user message content +function normalizeUserMessageContent(message: Message): void { + if ( + message.role === "user" + && typeof message.content === "string" + && !message.content.trim() + ) { + message.content = " " // Add minimal text content as fallback + } +} + +// Helper function to check if messages can be merged +function canMergeMessages( + lastMessage: Message, + currentMessage: Message, +): boolean { + return ( + lastMessage.role === currentMessage.role + && !lastMessage.tool_calls + && !currentMessage.tool_calls + && !(lastMessage as { tool_call_id?: string }).tool_call_id + && !(currentMessage as { tool_call_id?: string }).tool_call_id + && typeof lastMessage.content === "string" + && typeof currentMessage.content === "string" + ) +} + +// Helper function to check if message should be skipped +function shouldSkipMessage( + message: Message, + messages: Array, + seenToolCallIds: Set, +): boolean { + // Skip incomplete assistant messages with tool calls that have no responses + if ( + message.role === "assistant" + && message.tool_calls + && !hasCorrespondingToolResponses(messages, message.tool_calls) + ) { + return true + } + + // Skip duplicate tool responses + if (isDuplicateToolResponse(message, seenToolCallIds)) { + return true + } + + return false +} + +// Helper function to process and add message to cleaned array +function processAndAddMessage( + message: Message, + cleanedMessages: Array, + seenToolCallIds: Set, +): void { + // Track tool call IDs for deduplication + if (message.role === "tool" && message.tool_call_id) { + seenToolCallIds.add(message.tool_call_id) + } + + // Normalize user message content + normalizeUserMessageContent(message) + + // Try to merge with previous message + const lastMessage = cleanedMessages.at(-1) + if (lastMessage && canMergeMessages(lastMessage, message)) { + // Merge with previous message of same role + // canMergeMessages already ensures both contents are strings + if ( + typeof lastMessage.content === "string" + && typeof message.content === "string" + ) { + lastMessage.content = `${lastMessage.content}\n\n${message.content}` + } + } else { + cleanedMessages.push(message) + } +} + +// Consolidated message cleanup function +function cleanupMessages(messages: Array): Array { + const cleanedMessages: Array = [] + const seenToolCallIds = new Set() + + for (const message of messages) { + if (shouldSkipMessage(message, messages, seenToolCallIds)) { + continue + } + + processAndAddMessage(message, cleanedMessages, seenToolCallIds) + } + + return cleanedMessages +} + +function translateGeminiContentsToOpenAI( + contents: Array< + | GeminiContent + | Array<{ + functionResponse: { id?: string; name: string; response: unknown } + }> + >, + systemInstruction?: GeminiContent, +): Array { + const messages: Array = [] + const pendingToolCalls = new Map() // tool_call_id -> function_name + + // Add system instruction first if present + if (systemInstruction) { + const systemText = extractTextFromGeminiContent(systemInstruction) + if (systemText) { + messages.push({ role: "system", content: systemText }) + } + } + + // Process conversation contents + for (const item of contents) { + // Handle special case where Gemini CLI sends function responses as nested arrays + if (Array.isArray(item)) { + processFunctionResponseArray(item, pendingToolCalls, messages) + continue + } + + const content = item + const role = content.role === "model" ? "assistant" : "user" + + // Check for function calls/responses + const functionCalls = content.parts.filter( + (part): part is GeminiFunctionCallPart => "functionCall" in part, + ) + const functionResponses = content.parts.filter( + (part): part is GeminiFunctionResponsePart => "functionResponse" in part, + ) + + if (functionResponses.length > 0) { + processFunctionResponses(functionResponses, pendingToolCalls, messages) + } + + if (functionCalls.length > 0 && role === "assistant") { + processFunctionCalls({ + functionCalls, + content, + pendingToolCalls, + messages, + }) + } else { + // Regular message + const messageContent = translateGeminiContentToOpenAI(content) + if (messageContent) { + messages.push({ role, content: messageContent }) + } + } + } + + // Post-process: Clean up messages and ensure tool call consistency + return cleanupMessages(messages) +} + +function translateGeminiContentToOpenAI( + content: GeminiContent, +): string | Array | null { + if (content.parts.length === 0) return null + + const hasMedia = content.parts.some((part) => "inlineData" in part) + + if (!hasMedia) { + // Text-only content + return extractTextFromGeminiContent(content) + } + + // Mixed content with media + const contentParts: Array = [] + for (const part of content.parts) { + if ("text" in part) { + contentParts.push({ type: "text", text: part.text }) + } else if ("inlineData" in part) { + // Handle inline data for images - this is a legacy format + const partWithInlineData = part as { + inlineData: { mimeType: string; data: string } + } + contentParts.push({ + type: "image_url", + image_url: { + url: `data:${partWithInlineData.inlineData.mimeType};base64,${partWithInlineData.inlineData.data}`, + }, + }) + } + } + + return contentParts +} + +function extractTextFromGeminiContent(content: GeminiContent): string { + return content.parts + .filter((part): part is GeminiTextPart => "text" in part) + .map((part) => part.text) + .join("\n\n") +} + +// Response translation: OpenAI -> Gemini + +// Helper function to deduplicate tool responses - remove duplicate tool_call_ids +// The problem was our logic was CREATING duplicates instead of preventing them + +export function translateOpenAIToGemini( + response: ChatCompletionResponse, +): GeminiResponse { + const result = { + candidates: response.choices.map((choice, index) => ({ + content: translateOpenAIMessageToGeminiContent(choice.message), + finishReason: mapOpenAIFinishReasonToGemini(choice.finish_reason), + index, + })), + usageMetadata: { + promptTokenCount: response.usage?.prompt_tokens || 0, + candidatesTokenCount: response.usage?.completion_tokens || 0, + totalTokenCount: response.usage?.total_tokens || 0, + }, + } + + // Debug: Log original GitHub Copilot response and translated Gemini response + if (process.env.DEBUG_GEMINI_REQUESTS === "true") { + DebugLogger.logResponseComparison(response, result, { + context: "Non-Stream Response Translation", + filePrefix: "debug-nonstream-comparison", + }).catch((error: unknown) => { + console.error( + "[DEBUG] Failed to log non-stream response comparison:", + error, + ) + }) + } + + return result +} + +function translateOpenAIMessageToGeminiContent( + message: Message, +): GeminiContent { + const parts: Array = [] + + // Handle text content + if (typeof message.content === "string") { + if (message.content) { + parts.push({ text: message.content }) + } + } else if (Array.isArray(message.content)) { + for (const part of message.content) { + if (part.type === "text") { + parts.push({ text: part.text }) + } else { + // Convert data URL back to inline data + const match = part.image_url.url.match(/^data:([^;]+);base64,(.+)$/) + if (match) { + parts.push({ + inlineData: { + mimeType: match[1], + data: match[2], + }, + }) + } + } + } + } + + // Handle tool calls + if (message.tool_calls) { + for (const toolCall of message.tool_calls) { + // Debug: Log tool call arguments to verify what GitHub Copilot returns + if (process.env.DEBUG_GEMINI_REQUESTS === "true") { + console.log( + `[DEBUG] Tool call - name: ${toolCall.function.name}, arguments: "${toolCall.function.arguments}", type: ${typeof toolCall.function.arguments}, truthy: ${Boolean(toolCall.function.arguments)}`, + ) + } + + parts.push({ + functionCall: { + name: toolCall.function.name, + args: + toolCall.function.arguments ? + (JSON.parse(toolCall.function.arguments) as Record< + string, + unknown + >) + : {}, + }, + }) + } + } + + return { + parts, + role: "model", + } +} + +// Utility functions + +// Helper function to create usage metadata +function createUsageMetadata(chunk: ChatCompletionChunk): GeminiUsageMetadata { + return { + promptTokenCount: chunk.usage?.prompt_tokens || 0, + candidatesTokenCount: chunk.usage?.completion_tokens || 0, + totalTokenCount: chunk.usage?.total_tokens || 0, + } +} + +// Helper function to process chunk parts +function processChunkParts( + choice: { + delta: { + content?: string | null + tool_calls?: Array<{ + index: number + id?: string + type?: "function" + function?: { + name?: string + arguments?: string + } + }> + } + }, + accumulator: ToolCallAccumulator, +): Array { + const parts: Array = [] + + if (choice.delta.content) { + parts.push({ text: choice.delta.content }) + } + + if (choice.delta.tool_calls) { + parts.push( + ...processToolCallsWithAccumulator(choice.delta.tool_calls, accumulator), + ) + } + + return parts +} + +// Helper function to determine finish reason inclusion +function shouldIncludeFinishReason(choice: { + finish_reason: "stop" | "length" | "tool_calls" | "content_filter" | null + delta: { + tool_calls?: Array + } +}): boolean { + // Always include finish_reason when present, regardless of tool calls + // This ensures proper stream termination for both text and tool call completions + return Boolean(choice.finish_reason) +} + +// Helper function to create candidate object +function createGeminiCandidate( + parts: Array, + mappedFinishReason: string | undefined, + index: number, +): GeminiCandidate { + return { + content: { + parts, + role: "model", + }, + finishReason: mappedFinishReason as GeminiCandidate["finishReason"], + index, + } +} + +// Helper function to handle parts processing and validation +function processParts( + choice: { + finish_reason: "stop" | "length" | "tool_calls" | "content_filter" | null + delta: { + content?: string | null + tool_calls?: Array<{ + index: number + id?: string + type?: "function" + function?: { + name?: string + arguments?: string + } + }> + } + }, + accumulator: ToolCallAccumulator, +): Array | null { + const parts = processChunkParts(choice, accumulator) + + if (parts.length === 0 && !choice.finish_reason) { + return null + } + + // If we have a finish reason but no parts, add an empty text part + // This ensures Gemini CLI receives a properly formatted completion chunk + if (parts.length === 0 && choice.finish_reason) { + parts.push({ text: "" }) + } + + return parts +} + +// Helper function to build complete response +function buildGeminiResponse( + candidate: GeminiCandidate, + shouldInclude: boolean, + chunk: ChatCompletionChunk, +): { + candidates: Array + usageMetadata?: GeminiUsageMetadata +} { + const response: { + candidates: Array + usageMetadata?: GeminiUsageMetadata + } = { + candidates: [candidate], + } + + if (shouldInclude) { + response.usageMetadata = createUsageMetadata(chunk) + } + + return response +} + +// Stream translation: OpenAI Chunk -> Gemini Stream Response +export function translateOpenAIChunkToGemini( + chunk: ChatCompletionChunk, + accumulator: ToolCallAccumulator, +): { + candidates: Array + usageMetadata?: GeminiUsageMetadata +} | null { + if (chunk.choices.length === 0) { + return null + } + + const choice = chunk.choices[0] + + const parts = processParts(choice, accumulator) + if (!parts) { + return null + } + + // Additional validation - if we only have function call parts with empty names, + // skip this chunk entirely to prevent invalid tool call responses + const hasOnlyEmptyToolCalls = + parts.length > 0 + && parts.every((part) => { + if ("functionCall" in part) { + return !part.functionCall.name || part.functionCall.name.trim() === "" + } + return false + }) + && parts.some((part) => "functionCall" in part) + + if (hasOnlyEmptyToolCalls && !choice.finish_reason) { + return null + } + + const shouldInclude = shouldIncludeFinishReason(choice) + const mappedFinishReason = + shouldInclude ? + mapOpenAIFinishReasonToGemini(choice.finish_reason) + : undefined + + const candidate = createGeminiCandidate( + parts, + mappedFinishReason, + choice.index, + ) + const response = buildGeminiResponse(candidate, shouldInclude, chunk) + + // Debug: Log original GitHub Copilot chunk and translated Gemini chunk for comparison + if (process.env.DEBUG_GEMINI_REQUESTS === "true") { + DebugLogger.logResponseComparison(chunk, response, { + context: "Streaming Chunk Translation", + filePrefix: "debug-stream-comparison", + }).catch((error: unknown) => { + console.error("[DEBUG] Failed to log streaming chunk comparison:", error) + }) + } + + return response +} + +// Token counting translation + +export function translateGeminiCountTokensToOpenAI( + request: GeminiCountTokensRequest, + model: string, +): ChatCompletionsPayload { + const tools = selectTools(request.tools, request.contents) + return { + model: mapGeminiModelToCopilot(model), + messages: translateGeminiContentsToOpenAI( + request.contents, + request.systemInstruction, + ), + max_tokens: 1, + tools, + } +} + +export function translateTokenCountToGemini( + totalTokens: number, +): GeminiCountTokensResponse { + return { + totalTokens, + } +} diff --git a/src/routes/generate-content/types.ts b/src/routes/generate-content/types.ts new file mode 100644 index 000000000..8e893afe1 --- /dev/null +++ b/src/routes/generate-content/types.ts @@ -0,0 +1,117 @@ +// Gemini API Types + +export interface GeminiRequest { + contents: Array + tools?: Array + toolConfig?: GeminiToolConfig + safetySettings?: Array> + systemInstruction?: GeminiContent + generationConfig?: Record +} + +export interface GeminiContent { + parts: Array + role?: "user" | "model" +} + +export type GeminiPart = + | GeminiTextPart + | GeminiFunctionCallPart + | GeminiFunctionResponsePart + | GeminiInlineDataPart + +export interface GeminiTextPart { + text: string +} + +interface GeminiInlineDataPart { + inlineData: { + mimeType: string + data: string + } +} + +export interface GeminiFunctionCallPart { + functionCall: { + name: string + args: Record + } +} + +export interface GeminiFunctionResponsePart { + functionResponse: { + name: string + response: Record + } +} + +export interface GeminiTool { + functionDeclarations?: Array + googleSearch?: Record + urlContext?: Record +} + +interface GeminiFunctionDeclaration { + name: string + description?: string + parameters?: Record + parametersJsonSchema?: Record +} + +interface GeminiToolConfig { + functionCallingConfig: { + mode: "AUTO" | "ANY" | "NONE" + allowedFunctionNames?: Array + } +} + +// Response types +export interface GeminiResponse { + candidates: Array + usageMetadata?: GeminiUsageMetadata + promptFeedback?: Record +} + +export interface GeminiCandidate { + content: GeminiContent + finishReason?: + | "FINISH_REASON_UNSPECIFIED" + | "STOP" + | "MAX_TOKENS" + | "SAFETY" + | "RECITATION" + | "LANGUAGE" + | "OTHER" + | "BLOCKLIST" + | "PROHIBITED_CONTENT" + | "SPII" + | "MALFORMED_FUNCTION_CALL" + | "IMAGE_SAFETY" + | "UNEXPECTED_TOOL_CALL" + | "TOO_MANY_TOOL_CALLS" + index: number + safetyRatings?: Array> +} + +export interface GeminiUsageMetadata { + promptTokenCount: number + candidatesTokenCount: number + totalTokenCount: number +} + +// Token counting types +export interface GeminiCountTokensRequest { + contents: Array + tools?: Array + systemInstruction?: GeminiContent +} + +export interface GeminiCountTokensResponse { + totalTokens: number +} + +// Streaming types +export interface GeminiStreamResponse { + candidates?: Array + usageMetadata?: GeminiUsageMetadata +} diff --git a/src/routes/generate-content/utils.ts b/src/routes/generate-content/utils.ts new file mode 100644 index 000000000..1f3fa7dc0 --- /dev/null +++ b/src/routes/generate-content/utils.ts @@ -0,0 +1,43 @@ +import { type GeminiCandidate } from "./types" + +const OpenAIFinish = { + stop: "stop", + length: "length", + content_filter: "content_filter", + tool_calls: "tool_calls", +} as const + +const GeminiFinish = { + FINISH_REASON_UNSPECIFIED: "FINISH_REASON_UNSPECIFIED", + STOP: "STOP", + MAX_TOKENS: "MAX_TOKENS", + SAFETY: "SAFETY", + RECITATION: "RECITATION", + BLOCKLIST: "BLOCKLIST", + PROHIBITED_CONTENT: "PROHIBITED_CONTENT", + SPII: "SPII", + IMAGE_SAFETY: "IMAGE_SAFETY", + MALFORMED_FUNCTION_CALL: "MALFORMED_FUNCTION_CALL", +} as const + +export function mapOpenAIFinishReasonToGemini( + finishReason: string | null, +): GeminiCandidate["finishReason"] { + switch (finishReason) { + case OpenAIFinish.stop: { + return "STOP" + } + case OpenAIFinish.length: { + return "MAX_TOKENS" + } + case OpenAIFinish.content_filter: { + return "SAFETY" + } + case OpenAIFinish.tool_calls: { + return "STOP" // Gemini doesn't have a specific tool_calls finish reason, map to STOP + } + default: { + return GeminiFinish.FINISH_REASON_UNSPECIFIED + } + } +} diff --git a/src/routes/messages/anthropic-types.ts b/src/routes/messages/anthropic-types.ts index 881fffcc8..64882b02b 100644 --- a/src/routes/messages/anthropic-types.ts +++ b/src/routes/messages/anthropic-types.ts @@ -56,6 +56,7 @@ export interface AnthropicToolUseBlock { export interface AnthropicThinkingBlock { type: "thinking" thinking: string + signature: string } export type AnthropicUserContentBlock = @@ -101,7 +102,7 @@ export interface AnthropicResponse { | "refusal" | null stop_sequence: string | null - usage: { + usage?: { input_tokens: number output_tokens: number cache_creation_input_tokens?: number diff --git a/src/routes/messages/count-tokens-handler.ts b/src/routes/messages/count-tokens-handler.ts new file mode 100644 index 000000000..2ec849cb8 --- /dev/null +++ b/src/routes/messages/count-tokens-handler.ts @@ -0,0 +1,70 @@ +import type { Context } from "hono" + +import consola from "consola" + +import { state } from "~/lib/state" +import { getTokenCount } from "~/lib/tokenizer" + +import { type AnthropicMessagesPayload } from "./anthropic-types" +import { translateToOpenAI } from "./non-stream-translation" + +/** + * Handles token counting for Anthropic messages + */ +export async function handleCountTokens(c: Context) { + try { + const anthropicBeta = c.req.header("anthropic-beta") + + const anthropicPayload = await c.req.json() + + const openAIPayload = translateToOpenAI(anthropicPayload) + + const selectedModel = state.models?.data.find( + (model) => model.id === anthropicPayload.model, + ) + + if (!selectedModel) { + consola.warn("Model not found, returning default token count") + return c.json({ + input_tokens: 1, + }) + } + + const tokenCount = await getTokenCount(openAIPayload, selectedModel) + + if (anthropicPayload.tools && anthropicPayload.tools.length > 0) { + let mcpToolExist = false + if (anthropicBeta?.startsWith("claude-code")) { + mcpToolExist = anthropicPayload.tools.some((tool) => + tool.name.startsWith("mcp__"), + ) + } + if (!mcpToolExist) { + if (anthropicPayload.model.startsWith("claude")) { + // https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/overview#pricing + tokenCount.input = tokenCount.input + 346 + } else if (anthropicPayload.model.startsWith("grok")) { + tokenCount.input = tokenCount.input + 480 + } + } + } + + let finalTokenCount = tokenCount.input + tokenCount.output + if (anthropicPayload.model.startsWith("claude")) { + finalTokenCount = Math.round(finalTokenCount * 1.15) + } else if (anthropicPayload.model.startsWith("grok")) { + finalTokenCount = Math.round(finalTokenCount * 1.03) + } + + consola.info("Token count:", finalTokenCount) + + return c.json({ + input_tokens: finalTokenCount, + }) + } catch (error) { + consola.error("Error counting tokens:", error) + return c.json({ + input_tokens: 1, + }) + } +} diff --git a/src/routes/messages/handler.ts b/src/routes/messages/handler.ts index 85dbf6243..10b97c53c 100644 --- a/src/routes/messages/handler.ts +++ b/src/routes/messages/handler.ts @@ -6,11 +6,24 @@ import { streamSSE } from "hono/streaming" import { awaitApproval } from "~/lib/approval" import { checkRateLimit } from "~/lib/rate-limit" import { state } from "~/lib/state" +import { + createResponsesStreamState, + translateResponsesStreamEvent, +} from "~/routes/messages/responses-stream-translation" +import { + translateAnthropicMessagesToResponsesPayload, + translateResponsesResultToAnthropic, +} from "~/routes/messages/responses-translation" +import { getResponsesRequestOptions } from "~/routes/responses/utils" import { createChatCompletions, type ChatCompletionChunk, type ChatCompletionResponse, } from "~/services/copilot/create-chat-completions" +import { + createResponses, + type ResponsesResult, +} from "~/services/copilot/create-responses" import { type AnthropicMessagesPayload, @@ -28,16 +41,31 @@ export async function handleCompletion(c: Context) { const anthropicPayload = await c.req.json() consola.debug("Anthropic request payload:", JSON.stringify(anthropicPayload)) + const useResponsesApi = shouldUseResponsesApi(anthropicPayload.model) + + if (state.manualApprove) { + await awaitApproval() + } + + if (useResponsesApi) { + return await handleWithResponsesApi(c, anthropicPayload) + } + + return await handleWithChatCompletions(c, anthropicPayload) +} + +const RESPONSES_ENDPOINT = "/responses" + +const handleWithChatCompletions = async ( + c: Context, + anthropicPayload: AnthropicMessagesPayload, +) => { const openAIPayload = translateToOpenAI(anthropicPayload) consola.debug( "Translated OpenAI request payload:", JSON.stringify(openAIPayload), ) - if (state.manualApprove) { - await awaitApproval() - } - const response = await createChatCompletions(openAIPayload) if (isNonStreaming(response)) { @@ -86,6 +114,108 @@ export async function handleCompletion(c: Context) { }) } +const handleWithResponsesApi = async ( + c: Context, + anthropicPayload: AnthropicMessagesPayload, +) => { + const responsesPayload = + translateAnthropicMessagesToResponsesPayload(anthropicPayload) + consola.debug( + "Translated Responses payload:", + JSON.stringify(responsesPayload), + ) + + const { vision, initiator } = getResponsesRequestOptions(responsesPayload) + const response = await createResponses(responsesPayload, { + vision, + initiator, + }) + + if (responsesPayload.stream && isAsyncIterable(response)) { + consola.debug("Streaming response from Copilot (Responses API)") + return streamSSE(c, async (stream) => { + const streamState = createResponsesStreamState() + + for await (const chunk of response) { + consola.debug("Responses raw stream event:", JSON.stringify(chunk)) + + const eventName = (chunk as { event?: string }).event + if (eventName === "ping") { + await stream.writeSSE({ event: "ping", data: "" }) + continue + } + + const data = (chunk as { data?: string }).data + if (!data) { + continue + } + + if (data === "[DONE]") { + break + } + + const parsed = safeJsonParse(data) + if (!parsed) { + continue + } + + const events = translateResponsesStreamEvent(parsed, streamState) + for (const event of events) { + consola.debug("Translated Anthropic event:", JSON.stringify(event)) + await stream.writeSSE({ + event: event.type, + data: JSON.stringify(event), + }) + } + } + + if (!streamState.messageCompleted) { + consola.warn( + "Responses stream ended without completion; sending fallback message_stop", + ) + const fallback = { type: "message_stop" as const } + await stream.writeSSE({ + event: fallback.type, + data: JSON.stringify(fallback), + }) + } + }) + } + + consola.debug( + "Non-streaming Responses result:", + JSON.stringify(response).slice(-400), + ) + const anthropicResponse = translateResponsesResultToAnthropic( + response as ResponsesResult, + ) + consola.debug( + "Translated Anthropic response:", + JSON.stringify(anthropicResponse), + ) + return c.json(anthropicResponse) +} + +const shouldUseResponsesApi = (modelId: string): boolean => { + const selectedModel = state.models?.data.find((model) => model.id === modelId) + return ( + selectedModel?.supported_endpoints?.includes(RESPONSES_ENDPOINT) ?? false + ) +} + const isNonStreaming = ( response: Awaited>, ): response is ChatCompletionResponse => Object.hasOwn(response, "choices") + +const isAsyncIterable = (value: unknown): value is AsyncIterable => + Boolean(value) + && typeof (value as AsyncIterable)[Symbol.asyncIterator] === "function" + +const safeJsonParse = (value: string): Record | undefined => { + try { + return JSON.parse(value) as Record + } catch (error) { + consola.warn("Failed to parse Responses stream chunk:", value, error) + return undefined + } +} diff --git a/src/routes/messages/non-stream-translation.ts b/src/routes/messages/non-stream-translation.ts index 271aa47f6..dc41e6382 100644 --- a/src/routes/messages/non-stream-translation.ts +++ b/src/routes/messages/non-stream-translation.ts @@ -313,8 +313,15 @@ export function translateToAnthropic( stop_reason: mapOpenAIStopReasonToAnthropic(stopReason), stop_sequence: null, usage: { - input_tokens: response.usage?.prompt_tokens ?? 0, + input_tokens: + (response.usage?.prompt_tokens ?? 0) + - (response.usage?.prompt_tokens_details?.cached_tokens ?? 0), output_tokens: response.usage?.completion_tokens ?? 0, + ...(response.usage?.prompt_tokens_details?.cached_tokens + !== undefined && { + cache_read_input_tokens: + response.usage.prompt_tokens_details.cached_tokens, + }), }, } } diff --git a/src/routes/messages/responses-stream-translation.ts b/src/routes/messages/responses-stream-translation.ts new file mode 100644 index 000000000..db09bf144 --- /dev/null +++ b/src/routes/messages/responses-stream-translation.ts @@ -0,0 +1,803 @@ +import { type ResponsesResult } from "~/services/copilot/create-responses" + +import { type AnthropicStreamEventData } from "./anthropic-types" +import { translateResponsesResultToAnthropic } from "./responses-translation" + +export interface ResponsesStreamState { + messageStartSent: boolean + messageCompleted: boolean + nextContentBlockIndex: number + blockIndexByKey: Map + openBlocks: Set + blockHasDelta: Set + currentResponseId?: string + currentModel?: string + initialInputTokens?: number + initialInputCachedTokens?: number + functionCallStateByOutputIndex: Map + functionCallOutputIndexByItemId: Map +} + +type FunctionCallStreamState = { + blockIndex: number + toolCallId: string + name: string +} + +export const createResponsesStreamState = (): ResponsesStreamState => ({ + messageStartSent: false, + messageCompleted: false, + nextContentBlockIndex: 0, + blockIndexByKey: new Map(), + openBlocks: new Set(), + blockHasDelta: new Set(), + functionCallStateByOutputIndex: new Map(), + functionCallOutputIndexByItemId: new Map(), +}) + +export const translateResponsesStreamEvent = ( + rawEvent: Record, + state: ResponsesStreamState, +): Array => { + const eventType = + typeof rawEvent.type === "string" ? rawEvent.type : undefined + if (!eventType) { + return [] + } + + switch (eventType) { + case "response.created": { + return handleResponseCreated(rawEvent, state) + } + + case "response.reasoning_summary_text.delta": { + return handleReasoningSummaryTextDelta(rawEvent, state) + } + + case "response.output_text.delta": { + return handleOutputTextDelta(rawEvent, state) + } + + case "response.reasoning_summary_part.done": { + return handleReasoningSummaryPartDone(rawEvent, state) + } + + case "response.output_text.done": { + return handleOutputTextDone(rawEvent, state) + } + + case "response.output_item.added": { + return handleOutputItemAdded(rawEvent, state) + } + + case "response.output_item.done": { + return handleOutputItemDone(rawEvent, state) + } + + case "response.function_call_arguments.delta": { + return handleFunctionCallArgumentsDelta(rawEvent, state) + } + + case "response.function_call_arguments.done": { + return handleFunctionCallArgumentsDone(rawEvent, state) + } + + case "response.completed": + case "response.incomplete": { + return handleResponseCompleted(rawEvent, state) + } + + case "response.failed": { + return handleResponseFailed(rawEvent, state) + } + + case "error": { + return handleErrorEvent(rawEvent, state) + } + + default: { + return [] + } + } +} + +// Helper handlers to keep translateResponsesStreamEvent concise +const handleResponseCreated = ( + rawEvent: Record, + state: ResponsesStreamState, +): Array => { + const response = toResponsesResult(rawEvent.response) + if (response) { + cacheResponseMetadata(state, response) + } + return ensureMessageStart(state, response) +} + +const handleOutputItemAdded = ( + rawEvent: Record, + state: ResponsesStreamState, +): Array => { + const response = toResponsesResult(rawEvent.response) + const events = ensureMessageStart(state, response) + + const functionCallDetails = extractFunctionCallDetails(rawEvent, state) + if (!functionCallDetails) { + return events + } + + const { outputIndex, toolCallId, name, initialArguments, itemId } = + functionCallDetails + + if (itemId) { + state.functionCallOutputIndexByItemId.set(itemId, outputIndex) + } + + const blockIndex = openFunctionCallBlock(state, { + outputIndex, + toolCallId, + name, + events, + }) + + if (initialArguments !== undefined && initialArguments.length > 0) { + events.push({ + type: "content_block_delta", + index: blockIndex, + delta: { + type: "input_json_delta", + partial_json: initialArguments, + }, + }) + state.blockHasDelta.add(blockIndex) + } + + return events +} + +const handleOutputItemDone = ( + rawEvent: Record, + state: ResponsesStreamState, +): Array => { + const events = ensureMessageStart(state) + + const item = isRecord(rawEvent.item) ? rawEvent.item : undefined + if (!item) { + return events + } + + const itemType = typeof item.type === "string" ? item.type : undefined + if (itemType !== "reasoning") { + return events + } + + const outputIndex = toNumber(rawEvent.output_index) + + const blockIndex = openThinkingBlockIfNeeded(state, outputIndex, events) + + const signature = + typeof item.encrypted_content === "string" ? item.encrypted_content : "" + + if (signature) { + events.push({ + type: "content_block_delta", + index: blockIndex, + delta: { + type: "signature_delta", + signature, + }, + }) + state.blockHasDelta.add(blockIndex) + } + + closeBlockIfOpen(state, blockIndex, events) + + return events +} + +const handleFunctionCallArgumentsDelta = ( + rawEvent: Record, + state: ResponsesStreamState, +): Array => { + const events = ensureMessageStart(state) + + const outputIndex = resolveFunctionCallOutputIndex(state, rawEvent) + if (outputIndex === undefined) { + return events + } + + const deltaText = typeof rawEvent.delta === "string" ? rawEvent.delta : "" + if (!deltaText) { + return events + } + + const blockIndex = openFunctionCallBlock(state, { + outputIndex, + events, + }) + + events.push({ + type: "content_block_delta", + index: blockIndex, + delta: { + type: "input_json_delta", + partial_json: deltaText, + }, + }) + state.blockHasDelta.add(blockIndex) + + return events +} + +const handleFunctionCallArgumentsDone = ( + rawEvent: Record, + state: ResponsesStreamState, +): Array => { + const events = ensureMessageStart(state) + + const outputIndex = resolveFunctionCallOutputIndex(state, rawEvent) + if (outputIndex === undefined) { + return events + } + + const blockIndex = openFunctionCallBlock(state, { + outputIndex, + events, + }) + + const finalArguments = + typeof rawEvent.arguments === "string" ? rawEvent.arguments : undefined + + if (!state.blockHasDelta.has(blockIndex) && finalArguments) { + events.push({ + type: "content_block_delta", + index: blockIndex, + delta: { + type: "input_json_delta", + partial_json: finalArguments, + }, + }) + state.blockHasDelta.add(blockIndex) + } + + closeBlockIfOpen(state, blockIndex, events) + + const existingState = state.functionCallStateByOutputIndex.get(outputIndex) + if (existingState) { + state.functionCallOutputIndexByItemId.delete(existingState.toolCallId) + } + state.functionCallStateByOutputIndex.delete(outputIndex) + + const itemId = toNonEmptyString(rawEvent.item_id) + if (itemId) { + state.functionCallOutputIndexByItemId.delete(itemId) + } + + return events +} + +const handleOutputTextDelta = ( + rawEvent: Record, + state: ResponsesStreamState, +): Array => { + const events = ensureMessageStart(state) + + const outputIndex = toNumber(rawEvent.output_index) + const contentIndex = toNumber(rawEvent.content_index) + const deltaText = typeof rawEvent.delta === "string" ? rawEvent.delta : "" + + if (!deltaText) { + return events + } + + const blockIndex = openTextBlockIfNeeded(state, { + outputIndex, + contentIndex, + events, + }) + + events.push({ + type: "content_block_delta", + index: blockIndex, + delta: { + type: "text_delta", + text: deltaText, + }, + }) + state.blockHasDelta.add(blockIndex) + + return events +} + +const handleReasoningSummaryTextDelta = ( + rawEvent: Record, + state: ResponsesStreamState, +): Array => { + const events = ensureMessageStart(state) + + const outputIndex = toNumber(rawEvent.output_index) + const deltaText = typeof rawEvent.delta === "string" ? rawEvent.delta : "" + + if (!deltaText) { + return events + } + + const blockIndex = openThinkingBlockIfNeeded(state, outputIndex, events) + + events.push({ + type: "content_block_delta", + index: blockIndex, + delta: { + type: "thinking_delta", + thinking: deltaText, + }, + }) + state.blockHasDelta.add(blockIndex) + + return events +} + +const handleReasoningSummaryPartDone = ( + rawEvent: Record, + state: ResponsesStreamState, +): Array => { + const events = ensureMessageStart(state) + + const outputIndex = toNumber(rawEvent.output_index) + const part = isRecord(rawEvent.part) ? rawEvent.part : undefined + const text = part && typeof part.text === "string" ? part.text : "" + + const blockIndex = openThinkingBlockIfNeeded(state, outputIndex, events) + + if (text && !state.blockHasDelta.has(blockIndex)) { + events.push({ + type: "content_block_delta", + index: blockIndex, + delta: { + type: "thinking_delta", + thinking: text, + }, + }) + } + + return events +} + +const handleOutputTextDone = ( + rawEvent: Record, + state: ResponsesStreamState, +): Array => { + const events = ensureMessageStart(state) + + const outputIndex = toNumber(rawEvent.output_index) + const contentIndex = toNumber(rawEvent.content_index) + const text = typeof rawEvent.text === "string" ? rawEvent.text : "" + + const blockIndex = openTextBlockIfNeeded(state, { + outputIndex, + contentIndex, + events, + }) + + if (text && !state.blockHasDelta.has(blockIndex)) { + events.push({ + type: "content_block_delta", + index: blockIndex, + delta: { + type: "text_delta", + text, + }, + }) + } + + closeBlockIfOpen(state, blockIndex, events) + + return events +} + +const handleResponseCompleted = ( + rawEvent: Record, + state: ResponsesStreamState, +): Array => { + const response = toResponsesResult(rawEvent.response) + const events = ensureMessageStart(state, response) + + closeAllOpenBlocks(state, events) + + if (response) { + const anthropic = translateResponsesResultToAnthropic(response) + events.push({ + type: "message_delta", + delta: { + stop_reason: anthropic.stop_reason, + stop_sequence: anthropic.stop_sequence, + }, + usage: anthropic.usage, + }) + } else { + events.push({ + type: "message_delta", + delta: { + stop_reason: null, + stop_sequence: null, + }, + }) + } + + events.push({ type: "message_stop" }) + state.messageCompleted = true + + return events +} + +const handleResponseFailed = ( + rawEvent: Record, + state: ResponsesStreamState, +): Array => { + const response = toResponsesResult(rawEvent.response) + const events = ensureMessageStart(state, response) + + closeAllOpenBlocks(state, events) + + const message = + typeof rawEvent.error === "string" ? + rawEvent.error + : "Response generation failed." + + events.push(buildErrorEvent(message)) + state.messageCompleted = true + + return events +} + +const handleErrorEvent = ( + rawEvent: Record, + state: ResponsesStreamState, +): Array => { + const message = + typeof rawEvent.message === "string" ? + rawEvent.message + : "An unexpected error occurred during streaming." + + state.messageCompleted = true + return [buildErrorEvent(message)] +} + +const ensureMessageStart = ( + state: ResponsesStreamState, + response?: ResponsesResult, +): Array => { + if (state.messageStartSent) { + return [] + } + + if (response) { + cacheResponseMetadata(state, response) + } + + const id = response?.id ?? state.currentResponseId ?? "response" + const model = response?.model ?? state.currentModel ?? "" + + state.messageStartSent = true + + const inputTokens = + (state.initialInputTokens ?? 0) - (state.initialInputCachedTokens ?? 0) + return [ + { + type: "message_start", + message: { + id, + type: "message", + role: "assistant", + content: [], + model, + stop_reason: null, + stop_sequence: null, + usage: { + input_tokens: inputTokens, + output_tokens: 0, + ...(state.initialInputCachedTokens !== undefined && { + cache_creation_input_tokens: state.initialInputCachedTokens, + }), + }, + }, + }, + ] +} + +const openTextBlockIfNeeded = ( + state: ResponsesStreamState, + params: { + outputIndex: number + contentIndex: number + events: Array + }, +): number => { + const { outputIndex, contentIndex, events } = params + const key = getBlockKey(outputIndex, contentIndex) + let blockIndex = state.blockIndexByKey.get(key) + + if (blockIndex === undefined) { + blockIndex = state.nextContentBlockIndex + state.nextContentBlockIndex += 1 + state.blockIndexByKey.set(key, blockIndex) + } + + if (!state.openBlocks.has(blockIndex)) { + events.push({ + type: "content_block_start", + index: blockIndex, + content_block: { + type: "text", + text: "", + }, + }) + state.openBlocks.add(blockIndex) + } + + return blockIndex +} + +const openThinkingBlockIfNeeded = ( + state: ResponsesStreamState, + outputIndex: number, + events: Array, +): number => { + const contentIndex = 0 + const key = getBlockKey(outputIndex, contentIndex) + let blockIndex = state.blockIndexByKey.get(key) + + if (blockIndex === undefined) { + blockIndex = state.nextContentBlockIndex + state.nextContentBlockIndex += 1 + state.blockIndexByKey.set(key, blockIndex) + } + + if (!state.openBlocks.has(blockIndex)) { + events.push({ + type: "content_block_start", + index: blockIndex, + content_block: { + type: "thinking", + thinking: "", + }, + }) + state.openBlocks.add(blockIndex) + } + + return blockIndex +} + +const closeBlockIfOpen = ( + state: ResponsesStreamState, + blockIndex: number, + events: Array, +) => { + if (!state.openBlocks.has(blockIndex)) { + return + } + + events.push({ type: "content_block_stop", index: blockIndex }) + state.openBlocks.delete(blockIndex) + state.blockHasDelta.delete(blockIndex) +} + +const closeAllOpenBlocks = ( + state: ResponsesStreamState, + events: Array, +) => { + for (const blockIndex of state.openBlocks) { + closeBlockIfOpen(state, blockIndex, events) + } + + state.functionCallStateByOutputIndex.clear() + state.functionCallOutputIndexByItemId.clear() +} + +const cacheResponseMetadata = ( + state: ResponsesStreamState, + response: ResponsesResult, +) => { + state.currentResponseId = response.id + state.currentModel = response.model + state.initialInputTokens = response.usage?.input_tokens ?? 0 + state.initialInputCachedTokens = + response.usage?.input_tokens_details?.cached_tokens +} + +const buildErrorEvent = (message: string): AnthropicStreamEventData => ({ + type: "error", + error: { + type: "api_error", + message, + }, +}) + +const getBlockKey = (outputIndex: number, contentIndex: number): string => + `${outputIndex}:${contentIndex}` + +const resolveFunctionCallOutputIndex = ( + state: ResponsesStreamState, + rawEvent: Record, +): number | undefined => { + if ( + typeof rawEvent.output_index === "number" + || (typeof rawEvent.output_index === "string" + && rawEvent.output_index.length > 0) + ) { + const parsed = toOptionalNumber(rawEvent.output_index) + if (parsed !== undefined) { + return parsed + } + } + + const itemId = toNonEmptyString(rawEvent.item_id) + if (itemId) { + const mapped = state.functionCallOutputIndexByItemId.get(itemId) + if (mapped !== undefined) { + return mapped + } + } + + return undefined +} + +const openFunctionCallBlock = ( + state: ResponsesStreamState, + params: { + outputIndex: number + toolCallId?: string + name?: string + events: Array + }, +): number => { + const { outputIndex, toolCallId, name, events } = params + + let functionCallState = state.functionCallStateByOutputIndex.get(outputIndex) + + if (!functionCallState) { + const blockIndex = state.nextContentBlockIndex + state.nextContentBlockIndex += 1 + + const resolvedToolCallId = toolCallId ?? `tool_call_${blockIndex}` + const resolvedName = name ?? "function" + + functionCallState = { + blockIndex, + toolCallId: resolvedToolCallId, + name: resolvedName, + } + + state.functionCallStateByOutputIndex.set(outputIndex, functionCallState) + state.functionCallOutputIndexByItemId.set(resolvedToolCallId, outputIndex) + } + + const { blockIndex } = functionCallState + + if (!state.openBlocks.has(blockIndex)) { + events.push({ + type: "content_block_start", + index: blockIndex, + content_block: { + type: "tool_use", + id: functionCallState.toolCallId, + name: functionCallState.name, + input: {}, + }, + }) + state.openBlocks.add(blockIndex) + } + + return blockIndex +} + +type FunctionCallDetails = { + outputIndex: number + toolCallId: string + name: string + initialArguments?: string + itemId?: string +} + +const extractFunctionCallDetails = ( + rawEvent: Record, + state: ResponsesStreamState, +): FunctionCallDetails | undefined => { + const item = isRecord(rawEvent.item) ? rawEvent.item : undefined + if (!item) { + return undefined + } + + const itemType = typeof item.type === "string" ? item.type : undefined + if (itemType !== "function_call") { + return undefined + } + + const outputIndex = resolveFunctionCallOutputIndex(state, rawEvent) + if (outputIndex === undefined) { + return undefined + } + + const callId = toNonEmptyString(item.call_id) + const itemId = toNonEmptyString(item.id) + const name = toNonEmptyString(item.name) ?? "function" + + const toolCallId = callId ?? itemId ?? `tool_call_${outputIndex}` + const initialArguments = + typeof item.arguments === "string" ? item.arguments : undefined + + return { + outputIndex, + toolCallId, + name, + initialArguments, + itemId, + } +} + +const toResponsesResult = (value: unknown): ResponsesResult | undefined => + isResponsesResult(value) ? value : undefined + +const toOptionalNumber = (value: unknown): number | undefined => { + if (typeof value === "number" && Number.isFinite(value)) { + return value + } + + if (typeof value === "string" && value.length > 0) { + const parsed = Number(value) + if (Number.isFinite(parsed)) { + return parsed + } + } + + return undefined +} + +const toNonEmptyString = (value: unknown): string | undefined => { + if (typeof value === "string" && value.length > 0) { + return value + } + + return undefined +} + +const toNumber = (value: unknown): number => { + if (typeof value === "number" && Number.isFinite(value)) { + return value + } + + if (typeof value === "string") { + const parsed = Number(value) + if (Number.isFinite(parsed)) { + return parsed + } + } + + return 0 +} + +const isResponsesResult = (value: unknown): value is ResponsesResult => { + if (!isRecord(value)) { + return false + } + + if (typeof value.id !== "string") { + return false + } + + if (typeof value.model !== "string") { + return false + } + + if (!Array.isArray(value.output)) { + return false + } + + if (typeof value.object !== "string") { + return false + } + + return true +} + +const isRecord = (value: unknown): value is Record => + typeof value === "object" && value !== null diff --git a/src/routes/messages/responses-translation.ts b/src/routes/messages/responses-translation.ts new file mode 100644 index 000000000..00f481291 --- /dev/null +++ b/src/routes/messages/responses-translation.ts @@ -0,0 +1,647 @@ +import consola from "consola" + +import { + type ResponsesPayload, + type ResponseInputContent, + type ResponseInputImage, + type ResponseInputItem, + type ResponseInputMessage, + type ResponseInputReasoning, + type ResponseInputText, + type ResponsesResult, + type ResponseOutputContentBlock, + type ResponseOutputFunctionCall, + type ResponseOutputItem, + type ResponseOutputReasoning, + type ResponseReasoningBlock, + type ResponseOutputRefusal, + type ResponseOutputText, + type ResponseFunctionToolCallItem, + type ResponseFunctionCallOutputItem, +} from "~/services/copilot/create-responses" + +import { + type AnthropicAssistantContentBlock, + type AnthropicAssistantMessage, + type AnthropicResponse, + type AnthropicImageBlock, + type AnthropicMessage, + type AnthropicMessagesPayload, + type AnthropicTextBlock, + type AnthropicThinkingBlock, + type AnthropicTool, + type AnthropicToolResultBlock, + type AnthropicToolUseBlock, + type AnthropicUserContentBlock, + type AnthropicUserMessage, +} from "./anthropic-types" + +const MESSAGE_TYPE = "message" + +export const translateAnthropicMessagesToResponsesPayload = ( + payload: AnthropicMessagesPayload, +): ResponsesPayload => { + const input: Array = [] + + for (const message of payload.messages) { + input.push(...translateMessage(message)) + } + + const translatedTools = convertAnthropicTools(payload.tools) + const toolChoice = convertAnthropicToolChoice(payload.tool_choice) + + const { safetyIdentifier, promptCacheKey } = parseUserId( + payload.metadata?.user_id, + ) + + const responsesPayload: ResponsesPayload = { + model: payload.model, + input, + instructions: translateSystemPrompt(payload.system), + temperature: payload.temperature ?? null, + top_p: payload.top_p ?? null, + max_output_tokens: payload.max_tokens, + tools: translatedTools, + tool_choice: toolChoice, + metadata: payload.metadata ? { ...payload.metadata } : null, + safety_identifier: safetyIdentifier, + prompt_cache_key: promptCacheKey, + stream: payload.stream ?? null, + store: false, + parallel_tool_calls: true, + reasoning: { effort: "high", summary: "auto" }, + include: ["reasoning.encrypted_content"], + } + + return responsesPayload +} + +const translateMessage = ( + message: AnthropicMessage, +): Array => { + if (message.role === "user") { + return translateUserMessage(message) + } + + return translateAssistantMessage(message) +} + +const translateUserMessage = ( + message: AnthropicUserMessage, +): Array => { + if (typeof message.content === "string") { + return [createMessage("user", message.content)] + } + + if (!Array.isArray(message.content)) { + return [] + } + + const items: Array = [] + const pendingContent: Array = [] + + for (const block of message.content) { + if (block.type === "tool_result") { + flushPendingContent("user", pendingContent, items) + items.push(createFunctionCallOutput(block)) + continue + } + + const converted = translateUserContentBlock(block) + if (converted) { + pendingContent.push(converted) + } + } + + flushPendingContent("user", pendingContent, items) + + return items +} + +const translateAssistantMessage = ( + message: AnthropicAssistantMessage, +): Array => { + if (typeof message.content === "string") { + return [createMessage("assistant", message.content)] + } + + if (!Array.isArray(message.content)) { + return [] + } + + const items: Array = [] + const pendingContent: Array = [] + + for (const block of message.content) { + if (block.type === "tool_use") { + flushPendingContent("assistant", pendingContent, items) + items.push(createFunctionToolCall(block)) + continue + } + + if (block.type === "thinking") { + flushPendingContent("assistant", pendingContent, items) + items.push(createReasoningContent(block)) + continue + } + + const converted = translateAssistantContentBlock(block) + if (converted) { + pendingContent.push(converted) + } + } + + flushPendingContent("assistant", pendingContent, items) + + return items +} + +const translateUserContentBlock = ( + block: AnthropicUserContentBlock, +): ResponseInputContent | undefined => { + switch (block.type) { + case "text": { + return createTextContent(block.text) + } + case "image": { + return createImageContent(block) + } + default: { + return undefined + } + } +} + +const translateAssistantContentBlock = ( + block: AnthropicAssistantContentBlock, +): ResponseInputContent | undefined => { + switch (block.type) { + case "text": { + return createOutPutTextContent(block.text) + } + default: { + return undefined + } + } +} + +const flushPendingContent = ( + role: ResponseInputMessage["role"], + pendingContent: Array, + target: Array, +) => { + if (pendingContent.length === 0) { + return + } + + const messageContent = + pendingContent.length === 1 && isPlainText(pendingContent[0]) ? + pendingContent[0].text + : [...pendingContent] + + target.push(createMessage(role, messageContent)) + pendingContent.length = 0 +} + +const createMessage = ( + role: ResponseInputMessage["role"], + content: string | Array, +): ResponseInputMessage => ({ + type: MESSAGE_TYPE, + role, + content, +}) + +const createTextContent = (text: string): ResponseInputText => ({ + type: "input_text", + text, +}) + +const createOutPutTextContent = (text: string): ResponseInputText => ({ + type: "output_text", + text, +}) + +const createImageContent = ( + block: AnthropicImageBlock, +): ResponseInputImage => ({ + type: "input_image", + image_url: `data:${block.source.media_type};base64,${block.source.data}`, +}) + +const createReasoningContent = ( + block: AnthropicThinkingBlock, +): ResponseInputReasoning => ({ + type: "reasoning", + summary: [ + { + type: "summary_text", + text: block.thinking, + }, + ], + encrypted_content: block.signature, +}) + +const createFunctionToolCall = ( + block: AnthropicToolUseBlock, +): ResponseFunctionToolCallItem => ({ + type: "function_call", + call_id: block.id, + name: block.name, + arguments: JSON.stringify(block.input), + status: "completed", +}) + +const createFunctionCallOutput = ( + block: AnthropicToolResultBlock, +): ResponseFunctionCallOutputItem => ({ + type: "function_call_output", + call_id: block.tool_use_id, + output: convertToolResultContent(block.content), + status: block.is_error ? "incomplete" : "completed", +}) + +const translateSystemPrompt = ( + system: string | Array | undefined, +): string | null => { + if (!system) { + return null + } + + const toolUsePrompt = ` +## Tool use +- You have access to many tools. If a tool exists to perform a specific task, you MUST use that tool instead of running a terminal command to perform that task. +### Bash tool +When using the Bash tool, follow these rules: +- always run_in_background set to false, unless you are running a long-running command (e.g., a server or a watch command). +### BashOutput tool +When using the BashOutput tool, follow these rules: +- Only Bash Tool run_in_background set to true, Use BashOutput to read the output later +### TodoWrite tool +When using the TodoWrite tool, follow these rules: +- Skip using the TodoWrite tool for simple or straightforward tasks (roughly the easiest 25%). +- Do not make single-step todo lists. +- When you made a todo, update it after having performed one of the sub-tasks that you shared on the todo list.` + + if (typeof system === "string") { + return system + toolUsePrompt + } + + const text = system + .map((block, index) => { + if (index === 0) { + return block.text + toolUsePrompt + } + return block.text + }) + .join(" ") + return text.length > 0 ? text : null +} + +const convertAnthropicTools = ( + tools: Array | undefined, +): Array> | null => { + if (!tools || tools.length === 0) { + return null + } + + return tools.map((tool) => ({ + type: "function", + name: tool.name, + parameters: tool.input_schema, + strict: false, + ...(tool.description ? { description: tool.description } : {}), + })) +} + +const convertAnthropicToolChoice = ( + choice: AnthropicMessagesPayload["tool_choice"], +): unknown => { + if (!choice) { + return undefined + } + + switch (choice.type) { + case "auto": { + return "auto" + } + case "any": { + return "required" + } + case "tool": { + return choice.name ? { type: "function", name: choice.name } : undefined + } + case "none": { + return "none" + } + default: { + return undefined + } + } +} + +const isPlainText = ( + content: ResponseInputContent, +): content is ResponseInputText | { text: string } => { + if (typeof content !== "object") { + return false + } + + return ( + "text" in content + && typeof (content as ResponseInputText).text === "string" + && !("image_url" in content) + ) +} + +export const translateResponsesResultToAnthropic = ( + response: ResponsesResult, +): AnthropicResponse => { + const contentBlocks = mapOutputToAnthropicContent(response.output) + const usage = mapResponsesUsage(response) + let anthropicContent = fallbackContentBlocks(response.output_text) + if (contentBlocks.length > 0) { + anthropicContent = contentBlocks + } + + const stopReason = mapResponsesStopReason(response) + + return { + id: response.id, + type: "message", + role: "assistant", + content: anthropicContent, + model: response.model, + stop_reason: stopReason, + stop_sequence: null, + usage, + } +} + +const mapOutputToAnthropicContent = ( + output: Array, +): Array => { + const contentBlocks: Array = [] + + for (const item of output) { + switch (item.type) { + case "reasoning": { + const thinkingText = extractReasoningText(item) + if (thinkingText.length > 0) { + contentBlocks.push({ + type: "thinking", + thinking: thinkingText, + signature: item.encrypted_content ?? "", + }) + } + break + } + case "function_call": { + const toolUseBlock = createToolUseContentBlock(item) + if (toolUseBlock) { + contentBlocks.push(toolUseBlock) + } + break + } + case "message": { + const combinedText = combineMessageTextContent(item.content) + if (combinedText.length > 0) { + contentBlocks.push({ type: "text", text: combinedText }) + } + break + } + default: { + // Future compatibility for unrecognized output item types. + const combinedText = combineMessageTextContent( + (item as { content?: Array }).content, + ) + if (combinedText.length > 0) { + contentBlocks.push({ type: "text", text: combinedText }) + } + } + } + } + + return contentBlocks +} + +const combineMessageTextContent = ( + content: Array | undefined, +): string => { + if (!Array.isArray(content)) { + return "" + } + + let aggregated = "" + + for (const block of content) { + if (isResponseOutputText(block)) { + aggregated += block.text + continue + } + + if (isResponseOutputRefusal(block)) { + aggregated += block.refusal + continue + } + + if (typeof (block as { text?: unknown }).text === "string") { + aggregated += (block as { text: string }).text + continue + } + + if (typeof (block as { reasoning?: unknown }).reasoning === "string") { + aggregated += (block as { reasoning: string }).reasoning + continue + } + } + + return aggregated +} + +const extractReasoningText = (item: ResponseOutputReasoning): string => { + const segments: Array = [] + + const collectFromBlocks = (blocks?: Array) => { + if (!Array.isArray(blocks)) { + return + } + + for (const block of blocks) { + if (typeof block.text === "string") { + segments.push(block.text) + continue + } + } + } + + collectFromBlocks(item.summary) + + return segments.join("").trim() +} + +const createToolUseContentBlock = ( + call: ResponseOutputFunctionCall, +): AnthropicToolUseBlock | null => { + const toolId = call.call_id ?? call.id + if (!call.name || !toolId) { + return null + } + + const input = parseFunctionCallArguments(call.arguments) + + return { + type: "tool_use", + id: toolId, + name: call.name, + input, + } +} + +const parseFunctionCallArguments = ( + rawArguments: string, +): Record => { + if (typeof rawArguments !== "string" || rawArguments.trim().length === 0) { + return {} + } + + try { + const parsed: unknown = JSON.parse(rawArguments) + + if (Array.isArray(parsed)) { + return { arguments: parsed } + } + + if (parsed && typeof parsed === "object") { + return parsed as Record + } + } catch (error) { + consola.warn("Failed to parse function call arguments", { + error, + rawArguments, + }) + } + + return { raw_arguments: rawArguments } +} + +const fallbackContentBlocks = ( + outputText: string, +): Array => { + if (!outputText) { + return [] + } + + return [ + { + type: "text", + text: outputText, + }, + ] +} + +const mapResponsesStopReason = ( + response: ResponsesResult, +): AnthropicResponse["stop_reason"] => { + const { status, incomplete_details: incompleteDetails } = response + + if (status === "completed") { + return "end_turn" + } + + if (status === "incomplete") { + if (incompleteDetails?.reason === "max_output_tokens") { + return "max_tokens" + } + if (incompleteDetails?.reason === "content_filter") { + return "end_turn" + } + if (incompleteDetails?.reason === "tool_use") { + return "tool_use" + } + } + + return null +} + +const mapResponsesUsage = ( + response: ResponsesResult, +): AnthropicResponse["usage"] => { + const inputTokens = response.usage?.input_tokens ?? 0 + const outputTokens = response.usage?.output_tokens ?? 0 + const inputCachedTokens = response.usage?.input_tokens_details?.cached_tokens + + return { + input_tokens: inputTokens - (inputCachedTokens ?? 0), + output_tokens: outputTokens, + ...(response.usage?.input_tokens_details?.cached_tokens !== undefined && { + cache_read_input_tokens: + response.usage.input_tokens_details.cached_tokens, + }), + } +} + +const isRecord = (value: unknown): value is Record => + typeof value === "object" && value !== null + +const isResponseOutputText = ( + block: ResponseOutputContentBlock, +): block is ResponseOutputText => + isRecord(block) + && "type" in block + && (block as { type?: unknown }).type === "output_text" + +const isResponseOutputRefusal = ( + block: ResponseOutputContentBlock, +): block is ResponseOutputRefusal => + isRecord(block) + && "type" in block + && (block as { type?: unknown }).type === "refusal" + +const parseUserId = ( + userId: string | undefined, +): { safetyIdentifier: string | null; promptCacheKey: string | null } => { + if (!userId || typeof userId !== "string") { + return { safetyIdentifier: null, promptCacheKey: null } + } + + // Parse safety_identifier: content between "user_" and "_account" + const userMatch = userId.match(/user_([^_]+)_account/) + const safetyIdentifier = userMatch ? userMatch[1] : null + + // Parse prompt_cache_key: content after "_session_" + const sessionMatch = userId.match(/_session_(.+)$/) + const promptCacheKey = sessionMatch ? sessionMatch[1] : null + + return { safetyIdentifier, promptCacheKey } +} + +const convertToolResultContent = ( + content: string | Array | Array, +): string | Array => { + if (typeof content === "string") { + return content + } + + if (Array.isArray(content)) { + const result: Array = [] + for (const block of content) { + switch (block.type) { + case "text": { + result.push(createTextContent(block.text)) + break + } + case "image": { + result.push(createImageContent(block)) + break + } + default: { + break + } + } + } + return result + } + + return "" +} diff --git a/src/routes/messages/route.ts b/src/routes/messages/route.ts index 1f4eee2f9..ef72d802e 100644 --- a/src/routes/messages/route.ts +++ b/src/routes/messages/route.ts @@ -2,6 +2,7 @@ import { Hono } from "hono" import { forwardError } from "~/lib/error" +import { handleCountTokens } from "./count-tokens-handler" import { handleCompletion } from "./handler" export const messageRoutes = new Hono() @@ -13,3 +14,11 @@ messageRoutes.post("/", async (c) => { return await forwardError(c, error) } }) + +messageRoutes.post("/count_tokens", async (c) => { + try { + return await handleCountTokens(c) + } catch (error) { + return await forwardError(c, error) + } +}) diff --git a/src/routes/messages/stream-translation.ts b/src/routes/messages/stream-translation.ts index c8c20a07f..5f3a6a183 100644 --- a/src/routes/messages/stream-translation.ts +++ b/src/routes/messages/stream-translation.ts @@ -41,10 +41,6 @@ export function translateChunkToAnthropicEvents( model: chunk.model, stop_reason: null, stop_sequence: null, - usage: { - input_tokens: chunk.usage?.prompt_tokens ?? 0, - output_tokens: 0, // Will be updated in message_delta when finished - }, }, }) state.messageStartSent = true @@ -152,7 +148,9 @@ export function translateChunkToAnthropicEvents( stop_sequence: null, }, usage: { - input_tokens: chunk.usage?.prompt_tokens ?? 0, + input_tokens: + (chunk.usage?.prompt_tokens ?? 0) + - (chunk.usage?.prompt_tokens_details?.cached_tokens ?? 0), output_tokens: chunk.usage?.completion_tokens ?? 0, ...(chunk.usage?.prompt_tokens_details?.cached_tokens !== undefined && { diff --git a/src/routes/responses/handler.ts b/src/routes/responses/handler.ts new file mode 100644 index 000000000..d06d02d67 --- /dev/null +++ b/src/routes/responses/handler.ts @@ -0,0 +1,94 @@ +import type { Context } from "hono" + +import consola from "consola" +import { streamSSE } from "hono/streaming" + +import { awaitApproval } from "~/lib/approval" +import { checkRateLimit } from "~/lib/rate-limit" +import { state } from "~/lib/state" +import { + createResponses, + type ResponsesPayload, + type ResponsesResult, +} from "~/services/copilot/create-responses" + +import { getResponsesRequestOptions } from "./utils" + +const RESPONSES_ENDPOINT = "/responses" + +export const handleResponses = async (c: Context) => { + await checkRateLimit(state) + + const payload = await c.req.json() + consola.debug("Responses request payload:", JSON.stringify(payload)) + + const selectedModel = state.models?.data.find( + (model) => model.id === payload.model, + ) + const supportsResponses = + selectedModel?.supported_endpoints?.includes(RESPONSES_ENDPOINT) ?? false + + if (!supportsResponses) { + return c.json( + { + error: { + message: + "This model does not support the responses endpoint. Please choose a different model.", + type: "invalid_request_error", + }, + }, + 400, + ) + } + + const { vision, initiator } = getResponsesRequestOptions(payload) + + if (state.manualApprove) { + await awaitApproval() + } + + const response = await createResponses(payload, { vision, initiator }) + + if (isStreamingRequested(payload) && isAsyncIterable(response)) { + consola.debug("Forwarding native Responses stream") + return streamSSE(c, async (stream) => { + const pingInterval = setInterval(async () => { + try { + await stream.writeSSE({ + event: "ping", + data: JSON.stringify({ timestamp: Date.now() }), + }) + } catch (error) { + consola.warn("Failed to send ping:", error) + clearInterval(pingInterval) + } + }, 3000) + + try { + for await (const chunk of response) { + consola.debug("Responses stream chunk:", JSON.stringify(chunk)) + await stream.writeSSE({ + id: (chunk as { id?: string }).id, + event: (chunk as { event?: string }).event, + data: (chunk as { data?: string }).data ?? "", + }) + } + } finally { + clearInterval(pingInterval) + } + }) + } + + consola.debug( + "Forwarding native Responses result:", + JSON.stringify(response).slice(-400), + ) + return c.json(response as ResponsesResult) +} + +const isAsyncIterable = (value: unknown): value is AsyncIterable => + Boolean(value) + && typeof (value as AsyncIterable)[Symbol.asyncIterator] === "function" + +const isStreamingRequested = (payload: ResponsesPayload): boolean => + Boolean(payload.stream) diff --git a/src/routes/responses/route.ts b/src/routes/responses/route.ts new file mode 100644 index 000000000..af2423427 --- /dev/null +++ b/src/routes/responses/route.ts @@ -0,0 +1,15 @@ +import { Hono } from "hono" + +import { forwardError } from "~/lib/error" + +import { handleResponses } from "./handler" + +export const responsesRoutes = new Hono() + +responsesRoutes.post("/", async (c) => { + try { + return await handleResponses(c) + } catch (error) { + return await forwardError(c, error) + } +}) diff --git a/src/routes/responses/utils.ts b/src/routes/responses/utils.ts new file mode 100644 index 000000000..734319cd7 --- /dev/null +++ b/src/routes/responses/utils.ts @@ -0,0 +1,67 @@ +import type { + ResponseInputItem, + ResponsesPayload, +} from "~/services/copilot/create-responses" + +export const getResponsesRequestOptions = ( + payload: ResponsesPayload, +): { vision: boolean; initiator: "agent" | "user" } => { + const vision = hasVisionInput(payload) + const initiator = hasAgentInitiator(payload) ? "agent" : "user" + + return { vision, initiator } +} + +export const hasAgentInitiator = (payload: ResponsesPayload): boolean => + getPayloadItems(payload).some((item) => { + if (!("role" in item) || !item.role) { + return true + } + const role = typeof item.role === "string" ? item.role.toLowerCase() : "" + return role === "assistant" + }) + +export const hasVisionInput = (payload: ResponsesPayload): boolean => { + const values = getPayloadItems(payload) + return values.some((item) => containsVisionContent(item)) +} + +const getPayloadItems = ( + payload: ResponsesPayload, +): Array => { + const result: Array = [] + + const { input } = payload + + if (Array.isArray(input)) { + result.push(...input) + } + + return result +} + +const containsVisionContent = (value: unknown): boolean => { + if (!value) return false + + if (Array.isArray(value)) { + return value.some((entry) => containsVisionContent(entry)) + } + + if (typeof value !== "object") { + return false + } + + const record = value as Record + const type = + typeof record.type === "string" ? record.type.toLowerCase() : undefined + + if (type === "input_image") { + return true + } + + if (Array.isArray(record.content)) { + return record.content.some((entry) => containsVisionContent(entry)) + } + + return false +} diff --git a/src/server.ts b/src/server.ts index 3cb2bb860..b7fea00e4 100644 --- a/src/server.ts +++ b/src/server.ts @@ -4,8 +4,10 @@ import { logger } from "hono/logger" import { completionRoutes } from "./routes/chat-completions/route" import { embeddingRoutes } from "./routes/embeddings/route" +import { geminiRouter } from "./routes/generate-content/route" import { messageRoutes } from "./routes/messages/route" import { modelRoutes } from "./routes/models/route" +import { responsesRoutes } from "./routes/responses/route" import { tokenRoute } from "./routes/token/route" import { usageRoute } from "./routes/usage/route" @@ -21,12 +23,16 @@ server.route("/models", modelRoutes) server.route("/embeddings", embeddingRoutes) server.route("/usage", usageRoute) server.route("/token", tokenRoute) +server.route("/responses", responsesRoutes) // Compatibility with tools that expect v1/ prefix server.route("/v1/chat/completions", completionRoutes) server.route("/v1/models", modelRoutes) server.route("/v1/embeddings", embeddingRoutes) +server.route("/v1/responses", responsesRoutes) // Anthropic compatible endpoints server.route("/v1/messages", messageRoutes) -server.post("/v1/messages/count_tokens", (c) => c.json({ input_tokens: 1 })) + +// Gemini +server.route("/", geminiRouter) diff --git a/src/services/copilot/create-chat-completions.ts b/src/services/copilot/create-chat-completions.ts index 5d38bb452..8534151da 100644 --- a/src/services/copilot/create-chat-completions.ts +++ b/src/services/copilot/create-chat-completions.ts @@ -103,6 +103,9 @@ export interface ChatCompletionResponse { prompt_tokens: number completion_tokens: number total_tokens: number + prompt_tokens_details?: { + cached_tokens: number + } } } diff --git a/src/services/copilot/create-responses.ts b/src/services/copilot/create-responses.ts new file mode 100644 index 000000000..8322cacee --- /dev/null +++ b/src/services/copilot/create-responses.ts @@ -0,0 +1,204 @@ +import consola from "consola" +import { events } from "fetch-event-stream" + +import { copilotBaseUrl, copilotHeaders } from "~/lib/api-config" +import { HTTPError } from "~/lib/error" +import { state } from "~/lib/state" + +export interface ResponsesPayload { + model: string + instructions?: string | null + input?: string | Array + tools?: Array> | null + tool_choice?: unknown + temperature?: number | null + top_p?: number | null + max_output_tokens?: number | null + metadata?: Record | null + stream?: boolean | null + response_format?: Record | null + safety_identifier?: string | null + prompt_cache_key?: string | null + parallel_tool_calls?: boolean | null + store?: boolean | null + reasoning?: Record | null + include?: Array + [key: string]: unknown +} + +export interface ResponseInputMessage { + type?: "message" + role: "user" | "assistant" | "system" | "developer" + content?: string | Array + status?: string +} + +export interface ResponseFunctionToolCallItem { + type: "function_call" + call_id: string + name: string + arguments: string + status?: "in_progress" | "completed" | "incomplete" +} + +export interface ResponseFunctionCallOutputItem { + type: "function_call_output" + call_id: string + output: string | Array + status?: "in_progress" | "completed" | "incomplete" +} + +export interface ResponseInputReasoning { + type: "reasoning" + summary: Array<{ + type: "summary_text" + text: string + }> + encrypted_content: string +} + +export type ResponseInputItem = + | ResponseInputMessage + | ResponseFunctionToolCallItem + | ResponseFunctionCallOutputItem + | ResponseInputReasoning + | Record + +export type ResponseInputContent = + | ResponseInputText + | ResponseInputImage + | Record + +export interface ResponseInputText { + type?: "input_text" | "output_text" + text: string +} + +export interface ResponseInputImage { + type: "input_image" + image_url?: string | null + file_id?: string | null + detail?: "low" | "high" | "auto" +} + +export interface ResponsesResult { + id: string + object: "response" + created_at: number + model: string + output: Array + output_text: string + status: string + usage?: ResponseUsage | null + error: Record | null + incomplete_details: Record | null + instructions: string | null + metadata: Record | null + parallel_tool_calls: boolean + temperature: number | null + tool_choice: unknown + tools: Array> + top_p: number | null +} + +export type ResponseOutputItem = + | ResponseOutputMessage + | ResponseOutputReasoning + | ResponseOutputFunctionCall + +export interface ResponseOutputMessage { + id: string + type: "message" + role: "assistant" + status: "completed" | "in_progress" | "incomplete" + content?: Array +} + +export interface ResponseOutputReasoning { + id: string + type: "reasoning" + summary?: Array + encrypted_content?: string + status: "completed" | "in_progress" | "incomplete" + [key: string]: unknown +} + +export interface ResponseReasoningBlock { + type: string + text?: string +} + +export interface ResponseOutputFunctionCall { + id: string + type: "function_call" + call_id?: string + name: string + arguments: string + status?: "in_progress" | "completed" | "incomplete" + [key: string]: unknown +} + +export type ResponseOutputContentBlock = + | ResponseOutputText + | ResponseOutputRefusal + | Record + +export interface ResponseOutputText { + type: "output_text" + text: string + annotations: Array +} + +export interface ResponseOutputRefusal { + type: "refusal" + refusal: string +} + +export interface ResponseUsage { + input_tokens: number + output_tokens?: number + total_tokens: number + input_tokens_details?: { + cached_tokens: number + } + output_tokens_details?: { + reasoning_tokens: number + } +} + +export type ResponsesStream = ReturnType +export type CreateResponsesReturn = ResponsesResult | ResponsesStream + +interface ResponsesRequestOptions { + vision: boolean + initiator: "agent" | "user" +} + +export const createResponses = async ( + payload: ResponsesPayload, + { vision, initiator }: ResponsesRequestOptions, +): Promise => { + if (!state.copilotToken) throw new Error("Copilot token not found") + + const headers: Record = { + ...copilotHeaders(state, vision), + "X-Initiator": initiator, + } + + const response = await fetch(`${copilotBaseUrl(state)}/responses`, { + method: "POST", + headers, + body: JSON.stringify(payload), + }) + + if (!response.ok) { + consola.error("Failed to create responses", response) + throw new HTTPError("Failed to create responses", response) + } + + if (payload.stream) { + return events(response) + } + + return (await response.json()) as ResponsesResult +} diff --git a/src/services/copilot/get-models.ts b/src/services/copilot/get-models.ts index 792adc480..3690ad3f5 100644 --- a/src/services/copilot/get-models.ts +++ b/src/services/copilot/get-models.ts @@ -28,6 +28,9 @@ interface ModelSupports { tool_calls?: boolean parallel_tool_calls?: boolean dimensions?: boolean + streaming?: boolean + structured_outputs?: boolean + vision?: boolean } interface ModelCapabilities { @@ -39,7 +42,7 @@ interface ModelCapabilities { type: string } -interface Model { +export interface Model { capabilities: ModelCapabilities id: string model_picker_enabled: boolean @@ -52,4 +55,5 @@ interface Model { state: string terms: string } + supported_endpoints?: Array } diff --git a/src/start.ts b/src/start.ts index a1b02303e..296e1b157 100644 --- a/src/start.ts +++ b/src/start.ts @@ -84,7 +84,11 @@ export async function runServer(options: RunServerOptions): Promise { ANTHROPIC_BASE_URL: serverUrl, ANTHROPIC_AUTH_TOKEN: "dummy", ANTHROPIC_MODEL: selectedModel, + ANTHROPIC_DEFAULT_SONNET_MODEL: selectedModel, ANTHROPIC_SMALL_FAST_MODEL: selectedSmallModel, + ANTHROPIC_DEFAULT_HAIKU_MODEL: selectedSmallModel, + DISABLE_NON_ESSENTIAL_MODEL_CALLS: "1", + CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC: "1", }, "claude", ) @@ -103,10 +107,12 @@ export async function runServer(options: RunServerOptions): Promise { consola.box( `🌐 Usage Viewer: https://ericc-ch.github.io/copilot-api?endpoint=${serverUrl}/usage`, ) - serve({ fetch: server.fetch as ServerHandler, port: options.port, + bun: { + idleTimeout: 255, // gemini timeout + }, }) } diff --git a/tests/@types/server-with-query.d.ts b/tests/@types/server-with-query.d.ts new file mode 100644 index 000000000..a8f40adf6 --- /dev/null +++ b/tests/@types/server-with-query.d.ts @@ -0,0 +1,4 @@ +// Allow importing "~/server?foo" variants in tests without impacting runtime behavior. +declare module "~/server?*" { + export const server: import("hono").Hono +} diff --git a/tests/anthropic-request.test.ts b/tests/anthropic-request.test.ts index a4a5b06b5..c86bcac13 100644 --- a/tests/anthropic-request.test.ts +++ b/tests/anthropic-request.test.ts @@ -136,6 +136,7 @@ describe("Anthropic to OpenAI translation logic", () => { { type: "thinking", thinking: "Let me think about this simple math problem...", + signature: "abc123", }, { type: "text", text: "2+2 equals 4." }, ], @@ -168,6 +169,7 @@ describe("Anthropic to OpenAI translation logic", () => { type: "thinking", thinking: "I need to call the weather API to get current weather information.", + signature: "def456", }, { type: "text", text: "I'll check the weather for you." }, { diff --git a/tests/anthropic-response.test.ts b/tests/anthropic-response.test.ts index 352f06ea7..247b554ae 100644 --- a/tests/anthropic-response.test.ts +++ b/tests/anthropic-response.test.ts @@ -100,7 +100,10 @@ describe("OpenAI to Anthropic Non-Streaming Response Translation", () => { expect(anthropicResponse.id).toBe("chatcmpl-123") expect(anthropicResponse.stop_reason).toBe("end_turn") - expect(anthropicResponse.usage.input_tokens).toBe(9) + expect(anthropicResponse.usage).toBeDefined() + if (anthropicResponse.usage) { + expect(anthropicResponse.usage.input_tokens).toBe(9) + } expect(anthropicResponse.content[0].type).toBe("text") if (anthropicResponse.content[0].type === "text") { expect(anthropicResponse.content[0].text).toBe( diff --git a/tests/generate-content/_test-utils.ts b/tests/generate-content/_test-utils.ts new file mode 100644 index 000000000..11a1dcdf1 --- /dev/null +++ b/tests/generate-content/_test-utils.ts @@ -0,0 +1,110 @@ +import { mock } from "bun:test" + +import type { + TestServer, + MockChatCompletionsModule, + MockRateLimitModule, + MockTokenCountModule, +} from "./test-types" + +export function asyncIterableFrom( + events: Array<{ data?: string }>, +): AsyncIterable<{ data: string }> { + return { + [Symbol.asyncIterator]() { + let i = 0 + return { + next(): Promise> { + if (i < events.length) { + const event = events[i++] + return Promise.resolve({ + value: { data: event.data ?? "" }, + done: false, + }) + } + return Promise.resolve({ + value: undefined as unknown as { data: string }, + done: true, + }) + }, + } + }, + } +} + +export function createMockChatCompletions(events: Array<{ data?: string }>) { + return mock.module( + "~/services/copilot/create-chat-completions", + (): MockChatCompletionsModule => ({ + createChatCompletions: () => asyncIterableFrom(events), + }), + ) +} + +export function createMockRateLimit() { + return mock.module( + "~/lib/rate-limit", + (): MockRateLimitModule => ({ + checkRateLimit: (_: unknown) => {}, + }), + ) +} + +export function createMockTokenCount(tokens: { + input: number + output: number +}) { + return mock.module( + "~/services/copilot/get-token-count", + (): MockTokenCountModule => ({ + getTokenCount: () => tokens, + }), + ) +} + +export async function makeStreamRequest( + path: string, + body: Record, + queryString?: string, +): Promise { + const serverModule = (await import(`~/server?${queryString}`)) as { + server: TestServer + } + return serverModule.server.request(path, { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify(body), + }) +} + +export async function makeRequest( + path: string, + body: Record, + queryString?: string, +): Promise { + const serverModule = (await import(`~/server?${queryString}`)) as { + server: TestServer + } + return serverModule.server.request(path, { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify(body), + }) +} + +export const commonResponseData = { + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, +} + +export const sampleGeminiRequest = { + contents: [{ role: "user", parts: [{ text: "Hello" }] }], +} + +export const sampleToolCall = { + index: 0, + type: "function", + function: { + name: "ReadFile", + arguments: '{"absolute_path": "/path/to/file.txt"}', + }, +} diff --git a/tests/generate-content/core-functionality.test.ts b/tests/generate-content/core-functionality.test.ts new file mode 100644 index 000000000..7f05608e8 --- /dev/null +++ b/tests/generate-content/core-functionality.test.ts @@ -0,0 +1,220 @@ +import { afterEach, expect, test, mock } from "bun:test" + +import type { TestServer } from "./test-types" + +import { createMockChatCompletions } from "./_test-utils" + +afterEach(() => { + mock.restore() +}) + +test("translates request and uses local tokenizer without downstream call", async () => { + let downstreamCalled = false + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => { + downstreamCalled = true + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + await mock.module("~/lib/tokenizer", () => ({ + getTokenCount: (_: unknown) => ({ input: 2, output: 3 }), + })) + + const { server } = (await import("~/server")) as { server: TestServer } + const res = await server.request("/v1beta/models/gemini-pro:countTokens", { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }) + + expect(res.status).toBe(200) + const json = (await res.json()) as { totalTokens: number } + expect(json).toEqual({ totalTokens: 5 }) + expect(downstreamCalled).toBe(false) +}) + +test("maps finish_reason stop/length/content_filter/tool_calls correctly (non-stream)", async () => { + const finishCases = [ + { fr: "stop", expected: "STOP" }, + { fr: "length", expected: "MAX_TOKENS" }, + { fr: "content_filter", expected: "SAFETY" }, + { fr: "tool_calls", expected: "STOP" }, + ] + + let idx = 0 + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => { + const fr = finishCases[idx++].fr as + | "stop" + | "length" + | "content_filter" + | "tool_calls" + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: fr, + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 2, total_tokens: 3 }, + } + }, + })) + + const { server } = (await import("~/server")) as { server: TestServer } + for (const finishCase of finishCases) { + const res = await server.request( + "/v1beta/models/gemini-pro:generateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }, + ) + expect(res.status).toBe(200) + const json = (await res.json()) as { + candidates: [{ finishReason: string }] + } + const got = json.candidates[0].finishReason + expect(got).toBe(finishCase.expected) + } +}) + +test("optional manual approval gate triggers before downstream call", async () => { + const calls: Array = [] + await mock.module("~/lib/state", () => ({ + state: { manualApprove: true }, + })) + await mock.module("~/lib/approval", () => ({ + awaitApproval: () => { + calls.push("approve") + }, + })) + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => { + calls.push("create") + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + const { server } = (await import("~/server")) as { server: TestServer } + const res = await server.request( + "/v1beta/models/gemini-pro:generateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }, + ) + + expect(res.status).toBe(200) + expect(calls).toEqual(["approve", "create"]) +}) + +test("enforces rate limit before processing (non-stream)", async () => { + await mock.module("~/lib/rate-limit", () => ({ + checkRateLimit: () => { + throw new Error("Rate limited") + }, + })) + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => { + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + const { server } = (await import("~/server")) as { server: TestServer } + const res = await server.request( + "/v1beta/models/gemini-pro:generateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }, + ) + + expect(res.status).toBe(500) + const json = (await res.json()) as { + error: { message: string; type: string } + } + expect(json).toEqual({ error: { message: "Rate limited", type: "error" } }) +}) + +test("enforces rate limit before stream", async () => { + await mock.module("~/lib/rate-limit", () => ({ + checkRateLimit: () => { + throw new Error("Rate limited stream") + }, + })) + await createMockChatCompletions([ + { + data: JSON.stringify({ + id: "c1", + choices: [{ index: 0, delta: { content: "x" }, finish_reason: null }], + }), + }, + { + data: JSON.stringify({ + id: "c1", + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + }), + }, + { data: "[DONE]" }, + ]) + + const { server } = (await import("~/server")) as { server: TestServer } + const res = await server.request( + "/v1beta/models/gemini-pro:streamGenerateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }, + ) + + expect(res.status).toBe(500) + const txt = await res.text() + expect(txt.includes("Rate limited stream")).toBe(true) +}) diff --git a/tests/generate-content/route-routing.test.ts b/tests/generate-content/route-routing.test.ts new file mode 100644 index 000000000..cb73173d8 --- /dev/null +++ b/tests/generate-content/route-routing.test.ts @@ -0,0 +1,162 @@ +import { afterEach, expect, test, mock } from "bun:test" + +function asyncIterableFrom(events: Array<{ data?: string }>) { + return { + [Symbol.asyncIterator]() { + let i = 0 + return { + next() { + if (i < events.length) + return Promise.resolve({ value: events[i++], done: false }) + return Promise.resolve({ value: undefined, done: true }) + }, + } + }, + } +} + +afterEach(() => { + mock.restore() +}) + +test("routes to stream endpoint based on URL keyword", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (_: unknown) => + asyncIterableFrom([ + { + data: JSON.stringify({ + id: "c1", + choices: [ + { index: 0, delta: { content: "hi" }, finish_reason: null }, + ], + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + }), + }, + { + data: JSON.stringify({ + id: "c1", + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + }), + }, + { data: "[DONE]" }, + ]), + })) + await mock.module("~/lib/rate-limit", () => ({ + checkRateLimit: () => {}, + })) + const { server } = await import("~/server?route-routing") + const res = await server.request( + "/v1beta/models/gemini-pro:streamGenerateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }, + ) + expect(res.status).toBe(200) + const ct = res.headers.get("content-type") || "" + expect(ct.includes("text/event-stream")).toBe(true) + const body = await res.text() + expect(body.includes("data:")).toBe(true) + expect(body.includes('"role":"model"')).toBe(true) +}) + +test("routes to countTokens endpoint based on URL keyword", async () => { + await mock.module("~/lib/tokenizer", () => ({ + getTokenCount: (_: unknown) => ({ input: 2, output: 3 }), + })) + await mock.module("~/lib/rate-limit", () => ({ + checkRateLimit: () => {}, + })) + const { server } = await import("~/server?route-routing") + const res = await server.request("/v1beta/models/gemini-pro:countTokens", { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }) + expect(res.status).toBe(200) + const json = + (await res.json()) as import("~/routes/generate-content/types").GeminiCountTokensResponse + expect(json).toEqual({ totalTokens: 5 }) +}) + +test("routes to non-stream endpoint with path exclusivity", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (_: unknown) => ({ + id: "res-2", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + }), + })) + await mock.module("~/lib/rate-limit", () => ({ + checkRateLimit: () => {}, + })) + const { server } = await import("~/server?route-routing") + const res = await server.request( + "/v1beta/models/gemini-pro:generateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }, + ) + expect(res.status).toBe(200) + const ct = res.headers.get("content-type") || "" + expect(ct.includes("application/json")).toBe(true) + const json = + (await res.json()) as import("~/routes/generate-content/types").GeminiResponse + expect(Array.isArray(json.candidates)).toBe(true) +}) + +test("does NOT mis-route to non-stream endpoint", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (_: unknown) => + asyncIterableFrom([ + { + data: JSON.stringify({ + id: "c1", + choices: [ + { index: 0, delta: { content: "x" }, finish_reason: null }, + ], + }), + }, + { + data: JSON.stringify({ + id: "c1", + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + }), + }, + { data: "[DONE]" }, + ]), + })) + await mock.module("~/lib/rate-limit", () => ({ + checkRateLimit: () => {}, + })) + const { server } = await import("~/server?route-routing") + const res = await server.request( + "/v1beta/models/gemini-pro:generateContent:streamGenerateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }, + ) + expect(res.status).toBe(200) + const ct = res.headers.get("content-type") || "" + expect(ct.includes("text/event-stream")).toBe(true) +}) diff --git a/tests/generate-content/stream-tool-call-accumulator.test.ts b/tests/generate-content/stream-tool-call-accumulator.test.ts new file mode 100644 index 000000000..9fbde349b --- /dev/null +++ b/tests/generate-content/stream-tool-call-accumulator.test.ts @@ -0,0 +1,240 @@ +import { afterEach, expect, test, mock } from "bun:test" + +function asyncIterableFrom(events: Array<{ data?: string }>) { + return { + [Symbol.asyncIterator]() { + let i = 0 + return { + next() { + if (i < events.length) + return Promise.resolve({ value: events[i++], done: false }) + return Promise.resolve({ value: undefined, done: true }) + }, + } + }, + } +} + +afterEach(() => { + mock.restore() +}) + +test("[Stream] handles complete tool call parameters in single chunk", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => + asyncIterableFrom([ + { + data: JSON.stringify({ + id: "c1", + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: 0, + type: "function", + function: { + name: "ReadFile", + arguments: '{"absolute_path": "/path/to/file.txt"}', + }, + }, + ], + }, + finish_reason: null, + }, + ], + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + }), + }, + { + data: JSON.stringify({ + id: "c1", + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + }), + }, + { data: "[DONE]" }, + ]), + })) + + await mock.module("~/lib/rate-limit", () => ({ + checkRateLimit: (_: unknown) => {}, + })) + const { server } = await import("~/server?stream-complete-params") + const res = await server.request( + "/v1beta/models/gemini-pro:streamGenerateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "Read the file" }] }], + }), + }, + ) + + expect(res.status).toBe(200) + const body = await res.text() + expect( + body.includes( + '"functionCall":{"name":"ReadFile","args":{"absolute_path":"/path/to/file.txt"}}', + ), + ).toBe(true) +}) + +test("[Stream] handles fragmented tool call parameters across multiple chunks", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => + asyncIterableFrom([ + { + data: JSON.stringify({ + id: "c1", + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: 0, + type: "function", + function: { name: "ReadFile", arguments: '{"absolu' }, + }, + ], + }, + finish_reason: null, + }, + ], + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + }), + }, + { + data: JSON.stringify({ + id: "c1", + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: 0, + type: "function", + function: { arguments: 'te_path": "/file.txt"}' }, + }, + ], + }, + finish_reason: null, + }, + ], + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + }), + }, + { + data: JSON.stringify({ + id: "c1", + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + }), + }, + { data: "[DONE]" }, + ]), + })) + + await mock.module("~/lib/rate-limit", () => ({ + checkRateLimit: (_: unknown) => {}, + })) + const { server } = await import("~/server?stream-fragmented-params") + const res = await server.request( + "/v1beta/models/gemini-pro:streamGenerateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "Read the file" }] }], + }), + }, + ) + + expect(res.status).toBe(200) + const body = await res.text() + expect( + body.includes( + '"functionCall":{"name":"ReadFile","args":{"absolute_path":"/file.txt"}}', + ), + ).toBe(true) +}) + +test("[Stream] correctly processes multiple concurrent tool calls", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => + asyncIterableFrom([ + { + data: JSON.stringify({ + id: "c1", + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: 0, + type: "function", + function: { + name: "ReadFile", + arguments: '{"path": "/read.txt"}', + }, + }, + { + index: 1, + type: "function", + function: { + name: "WriteFile", + arguments: '{"path": "/write.txt", "content": "data"}', + }, + }, + ], + }, + finish_reason: null, + }, + ], + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + }), + }, + { + data: JSON.stringify({ + id: "c1", + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + }), + }, + { data: "[DONE]" }, + ]), + })) + + await mock.module("~/lib/rate-limit", () => ({ + checkRateLimit: (_: unknown) => {}, + })) + const { server } = await import("~/server?stream-multiple-tools") + const res = await server.request( + "/v1beta/models/gemini-pro:streamGenerateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "Read and write files" }] }], + }), + }, + ) + + expect(res.status).toBe(200) + const body = await res.text() + expect( + body.includes( + '"functionCall":{"name":"ReadFile","args":{"path":"/read.txt"}}', + ), + ).toBe(true) + expect( + body.includes( + '"functionCall":{"name":"WriteFile","args":{"path":"/write.txt","content":"data"}}', + ), + ).toBe(true) +}) diff --git a/tests/generate-content/streaming.test.ts b/tests/generate-content/streaming.test.ts new file mode 100644 index 000000000..b55030175 --- /dev/null +++ b/tests/generate-content/streaming.test.ts @@ -0,0 +1,239 @@ +import { afterEach, expect, test, mock } from "bun:test" + +import type { TestServer } from "./test-types" + +import { + asyncIterableFrom, + createMockChatCompletions, + createMockRateLimit, +} from "./_test-utils" + +afterEach(() => { + mock.restore() +}) + +test("falls back to streaming when downstream returns non-stream JSON", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (_: unknown) => ({ + id: "res-3", + choices: [ + { + index: 0, + message: { role: "assistant", content: "stream me" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 2, completion_tokens: 3, total_tokens: 5 }, + }), + })) + + await createMockRateLimit() + const { server } = (await import("~/server?fallback-non-streaming")) as { + server: TestServer + } + const res = await server.request( + "/v1beta/models/gemini-pro:streamGenerateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }, + ) + + expect(res.status).toBe(200) + const ct = res.headers.get("content-type") || "" + expect(ct.includes("text/event-stream")).toBe(true) + const body = await res.text() + + expect(body.includes("data:")).toBe(true) + expect(body.includes("stream me")).toBe(true) + expect(body.includes('"finishReason":"STOP"')).toBe(true) + expect(body.includes('"usageMetadata"')).toBe(true) + + const occurrences = (body.match(/stream me/g) || []).length + expect(occurrences >= 1).toBe(true) +}) + +test("accumulates and parses partial JSON chunks", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (_: unknown) => { + const firstChunk = { + id: "c1", + choices: [ + { index: 0, delta: { content: "hello" }, finish_reason: null }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + const json = JSON.stringify(firstChunk) + const mid = Math.floor(json.length / 2) + const finishChunk = { + id: "c1", + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + return asyncIterableFrom([ + { data: json.slice(0, mid) }, + { data: json.slice(mid) }, + { data: JSON.stringify(finishChunk) }, + { data: "[DONE]" }, + ]) + }, + })) + + await createMockRateLimit() + const { server } = (await import( + "~/server?streaming-parser-accumulation" + )) as { server: TestServer } + const res = await server.request( + "/v1beta/models/gemini-pro:streamGenerateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }, + ) + + expect(res.status).toBe(200) + const ct = res.headers.get("content-type") || "" + expect(ct.includes("text/event-stream")).toBe(true) + const body = await res.text() + + const helloCount = (body.match(/hello/g) || []).length + expect(helloCount).toBe(1) + + expect(body.includes('"finishReason":"STOP"')).toBe(true) + expect(body.includes("data:")).toBe(true) +}) + +test("includes usageMetadata only on final chunk and injects empty part when only finish_reason", async () => { + await createMockChatCompletions([ + { + data: JSON.stringify({ + id: "c1", + choices: [ + { index: 0, delta: { content: "hello" }, finish_reason: null }, + ], + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + }), + }, + { + data: JSON.stringify({ + id: "c1", + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + usage: { prompt_tokens: 1, completion_tokens: 2, total_tokens: 3 }, + }), + }, + { data: "[DONE]" }, + ]) + + await createMockRateLimit() + const { server } = (await import( + "~/server?stream-finish-reason-and-empty-part" + )) as { server: TestServer } + const res = await server.request( + "/v1beta/models/gemini-pro:streamGenerateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }, + ) + + expect(res.status).toBe(200) + const body = await res.text() + + const usageCount = (body.match(/"usageMetadata"/g) || []).length + expect(usageCount).toBe(1) + + const finishStop = body.includes('"finishReason":"STOP"') + expect(finishStop).toBe(true) + + const injectedEmpty = body.includes('"parts":[{"text":""}]') + expect(injectedEmpty).toBe(true) +}) + +test("[Stream] skips tool_calls with partial JSON arguments until complete", async () => { + await createMockChatCompletions([ + { + data: JSON.stringify({ + id: "c1", + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: 0, + type: "function", + function: { name: "f", arguments: '{"a":' }, + }, + ], + }, + finish_reason: null, + }, + ], + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + }), + }, + { + data: JSON.stringify({ + id: "c1", + choices: [ + { + index: 0, + delta: { + tool_calls: [ + { + index: 0, + type: "function", + function: { name: "f", arguments: '{"a":1}' }, + }, + ], + }, + finish_reason: null, + }, + ], + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + }), + }, + { + data: JSON.stringify({ + id: "c1", + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + }), + }, + { data: "[DONE]" }, + ]) + + await createMockRateLimit() + const { server } = (await import( + "~/server?stream-skip-partial-tool-calls" + )) as { + server: TestServer + } + const res = await server.request( + "/v1beta/models/gemini-pro:streamGenerateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }, + ) + + expect(res.status).toBe(200) + const body = await res.text() + + expect(body.includes('"functionCall":{"name":"f","args"')).toBe(true) + expect(body.includes('"functionCall":{"name":"f","args":{')).toBe(true) + expect(body.includes('"functionCall":{"name":"f","args":{')).toBe(true) + expect(body.includes('"functionCall":{"name":"f","args":{"a":1}')).toBe(true) +}) diff --git a/tests/generate-content/test-types.ts b/tests/generate-content/test-types.ts new file mode 100644 index 000000000..4fc8df3f7 --- /dev/null +++ b/tests/generate-content/test-types.ts @@ -0,0 +1,74 @@ +import type { + ChatCompletionResponse, + ChatCompletionsPayload, +} from "~/services/copilot/create-chat-completions" + +// Test utility types +export interface TestServer { + request: ( + url: string, + options: { method: string; headers: Record; body: string }, + ) => Promise +} + +export interface MockChatCompletionsModule { + createChatCompletions: ( + payload: ChatCompletionsPayload, + ) => ChatCompletionResponse | AsyncIterable<{ data: string }> +} + +export interface MockRateLimitModule { + checkRateLimit: (payload: unknown) => void +} + +export interface MockTokenCountModule { + getTokenCount: () => { input: number; output: number } +} + +// Common test data types +export interface CapturedPayload extends Record { + messages?: Array<{ + role: string + content: string + tool_calls?: Array<{ + id: string + type: string + function: { name: string; arguments: string } + }> + tool_call_id?: string + }> + tools?: Array<{ + type: string + function: { name: string; parameters: Record } + }> + tool_choice?: string + model?: string +} + +// Gemini request types for tests +export interface GeminiTestRequest { + contents: Array<{ + role: "user" | "model" + parts: Array< + | { text: string } + | { functionCall: { name: string; args: Record } } + | { + functionResponse: { name: string; response: Record } + } + > + }> + tools?: Array<{ + functionDeclarations?: Array<{ + name: string + parameters: { type: string; properties?: Record } + }> + urlContext?: Record + }> + toolConfig?: { + functionCallingConfig: { mode: "AUTO" | "ANY" | "NONE" } + } + systemInstruction?: { + parts: Array<{ text: string }> + } + model?: string +} diff --git a/tests/generate-content/translation-coverage.test.ts b/tests/generate-content/translation-coverage.test.ts new file mode 100644 index 000000000..b36cdb2ce --- /dev/null +++ b/tests/generate-content/translation-coverage.test.ts @@ -0,0 +1,731 @@ +import { afterEach, expect, test, mock } from "bun:test" + +import type { CapturedPayload } from "./test-types" + +import { makeRequest } from "./_test-utils" + +afterEach(() => { + mock.restore() +}) + +test("processes function response arrays with tool call matching", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + // Should correctly process nested function response arrays + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { role: "user", parts: [{ text: "Call function" }] }, + { + role: "model", + parts: [ + { + functionCall: { name: "testFunc", args: { param: "value" } }, + }, + ], + }, + { + role: "user", + parts: [ + [ + { + functionResponse: { + name: "testFunc", + response: { result: "success" }, + }, + }, + ], + ], + }, + ], + }) + + expect(res.status).toBe(200) + // Verify nested array structure is processed correctly + const messages = capturedPayload.messages ?? [] + expect(messages.length).toBeGreaterThan(0) + + // Should successfully parse and process nested function response arrays + // The actual message structure depends on cleanup logic + // Key is that the request succeeds and messages are generated + const userMessages = messages.filter((m) => m.role === "user") + expect(userMessages.length).toBeGreaterThan(0) +}) + +test("handles function response without matching tool call", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + // Should skip function responses without matching tool calls + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { role: "user", parts: [{ text: "Call function" }] }, + { + role: "user", + parts: [ + { + functionResponse: { + name: "nonExistentFunc", + response: { result: "orphaned" }, + }, + }, + ], + }, + ], + }) + + expect(res.status).toBe(200) + const toolMessages = + capturedPayload.messages?.filter((m) => m.role === "tool") ?? [] + expect(toolMessages.length).toBe(0) + + // Verify user messages are still processed + const userMessages = + capturedPayload.messages?.filter((m) => m.role === "user") ?? [] + expect(userMessages.length).toBeGreaterThan(0) + expect(userMessages[0]?.content).toContain("Call function") +}) + +test("handles empty content merging fallback", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + // Should merge empty and whitespace-only content correctly + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { role: "user", parts: [{ text: "" }] }, // Empty text + { role: "user", parts: [{ text: " " }] }, // Whitespace only + { role: "user", parts: [{ text: "actual question" }] }, + ], + }) + + expect(res.status).toBe(200) + const userMessages = + capturedPayload.messages?.filter((m) => m.role === "user") ?? [] + expect(userMessages.length).toBe(1) + expect(userMessages[0]?.content).toContain("actual question") + // Ensure empty/whitespace content doesn't appear in merged message + expect(userMessages[0]?.content).not.toMatch(/^\s*$/) +}) + +test("handles complex content that cannot be merged", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + // Should handle complex content mixing text and function responses + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { role: "user", parts: [{ text: "First message" }] }, + { + role: "user", + parts: [ + { text: "Second message" }, + { + functionResponse: { + name: "func", + response: { data: "complex" }, + }, + }, + ], + }, + ], + }) + + expect(res.status).toBe(200) + const messages = capturedPayload.messages ?? [] + expect(messages.length).toBeGreaterThan(0) + + // Verify text messages are merged but function responses are handled separately + const userMessages = messages.filter((m) => m.role === "user") + expect(userMessages.length).toBeGreaterThan(0) + const mergedContent = userMessages.map((m) => m.content).join(" ") + expect(mergedContent).toContain("First message") + expect(mergedContent).toContain("Second message") +}) + +test("maps unsupported Gemini model names to supported ones", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + // Should map unsupported model names to supported equivalents + const res = await makeRequest( + "/v1beta/models/gemini-2.5-flash:generateContent", + { + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }, + ) + + expect(res.status).toBe(200) + expect(capturedPayload.model).toBe("gemini-2.0-flash-001") +}) + +test("preserves supported model names without mapping", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + // Should preserve already supported model names + const res = await makeRequest( + "/v1beta/models/gemini-1.5-pro:generateContent", + { + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }, + ) + + expect(res.status).toBe(200) + expect(capturedPayload.model).toBe("gemini-1.5-pro") +}) + +test("handles tool call cleanup with incomplete tool calls", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + // Should clean up incomplete tool calls (tool_calls without responses) + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { role: "user", parts: [{ text: "Search something" }] }, + { + role: "model", + parts: [{ functionCall: { name: "search", args: { query: "test" } } }], + }, + { role: "user", parts: [{ text: "What did you find?" }] }, + ], + }) + + expect(res.status).toBe(200) + // Incomplete tool calls should be removed + const assistantMessages = + capturedPayload.messages?.filter((m) => m.role === "assistant") ?? [] + expect(assistantMessages.length).toBe(0) + + // User messages should still be present + const userMessages = + capturedPayload.messages?.filter((m) => m.role === "user") ?? [] + expect(userMessages.length).toBeGreaterThan(0) +}) + +test("processes inline data with inlineData field", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + // Should process inline data (base64-encoded images) correctly + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { + role: "user", + parts: [ + { text: "Analyze this image" }, + { + inlineData: { + mimeType: "image/jpeg", + data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", + }, + }, + ], + }, + ], + }) + + expect(res.status).toBe(200) + expect(capturedPayload.messages?.length).toBe(1) + + const userMessage = capturedPayload.messages?.[0] + expect(userMessage?.role).toBe("user") + // Content should include both text and image data + const content = userMessage?.content + expect(content).toBeDefined() + expect(typeof content === "string" || Array.isArray(content)).toBe(true) +}) + +test("handles streaming tool calls with incomplete arguments", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => { + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + // This tests the streaming tool call processing logic + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [{ role: "user", parts: [{ text: "Do a search" }] }], + }) + + expect(res.status).toBe(200) +}) + +test("accumulates streaming tool call arguments correctly", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => { + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + // Should handle streaming arguments accumulation correctly + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [{ role: "user", parts: [{ text: "Search for something" }] }], + }) + + expect(res.status).toBe(200) + // The request should process successfully even with complex tool call scenarios +}) + +test("handles Google Search tool processing", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + // Should handle Google Search tool configuration and processing + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + tools: [ + { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, + }, + }, + }, + ], + contents: [{ role: "user", parts: [{ text: "Search for latest news" }] }], + }) + + expect(res.status).toBe(200) + expect(capturedPayload.messages?.length).toBe(1) + + const userMessage = capturedPayload.messages?.[0] + expect(userMessage?.role).toBe("user") + expect(userMessage?.content).toContain("latest news") + + // Google Search tool is Gemini-specific and gets translated + // It may or may not appear in the tools array depending on translation logic + // The key is that the request succeeds + expect(capturedPayload.messages).toBeDefined() +}) + +test("handles translation errors gracefully", async () => { + // Should return appropriate error status when Copilot API fails + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => { + throw new Error("Copilot API error") + }, + })) + + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [{ role: "user", parts: [{ text: "This should fail" }] }], + }) + + // Should handle the error and return appropriate status + expect(res.status).toBeGreaterThanOrEqual(400) +}) + +test("handles malformed tool calls in content processing", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => { + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + // Test malformed function call handling + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { role: "user", parts: [{ text: "Process this" }] }, + { + role: "model", + parts: [ + { + functionCall: { + name: "", // Empty name should trigger error handling + args: {}, + }, + }, + ], + }, + ], + }) + + expect(res.status).toBe(200) + // Should handle malformed calls gracefully +}) + +// Real scenario tests for multi-turn tool calls and deduplication + +test("handles multi-turn tool call conversation correctly", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "Result processed" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 50, completion_tokens: 10, total_tokens: 60 }, + } + }, + })) + + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { role: "user", parts: [{ text: "Read file A" }] }, + { + role: "model", + parts: [ + { functionCall: { name: "readFile", args: { path: "a.txt" } } }, + ], + }, + { + role: "user", + parts: [ + { + functionResponse: { + name: "readFile", + response: { content: "Content of A" }, + }, + }, + ], + }, + { + role: "model", + parts: [{ text: "File A contains: Content of A" }], + }, + { role: "user", parts: [{ text: "Now read file B" }] }, + { + role: "model", + parts: [ + { functionCall: { name: "readFile", args: { path: "b.txt" } } }, + ], + }, + { + role: "user", + parts: [ + { + functionResponse: { + name: "readFile", + response: { content: "Content of B" }, + }, + }, + ], + }, + ], + }) + + expect(res.status).toBe(200) + + // Verify message structure: user, assistant+tool_call, tool, assistant, user, assistant+tool_call, tool + const messages = capturedPayload.messages ?? [] + expect(messages.length).toBeGreaterThanOrEqual(5) + + // Verify tool call ID consistency + const assistantWithTools = messages.filter( + (m) => m.role === "assistant" && m.tool_calls, + ) + expect(assistantWithTools.length).toBeGreaterThanOrEqual(2) + + const toolMessages = messages.filter((m) => m.role === "tool") + expect(toolMessages.length).toBeGreaterThanOrEqual(2) + + // Each tool message should reference a tool_call_id + for (const toolMsg of toolMessages) { + expect(toolMsg.tool_call_id).toBeDefined() + expect(typeof toolMsg.tool_call_id).toBe("string") + } +}) + +test("handles duplicate tool responses by deduplication", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "Processed" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 20, completion_tokens: 5, total_tokens: 25 }, + } + }, + })) + + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { role: "user", parts: [{ text: "Call function" }] }, + { + role: "model", + parts: [ + { functionCall: { name: "testFunc", args: { param: "value1" } } }, + { functionCall: { name: "testFunc2", args: { param: "value2" } } }, + ], + }, + { + role: "user", + parts: [ + { + functionResponse: { + name: "testFunc", + response: { result: "first" }, + }, + }, + { + functionResponse: { + name: "testFunc2", + response: { result: "second" }, + }, + }, + // Duplicate response - should be deduplicated + { + functionResponse: { + name: "testFunc", + response: { result: "duplicate" }, + }, + }, + ], + }, + ], + }) + + expect(res.status).toBe(200) + + // Verify deduplication: should have exactly 2 tool messages (not 3) + const messages = capturedPayload.messages ?? [] + const toolMessages = messages.filter((m) => m.role === "tool") + + // Count unique tool_call_ids + const toolCallIds = new Set( + toolMessages.map((m) => m.tool_call_id).filter(Boolean), + ) + expect(toolCallIds.size).toBeLessThanOrEqual(2) +}) + +test("verifies tool_call_id length constraint (≤40 characters)", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + } + }, + })) + + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { role: "user", parts: [{ text: "Call a function" }] }, + { + role: "model", + parts: [ + { + functionCall: { + name: "veryLongFunctionNameThatMightCauseIssues", + args: { param: "test" }, + }, + }, + ], + }, + { + role: "user", + parts: [ + { + functionResponse: { + name: "veryLongFunctionNameThatMightCauseIssues", + response: { result: "ok" }, + }, + }, + ], + }, + ], + }) + + expect(res.status).toBe(200) + + const messages = capturedPayload.messages ?? [] + const assistantWithTools = messages.filter( + (m) => m.role === "assistant" && m.tool_calls, + ) + + // Verify all generated tool_call_ids are within limit + for (const msg of assistantWithTools) { + if (msg.tool_calls) { + for (const toolCall of msg.tool_calls) { + expect(toolCall.id.length).toBeLessThanOrEqual(40) + } + } + } +}) diff --git a/tests/generate-content/translation-response-coverage.test.ts b/tests/generate-content/translation-response-coverage.test.ts new file mode 100644 index 000000000..78873696c --- /dev/null +++ b/tests/generate-content/translation-response-coverage.test.ts @@ -0,0 +1,132 @@ +import { describe, it, expect } from "bun:test" + +import type { ChatCompletionResponse } from "~/services/copilot/create-chat-completions" + +import { translateOpenAIToGemini } from "~/routes/generate-content/translation" + +describe("OpenAI to Gemini Response Translation", () => { + it("should handle assistant message with tool calls having arguments", () => { + const openAIResponse: ChatCompletionResponse = { + id: "chatcmpl-123", + object: "chat.completion", + created: Date.now(), + model: "gpt-4", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: "I'll search for that", + tool_calls: [ + { + id: "call_123", + type: "function", + function: { + name: "search", + arguments: '{"query": "test query", "limit": 10}', + }, + }, + ], + }, + finish_reason: "tool_calls", + logprobs: null, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 20, + total_tokens: 30, + }, + } + + const result = translateOpenAIToGemini(openAIResponse) + + expect(result.candidates).toHaveLength(1) + expect(result.candidates[0]?.content.parts).toHaveLength(2) + expect(result.candidates[0]?.content.parts[0]).toEqual({ + text: "I'll search for that", + }) + expect(result.candidates[0]?.content.parts[1]).toEqual({ + functionCall: { + name: "search", + args: { query: "test query", limit: 10 }, + }, + }) + }) + + it("should handle assistant message with tool calls having empty arguments", () => { + const openAIResponse: ChatCompletionResponse = { + id: "chatcmpl-456", + object: "chat.completion", + created: Date.now(), + model: "gpt-4", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: "Getting current time", + tool_calls: [ + { + id: "call_456", + type: "function", + function: { + name: "get_current_time", + arguments: "", + }, + }, + ], + }, + finish_reason: "tool_calls", + logprobs: null, + }, + ], + usage: { + prompt_tokens: 5, + completion_tokens: 10, + total_tokens: 15, + }, + } + + const result = translateOpenAIToGemini(openAIResponse) + + expect(result.candidates[0]?.content.parts[1]).toEqual({ + functionCall: { + name: "get_current_time", + args: {}, + }, + }) + }) + + it("should handle assistant message with simple text content", () => { + const openAIResponse: ChatCompletionResponse = { + id: "chatcmpl-789", + object: "chat.completion", + created: Date.now(), + model: "gpt-4", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: "Here's my response", + }, + finish_reason: "stop", + logprobs: null, + }, + ], + usage: { + prompt_tokens: 15, + completion_tokens: 5, + total_tokens: 20, + }, + } + + const result = translateOpenAIToGemini(openAIResponse) + + expect(result.candidates[0]?.content.parts).toHaveLength(1) + expect(result.candidates[0]?.content.parts[0]).toEqual({ + text: "Here's my response", + }) + }) +}) diff --git a/tests/generate-content/translation.test.ts b/tests/generate-content/translation.test.ts new file mode 100644 index 000000000..85608679c --- /dev/null +++ b/tests/generate-content/translation.test.ts @@ -0,0 +1,320 @@ +import { afterEach, expect, test, mock } from "bun:test" + +import type { CapturedPayload } from "./test-types" + +import { makeRequest } from "./_test-utils" + +afterEach(() => { + mock.restore() +}) + +test("processes toolConfig AUTO/ANY/NONE mapping end-to-end", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + // toolConfig requires tools to be processed, so add tools to request + const baseRequest = { + tools: [ + { + functionDeclarations: [ + { name: "test", parameters: { type: "object" } }, + ], + }, + ], + contents: [{ role: "user", parts: [{ text: "hi" }] }], + } + + // Test AUTO -> auto + const autoRes = await makeRequest( + "/v1beta/models/gemini-pro:generateContent", + { + ...baseRequest, + toolConfig: { functionCallingConfig: { mode: "AUTO" } }, + }, + ) + expect(autoRes.status).toBe(200) + expect(capturedPayload.tool_choice).toBe("auto") + + // Test ANY -> required + const anyRes = await makeRequest( + "/v1beta/models/gemini-pro:generateContent", + { + ...baseRequest, + toolConfig: { functionCallingConfig: { mode: "ANY" } }, + }, + ) + expect(anyRes.status).toBe(200) + expect(capturedPayload.tool_choice).toBe("required") + + // Test NONE -> none + const noneRes = await makeRequest( + "/v1beta/models/gemini-pro:generateContent", + { + ...baseRequest, + toolConfig: { functionCallingConfig: { mode: "NONE" } }, + }, + ) + expect(noneRes.status).toBe(200) + expect(capturedPayload.tool_choice).toBe("none") +}) + +test("handles urlContext tool filtering in request", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + tools: [ + { urlContext: {} }, + { + functionDeclarations: [ + { name: "readFile", parameters: { type: "object" } }, + ], + }, + ], + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }) + + expect(res.status).toBe(200) + expect(capturedPayload.tools).toBeDefined() + const toolNames = new Set( + capturedPayload.tools?.map((t) => t.function.name) ?? [], + ) + expect(toolNames.has("readFile")).toBe(true) + expect(toolNames.has("urlContext")).toBe(false) +}) + +test("synthesizes tools from function calls when tools not provided", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { role: "user", parts: [{ text: "Do a web search" }] }, + { + role: "model", + parts: [{ functionCall: { name: "search", args: { query: "cats" } } }], + }, + ], + }) + + expect(res.status).toBe(200) + expect(capturedPayload.tools).toBeDefined() + const toolNames = capturedPayload.tools?.map((t) => t.function.name) ?? [] + expect(toolNames.includes("search")).toBe(true) +}) + +test("handles same-role message merging behavior", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { role: "user", parts: [{ text: "Hello." }] }, + { role: "user", parts: [{ text: "How are you?" }] }, + ], + }) + + expect(res.status).toBe(200) + const userMessages = + capturedPayload.messages?.filter((m) => m.role === "user") ?? [] + expect(userMessages.length).toBe(1) + expect(userMessages[0]?.content).toContain("Hello.") + expect(userMessages[0]?.content).toContain("How are you?") +}) + +test("handles incomplete tool calls cleanup", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { role: "user", parts: [{ text: "Search for cats." }] }, + { + role: "model", + parts: [{ functionCall: { name: "search", args: { query: "cats" } } }], + }, + { role: "user", parts: [{ text: "Show me results." }] }, + ], + }) + + expect(res.status).toBe(200) + const assistantMessages = + capturedPayload.messages?.filter((m) => m.role === "assistant") ?? [] + expect(assistantMessages.length).toBe(0) + const userMessages = + capturedPayload.messages?.filter((m) => m.role === "user") ?? [] + expect(userMessages.length).toBeGreaterThan(0) +}) + +test("handles system instruction in contents", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + systemInstruction: { parts: [{ text: "You are a helpful assistant" }] }, + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }) + + expect(res.status).toBe(200) + const systemMessage = capturedPayload.messages?.find( + (m) => m.role === "system", + ) + expect(systemMessage).toBeDefined() + expect(systemMessage?.content).toContain("helpful assistant") +}) + +test("handles empty contents gracefully", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => { + throw new Error("Should not be called with empty contents") + }, + })) + + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [], + }) + + // Empty contents cause translation error, expect 500 status + expect(res.status).toBe(500) +}) + +test("handles complex tool call workflow", async () => { + let capturedPayload: CapturedPayload = {} as CapturedPayload + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (payload: CapturedPayload) => { + capturedPayload = payload + return { + id: "x", + choices: [ + { + index: 0, + message: { role: "assistant", content: "ok" }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + })) + + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [ + { role: "user", parts: [{ text: "Read a file" }] }, + { + role: "model", + parts: [ + { functionCall: { name: "readFile", args: { path: "test.txt" } } }, + ], + }, + { + role: "user", + parts: [ + { + functionResponse: { + name: "readFile", + response: { content: "Hello World" }, + }, + }, + ], + }, + ], + }) + + expect(res.status).toBe(200) + expect( + capturedPayload.messages?.some( + (m) => m.role === "assistant" && m.tool_calls, + ), + ).toBe(true) + expect(capturedPayload.messages?.some((m) => m.role === "tool")).toBe(true) +}) diff --git a/tests/generate-content/validation-and-routing.test.ts b/tests/generate-content/validation-and-routing.test.ts new file mode 100644 index 000000000..de11b817e --- /dev/null +++ b/tests/generate-content/validation-and-routing.test.ts @@ -0,0 +1,236 @@ +import { afterEach, expect, test, mock } from "bun:test" + +import { + asyncIterableFrom, + createMockRateLimit, + makeRequest, +} from "./_test-utils" + +afterEach(() => { + mock.restore() +}) + +test("forwards generic errors as HTTP 500", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => { + throw new Error("Internal issue") + }, + })) + const { server } = await import("~/server") + const res = await server.request( + "/v1beta/models/gemini-pro:streamGenerateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }, + ) + expect(res.status).toBe(500) + const json = (await res.json()) as { + error: { message: string; type: string } + } + expect(json).toEqual({ error: { message: "Internal issue", type: "error" } }) +}) + +test("requires model in URL for non-stream endpoint", async () => { + const { server } = await import("~/server") + const res = await server.request("/v1beta/models/:generateContent", { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }) + + expect(res.status).toBe(500) + const json = await res.json() + expect(json).toEqual({ + error: { message: "Model name is required in URL path", type: "error" }, + }) +}) + +test("requires model in URL for stream endpoint", async () => { + const { server } = await import("~/server") + const res = await server.request("/v1beta/models/:streamGenerateContent", { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }) + + expect(res.status).toBe(500) + const json = await res.json() + expect(json).toEqual({ + error: { message: "Model name is required in URL path", type: "error" }, + }) +}) + +test("requires model in URL for countTokens endpoint", async () => { + const { server } = await import("~/server") + const res = await server.request("/v1beta/models/:countTokens", { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }) + + expect(res.status).toBe(500) + const json = await res.json() + expect(json).toEqual({ + error: { message: "Model name is required in URL path", type: "error" }, + }) +}) + +test("streams fallback response when no text content in non-streaming to streaming conversion", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (_: unknown) => ({ + id: "res-fallback", + choices: [ + { + index: 0, + message: { role: "assistant", content: null }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 1, completion_tokens: 0, total_tokens: 1 }, + }), + })) + + await createMockRateLimit() + + const { server } = await import("~/server?fallback-response-no-text") + + const res = await server.request( + "/v1beta/models/gemini-pro:streamGenerateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "test" }] }], + }), + }, + ) + + expect(res.status).toBe(200) + const ct = res.headers.get("content-type") || "" + expect(ct.includes("text/event-stream")).toBe(true) + + const body = await res.text() + + expect(body.includes("data:")).toBe(true) + expect(body.includes('"candidates"')).toBe(true) + expect(body.includes('"usageMetadata"')).toBe(true) +}) + +test("non-stream endpoint rejects streaming response with 500", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: (_: unknown) => + asyncIterableFrom([ + { + data: JSON.stringify({ + id: "c1", + choices: [ + { index: 0, delta: { content: "x" }, finish_reason: null }, + ], + }), + }, + { + data: JSON.stringify({ + id: "c1", + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + }), + }, + { data: "[DONE]" }, + ]), + })) + + const { server } = await import("~/server") + const res = await server.request( + "/v1beta/models/gemini-pro:generateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }), + }, + ) + + expect(res.status).toBe(500) + const json = await res.json() + expect(json).toEqual({ + error: { + message: "Unexpected streaming response for non-streaming endpoint", + type: "error", + }, + }) +}) + +test("routes fallthrough when URL doesn't match any generate-content patterns", async () => { + await createMockRateLimit() + + const { server } = await import("~/server?route-fallthrough") + + const res = await server.request( + "/v1beta/models/gemini-pro:unknownOperation", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + contents: [{ role: "user", parts: [{ text: "test" }] }], + }), + }, + ) + + expect(res.status).toBe(404) +}) + +test("handles HTTP errors with proper error codes", async () => { + await mock.module("~/services/copilot/create-chat-completions", () => ({ + createChatCompletions: () => { + const error = new Error("Bad Request") + // Simulate HTTPError-like structure + Object.assign(error, { status: 400 }) + throw error + }, + })) + + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + contents: [{ role: "user", parts: [{ text: "hi" }] }], + }) + + // 由于错误处理机制,HTTP错误也会转为500 + expect(res.status).toBe(500) + const json = (await res.json()) as { + error: { message: string; type: string } + } + expect(json.error.message).toContain("Bad Request") +}) + +test("handles malformed JSON in request body", async () => { + const { server } = await import("~/server") + const res = await server.request( + "/v1beta/models/gemini-pro:generateContent", + { + method: "POST", + headers: { "content-type": "application/json" }, + body: "{ invalid json", + }, + ) + + // JSON parsing errors will return 500 + expect(res.status).toBe(500) +}) + +test("validates required contents field in request", async () => { + const res = await makeRequest("/v1beta/models/gemini-pro:generateContent", { + // Missing contents field + model: "gemini-pro", + }) + + expect([400, 500]).toContain(res.status) +}) diff --git a/tests/responses-stream-translation.test.ts b/tests/responses-stream-translation.test.ts new file mode 100644 index 000000000..9f149e1bd --- /dev/null +++ b/tests/responses-stream-translation.test.ts @@ -0,0 +1,137 @@ +import { describe, expect, test } from "bun:test" + +import type { AnthropicStreamEventData } from "~/routes/messages/anthropic-types" + +import { + createResponsesStreamState, + translateResponsesStreamEvent, +} from "~/routes/messages/responses-stream-translation" + +const createFunctionCallAddedEvent = () => ({ + type: "response.output_item.added", + output_index: 1, + item: { + id: "item-1", + type: "function_call", + call_id: "call-1", + name: "TodoWrite", + arguments: "", + status: "in_progress", + }, +}) + +describe("translateResponsesStreamEvent tool calls", () => { + test("streams function call arguments across deltas", () => { + const state = createResponsesStreamState() + + const events = [ + translateResponsesStreamEvent(createFunctionCallAddedEvent(), state), + translateResponsesStreamEvent( + { + type: "response.function_call_arguments.delta", + output_index: 1, + delta: '{"todos":', + }, + state, + ), + translateResponsesStreamEvent( + { + type: "response.function_call_arguments.delta", + output_index: 1, + delta: "[]}", + }, + state, + ), + translateResponsesStreamEvent( + { + type: "response.function_call_arguments.done", + output_index: 1, + arguments: '{"todos":[]}', + }, + state, + ), + ].flat() + + const messageStart = events.find((event) => event.type === "message_start") + expect(messageStart).toBeDefined() + + const blockStart = events.find( + (event) => event.type === "content_block_start", + ) + expect(blockStart).toBeDefined() + if (blockStart?.type === "content_block_start") { + expect(blockStart.content_block).toEqual({ + type: "tool_use", + id: "call-1", + name: "TodoWrite", + input: {}, + }) + } + + const deltas = events.filter( + ( + event, + ): event is Extract< + AnthropicStreamEventData, + { type: "content_block_delta" } + > => event.type === "content_block_delta", + ) + expect(deltas).toHaveLength(2) + expect(deltas[0].delta).toEqual({ + type: "input_json_delta", + partial_json: '{"todos":', + }) + expect(deltas[1].delta).toEqual({ + type: "input_json_delta", + partial_json: "[]}", + }) + + const blockStop = events.find( + (event) => event.type === "content_block_stop", + ) + expect(blockStop).toBeDefined() + + expect(state.openBlocks.size).toBe(0) + expect(state.functionCallStateByOutputIndex.size).toBe(0) + }) + + test("emits full arguments when only done payload is present", () => { + const state = createResponsesStreamState() + + const events = [ + translateResponsesStreamEvent(createFunctionCallAddedEvent(), state), + translateResponsesStreamEvent( + { + type: "response.function_call_arguments.done", + output_index: 1, + arguments: + '{"todos":[{"content":"Review src/routes/responses/translation.ts"}]}', + }, + state, + ), + ].flat() + + const deltas = events.filter( + ( + event, + ): event is Extract< + AnthropicStreamEventData, + { type: "content_block_delta" } + > => event.type === "content_block_delta", + ) + expect(deltas).toHaveLength(1) + expect(deltas[0].delta).toEqual({ + type: "input_json_delta", + partial_json: + '{"todos":[{"content":"Review src/routes/responses/translation.ts"}]}', + }) + + const blockStop = events.find( + (event) => event.type === "content_block_stop", + ) + expect(blockStop).toBeDefined() + + expect(state.openBlocks.size).toBe(0) + expect(state.functionCallStateByOutputIndex.size).toBe(0) + }) +}) diff --git a/tests/translation.test.ts b/tests/translation.test.ts new file mode 100644 index 000000000..0c3ececb2 --- /dev/null +++ b/tests/translation.test.ts @@ -0,0 +1,161 @@ +import { describe, expect, it } from "bun:test" + +import type { AnthropicMessagesPayload } from "~/routes/messages/anthropic-types" +import type { + ResponseInputMessage, + ResponsesResult, +} from "~/services/copilot/create-responses" + +import { + translateAnthropicMessagesToResponsesPayload, + translateResponsesResultToAnthropic, +} from "~/routes/messages/responses-translation" + +const samplePayload = { + model: "claude-3-5-sonnet", + max_tokens: 1024, + messages: [ + { + role: "user", + content: [ + { + type: "text", + text: "\nThis is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware. If you are working on tasks that would benefit from a todo list please use the TodoWrite tool to create one. If not, please feel free to ignore. Again do not mention this message to the user.\n", + }, + { + type: "text", + text: "\nAs you answer the user's questions, you can use the following context:\n# important-instruction-reminders\nDo what has been asked; nothing more, nothing less.\nNEVER create files unless they're absolutely necessary for achieving your goal.\nALWAYS prefer editing an existing file to creating a new one.\nNEVER proactively create documentation files (*.md) or README files. Only create documentation files if explicitly requested by the User.\n\n \n IMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task.\n", + }, + { + type: "text", + text: "hi", + }, + { + type: "text", + text: "\nThe user opened the file c:\\Work2\\copilot-api\\src\\routes\\responses\\translation.ts in the IDE. This may or may not be related to the current task.\n", + }, + { + type: "text", + text: "hi", + cache_control: { + type: "ephemeral", + }, + }, + ], + }, + ], +} as unknown as AnthropicMessagesPayload + +describe("translateAnthropicMessagesToResponsesPayload", () => { + it("converts anthropic text blocks into response input messages", () => { + const result = translateAnthropicMessagesToResponsesPayload(samplePayload) + + console.log("result:", JSON.stringify(result, null, 2)) + expect(Array.isArray(result.input)).toBe(true) + const input = result.input as Array + expect(input).toHaveLength(1) + + const message = input[0] + expect(message.role).toBe("user") + expect(Array.isArray(message.content)).toBe(true) + + const content = message.content as Array<{ text: string }> + expect(content.map((item) => item.text)).toEqual([ + "\nThis is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware. If you are working on tasks that would benefit from a todo list please use the TodoWrite tool to create one. If not, please feel free to ignore. Again do not mention this message to the user.\n", + "\nAs you answer the user's questions, you can use the following context:\n# important-instruction-reminders\nDo what has been asked; nothing more, nothing less.\nNEVER create files unless they're absolutely necessary for achieving your goal.\nALWAYS prefer editing an existing file to creating a new one.\nNEVER proactively create documentation files (*.md) or README files. Only create documentation files if explicitly requested by the User.\n\n \n IMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task.\n", + "hi", + "\nThe user opened the file c:\\Work2\\copilot-api\\src\\routes\\responses\\translation.ts in the IDE. This may or may not be related to the current task.\n", + "hi", + ]) + }) +}) + +describe("translateResponsesResultToAnthropic", () => { + it("handles reasoning and function call items", () => { + const responsesResult: ResponsesResult = { + id: "resp_123", + object: "response", + created_at: 0, + model: "gpt-4.1", + output: [ + { + id: "reason_1", + type: "reasoning", + summary: [{ type: "text", text: "Thinking about the task." }], + status: "completed", + encrypted_content: "encrypted_reasoning_content", + }, + { + id: "call_1", + type: "function_call", + call_id: "call_1", + name: "TodoWrite", + arguments: + '{"todos":[{"content":"Read src/routes/responses/translation.ts","status":"in_progress"}]}', + status: "completed", + }, + { + id: "message_1", + type: "message", + role: "assistant", + status: "completed", + content: [ + { + type: "output_text", + text: "Added the task to your todo list.", + annotations: [], + }, + ], + }, + ], + output_text: "Added the task to your todo list.", + status: "incomplete", + usage: { + input_tokens: 120, + output_tokens: 36, + total_tokens: 156, + }, + error: null, + incomplete_details: { reason: "tool_use" }, + instructions: null, + metadata: null, + parallel_tool_calls: false, + temperature: null, + tool_choice: null, + tools: [], + top_p: null, + } + + const anthropicResponse = + translateResponsesResultToAnthropic(responsesResult) + + expect(anthropicResponse.stop_reason).toBe("tool_use") + expect(anthropicResponse.content).toHaveLength(3) + + const [thinkingBlock, toolUseBlock, textBlock] = anthropicResponse.content + + expect(thinkingBlock.type).toBe("thinking") + if (thinkingBlock.type === "thinking") { + expect(thinkingBlock.thinking).toContain("Thinking about the task") + } + + expect(toolUseBlock.type).toBe("tool_use") + if (toolUseBlock.type === "tool_use") { + expect(toolUseBlock.id).toBe("call_1") + expect(toolUseBlock.name).toBe("TodoWrite") + expect(toolUseBlock.input).toEqual({ + todos: [ + { + content: "Read src/routes/responses/translation.ts", + status: "in_progress", + }, + ], + }) + } + + expect(textBlock.type).toBe("text") + if (textBlock.type === "text") { + expect(textBlock.text).toBe("Added the task to your todo list.") + } + }) +})