Skip to content

Commit a3a5c22

Browse files
committed
feat: refactor schema handling and enhance column/table management
- Removed deprecated Column and Table classes from the schema module, introducing ColumnRef for better reference management. - Updated the SchemaRetrieverNode to build tables from schema snapshots, incorporating relationships and metadata. - Enhanced the LogicalValidatorNode to enforce join relationships and validate against column statistics. - Improved error handling and logging in various pipeline nodes for better traceability. - Refactored imports across the codebase to utilize the new schema structure, ensuring consistency and clarity.
1 parent bff7c64 commit a3a5c22

16 files changed

Lines changed: 548 additions & 152 deletions

File tree

packages/adapter-sdk/src/nl2sql_adapter_sdk/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from .contracts import AdapterRequest, ResultError, ResultFrame
55
from .protocols import DatasourceAdapterProtocol
66
from .schema import (
7-
Column,
8-
Table,
97
TableRef,
108
ColumnStatistics,
119
ColumnMetadata,
@@ -16,6 +14,7 @@
1614
SchemaContract,
1715
SchemaMetadata,
1816
SchemaSnapshot,
17+
ColumnRef
1918
)
2019

2120
__all__ = [
@@ -24,8 +23,6 @@
2423
"ResultError",
2524
"ResultFrame",
2625
"DatasourceAdapterProtocol",
27-
"Column",
28-
"Table",
2926
"TableRef",
3027
"ColumnStatistics",
3128
"ColumnMetadata",
@@ -36,4 +33,5 @@
3633
"SchemaContract",
3734
"SchemaMetadata",
3835
"SchemaSnapshot",
36+
"ColumnRef",
3937
]

packages/adapter-sdk/src/nl2sql_adapter_sdk/schema.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,10 @@
88
JsonValue = Union[Scalar, List[Scalar], Dict[str, Scalar]]
99

1010

11-
class Column(BaseModel):
12-
"""Lightweight column schema for routing/planning."""
13-
14-
name: str
15-
type: Optional[str] = None
16-
17-
model_config = ConfigDict(extra="allow")
18-
19-
20-
class Table(BaseModel):
21-
"""Lightweight table schema for routing/planning."""
22-
23-
name: str
24-
columns: List[Column] = Field(default_factory=list)
11+
class ColumnRef(BaseModel):
12+
table: TableRef
13+
column_name: str
2514

26-
model_config = ConfigDict(extra="allow")
2715

2816

2917
class TableRef(BaseModel):

packages/core/src/nl2sql/indexing/chunk_builder.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
ColumnChunk,
88
RelationshipChunk,
99
MetricChunk,
10-
TableRef,
11-
ColumnRef,
1210
)
1311
from nl2sql.schema import SchemaSnapshot
12+
from nl2sql_adapter_sdk.schema import TableRef, ColumnRef
1413

1514

1615
class SchemaChunkBuilder:
@@ -150,14 +149,13 @@ def _build_column_chunks(self) -> List[ColumnChunk]:
150149
column_md = table_md.columns.get(column_name) if table_md else None
151150

152151
column_ref = ColumnRef(
153-
schema_name=table_ref.schema_name,
154-
table_name=table_ref.table_name,
152+
table=table_ref,
155153
column_name=column_name,
156154
)
157155

158156
chunks.append(
159157
ColumnChunk(
160-
id=f"schema.column:{column_ref.full_name}:{self.schema_version}",
158+
id=f"schema.column:{column_ref.table.full_name}:{column_ref.column_name}:{self.schema_version}",
161159
datasource_id=self.ds_id,
162160
column=column_ref,
163161
dtype=column_contract.data_type,

packages/core/src/nl2sql/indexing/models.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
from typing import List, Dict, Optional, Literal, Any
33
from pydantic import BaseModel, Field
4+
from nl2sql_adapter_sdk.schema import TableRef, ColumnRef
45

56

67
class BaseChunk(BaseModel):
@@ -20,28 +21,6 @@ def get_metadata(self) -> Dict[str, Any]:
2021
"type": self.type,
2122
}
2223

23-
class TableRef(BaseModel):
24-
schema_name: str
25-
table_name: str
26-
27-
@property
28-
def full_name(self) -> str:
29-
return f"{self.schema_name}.{self.table_name}"
30-
31-
32-
class ColumnRef(BaseModel):
33-
schema_name: str
34-
table_name: str
35-
column_name: str
36-
37-
@property
38-
def full_name(self) -> str:
39-
return f"{self.schema_name}.{self.table_name}.{self.column_name}"
40-
41-
@property
42-
def table_full_name(self) -> str:
43-
return f"{self.schema_name}.{self.table_name}"
44-
4524

4625
class DatasourceChunk(BaseChunk):
4726
type: Literal["schema.datasource"] = Field(
@@ -106,6 +85,9 @@ def get_metadata(self) -> Dict[str, Any]:
10685
"table": self.table.full_name,
10786
"row_count": self.row_count,
10887
"schema_version": self.schema_version,
88+
"description": self.description,
89+
"primary_key": ','.join(self.primary_key),
90+
"foreign_keys": ','.join(self.foreign_keys),
10991
}
11092

11193

@@ -131,7 +113,8 @@ def get_page_content(self) -> str:
131113
else ""
132114
)
133115
return (
134-
f"Column: {self.column.full_name}\n"
116+
f"Table: {self.column.table.full_name}\n"
117+
f"Column: {self.column.column_name}\n"
135118
f"Type: {self.dtype}\n"
136119
f"{self.description or ''}\n"
137120
f"{stats}\n"
@@ -142,11 +125,12 @@ def get_metadata(self) -> Dict[str, Any]:
142125
return {
143126
**super().get_metadata(),
144127
"datasource_id": self.datasource_id,
145-
"column": self.column.full_name,
146-
"table": self.column.table_full_name,
128+
"column": self.column.column_name,
129+
"table": self.column.table.full_name,
147130
"dtype": self.dtype,
148131
"pii": self.pii,
149132
"schema_version": self.schema_version,
133+
"description": self.description,
150134
}
151135

152136
class RelationshipChunk(BaseChunk):

packages/core/src/nl2sql/indexing/vector_store.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@ def _initialize_vector_store(self) -> None:
5050
persist_directory=self.persist_directory,
5151
)
5252

