Skip to content

Commit 104edc8

Browse files
committed
fix: convert Union[Pydantic, Pydantic] tool args at runtime
FunctionTool._preprocess_args only converted dict args to a Pydantic model for single-model and Optional[Model] annotations. A Union[ModelA, ModelB] parameter was left as a raw dict, so isinstance checks inside the tool failed with "Unexpected entity type: <class 'dict'>" Use pydantic.TypeAdapter to validate against the full Union so pydantic picks the matching member. None and instances of any declared union member pass through unchanged; instances of unrelated BaseModels fall back to the existing graceful-failure warning path. Close #5799 Change-Id: Ie69f8efc8395162eac375a0eaad0c77ed2097cec
1 parent b3d0759 commit 104edc8

3 files changed

Lines changed: 180 additions & 54 deletions

File tree

contributing/samples/tools/pydantic_argument/tests/test_create_company.json

Lines changed: 11 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -53,66 +53,23 @@
5353
"id": "fc-1",
5454
"name": "create_entity_profile",
5555
"response": {
56-
"message": "Unexpected entity type: <class 'dict'>",
57-
"status": "error"
58-
}
59-
}
60-
}
61-
],
62-
"role": "user"
63-
},
64-
"id": "e-3",
65-
"invocationId": "i-1",
66-
"nodeInfo": {
67-
"path": "profile_agent@1"
68-
}
69-
},
70-
{
71-
"author": "profile_agent",
72-
"content": {
73-
"parts": [
74-
{
75-
"functionCall": {
76-
"args": {
77-
"entity": {
56+
"entity_type": "company",
57+
"message": "Company profile created for Acme Corp!",
58+
"profile": {
7859
"company_name": "Acme Corp",
7960
"employee_count": 50,
80-
"industry": "tech"
81-
}
82-
},
83-
"id": "fc-2",
84-
"name": "create_entity_profile"
85-
}
86-
}
87-
],
88-
"role": "model"
89-
},
90-
"finishReason": "STOP",
91-
"id": "e-4",
92-
"invocationId": "i-1",
93-
"longRunningToolIds": [],
94-
"nodeInfo": {
95-
"path": "profile_agent@1"
96-
}
97-
},
98-
{
99-
"author": "profile_agent",
100-
"content": {
101-
"parts": [
102-
{
103-
"functionResponse": {
104-
"id": "fc-2",
105-
"name": "create_entity_profile",
106-
"response": {
107-
"message": "Unexpected entity type: <class 'dict'>",
108-
"status": "error"
61+
"industry": "tech",
62+
"model_type": "CompanyProfile",
63+
"website": "Not provided"
64+
},
65+
"status": "company_profile_created"
10966
}
11067
}
11168
}
11269
],
11370
"role": "user"
11471
},
115-
"id": "e-5",
72+
"id": "e-3",
11673
"invocationId": "i-1",
11774
"nodeInfo": {
11875
"path": "profile_agent@1"
@@ -123,13 +80,13 @@
12380
"content": {
12481
"parts": [
12582
{
126-
"text": "I apologize for the repeated errors. It seems I am having trouble with the tool's expected input format. I will try again to create the company profile for Acme Corp. I believe the issue was in how I was structuring the data for the `create_entity_profile` function. I will now use the correct dataclass constructor to create the profile. Please bear with me."
83+
"text": "I have created a profile for Acme Corp in the tech industry with 50 employees."
12784
}
12885
],
12986
"role": "model"
13087
},
13188
"finishReason": "STOP",
132-
"id": "e-6",
89+
"id": "e-4",
13390
"invocationId": "i-1",
13491
"nodeInfo": {
13592
"path": "profile_agent@1"

src/google/adk/tools/function_tool.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,24 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
148148
]
149149
if len(non_none_types) == 1:
150150
target_type = non_none_types[0]
151+
elif len(non_none_types) > 1 and all(
152+
inspect.isclass(t) and issubclass(t, pydantic.BaseModel)
153+
for t in non_none_types
154+
):
155+
if args[param_name] is None or isinstance(
156+
args[param_name], tuple(non_none_types)
157+
):
158+
continue
159+
try:
160+
converted_args[param_name] = pydantic.TypeAdapter(
161+
param.annotation
162+
).validate_python(args[param_name])
163+
except Exception as e:
164+
logger.warning(
165+
f"Failed to convert argument '{param_name}' to"
166+
f' {param.annotation}: {e}'
167+
)
168+
continue
151169

152170
# Check if the target type is a Pydantic model
153171
if inspect.isclass(target_type) and issubclass(

tests/unittests/tools/test_function_tool_pydantic.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# Pydantic model conversion tests
1616

1717
from typing import Optional
18+
from typing import Union
1819
from unittest.mock import MagicMock
1920

2021
from google.adk.agents.invocation_context import InvocationContext
@@ -40,6 +41,14 @@ class PreferencesModel(pydantic.BaseModel):
4041
notifications: bool = True
4142

4243

44+
class CompanyModel(pydantic.BaseModel):
45+
"""Test Pydantic model for company data."""
46+
47+
company_name: str
48+
industry: str
49+
employee_count: int
50+
51+
4352
def sync_function_with_pydantic_model(user: UserModel) -> dict:
4453
"""Sync function that takes a Pydantic model."""
4554
return {
@@ -370,3 +379,145 @@ def place_order(orders: list[UserModel]) -> int:
370379
result = await tool.run_async(args=args, tool_context=tool_context_mock)
371380

372381
assert result == 50
382+
383+
384+
def _function_with_union_of_basemodels(
385+
entity: Union[UserModel, CompanyModel],
386+
) -> str:
387+
return type(entity).__name__
388+
389+
390+
def test_preprocess_args_with_union_of_basemodels_picks_user():
391+
"""Dict matching UserModel is converted to UserModel."""
392+
tool = FunctionTool(_function_with_union_of_basemodels)
393+
394+
processed_args = tool._preprocess_args(
395+
{"entity": {"name": "Diana", "age": 32, "email": "d@example.com"}}
396+
)
397+
398+
assert isinstance(processed_args["entity"], UserModel)
399+
assert processed_args["entity"].name == "Diana"
400+
401+
402+
def test_preprocess_args_with_union_of_basemodels_picks_company():
403+
"""Dict matching CompanyModel is converted to CompanyModel."""
404+
tool = FunctionTool(_function_with_union_of_basemodels)
405+
406+
processed_args = tool._preprocess_args({
407+
"entity": {
408+
"company_name": "Acme Corp",
409+
"industry": "tech",
410+
"employee_count": 50,
411+
}
412+
})
413+
414+
assert isinstance(processed_args["entity"], CompanyModel)
415+
assert processed_args["entity"].company_name == "Acme Corp"
416+
417+
418+
def test_preprocess_args_with_union_of_basemodels_existing_instance_unchanged():
419+
"""Existing instance of any union member is left unchanged."""
420+
tool = FunctionTool(_function_with_union_of_basemodels)
421+
422+
user = UserModel(name="Bob", age=25)
423+
assert tool._preprocess_args({"entity": user})["entity"] is user
424+
425+
company = CompanyModel(
426+
company_name="Acme", industry="tech", employee_count=10
427+
)
428+
assert tool._preprocess_args({"entity": company})["entity"] is company
429+
430+
431+
def test_preprocess_args_with_union_of_basemodels_unrelated_instance_passthrough():
432+
"""A BaseModel instance not in the union is not silently accepted."""
433+
tool = FunctionTool(_function_with_union_of_basemodels)
434+
435+
class UnrelatedModel(pydantic.BaseModel):
436+
name: str
437+
age: int
438+
439+
unrelated = UnrelatedModel(name="Carol", age=20)
440+
processed_args = tool._preprocess_args({"entity": unrelated})
441+
442+
# Conversion fails (UnrelatedModel is not in the union); value is left
443+
# alone so the function receives it and raises a clear error itself.
444+
assert processed_args["entity"] is unrelated
445+
446+
447+
def test_preprocess_args_with_optional_union_of_basemodels_none():
448+
"""Optional[Union[A, B]] passes None through unchanged."""
449+
450+
def fn(entity: Optional[Union[UserModel, CompanyModel]] = None) -> str:
451+
return type(entity).__name__
452+
453+
tool = FunctionTool(fn)
454+
455+
processed_args = tool._preprocess_args({"entity": None})
456+
457+
assert processed_args["entity"] is None
458+
459+
460+
def test_preprocess_args_with_optional_union_of_basemodels_dict():
461+
"""Optional[Union[A, B]] converts a dict to the matching model."""
462+
463+
def fn(entity: Optional[Union[UserModel, CompanyModel]] = None) -> str:
464+
return type(entity).__name__
465+
466+
tool = FunctionTool(fn)
467+
468+
processed_args = tool._preprocess_args({"entity": {"name": "Eve", "age": 40}})
469+
470+
assert isinstance(processed_args["entity"], UserModel)
471+
assert processed_args["entity"].name == "Eve"
472+
473+
474+
def test_preprocess_args_with_union_of_basemodels_invalid_data():
475+
"""Invalid data for Union[BaseModel, BaseModel] is kept unchanged."""
476+
tool = FunctionTool(_function_with_union_of_basemodels)
477+
478+
# Dict matches neither model.
479+
processed_args = tool._preprocess_args(
480+
{"entity": {"unrelated_field": "value"}}
481+
)
482+
483+
assert processed_args["entity"] == {"unrelated_field": "value"}
484+
485+
486+
@pytest.mark.asyncio
487+
async def test_run_async_with_union_of_basemodels():
488+
"""run_async end-to-end converts dict to the matching union member."""
489+
490+
def create_entity_profile(
491+
entity: Union[UserModel, CompanyModel],
492+
) -> dict:
493+
if isinstance(entity, UserModel):
494+
return {"entity_type": "user", "name": entity.name}
495+
if isinstance(entity, CompanyModel):
496+
return {"entity_type": "company", "name": entity.company_name}
497+
return {"entity_type": "unknown"}
498+
499+
tool = FunctionTool(create_entity_profile)
500+
501+
tool_context_mock = MagicMock(spec=ToolContext)
502+
invocation_context_mock = MagicMock(spec=InvocationContext)
503+
session_mock = MagicMock(spec=Session)
504+
invocation_context_mock.session = session_mock
505+
tool_context_mock.invocation_context = invocation_context_mock
506+
507+
user_result = await tool.run_async(
508+
args={"entity": {"name": "Diana", "age": 32}},
509+
tool_context=tool_context_mock,
510+
)
511+
assert user_result == {"entity_type": "user", "name": "Diana"}
512+
513+
company_result = await tool.run_async(
514+
args={
515+
"entity": {
516+
"company_name": "Acme Corp",
517+
"industry": "tech",
518+
"employee_count": 50,
519+
}
520+
},
521+
tool_context=tool_context_mock,
522+
)
523+
assert company_result == {"entity_type": "company", "name": "Acme Corp"}

0 commit comments

Comments
 (0)