Skip to content

Commit 44a5e6b

Browse files
wukathcopybara-github
authored andcommitted
feat: Add support for ADK tools in SkillToolset
To use ADK tools, users can specify the tool name in a skill object's `additional_tools` and pass the tool in when initializing a SkillToolset. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 879230409
1 parent bcf38fa commit 44a5e6b

File tree

6 files changed

+239
-12
lines changed

6 files changed

+239
-12
lines changed

contributing/samples/skills_agent/agent.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,52 @@
2020
from google.adk.code_executors.unsafe_local_code_executor import UnsafeLocalCodeExecutor
2121
from google.adk.skills import load_skill_from_dir
2222
from google.adk.skills import models
23+
from google.adk.tools.base_tool import BaseTool
2324
from google.adk.tools.skill_toolset import SkillToolset
25+
from google.genai import types
26+
27+
28+
class GetTimezoneTool(BaseTool):
29+
"""A tool to get the timezone for a given location."""
30+
31+
def __init__(self):
32+
super().__init__(
33+
name="get_timezone",
34+
description="Returns the timezone for a given location.",
35+
)
36+
37+
def _get_declaration(self) -> types.FunctionDeclaration | None:
38+
return types.FunctionDeclaration(
39+
name=self.name,
40+
description=self.description,
41+
parameters_json_schema={
42+
"type": "object",
43+
"properties": {
44+
"location": {
45+
"type": "string",
46+
"description": "The location to get the timezone for.",
47+
},
48+
},
49+
"required": ["location"],
50+
},
51+
)
52+
53+
async def run_async(self, *, args: dict, tool_context) -> str:
54+
return f"The timezone for {args['location']} is UTC+00:00."
55+
56+
57+
def get_current_humidity(location: str) -> str:
58+
"""Returns the current humidity for a given location."""
59+
return f"The humidity in {location} is 45%."
60+
2461

