Skip to content

Commit 275157e

Browse files
committed
feat: enhance query history management and output handling in run pipeline
- Updated the run pipeline function to incorporate subgraph outputs, allowing for more detailed query history tracking. - Improved handling of sub-query metadata and SQL drafts, ensuring accurate reporting of execution details. - Added functionality to display used datasources in verbose mode, enhancing user feedback during execution. - Refactored query history construction to streamline data collection and improve clarity in output presentation.
1 parent 05734e2 commit 275157e

3 files changed

Lines changed: 61 additions & 15 deletions

File tree

packages/cli/src/nl2sql_cli/commands/run.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,61 @@ def run_pipeline(
5959
result_refs = final_state.get("result_refs", {})
6060
sub_queries = final_state.get("sub_queries", [])
6161
sq_map = {sq.id: sq for sq in sub_queries}
62-
62+
6363
query_history = []
64-
for sq_id, result_id in result_refs.items():
65-
sq = sq_map.get(sq_id)
66-
if sq:
67-
frame = ctx.result_store.get(result_id)
68-
metadata = ctx.result_store.get_metadata(result_id)
69-
query_history.append({
70-
"sub_query": sq.intent,
71-
"datasource_id": metadata.get("datasource_id", sq.datasource_id),
72-
"execution": {
64+
subgraph_outputs = final_state.get("subgraph_outputs") or {}
65+
if subgraph_outputs:
66+
for _subgraph_id, output in subgraph_outputs.items():
67+
if isinstance(output, dict):
68+
sub_query = output.get("sub_query")
69+
sql_draft = output.get("sql_draft")
70+
else:
71+
sub_query = getattr(output, "sub_query", None)
72+
sql_draft = getattr(output, "sql_draft", None)
73+
74+
if not sub_query:
75+
continue
76+
77+
if isinstance(sub_query, dict):
78+
sq_id = sub_query.get("id")
79+
ds_id = sub_query.get("datasource_id")
80+
intent = sub_query.get("intent")
81+
else:
82+
sq_id = getattr(sub_query, "id", None)
83+
ds_id = getattr(sub_query, "datasource_id", None)
84+
intent = getattr(sub_query, "intent", None)
85+
86+
entry: Dict[str, Any] = {
87+
"sub_query": intent,
88+
"datasource_id": ds_id,
89+
"sql": sql_draft,
90+
}
91+
92+
result_id = result_refs.get(sq_id)
93+
if result_id:
94+
frame = ctx.result_store.get(result_id)
95+
metadata = ctx.result_store.get_metadata(result_id)
96+
entry["datasource_id"] = metadata.get("datasource_id", ds_id)
97+
entry["execution"] = {
7398
"row_count": frame.row_count,
7499
"columns": [c.name for c in frame.columns],
75-
},
76-
})
100+
}
101+
102+
query_history.append(entry)
103+
else:
104+
for sq_id, result_id in result_refs.items():
105+
sq = sq_map.get(sq_id)
106+
if sq:
107+
frame = ctx.result_store.get(result_id)
108+
metadata = ctx.result_store.get_metadata(result_id)
109+
query_history.append({
110+
"sub_query": sq.intent,
111+
"datasource_id": metadata.get("datasource_id", sq.datasource_id),
112+
"execution": {
113+
"row_count": frame.row_count,
114+
"columns": [c.name for c in frame.columns],
115+
},
116+
})
77117

78118
if config.verbose:
79119
reasoning = final_state.get("reasoning", [])
@@ -91,14 +131,17 @@ def run_pipeline(
91131
pass
92132

93133
if query_history:
134+
datasources_used = sorted({item.get("datasource_id") for item in query_history if item.get("datasource_id")})
135+
if datasources_used:
136+
presenter.print_info(f"Datasources used: {', '.join(datasources_used)}")
137+
94138
for item in query_history:
95139
ds = item.get("datasource_id", "Unknown")
96-
ds_type = item.get("datasource_type", "Unknown")
97140
sub_query = item.get("sub_query")
98141
sql = item.get("sql")
99-
142+
100143
if sql:
101-
header = f"[bold]Datasource: {ds} ({ds_type})[/bold]"
144+
header = f"[bold]Datasource: {ds}[/bold]"
102145
if sub_query:
103146
header = f"[bold]Sub-Query: {sub_query}[/bold]\n" + header
104147
presenter.print_sql(f"{header}\n\n{sql}")

packages/core/src/nl2sql/pipeline/graph_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def _wrapper(state_dict: dict) -> Dict[str, Any]:
101101

102102
executor_response = returned_state.executor_response
103103
planner_response = returned_state.ast_planner_response
104+
generator_response = returned_state.generator_response
104105
sub_reasoning = returned_state.reasoning
105106
artifact_refs: Dict[str, Any] = {}
106107
artifact = executor_response.artifact
@@ -114,6 +115,7 @@ def _wrapper(state_dict: dict) -> Dict[str, Any]:
114115
subgraph_id=subgraph_id,
115116
retry_count=retry_count,
116117
plan=planner_response.plan,
118+
sql_draft=generator_response.sql_draft if generator_response else None,
117119
artifact=artifact,
118120
errors=returned_state.errors,
119121
reasoning=sub_reasoning,

packages/core/src/nl2sql/pipeline/subgraphs/schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class SubgraphOutput(BaseModel):
1616
subgraph_name: Optional[str] = None
1717
retry_count: int = 0
1818
plan: Optional[PlanModel] = None
19+
sql_draft: Optional[str] = None
1920
artifact: Optional[ArtifactRef] = None
2021
errors: List[PipelineError] = Field(default_factory=list)
2122
reasoning: List[Dict[str, Any]] = Field(default_factory=list)

0 commit comments

Comments
 (0)