Skip to content

Commit d2adbd1

Browse files
committed
fix(litellm): preserve thought_signature across tool call roundtrip
1 parent 2b8ccd4 commit d2adbd1

File tree

2 files changed

+101
-2
lines changed

2 files changed

+101
-2
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,46 @@
143143
"before a response was recorded)."
144144
)
145145

146+
_LITELLM_THOUGHT_SIGNATURE_SEPARATOR = "__thought__"
147+
148+
149+
def _decode_litellm_tool_call_id(
150+
tool_call_id: str,
151+
) -> tuple[str, Optional[bytes]]:
152+
"""Extracts thought_signature bytes from a LiteLLM tool call id."""
153+
if not tool_call_id:
154+
return tool_call_id, None
155+
156+
tool_call_id, separator, encoded_signature = tool_call_id.partition(
157+
_LITELLM_THOUGHT_SIGNATURE_SEPARATOR
158+
)
159+
if not separator or not encoded_signature:
160+
return tool_call_id, None
161+
162+
try:
163+
return tool_call_id, base64.b64decode(encoded_signature)
164+
except (ValueError, TypeError) as err:
165+
logger.warning(
166+
"Failed to decode thought_signature from tool call id %r: %s",
167+
tool_call_id,
168+
err,
169+
)
170+
return tool_call_id, None
171+
172+
173+
def _encode_litellm_tool_call_id(
174+
tool_call_id: Optional[str], thought_signature: Optional[bytes]
175+
) -> Optional[str]:
176+
"""Embeds thought_signature bytes in a LiteLLM-compatible tool call id."""
177+
if not tool_call_id or not thought_signature:
178+
return tool_call_id
179+
180+
encoded_signature = base64.b64encode(thought_signature).decode("utf-8")
181+
return (
182+
f"{tool_call_id}{_LITELLM_THOUGHT_SIGNATURE_SEPARATOR}"
183+
f"{encoded_signature}"
184+
)
185+
146186
_LITELLM_IMPORTED = False
147187
_LITELLM_GLOBAL_SYMBOLS = (
148188
"ChatCompletionAssistantMessage",
@@ -665,7 +705,10 @@ async def _content_to_message_param(
665705
tool_calls.append(
666706
ChatCompletionAssistantToolCall(
667707
type="function",
668-
id=part.function_call.id,
708+
id=_encode_litellm_tool_call_id(
709+
part.function_call.id,
710+
part.thought_signature,
711+
),
669712
function=Function(
670713
name=part.function_call.name,
671714
arguments=_safe_json_serialize(part.function_call.args),
@@ -1481,7 +1524,12 @@ def _message_to_generate_content_response(
14811524
name=tool_call.function.name,
14821525
args=json.loads(tool_call.function.arguments or "{}"),
14831526
)
1484-
part.function_call.id = tool_call.id
1527+
tool_call_id, thought_signature = _decode_litellm_tool_call_id(
1528+
tool_call.id
1529+
)
1530+
part.function_call.id = tool_call_id
1531+
if thought_signature:
1532+
part.thought_signature = thought_signature
14851533
parts.append(part)
14861534

14871535
return LlmResponse(

tests/unittests/models/test_litellm.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
1414

15+
import base64
1516
import contextlib
1617
import json
1718
import logging
@@ -2217,6 +2218,56 @@ def test_message_to_generate_content_response_tool_call():
22172218
assert response.content.parts[0].function_call.id == "test_tool_call_id"
22182219

22192220

2221+
def test_message_to_generate_content_response_tool_call_with_thought_signature():
2222+
signature = b"gemini_signature"
2223+
encoded_signature = base64.b64encode(signature).decode("utf-8")
2224+
message = ChatCompletionAssistantMessage(
2225+
role="assistant",
2226+
content=None,
2227+
tool_calls=[
2228+
ChatCompletionMessageToolCall(
2229+
type="function",
2230+
id=f"test_tool_call_id__thought__{encoded_signature}",
2231+
function=Function(
2232+
name="test_function",
2233+
arguments='{"test_arg": "test_value"}',
2234+
),
2235+
)
2236+
],
2237+
)
2238+
2239+
response = _message_to_generate_content_response(message)
2240+
assert response.content.role == "model"
2241+
assert response.content.parts[0].function_call.name == "test_function"
2242+
assert response.content.parts[0].function_call.args == {
2243+
"test_arg": "test_value"
2244+
}
2245+
assert response.content.parts[0].function_call.id == "test_tool_call_id"
2246+
assert response.content.parts[0].thought_signature == signature
2247+
2248+
2249+
@pytest.mark.asyncio
2250+
async def test_content_to_message_param_embeds_thought_signature_in_tool_call():
2251+
part = types.Part.from_function_call(
2252+
name="test_function",
2253+
args={"test_arg": "test_value"},
2254+
)
2255+
part.function_call.id = "test_tool_call_id"
2256+
part.thought_signature = b"gemini_signature"
2257+
content = types.Content(role="model", parts=[part])
2258+
2259+
message = await _content_to_message_param(content)
2260+
2261+
tool_calls = message["tool_calls"]
2262+
assert tool_calls is not None
2263+
assert len(tool_calls) == 1
2264+
assert (
2265+
tool_calls[0]["id"]
2266+
== "test_tool_call_id__thought__"
2267+
+ base64.b64encode(b"gemini_signature").decode("utf-8")
2268+
)
2269+
2270+
22202271
def test_message_to_generate_content_response_inline_tool_call_text():
22212272
message = ChatCompletionAssistantMessage(
22222273
role="assistant",

0 commit comments

Comments
 (0)