33Implements a self-contained ReAct loop where an inner LLM translates
44natural-language questions into SQL, executes them via ``execute_sql``,
55and retries on errors — all within a single outer tool call.
6+
7+ On a successful SQL execution the graph short-circuits straight to END
8+ rather than invoking the LLM again to reformat the records into prose;
9+ the outer agent receives the raw tool result and produces the final
10+ natural-language answer. Errors still loop back to the inner LLM so the
11+ retry path remains intact.
612"""
713
814import asyncio
@@ -37,6 +43,7 @@ class DataFabricSubgraphState(BaseModel):
3743
3844 messages : Annotated [list [AnyMessage ], add_messages ] = []
3945 iteration_count : int = 0
46+ last_tool_success : bool = False
4047
4148
4249class QueryExecutor :
@@ -104,7 +111,7 @@ def __init__(
104111 graph .add_conditional_edges (
105112 "inner_llm" , self .router , ["inner_tool" , "termination" , END ]
106113 )
107- graph .add_edge ("inner_tool" , "inner_llm" )
114+ graph .add_conditional_edges ("inner_tool" , self . tool_router , [ "inner_llm" , END ] )
108115 graph .add_edge ("termination" , END )
109116 self .compiled_graph : CompiledStateGraph [Any ] = graph .compile ()
110117
@@ -120,16 +127,19 @@ async def tool_node(self, state: DataFabricSubgraphState) -> dict[str, Any]:
120127 if not isinstance (last , AIMessage ) or not last .tool_calls :
121128 return {"iteration_count" : state .iteration_count }
122129
123- tool_messages = await asyncio .gather (
130+ results = await asyncio .gather (
124131 * [self ._execute_tool_call (tc ) for tc in last .tool_calls ]
125132 )
133+ tool_messages = [msg for msg , _ in results ]
134+ all_succeeded = bool (results ) and all (success for _ , success in results )
126135 return {
127- "messages" : list ( tool_messages ) ,
136+ "messages" : tool_messages ,
128137 "iteration_count" : state .iteration_count + len (last .tool_calls ),
138+ "last_tool_success" : all_succeeded ,
129139 }
130140
131- async def _execute_tool_call (self , tool_call : ToolCall ) -> ToolMessage :
132- """Execute a single tool call and wrap the result ."""
141+ async def _execute_tool_call (self , tool_call : ToolCall ) -> tuple [ ToolMessage , bool ] :
142+ """Execute a single tool call and report whether it succeeded ."""
133143 args = tool_call .get ("args" , {})
134144 try :
135145 result = await self ._execute_sql_tool .ainvoke (args )
@@ -140,10 +150,18 @@ async def _execute_tool_call(self, tool_call: ToolCall) -> ToolMessage:
140150 "error" : str (e ),
141151 "sql_query" : args .get ("sql_query" , "" ),
142152 }
143- return ToolMessage (
144- content = str (result ),
145- tool_call_id = tool_call ["id" ],
146- name = "execute_sql" ,
153+ succeeded = (
154+ isinstance (result , dict )
155+ and not result .get ("error" )
156+ and result .get ("total_count" , 0 ) > 0
157+ )
158+ return (
159+ ToolMessage (
160+ content = str (result ),
161+ tool_call_id = tool_call ["id" ],
162+ name = "execute_sql" ,
163+ ),
164+ succeeded ,
147165 )
148166
149167 async def termination_node (self , state : DataFabricSubgraphState ) -> dict [str , Any ]:
@@ -161,14 +179,26 @@ async def termination_node(self, state: DataFabricSubgraphState) -> dict[str, An
161179 }
162180
163181 def router (self , state : DataFabricSubgraphState ) -> str :
164- """Route to tool, termination, or END based on state ."""
182+ """Route from ``inner_llm`` to tool, termination, or END."""
165183 last = state .messages [- 1 ] if state .messages else None
166184 if isinstance (last , AIMessage ) and last .tool_calls :
167185 if state .iteration_count < self ._max_iterations :
168186 return "inner_tool"
169187 return "termination"
170188 return END
171189
190+ def tool_router (self , state : DataFabricSubgraphState ) -> str :
191+ """Route from ``inner_tool``: short-circuit on success, retry on error.
192+
193+ Skips the redundant LLM call that would otherwise reformat a
194+ successful SQL result into prose — the outer agent receives the
195+ raw tool output and produces the final natural-language answer.
196+ Errors loop back to ``inner_llm`` so the retry path is preserved.
197+ """
198+ if state .last_tool_success :
199+ return END
200+ return "inner_llm"
201+
172202 def _create_execute_sql_tool (
173203 self ,
174204 entities_service : EntitiesService ,
0 commit comments