Skip to content

Commit d03b7a9

Browse files
feat: add v1 SQL prompt with enriched outer tool description for DataFabric
Introduces a versioned prompt system (v0/v1) for the DataFabric inner sub-graph LLM, cherry-picking high-impact patterns from BIRD v7-hybrid: structured query planning, value resolution via ECP metadata, error taxonomy, and convergence rules. Enriches the outer tool description with actual entity names and descriptions from the agent config so the outer ReAct agent can ground its responses in real data instead of hallucinating generic examples. Also adds SQL normalization (trailing semicolon stripping) and multi-statement validation to the query tool. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c66e07e commit d03b7a9

13 files changed

Lines changed: 743 additions & 34 deletions

src/uipath_langchain/agent/tools/datafabric_query_tool.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,19 @@
99
from .base_uipath_structured_tool import BaseUiPathStructuredTool
1010

1111

12+
def _normalize_sql(sql: str) -> str:
13+
"""Normalize a generated SQL query before validation / execution.
14+
15+
Strips surrounding whitespace and removes a single trailing semicolon so the
16+
backend receives one bare statement even if the model emits canonical SQL
17+
terminators by default.
18+
"""
19+
normalized = sql.strip()
20+
if normalized.endswith(";"):
21+
normalized = normalized[:-1].rstrip()
22+
return normalized
23+
24+
1225
def _validate_sql(sql: str) -> str | None:
1326
"""Validate SQL syntax using sqlparse.
1427
@@ -18,6 +31,8 @@ def _validate_sql(sql: str) -> str | None:
1831
parsed = sqlparse.parse(sql)
1932
if not parsed or not parsed[0].tokens:
2033
return "Empty or unparseable SQL query"
34+
if len(parsed) != 1:
35+
return "Multiple SQL statements are not allowed"
2136
return None
2237

2338

@@ -37,9 +52,14 @@ async def _arun(
3752
**kwargs: Any,
3853
) -> Any:
3954
sql_query = kwargs.get("sql_query") or (args[0] if args else "")
40-
error = _validate_sql(sql_query)
55+
normalized_sql_query = _normalize_sql(sql_query)
56+
error = _validate_sql(normalized_sql_query)
4157
if error:
4258
raise ValueError(error)
59+
if "sql_query" in kwargs:
60+
kwargs["sql_query"] = normalized_sql_query
61+
elif args:
62+
args = (normalized_sql_query, *args[1:])
4363
return await super()._arun(
4464
*args, config=config, run_manager=run_manager, **kwargs
4565
)

src/uipath_langchain/agent/tools/datafabric_tool/datafabric_prompt_builder.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,25 @@
33
Converts raw Entity SDK objects into structured Pydantic models (SQLContext),
44
then formats them as text for system prompt injection.
55
6-
Note: This module will go through refinements as we better understand
7-
the tool's performance characteristics and scoring in production.
6+
The SQL strategy section (``sql_expert_system_prompt``) is rendered from a
7+
versioned prompt template via the ``prompts`` package. ``SQL_CONSTRAINTS`` is
8+
appended verbatim — the system prompt should describe strategy only, not
9+
backend deny-lists.
810
"""
911

1012
import logging
1113

1214
from uipath.platform.entities import Entity
1315

14-
from .datafabric_prompts import SQL_CONSTRAINTS, SQL_EXPERT_SYSTEM_PROMPT
16+
from .datafabric_prompts import SQL_CONSTRAINTS
1517
from .models import (
1618
EntitySchema,
1719
EntitySQLContext,
1820
FieldSchema,
1921
QueryPattern,
2022
SQLContext,
2123
)
24+
from .prompts import build_prompt_context, get_prompt_version
2225

2326
logger = logging.getLogger(__name__)
2427

@@ -101,12 +104,30 @@ def build_sql_context(
101104
entities: list[Entity],
102105
resource_description: str = "",
103106
base_system_prompt: str = "",
107+
prompt_version: str | None = None,
104108
) -> SQLContext:
105-
"""Build the full SQL context from entities, prompts, and constraints."""
109+
"""Build the full SQL context from entities, prompts, and constraints.
110+
111+
Args:
112+
entities: Resolved Data Fabric entities.
113+
resource_description: Optional free-text description folded into the
114+
rendered prompt as ``## Domain Guidance``.
115+
base_system_prompt: Optional outer-agent system prompt prepended as
116+
``## Agent Instructions``.
117+
prompt_version: Optional version key (e.g. ``"v0"``, ``"v1"``).
118+
Defaults to the registry's default.
119+
"""
120+
version = get_prompt_version(prompt_version)
121+
ctx = build_prompt_context(
122+
entities=entities,
123+
resource_description=resource_description,
124+
)
125+
rendered_prompt = version.render(ctx)
126+
106127
return SQLContext(
107128
base_system_prompt=base_system_prompt or None,
108-
resource_description=resource_description or None,
109-
sql_expert_system_prompt=SQL_EXPERT_SYSTEM_PROMPT,
129+
resource_description=None,
130+
sql_expert_system_prompt=rendered_prompt,
110131
constraints=SQL_CONSTRAINTS,
111132
entity_contexts=[build_entity_context(e) for e in entities],
112133
)
@@ -174,22 +195,31 @@ def build(
174195
entities: list[Entity],
175196
resource_description: str = "",
176197
base_system_prompt: str = "",
198+
prompt_version: str | None = None,
177199
) -> str:
178200
"""Build the full SQL prompt text for the inner sub-graph LLM.
179201
180-
Combines agent system prompt, resource description, SQL guidelines,
181-
constraints, entity schemas, and query patterns into a single prompt string.
202+
Combines agent system prompt, the rendered SQL strategy prompt, the
203+
Calcite constraint deny-list, and entity schemas + query patterns.
182204
183205
Args:
184206
entities: List of Entity objects with schema information.
185-
resource_description: Optional description of the resource/entity set.
207+
resource_description: Optional description of the resource/entity set;
208+
folded into the rendered prompt as domain guidance.
186209
base_system_prompt: Optional system prompt from the outer agent.
210+
prompt_version: Optional version key (e.g. ``"v0"``, ``"v1"``).
211+
Defaults to the registry's default.
187212
188213
Returns:
189214
Formatted prompt string for the inner LLM system message.
190215
"""
191216
if not entities:
192217
return ""
193218

