Skip to content

Commit 466652c

Browse files
committed
feat: support vanna tools
1 parent 68db19e commit 466652c

7 files changed

Lines changed: 1614 additions & 0 deletions

File tree

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
from typing import Any, Dict, Optional, List
2+
from google.adk.tools import BaseTool, ToolContext
3+
from google.genai import types
4+
from vanna.tools.agent_memory import (
5+
SaveQuestionToolArgsTool as VannaSaveQuestionToolArgsTool,
6+
SearchSavedCorrectToolUsesTool as VannaSearchSavedCorrectToolUsesTool,
7+
SaveTextMemoryTool as VannaSaveTextMemoryTool,
8+
)
9+
from vanna.core.user import User
10+
from vanna.core.tool import ToolContext as VannaToolContext
11+
12+
13+
class SaveQuestionToolArgsTool(BaseTool):
14+
"""Save successful question-tool-argument combinations for future reference."""
15+
16+
def __init__(
17+
self,
18+
agent_memory,
19+
access_groups: Optional[List[str]] = None,
20+
):
21+
"""
22+
Initialize the save tool usage tool with custom agent_memory.
23+
24+
Args:
25+
agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory)
26+
access_groups: List of user groups that can access this tool (e.g., ['admin'])
27+
"""
28+
self.agent_memory = agent_memory
29+
self.vanna_tool = VannaSaveQuestionToolArgsTool()
30+
self.access_groups = access_groups or ["admin"] # Default: only admin
31+
32+
super().__init__(
33+
name="save_question_tool_args", # Keep the same name as Vanna
34+
description="Save a successful question-tool-argument combination for future reference.",
35+
)
36+
37+
def _get_declaration(self) -> types.FunctionDeclaration:
38+
return types.FunctionDeclaration(
39+
name=self.name,
40+
description=self.description,
41+
parameters=types.Schema(
42+
type=types.Type.OBJECT,
43+
properties={
44+
"question": types.Schema(
45+
type=types.Type.STRING,
46+
description="The original question that was asked",
47+
),
48+
"tool_name": types.Schema(
49+
type=types.Type.STRING,
50+
description="The name of the tool that was used successfully",
51+
),
52+
"args": types.Schema(
53+
type=types.Type.OBJECT,
54+
description="The arguments that were passed to the tool",
55+
),
56+
},
57+
required=["question", "tool_name", "args"],
58+
),
59+
)
60+
61+
def _get_user_groups(self, tool_context: ToolContext) -> List[str]:
62+
"""Get user groups from context."""
63+
user_groups = tool_context.state.get("user_groups", ["user"])
64+
return user_groups
65+
66+
def _check_access(self, user_groups: List[str]) -> bool:
67+
"""Check if user has access to this tool."""
68+
return any(group in self.access_groups for group in user_groups)
69+
70+
def _create_vanna_context(
71+
self, tool_context: ToolContext, user_groups: List[str]
72+
) -> VannaToolContext:
73+
"""Create Vanna context from Veadk ToolContext."""
74+
user_id = tool_context.user_id
75+
user_email = tool_context.state.get("user_email", "user@example.com")
76+
77+
vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups)
78+
79+
vanna_context = VannaToolContext(
80+
user=vanna_user,
81+
conversation_id=tool_context.session.id,
82+
request_id=tool_context.session.id,
83+
agent_memory=self.agent_memory,
84+
)
85+
86+
return vanna_context
87+
88+
async def run_async(
89+
self, *, args: Dict[str, Any], tool_context: ToolContext
90+
) -> str:
91+
"""Save a tool usage pattern."""
92+
question = args.get("question", "").strip()
93+
tool_name = args.get("tool_name", "").strip()
94+
tool_args = args.get("args", {})
95+
96+
if not question:
97+
return "Error: No question provided"
98+
99+
if not tool_name:
100+
return "Error: No tool name provided"
101+
102+
try:
103+
user_groups = self._get_user_groups(tool_context)
104+
105+
if not self._check_access(user_groups):
106+
return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}"
107+
108+
vanna_context = self._create_vanna_context(tool_context, user_groups)
109+
110+
args_model = self.vanna_tool.get_args_schema()(
111+
question=question, tool_name=tool_name, args=tool_args
112+
)
113+
result = await self.vanna_tool.execute(vanna_context, args_model)
114+
115+
return str(result.result_for_llm)
116+
except Exception as e:
117+
return f"Error saving tool usage: {str(e)}"
118+
119+
120+
class SearchSavedCorrectToolUsesTool(BaseTool):
121+
"""Search for similar tool usage patterns based on a question."""
122+
123+
def __init__(
124+
self,
125+
agent_memory,
126+
access_groups: Optional[List[str]] = None,
127+
):
128+
"""
129+
Initialize the search similar tools tool with custom agent_memory.
130+
131+
Args:
132+
agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory)
133+
access_groups: List of user groups that can access this tool (e.g., ['admin', 'user'])
134+
user_group_resolver: Optional callable that takes ToolContext and returns user groups
135+
"""
136+
self.agent_memory = agent_memory
137+
self.vanna_tool = VannaSearchSavedCorrectToolUsesTool()
138+
self.access_groups = access_groups or ["admin", "user"]
139+
140+
super().__init__(
141+
name="search_saved_correct_tool_uses", # Keep the same name as Vanna
142+
description="Search for similar tool usage patterns based on a question.",
143+
)
144+
145+
def _get_declaration(self) -> types.FunctionDeclaration:
146+
return types.FunctionDeclaration(
147+
name=self.name,
148+
description=self.description,
149+
parameters=types.Schema(
150+
type=types.Type.OBJECT,
151+
properties={
152+
"question": types.Schema(
153+
type=types.Type.STRING,
154+
description="The question to find similar tool usage patterns for",
155+
),
156+
"limit": types.Schema(
157+
type=types.Type.INTEGER,
158+
description="Maximum number of results to return (default: 10)",
159+
),
160+
},
161+
required=["question"],
162+
),
163+
)
164+
165+
def _get_user_groups(self, tool_context: ToolContext) -> List[str]:
166+
"""Get user groups from context."""
167+
user_groups = tool_context.state.get("user_groups", ["user"])
168+
return user_groups
169+
170+
def _check_access(self, user_groups: List[str]) -> bool:
171+
"""Check if user has access to this tool."""
172+
return any(group in self.access_groups for group in user_groups)
173+
174+
def _create_vanna_context(
175+
self, tool_context: ToolContext, user_groups: List[str]
176+
) -> VannaToolContext:
177+
"""Create Vanna context from Veadk ToolContext."""
178+
user_id = tool_context.user_id
179+
user_email = tool_context.state.get("user_email", "user@example.com")
180+
181+
vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups)
182+
183+
vanna_context = VannaToolContext(
184+
user=vanna_user,
185+
conversation_id=tool_context.session.id,
186+
request_id=tool_context.session.id,
187+
agent_memory=self.agent_memory,
188+
)
189+
190+
return vanna_context
191+
192+
async def run_async(
193+
self, *, args: Dict[str, Any], tool_context: ToolContext
194+
) -> str:
195+
"""Search for similar tool usage patterns."""
196+
question = args.get("question", "").strip()
197+
limit = args.get("limit", 10)
198+
199+
if not question:
200+
return "Error: No question provided"
201+
202+
try:
203+
user_groups = self._get_user_groups(tool_context)
204+
205+
if not self._check_access(user_groups):
206+
return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}"
207+
208+
vanna_context = self._create_vanna_context(tool_context, user_groups)
209+
210+
args_model = self.vanna_tool.get_args_schema()(
211+
question=question, limit=limit
212+
)
213+
result = await self.vanna_tool.execute(vanna_context, args_model)
214+
215+
return str(result.result_for_llm)
216+
except Exception as e:
217+
return f"Error searching similar tools: {str(e)}"
218+
219+
220+
class SaveTextMemoryTool(BaseTool):
221+
"""Save free-form text memories for important insights, observations, or context."""
222+
223+
def __init__(
224+
self,
225+
agent_memory,
226+
access_groups: Optional[List[str]] = None,
227+
):
228+
"""
229+
Initialize the save text memory tool with custom agent_memory.
230+
231+
Args:
232+
agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory)
233+
access_groups: List of user groups that can access this tool (e.g., ['admin', 'user'])
234+
user_group_resolver: Optional callable that takes ToolContext and returns user groups
235+
"""
236+
self.agent_memory = agent_memory
237+
self.vanna_tool = VannaSaveTextMemoryTool()
238+
self.access_groups = access_groups or ["admin", "user"]
239+
240+
super().__init__(
241+
name="save_text_memory", # Keep the same name as Vanna
242+
description="Save free-form text memory for important insights, observations, or context.",
243+
)
244+
245+
def _get_declaration(self) -> types.FunctionDeclaration:
246+
return types.FunctionDeclaration(
247+
name=self.name,
248+
description=self.description,
249+
parameters=types.Schema(
250+
type=types.Type.OBJECT,
251+
properties={
252+
"content": types.Schema(
253+
type=types.Type.STRING,
254+
description="The text content to save as a memory",
255+
),
256+
},
257+
required=["content"],
258+
),
259+
)
260+
261+
def _get_user_groups(self, tool_context: ToolContext) -> List[str]:
262+
"""Get user groups from context."""
263+
user_groups = tool_context.state.get("user_groups", ["user"])
264+
return user_groups
265+
266+
def _check_access(self, user_groups: List[str]) -> bool:
267+
"""Check if user has access to this tool."""
268+
return any(group in self.access_groups for group in user_groups)
269+
270+
def _create_vanna_context(
271+
self, tool_context: ToolContext, user_groups: List[str]
272+
) -> VannaToolContext:
273+
"""Create Vanna context from Veadk ToolContext."""
274+
user_id = tool_context.user_id
275+
user_email = tool_context.state.get("user_email", "user@example.com")
276+
277+
vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups)
278+
279+
vanna_context = VannaToolContext(
280+
user=vanna_user,
281+
conversation_id=tool_context.session.id,
282+
request_id=tool_context.session.id,
283+
agent_memory=self.agent_memory,
284+
)
285+
286+
return vanna_context
287+
288+
async def run_async(
289+
self, *, args: Dict[str, Any], tool_context: ToolContext
290+
) -> str:
291+
"""Save a text memory."""
292+
content = args.get("content", "").strip()
293+
294+
if not content:
295+
return "Error: No content provided"
296+
297+
try:
298+
user_groups = self._get_user_groups(tool_context)
299+
300+
if not self._check_access(user_groups):
301+
return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}"
302+
303+
vanna_context = self._create_vanna_context(tool_context, user_groups)
304+
305+
args_model = self.vanna_tool.get_args_schema()(content=content)
306+
result = await self.vanna_tool.execute(vanna_context, args_model)
307+
308+
return str(result.result_for_llm)
309+
except Exception as e:
310+
return f"Error saving text memory: {str(e)}"

0 commit comments

Comments
 (0)