Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/uipath_langchain/agent/tools/datafabric_query_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@
from .base_uipath_structured_tool import BaseUiPathStructuredTool


def _normalize_sql(sql: str) -> str:
"""Normalize a generated SQL query before validation / execution.

Strips surrounding whitespace and removes a single trailing semicolon so the
backend receives one bare statement even if the model emits canonical SQL
terminators by default.
"""
normalized = sql.strip()
if normalized.endswith(";"):
normalized = normalized[:-1].rstrip()
return normalized


def _validate_sql(sql: str) -> str | None:
"""Validate SQL syntax using sqlparse.

Expand All @@ -18,6 +31,8 @@ def _validate_sql(sql: str) -> str | None:
parsed = sqlparse.parse(sql)
if not parsed or not parsed[0].tokens:
return "Empty or unparseable SQL query"
if len(parsed) != 1:
return "Multiple SQL statements are not allowed"
return None


Expand All @@ -37,9 +52,14 @@ async def _arun(
**kwargs: Any,
) -> Any:
sql_query = kwargs.get("sql_query") or (args[0] if args else "")
error = _validate_sql(sql_query)
normalized_sql_query = _normalize_sql(sql_query)
error = _validate_sql(normalized_sql_query)
if error:
raise ValueError(error)
if "sql_query" in kwargs:
kwargs["sql_query"] = normalized_sql_query
elif args:
args = (normalized_sql_query, *args[1:])
return await super()._arun(
*args, config=config, run_manager=run_manager, **kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,25 @@
Converts raw Entity SDK objects into structured Pydantic models (SQLContext),
then formats them as text for system prompt injection.

Note: This module will go through refinements as we better understand
the tool's performance characteristics and scoring in production.
The SQL strategy section (``sql_expert_system_prompt``) is rendered from a
versioned prompt template via the ``prompts`` package. ``SQL_CONSTRAINTS`` is
appended verbatim — the system prompt should describe strategy only, not
backend deny-lists.
"""

import logging

from uipath.platform.entities import Entity

from .datafabric_prompts import SQL_CONSTRAINTS, SQL_EXPERT_SYSTEM_PROMPT
from .datafabric_prompts import SQL_CONSTRAINTS
from .models import (
EntitySchema,
EntitySQLContext,
FieldSchema,
QueryPattern,
SQLContext,
)
from .prompts import build_prompt_context, get_prompt_version

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,12 +104,30 @@ def build_sql_context(
entities: list[Entity],
resource_description: str = "",
base_system_prompt: str = "",
prompt_version: str | None = None,
) -> SQLContext:
"""Build the full SQL context from entities, prompts, and constraints."""
"""Build the full SQL context from entities, prompts, and constraints.

Args:
entities: Resolved Data Fabric entities.
resource_description: Optional free-text description folded into the
rendered prompt as ``## Domain Guidance``.
base_system_prompt: Optional outer-agent system prompt prepended as
``## Agent Instructions``.
prompt_version: Optional version key (e.g. ``"v0"``, ``"v1"``).
Defaults to the registry's default.
"""
version = get_prompt_version(prompt_version)
ctx = build_prompt_context(
entities=entities,
resource_description=resource_description,
)
rendered_prompt = version.render(ctx)

return SQLContext(
base_system_prompt=base_system_prompt or None,
resource_description=resource_description or None,
sql_expert_system_prompt=SQL_EXPERT_SYSTEM_PROMPT,
resource_description=None,
sql_expert_system_prompt=rendered_prompt,
constraints=SQL_CONSTRAINTS,
entity_contexts=[build_entity_context(e) for e in entities],
)
Expand Down Expand Up @@ -174,22 +195,31 @@ def build(
entities: list[Entity],
resource_description: str = "",
base_system_prompt: str = "",
prompt_version: str | None = None,
) -> str:
"""Build the full SQL prompt text for the inner sub-graph LLM.

Combines agent system prompt, resource description, SQL guidelines,
constraints, entity schemas, and query patterns into a single prompt string.
Combines agent system prompt, the rendered SQL strategy prompt, the
Calcite constraint deny-list, and entity schemas + query patterns.

Args:
entities: List of Entity objects with schema information.
resource_description: Optional description of the resource/entity set.
resource_description: Optional description of the resource/entity set;
folded into the rendered prompt as domain guidance.
base_system_prompt: Optional system prompt from the outer agent.
prompt_version: Optional version key (e.g. ``"v0"``, ``"v1"``).
Defaults to the registry's default.

Returns:
Formatted prompt string for the inner LLM system message.
"""
if not entities:
return ""

