Skip to content

Commit 3d07960

Browse files
jfrometa88xuanyang15
authored andcommitted
fix: use tool_responses role for gemma4 models in LiteLLM integration
Merge #5655 Closes: #5650 Co-authored-by: Xuan Yang <xygoogle@google.com> COPYBARA_INTEGRATE_REVIEW=#5655 from jfrometa88:main d74b521 PiperOrigin-RevId: 917231422
1 parent 1284493 commit 3d07960

2 files changed

Lines changed: 197 additions & 7 deletions

File tree

src/google/adk/models/lite_llm.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -807,9 +807,14 @@ async def _content_to_message_param(
807807
if isinstance(response, str)
808808
else _safe_json_serialize(response)
809809
)
810+
# gemma4 requires role='tool_responses' for recognizing function_response parts as responses
811+
# from the tool call, instead of OpenAI-compatible 'tool' role used by other models.
812+
# Earlier Gemma versions before version 4 do not support tool use,
813+
# so this check is intentionally scoped to only look for "gemma4" in the model name.
814+
tool_role = "tool_responses" if "gemma4" in model.lower() else "tool"
810815
tool_messages.append(
811816
ChatCompletionToolMessage(
812-
role="tool",
817+
role=tool_role,
813818
tool_call_id=part.function_response.id,
814819
content=response_content,
815820
)
@@ -824,6 +829,7 @@ async def _content_to_message_param(
824829
follow_up = await _content_to_message_param(
825830
types.Content(role=content.role, parts=non_tool_parts),
826831
provider=provider,
832+
model=model,
827833
)
828834
follow_up_messages = (
829835
follow_up if isinstance(follow_up, list) else [follow_up]
@@ -934,12 +940,16 @@ async def _content_to_message_param(
934940
)
935941

936942

937-
def _ensure_tool_results(messages: List[Message]) -> List[Message]:
943+
def _ensure_tool_results(messages: List[Message], model: str) -> List[Message]:
938944
"""Insert placeholder tool messages for missing tool results.
939945
940946
LiteLLM-backed providers like OpenAI and Anthropic reject histories where an
941947
assistant tool call is not followed by tool responses before the next
942948
non-tool message. This helps recover from interrupted tool execution.
949+
950+
For models that expect a different tool response role (e.g. Gemma4 models,
951+
which require 'tool_responses' instead of 'tool'), the role is adjusted
952+
accordingly.
943953
"""
944954
if not messages:
945955
return messages
@@ -948,17 +958,19 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]:
948958

949959
healed_messages: List[Message] = []
950960
pending_tool_call_ids: List[str] = []
961+
expected_tool_role = "tool_responses" if "gemma4" in model.lower() else "tool"
951962

952963
for message in messages:
953964
role = message.get("role")
954-
if pending_tool_call_ids and role != "tool":
965+
966+
if pending_tool_call_ids and role != expected_tool_role:
955967
logger.warning(
956968
"Missing tool results for tool_call_id(s): %s",
957969
pending_tool_call_ids,
958970
)
959971
healed_messages.extend(
960972
ChatCompletionToolMessage(
961-
role="tool",
973+
role=expected_tool_role,
962974
tool_call_id=tool_call_id,
963975
content=_MISSING_TOOL_RESULT_MESSAGE,
964976
)
@@ -971,21 +983,22 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]:
971983
pending_tool_call_ids = [
972984
tool_call.get("id") for tool_call in tool_calls if tool_call.get("id")
973985
]
974-
elif role == "tool":
986+
elif role == expected_tool_role:
975987
tool_call_id = message.get("tool_call_id")
976988
if tool_call_id in pending_tool_call_ids:
977989
pending_tool_call_ids.remove(tool_call_id)
978990

979991
healed_messages.append(message)
980992

993+
# Final block also uses expected_tool_role
981994
if pending_tool_call_ids:
982995
logger.warning(
983996
"Missing tool results for tool_call_id(s): %s",
984997
pending_tool_call_ids,
985998
)
986999
healed_messages.extend(
9871000
ChatCompletionToolMessage(
988-
role="tool",
1001+
role=expected_tool_role,
9891002
tool_call_id=tool_call_id,
9901003
content=_MISSING_TOOL_RESULT_MESSAGE,
9911004
)
@@ -1905,7 +1918,7 @@ async def _get_completion_inputs(
19051918
content=llm_request.config.system_instruction,
19061919
),
19071920
)
1908-
messages = _ensure_tool_results(messages)
1921+
messages = _ensure_tool_results(messages, model)
19091922

19101923
# 2. Convert tool declarations
19111924
tools: Optional[List[Dict]] = None
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for Gemma-specific tool role handling in _content_to_message_param.
16+
17+
Gemma's chat template expects role='tool_responses' for tool result messages,
18+
while the OpenAI-compatible default is role='tool'. This module verifies that
19+
_content_to_message_param sets the correct role based on the model name.
20+
"""
21+
22+
from typing import Any
23+
24+
from google.adk.models.lite_llm import _content_to_message_param
25+
from google.genai import types
26+
import pytest
27+
28+
29+
def _make_function_response_content(
30+
function_name: str = "get_weather",
31+
response_data: dict[str, Any] | None = None,
32+
call_id: str = "call_001",
33+
) -> types.Content:
34+
"""Builds a types.Content with a single function_response part."""
35+
if response_data is None:
36+
response_data = {"city": "Santiago de Cuba", "condition": "sunny"}
37+
return types.Content(
38+
role="user",
39+
parts=[
40+
types.Part(
41+
function_response=types.FunctionResponse(
42+
name=function_name,
43+
response=response_data,
44+
id=call_id,
45+
)
46+
)
47+
],
48+
)
49+
50+
51+
def _make_multi_function_response_content(
52+
call_ids: list[str] | None = None,
53+
) -> types.Content:
54+
"""Builds a types.Content with multiple function_response parts."""
55+
if call_ids is None:
56+
call_ids = ["call_001", "call_002"]
57+
return types.Content(
58+
role="user",
59+
parts=[
60+
types.Part(
61+
function_response=types.FunctionResponse(
62+
name=f"tool_{i}",
63+
response={"result": f"value_{i}"},
64+
id=call_id,
65+
)
66+
)
67+
for i, call_id in enumerate(call_ids)
68+
],
69+
)
70+
71+
72+
def _extract_role(msg) -> str:
73+
"""Extracts role from a litellm message, whether dict or object."""
74+
if isinstance(msg, dict):
75+
return msg["role"]
76+
return msg.role
77+
78+
79+
class TestToolRoleSingleResponse:
80+
"""_content_to_message_param with a single function_response part."""
81+
82+
@pytest.mark.asyncio
83+
async def test_gemma4_model_uses_tool_responses_role(self):
84+
"""Models containing 'gemma4' should get role='tool_responses'."""
85+
content = _make_function_response_content()
86+
87+
result = await _content_to_message_param(content, model="ollama/gemma4:e2b")
88+
89+
assert _extract_role(result) == "tool_responses", (
90+
"Gemma models require role='tool_responses' to match their chat "
91+
"template; role='tool' causes infinite tool-calling loops."
92+
)
93+
94+
@pytest.mark.asyncio
95+
async def test_gemma4_uppercase_model_name(self):
96+
"""Model name matching should be case-insensitive."""
97+
content = _make_function_response_content()
98+
99+
result = await _content_to_message_param(content, model="ollama/Gemma4:31b")
100+
101+
assert _extract_role(result) == "tool_responses"
102+
103+
@pytest.mark.asyncio
104+
async def test_tool_call_id_and_content_preserved(self):
105+
"""Fix must not alter tool_call_id or content — only role changes."""
106+
content = _make_function_response_content(
107+
response_data={"status": "ok"}, call_id="my_call_123"
108+
)
109+
110+
result = await _content_to_message_param(content, model="ollama/gemma4:e2b")
111+
112+
if isinstance(result, dict):
113+
assert result["tool_call_id"] == "my_call_123"
114+
assert "ok" in result["content"]
115+
else:
116+
assert result.tool_call_id == "my_call_123"
117+
assert "ok" in result.content
118+
119+
@pytest.mark.asyncio
120+
async def test_empty_model_string_uses_tool_role(self):
121+
"""Empty model string should fall back to default role='tool'."""
122+
content = _make_function_response_content()
123+
124+
result = await _content_to_message_param(content, model="")
125+
126+
assert _extract_role(result) == "tool"
127+
128+
@pytest.mark.asyncio
129+
async def test_unrelated_models_use_tool_role(self):
130+
"""Models that do not contain 'gemma4' must not be affected."""
131+
unaffected_models = [
132+
"ollama/llama3:8b",
133+
"ollama/qwen2.5-coder:3b",
134+
"anthropic/claude-3-opus",
135+
"openai/gpt-4o",
136+
"ollama/gemma3:4b", # gemma3 != gemma4
137+
]
138+
for model in unaffected_models:
139+
content = _make_function_response_content()
140+
result = await _content_to_message_param(content, model=model)
141+
assert (
142+
_extract_role(result) == "tool"
143+
), f"Model '{model}' should not be affected by the Gemma4 fix."
144+
145+
146+
class TestToolRoleMultipleResponses:
147+
"""_content_to_message_param with multiple function_response parts."""
148+
149+
@pytest.mark.asyncio
150+
async def test_gemma4_all_messages_use_tool_responses_role(self):
151+
"""All messages in a multi-response must have role='tool_responses'."""
152+
content = _make_multi_function_response_content(
153+
call_ids=["call_a", "call_b", "call_c"]
154+
)
155+
156+
result = await _content_to_message_param(content, model="ollama/gemma4:4b")
157+
158+
assert isinstance(result, list)
159+
assert len(result) == 3
160+
for msg in result:
161+
assert _extract_role(msg) == "tool_responses", (
162+
"Every tool message in a multi-response must use 'tool_responses' "
163+
"for Gemma4 models."
164+
)
165+
166+
@pytest.mark.asyncio
167+
async def test_non_gemma_multi_response_uses_tool_role(self):
168+
"""Non-Gemma multi-response messages should all have role='tool'."""
169+
content = _make_multi_function_response_content(
170+
call_ids=["call_a", "call_b"]
171+
)
172+
173+
result = await _content_to_message_param(content, model="openai/gpt-4o")
174+
175+
assert isinstance(result, list)
176+
for msg in result:
177+
assert _extract_role(msg) == "tool"

0 commit comments

Comments
 (0)