1515# Pydantic model conversion tests
1616
1717from typing import Optional
18+ from typing import Union
1819from unittest .mock import MagicMock
1920
2021from google .adk .agents .invocation_context import InvocationContext
@@ -40,6 +41,14 @@ class PreferencesModel(pydantic.BaseModel):
4041 notifications : bool = True
4142
4243
44+ class CompanyModel (pydantic .BaseModel ):
45+ """Test Pydantic model for company data."""
46+
47+ company_name : str
48+ industry : str
49+ employee_count : int
50+
51+
4352def sync_function_with_pydantic_model (user : UserModel ) -> dict :
4453 """Sync function that takes a Pydantic model."""
4554 return {
@@ -370,3 +379,145 @@ def place_order(orders: list[UserModel]) -> int:
370379 result = await tool .run_async (args = args , tool_context = tool_context_mock )
371380
372381 assert result == 50
382+
383+
384+ def _function_with_union_of_basemodels (
385+ entity : Union [UserModel , CompanyModel ],
386+ ) -> str :
387+ return type (entity ).__name__
388+
389+
390+ def test_preprocess_args_with_union_of_basemodels_picks_user ():
391+ """Dict matching UserModel is converted to UserModel."""
392+ tool = FunctionTool (_function_with_union_of_basemodels )
393+
394+ processed_args = tool ._preprocess_args (
395+ {"entity" : {"name" : "Diana" , "age" : 32 , "email" : "d@example.com" }}
396+ )
397+
398+ assert isinstance (processed_args ["entity" ], UserModel )
399+ assert processed_args ["entity" ].name == "Diana"
400+
401+
402+ def test_preprocess_args_with_union_of_basemodels_picks_company ():
403+ """Dict matching CompanyModel is converted to CompanyModel."""
404+ tool = FunctionTool (_function_with_union_of_basemodels )
405+
406+ processed_args = tool ._preprocess_args ({
407+ "entity" : {
408+ "company_name" : "Acme Corp" ,
409+ "industry" : "tech" ,
410+ "employee_count" : 50 ,
411+ }
412+ })
413+
414+ assert isinstance (processed_args ["entity" ], CompanyModel )
415+ assert processed_args ["entity" ].company_name == "Acme Corp"
416+
417+
418+ def test_preprocess_args_with_union_of_basemodels_existing_instance_unchanged ():
419+ """Existing instance of any union member is left unchanged."""
420+ tool = FunctionTool (_function_with_union_of_basemodels )
421+
422+ user = UserModel (name = "Bob" , age = 25 )
423+ assert tool ._preprocess_args ({"entity" : user })["entity" ] is user
424+
425+ company = CompanyModel (
426+ company_name = "Acme" , industry = "tech" , employee_count = 10
427+ )
428+ assert tool ._preprocess_args ({"entity" : company })["entity" ] is company
429+
430+
431+ def test_preprocess_args_with_union_of_basemodels_unrelated_instance_passthrough ():
432+ """A BaseModel instance not in the union is not silently accepted."""
433+ tool = FunctionTool (_function_with_union_of_basemodels )
434+
435+ class UnrelatedModel (pydantic .BaseModel ):
436+ name : str
437+ age : int
438+
439+ unrelated = UnrelatedModel (name = "Carol" , age = 20 )
440+ processed_args = tool ._preprocess_args ({"entity" : unrelated })
441+
442+ # Conversion fails (UnrelatedModel is not in the union); value is left
443+ # alone so the function receives it and raises a clear error itself.
444+ assert processed_args ["entity" ] is unrelated
445+
446+
447+ def test_preprocess_args_with_optional_union_of_basemodels_none ():
448+ """Optional[Union[A, B]] passes None through unchanged."""
449+
450+ def fn (entity : Optional [Union [UserModel , CompanyModel ]] = None ) -> str :
451+ return type (entity ).__name__
452+
453+ tool = FunctionTool (fn )
454+
455+ processed_args = tool ._preprocess_args ({"entity" : None })
456+
457+ assert processed_args ["entity" ] is None
458+
459+
460+ def test_preprocess_args_with_optional_union_of_basemodels_dict ():
461+ """Optional[Union[A, B]] converts a dict to the matching model."""
462+
463+ def fn (entity : Optional [Union [UserModel , CompanyModel ]] = None ) -> str :
464+ return type (entity ).__name__
465+
466+ tool = FunctionTool (fn )
467+
468+ processed_args = tool ._preprocess_args ({"entity" : {"name" : "Eve" , "age" : 40 }})
469+
470+ assert isinstance (processed_args ["entity" ], UserModel )
471+ assert processed_args ["entity" ].name == "Eve"
472+
473+
474+ def test_preprocess_args_with_union_of_basemodels_invalid_data ():
475+ """Invalid data for Union[BaseModel, BaseModel] is kept unchanged."""
476+ tool = FunctionTool (_function_with_union_of_basemodels )
477+
478+ # Dict matches neither model.
479+ processed_args = tool ._preprocess_args (
480+ {"entity" : {"unrelated_field" : "value" }}
481+ )
482+
483+ assert processed_args ["entity" ] == {"unrelated_field" : "value" }
484+
485+
486+ @pytest .mark .asyncio
487+ async def test_run_async_with_union_of_basemodels ():
488+ """run_async end-to-end converts dict to the matching union member."""
489+
490+ def create_entity_profile (
491+ entity : Union [UserModel , CompanyModel ],
492+ ) -> dict :
493+ if isinstance (entity , UserModel ):
494+ return {"entity_type" : "user" , "name" : entity .name }
495+ if isinstance (entity , CompanyModel ):
496+ return {"entity_type" : "company" , "name" : entity .company_name }
497+ return {"entity_type" : "unknown" }
498+
499+ tool = FunctionTool (create_entity_profile )
500+
501+ tool_context_mock = MagicMock (spec = ToolContext )
502+ invocation_context_mock = MagicMock (spec = InvocationContext )
503+ session_mock = MagicMock (spec = Session )
504+ invocation_context_mock .session = session_mock
505+ tool_context_mock .invocation_context = invocation_context_mock
506+
507+ user_result = await tool .run_async (
508+ args = {"entity" : {"name" : "Diana" , "age" : 32 }},
509+ tool_context = tool_context_mock ,
510+ )
511+ assert user_result == {"entity_type" : "user" , "name" : "Diana" }
512+
513+ company_result = await tool .run_async (
514+ args = {
515+ "entity" : {
516+ "company_name" : "Acme Corp" ,
517+ "industry" : "tech" ,
518+ "employee_count" : 50 ,
519+ }
520+ },
521+ tool_context = tool_context_mock ,
522+ )
523+ assert company_result == {"entity_type" : "company" , "name" : "Acme Corp" }
0 commit comments