forked from google/adk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathset_model_response_tool.py
More file actions
162 lines (142 loc) · 5.71 KB
/
set_model_response_tool.py
File metadata and controls
162 lines (142 loc) · 5.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tool for setting model response when using output_schema with other tools."""
from __future__ import annotations
import inspect
from typing import Any
from typing import Optional
from google.genai import types
from pydantic import TypeAdapter
from typing_extensions import override
from ..utils._schema_utils import get_list_inner_type
from ..utils._schema_utils import is_basemodel_schema
from ..utils._schema_utils import is_list_of_basemodel
from ..utils._schema_utils import SchemaType
from ._automatic_function_calling_util import build_function_declaration
from .base_tool import BaseTool
from .tool_context import ToolContext
class SetModelResponseTool(BaseTool):
"""Internal tool used for output schema workaround.
This tool allows the model to set its final response when output_schema
is configured alongside other tools. The model should use this tool to
provide its final structured response instead of outputting text directly.
"""
def __init__(self, output_schema: SchemaType):
"""Initialize the tool with the expected output schema.
Args:
output_schema: The output schema. Supports all types from SchemaUnion:
- type[BaseModel]: A pydantic model class (e.g., MySchema)
- list[type[BaseModel]]: A generic list type (e.g., list[MySchema])
- list[primitive]: e.g., list[str], list[int]
- dict: Raw dict schemas
- Schema: Google's Schema type
"""
self.output_schema = output_schema
self._is_basemodel = is_basemodel_schema(output_schema)
self._is_list_of_basemodel = is_list_of_basemodel(output_schema)
# Create a function that matches the output schema
def set_model_response() -> str:
"""Set your final response using the required output schema.
Use this tool to provide your final structured answer instead
of outputting text directly.
"""
return 'Response set successfully.'
# Add the schema fields as parameters to the function dynamically
if self._is_basemodel:
# For regular BaseModel, use the model's fields
schema_fields = output_schema.model_fields
params = []
for field_name, field_info in schema_fields.items():
param = inspect.Parameter(
field_name,
inspect.Parameter.KEYWORD_ONLY,
annotation=field_info.annotation,
)
params.append(param)
elif self._is_list_of_basemodel:
# For list[BaseModel], create a single 'items' parameter
inner_type = get_list_inner_type(output_schema)
params = [
inspect.Parameter(
'items',
inspect.Parameter.KEYWORD_ONLY,
annotation=list[inner_type],
)
]
elif isinstance(output_schema, dict):
# Use `dict` type, not the instance — dict instances are unhashable.
params = [
inspect.Parameter(
'response',
inspect.Parameter.KEYWORD_ONLY,
annotation=dict,
)
]
else:
# For other schema types (list[str], Schema, GenericAlias, etc.),
# create a single parameter with the actual schema type
params = [
inspect.Parameter(
'response',
inspect.Parameter.KEYWORD_ONLY,
annotation=output_schema,
)
]
# Create new signature with schema parameters
new_sig = inspect.Signature(parameters=params)
setattr(set_model_response, '__signature__', new_sig)
self.func = set_model_response
super().__init__(
name=self.func.__name__,
description=self.func.__doc__.strip() if self.func.__doc__ else '',
)
@override
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
"""Gets the OpenAPI specification of this tool."""
function_decl = types.FunctionDeclaration.model_validate(
build_function_declaration(
func=self.func,
ignore_params=[],
variant=self._api_variant,
)
)
return function_decl
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
"""Process the model's response and return the validated data.
Args:
args: The structured response data matching the output schema.
tool_context: Tool execution context.
Returns:
The validated response. Type depends on the output_schema:
- dict for BaseModel
- list of dicts for list[BaseModel]
- raw value for other schema types (list[str], dict, etc.)
"""
if self._is_basemodel:
# For regular BaseModel, validate directly
validated_response = self.output_schema.model_validate(args)
return validated_response.model_dump(exclude_none=True)
elif self._is_list_of_basemodel:
# For list[BaseModel], extract and validate the 'items' field
items = args.get('items', [])
type_adapter = TypeAdapter(self.output_schema)
validated_response = type_adapter.validate_python(items)
return [item.model_dump(exclude_none=True) for item in validated_response]
else:
# For other schema types (list[str], dict, etc.),
# return the value directly without pydantic validation
return args.get('response')