Skip to content

Commit 4b83906

Browse files
fix: validate and coerce function tool argument types (#4612)
_preprocess_args now uses pydantic.TypeAdapter to validate and coerce all annotated argument types (primitives, enums, containers), not just Pydantic models. Invalid arguments return a descriptive error to the LLM so it can self-correct and retry, matching the existing pattern for missing mandatory args. - Coerces compatible types (e.g. str "42" -> int 42, str "red" -> Color.RED) - Returns validation errors to LLM for incompatible types (e.g. "foobar" -> int) - Existing Pydantic BaseModel handling unchanged (graceful failure) - Updated all 3 call sites (FunctionTool, CrewAITool, sync tool path)
1 parent 0ad4de7 commit 4b83906

File tree

5 files changed

+316
-55
lines changed

5 files changed

+316
-55
lines changed

src/google/adk/flows/llm_flows/functions.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,16 @@ async def _call_tool_in_thread_pool(
147147
# For sync FunctionTool, call the underlying function directly
148148
def run_sync_tool():
149149
if isinstance(tool, FunctionTool):
150-
args_to_call = tool._preprocess_args(args)
150+
args_to_call, validation_errors = tool._preprocess_args(args)
151+
if validation_errors:
152+
validation_errors_str = '\n'.join(validation_errors)
153+
return {
154+
'error': (
155+
f'Invoking `{tool.name}()` failed due to argument'
156+
f' validation errors:\n{validation_errors_str}\nYou could'
157+
' retry calling this tool with corrected argument types.'
158+
)
159+
}
151160
signature = inspect.signature(tool.func)
152161
valid_params = {param for param in signature.parameters}
153162
if tool._context_param_name in valid_params:

src/google/adk/tools/crewai_tool.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,16 @@ async def run_async(
7373
duplicates, but is re-added if the function signature explicitly requires it
7474
as a parameter.
7575
"""
76-
# Preprocess arguments (includes Pydantic model conversion)
77-
args_to_call = self._preprocess_args(args)
76+
# Preprocess arguments (includes Pydantic model conversion and type
77+
# validation)
78+
args_to_call, validation_errors = self._preprocess_args(args)
79+
80+
if validation_errors:
81+
validation_errors_str = '\n'.join(validation_errors)
82+
error_str = f"""Invoking `{self.name}()` failed due to argument validation errors:
83+
{validation_errors_str}
84+
You could retry calling this tool with corrected argument types."""
85+
return {'error': error_str}
7886

7987
signature = inspect.signature(self.func)
8088
valid_params = {param for param in signature.parameters}

src/google/adk/tools/function_tool.py

Lines changed: 78 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -100,68 +100,101 @@ def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
100100

101101
return function_decl
102102

103-
def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
104-
"""Preprocess and convert function arguments before invocation.
103+
def _preprocess_args(
104+
self, args: dict[str, Any]
105+
) -> tuple[dict[str, Any], list[str]]:
106+
"""Preprocess, validate, and convert function arguments before invocation.
105107
106-
Currently handles:
108+
Handles:
107109
- Converting JSON dictionaries to Pydantic model instances where expected
108-
109-
Future extensions could include:
110-
- Type coercion for other complex types
111-
- Validation and sanitization
112-
- Custom conversion logic
110+
- Validating and coercing primitive types (int, float, str, bool)
111+
- Validating enum values
112+
- Validating container types (list[int], dict[str, float], etc.)
113113
114114
Args:
115115
args: Raw arguments from the LLM tool call
116116
117117
Returns:
118-
Processed arguments ready for function invocation
118+
A tuple of (processed_args, validation_errors). If validation_errors is
119+
non-empty, the caller should return the errors to the LLM instead of
120+
invoking the function.
119121
"""
120122
signature = inspect.signature(self.func)
121123
converted_args = args.copy()
124+
validation_errors = []
122125

123126
for param_name, param in signature.parameters.items():
124-
if param_name in args and param.annotation != inspect.Parameter.empty:
125-
target_type = param.annotation
126-
127-
# Handle Optional[PydanticModel] types
128-
if get_origin(param.annotation) is Union:
129-
union_args = get_args(param.annotation)
130-
# Find the non-None type in Optional[T] (which is Union[T, None])
131-
non_none_types = [arg for arg in union_args if arg is not type(None)]
132-
if len(non_none_types) == 1:
133-
target_type = non_none_types[0]
134-
135-
# Check if the target type is a Pydantic model
136-
if inspect.isclass(target_type) and issubclass(
137-
target_type, pydantic.BaseModel
138-
):
139-
# Skip conversion if the value is None and the parameter is Optional
140-
if args[param_name] is None:
141-
continue
142-
143-
# Convert to Pydantic model if it's not already the correct type
144-
if not isinstance(args[param_name], target_type):
145-
try:
146-
converted_args[param_name] = target_type.model_validate(
147-
args[param_name]
148-
)
149-
except Exception as e:
150-
logger.warning(
151-
f"Failed to convert argument '{param_name}' to Pydantic model"
152-
f' {target_type.__name__}: {e}'
153-
)
154-
# Keep the original value if conversion fails
155-
pass
156-
157-
return converted_args
127+
if param_name not in args or param.annotation is inspect.Parameter.empty:
128+
continue
129+
130+
target_type = param.annotation
131+
is_optional = False
132+
133+
# Handle Optional[T] (Union[T, None]) - unwrap to get inner type
134+
if get_origin(param.annotation) is Union:
135+
union_args = get_args(param.annotation)
136+
non_none_types = [arg for arg in union_args if arg is not type(None)]
137+
if len(non_none_types) == 1:
138+
target_type = non_none_types[0]
139+
is_optional = len(union_args) != len(non_none_types)
140+
141+
# Pydantic models: keep existing graceful-failure behavior
142+
if inspect.isclass(target_type) and issubclass(
143+
target_type, pydantic.BaseModel
144+
):
145+
if args[param_name] is None:
146+
continue
147+
if not isinstance(args[param_name], target_type):
148+
try:
149+
converted_args[param_name] = target_type.model_validate(
150+
args[param_name]
151+
)
152+
except Exception as e:
153+
logger.warning(
154+
f"Failed to convert argument '{param_name}' to Pydantic model"
155+
f' {target_type.__name__}: {e}'
156+
)
157+
continue
158+
159+
# Skip None values only for Optional params
160+
if args[param_name] is None and is_optional:
161+
continue
162+
163+
# Validate and coerce all other annotated types using TypeAdapter.
164+
# This handles primitives (int, float, str, bool), enums, and
165+
# container types (list[int], dict[str, float], etc.).
166+
try:
167+
adapter = pydantic.TypeAdapter(target_type)
168+
converted_args[param_name] = adapter.validate_python(
169+
args[param_name]
170+
)
171+
except pydantic.ValidationError as e:
172+
validation_errors.append(
173+
f"Parameter '{param_name}': expected type '{target_type}',"
174+
f' validation error: {e}'
175+
)
176+
except Exception:
177+
# TypeAdapter could not handle this annotation (e.g. a forward
178+
# reference string). Skip validation silently.
179+
pass
180+
181+
return converted_args, validation_errors
158182

159183
@override
160184
async def run_async(
161185
self, *, args: dict[str, Any], tool_context: ToolContext
162186
) -> Any:
163-
# Preprocess arguments (includes Pydantic model conversion)
164-
args_to_call = self._preprocess_args(args)
187+
# Preprocess arguments (includes Pydantic model conversion and type
188+
# validation). Validation errors are returned to the LLM so it can
189+
# self-correct and retry with proper argument types.
190+
args_to_call, validation_errors = self._preprocess_args(args)
191+
192+
if validation_errors:
193+
validation_errors_str = '\n'.join(validation_errors)
194+
error_str = f"""Invoking `{self.name}()` failed due to argument validation errors:
195+
{validation_errors_str}
196+
You could retry calling this tool with corrected argument types."""
197+
return {'error': error_str}
165198

166199
signature = inspect.signature(self.func)
167200
valid_params = {param for param in signature.parameters}
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for FunctionTool argument type validation and coercion."""
16+
17+
from enum import Enum
18+
from typing import Optional
19+
from unittest.mock import MagicMock
20+
21+
from google.adk.agents.invocation_context import InvocationContext
22+
from google.adk.sessions.session import Session
23+
from google.adk.tools.function_tool import FunctionTool
24+
from google.adk.tools.tool_context import ToolContext
25+
import pytest
26+
27+
28+
class Color(Enum):
29+
RED = "red"
30+
GREEN = "green"
31+
BLUE = "blue"
32+
33+
34+
def int_func(num: int) -> int:
35+
return num
36+
37+
38+
def float_func(val: float) -> float:
39+
return val
40+
41+
42+
def bool_func(flag: bool) -> bool:
43+
return flag
44+
45+
46+
def enum_func(color: Color) -> str:
47+
return color.value
48+
49+
50+
def list_int_func(nums: list[int]) -> list[int]:
51+
return nums
52+
53+
54+
def optional_int_func(num: Optional[int] = None) -> Optional[int]:
55+
return num
56+
57+
58+
def multi_param_func(name: str, count: int, flag: bool) -> dict:
59+
return {"name": name, "count": count, "flag": flag}
60+
61+
62+
# --- _preprocess_args coercion tests ---
63+
64+
65+
class TestArgCoercion:
66+
67+
def test_string_to_int(self):
68+
tool = FunctionTool(int_func)
69+
args, errors = tool._preprocess_args({"num": "42"})
70+
assert errors == []
71+
assert args["num"] == 42
72+
assert isinstance(args["num"], int)
73+
74+
def test_float_to_int(self):
75+
"""Pydantic lax mode truncates float to int."""
76+
tool = FunctionTool(int_func)
77+
args, errors = tool._preprocess_args({"num": 3.0})
78+
assert errors == []
79+
assert args["num"] == 3
80+
assert isinstance(args["num"], int)
81+
82+
def test_string_to_float(self):
83+
tool = FunctionTool(float_func)
84+
args, errors = tool._preprocess_args({"val": "3.14"})
85+
assert errors == []
86+
assert abs(args["val"] - 3.14) < 1e-9
87+
88+
def test_int_to_float(self):
89+
tool = FunctionTool(float_func)
90+
args, errors = tool._preprocess_args({"val": 5})
91+
assert errors == []
92+
assert args["val"] == 5.0
93+
assert isinstance(args["val"], float)
94+
95+
def test_enum_valid_value(self):
96+
tool = FunctionTool(enum_func)
97+
args, errors = tool._preprocess_args({"color": "red"})
98+
assert errors == []
99+
assert args["color"] == Color.RED
100+
101+
def test_enum_invalid_value(self):
102+
tool = FunctionTool(enum_func)
103+
args, errors = tool._preprocess_args({"color": "purple"})
104+
assert len(errors) == 1
105+
assert "color" in errors[0]
106+
107+
def test_list_int_coercion(self):
108+
tool = FunctionTool(list_int_func)
109+
args, errors = tool._preprocess_args({"nums": ["1", "2", "3"]})
110+
assert errors == []
111+
assert args["nums"] == [1, 2, 3]
112+
113+
def test_optional_none_skipped(self):
114+
tool = FunctionTool(optional_int_func)
115+
args, errors = tool._preprocess_args({"num": None})
116+
assert errors == []
117+
assert args["num"] is None
118+
119+
def test_optional_value_coerced(self):
120+
tool = FunctionTool(optional_int_func)
121+
args, errors = tool._preprocess_args({"num": "7"})
122+
assert errors == []
123+
assert args["num"] == 7
124+
125+
def test_bool_from_int(self):
126+
tool = FunctionTool(bool_func)
127+
args, errors = tool._preprocess_args({"flag": 1})
128+
assert errors == []
129+
assert args["flag"] is True
130+
131+
132+
# --- _preprocess_args validation error tests ---
133+
134+
135+
class TestArgValidationErrors:
136+
137+
def test_string_for_int_returns_error(self):
138+
tool = FunctionTool(int_func)
139+
args, errors = tool._preprocess_args({"num": "foobar"})
140+
assert len(errors) == 1
141+
assert "num" in errors[0]
142+
143+
def test_none_for_required_int_returns_error(self):
144+
"""None for a non-Optional int should be flagged."""
145+
tool = FunctionTool(int_func)
146+
# None passed for a required int param. The Optional unwrap won't
147+
# trigger because the annotation is plain `int`, not Optional[int].
148+
# TypeAdapter(int).validate_python(None) raises ValidationError.
149+
args, errors = tool._preprocess_args({"num": None})
150+
assert len(errors) == 1
151+
assert "num" in errors[0]
152+
153+
def test_multiple_param_errors(self):
154+
tool = FunctionTool(multi_param_func)
155+
args, errors = tool._preprocess_args(
156+
{"name": 123, "count": "not_a_number", "flag": "not_a_bool"}
157+
)
158+
# name: int->str coercion might fail in strict, but lax mode might
159+
# handle it. count: "not_a_number"->int will fail. flag: depends on
160+
# pydantic behavior.
161+
assert any("count" in e for e in errors)
162+
163+
164+
# --- run_async integration tests ---
165+
166+
167+
def _make_tool_context():
168+
tool_context_mock = MagicMock(spec=ToolContext)
169+
invocation_context_mock = MagicMock(spec=InvocationContext)
170+
session_mock = MagicMock(spec=Session)
171+
invocation_context_mock.session = session_mock
172+
tool_context_mock.invocation_context = invocation_context_mock
173+
return tool_context_mock
174+
175+
176+
class TestRunAsyncValidation:
177+
178+
@pytest.mark.asyncio
179+
async def test_invalid_arg_returns_error_to_llm(self):
180+
tool = FunctionTool(int_func)
181+
result = await tool.run_async(
182+
args={"num": "foobar"}, tool_context=_make_tool_context()
183+
)
184+
assert isinstance(result, dict)
185+
assert "error" in result
186+
assert "validation error" in result["error"].lower()
187+
188+
@pytest.mark.asyncio
189+
async def test_valid_coercion_invokes_function(self):
190+
tool = FunctionTool(int_func)
191+
result = await tool.run_async(
192+
args={"num": "42"}, tool_context=_make_tool_context()
193+
)
194+
assert result == 42
195+
196+
@pytest.mark.asyncio
197+
async def test_enum_invalid_returns_error(self):
198+
tool = FunctionTool(enum_func)
199+
result = await tool.run_async(
200+
args={"color": "purple"}, tool_context=_make_tool_context()
201+
)
202+
assert isinstance(result, dict)
203+
assert "error" in result
204+
205+
@pytest.mark.asyncio
206+
async def test_enum_valid_invokes_function(self):
207+
tool = FunctionTool(enum_func)
208+
result = await tool.run_async(
209+
args={"color": "green"}, tool_context=_make_tool_context()
210+
)
211+
assert result == "green"

0 commit comments

Comments
 (0)