2562
greeting_skill = models.Skill(
2663
frontmatter=models.Frontmatter(
2764
name="greeting-skill",
2865
description=(
2966
"A friendly greeting skill that can say hello to a specific person."
3067
),
68+
metadata={"adk_additional_tools": ["get_timezone"]},
3169
),
3270
instructions=(
3371
"Step 1: Read the 'references/hello_world.txt' file to understand how"
@@ -49,6 +87,7 @@
4987
# be used in production environments.
5088
my_skill_toolset = SkillToolset(
5189
skills=[greeting_skill, weather_skill],
90+
additional_tools=[GetTimezoneTool(), get_current_humidity],
5291
code_executor=UnsafeLocalCodeExecutor(),
5392
)
5493

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
---
2+
name: weather-skill
3+
description: A skill that provides weather information based on reference data.
4+
metadata:
5+
adk_additional_tools:
6+
- get_current_humidity
7+
---
8+
9+
Step 1: Check 'references/weather_info.md' for the current weather.
10+
Step 2: Provide the weather update to the user.

src/google/adk/skills/models.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import re
20+
from typing import Any
2021
from typing import Optional
2122
import unicodedata
2223

@@ -37,11 +38,13 @@ class Frontmatter(BaseModel):
3738
(required).
3839
license: License for the skill (optional).
3940
compatibility: Compatibility information for the skill (optional).
40-
allowed_tools: Tool patterns the skill requires (optional, experimental).
41-
Accepts both ``allowed_tools`` and the YAML-friendly ``allowed-tools``
42-
key.
41+
allowed_tools: A space-delimited list of tools that are pre-approved to
42+
run (optional, experimental). Accepts both ``allowed_tools`` and the
43+
YAML-friendly ``allowed-tools`` key. For more details, see
44+
https://agentskills.io/specification#allowed-tools-field.
4345
metadata: Key-value pairs for client-specific properties (defaults to
44-
empty dict).
46+
empty dict). For example, to include additional tools, use the
47+
``adk_additional_tools`` key with a list of tools.
4548
"""
4649

4750
model_config = ConfigDict(
@@ -58,7 +61,16 @@ class Frontmatter(BaseModel):
5861
alias="allowed-tools",
5962
serialization_alias="allowed-tools",
6063
)
61-
metadata: dict[str, str] = {}
64+
metadata: dict[str, Any] = {}
65+
66+
@field_validator("metadata")
67+
@classmethod
68+
def _validate_metadata(cls, v: dict[str, Any]) -> dict[str, Any]:
69+
if "adk_additional_tools" in v:
70+
tools = v["adk_additional_tools"]
71+
if not isinstance(tools, list):
72+
raise ValueError("adk_additional_tools must be a list of strings")
73+
return v
6274

6375
@field_validator("name")
6476
@classmethod

src/google/adk/tools/skill_toolset.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@
3737
from ..skills import prompt
3838
from .base_tool import BaseTool
3939
from .base_toolset import BaseToolset
40+
from .function_tool import FunctionTool
4041
from .tool_context import ToolContext
4142

4243
if TYPE_CHECKING:
44+
from ..agents.llm_agent import ToolUnion
4345
from ..models.llm_request import LlmRequest
4446

4547
logger = logging.getLogger("google_adk." + __name__)
@@ -138,6 +140,15 @@ async def run_async(
138140
"error_code": "SKILL_NOT_FOUND",
139141
}
140142

143+
# Record skill activation in agent state for tool resolution.
144+
agent_name = tool_context.agent_name
145+
state_key = f"_adk_activated_skill_{agent_name}"
146+
147+
activated_skills = list(tool_context.state.get(state_key, []))
148+
if skill_name not in activated_skills:
149+
activated_skills.append(skill_name)
150+
tool_context.state[state_key] = activated_skills
151+
141152
return {
142153
"skill_name": skill_name,
143154
"instructions": skill.instructions,
@@ -586,6 +597,7 @@ def __init__(
586597
*,
587598
code_executor: Optional[BaseCodeExecutor] = None,
588599
script_timeout: int = _DEFAULT_SCRIPT_TIMEOUT,
600+
additional_tools: list[ToolUnion] | None = None,
589601
):
590602
"""Initializes the SkillToolset.
591603
@@ -609,20 +621,73 @@ def __init__(
609621
self._code_executor = code_executor
610622
self._script_timeout = script_timeout
611623

624+
self._provided_tools_by_name = {}
625+
for tool_union in additional_tools or []:
626+
if isinstance(tool_union, BaseTool):
627+
self._provided_tools_by_name[tool_union.name] = tool_union
628+
elif callable(tool_union):
629+
ft = FunctionTool(tool_union)
630+
self._provided_tools_by_name[ft.name] = ft
631+
612632
# Initialize core skill tools
613633
self._tools = [
614634
ListSkillsTool(self),
615635
LoadSkillTool(self),
616636
LoadSkillResourceTool(self),
637+
RunSkillScriptTool(self),
617638
]
618-
# Always add RunSkillScriptTool, relies on invocation_context fallback if _code_executor is None
619-
self._tools.append(RunSkillScriptTool(self))
620639

621640
async def get_tools(
622641
self, readonly_context: ReadonlyContext | None = None
623642
) -> list[BaseTool]:
624643
"""Returns the list of tools in this toolset."""
625-
return self._tools
644+
dynamic_tools = await self._resolve_additional_tools_from_state(
645+
readonly_context
646+
)
647+
return self._tools + dynamic_tools
648+
649+
async def _resolve_additional_tools_from_state(
650+
self, readonly_context: ReadonlyContext | None
651+
) -> list[BaseTool]:
652+
"""Resolves tools listed in the "adk_additional_tools" metadata of skills."""
653+
654+
if not readonly_context:
655+
return []
656+
657+
agent_name = readonly_context.agent_name
658+
state_key = f"_adk_activated_skill_{agent_name}"
659+
activated_skills = readonly_context.state.get(state_key, [])
660+
661+
if not activated_skills:
662+
return []
663+
664+
additional_tool_names = set()
665+
for skill_name in activated_skills:
666+
skill = self._skills.get(skill_name)
667+
if skill:
668+
additional_tools = skill.frontmatter.metadata.get(
669+
"adk_additional_tools"
670+
)
671+
if additional_tools:
672+
additional_tool_names.update(additional_tools)
673+
674+
if not additional_tool_names:
675+
return []
676+
677+
resolved_tools = []
678+
existing_tool_names = {t.name for t in self._tools}
679+
for name in additional_tool_names:
680+
if name in self._provided_tools_by_name:
681+
tool = self._provided_tools_by_name[name]
682+
if tool.name in existing_tool_names:
683+
logger.error(
684+
"Tool name collision: tool '%s' already exists.", tool.name
685+
)
686+
continue
687+
resolved_tools.append(tool)
688+
existing_tool_names.add(tool.name)
689+
690+
return resolved_tools
626691

627692
def _get_skill(self, name: str) -> models.Skill | None:
628693
"""Retrieves a skill by name."""

tests/unittests/skills/test_models.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,34 @@ def test_allowed_tools_serialization_alias():
173173
dumped = fm.model_dump(by_alias=True)
174174
assert "allowed-tools" in dumped
175175
assert dumped["allowed-tools"] == "tool-pattern"
176+
177+
178+
def test_metadata_adk_additional_tools_list():
179+
fm = models.Frontmatter.model_validate({
180+
"name": "my-skill",
181+
"description": "desc",
182+
"metadata": {"adk_additional_tools": ["tool1", "tool2"]},
183+
})
184+
assert fm.metadata["adk_additional_tools"] == ["tool1", "tool2"]
185+
186+
187+
def test_metadata_adk_additional_tools_rejected_as_string():
188+
with pytest.raises(
189+
ValidationError, match="adk_additional_tools must be a list of strings"
190+
):
191+
models.Frontmatter.model_validate({
192+
"name": "my-skill",
193+
"description": "desc",
194+
"metadata": {"adk_additional_tools": "tool1 tool2"},
195+
})
196+
197+
198+
def test_metadata_adk_additional_tools_invalid_type():
199+
with pytest.raises(
200+
ValidationError, match="adk_additional_tools must be a list of strings"
201+
):
202+
models.Frontmatter.model_validate({
203+
"name": "my-skill",
204+
"description": "desc",
205+
"metadata": {"adk_additional_tools": 123},
206+
})

tests/unittests/tools/test_skill_toolset.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# pylint: disable=redefined-outer-name,g-import-not-at-top,protected-access
16-
17-
15+
import logging
1816
from unittest import mock
1917

2018
from google.adk.code_executors.base_code_executor import BaseCodeExecutor
@@ -145,7 +143,13 @@ def get_asset(name):
145143
@pytest.fixture
146144
def tool_context_instance():
147145
"""Fixture for tool context."""
148-
return mock.create_autospec(tool_context.ToolContext, instance=True)
146+
ctx = mock.create_autospec(tool_context.ToolContext, instance=True)
147+
ctx._invocation_context = mock.MagicMock()
148+
ctx._invocation_context.agent = mock.MagicMock()
149+
ctx._invocation_context.agent.name = "test_agent"
150+
ctx._invocation_context.agent_states = {}
151+
ctx.agent_name = "test_agent"
152+
return ctx
149153

150154

151155
# SkillToolset tests
@@ -361,6 +365,10 @@ def _make_tool_context_with_agent(agent=None):
361365
ctx = mock.MagicMock(spec=tool_context.ToolContext)
362366
ctx._invocation_context = mock.MagicMock()
363367
ctx._invocation_context.agent = agent or mock.MagicMock()
368+
ctx._invocation_context.agent.name = "test_agent"
369+
ctx._invocation_context.agent_states = {}
370+
ctx.agent_name = "test_agent"
371+
ctx.state = {}
364372
return ctx
365373

366374

@@ -1202,3 +1210,65 @@ async def test_execute_script_binary_content_packaged():
12021210
assert "b'\\x00\\x01\\x02'" in code_input.code
12031211
# Wrapper code handles binary with 'wb' mode
12041212
assert "'wb' if isinstance(content, bytes)" in code_input.code
1213+
1214+
1215+
@pytest.mark.asyncio
1216+
async def test_skill_toolset_dynamic_tool_resolution(mock_skill1):
1217+
# Set up a skill with additional_tools in metadata
1218+
mock_skill1.frontmatter.metadata = {
1219+
"adk_additional_tools": ["my_custom_tool", "my_func"]
1220+
}
1221+
mock_skill1.name = "skill1"
1222+
1223+
# Prepare additional tools
1224+
custom_tool = mock.create_autospec(skill_toolset.BaseTool, instance=True)
1225+
custom_tool.name = "my_custom_tool"
1226+
1227+
def my_func():
1228+
"""My function description."""
1229+
pass
1230+
1231+
toolset = skill_toolset.SkillToolset(
1232+
[mock_skill1],
1233+
additional_tools=[custom_tool, my_func],
1234+
)
1235+
1236+
ctx = _make_tool_context_with_agent()
1237+
# Initial tools (only core)
1238+
tools = await toolset.get_tools(readonly_context=ctx)
1239+
assert len(tools) == 4
1240+
1241+
# Activate skill
1242+
load_tool = skill_toolset.LoadSkillTool(toolset)
1243+
await load_tool.run_async(args={"name": "skill1"}, tool_context=ctx)
1244+
1245+
# Dynamic tools should now be resolved
1246+
tools = await toolset.get_tools(readonly_context=ctx)
1247+
tool_names = {t.name for t in tools}
1248+
assert "my_custom_tool" in tool_names
1249+
assert "my_func" in tool_names
1250+
1251+
# Check specific tool resolution details
1252+
my_func_tool = next(t for t in tools if t.name == "my_func")
1253+
assert isinstance(my_func_tool, skill_toolset.FunctionTool)
1254+
assert my_func_tool.description == "My function description."
1255+
1256+
1257+
@pytest.mark.asyncio
1258+
async def test_skill_toolset_resolution_error_handling(mock_skill1, caplog):
1259+
mock_skill1.frontmatter.metadata = {
1260+
"adk_additional_tools": ["nonexistent_tool"]
1261+
}
1262+
mock_skill1.name = "skill1"
1263+
toolset = skill_toolset.SkillToolset([mock_skill1])
1264+
ctx = _make_tool_context_with_agent()
1265+
1266+
# Activate skill
1267+
load_tool = skill_toolset.LoadSkillTool(toolset)
1268+
await load_tool.run_async(args={"name": "skill1"}, tool_context=ctx)
1269+
1270+
with caplog.at_level(logging.WARNING):
1271+
tools = await toolset.get_tools(readonly_context=ctx)
1272+
1273+
# Should still return basic skill tools
1274+
assert len(tools) == 4

0 commit comments

Comments
 (0)