Skip to content

Commit c0db575

Browse files
fix dynamic and static options (#394)
1 parent acb4ca5 commit c0db575

2 files changed

Lines changed: 577 additions & 71 deletions

File tree

Lines changed: 177 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Context tool creation for semantic index retrieval."""
22

33
import uuid
4-
from typing import Any
4+
from typing import Any, Optional, Type
55

66
from langchain_core.documents import Document
77
from langchain_core.tools import StructuredTool
@@ -26,6 +26,13 @@
2626
from .utils import sanitize_tool_name
2727

2828

29+
def is_static_query(resource: AgentContextResourceConfig) -> bool:
30+
"""Check if the resource configuration uses a static query variant."""
31+
if resource.settings.query is None or resource.settings.query.variant is None:
32+
return False
33+
return resource.settings.query.variant.lower() == "static"
34+
35+
2936
def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool:
3037
tool_name = sanitize_tool_name(resource.name)
3138
retrieval_mode = resource.settings.retrieval_mode.lower()
@@ -40,34 +47,58 @@ def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool:
4047
def handle_semantic_search(
4148
tool_name: str, resource: AgentContextResourceConfig
4249
) -> StructuredTool:
50+
ensure_valid_fields(resource)
51+
52+
# needed for type checking
53+
assert resource.settings.query is not None
54+
assert resource.settings.query.variant is not None
55+
4356
retriever = ContextGroundingRetriever(
4457
index_name=resource.index_name,
4558
folder_path=resource.folder_path,
4659
number_of_results=resource.settings.result_count,
4760
)
4861

49-
class ContextInputSchemaModel(BaseModel):
50-
query: str = Field(
51-
..., description="The query to search for in the knowledge base"
52-
)
53-
5462
class ContextOutputSchemaModel(BaseModel):
5563
documents: list[Document] = Field(
5664
..., description="List of retrieved documents."
5765
)
5866

59-
input_model = ContextInputSchemaModel
6067
output_model = ContextOutputSchemaModel
6168

62-
@mockable(
63-
name=resource.name,
64-
description=resource.description,
65-
input_schema=input_model.model_json_schema(),
66-
output_schema=output_model.model_json_schema(),
67-
example_calls=[], # Examples cannot be provided for context.
68-
)
69-
async def context_tool_fn(query: str) -> dict[str, Any]:
70-
return {"documents": await retriever.ainvoke(query)}
69+
if is_static_query(resource):
70+
static_query_value = resource.settings.query.value
71+
assert static_query_value is not None
72+
input_model = None
73+
74+
@mockable(
75+
name=resource.name,
76+
description=resource.description,
77+
input_schema=input_model,
78+
output_schema=output_model.model_json_schema(),
79+
example_calls=[], # Examples cannot be provided for context.
80+
)
81+
async def context_tool_fn() -> dict[str, Any]:
82+
return {"documents": await retriever.ainvoke(static_query_value)}
83+
84+
else:
85+
# Dynamic query - requires query parameter
86+
class ContextInputSchemaModel(BaseModel):
87+
query: str = Field(
88+
..., description="The query to search for in the knowledge base"
89+
)
90+
91+
input_model = ContextInputSchemaModel
92+
93+
@mockable(
94+
name=resource.name,
95+
description=resource.description,
96+
input_schema=input_model.model_json_schema(),
97+
output_schema=output_model.model_json_schema(),
98+
example_calls=[], # Examples cannot be provided for context.
99+
)
100+
async def context_tool_fn(query: str) -> dict[str, Any]:
101+
return {"documents": await retriever.ainvoke(query)}
71102

72103
return StructuredToolWithOutputType(
73104
name=tool_name,
@@ -82,36 +113,69 @@ def handle_deep_rag(
82113
tool_name: str, resource: AgentContextResourceConfig
83114
) -> StructuredTool:
84115
ensure_valid_fields(resource)
116+
85117
# needed for type checking
86118
assert resource.settings.query is not None
87-
assert resource.settings.query.value is not None
119+
assert resource.settings.query.variant is not None
88120

89121
index_name = resource.index_name
90-
prompt = resource.settings.query.value
91122
if not resource.settings.citation_mode:
92123
raise ValueError("Citation mode is required for Deep RAG")
93124
citation_mode = CitationMode(resource.settings.citation_mode.value)
94125

95-
input_model = None
96126
output_model = DeepRagResponse
97127

98-
@mockable(
99-
name=resource.name,
100-
description=resource.description,
101-
input_schema=input_model,
102-
output_schema=output_model.model_json_schema(),
103-
example_calls=[], # Examples cannot be provided for context.
104-
)
105-
async def context_tool_fn() -> dict[str, Any]:
106-
# TODO: add glob pattern support
107-
return interrupt(
108-
CreateDeepRag(
109-
name=f"task-{uuid.uuid4()}",
110-
index_name=index_name,
111-
prompt=prompt,
112-
citation_mode=citation_mode,
128+
if is_static_query(resource):
129+
# Static query - no input parameter needed
130+
static_prompt = resource.settings.query.value
131+
assert static_prompt is not None
132+
input_model = None
133+
134+
@mockable(
135+
name=resource.name,
136+
description=resource.description,
137+
input_schema=input_model,
138+
output_schema=output_model.model_json_schema(),
139+
example_calls=[], # Examples cannot be provided for context.
140+
)
141+
async def context_tool_fn() -> dict[str, Any]:
142+
# TODO: add glob pattern support
143+
return interrupt(
144+
CreateDeepRag(
145+
name=f"task-{uuid.uuid4()}",
146+
index_name=index_name,
147+
prompt=static_prompt,
148+
citation_mode=citation_mode,
149+
)
150+
)
151+
152+
else:
153+
# Dynamic query - requires query parameter
154+
class DeepRagInputSchemaModel(BaseModel):
155+
query: str = Field(
156+
...,
157+
description="Describe the task: what to research across documents, what to synthesize, and how to cite sources",
113158
)
159+
160+
input_model = DeepRagInputSchemaModel
161+
162+
@mockable(
163+
name=resource.name,
164+
description=resource.description,
165+
input_schema=input_model.model_json_schema(),
166+
output_schema=output_model.model_json_schema(),
167+
example_calls=[], # Examples cannot be provided for context.
114168
)
169+
async def context_tool_fn(query: str) -> dict[str, Any]:
170+
# TODO: add glob pattern support
171+
return interrupt(
172+
CreateDeepRag(
173+
name=f"task-{uuid.uuid4()}",
174+
index_name=index_name,
175+
prompt=query,
176+
citation_mode=citation_mode,
177+
)
178+
)
115179

116180
return StructuredToolWithOutputType(
117181
name=tool_name,
@@ -129,11 +193,9 @@ def handle_batch_transform(
129193

130194
# needed for type checking
131195
assert resource.settings.query is not None
132-
assert resource.settings.query.value is not None
196+
assert resource.settings.query.variant is not None
133197

134198
index_name = resource.index_name
135-
prompt = resource.settings.query.value
136-
137199
index_folder_path = resource.folder_path
138200
if not resource.settings.web_search_grounding:
139201
raise ValueError("Web search grounding field is required for Batch Transform")
@@ -157,35 +219,82 @@ def handle_batch_transform(
157219
)
158220
)
159221

160-
class BatchTransformSchemaModel(BaseModel):
161-
destination_path: str = Field(
162-
...,
163-
description="The relative file path destination for the modified csv file",
164-
)
165-
166-
input_model = BatchTransformSchemaModel
167222
output_model = BatchTransformResponse
168223

169-
@mockable(
170-
name=resource.name,
171-
description=resource.description,
172-
input_schema=input_model.model_json_schema(),
173-
output_schema=output_model.model_json_schema(),
174-
example_calls=[], # Examples cannot be provided for context.
175-
)
176-
async def context_tool_fn(destination_path: str) -> dict[str, Any]:
177-
# TODO: storage_bucket_folder_path_prefix support
178-
return interrupt(
179-
CreateBatchTransform(
180-
name=f"task-{uuid.uuid4()}",
181-
index_name=index_name,
182-
prompt=prompt,
183-
destination_path=destination_path,
184-
index_folder_path=index_folder_path,
185-
enable_web_search_grounding=enable_web_search_grounding,
186-
output_columns=batch_transform_output_columns,
224+
input_model: Optional[Type[BaseModel]]
225+
226+
if is_static_query(resource):
227+
# Static query - only destination_path parameter needed
228+
static_prompt = resource.settings.query.value
229+
assert static_prompt is not None
230+
231+
class StaticBatchTransformSchemaModel(BaseModel):
232+
destination_path: str = Field(
233+
default="output.csv",
234+
description="The relative file path destination for the modified csv file",
235+
)
236+
237+
input_model = StaticBatchTransformSchemaModel
238+
239+
@mockable(
240+
name=resource.name,
241+
description=resource.description,
242+
input_schema=input_model.model_json_schema(),
243+
output_schema=output_model.model_json_schema(),
244+
example_calls=[], # Examples cannot be provided for context.
245+
)
246+
async def context_tool_fn(
247+
destination_path: str = "output.csv",
248+
) -> dict[str, Any]:
249+
# TODO: storage_bucket_folder_path_prefix support
250+
return interrupt(
251+
CreateBatchTransform(
252+
name=f"task-{uuid.uuid4()}",
253+
index_name=index_name,
254+
prompt=static_prompt,
255+
destination_path=destination_path,
256+
index_folder_path=index_folder_path,
257+
enable_web_search_grounding=enable_web_search_grounding,
258+
output_columns=batch_transform_output_columns,
259+
)
260+
)
261+
262+
else:
263+
# Dynamic query - requires both query and destination_path parameters
264+
class DynamicBatchTransformSchemaModel(BaseModel):
265+
query: str = Field(
266+
...,
267+
description="Describe the task for each row: what to analyze, what to extract, and how to populate the output columns",
187268
)
269+
destination_path: str = Field(
270+
default="output.csv",
271+
description="The relative file path destination for the modified csv file",
272+
)
273+
274+
input_model = DynamicBatchTransformSchemaModel
275+
276+
@mockable(
277+
name=resource.name,
278+
description=resource.description,
279+
input_schema=input_model.model_json_schema(),
280+
output_schema=output_model.model_json_schema(),
281+
example_calls=[], # Examples cannot be provided for context.
188282
)
283+
async def context_tool_fn(
284+
query: str, destination_path: str = "output.csv"
285+
) -> dict[str, Any]:
286+
# TODO: storage_bucket_folder_path_prefix support
287+
return interrupt(
288+
CreateBatchTransform(
289+
name=f"task-{uuid.uuid4()}",
290+
index_name=index_name,
291+
prompt=query,
292+
destination_path=destination_path,
293+
index_folder_path=index_folder_path,
294+
enable_web_search_grounding=enable_web_search_grounding,
295+
output_columns=batch_transform_output_columns,
296+
)
297+
)
189298

190299
return StructuredToolWithOutputType(
191300
name=tool_name,
@@ -199,5 +308,9 @@ async def context_tool_fn(destination_path: str) -> dict[str, Any]:
199308
def ensure_valid_fields(resource_config: AgentContextResourceConfig):
200309
if not resource_config.settings.query:
201310
raise ValueError("Query object is required")
202-
if not resource_config.settings.query.value:
203-
raise ValueError("Query prompt is required")
311+
312+
if not resource_config.settings.query.variant:
313+
raise ValueError("Query variant is required")
314+
315+
if is_static_query(resource_config) and not resource_config.settings.query.value:
316+
raise ValueError("Static query requires a query value to be set")

0 commit comments

Comments
 (0)