53+
def initialize_if_not_exists(self) -> None:
54+
"""
55+
Initializes the vector store if it does not exist.
56+
"""
57+
try:
58+
_ = self.vectorstore._collection.count()
59+
except Exception:
60+
logger.info("Vector store not found, initializing new store.")
61+
self._initialize_vector_store()
62+
5363
def is_empty(self) -> bool:
5464
"""
5565
Checks whether the vector store is empty.
@@ -185,6 +195,7 @@ def retrieve_datasource_candidates(
185195
Returns:
186196
Retrieved datasource documents.
187197
"""
198+
self.initialize_if_not_exists()
188199
from nl2sql.common.resilience import VECTOR_BREAKER
189200

190201
@VECTOR_BREAKER
@@ -216,8 +227,10 @@ def retrieve_schema_context(
216227
Returns:
217228
Retrieved schema documents.
218229
"""
230+
self.initialize_if_not_exists()
219231
from nl2sql.common.resilience import VECTOR_BREAKER
220232

233+
221234
@VECTOR_BREAKER
222235
def _execute():
223236
return self.vectorstore.max_marginal_relevance_search(
@@ -254,6 +267,8 @@ def retrieve_column_candidates(
254267
"""
255268
from nl2sql.common.resilience import VECTOR_BREAKER
256269

270+
self.initialize_if_not_exists()
271+
257272
@VECTOR_BREAKER
258273
def _execute():
259274
return self.vectorstore.max_marginal_relevance_search(
@@ -292,6 +307,8 @@ def retrieve_planning_context(
292307
"""
293308
from nl2sql.common.resilience import VECTOR_BREAKER
294309

310+
self.initialize_if_not_exists()
311+
295312
@VECTOR_BREAKER
296313
def _execute():
297314
return self.vectorstore.max_marginal_relevance_search(

packages/core/src/nl2sql/pipeline/nodes/ast_planner/node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def __call__(self, state: SubgraphExecutionState) -> Dict[str, Any]:
7474
}
7575
)
7676

77+
logger.info(f"Generated Plan: {plan.model_dump_json(indent=2)}")
78+
7779
return {
7880
"ast_planner_response": ASTPlannerResponse(plan=plan),
7981
"reasoning": [

packages/core/src/nl2sql/pipeline/nodes/ast_planner/prompts.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,22 @@
104104

105105
"[INSTRUCTIONS]\n"
106106
"1. Analyze [USER_QUERY] and [SEMANTIC_CONTEXT].\n"
107-
"2. Select tables from [RELEVANT_TABLES]. Assign strict 'ordinal' positions 0..N.\n"
108-
"3. Define joins using ONLY table aliases (left_alias/right_alias).\n"
109-
"4. Build Expr trees using:\n"
107+
"2. Select ONLY tables from [RELEVANT_TABLES]. Assign strict 'ordinal' positions 0..N.\n"
108+
"3. When joining, use relationships listed in [RELEVANT_TABLES]. If no relationship exists, do not join.\n"
109+
"4. Define joins using ONLY table aliases (left_alias/right_alias).\n"
110+
"5. Build Expr trees using:\n"
110111
" literal | column | func | binary | unary | case\n"
111-
"5. Every list MUST contain `ordinal` fields in ascending order starting at 0.\n"
112-
"6. Order lists to match ordinals (0..N) exactly.\n\n"
112+
"6. Every list MUST contain `ordinal` fields in ascending order starting at 0.\n"
113+
"7. Order lists to match ordinals (0..N) exactly.\n"
114+
"8. For literal values on '=' or 'IN', choose values from column stats if available.\n"
115+
"9. If no exact match is available, fall back to LIKE but keep the pattern derived from stats/synonyms.\n\n"
113116

114117
"[OUTPUT CONTRACT]\n"
115118
"- If [EXPECTED_SCHEMA] is provided and non-empty:\n"
116119
" - select_items length MUST equal expected_schema length.\n"
117120
" - select_items aliases MUST match expected_schema names in the same order.\n"
118121
"- All table/column references MUST come from [RELEVANT_TABLES].\n"
122+
"- Joins MUST use relationships provided in [RELEVANT_TABLES].\n"
119123
"- The [EXAMPLES] are illustrative; always follow [EXPECTED_SCHEMA] when provided.\n\n"
120124

121125
"[CONSTRAINTS]\n"

0 commit comments

Comments
 (0)