Skip to content

Commit 551445e

Browse files
sbunchgoogleGWeale
authored andcommitted
fix: add PEP 604 union syntax in function tool parameters
- Support PEP 604 unions in direct parameter parser. - Fix a bug where collapsing simple unions (e.g., `Optional[list[T]]`) lost nested schema properties (like `items` or `properties`) by replacing the parent schema with the collapsed inner schema instead of just copying its type. - Update `test_required_fields_set_in_json_schema_fallback` to use `tuple[str, ...]` to ensure the fallback path remains tested. Change-Id: Idc1cb55e265ba888c03aa923f60d2d4b3d1ae131
1 parent cae2337 commit 551445e

3 files changed

Lines changed: 83 additions & 10 deletions

File tree

src/google/adk/tools/_function_parameter_parse_util.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _parse_schema_from_parameter(
247247
_raise_if_schema_unsupported(variant, schema)
248248
return schema
249249
if (
250-
get_origin(param.annotation) is Union
250+
get_origin(param.annotation) in (Union, typing_types.UnionType)
251251
# only parse simple UnionType, example int | str | float | bool
252252
# complex types.UnionType will be invoked in raise branch
253253
and all(
@@ -276,8 +276,10 @@ def _parse_schema_from_parameter(
276276
schema.any_of.append(schema_in_any_of)
277277
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
278278
if len(schema.any_of) == 1: # param: list | None -> Array
279-
schema.type = schema.any_of[0].type
280-
schema.any_of = None
279+
collapsed = schema.any_of[0]
280+
if schema.nullable:
281+
collapsed.nullable = True
282+
schema = collapsed
281283
if (
282284
param.default is not inspect.Parameter.empty
283285
and param.default is not None
@@ -287,8 +289,10 @@ def _parse_schema_from_parameter(
287289
schema.default = param.default
288290
_raise_if_schema_unsupported(variant, schema)
289291
return schema
290-
if isinstance(param.annotation, _GenericAlias) or isinstance(
291-
param.annotation, typing_types.GenericAlias
292+
if (
293+
isinstance(param.annotation, _GenericAlias)
294+
or isinstance(param.annotation, typing_types.GenericAlias)
295+
or isinstance(param.annotation, typing_types.UnionType)
292296
):
293297
origin = get_origin(param.annotation)
294298
args = get_args(param.annotation)
@@ -330,7 +334,7 @@ def _parse_schema_from_parameter(
330334
schema.default = param.default
331335
_raise_if_schema_unsupported(variant, schema)
332336
return schema
333-
if origin is Union:
337+
if origin in (Union, typing_types.UnionType):
334338
schema.any_of = []
335339
schema.type = types.Type.OBJECT
336340
unique_types = set()
@@ -365,8 +369,10 @@ def _parse_schema_from_parameter(
365369
schema.any_of.append(schema_in_any_of)
366370
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
367371
if len(schema.any_of) == 1: # param: Union[List, None] -> Array
368-
schema.type = schema.any_of[0].type
369-
schema.any_of = None
372+
collapsed = schema.any_of[0]
373+
if schema.nullable:
374+
collapsed.nullable = True
375+
schema = collapsed
370376
if (
371377
param.default is not None
372378
and param.default is not inspect.Parameter.empty

tests/unittests/tools/test_from_function_with_options.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ async def test_function(param: str) -> AsyncGenerator[Dict[str, str], None]:
324324
def test_required_fields_set_in_json_schema_fallback():
325325
"""Test that required fields are populated when the json_schema fallback path is used.
326326
327-
When a parameter has a complex union type (e.g. list[str] | None) that
327+
When a parameter has a complex type (e.g. tuple[str, ...] | None) that
328328
_parse_schema_from_parameter can't handle, from_function_with_options falls
329329
back to the parameters_json_schema branch. This test verifies that the
330330
required fields are correctly populated in that fallback branch.
@@ -333,7 +333,7 @@ def test_required_fields_set_in_json_schema_fallback():
333333
def complex_tool(
334334
query: str,
335335
mode: str = 'default',
336-
tags: list[str] | None = None,
336+
tags: tuple[str, ...] | None = None,
337337
) -> str:
338338
"""A tool where one param has a complex union type."""
339339
return query

tests/unittests/tools/test_set_model_response_tool.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
"""Tests for SetModelResponseTool."""
1616

1717
import inspect
18+
from typing import Optional
1819

1920
from google.adk.agents.invocation_context import InvocationContext
2021
from google.adk.agents.llm_agent import LlmAgent
2122
from google.adk.agents.run_config import RunConfig
23+
from google.adk.features._feature_registry import FeatureName
24+
from google.adk.features._feature_registry import temporary_feature_override
2225
from google.adk.sessions.in_memory_session_service import InMemorySessionService
2326
from google.adk.tools.set_model_response_tool import SetModelResponseTool
2427
from google.adk.tools.tool_context import ToolContext
@@ -467,3 +470,67 @@ async def test_run_async_dict_schema():
467470
assert result is not None
468471
assert isinstance(result, dict)
469472
assert result == {'a': 1, 'b': 2, 'c': 3}
473+
474+
475+
class SubSchema(BaseModel):
476+
477+
field1: str = Field(description='Field 1')
478+
field2: int = Field(description='Field 2')
479+
480+
481+
class ConsolidatedOptionalSchema(BaseModel):
482+
483+
nested: Optional[SubSchema] = Field(default=None, description='Nested model')
484+
nested_list: Optional[list[SubSchema]] = Field(
485+
default=None, description='Nested list of models'
486+
)
487+
pep604_nested: SubSchema | None = Field(
488+
default=None, description='PEP 604 optional nested model'
489+
)
490+
pep604_raw_list: list | None = Field(default=None, description='Raw list')
491+
492+
493+
def test_get_declaration_optional_fields():
494+
"""Test that tool declaration preserves properties for various optional fields."""
495+
with temporary_feature_override(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, False):
496+
tool = SetModelResponseTool(ConsolidatedOptionalSchema)
497+
498+
declaration = tool._get_declaration()
499+
500+
assert declaration is not None
501+
assert declaration.name == 'set_model_response'
502+
params_schema = declaration.parameters
503+
assert params_schema is not None
504+
assert params_schema.type == 'OBJECT'
505+
506+
# 1. Optional[SubSchema]
507+
assert 'nested' in params_schema.properties
508+
nested_schema = params_schema.properties['nested']
509+
assert nested_schema.type == 'OBJECT'
510+
assert nested_schema.properties is not None
511+
assert nested_schema.properties['field1'].type == 'STRING'
512+
assert nested_schema.properties['field2'].type == 'INTEGER'
513+
514+
# 2. Optional[list[SubSchema]]
515+
assert 'nested_list' in params_schema.properties
516+
nested_list_schema = params_schema.properties['nested_list']
517+
assert nested_list_schema.type == 'ARRAY'
518+
assert nested_list_schema.items is not None
519+
items_schema = nested_list_schema.items
520+
assert items_schema.type == 'OBJECT'
521+
assert items_schema.properties is not None
522+
assert items_schema.properties['field1'].type == 'STRING'
523+
assert items_schema.properties['field2'].type == 'INTEGER'
524+
525+
# 3. SubSchema | None (PEP 604)
526+
assert 'pep604_nested' in params_schema.properties
527+
pep604_nested_schema = params_schema.properties['pep604_nested']
528+
assert pep604_nested_schema.type == 'OBJECT'
529+
assert pep604_nested_schema.properties is not None
530+
assert pep604_nested_schema.properties['field1'].type == 'STRING'
531+
assert pep604_nested_schema.properties['field2'].type == 'INTEGER'
532+
533+
# 4. list | None (PEP 604)
534+
assert 'pep604_raw_list' in params_schema.properties
535+
pep604_raw_list_schema = params_schema.properties['pep604_raw_list']
536+
assert pep604_raw_list_schema.type == 'ARRAY'

0 commit comments

Comments
 (0)