ctx = build_sql_context(entities, resource_description, base_system_prompt)
ctx = build_sql_context(
entities,
resource_description,
base_system_prompt,
prompt_version=prompt_version,
)
return format_sql_context(ctx)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
6. Use the exact table and column names from the provided schema
7. For financial values (salary, price, etc.), use ROUND() function
8. Handle NULL values appropriately with COALESCE() or IFNULL()
9. Do NOT terminate the SQL query with a semicolon

SUPPORTED SCENARIOS (Use these patterns):

Expand All @@ -28,8 +29,7 @@
- IS NULL/IS NOT NULL: WHERE deleted_at IS NULL

2. Multi-Entity Joins (≤4 tables):
- LEFT JOIN chains (up to 4 tables): SELECT o.id, c.name FROM Order o LEFT JOIN Customer c ON o.customer_id = c.id
- Null-preserving semantics
- INNER JOIN chains (up to 4 tables): SELECT o.id, c.name FROM Order o INNER JOIN Customer c ON o.customer_id = c.id

3. Predicate Distribution:
- Table-scoped predicates: WHERE c.country='IN' AND o.total>1000
Expand Down Expand Up @@ -94,7 +94,7 @@
- Temporary objects or transactions

5. UNSUPPORTED_CONSTRUCTS - Joins:
- RIGHT JOIN, FULL OUTER JOIN, CROSS JOIN
- LEFT JOIN, RIGHT JOIN, FULL OUTER JOIN, CROSS JOIN
- Non-equi join conditions: ON a.created_at > b.created_at
- Self-joins
- LATERAL/APPLY
Expand Down Expand Up @@ -156,14 +156,12 @@
- SELECT id, name FROM Customer WHERE deleted_at IS NULL

### 2. Multi-Entity Joins (≤4 adapters)
- LEFT JOIN chains via entity model (up to 4 tables)
- Optional adapters pruned
- INNER JOIN chains via entity model (up to 4 tables)
- Shared intermediates
- Null-preserving semantics

**Examples:**
- SELECT o.id, c.name FROM Order o LEFT JOIN Customer c ON o.customer_id = c.id
- Fields spanning 3-4 adapters with proper LEFT JOIN chains
- SELECT o.id, c.name FROM Order o INNER JOIN Customer c ON o.customer_id = c.id
- Fields spanning 3-4 adapters with proper INNER JOIN chains

### 3. Predicate Distribution & Pushdown
- Adapter-scoped predicates pushed down
Expand Down Expand Up @@ -255,7 +253,7 @@
- Common Table Expressions (WITH/CTE)
- Window functions (ROW_NUMBER, RANK, PARTITION BY)
- Self-joins
- RIGHT JOIN or FULL OUTER JOIN (only LEFT JOIN supported)
- LEFT JOIN, RIGHT JOIN, FULL OUTER JOIN (only INNER JOIN supported)
- CROSS JOIN

**Examples:**
Expand All @@ -277,6 +275,7 @@

### 4. ADVANCED_JOINS
- More than 4 tables in JOIN chain
- LEFT JOIN
- RIGHT JOIN
- FULL OUTER JOIN
- CROSS JOIN
Expand Down Expand Up @@ -337,7 +336,7 @@

1. **ALWAYS use explicit column names** - Never use SELECT *
2. **Use COUNT(column_name)** - Never use COUNT(*)
3. **Only LEFT JOIN** - No RIGHT JOIN, FULL OUTER JOIN, or CROSS JOIN
3. **Only INNER JOIN** - No LEFT JOIN, RIGHT JOIN, FULL OUTER JOIN, or CROSS JOIN
4. **Maximum 4 tables** - No more than 4 tables in a JOIN chain
5. **No subqueries** - No subqueries in any clause
6. **No CTEs** - No WITH clauses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
Implements a self-contained ReAct loop where an inner LLM translates
natural-language questions into SQL, executes them via ``execute_sql``,
and retries on errors — all within a single outer tool call.

