Skip to content

Commit bff7c64

Browse files
committed
feat: update artifact handling and schema management in execution pipeline
- Enhanced artifact handling by introducing a unified method for creating artifact references across different storage backends (S3, ADLS, Local). - Refactored the execution contracts to replace the deprecated ExecutorBaseModel with ExecutorResponse, improving clarity and consistency. - Updated the ExecutorNode to include tenant_id in requests, facilitating multi-tenant support. - Improved error handling and logging across various pipeline nodes, ensuring better traceability and debugging capabilities. - Removed unused schema management methods and streamlined the datasource resolution process for improved performance.
1 parent e7ffb3b commit bff7c64

39 files changed

Lines changed: 214 additions & 865 deletions

File tree

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,6 @@ site
2626

2727
data
2828

29-
last_reasoning.json
29+
last_reasoning.json
30+
31+
artifacts/

packages/adapter-sdk/src/nl2sql_adapter_sdk/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Adapter SDK: shared contracts for core and adapters."""
22

33
from .capabilities import DatasourceCapability
4-
from .contracts import AdapterRequest, ResultColumn, ResultError, ResultFrame
4+
from .contracts import AdapterRequest, ResultError, ResultFrame
55
from .protocols import DatasourceAdapterProtocol
66
from .schema import (
77
Column,
@@ -21,7 +21,6 @@
2121
__all__ = [
2222
"DatasourceCapability",
2323
"AdapterRequest",
24-
"ResultColumn",
2524
"ResultError",
2625
"ResultFrame",
2726
"DatasourceAdapterProtocol",

packages/adapter-sdk/src/nl2sql_adapter_sdk/contracts.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,6 @@ class AdapterRequest(BaseModel):
2626
model_config = ConfigDict(extra="ignore")
2727

2828

29-
class ResultColumn(BaseModel):
30-
"""Column metadata for a ResultFrame."""
31-
32-
name: str
33-
type: str = Field(default="unknown", description="Logical or native column type.")
34-
3529

3630
class ResultError(BaseModel):
3731
"""Standardized error envelope for adapter results."""
@@ -49,7 +43,7 @@ class ResultFrame(BaseModel):
4943
"""Adapter-agnostic, DataFrame-like result contract."""
5044

5145
success: bool = Field(default=True)
52-
columns: List[ResultColumn] = Field(default_factory=list)
46+
columns: List[str] = Field(default_factory=list)
5347
rows: List[List[Any]] = Field(default_factory=list)
5448
row_count: int = Field(default=0)
5549
truncated: bool = Field(default=False)
@@ -75,11 +69,10 @@ def from_row_dicts(
7569
if columns is None:
7670
columns = list(rows[0].keys()) if rows else []
7771

78-
col_specs = [ResultColumn(name=col, type="unknown") for col in columns]
7972
row_values = [[row.get(col) for col in columns] for row in rows]
8073

8174
return cls(
82-
columns=col_specs,
75+
columns=columns,
8376
rows=row_values,
8477
row_count=row_count if row_count is not None else len(row_values),
8578
**kwargs,
@@ -90,5 +83,5 @@ def to_row_dicts(self) -> List[Dict[str, Any]]:
9083

9184
if not self.rows or not self.columns:
9285
return []
93-
names = [col.name for col in self.columns]
86+
names = self.columns
9487
return [dict(zip(names, row)) for row in self.rows]

packages/adapter-sqlalchemy/src/nl2sql_sqlalchemy_adapter/adapter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from nl2sql_adapter_sdk.capabilities import DatasourceCapability
66
from nl2sql_adapter_sdk.contracts import (
77
AdapterRequest,
8-
ResultColumn,
98
ResultError,
109
ResultFrame,
1110
)
@@ -155,7 +154,7 @@ def execute_sql(self, sql: str) -> ResultFrame:
155154

156155
return ResultFrame(
157156
success=True,
158-
columns=[ResultColumn(name=col, type="unknown") for col in cols],
157+
columns=cols,
159158
rows=rows,
160159
row_count=row_count,
161160
bytes=total_bytes,

packages/core/src/nl2sql/aggregation/aggregator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def execute(
3434

3535
computed: Dict[str, pl.DataFrame] = {}
3636

37-
for layer in dag.layers or []:
37+
for layer in dag.layers:
3838
for node_id in layer:
3939
node = node_index.get(node_id)
4040
if not node:

packages/core/src/nl2sql/aggregation/engines/polars_duckdb.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,7 @@ def __init__(self):
1616
self.artifact_store = build_artifact_store()
1717

1818
def load_scan(self, artifact: ArtifactRef) -> pl.DataFrame:
19-
if artifact.uri.startswith("s3://") or artifact.uri.startswith("abfs://"):
20-
result_frame = self.artifact_store.read_result_frame(artifact)
21-
return pl.from_dicts(result_frame.to_row_dicts())
22-
23-
relation = duckdb.query(f"SELECT * FROM '{artifact.uri}'")
24-
table = relation.arrow()
25-
return pl.from_arrow(table)
19+
return self.artifact_store.read_parquet(artifact)
2620

2721
def combine(
2822
self,

packages/core/src/nl2sql/common/errors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ class ErrorCode(str, Enum):
4747
PIPELINE_TIMEOUT = "PIPELINE_TIMEOUT"
4848
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
4949
EXECUTION_TIMEOUT = "EXECUTION_TIMEOUT"
50-
CANCELLED = "CANCELLED"
50+
CANCELLED = "CANCELLED",
51+
EXECUTION_FAILED = "EXECUTION_FAILED"
5152

5253

5354

packages/core/src/nl2sql/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
vector_store_path = pathlib.Path(settings.vector_store_path)
4949

5050
cm = ConfigManager()
51+
self.tenant_id = settings.tenant_id
5152
self.config_manager = cm
5253

5354
secret_configs = cm.load_secrets(secrets_config_path)

packages/core/src/nl2sql/datasources/registry.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -112,24 +112,6 @@ def register_datasource(self, config: DatasourceConfig) -> DatasourceAdapterProt
112112
f"No adapter found for engine type: '{conn_type}' in datasource '{ds_id}'"
113113
)
114114

115-
def refresh_schema(self, datasource_id: str, vector_store: Any) -> Dict[str, int]:
116-
"""Refreshes the schema for a specific datasource.
117-
118-
This triggers a fresh intrusion of the database schema via the adapter
119-
and updates the vector store index.
120-
121-
Args:
122-
datasource_id: The ID of the datasource to refresh.
123-
vector_store: The VectorStore instance.
124-
125-
Returns:
126-
Dict[str, int]: Statistics of the refreshed components.
127-
128-
Raises:
129-
ValueError: If the datasource ID is unknown.
130-
"""
131-
adapter = self.get_adapter(datasource_id)
132-
return vector_store.refresh_schema(adapter, datasource_id)
133115

134116
def get_adapter(self, datasource_id: str) -> DatasourceAdapterProtocol:
135117
"""Retrieves the adapter for a datasource.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .contracts import ArtifactRef, ExecutorBaseModel, ExecutorRequest
1+
from .contracts import ArtifactRef, ExecutorResponse, ExecutorRequest
22
from .execution_store import ExecutionStore
33

4-
__all__ = ["ArtifactRef", "ExecutorBaseModel", "ExecutorRequest", "ExecutionStore"]
4+
__all__ = ["ArtifactRef", "ExecutorResponse", "ExecutorRequest", "ExecutionStore"]

0 commit comments

Comments
 (0)