Skip to content

Commit 31cc5a1

Browse files
committed
fix: parse noncanonical litellm tool call arguments
Change-Id: Iced114b05b6d89cba62e02e7440134e8f5612215
1 parent b73679e commit 31cc5a1

2 files changed

Lines changed: 149 additions & 2 deletions

File tree

src/google/adk/models/lite_llm.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import ast
1718
import base64
1819
import binascii
1920
import copy
@@ -98,6 +99,7 @@
9899
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}
99100
_LITELLM_STRUCTURED_TYPES = {"json_object", "json_schema"}
100101
_JSON_DECODER = json.JSONDecoder()
102+
_UNQUOTED_KEY_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*")
101103

102104
# Mapping of major MIME type prefixes to LiteLLM content types for URL blocks.
103105
# Audio is handled separately as `input_audio` content blocks because LiteLLM
@@ -122,6 +124,100 @@
122124
"content_filter": types.FinishReason.SAFETY,
123125
}
124126

127+
128+
def _quote_unquoted_json_object_keys(value: str) -> str:
129+
"""Quotes simple unquoted object keys without touching string contents."""
130+
result = []
131+
i = 0
132+
in_string = False
133+
string_quote = ""
134+
escaped = False
135+
136+
while i < len(value):
137+
char = value[i]
138+
if in_string:
139+
result.append(char)
140+
if escaped:
141+
escaped = False
142+
elif char == "\\":
143+
escaped = True
144+
elif char == string_quote:
145+
in_string = False
146+
string_quote = ""
147+
i += 1
148+
continue
149+
150+
if char in {'"', "'"}:
151+
in_string = True
152+
string_quote = char
153+
result.append(char)
154+
i += 1
155+
continue
156+
157+
if char in "{,":
158+
result.append(char)
159+
i += 1
160+
whitespace_start = i
161+
while i < len(value) and value[i].isspace():
162+
i += 1
163+
result.append(value[whitespace_start:i])
164+
165+
key_match = _UNQUOTED_KEY_RE.match(value, i)
166+
if key_match:
167+
key_end = key_match.end()
168+
colon_index = key_end
169+
while colon_index < len(value) and value[colon_index].isspace():
170+
colon_index += 1
171+
if colon_index < len(value) and value[colon_index] == ":":
172+
result.append(f'"{key_match.group(0)}"')
173+
result.append(value[key_end:colon_index])
174+
i = colon_index
175+
continue
176+
continue
177+
178+
result.append(char)
179+
i += 1
180+
181+
return "".join(result)
182+
183+
184+
def _parse_tool_call_arguments(arguments: Any) -> Any:
185+
"""Parses LiteLLM tool call arguments.
186+
187+
LiteLLM normally returns OpenAI-compatible tool call arguments as JSON
188+
strings, but some providers can stream a complete tool call whose finalized
189+
argument payload is a Python dict literal or has unquoted object keys. Keep
190+
strict JSON as the primary path, then repair only those complete
191+
object-literal shapes so ADK can still surface the intended function call.
192+
"""
193+
if not arguments:
194+
return {}
195+
if not isinstance(arguments, str):
196+
return arguments
197+
198+
try:
199+
return json.loads(arguments)
200+
except json.JSONDecodeError as exc:
201+
json_error = exc
202+
203+
try:
204+
return ast.literal_eval(arguments)
205+
except (SyntaxError, ValueError):
206+
pass
207+
208+
repaired_arguments = _quote_unquoted_json_object_keys(arguments)
209+
if repaired_arguments != arguments:
210+
try:
211+
return json.loads(repaired_arguments)
212+
except json.JSONDecodeError:
213+
try:
214+
return ast.literal_eval(repaired_arguments)
215+
except (SyntaxError, ValueError):
216+
pass
217+
218+
raise json_error
219+
220+
125221
# File MIME types supported for upload as file content (not decoded as text).
126222
# Note: text/* types are handled separately and decoded as text content.
127223
# These types are uploaded as files to providers that support it.
@@ -1727,7 +1823,7 @@ def _message_to_generate_content_response(
17271823
thought_signature = _extract_thought_signature_from_tool_call(tool_call)
17281824
part = types.Part.from_function_call(
17291825
name=tool_call.function.name,
1730-
args=json.loads(tool_call.function.arguments or "{}"),
1826+
args=_parse_tool_call_arguments(tool_call.function.arguments),
17311827
)
17321828
part.function_call.id = tool_call.id
17331829
if thought_signature:
@@ -2281,7 +2377,7 @@ def _finalize_tool_call_response(
22812377
if func_data["id"]:
22822378
if finish_reason == "length":
22832379
try:
2284-
json.loads(func_data["args"] or "{}")
2380+
_parse_tool_call_arguments(func_data["args"])
22852381
except json.JSONDecodeError:
22862382
has_incomplete_tool_call_args = True
22872383
continue

tests/unittests/models/test_litellm.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2253,6 +2253,57 @@ def test_message_to_generate_content_response_tool_call():
22532253
assert response.content.parts[0].function_call.id == "test_tool_call_id"
22542254

22552255

2256+
def test_message_to_generate_content_response_tool_call_accepts_python_literal_arguments():
2257+
message = ChatCompletionAssistantMessage(
2258+
role="assistant",
2259+
content=None,
2260+
tool_calls=[
2261+
ChatCompletionMessageToolCall(
2262+
type="function",
2263+
id="test_tool_call_id",
2264+
function=Function(
2265+
name="test_function",
2266+
arguments="{'query': 'MATCH (n) RETURN n'}",
2267+
),
2268+
)
2269+
],
2270+
)
2271+
2272+
response = _message_to_generate_content_response(message)
2273+
2274+
assert response.content.role == "model"
2275+
assert response.content.parts[0].function_call.name == "test_function"
2276+
assert response.content.parts[0].function_call.args == {
2277+
"query": "MATCH (n) RETURN n"
2278+
}
2279+
2280+
2281+
def test_message_to_generate_content_response_tool_call_accepts_unquoted_json_keys():
2282+
message = ChatCompletionAssistantMessage(
2283+
role="assistant",
2284+
content=None,
2285+
tool_calls=[
2286+
ChatCompletionMessageToolCall(
2287+
type="function",
2288+
id="test_tool_call_id",
2289+
function=Function(
2290+
name="test_function",
2291+
arguments='{query: "MATCH (n) RETURN n", limit: 5}',
2292+
),
2293+
)
2294+
],
2295+
)
2296+
2297+
response = _message_to_generate_content_response(message)
2298+
2299+
assert response.content.role == "model"
2300+
assert response.content.parts[0].function_call.name == "test_function"
2301+
assert response.content.parts[0].function_call.args == {
2302+
"query": "MATCH (n) RETURN n",
2303+
"limit": 5,
2304+
}
2305+
2306+
22562307
def test_message_to_generate_content_response_inline_tool_call_text():
22572308
message = ChatCompletionAssistantMessage(
22582309
role="assistant",

0 commit comments

Comments
 (0)