194-
ctx = build_sql_context(entities, resource_description, base_system_prompt)
219+
ctx = build_sql_context(
220+
entities,
221+
resource_description,
222+
base_system_prompt,
223+
prompt_version=prompt_version,
224+
)
195225
return format_sql_context(ctx)

src/uipath_langchain/agent/tools/datafabric_tool/datafabric_prompts.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
6. Use the exact table and column names from the provided schema
1919
7. For financial values (salary, price, etc.), use ROUND() function
2020
8. Handle NULL values appropriately with COALESCE() or IFNULL()
21+
9. Do NOT terminate the SQL query with a semicolon
2122
2223
SUPPORTED SCENARIOS (Use these patterns):
2324
@@ -28,8 +29,7 @@
2829
- IS NULL/IS NOT NULL: WHERE deleted_at IS NULL
2930
3031
2. Multi-Entity Joins (≤4 tables):
31-
- LEFT JOIN chains (up to 4 tables): SELECT o.id, c.name FROM Order o LEFT JOIN Customer c ON o.customer_id = c.id
32-
- Null-preserving semantics
32+
- INNER JOIN chains (up to 4 tables): SELECT o.id, c.name FROM Order o INNER JOIN Customer c ON o.customer_id = c.id
3333
3434
3. Predicate Distribution:
3535
- Table-scoped predicates: WHERE c.country='IN' AND o.total>1000
@@ -94,7 +94,7 @@
9494
- Temporary objects or transactions
9595
9696
5. UNSUPPORTED_CONSTRUCTS - Joins:
97-
- RIGHT JOIN, FULL OUTER JOIN, CROSS JOIN
97+
- LEFT JOIN, RIGHT JOIN, FULL OUTER JOIN, CROSS JOIN
9898
- Non-equi join conditions: ON a.created_at > b.created_at
9999
- Self-joins
100100
- LATERAL/APPLY
@@ -156,14 +156,12 @@
156156
- SELECT id, name FROM Customer WHERE deleted_at IS NULL
157157
158158
### 2. Multi-Entity Joins (≤4 adapters)
159-
- LEFT JOIN chains via entity model (up to 4 tables)
160-
- Optional adapters pruned
159+
- INNER JOIN chains via entity model (up to 4 tables)
161160
- Shared intermediates
162-
- Null-preserving semantics
163161
164162
**Examples:**
165-
- SELECT o.id, c.name FROM Order o LEFT JOIN Customer c ON o.customer_id = c.id
166-
- Fields spanning 3-4 adapters with proper LEFT JOIN chains
163+
- SELECT o.id, c.name FROM Order o INNER JOIN Customer c ON o.customer_id = c.id
164+
- Fields spanning 3-4 adapters with proper INNER JOIN chains
167165
168166
### 3. Predicate Distribution & Pushdown
169167
- Adapter-scoped predicates pushed down
@@ -255,7 +253,7 @@
255253
- Common Table Expressions (WITH/CTE)
256254
- Window functions (ROW_NUMBER, RANK, PARTITION BY)
257255
- Self-joins
258-
- RIGHT JOIN or FULL OUTER JOIN (only LEFT JOIN supported)
256+
- LEFT JOIN, RIGHT JOIN, FULL OUTER JOIN (only INNER JOIN supported)
259257
- CROSS JOIN
260258
261259
**Examples:**
@@ -277,6 +275,7 @@
277275
278276
### 4. ADVANCED_JOINS
279277
- More than 4 tables in JOIN chain
278+
- LEFT JOIN
280279
- RIGHT JOIN
281280
- FULL OUTER JOIN
282281
- CROSS JOIN
@@ -337,7 +336,7 @@
337336
338337
1. **ALWAYS use explicit column names** - Never use SELECT *
339338
2. **Use COUNT(column_name)** - Never use COUNT(*)
340-
3. **Only LEFT JOIN** - No RIGHT JOIN, FULL OUTER JOIN, or CROSS JOIN
339+
3. **Only INNER JOIN** - No LEFT JOIN, RIGHT JOIN, FULL OUTER JOIN, or CROSS JOIN
341340
4. **Maximum 4 tables** - No more than 4 tables in a JOIN chain
342341
5. **No subqueries** - No subqueries in any clause
343342
6. **No CTEs** - No WITH clauses

