Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/llm/src/protocols/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
} from "../schema"
import { JsonObject, optionalArray, ProviderShared } from "./shared"
import { GeminiToolSchema } from "./utils/gemini-tool-schema"
import { unflattenArgs } from "./utils/unflatten-args"
import { Lifecycle } from "./utils/lifecycle"
import { ToolSchemaProjection } from "./utils/tool-schema"

Expand Down Expand Up @@ -439,7 +440,7 @@ const step = (state: ParserState, event: GeminiEvent) => {
}

if ("functionCall" in part) {
const input = part.functionCall.args
const input = unflattenArgs(part.functionCall.args)
const id = `tool_${nextToolCallId++}`
lifecycle = Lifecycle.reasoningEnd(
lifecycle,
Expand Down
67 changes: 67 additions & 0 deletions packages/llm/src/protocols/utils/unflatten-args.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/**
* Unflattens dot-bracket notation keys into nested objects/arrays.
*
* Gemini models sometimes return tool call arguments in flattened form, e.g.
* `{ "questions[0].header": "Auth", "questions[0].multiSelect": false }` instead
* of the nested `{ questions: [{ header: "Auth", multiSelect: false }] }` that
* downstream schema validation expects.
*/
export function unflattenArgs(args: Record<string, unknown> | null | undefined): Record<string, unknown> | null | undefined {
if (!args || typeof args !== "object") return args
const keys = Object.keys(args)
if (keys.length === 0) return args

// Fast-path: if no key contains '[', the args are already nested.
const needsUnflatten = keys.some((k) => k.includes("["))
if (!needsUnflatten) return args

const result: Record<string, unknown> = Object.create(null)
for (const key of keys) {
const tokens = tokenize(key)
if (tokens.length > 0) setNested(result, tokens, args[key])
}
return result
}

/** Parse a flat key like "a[0].b.c[1]" into tokens: ["a", 0, "b", "c", 1] */
function tokenize(key: string): Array<string | number> {
const tokens: Array<string | number> = []
let i = 0
while (i < key.length) {
if (key[i] === "[") {
// bracket segment
const end = key.indexOf("]", i)
if (end === -1) break // malformed key, stop parsing
const inner = key.slice(i + 1, end)
tokens.push(/^\d+$/.test(inner) ? Number(inner) : inner)
i = end + 1
if (key[i] === ".") i++ // skip trailing dot
} else {
// dot-delimited segment
let end = i
while (end < key.length && key[end] !== "." && key[end] !== "[") end++
tokens.push(key.slice(i, end))
i = end
if (key[i] === ".") i++ // skip dot
}
}
return tokens
}

const BANNED_KEYS = new Set(["__proto__", "constructor", "prototype"])

function setNested(obj: Record<string, unknown>, tokens: Array<string | number>, value: unknown): void {
let current: any = obj
for (let i = 0; i < tokens.length - 1; i++) {
const token = tokens[i]
if (typeof token === "string" && BANNED_KEYS.has(token)) return
const next = tokens[i + 1]
if (current[token as any] == null) {
current[token as any] = typeof next === "number" ? [] : Object.create(null)
}
current = current[token as any]
}
const last = tokens[tokens.length - 1]
if (typeof last === "string" && BANNED_KEYS.has(last)) return
current[last as any] = value
}
89 changes: 89 additions & 0 deletions packages/llm/test/unflatten-args.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import { describe, expect, test } from "bun:test"
import { unflattenArgs } from "../src/protocols/utils/unflatten-args"

describe("unflattenArgs", () => {
test("returns null/undefined as-is", () => {
expect(unflattenArgs(null)).toBe(null)
expect(unflattenArgs(undefined)).toBe(undefined)
})

test("returns empty object as-is", () => {
const obj = {}
expect(unflattenArgs(obj)).toBe(obj)
})

test("returns already-nested object unchanged", () => {
const obj = { name: "hello", nested: { a: 1 } }
expect(unflattenArgs(obj)).toBe(obj)
})

test("passes through dot-only keys (no brackets)", () => {
const obj = { "a.b": "val" }
expect(unflattenArgs(obj)).toBe(obj)
})

test("unflattens bracket notation", () => {
expect(unflattenArgs({ "a[0]": "val" })).toEqual({ a: ["val"] })
})

test("unflattens deep mixed notation", () => {
expect(
unflattenArgs({ "questions[0].header": "Auth" }),
).toEqual({ questions: [{ header: "Auth" }] })
})

test("unflattens multiple array items", () => {
expect(
unflattenArgs({
"a[0].x": 1,
"a[1].x": 2,
}),
).toEqual({ a: [{ x: 1 }, { x: 2 }] })
})

test("handles the issue #35105 example", () => {
const flat = {
"questions[0].question": "Which auth method?",
"questions[0].header": "Auth",
"questions[0].options[0].label": "OAuth",
"questions[0].options[0].description": "Use OAuth",
"questions[0].options[1].label": "JWT",
"questions[0].options[1].description": "Use JWT",
"questions[0].multiSelect": false,
}
expect(unflattenArgs(flat)).toEqual({
questions: [
{
question: "Which auth method?",
header: "Auth",
options: [
{ label: "OAuth", description: "Use OAuth" },
{ label: "JWT", description: "Use JWT" },
],
multiSelect: false,
},
],
})
})

test("preserves non-bracket keys alongside bracket keys", () => {
expect(
unflattenArgs({
plain: "yes",
"arr[0]": "val",
}),
).toEqual({ plain: "yes", arr: ["val"] })
})

test("handles malformed key with missing closing bracket", () => {
// Should not hang — partial token is ignored
const result = unflattenArgs({ "a[0": "val" })
expect(result).toBeDefined()
})

test("rejects prototype pollution attempts", () => {
const result = unflattenArgs({ "__proto__[0]": "evil" }) as any
expect(({} as any).constructor).toBeDefined() // Object.prototype untouched
expect(result["__proto__"] ?? undefined).toBeUndefined()
})
})
Loading