Skip to content

Commit 3e282d2

Browse files
wukathcopybara-github
authored andcommitted
feat: Implement Skill Registry interface in ADK
Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 907023117
1 parent 6d2ada8 commit 3e282d2

3 files changed

Lines changed: 295 additions & 3 deletions

File tree

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Interface for skill registry."""
16+
17+
from __future__ import annotations
18+
19+
import abc
20+
from typing import Any
21+
from typing import Dict
22+
from typing import List
23+
24+
from . import models
25+
26+
27+
class SkillRegistry(abc.ABC):
28+
"""Interface for a skill registry."""
29+
30+
@abc.abstractmethod
31+
async def get_skill(
32+
self, *, name: str, version: str | None = None
33+
) -> models.Skill:
34+
"""Fetches a skill from the registry.
35+
36+
Args:
37+
name: The name of the skill.
38+
version: Optional version of the skill.
39+
40+
Returns:
41+
A Skill object.
42+
"""
43+
pass
44+
45+
@abc.abstractmethod
46+
async def search_skills(
47+
self,
48+
*,
49+
query: str,
50+
filters: Dict[str, Any] | None = None,
51+
**kwargs,
52+
) -> List[models.Frontmatter]:
53+
"""Searches for skills in the registry.
54+
55+
Args:
56+
query: The search query.
57+
filters: Optional filters.
58+
**kwargs: Additional implementation-specific arguments.
59+
60+
Returns:
61+
A list of Frontmatter objects for discovery.
62+
"""
63+
pass
64+
65+
@abc.abstractmethod
66+
def get_filter_schema(self) -> Dict[str, Any] | None:
67+
"""Returns the JSON schema for the filters supported by this registry.
68+
69+
Returns:
70+
A JSON schema dict or None if filters are not supported
71+
"""
72+
pass
73+
74+
def get_search_description(self) -> str:
75+
"""Returns the description for the search_skills tool.
76+
77+
Registries can override this to provide specialized instructions to the
78+
model on how to use their specific search capabilities.
79+
"""
80+
return (
81+
"Searches for relevant skills in the registry based on a semantic or"
82+
" keyword query."
83+
)

src/google/adk/tools/skill_toolset.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from ..features import FeatureName
3838
from ..skills import models
3939
from ..skills import prompt
40+
from ..skills.skill_registry import SkillRegistry
4041
from .base_tool import BaseTool
4142
from .base_toolset import BaseToolset
4243
from .function_tool import FunctionTool
@@ -155,6 +156,20 @@ async def run_async(
155156
}
156157

157158
skill = self._toolset._get_skill(skill_name)
159+
if not skill and self._toolset._registry:
160+
try:
161+
skill = await self._toolset._registry.get_skill(name=skill_name)
162+
if skill:
163+
self._toolset._skills[skill_name] = skill
164+
except Exception as e:
165+
logger.exception(
166+
"Failed to fetch skill '%s' from registry.", skill_name
167+
)
168+
return {
169+
"error": f"Failed to fetch skill '{skill_name}' from registry: {e}",
170+
"error_code": "REGISTRY_ERROR",
171+
}
172+
158173
if not skill:
159174
return {
160175
"error": f"Skill '{skill_name}' not found.",
@@ -771,14 +786,75 @@ async def run_async(
771786
)
772787

773788

789+
@experimental(FeatureName.SKILL_TOOLSET)
790+
class SearchSkillsTool(BaseTool):
791+
"""Tool to search for relevant skills in the registry."""
792+
793+
def __init__(self, toolset: "SkillToolset"):
794+
super().__init__(
795+
name="search_skills",
796+
description=toolset._registry.get_search_description(),
797+
)
798+
self._toolset = toolset
799+
800+
def _get_declaration(self) -> types.FunctionDeclaration | None:
801+
properties = {
802+
"query": {
803+
"type": "string",
804+
"description": "Semantic or keyword search query.",
805+
},
806+
}
807+
filter_schema = self._toolset._registry.get_filter_schema()
808+
if filter_schema:
809+
properties["filters"] = filter_schema
810+
return types.FunctionDeclaration(
811+
name=self.name,
812+
description=self.description,
813+
parameters_json_schema={
814+
"type": "object",
815+
"properties": properties,
816+
"required": ["query"],
817+
},
818+
)
819+
820+
async def run_async(
821+
self, *, args: dict[str, Any], tool_context: ToolContext
822+
) -> Any:
823+
query = args.get("query")
824+
filters = args.get("filters")
825+
826+
if not query:
827+
return {
828+
"error": "Argument 'query' is required.",
829+
"error_code": "INVALID_ARGUMENTS",
830+
}
831+
832+
results = await self._toolset._registry.search_skills(
833+
query=query, filters=filters
834+
)
835+
836+
formatted_results = []
837+
for r in results:
838+
if r.name in self._toolset._skills:
839+
logger.warning(
840+
"Naming conflict detected: Skill '%s' exists both locally and in"
841+
" the registry. Filtering out the registry skill.",
842+
r.name,
843+
)
844+
continue
845+
formatted_results.append(r.model_dump())
846+
return formatted_results
847+
848+
774849
@experimental(FeatureName.SKILL_TOOLSET)
775850
class SkillToolset(BaseToolset):
776851
"""A toolset for managing and interacting with agent skills."""
777852

778853
def __init__(
779854
self,
780-
skills: list[models.Skill],
855+
skills: list[models.Skill] | None = None,
781856
*,
857+
registry: Optional[SkillRegistry] = None,
782858
code_executor: Optional[BaseCodeExecutor] = None,
783859
script_timeout: int = _DEFAULT_SCRIPT_TIMEOUT,
784860
additional_tools: list[ToolUnion] | None = None,
@@ -787,6 +863,7 @@ def __init__(
787863
788864
Args:
789865
skills: List of skills to register.
866+
registry: Optional skill registry for dynamic discovery.
790867
code_executor: Optional code executor for script execution.
791868
script_timeout: Timeout in seconds for shell script execution via
792869
subprocess.run. Defaults to 300 seconds. Does not apply to Python
@@ -796,12 +873,13 @@ def __init__(
796873

797874
# Check for duplicate skill names
798875
seen: set[str] = set()
799-
for skill in skills:
876+
for skill in skills or []:
800877
if skill.name in seen:
801878
raise ValueError(f"Duplicate skill name '{skill.name}'.")
802879
seen.add(skill.name)
803880

804-
self._skills = {skill.name: skill for skill in skills}
881+
self._skills = {skill.name: skill for skill in skills or []}
882+
self._registry = registry
805883
self._code_executor = code_executor
806884
self._script_timeout = script_timeout
807885
self._use_invocation_cache = False
@@ -824,6 +902,8 @@ def __init__(
824902
LoadSkillResourceTool(self),
825903
RunSkillScriptTool(self),
826904
]
905+
if self._registry:
906+
self._tools.append(SearchSkillsTool(self))
827907

828908
async def get_tools(
829909
self, readonly_context: ReadonlyContext | None = None
@@ -904,6 +984,12 @@ async def process_llm_request(
904984
skills_xml = prompt.format_skills_as_xml(skills)
905985
instructions = []
906986
instructions.append(_DEFAULT_SKILL_SYSTEM_INSTRUCTION)
987+
if self._registry:
988+
instructions.append(
989+
"\nYou can also use the `search_skills` tool to discover additional"
990+
" skills in the registry if the available skills listed below are"
991+
" not sufficient.\n"
992+
)
907993
instructions.append(skills_xml)
908994
llm_request.append_instructions(instructions)
909995

tests/unittests/tools/test_skill_toolset.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from google.adk.code_executors.unsafe_local_code_executor import UnsafeLocalCodeExecutor
2323
from google.adk.models import llm_request as llm_request_model
2424
from google.adk.skills import models
25+
from google.adk.skills.skill_registry import SkillRegistry
2526
from google.adk.tools import skill_toolset
2627
from google.adk.tools import tool_context
2728
from google.genai import types
@@ -473,6 +474,128 @@ async def test_scripts_resource_not_found(mock_skill1, tool_context_instance):
473474
assert result["error_code"] == "RESOURCE_NOT_FOUND"
474475

475476

477+
class MockSkillRegistry(SkillRegistry):
478+
479+
def __init__(self):
480+
self.skills = {}
481+
self.search_results = []
482+
483+
async def get_skill(self, *, name, version=None):
484+
return self.skills.get(name)
485+
486+
async def search_skills(self, *, query, filters=None, **kwargs):
487+
return self.search_results
488+
489+
def get_filter_schema(self):
490+
return None
491+
492+
493+
@pytest.mark.asyncio
494+
async def test_skill_toolset_init_with_registry(mock_skill1):
495+
registry = MockSkillRegistry()
496+
toolset = skill_toolset.SkillToolset([mock_skill1], registry=registry)
497+
assert toolset._registry == registry
498+
tools = await toolset.get_tools()
499+
assert len(tools) == 5 # 4 default + SearchSkillsTool
500+
assert isinstance(tools[4], skill_toolset.SearchSkillsTool)
501+
502+
503+
@pytest.mark.asyncio
504+
async def test_search_skills_tool_run_async(mock_skill1, tool_context_instance):
505+
registry = MockSkillRegistry()
506+
frontmatter = mock.create_autospec(models.Frontmatter, instance=True)
507+
frontmatter.name = "remote-skill"
508+
frontmatter.model_dump.return_value = {"name": "remote-skill"}
509+
registry.search_results = [frontmatter]
510+
511+
toolset = skill_toolset.SkillToolset([mock_skill1], registry=registry)
512+
tool = skill_toolset.SearchSkillsTool(toolset)
513+
514+
result = await tool.run_async(
515+
args={"query": "test"}, tool_context=tool_context_instance
516+
)
517+
assert result == [{"name": "remote-skill"}]
518+
519+
520+
@pytest.mark.asyncio
521+
async def test_search_skills_tool_collision(
522+
mock_skill1, tool_context_instance, caplog
523+
):
524+
registry = MockSkillRegistry()
525+
frontmatter = mock.create_autospec(models.Frontmatter, instance=True)
526+
frontmatter.name = "skill1" # Same name as mock_skill1
527+
frontmatter.model_dump.return_value = {"name": "skill1"}
528+
529+
frontmatter2 = mock.create_autospec(models.Frontmatter, instance=True)
530+
frontmatter2.name = "remote-skill"
531+
frontmatter2.model_dump.return_value = {"name": "remote-skill"}
532+
533+
registry.search_results = [frontmatter, frontmatter2]
534+
535+
toolset = skill_toolset.SkillToolset([mock_skill1], registry=registry)
536+
tool = skill_toolset.SearchSkillsTool(toolset)
537+
538+
with caplog.at_level(logging.WARNING):
539+
result = await tool.run_async(
540+
args={"query": "test"}, tool_context=tool_context_instance
541+
)
542+
assert result == [{"name": "remote-skill"}]
543+
assert "Naming conflict detected" in caplog.text
544+
545+
546+
@pytest.mark.asyncio
547+
async def test_load_skill_tool_fetches_from_registry(
548+
tool_context_instance, mock_skill1
549+
):
550+
registry = MockSkillRegistry()
551+
registry.skills["my-skill"] = mock_skill1
552+
553+
toolset = skill_toolset.SkillToolset([], registry=registry)
554+
tool = skill_toolset.LoadSkillTool(toolset)
555+
556+
result = await tool.run_async(
557+
args={"skill_name": "my-skill"}, tool_context=tool_context_instance
558+
)
559+
assert result["skill_name"] == "my-skill"
560+
assert toolset._skills["my-skill"] == mock_skill1
561+
562+
563+
@pytest.mark.asyncio
564+
async def test_load_skill_tool_registry_error(tool_context_instance):
565+
registry = MockSkillRegistry()
566+
registry.get_skill = mock.AsyncMock(
567+
side_effect=Exception("Test registry error")
568+
)
569+
570+
toolset = skill_toolset.SkillToolset([], registry=registry)
571+
tool = skill_toolset.LoadSkillTool(toolset)
572+
573+
result = await tool.run_async(
574+
args={"skill_name": "my-skill"}, tool_context=tool_context_instance
575+
)
576+
assert result["error_code"] == "REGISTRY_ERROR"
577+
assert "Failed to fetch skill 'my-skill' from registry" in result["error"]
578+
579+
580+
@pytest.mark.asyncio
581+
async def test_process_llm_request_with_registry(
582+
mock_skill1, tool_context_instance
583+
):
584+
registry = MockSkillRegistry()
585+
toolset = skill_toolset.SkillToolset([mock_skill1], registry=registry)
586+
llm_req = mock.create_autospec(llm_request_model.LlmRequest, instance=True)
587+
588+
await toolset.process_llm_request(
589+
tool_context=tool_context_instance, llm_request=llm_req
590+
)
591+
592+
llm_req.append_instructions.assert_called_once()
593+
args, _ = llm_req.append_instructions.call_args
594+
instructions = args[0]
595+
assert len(instructions) == 3 # default + search instruction + skills xml
596+
assert "search_skills" in instructions[1]
597+
598+
476599
# RunSkillScriptTool tests
477600

478601

0 commit comments

Comments
 (0)