Skip to content

Commit 0d58863

Browse files
committed
feat: enhance token counting with model validation and update tests
1 parent 510cbf4 commit 0d58863

3 files changed

Lines changed: 40 additions & 3 deletions

File tree

src/routes/generate-content/handler.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,16 @@ export async function handleGeminiCountTokens(c: Context) {
300300

301301
const openAIPayload = translateGeminiCountTokensToOpenAI(geminiPayload, model)
302302

303-
const tokenCounts = getTokenCount(openAIPayload.messages)
303+
// Find the selected model object from state
304+
const selectedModel = state.models?.data.find((m) => m.id === model)
305+
306+
if (!selectedModel) {
307+
throw new Error(
308+
`Model ${model} not found in available models. Please ensure the model list is loaded.`,
309+
)
310+
}
311+
312+
const tokenCounts = await getTokenCount(openAIPayload, selectedModel)
304313

305314
const totalTokens = tokenCounts.input + tokenCounts.output
306315
const geminiResponse = translateTokenCountToGemini(totalTokens)

tests/generate-content/core-functionality.test.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,21 @@ test("translates request and uses local tokenizer without downstream call", asyn
2727
},
2828
}))
2929
await mock.module("~/lib/tokenizer", () => ({
30-
getTokenCount: (_messages: unknown) => ({ input: 2, output: 3 }),
30+
getTokenCount: async (_messages: unknown, _model: unknown) =>
31+
Promise.resolve({ input: 2, output: 3 }),
32+
}))
33+
await mock.module("~/lib/state", () => ({
34+
state: {
35+
models: {
36+
data: [
37+
{
38+
id: "gemini-pro",
39+
name: "Gemini Pro",
40+
capabilities: { tokenizer: "o200k_base" },
41+
},
42+
],
43+
},
44+
},
3145
}))
3246

3347
const { server } = (await import("~/server")) as { server: TestServer }

tests/generate-content/route-routing.test.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,21 @@ test("routes to stream endpoint based on URL keyword", async () => {
6666

6767
test("routes to countTokens endpoint based on URL keyword", async () => {
6868
await mock.module("~/lib/tokenizer", () => ({
69-
getTokenCount: (_messages: unknown) => ({ input: 2, output: 3 }),
69+
getTokenCount: async (_messages: unknown, _model: unknown) =>
70+
Promise.resolve({ input: 2, output: 3 }),
71+
}))
72+
await mock.module("~/lib/state", () => ({
73+
state: {
74+
models: {
75+
data: [
76+
{
77+
id: "gemini-pro",
78+
name: "Gemini Pro",
79+
capabilities: { tokenizer: "o200k_base" },
80+
},
81+
],
82+
},
83+
},
7084
}))
7185
await mock.module("~/lib/rate-limit", () => ({
7286
checkRateLimit: () => {},

0 commit comments

Comments
 (0)