Skip to content

Commit 3080cfa

Browse files
committed
Fix BUG-004: Schema Drift & Cleanup
- Implemented idempotent schema/example refresh in OrchestratorVectorStore - Added dynamic registration to DatasourceRegistry - Updated CLI indexing command to use new registry methods - Fixed IntentValidatorNode ImportError - Fixed DatasourceRegistry.get_all AttributeError - Fixed ChromaDB delete filter error - Added comprehensive unit tests in test_schema_lifecycle.py
1 parent e3bc8ee commit 3080cfa

File tree

8 files changed

+442
-245
lines changed

8 files changed

+442
-245
lines changed

audit/remediation_plan.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ This document serves as the master backlog for addressing findings from the Arch
2222
- **Fix**: Sanitize or hash non-user-facing errors in `AggregatorNode` before prompt construction. Only show generic error codes to the LLM.
2323
- **Status**: Fixed. Unit tests added in `tests/unit/test_node_aggregator.py`.
2424

25-
- [ ] **BUG-004: Schema Drift (Stale Cache)** (High)
25+
- [x] **BUG-004: Schema Drift (Stale Cache)** (High)
2626
- **Component**: Governance / Registry
2727
- **Issue**: `DatasourceRegistry` caches adapters indefinitely at startup. If the DB schema changes, the Planner hallucinates invalid columns.
28-
- **Fix**: Implement `SchemaWatcher` or TTL-based cache revocation in `DatasourceRegistry` to force schema refresh.
28+
- **Fix**: Implemented idempotent `refresh_schema` and `refresh_examples` in `OrchestratorVectorStore`, along with dynamic `register_datasource` in `DatasourceRegistry`.
29+
- **Status**: Fixed. Unit tests added in `tests/unit/test_schema_lifecycle.py`.
2930

3031
- [ ] **BUG-005: Missing Distributed Tracing** (High)
3132
- **Component**: Observability / Logging

packages/cli/src/nl2sql_cli/commands/indexing.py

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,28 @@
88

99
@handle_cli_errors
1010
def run_indexing(
11-
configs: Any, # List[Dict[str, Any]]
12-
vector_store_path: str,
13-
vector_store: OrchestratorVectorStore,
14-
llm_registry: Any = None
11+
registry: DatasourceRegistry,
12+
vector_store_path: str,
13+
vector_store: OrchestratorVectorStore,
14+
llm_registry: Any = None,
1515
) -> None:
1616
"""Runs the indexing process for schemas and examples.
1717
1818
This function orchestrates the full indexing workflow:
1919
1. Clears existing data from the vector store.
20-
2. Indexes database schemas (tables, columns, foreign keys) for all configured adapters.
21-
3. Indexes example questions from the sample questions file, optionally enriching them
22-
with synthetic variants using the Semantic Analysis Node.
23-
4. Displays a comprehensive summary table of indexed content.
20+
2. Indexes database schemas (tables, columns, foreign keys).
21+
3. Indexes example questions, optionally enriching them.
22+
4. Displays a comprehensive summary.
2423
2524
Args:
26-
configs (List[Any]): List of datasource configuration objects.
27-
vector_store_path (str): Path to the vector store directory.
28-
vector_store (OrchestratorVectorStore): The initialized vector store instance.
29-
llm_registry (Any, optional): Registry of LLMs used for semantic enrichment of examples.
25+
registry: The initialized DatasourceRegistry.
26+
vector_store_path: Path to the vector store directory.
27+
vector_store: The initialized vector store instance.
28+
llm_registry: Registry of LLMs used for semantic enrichment.
3029
"""
3130
presenter = ConsolePresenter()
3231
presenter.print_indexing_start(vector_store_path)
33-
34-
registry = DatasourceRegistry(configs)
32+
3533
adapters = registry.list_adapters()
3634
stats = []
3735