src/uipath_langchain/agent/tools/datafabric_tool/datafabric_subgraph.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
Implements a self-contained ReAct loop where an inner LLM translates
44
natural-language questions into SQL, executes them via ``execute_sql``,
55
and retries on errors — all within a single outer tool call.
6+
7+
On a successful SQL execution the graph short-circuits straight to END
8+
rather than invoking the LLM again to reformat the records into prose;
9+
the outer agent receives the raw tool result and produces the final
10+
natural-language answer. Errors still loop back to the inner LLM so the
11+
retry path remains intact.
612
"""
713

814
import asyncio
@@ -37,6 +43,7 @@ class DataFabricSubgraphState(BaseModel):
3743

3844
messages: Annotated[list[AnyMessage], add_messages] = []
3945
iteration_count: int = 0
46+
last_tool_success: bool = False
4047

4148

4249
class QueryExecutor:
@@ -104,7 +111,7 @@ def __init__(
104111
graph.add_conditional_edges(
105112
"inner_llm", self.router, ["inner_tool", "termination", END]
106113
)
107-
graph.add_edge("inner_tool", "inner_llm")
114+
graph.add_conditional_edges("inner_tool", self.tool_router, ["inner_llm", END])
108115
graph.add_edge("termination", END)
109116
self.compiled_graph: CompiledStateGraph[Any] = graph.compile()
110117

@@ -120,16 +127,19 @@ async def tool_node(self, state: DataFabricSubgraphState) -> dict[str, Any]:
120127
if not isinstance(last, AIMessage) or not last.tool_calls:
121128
return {"iteration_count": state.iteration_count}
122129

123-
tool_messages = await asyncio.gather(
130+
results = await asyncio.gather(
124131
*[self._execute_tool_call(tc) for tc in last.tool_calls]
125132
)
133+
tool_messages = [msg for msg, _ in results]
134+
all_succeeded = bool(results) and all(success for _, success in results)
126135
return {
127-
"messages": list(tool_messages),
136+
"messages": tool_messages,
128137
"iteration_count": state.iteration_count + len(last.tool_calls),
138+
"last_tool_success": all_succeeded,
129139
}
130140

131-
async def _execute_tool_call(self, tool_call: ToolCall) -> ToolMessage:
132-
"""Execute a single tool call and wrap the result."""
141+
async def _execute_tool_call(self, tool_call: ToolCall) -> tuple[ToolMessage, bool]:
142+
"""Execute a single tool call and report whether it succeeded."""
133143
args = tool_call.get("args", {})
134144
try:
135145
result = await self._execute_sql_tool.ainvoke(args)
@@ -140,10 +150,18 @@ async def _execute_tool_call(self, tool_call: ToolCall) -> ToolMessage:
140150
"error": str(e),
141151
"sql_query": args.get("sql_query", ""),
142152
}
143-
return ToolMessage(
144-
content=str(result),
145-
tool_call_id=tool_call["id"],
146-
name="execute_sql",
153+
succeeded = (
154+
isinstance(result, dict)
155+
and not result.get("error")
156+
and result.get("total_count", 0) > 0
157+
)
158+
return (
159+
ToolMessage(
160+
content=str(result),
161+
tool_call_id=tool_call["id"],
162+
name="execute_sql",
163+
),
164+
succeeded,
147165
)
148166

149167
async def termination_node(self, state: DataFabricSubgraphState) -> dict[str, Any]:
@@ -161,14 +179,26 @@ async def termination_node(self, state: DataFabricSubgraphState) -> dict[str, An
161179
}
162180

163181
def router(self, state: DataFabricSubgraphState) -> str:
164-
"""Route to tool, termination, or END based on state."""
182+
"""Route from ``inner_llm`` to tool, termination, or END."""
165183
last = state.messages[-1] if state.messages else None
166184
if isinstance(last, AIMessage) and last.tool_calls:
167185
if state.iteration_count < self._max_iterations:
168186
return "inner_tool"
169187
return "termination"
170188
return END
171189

190+
def tool_router(self, state: DataFabricSubgraphState) -> str:
191+
"""Route from ``inner_tool``: short-circuit on success, retry on error.
192+
193+
Skips the redundant LLM call that would otherwise reformat a
194+
successful SQL result into prose — the outer agent receives the
195+
raw tool output and produces the final natural-language answer.
196+
Errors loop back to ``inner_llm`` so the retry path is preserved.
197+
"""
198+
if state.last_tool_success:
199+
return END
200+
return "inner_llm"
201+
172202
def _create_execute_sql_tool(
173203
self,
174204
entities_service: EntitiesService,

0 commit comments

Comments
 (0)