diff --git a/README.md b/README.md index e5b390991..c36301dd7 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,7 @@ The following command line options are available for the `start` command: | --wait | Wait instead of error when rate limit is hit | false | -w | | --github-token | Provide GitHub token directly (must be generated using the `auth` subcommand) | none | -g | | --claude-code | Generate a command to launch Claude Code with Copilot API config | false | -c | +| --claude-code-env | Generate Claude Code Environment variables | true | none | | --show-token | Show GitHub and Copilot tokens on fetch and refresh | false | none | ### Auth Command Options diff --git a/src/lib/tokenizer.ts b/src/lib/tokenizer.ts index 73cd499f9..8c3eda736 100644 --- a/src/lib/tokenizer.ts +++ b/src/lib/tokenizer.ts @@ -1,37 +1,345 @@ -import { countTokens } from "gpt-tokenizer/model/gpt-4o" +import type { + ChatCompletionsPayload, + ContentPart, + Message, + Tool, + ToolCall, +} from "~/services/copilot/create-chat-completions" +import type { Model } from "~/services/copilot/get-models" -import type { Message } from "~/services/copilot/create-chat-completions" +// Encoder type mapping +const ENCODING_MAP = { + o200k_base: () => import("gpt-tokenizer/encoding/o200k_base"), + cl100k_base: () => import("gpt-tokenizer/encoding/cl100k_base"), + p50k_base: () => import("gpt-tokenizer/encoding/p50k_base"), + p50k_edit: () => import("gpt-tokenizer/encoding/p50k_edit"), + r50k_base: () => import("gpt-tokenizer/encoding/r50k_base"), +} as const -export const getTokenCount = (messages: Array) => { - const simplifiedMessages = messages.map((message) => { - let content = "" - if (typeof message.content === "string") { - content = message.content - } else if (Array.isArray(message.content)) { - content = message.content - .filter((part) => part.type === "text") - .map((part) => (part as { text: string }).text) - .join("") +type SupportedEncoding = keyof typeof ENCODING_MAP + +// Define encoder interface +interface Encoder { + encode: (text: string) => Array +} + +// Cache loaded encoders to avoid repeated imports +const encodingCache = new Map() + +/** + * Calculate tokens for tool calls + */ +const calculateToolCallsTokens = ( + toolCalls: Array, + encoder: Encoder, + constants: ReturnType, +): number => { + let tokens = 0 + for (const toolCall of toolCalls) { + tokens += constants.funcInit + tokens += encoder.encode(JSON.stringify(toolCall)).length + } + tokens += constants.funcEnd + return tokens +} + +/** + * Calculate tokens for content parts + */ +const calculateContentPartsTokens = ( + contentParts: Array, + encoder: Encoder, +): number => { + let tokens = 0 + for (const part of contentParts) { + if (part.type === "image_url") { + tokens += encoder.encode(part.image_url.url).length + 85 + } else if (part.text) { + tokens += encoder.encode(part.text).length + } + } + return tokens +} + +/** + * Calculate tokens for a single message + */ +const calculateMessageTokens = ( + message: Message, + encoder: Encoder, + constants: ReturnType, +): number => { + const tokensPerMessage = 3 + const tokensPerName = 1 + let tokens = tokensPerMessage + for (const [key, value] of Object.entries(message)) { + if (typeof value === "string") { + tokens += encoder.encode(value).length + } + if (key === "name") { + tokens += tokensPerName + } + if (key === "tool_calls") { + tokens += calculateToolCallsTokens( + value as Array, + encoder, + constants, + ) + } + if (key === "content" && Array.isArray(value)) { + tokens += calculateContentPartsTokens( + value as Array, + encoder, + ) + } + } + return tokens +} + +/** + * Calculate tokens using custom algorithm + */ +const calculateTokens = ( + messages: Array, + encoder: Encoder, + constants: ReturnType, +): number => { + if (messages.length === 0) { + return 0 + } + let numTokens = 0 + for (const message of messages) { + numTokens += calculateMessageTokens(message, encoder, constants) + } + // every reply is primed with <|start|>assistant<|message|> + numTokens += 3 + return numTokens +} + +/** + * Get the corresponding encoder module based on encoding type + */ +const getEncodeChatFunction = async (encoding: string): Promise => { + if (encodingCache.has(encoding)) { + const cached = encodingCache.get(encoding) + if (cached) { + return cached } - return { ...message, content } - }) + } + + const supportedEncoding = encoding as SupportedEncoding + if (!(supportedEncoding in ENCODING_MAP)) { + const fallbackModule = (await ENCODING_MAP.o200k_base()) as Encoder + encodingCache.set(encoding, fallbackModule) + return fallbackModule + } + + const encodingModule = (await ENCODING_MAP[supportedEncoding]()) as Encoder + encodingCache.set(encoding, encodingModule) + return encodingModule +} + +/** + * Get tokenizer type from model information + */ +export const getTokenizerFromModel = (model: Model): string => { + return model.capabilities.tokenizer || "o200k_base" +} + +/** + * Get model-specific constants for token calculation + */ +const getModelConstants = (model: Model) => { + return model.id === "gpt-3.5-turbo" || model.id === "gpt-4" ? + { + funcInit: 10, + propInit: 3, + propKey: 3, + enumInit: -3, + enumItem: 3, + funcEnd: 12, + } + : { + funcInit: 7, + propInit: 3, + propKey: 3, + enumInit: -3, + enumItem: 3, + funcEnd: 12, + } +} - let inputMessages = simplifiedMessages.filter((message) => { - return message.role !== "tool" - }) - let outputMessages: typeof simplifiedMessages = [] +/** + * Calculate tokens for a single parameter + */ +const calculateParameterTokens = ( + key: string, + prop: unknown, + context: { + encoder: Encoder + constants: ReturnType + }, +): number => { + const { encoder, constants } = context + let tokens = constants.propKey - const lastMessage = simplifiedMessages.at(-1) + // Early return if prop is not an object + if (typeof prop !== "object" || prop === null) { + return tokens + } - if (lastMessage?.role === "assistant") { - inputMessages = simplifiedMessages.slice(0, -1) - outputMessages = [lastMessage] + // Type assertion for parameter properties + const param = prop as { + type?: string + description?: string + enum?: Array + [key: string]: unknown } - // @ts-expect-error TS can't infer from arr.filter() - const inputTokens = countTokens(inputMessages) - // @ts-expect-error TS can't infer from arr.filter() - const outputTokens = countTokens(outputMessages) + const paramName = key + const paramType = param.type || "string" + let paramDesc = param.description || "" + + // Handle enum values + if (param.enum && Array.isArray(param.enum)) { + tokens += constants.enumInit + for (const item of param.enum) { + tokens += constants.enumItem + tokens += encoder.encode(String(item)).length + } + } + + // Clean up description + if (paramDesc.endsWith(".")) { + paramDesc = paramDesc.slice(0, -1) + } + + // Encode the main parameter line + const line = `${paramName}:${paramType}:${paramDesc}` + tokens += encoder.encode(line).length + + // Handle additional properties (excluding standard ones) + const excludedKeys = new Set(["type", "description", "enum"]) + for (const propertyName of Object.keys(param)) { + if (!excludedKeys.has(propertyName)) { + const propertyValue = param[propertyName] + const propertyText = + typeof propertyValue === "string" ? propertyValue : ( + JSON.stringify(propertyValue) + ) + tokens += encoder.encode(`${propertyName}:${propertyText}`).length + } + } + + return tokens +} + +/** + * Calculate tokens for function parameters + */ +const calculateParametersTokens = ( + parameters: unknown, + encoder: Encoder, + constants: ReturnType, +): number => { + if (!parameters || typeof parameters !== "object") { + return 0 + } + + const params = parameters as Record + let tokens = 0 + + for (const [key, value] of Object.entries(params)) { + if (key === "properties") { + const properties = value as Record + if (Object.keys(properties).length > 0) { + tokens += constants.propInit + for (const propKey of Object.keys(properties)) { + tokens += calculateParameterTokens(propKey, properties[propKey], { + encoder, + constants, + }) + } + } + } else { + const paramText = + typeof value === "string" ? value : JSON.stringify(value) + tokens += encoder.encode(`${key}:${paramText}`).length + } + } + + return tokens +} + +/** + * Calculate tokens for a single tool + */ +const calculateToolTokens = ( + tool: Tool, + encoder: Encoder, + constants: ReturnType, +): number => { + let tokens = constants.funcInit + const func = tool.function + const fName = func.name + let fDesc = func.description || "" + if (fDesc.endsWith(".")) { + fDesc = fDesc.slice(0, -1) + } + const line = fName + ":" + fDesc + tokens += encoder.encode(line).length + if ( + typeof func.parameters === "object" // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition + && func.parameters !== null + ) { + tokens += calculateParametersTokens(func.parameters, encoder, constants) + } + return tokens +} + +/** + * Calculate token count for tools based on model + */ +export const numTokensForTools = ( + tools: Array, + encoder: Encoder, + constants: ReturnType, +): number => { + let funcTokenCount = 0 + for (const tool of tools) { + funcTokenCount += calculateToolTokens(tool, encoder, constants) + } + funcTokenCount += constants.funcEnd + return funcTokenCount +} + +/** + * Calculate the token count of messages, supporting multiple GPT encoders + */ +export const getTokenCount = async ( + payload: ChatCompletionsPayload, + model: Model, +): Promise<{ input: number; output: number }> => { + // Get tokenizer string + const tokenizer = getTokenizerFromModel(model) + + // Get corresponding encoder module + const encoder = await getEncodeChatFunction(tokenizer) + + const simplifiedMessages = payload.messages + const inputMessages = simplifiedMessages.filter( + (msg) => msg.role !== "assistant", + ) + const outputMessages = simplifiedMessages.filter( + (msg) => msg.role === "assistant", + ) + + const constants = getModelConstants(model) + let inputTokens = calculateTokens(inputMessages, encoder, constants) + if (payload.tools && payload.tools.length > 0) { + inputTokens += numTokensForTools(payload.tools, encoder, constants) + } + const outputTokens = calculateTokens(outputMessages, encoder, constants) return { input: inputTokens, diff --git a/src/routes/chat-completions/handler.ts b/src/routes/chat-completions/handler.ts index 6e49029b8..04a5ae9ed 100644 --- a/src/routes/chat-completions/handler.ts +++ b/src/routes/chat-completions/handler.ts @@ -20,15 +20,26 @@ export async function handleCompletion(c: Context) { let payload = await c.req.json() consola.debug("Request payload:", JSON.stringify(payload).slice(-400)) - consola.info("Current token count:", getTokenCount(payload.messages)) + // Find the selected model + const selectedModel = state.models?.data.find( + (model) => model.id === payload.model, + ) + + // Calculate and display token count + try { + if (selectedModel) { + const tokenCount = await getTokenCount(payload, selectedModel) + consola.info("Current token count:", tokenCount) + } else { + consola.warn("No model selected, skipping token count calculation") + } + } catch (error) { + consola.warn("Failed to calculate token count:", error) + } if (state.manualApprove) await awaitApproval() if (isNullish(payload.max_tokens)) { - const selectedModel = state.models?.data.find( - (model) => model.id === payload.model, - ) - payload = { ...payload, max_tokens: selectedModel?.capabilities.limits.max_output_tokens, diff --git a/src/routes/messages/anthropic-types.ts b/src/routes/messages/anthropic-types.ts index 881fffcc8..100d906ee 100644 --- a/src/routes/messages/anthropic-types.ts +++ b/src/routes/messages/anthropic-types.ts @@ -101,7 +101,7 @@ export interface AnthropicResponse { | "refusal" | null stop_sequence: string | null - usage: { + usage?: { input_tokens: number output_tokens: number cache_creation_input_tokens?: number diff --git a/src/routes/messages/count-tokens-handler.ts b/src/routes/messages/count-tokens-handler.ts new file mode 100644 index 000000000..2ec849cb8 --- /dev/null +++ b/src/routes/messages/count-tokens-handler.ts @@ -0,0 +1,70 @@ +import type { Context } from "hono" + +import consola from "consola" + +import { state } from "~/lib/state" +import { getTokenCount } from "~/lib/tokenizer" + +import { type AnthropicMessagesPayload } from "./anthropic-types" +import { translateToOpenAI } from "./non-stream-translation" + +/** + * Handles token counting for Anthropic messages + */ +export async function handleCountTokens(c: Context) { + try { + const anthropicBeta = c.req.header("anthropic-beta") + + const anthropicPayload = await c.req.json() + + const openAIPayload = translateToOpenAI(anthropicPayload) + + const selectedModel = state.models?.data.find( + (model) => model.id === anthropicPayload.model, + ) + + if (!selectedModel) { + consola.warn("Model not found, returning default token count") + return c.json({ + input_tokens: 1, + }) + } + + const tokenCount = await getTokenCount(openAIPayload, selectedModel) + + if (anthropicPayload.tools && anthropicPayload.tools.length > 0) { + let mcpToolExist = false + if (anthropicBeta?.startsWith("claude-code")) { + mcpToolExist = anthropicPayload.tools.some((tool) => + tool.name.startsWith("mcp__"), + ) + } + if (!mcpToolExist) { + if (anthropicPayload.model.startsWith("claude")) { + // https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/overview#pricing + tokenCount.input = tokenCount.input + 346 + } else if (anthropicPayload.model.startsWith("grok")) { + tokenCount.input = tokenCount.input + 480 + } + } + } + + let finalTokenCount = tokenCount.input + tokenCount.output + if (anthropicPayload.model.startsWith("claude")) { + finalTokenCount = Math.round(finalTokenCount * 1.15) + } else if (anthropicPayload.model.startsWith("grok")) { + finalTokenCount = Math.round(finalTokenCount * 1.03) + } + + consola.info("Token count:", finalTokenCount) + + return c.json({ + input_tokens: finalTokenCount, + }) + } catch (error) { + consola.error("Error counting tokens:", error) + return c.json({ + input_tokens: 1, + }) + } +} diff --git a/src/routes/messages/non-stream-translation.ts b/src/routes/messages/non-stream-translation.ts index 271aa47f6..dc41e6382 100644 --- a/src/routes/messages/non-stream-translation.ts +++ b/src/routes/messages/non-stream-translation.ts @@ -313,8 +313,15 @@ export function translateToAnthropic( stop_reason: mapOpenAIStopReasonToAnthropic(stopReason), stop_sequence: null, usage: { - input_tokens: response.usage?.prompt_tokens ?? 0, + input_tokens: + (response.usage?.prompt_tokens ?? 0) + - (response.usage?.prompt_tokens_details?.cached_tokens ?? 0), output_tokens: response.usage?.completion_tokens ?? 0, + ...(response.usage?.prompt_tokens_details?.cached_tokens + !== undefined && { + cache_read_input_tokens: + response.usage.prompt_tokens_details.cached_tokens, + }), }, } } diff --git a/src/routes/messages/route.ts b/src/routes/messages/route.ts index 1f4eee2f9..ef72d802e 100644 --- a/src/routes/messages/route.ts +++ b/src/routes/messages/route.ts @@ -2,6 +2,7 @@ import { Hono } from "hono" import { forwardError } from "~/lib/error" +import { handleCountTokens } from "./count-tokens-handler" import { handleCompletion } from "./handler" export const messageRoutes = new Hono() @@ -13,3 +14,11 @@ messageRoutes.post("/", async (c) => { return await forwardError(c, error) } }) + +messageRoutes.post("/count_tokens", async (c) => { + try { + return await handleCountTokens(c) + } catch (error) { + return await forwardError(c, error) + } +}) diff --git a/src/routes/messages/stream-translation.ts b/src/routes/messages/stream-translation.ts index c8c20a07f..5f3a6a183 100644 --- a/src/routes/messages/stream-translation.ts +++ b/src/routes/messages/stream-translation.ts @@ -41,10 +41,6 @@ export function translateChunkToAnthropicEvents( model: chunk.model, stop_reason: null, stop_sequence: null, - usage: { - input_tokens: chunk.usage?.prompt_tokens ?? 0, - output_tokens: 0, // Will be updated in message_delta when finished - }, }, }) state.messageStartSent = true @@ -152,7 +148,9 @@ export function translateChunkToAnthropicEvents( stop_sequence: null, }, usage: { - input_tokens: chunk.usage?.prompt_tokens ?? 0, + input_tokens: + (chunk.usage?.prompt_tokens ?? 0) + - (chunk.usage?.prompt_tokens_details?.cached_tokens ?? 0), output_tokens: chunk.usage?.completion_tokens ?? 0, ...(chunk.usage?.prompt_tokens_details?.cached_tokens !== undefined && { diff --git a/src/server.ts b/src/server.ts index 3cb2bb860..462a278f3 100644 --- a/src/server.ts +++ b/src/server.ts @@ -29,4 +29,3 @@ server.route("/v1/embeddings", embeddingRoutes) // Anthropic compatible endpoints server.route("/v1/messages", messageRoutes) -server.post("/v1/messages/count_tokens", (c) => c.json({ input_tokens: 1 })) diff --git a/src/services/copilot/create-chat-completions.ts b/src/services/copilot/create-chat-completions.ts index 5d38bb452..8534151da 100644 --- a/src/services/copilot/create-chat-completions.ts +++ b/src/services/copilot/create-chat-completions.ts @@ -103,6 +103,9 @@ export interface ChatCompletionResponse { prompt_tokens: number completion_tokens: number total_tokens: number + prompt_tokens_details?: { + cached_tokens: number + } } } diff --git a/src/services/copilot/get-models.ts b/src/services/copilot/get-models.ts index 792adc480..3cfa30af0 100644 --- a/src/services/copilot/get-models.ts +++ b/src/services/copilot/get-models.ts @@ -39,7 +39,7 @@ interface ModelCapabilities { type: string } -interface Model { +export interface Model { capabilities: ModelCapabilities id: string model_picker_enabled: boolean diff --git a/src/start.ts b/src/start.ts index a1b02303e..737da7719 100644 --- a/src/start.ts +++ b/src/start.ts @@ -23,6 +23,7 @@ interface RunServerOptions { githubToken?: string claudeCode: boolean showToken: boolean + claudeCodeEnv?: boolean } export async function runServer(options: RunServerOptions): Promise { @@ -60,7 +61,7 @@ export async function runServer(options: RunServerOptions): Promise { const serverUrl = `http://localhost:${options.port}` - if (options.claudeCode) { + if (options.claudeCode && options.claudeCodeEnv) { invariant(state.models, "Models should be loaded by now") const selectedModel = await consola.prompt( @@ -169,6 +170,11 @@ export const start = defineCommand({ default: false, description: "Show GitHub and Copilot tokens on fetch and refresh", }, + "claude-code-env": { + type: "boolean", + default: true, + description: "Generate Claude Code Environment variables", + }, }, run({ args }) { const rateLimitRaw = args["rate-limit"] @@ -186,6 +192,7 @@ export const start = defineCommand({ githubToken: args["github-token"], claudeCode: args["claude-code"], showToken: args["show-token"], + claudeCodeEnv: args["claude-code-env"], }) }, }) diff --git a/tests/anthropic-response.test.ts b/tests/anthropic-response.test.ts index 352f06ea7..247b554ae 100644 --- a/tests/anthropic-response.test.ts +++ b/tests/anthropic-response.test.ts @@ -100,7 +100,10 @@ describe("OpenAI to Anthropic Non-Streaming Response Translation", () => { expect(anthropicResponse.id).toBe("chatcmpl-123") expect(anthropicResponse.stop_reason).toBe("end_turn") - expect(anthropicResponse.usage.input_tokens).toBe(9) + expect(anthropicResponse.usage).toBeDefined() + if (anthropicResponse.usage) { + expect(anthropicResponse.usage.input_tokens).toBe(9) + } expect(anthropicResponse.content[0].type).toBe("text") if (anthropicResponse.content[0].type === "text") { expect(anthropicResponse.content[0].text).toBe(