@@ -42,73 +40,86 @@ def run_indexing(
4240
presenter.console.print("\n[bold]Indexing Schemas...[/bold]")
4341
for adapter in adapters:
4442
ds_id = adapter.datasource_id
45-
43+
4644
with presenter.console.status(f"[cyan]Indexing schema: {ds_id}...[/cyan]"):
4745
try:
48-
schema_stats = vector_store.index_schema(adapter, datasource_id=ds_id)
49-
schema_stats['id'] = ds_id
46+
# Use idempotent refresh
47+
schema_stats = vector_store.refresh_schema(adapter, datasource_id=ds_id)
48+
schema_stats["id"] = ds_id
5049
stats.append(schema_stats)
51-
52-
t_count = schema_stats['tables']
53-
c_count = schema_stats['columns']
54-
presenter.console.print(f" [green][OK][/green] {ds_id} [dim]({t_count} Tables, {c_count} Columns)[/dim]")
55-
50+
51+
t_count = schema_stats["tables"]
52+
c_count = schema_stats["columns"]
53+
presenter.console.print(
54+
f" [green][OK][/green] {ds_id} [dim]({t_count} Tables, {c_count} Columns)[/dim]"
55+
)
56+
5657
except Exception as e:
5758
presenter.console.print(f" [red][FAIL][/red] {ds_id} [red]Failed: {e}[/red]")
58-
stats.append({'id': ds_id, 'tables': 0, 'columns': 0, 'examples': 0, 'error': str(e)})
59+
stats.append(
60+
{"id": ds_id, "tables": 0, "columns": 0, "examples": 0, "error": str(e)}
61+
)
5962

6063
from nl2sql.common.settings import settings
6164
import yaml
6265
import pathlib
63-
66+
6467
presenter.console.print("\n[bold]Indexing Examples...[/bold]")
65-
68+
6669
total_examples = 0
6770
path = pathlib.Path(settings.sample_questions_path)
68-
71+
6972
if path.exists():
7073
try:
7174
examples_data = yaml.safe_load(path.read_text()) or {}
72-
75+
7376
def get_stat_entry(ds_id):
7477
for s in stats:
75-
if s['id'] == ds_id: return s
76-
new_s = {'id': ds_id, 'tables': 0, 'columns': 0, 'examples': 0}
78+
if s["id"] == ds_id:
79+
return s
80+
new_s = {"id": ds_id, "tables": 0, "columns": 0, "examples": 0}
7781
stats.append(new_s)
7882
return new_s
7983

8084
enricher = None
8185
if llm_registry:
8286
try:
8387
from nl2sql.pipeline.nodes.semantic.node import SemanticAnalysisNode
88+
8489
enricher = SemanticAnalysisNode(llm_registry.semantic_llm())
8590
except Exception as e:
8691
presenter.print_warning(f"Could not load SemanticNode: {e}")
8792
else:
88-
presenter.console.print(" [yellow]![/yellow] [dim]Skipping enrichment (No LLM config)[/dim]")
93+
presenter.console.print(
94+
" [yellow]![/yellow] [dim]Skipping enrichment (No LLM config)[/dim]"
95+
)
8996

9097
for ds_id, questions in examples_data.items():
9198
with presenter.console.status(f"[cyan]Indexing examples for {ds_id}...[/cyan]"):
9299
try:
93-
docs = vector_store.prepare_examples_for_datasource(ds_id, questions, enricher)
94-
vector_store.add_documents(docs)
95-
96-
count = len(docs)
100+
# Use idempotent refresh
101+
count = vector_store.refresh_examples(ds_id, questions, enricher)
97102
total_examples += count
98-
103+
99104
# Update stats
100105
entry = get_stat_entry(ds_id)
101-
entry['examples'] = count
102-
103-
presenter.console.print(f" [green][OK][/green] {ds_id} [dim]({count} examples)[/dim]")
104-
106+
entry["examples"] = count
107+
108+
presenter.console.print(
109+
f" [green][OK][/green] {ds_id} [dim]({count} examples)[/dim]"
110+
)
111+
105112
except Exception as e:
106-
presenter.console.print(f" [red][FAIL][/red] {ds_id} [red]Failed: {e}[/red]")
107-
113+
presenter.console.print(
114+
f" [red][FAIL][/red] {ds_id} [red]Failed: {e}[/red]"
115+
)
116+
108117
except Exception as e:
109-
presenter.console.print(f" [red][FAIL][/red] Failed to load {path}: {e}")
118+
presenter.console.print(f" [red][FAIL][/red] Failed to load {path}: {e}")
110119
else:
111-
presenter.console.print(f" [yellow]![/yellow] [dim]No examples file found at {path}[/dim]")
112-
120+
presenter.console.print(
121+
f" [yellow]![/yellow] [dim]No examples file found at {path}[/dim]"
122+
)
123+
113124
presenter.print_indexing_summary(stats)
114125
presenter.print_indexing_complete()

packages/cli/src/nl2sql_cli/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def index(
128128
vector_store_path=vector_store
129129
)
130130

131-
run_indexing(ctx.registry.get_all(), vector_store, ctx.vector_store, ctx.llm_registry)
131+
run_indexing(ctx.registry, vector_store, ctx.vector_store, ctx.llm_registry)
132132

133133
@app.command()
134134
def doctor():

packages/core/src/nl2sql/datasources/registry.py

Lines changed: 86 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Type, List, Any
1+
from typing import Dict, Type, List, Any, Union
22
import importlib
33
from nl2sql_adapter_sdk import DatasourceAdapter
44
from nl2sql.datasources.discovery import discover_adapters
@@ -43,52 +43,88 @@ def __init__(self, configs: List[Dict[str, Any]]):
4343
"""Initializes the registry by eagerly creating adapters for all configs.
4444
4545
Args:
46-
configs (List[Any]): List of datasource configuration objects (Dict or DatasourceConfig).
46+
configs: List of datasource configuration objects (Dict or DatasourceConfig).
4747
"""
4848
self._adapters: Dict[str, DatasourceAdapter] = {}
49-
available_adapters = discover_adapters()
49+
self._available_adapters = discover_adapters()
5050

5151
for config in configs:
5252
try:
53-
# Normalize Pydantic Model to Dict
54-
if hasattr(config, "model_dump"):
55-
config = config.model_dump()
56-
57-
ds_id = config.get("id")
58-
if not ds_id:
59-
raise ValueError("Datasource ID is required. Please check your configuration.")
60-
61-
connection = config.get("connection", {})
62-
conn_type = connection.get("type", "").lower()
63-
resolved_connection = self.resolved_connection(connection)
64-
65-
if conn_type in available_adapters:
66-
AdapterCls = available_adapters[conn_type]
67-
68-
adapter = AdapterCls(
69-
datasource_id=ds_id,
70-
datasource_engine_type=conn_type,
71-
connection_args=resolved_connection,
72-
statement_timeout_ms=config.get("statement_timeout_ms"),
73-
row_limit=config.get("row_limit"),
74-
max_bytes=config.get("max_bytes")
75-
)
76-
self._adapters[ds_id] = adapter
77-
else:
78-
raise ValueError(f"No adapter found for engine type: '{conn_type}' in datasource '{ds_id}'")
79-
53+
self.register_datasource(config)
8054
except Exception as e:
55+
# Log usage would be better here, but we raise to stop startup on bad config
8156
raise ValueError(f"Failed to initialize adapter for '{config.get('id', 'unknown')}': {e}") from e
8257

58+
def register_datasource(self, config: Union[Dict[str, Any], Any]) -> DatasourceAdapter:
59+
"""Registers a new datasource dynamically.
60+
61+
Args:
62+
config: The datasource configuration dictionary or object.
63+
64+
Returns:
65+
DatasourceAdapter: The created and registered adapter.
66+
67+
Raises:
68+
ValueError: If configuration is invalid or adapter type is unknown.
69+
"""
70+
# Normalize Pydantic Model to Dict
71+
if hasattr(config, "model_dump"):
72+
config = config.model_dump()
73+
74+
ds_id = config.get("id")
75+
if not ds_id:
76+
raise ValueError("Datasource ID is required. Please check your configuration.")
77+
78+
connection = config.get("connection", {})
79+
conn_type = connection.get("type", "").lower()
80+
resolved_connection = self.resolved_connection(connection)
81+
82+
if conn_type in self._available_adapters:
83+
adapter_cls = self._available_adapters[conn_type]
84+
85+
adapter = adapter_cls(
86+
datasource_id=ds_id,
87+
datasource_engine_type=conn_type,
88+
connection_args=resolved_connection,
89+
statement_timeout_ms=config.get("statement_timeout_ms"),
90+
row_limit=config.get("row_limit"),
91+
max_bytes=config.get("max_bytes"),
92+
)
93+
self._adapters[ds_id] = adapter
94+
return adapter
95+
else:
96+
raise ValueError(
97+
f"No adapter found for engine type: '{conn_type}' in datasource '{ds_id}'"
98+
)
99+
100+
def refresh_schema(self, datasource_id: str, vector_store: Any) -> Dict[str, int]:
101+
"""Refreshes the schema for a specific datasource.
102+
103+
This triggers a fresh intrusion of the database schema via the adapter
104+
and updates the vector store index.
105+
106+
Args:
107+
datasource_id: The ID of the datasource to refresh.
108+
vector_store: The OrchestratorVectorStore instance.
109+
110+
Returns:
111+
Dict[str, int]: Statistics of the refreshed components.
112+
113+
Raises:
114+
ValueError: If the datasource ID is unknown.
115+
"""
116+
adapter = self.get_adapter(datasource_id)
117+
return vector_store.refresh_schema(adapter, datasource_id)
118+
83119
def get_adapter(self, datasource_id: str) -> DatasourceAdapter:
84120
"""Retrieves the DataSourceAdapter for a datasource.
85121
86122
Args:
87-
datasource_id (str): The ID of the datasource.
123+
datasource_id: The ID of the datasource.
88124
89125
Returns:
90126
DatasourceAdapter: The active adapter instance.
91-
127+
92128
Raises:
93129
ValueError: If the datasource ID is unknown.
94130
"""
@@ -97,13 +133,28 @@ def get_adapter(self, datasource_id: str) -> DatasourceAdapter:
97133
return self._adapters[datasource_id]
98134

99135
def get_dialect(self, datasource_id: str) -> str:
100-
"""Returns a normalized dialect string from the adapter."""
136+
"""Returns a normalized dialect string from the adapter.
137+
138+
Args:
139+
datasource_id: The ID of the datasource.
140+
141+
Returns:
142+
str: The dialect string (e.g., 'postgres').
143+
"""
101144
return self.get_adapter(datasource_id).get_dialect()
102145

103146
def list_adapters(self) -> List[DatasourceAdapter]:
104-
"""Returns a list of all registered adapters."""
147+
"""Returns a list of all registered adapters.
148+
149+
Returns:
150+
List[DatasourceAdapter]: All active adapters.
151+
"""
105152
return list(self._adapters.values())
106153

107154
def list_ids(self) -> List[str]:
108-
"""Returns a list of all registered datasource IDs."""
155+
"""Returns a list of all registered datasource IDs.
156+
157+
Returns:
158+
List[str]: All registered IDs.
159+
"""
109160
return list(self._adapters.keys())
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .node import IntentValidatorNode
2+
3+
__all__ = ["IntentValidatorNode"]

0 commit comments

Comments
 (0)