Skip to content

Commit 47cefdd

Browse files
wukathcopybara-github
authored andcommitted
fix: Cache BaseToolset.get_tools() for calls within the same invocation
This reduces latency by avoiding redudant network calls Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 890689589
1 parent f973673 commit 47cefdd

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

src/google/adk/tools/base_toolset.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def __init__(
8080
"""
8181
self.tool_filter = tool_filter
8282
self.tool_name_prefix = tool_name_prefix
83+
self._cached_invocation_id: Optional[str] = None
84+
self._cached_prefixed_tools: Optional[list[BaseTool]] = None
8385

8486
@abstractmethod
8587
async def get_tools(
@@ -112,9 +114,19 @@ async def get_tools_with_prefix(
112114
Returns:
113115
list[BaseTool]: A list of tools with prefixed names if tool_name_prefix is provided.
114116
"""
117+
invocation_id = readonly_context.invocation_id if readonly_context else None
118+
119+
if (
120+
self._cached_prefixed_tools is not None
121+
and self._cached_invocation_id == invocation_id
122+
):
123+
return self._cached_prefixed_tools
124+
115125
tools = await self.get_tools(readonly_context)
116126

117127
if not self.tool_name_prefix:
128+
self._cached_invocation_id = invocation_id
129+
self._cached_prefixed_tools = tools
118130
return tools
119131

120132
prefix = self.tool_name_prefix
@@ -147,6 +159,8 @@ def _get_prefixed_declaration():
147159
tool_copy._get_declaration = _create_prefixed_declaration()
148160
prefixed_tools.append(tool_copy)
149161

162+
self._cached_invocation_id = invocation_id
163+
self._cached_prefixed_tools = prefixed_tools
150164
return prefixed_tools
151165

152166
async def close(self) -> None:

tests/unittests/tools/test_base_toolset.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,58 @@ async def test_no_duplicate_prefixing():
383383
original_tools = await toolset.get_tools()
384384
assert original_tools[0].name == 'original'
385385

386-
# The prefixed tools should be different instances
387-
assert prefixed_tools_1[0] is not prefixed_tools_2[0]
386+
# The prefixed tools should be the same instance when cached
387+
assert prefixed_tools_1[0] is prefixed_tools_2[0]
388388
assert prefixed_tools_1[0] is not original_tools[0]
389+
390+
391+
@pytest.mark.asyncio
392+
async def test_get_tools_with_prefix_caching():
393+
"""Test that get_tools_with_prefix caches results within the same invocation."""
394+
tool1 = _TestingTool(name='tool1', description='Test tool 1')
395+
toolset = _TestingToolset(tools=[tool1], tool_name_prefix='test')
396+
397+
session_service = InMemorySessionService()
398+
session = await session_service.create_session(
399+
app_name='test_app', user_id='test_user'
400+
)
401+
agent = SequentialAgent(name='test_agent')
402+
invocation_context1 = InvocationContext(
403+
invocation_id='inv-1',
404+
agent=agent,
405+
session=session,
406+
session_service=session_service,
407+
)
408+
readonly_context1 = ReadonlyContext(invocation_context1)
409+
410+
# First call
411+
tools1 = await toolset.get_tools_with_prefix(
412+
readonly_context=readonly_context1
413+
)
414+
assert len(tools1) == 1
415+
assert tools1[0].name == 'test_tool1'
416+
417+
# Second call with same context/invocation_id
418+
tools2 = await toolset.get_tools_with_prefix(
419+
readonly_context=readonly_context1
420+
)
421+
assert len(tools2) == 1
422+
assert (
423+
tools2 is tools1
424+
) # Should return the exact same list instance (from cache)
425+
426+
# Third call with different invocation_id
427+
invocation_context2 = InvocationContext(
428+
invocation_id='inv-2',
429+
agent=agent,
430+
session=session,
431+
session_service=session_service,
432+
)
433+
readonly_context2 = ReadonlyContext(invocation_context2)
434+
435+
tools3 = await toolset.get_tools_with_prefix(
436+
readonly_context=readonly_context2
437+
)
438+
assert len(tools3) == 1
439+
assert tools3 is not tools1 # Should be a new list instance
440+
assert tools3[0].name == 'test_tool1'

0 commit comments

Comments
 (0)