Skip to content

Commit 22be41b

Browse files
authored
feat: integration of vanna tools (#533)
1 parent 33f602d commit 22be41b

File tree

8 files changed

+1909
-0
lines changed

8 files changed

+1909
-0
lines changed
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict, Optional, List
16+
from google.adk.tools import BaseTool, ToolContext
17+
from google.genai import types
18+
from vanna.tools.agent_memory import (
19+
SaveQuestionToolArgsTool as VannaSaveQuestionToolArgsTool,
20+
SearchSavedCorrectToolUsesTool as VannaSearchSavedCorrectToolUsesTool,
21+
SaveTextMemoryTool as VannaSaveTextMemoryTool,
22+
)
23+
from vanna.core.user import User
24+
from vanna.core.tool import ToolContext as VannaToolContext
25+
26+
27+
class SaveQuestionToolArgsTool(BaseTool):
28+
"""Save successful question-tool-argument combinations for future reference."""
29+
30+
def __init__(
31+
self,
32+
agent_memory,
33+
access_groups: Optional[List[str]] = None,
34+
):
35+
"""
36+
Initialize the save tool usage tool with custom agent_memory.
37+
38+
Args:
39+
agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory)
40+
access_groups: List of user groups that can access this tool (e.g., ['admin'])
41+
"""
42+
self.agent_memory = agent_memory
43+
self.vanna_tool = VannaSaveQuestionToolArgsTool()
44+
self.access_groups = access_groups or ["admin", "user"]
45+
46+
super().__init__(
47+
name="save_question_tool_args", # Keep the same name as Vanna
48+
description="Save a successful question-tool-argument combination for future reference.",
49+
)
50+
51+
def _get_declaration(self) -> types.FunctionDeclaration:
52+
return types.FunctionDeclaration(
53+
name=self.name,
54+
description=self.description,
55+
parameters=types.Schema(
56+
type=types.Type.OBJECT,
57+
properties={
58+
"question": types.Schema(
59+
type=types.Type.STRING,
60+
description="The original question that was asked",
61+
),
62+
"tool_name": types.Schema(
63+
type=types.Type.STRING,
64+
description="The name of the tool that was used successfully",
65+
),
66+
"args": types.Schema(
67+
type=types.Type.OBJECT,
68+
description="The arguments that were passed to the tool",
69+
),
70+
},
71+
required=["question", "tool_name", "args"],
72+
),
73+
)
74+
75+
def _get_user_groups(self, tool_context: ToolContext) -> List[str]:
76+
"""Get user groups from context."""
77+
user_groups = tool_context.state.get("user_groups", ["user"])
78+
return user_groups
79+
80+
def _check_access(self, user_groups: List[str]) -> bool:
81+
"""Check if user has access to this tool."""
82+
return any(group in self.access_groups for group in user_groups)
83+
84+
def _create_vanna_context(
85+
self, tool_context: ToolContext, user_groups: List[str]
86+
) -> VannaToolContext:
87+
"""Create Vanna context from Veadk ToolContext."""
88+
user_id = tool_context.user_id
89+
session_id = tool_context.session.id
90+
user_email = tool_context.state.get("user_email", "user@example.com")
91+
92+
vanna_user = User(
93+
id=user_id + "_" + session_id,
94+
email=user_email,
95+
group_memberships=user_groups,
96+
)
97+
98+
vanna_context = VannaToolContext(
99+
user=vanna_user,
100+
conversation_id=session_id,
101+
request_id=session_id,
102+
agent_memory=self.agent_memory,
103+
)
104+
105+
return vanna_context
106+
107+
async def run_async(
108+
self, *, args: Dict[str, Any], tool_context: ToolContext
109+
) -> str:
110+
"""Save a tool usage pattern."""
111+
question = args.get("question", "").strip()
112+
tool_name = args.get("tool_name", "").strip()
113+
tool_args = args.get("args", {})
114+
115+
if not question:
116+
return "Error: No question provided"
117+
118+
if not tool_name:
119+
return "Error: No tool name provided"
120+
121+
try:
122+
user_groups = self._get_user_groups(tool_context)
123+
124+
if not self._check_access(user_groups):
125+
return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}"
126+
127+
vanna_context = self._create_vanna_context(tool_context, user_groups)
128+
129+
args_model = self.vanna_tool.get_args_schema()(
130+
question=question, tool_name=tool_name, args=tool_args
131+
)
132+
result = await self.vanna_tool.execute(vanna_context, args_model)
133+
134+
return str(result.result_for_llm)
135+
except Exception as e:
136+
return f"Error saving tool usage: {str(e)}"
137+
138+
139+
class SearchSavedCorrectToolUsesTool(BaseTool):
140+
"""Search for similar tool usage patterns based on a question."""
141+
142+
def __init__(
143+
self,
144+
agent_memory,
145+
access_groups: Optional[List[str]] = None,
146+
):
147+
"""
148+
Initialize the search similar tools tool with custom agent_memory.
149+
150+
Args:
151+
agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory)
152+
access_groups: List of user groups that can access this tool (e.g., ['admin', 'user'])
153+
user_group_resolver: Optional callable that takes ToolContext and returns user groups
154+
"""
155+
self.agent_memory = agent_memory
156+
self.vanna_tool = VannaSearchSavedCorrectToolUsesTool()
157+
self.access_groups = access_groups or ["admin", "user"]
158+
159+
super().__init__(
160+
name="search_saved_correct_tool_uses", # Keep the same name as Vanna
161+
description="Search for similar tool usage patterns based on a question.",
162+
)
163+
164+
def _get_declaration(self) -> types.FunctionDeclaration:
165+
return types.FunctionDeclaration(
166+
name=self.name,
167+
description=self.description,
168+
parameters=types.Schema(
169+
type=types.Type.OBJECT,
170+
properties={
171+
"question": types.Schema(
172+
type=types.Type.STRING,
173+
description="The question to find similar tool usage patterns for",
174+
),
175+
"limit": types.Schema(
176+
type=types.Type.INTEGER,
177+
description="Maximum number of results to return (default: 10)",
178+
),
179+
},
180+
required=["question"],
181+
),
182+
)
183+
184+
def _get_user_groups(self, tool_context: ToolContext) -> List[str]:
185+
"""Get user groups from context."""
186+
user_groups = tool_context.state.get("user_groups", ["user"])
187+
return user_groups
188+
189+
def _check_access(self, user_groups: List[str]) -> bool:
190+
"""Check if user has access to this tool."""
191+
return any(group in self.access_groups for group in user_groups)
192+
193+
def _create_vanna_context(
194+
self, tool_context: ToolContext, user_groups: List[str]
195+
) -> VannaToolContext:
196+
"""Create Vanna context from Veadk ToolContext."""
197+
user_id = tool_context.user_id
198+
session_id = tool_context.session.id
199+
user_email = tool_context.state.get("user_email", "user@example.com")
200+
201+
vanna_user = User(
202+
id=user_id + "_" + session_id,
203+
email=user_email,
204+
group_memberships=user_groups,
205+
)
206+
207+
vanna_context = VannaToolContext(
208+
user=vanna_user,
209+
conversation_id=session_id,
210+
request_id=session_id,
211+
agent_memory=self.agent_memory,
212+
)
213+
214+
return vanna_context
215+
216+
async def run_async(
217+
self, *, args: Dict[str, Any], tool_context: ToolContext
218+
) -> str:
219+
"""Search for similar tool usage patterns."""
220+
question = args.get("question", "").strip()
221+
limit = args.get("limit", 10)
222+
223+
if not question:
224+
return "Error: No question provided"
225+
226+
try:
227+
user_groups = self._get_user_groups(tool_context)
228+
229+
if not self._check_access(user_groups):
230+
return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}"
231+
232+
vanna_context = self._create_vanna_context(tool_context, user_groups)
233+
234+
args_model = self.vanna_tool.get_args_schema()(
235+
question=question, limit=limit
236+
)
237+
result = await self.vanna_tool.execute(vanna_context, args_model)
238+
239+
return str(result.result_for_llm)
240+
except Exception as e:
241+
return f"Error searching similar tools: {str(e)}"
242+
243+
244+
class SaveTextMemoryTool(BaseTool):
245+
"""Save free-form text memories for important insights, observations, or context."""
246+
247+
def __init__(
248+
self,
249+
agent_memory,
250+
access_groups: Optional[List[str]] = None,
251+
):
252+
"""
253+
Initialize the save text memory tool with custom agent_memory.
254+
255+
Args:
256+
agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory)
257+
access_groups: List of user groups that can access this tool (e.g., ['admin', 'user'])
258+
user_group_resolver: Optional callable that takes ToolContext and returns user groups
259+
"""
260+
self.agent_memory = agent_memory
261+
self.vanna_tool = VannaSaveTextMemoryTool()
262+
self.access_groups = access_groups or ["admin", "user"]
263+
264+
super().__init__(
265+
name="save_text_memory", # Keep the same name as Vanna
266+
description="Save free-form text memory for important insights, observations, or context.",
267+
)
268+
269+
def _get_declaration(self) -> types.FunctionDeclaration:
270+
return types.FunctionDeclaration(
271+
name=self.name,
272+
description=self.description,
273+
parameters=types.Schema(
274+
type=types.Type.OBJECT,
275+
properties={
276+
"content": types.Schema(
277+
type=types.Type.STRING,
278+
description="The text content to save as a memory",
279+
),
280+
},
281+
required=["content"],
282+
),
283+
)
284+
285+
def _get_user_groups(self, tool_context: ToolContext) -> List[str]:
286+
"""Get user groups from context."""
287+
user_groups = tool_context.state.get("user_groups", ["user"])
288+
return user_groups
289+
290+
def _check_access(self, user_groups: List[str]) -> bool:
291+
"""Check if user has access to this tool."""
292+
return any(group in self.access_groups for group in user_groups)
293+
294+
def _create_vanna_context(
295+
self, tool_context: ToolContext, user_groups: List[str]
296+
) -> VannaToolContext:
297+
"""Create Vanna context from Veadk ToolContext."""
298+
user_id = tool_context.user_id
299+
session_id = tool_context.session.id
300+
user_email = tool_context.state.get("user_email", "user@example.com")
301+
302+
vanna_user = User(
303+
id=user_id + "_" + session_id,
304+
email=user_email,
305+
group_memberships=user_groups,
306+
)
307+
308+
vanna_context = VannaToolContext(
309+
user=vanna_user,
310+
conversation_id=session_id,
311+
request_id=session_id,
312+
agent_memory=self.agent_memory,
313+
)
314+
315+
return vanna_context
316+
317+
async def run_async(
318+
self, *, args: Dict[str, Any], tool_context: ToolContext
319+
) -> str:
320+
"""Save a text memory."""
321+
content = args.get("content", "").strip()
322+
323+
if not content:
324+
return "Error: No content provided"
325+
326+
try:
327+
user_groups = self._get_user_groups(tool_context)
328+
329+
if not self._check_access(user_groups):
330+
return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}"
331+
332+
vanna_context = self._create_vanna_context(tool_context, user_groups)
333+
334+
args_model = self.vanna_tool.get_args_schema()(content=content)
335+
result = await self.vanna_tool.execute(vanna_context, args_model)
336+
337+
return str(result.result_for_llm)
338+
except Exception as e:
339+
return f"Error saving text memory: {str(e)}"

0 commit comments

Comments
 (0)