On a successful SQL execution the graph short-circuits straight to END
rather than invoking the LLM again to reformat the records into prose;
the outer agent receives the raw tool result and produces the final
natural-language answer. Errors still loop back to the inner LLM so the
retry path remains intact.
"""

import asyncio
Expand Down Expand Up @@ -37,6 +43,7 @@ class DataFabricSubgraphState(BaseModel):

messages: Annotated[list[AnyMessage], add_messages] = []
iteration_count: int = 0
last_tool_success: bool = False


class QueryExecutor:
Expand Down Expand Up @@ -104,7 +111,7 @@ def __init__(
graph.add_conditional_edges(
"inner_llm", self.router, ["inner_tool", "termination", END]
)
graph.add_edge("inner_tool", "inner_llm")
graph.add_conditional_edges("inner_tool", self.tool_router, ["inner_llm", END])
graph.add_edge("termination", END)
self.compiled_graph: CompiledStateGraph[Any] = graph.compile()

Expand All @@ -120,16 +127,19 @@ async def tool_node(self, state: DataFabricSubgraphState) -> dict[str, Any]:
if not isinstance(last, AIMessage) or not last.tool_calls:
return {"iteration_count": state.iteration_count}

tool_messages = await asyncio.gather(
results = await asyncio.gather(
*[self._execute_tool_call(tc) for tc in last.tool_calls]
)
tool_messages = [msg for msg, _ in results]
all_succeeded = bool(results) and all(success for _, success in results)
return {
"messages": list(tool_messages),
"messages": tool_messages,
"iteration_count": state.iteration_count + len(last.tool_calls),
"last_tool_success": all_succeeded,
}

async def _execute_tool_call(self, tool_call: ToolCall) -> ToolMessage:
"""Execute a single tool call and wrap the result."""
async def _execute_tool_call(self, tool_call: ToolCall) -> tuple[ToolMessage, bool]:
"""Execute a single tool call and report whether it succeeded."""
args = tool_call.get("args", {})
try:
result = await self._execute_sql_tool.ainvoke(args)
Expand All @@ -140,10 +150,18 @@ async def _execute_tool_call(self, tool_call: ToolCall) -> ToolMessage:
"error": str(e),
"sql_query": args.get("sql_query", ""),
}
return ToolMessage(
content=str(result),
tool_call_id=tool_call["id"],
name="execute_sql",
succeeded = (
isinstance(result, dict)
and not result.get("error")
and result.get("total_count", 0) > 0
)
return (
ToolMessage(
content=str(result),
tool_call_id=tool_call["id"],
name="execute_sql",
),
succeeded,
)

async def termination_node(self, state: DataFabricSubgraphState) -> dict[str, Any]:
Expand All @@ -161,14 +179,26 @@ async def termination_node(self, state: DataFabricSubgraphState) -> dict[str, An
}

def router(self, state: DataFabricSubgraphState) -> str:
"""Route to tool, termination, or END based on state."""
"""Route from ``inner_llm`` to tool, termination, or END."""
last = state.messages[-1] if state.messages else None
if isinstance(last, AIMessage) and last.tool_calls:
if state.iteration_count < self._max_iterations:
return "inner_tool"
return "termination"
return END

def tool_router(self, state: DataFabricSubgraphState) -> str:
"""Route from ``inner_tool``: short-circuit on success, retry on error.

Skips the redundant LLM call that would otherwise reformat a
successful SQL result into prose — the outer agent receives the
raw tool output and produces the final natural-language answer.
Errors loop back to ``inner_llm`` so the retry path is preserved.
"""
if state.last_tool_success:
return END
return "inner_llm"

def _create_execute_sql_tool(
self,
entities_service: EntitiesService,
Expand Down
Loading
Loading