|
| 1 | +# Copyright 2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Tests for Gemma-specific tool role handling in _content_to_message_param. |
| 16 | +
|
| 17 | +Gemma's chat template expects role='tool_responses' for tool result messages, |
| 18 | +while the OpenAI-compatible default is role='tool'. This module verifies that |
| 19 | +_content_to_message_param sets the correct role based on the model name. |
| 20 | +""" |
| 21 | + |
| 22 | +from typing import Any |
| 23 | + |
| 24 | +from google.adk.models.lite_llm import _content_to_message_param |
| 25 | +from google.genai import types |
| 26 | +import pytest |
| 27 | + |
| 28 | + |
| 29 | +def _make_function_response_content( |
| 30 | + function_name: str = "get_weather", |
| 31 | + response_data: dict[str, Any] | None = None, |
| 32 | + call_id: str = "call_001", |
| 33 | +) -> types.Content: |
| 34 | + """Builds a types.Content with a single function_response part.""" |
| 35 | + if response_data is None: |
| 36 | + response_data = {"city": "Santiago de Cuba", "condition": "sunny"} |
| 37 | + return types.Content( |
| 38 | + role="user", |
| 39 | + parts=[ |
| 40 | + types.Part( |
| 41 | + function_response=types.FunctionResponse( |
| 42 | + name=function_name, |
| 43 | + response=response_data, |
| 44 | + id=call_id, |
| 45 | + ) |
| 46 | + ) |
| 47 | + ], |
| 48 | + ) |
| 49 | + |
| 50 | + |
| 51 | +def _make_multi_function_response_content( |
| 52 | + call_ids: list[str] | None = None, |
| 53 | +) -> types.Content: |
| 54 | + """Builds a types.Content with multiple function_response parts.""" |
| 55 | + if call_ids is None: |
| 56 | + call_ids = ["call_001", "call_002"] |
| 57 | + return types.Content( |
| 58 | + role="user", |
| 59 | + parts=[ |
| 60 | + types.Part( |
| 61 | + function_response=types.FunctionResponse( |
| 62 | + name=f"tool_{i}", |
| 63 | + response={"result": f"value_{i}"}, |
| 64 | + id=call_id, |
| 65 | + ) |
| 66 | + ) |
| 67 | + for i, call_id in enumerate(call_ids) |
| 68 | + ], |
| 69 | + ) |
| 70 | + |
| 71 | + |
| 72 | +def _extract_role(msg) -> str: |
| 73 | + """Extracts role from a litellm message, whether dict or object.""" |
| 74 | + if isinstance(msg, dict): |
| 75 | + return msg["role"] |
| 76 | + return msg.role |
| 77 | + |
| 78 | + |
| 79 | +class TestToolRoleSingleResponse: |
| 80 | + """_content_to_message_param with a single function_response part.""" |
| 81 | + |
| 82 | + @pytest.mark.asyncio |
| 83 | + async def test_gemma4_model_uses_tool_responses_role(self): |
| 84 | + """Models containing 'gemma4' should get role='tool_responses'.""" |
| 85 | + content = _make_function_response_content() |
| 86 | + |
| 87 | + result = await _content_to_message_param(content, model="ollama/gemma4:e2b") |
| 88 | + |
| 89 | + assert _extract_role(result) == "tool_responses", ( |
| 90 | + "Gemma models require role='tool_responses' to match their chat " |
| 91 | + "template; role='tool' causes infinite tool-calling loops." |
| 92 | + ) |
| 93 | + |
| 94 | + @pytest.mark.asyncio |
| 95 | + async def test_gemma4_uppercase_model_name(self): |
| 96 | + """Model name matching should be case-insensitive.""" |
| 97 | + content = _make_function_response_content() |
| 98 | + |
| 99 | + result = await _content_to_message_param(content, model="ollama/Gemma4:31b") |
| 100 | + |
| 101 | + assert _extract_role(result) == "tool_responses" |
| 102 | + |
| 103 | + @pytest.mark.asyncio |
| 104 | + async def test_tool_call_id_and_content_preserved(self): |
| 105 | + """Fix must not alter tool_call_id or content — only role changes.""" |
| 106 | + content = _make_function_response_content( |
| 107 | + response_data={"status": "ok"}, call_id="my_call_123" |
| 108 | + ) |
| 109 | + |
| 110 | + result = await _content_to_message_param(content, model="ollama/gemma4:e2b") |
| 111 | + |
| 112 | + if isinstance(result, dict): |
| 113 | + assert result["tool_call_id"] == "my_call_123" |
| 114 | + assert "ok" in result["content"] |
| 115 | + else: |
| 116 | + assert result.tool_call_id == "my_call_123" |
| 117 | + assert "ok" in result.content |
| 118 | + |
| 119 | + @pytest.mark.asyncio |
| 120 | + async def test_empty_model_string_uses_tool_role(self): |
| 121 | + """Empty model string should fall back to default role='tool'.""" |
| 122 | + content = _make_function_response_content() |
| 123 | + |
| 124 | + result = await _content_to_message_param(content, model="") |
| 125 | + |
| 126 | + assert _extract_role(result) == "tool" |
| 127 | + |
| 128 | + @pytest.mark.asyncio |
| 129 | + async def test_unrelated_models_use_tool_role(self): |
| 130 | + """Models that do not contain 'gemma4' must not be affected.""" |
| 131 | + unaffected_models = [ |
| 132 | + "ollama/llama3:8b", |
| 133 | + "ollama/qwen2.5-coder:3b", |
| 134 | + "anthropic/claude-3-opus", |
| 135 | + "openai/gpt-4o", |
| 136 | + "ollama/gemma3:4b", # gemma3 != gemma4 |
| 137 | + ] |
| 138 | + for model in unaffected_models: |
| 139 | + content = _make_function_response_content() |
| 140 | + result = await _content_to_message_param(content, model=model) |
| 141 | + assert ( |
| 142 | + _extract_role(result) == "tool" |
| 143 | + ), f"Model '{model}' should not be affected by the Gemma4 fix." |
| 144 | + |
| 145 | + |
| 146 | +class TestToolRoleMultipleResponses: |
| 147 | + """_content_to_message_param with multiple function_response parts.""" |
| 148 | + |
| 149 | + @pytest.mark.asyncio |
| 150 | + async def test_gemma4_all_messages_use_tool_responses_role(self): |
| 151 | + """All messages in a multi-response must have role='tool_responses'.""" |
| 152 | + content = _make_multi_function_response_content( |
| 153 | + call_ids=["call_a", "call_b", "call_c"] |
| 154 | + ) |
| 155 | + |
| 156 | + result = await _content_to_message_param(content, model="ollama/gemma4:4b") |
| 157 | + |
| 158 | + assert isinstance(result, list) |
| 159 | + assert len(result) == 3 |
| 160 | + for msg in result: |
| 161 | + assert _extract_role(msg) == "tool_responses", ( |
| 162 | + "Every tool message in a multi-response must use 'tool_responses' " |
| 163 | + "for Gemma4 models." |
| 164 | + ) |
| 165 | + |
| 166 | + @pytest.mark.asyncio |
| 167 | + async def test_non_gemma_multi_response_uses_tool_role(self): |
| 168 | + """Non-Gemma multi-response messages should all have role='tool'.""" |
| 169 | + content = _make_multi_function_response_content( |
| 170 | + call_ids=["call_a", "call_b"] |
| 171 | + ) |
| 172 | + |
| 173 | + result = await _content_to_message_param(content, model="openai/gpt-4o") |
| 174 | + |
| 175 | + assert isinstance(result, list) |
| 176 | + for msg in result: |
| 177 | + assert _extract_role(msg) == "tool" |
0 commit comments