11"""Context tool creation for semantic index retrieval."""
22
33import uuid
4- from typing import Any
4+ from typing import Any , Optional , Type
55
66from langchain_core .documents import Document
77from langchain_core .tools import StructuredTool
2626from .utils import sanitize_tool_name
2727
2828
29+ def is_static_query (resource : AgentContextResourceConfig ) -> bool :
30+ """Check if the resource configuration uses a static query variant."""
31+ if resource .settings .query is None or resource .settings .query .variant is None :
32+ return False
33+ return resource .settings .query .variant .lower () == "static"
34+
35+
2936def create_context_tool (resource : AgentContextResourceConfig ) -> StructuredTool :
3037 tool_name = sanitize_tool_name (resource .name )
3138 retrieval_mode = resource .settings .retrieval_mode .lower ()
@@ -40,34 +47,58 @@ def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool:
4047def handle_semantic_search (
4148 tool_name : str , resource : AgentContextResourceConfig
4249) -> StructuredTool :
50+ ensure_valid_fields (resource )
51+
52+ # needed for type checking
53+ assert resource .settings .query is not None
54+ assert resource .settings .query .variant is not None
55+
4356 retriever = ContextGroundingRetriever (
4457 index_name = resource .index_name ,
4558 folder_path = resource .folder_path ,
4659 number_of_results = resource .settings .result_count ,
4760 )
4861
49- class ContextInputSchemaModel (BaseModel ):
50- query : str = Field (
51- ..., description = "The query to search for in the knowledge base"
52- )
53-
5462 class ContextOutputSchemaModel (BaseModel ):
5563 documents : list [Document ] = Field (
5664 ..., description = "List of retrieved documents."
5765 )
5866
59- input_model = ContextInputSchemaModel
6067 output_model = ContextOutputSchemaModel
6168
62- @mockable (
63- name = resource .name ,
64- description = resource .description ,
65- input_schema = input_model .model_json_schema (),
66- output_schema = output_model .model_json_schema (),
67- example_calls = [], # Examples cannot be provided for context.
68- )
69- async def context_tool_fn (query : str ) -> dict [str , Any ]:
70- return {"documents" : await retriever .ainvoke (query )}
69+ if is_static_query (resource ):
70+ static_query_value = resource .settings .query .value
71+ assert static_query_value is not None
72+ input_model = None
73+
74+ @mockable (
75+ name = resource .name ,
76+ description = resource .description ,
77+ input_schema = input_model ,
78+ output_schema = output_model .model_json_schema (),
79+ example_calls = [], # Examples cannot be provided for context.
80+ )
81+ async def context_tool_fn () -> dict [str , Any ]:
82+ return {"documents" : await retriever .ainvoke (static_query_value )}
83+
84+ else :
85+ # Dynamic query - requires query parameter
86+ class ContextInputSchemaModel (BaseModel ):
87+ query : str = Field (
88+ ..., description = "The query to search for in the knowledge base"
89+ )
90+
91+ input_model = ContextInputSchemaModel
92+
93+ @mockable (
94+ name = resource .name ,
95+ description = resource .description ,
96+ input_schema = input_model .model_json_schema (),
97+ output_schema = output_model .model_json_schema (),
98+ example_calls = [], # Examples cannot be provided for context.
99+ )
100+ async def context_tool_fn (query : str ) -> dict [str , Any ]:
101+ return {"documents" : await retriever .ainvoke (query )}
71102
72103 return StructuredToolWithOutputType (
73104 name = tool_name ,
@@ -82,36 +113,69 @@ def handle_deep_rag(
82113 tool_name : str , resource : AgentContextResourceConfig
83114) -> StructuredTool :
84115 ensure_valid_fields (resource )
116+
85117 # needed for type checking
86118 assert resource .settings .query is not None
87- assert resource .settings .query .value is not None
119+ assert resource .settings .query .variant is not None
88120
89121 index_name = resource .index_name
90- prompt = resource .settings .query .value
91122 if not resource .settings .citation_mode :
92123 raise ValueError ("Citation mode is required for Deep RAG" )
93124 citation_mode = CitationMode (resource .settings .citation_mode .value )
94125
95- input_model = None
96126 output_model = DeepRagResponse
97127
98- @mockable (
99- name = resource .name ,
100- description = resource .description ,
101- input_schema = input_model ,
102- output_schema = output_model .model_json_schema (),
103- example_calls = [], # Examples cannot be provided for context.
104- )
105- async def context_tool_fn () -> dict [str , Any ]:
106- # TODO: add glob pattern support
107- return interrupt (
108- CreateDeepRag (
109- name = f"task-{ uuid .uuid4 ()} " ,
110- index_name = index_name ,
111- prompt = prompt ,
112- citation_mode = citation_mode ,
128+ if is_static_query (resource ):
129+ # Static query - no input parameter needed
130+ static_prompt = resource .settings .query .value
131+ assert static_prompt is not None
132+ input_model = None
133+
134+ @mockable (
135+ name = resource .name ,
136+ description = resource .description ,
137+ input_schema = input_model ,
138+ output_schema = output_model .model_json_schema (),
139+ example_calls = [], # Examples cannot be provided for context.
140+ )
141+ async def context_tool_fn () -> dict [str , Any ]:
142+ # TODO: add glob pattern support
143+ return interrupt (
144+ CreateDeepRag (
145+ name = f"task-{ uuid .uuid4 ()} " ,
146+ index_name = index_name ,
147+ prompt = static_prompt ,
148+ citation_mode = citation_mode ,
149+ )
150+ )
151+
152+ else :
153+ # Dynamic query - requires query parameter
154+ class DeepRagInputSchemaModel (BaseModel ):
155+ query : str = Field (
156+ ...,
157+ description = "Describe the task: what to research across documents, what to synthesize, and how to cite sources" ,
113158 )
159+
160+ input_model = DeepRagInputSchemaModel
161+
162+ @mockable (
163+ name = resource .name ,
164+ description = resource .description ,
165+ input_schema = input_model .model_json_schema (),
166+ output_schema = output_model .model_json_schema (),
167+ example_calls = [], # Examples cannot be provided for context.
114168 )
169+ async def context_tool_fn (query : str ) -> dict [str , Any ]:
170+ # TODO: add glob pattern support
171+ return interrupt (
172+ CreateDeepRag (
173+ name = f"task-{ uuid .uuid4 ()} " ,
174+ index_name = index_name ,
175+ prompt = query ,
176+ citation_mode = citation_mode ,
177+ )
178+ )
115179
116180 return StructuredToolWithOutputType (
117181 name = tool_name ,
@@ -129,11 +193,9 @@ def handle_batch_transform(
129193
130194 # needed for type checking
131195 assert resource .settings .query is not None
132- assert resource .settings .query .value is not None
196+ assert resource .settings .query .variant is not None
133197
134198 index_name = resource .index_name
135- prompt = resource .settings .query .value
136-
137199 index_folder_path = resource .folder_path
138200 if not resource .settings .web_search_grounding :
139201 raise ValueError ("Web search grounding field is required for Batch Transform" )
@@ -157,35 +219,82 @@ def handle_batch_transform(
157219 )
158220 )
159221
160- class BatchTransformSchemaModel (BaseModel ):
161- destination_path : str = Field (
162- ...,
163- description = "The relative file path destination for the modified csv file" ,
164- )
165-
166- input_model = BatchTransformSchemaModel
167222 output_model = BatchTransformResponse
168223
169- @mockable (
170- name = resource .name ,
171- description = resource .description ,
172- input_schema = input_model .model_json_schema (),
173- output_schema = output_model .model_json_schema (),
174- example_calls = [], # Examples cannot be provided for context.
175- )
176- async def context_tool_fn (destination_path : str ) -> dict [str , Any ]:
177- # TODO: storage_bucket_folder_path_prefix support
178- return interrupt (
179- CreateBatchTransform (
180- name = f"task-{ uuid .uuid4 ()} " ,
181- index_name = index_name ,
182- prompt = prompt ,
183- destination_path = destination_path ,
184- index_folder_path = index_folder_path ,
185- enable_web_search_grounding = enable_web_search_grounding ,
186- output_columns = batch_transform_output_columns ,
224+ input_model : Optional [Type [BaseModel ]]
225+
226+ if is_static_query (resource ):
227+ # Static query - only destination_path parameter needed
228+ static_prompt = resource .settings .query .value
229+ assert static_prompt is not None
230+
231+ class StaticBatchTransformSchemaModel (BaseModel ):
232+ destination_path : str = Field (
233+ default = "output.csv" ,
234+ description = "The relative file path destination for the modified csv file" ,
235+ )
236+
237+ input_model = StaticBatchTransformSchemaModel
238+
239+ @mockable (
240+ name = resource .name ,
241+ description = resource .description ,
242+ input_schema = input_model .model_json_schema (),
243+ output_schema = output_model .model_json_schema (),
244+ example_calls = [], # Examples cannot be provided for context.
245+ )
246+ async def context_tool_fn (
247+ destination_path : str = "output.csv" ,
248+ ) -> dict [str , Any ]:
249+ # TODO: storage_bucket_folder_path_prefix support
250+ return interrupt (
251+ CreateBatchTransform (
252+ name = f"task-{ uuid .uuid4 ()} " ,
253+ index_name = index_name ,
254+ prompt = static_prompt ,
255+ destination_path = destination_path ,
256+ index_folder_path = index_folder_path ,
257+ enable_web_search_grounding = enable_web_search_grounding ,
258+ output_columns = batch_transform_output_columns ,
259+ )
260+ )
261+
262+ else :
263+ # Dynamic query - requires both query and destination_path parameters
264+ class DynamicBatchTransformSchemaModel (BaseModel ):
265+ query : str = Field (
266+ ...,
267+ description = "Describe the task for each row: what to analyze, what to extract, and how to populate the output columns" ,
187268 )
269+ destination_path : str = Field (
270+ default = "output.csv" ,
271+ description = "The relative file path destination for the modified csv file" ,
272+ )
273+
274+ input_model = DynamicBatchTransformSchemaModel
275+
276+ @mockable (
277+ name = resource .name ,
278+ description = resource .description ,
279+ input_schema = input_model .model_json_schema (),
280+ output_schema = output_model .model_json_schema (),
281+ example_calls = [], # Examples cannot be provided for context.
188282 )
283+ async def context_tool_fn (
284+ query : str , destination_path : str = "output.csv"
285+ ) -> dict [str , Any ]:
286+ # TODO: storage_bucket_folder_path_prefix support
287+ return interrupt (
288+ CreateBatchTransform (
289+ name = f"task-{ uuid .uuid4 ()} " ,
290+ index_name = index_name ,
291+ prompt = query ,
292+ destination_path = destination_path ,
293+ index_folder_path = index_folder_path ,
294+ enable_web_search_grounding = enable_web_search_grounding ,
295+ output_columns = batch_transform_output_columns ,
296+ )
297+ )
189298
190299 return StructuredToolWithOutputType (
191300 name = tool_name ,
@@ -199,5 +308,9 @@ async def context_tool_fn(destination_path: str) -> dict[str, Any]:
199308def ensure_valid_fields (resource_config : AgentContextResourceConfig ):
200309 if not resource_config .settings .query :
201310 raise ValueError ("Query object is required" )
202- if not resource_config .settings .query .value :
203- raise ValueError ("Query prompt is required" )
311+
312+ if not resource_config .settings .query .variant :
313+ raise ValueError ("Query variant is required" )
314+
315+ if is_static_query (resource_config ) and not resource_config .settings .query .value :
316+ raise ValueError ("Static query requires a query value to be set" )
0 commit comments