Skip to content

Commit d05ea66

Browse files
committed
Fix BUG-005: Trace Context & Misc Regressions
- Implemented BUG-005: Distributed Tracing (TraceContextFilter, TraceID in GraphState, traced_node wrapper) - Fixed crash in 'nl2sql setup --demo' (DatasourceRegistry misuse) - Fixed syntax error in Intent Validator prompt (unescaped JSON braces) - Fixed SqliteAdapter DryRunResult schema mismatch (valid -> is_valid) - Updated all Adapters (Postgres, MySQL, MSSQL) get_dialect to return string names - Added unit tests for tracing
1 parent 3080cfa commit d05ea66

File tree

12 files changed

+154
-28
lines changed

12 files changed

+154
-28
lines changed

audit/remediation_plan.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ This document serves as the master backlog for addressing findings from the Arch
2828
- **Fix**: Implemented idempotent `refresh_schema` and `refresh_examples` in `OrchestratorVectorStore`, along with dynamic `register_datasource` in `DatasourceRegistry`.
2929
- **Status**: Fixed. Unit tests added in `tests/unit/test_schema_lifecycle.py`.
3030

31-
- [ ] **BUG-005: Missing Distributed Tracing** (High)
31+
- [x] **BUG-005: Missing Distributed Tracing** (High)
3232
- **Component**: Observability / Logging
3333
- **Issue**: Logs lack a unique `trace_id` per request, making concurrent request debugging impossible in multi-threaded environments.
34-
- **Fix**: Inject `trace_id` at the `GraphState` entry point and propagate it to `Python LogRecord` context.
34+
- **Fix**: Implemented `trace_id` in `GraphState`, `TraceContextFilter` in logger, and `traced_node` wrapper for context propagation.
35+
- **Status**: Fixed. Verified in `tests/unit/test_tracing.py`.
3536

3637
## 🟡 Medium & Low Priority Bugs
3738

configs/llm.demo.yaml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,3 @@ default:
66
model: gpt-4o
77
temperature: 0.0
88
api_key: ${env:OPENAI_API_KEY}
9-
agents:
10-
intent_validator:
11-
provider: openai
12-
model: gpt-4o-mini
13-
temperature: 0.0
14-
api_key: ${env:OPENAI_API_KEY}

packages/adapters/mssql/src/nl2sql_mssql/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def explain(self, sql: str) -> QueryPlan:
9090

9191
def get_dialect(self) -> str:
9292
"""MSSQL uses T-SQL dialect."""
93-
return mssql.dialect()
93+
return mssql.dialect.name
9494

9595
def cost_estimate(self, sql: str) -> CostEstimate:
9696
import re

packages/adapters/mysql/src/nl2sql_mysql/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,5 +116,5 @@ def cost_estimate(self, sql: str) -> CostEstimate:
116116
return CostEstimate(estimated_cost=0.0, estimated_rows=0)
117117

118118
def get_dialect(self) -> str:
119-
return mysql.dialect()
119+
return mysql.dialect.name
120120

packages/adapters/postgres/src/nl2sql_postgres/adapter.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,4 @@ def cost_estimate(self, sql: str) -> CostEstimate:
108108
return CostEstimate(estimated_cost=0.0, estimated_rows=0)
109109

110110
def get_dialect(self) -> str:
111-
return postgresql.dialect()
112-
113-
111+
return postgresql.dialect.name

packages/adapters/sqlite/src/nl2sql_sqlite/adapter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def dry_run(self, query: str) -> DryRunResult:
6464
try:
6565
with self.engine.connect() as conn:
6666
conn.execute(text(f"EXPLAIN QUERY PLAN {query}"))
67-
return DryRunResult(valid=True, error=None)
67+
return DryRunResult(is_valid=True, error_message=None)
6868
except Exception as e:
69-
return DryRunResult(valid=False, error=str(e))
69+
return DryRunResult(is_valid=False, error_message=str(e))
7070

7171
def explain(self, query: str) -> QueryPlan:
7272
return QueryPlan(original_query=query, plan="EXPLAIN QUERY PLAN not fully parsed")
@@ -83,4 +83,4 @@ def cost_estimate(self, query: str) -> CostEstimate:
8383

8484

8585
def get_dialect(self) -> str:
86-
return sqlite.dialect()
86+
return sqlite.dialect.name

packages/cli/src/nl2sql_cli/demo/manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def index_demo_data(self):
210210
from nl2sql.services.vector_store import OrchestratorVectorStore
211211
from nl2sql.services.llm import LLMRegistry
212212
from nl2sql.configs import ConfigManager
213+
from nl2sql.datasources import DatasourceRegistry
213214

214215
try:
215216
indexer_config_manager = ConfigManager(self.project_root)
@@ -224,7 +225,8 @@ def index_demo_data(self):
224225
except Exception:
225226
pass # Optional for schema indexing
226227

227-
run_indexing(configs, settings.vector_store_path, v_store, llm_registry)
228+
registry = DatasourceRegistry(configs)
229+
run_indexing(registry, settings.vector_store_path, v_store, llm_registry)
228230
return True
229231
except Exception as e:
230232
self.print_error(f"Indexing Failed: {e}")

packages/core/src/nl2sql/common/logger.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,26 @@
11
import logging
22
import json
33
import time
4-
from typing import Any, Dict
4+
import contextvars
5+
from contextlib import contextmanager
6+
from typing import Any, Dict, Optional
57

8+
_trace_id_ctx = contextvars.ContextVar("trace_id", default=None)
9+
10+
class TraceContextFilter(logging.Filter):
11+
"""Injects trace_id from contextvar into the log record."""
12+
def filter(self, record):
13+
record.trace_id = _trace_id_ctx.get()
14+
return True
15+
16+
@contextmanager
17+
def trace_context(trace_id: str):
18+
"""Context manager to set the trace_id for the current context."""
19+
token = _trace_id_ctx.set(trace_id)
20+
try:
21+
yield
22+
finally:
23+
_trace_id_ctx.reset(token)
624

725
class JsonFormatter(logging.Formatter):
826
"""Formatter that outputs JSON strings after parsing the LogRecord."""
@@ -23,13 +41,16 @@ def format(self, record: logging.LogRecord) -> str:
2341
"message": record.getMessage(),
2442
}
2543

44+
if getattr(record, "trace_id", None):
45+
log_record["trace_id"] = record.trace_id
46+
2647
# Standard LogRecord attributes to ignore
2748
standard_attrs = {
2849
"args", "asctime", "created", "exc_info", "exc_text", "filename",
2950
"funcName", "levelname", "levelno", "lineno", "module",
3051
"msecs", "message", "msg", "name", "pathname", "process",
3152
"processName", "relativeCreated", "stack_info", "thread", "threadName",
32-
"taskName"
53+
"taskName", "trace_id"
3354
}
3455

3556
for key, value in record.__dict__.items():
@@ -54,11 +75,19 @@ def configure_logging(level: str = "INFO", json_format: bool = False):
5475
root_logger.removeHandler(handler)
5576

5677
handler = logging.StreamHandler()
78+
handler.addFilter(TraceContextFilter())
5779

5880
if json_format:
5981
handler.setFormatter(JsonFormatter())
6082
else:
61-
handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
83+
# Include trace_id in standard format if present
84+
# This is a bit tricky with dynamic formatting, usually easier to check record in formatter
85+
# For simplicity, we stick to standard format but maybe prepend trace_id if possible?
86+
# We'll stick to a standard format for text logs for now, trace_id mainly for JSON/Production
87+
formatter = logging.Formatter(
88+
"%(asctime)s - [%(trace_id)s] - %(name)s - %(levelname)s - %(message)s"
89+
)
90+
handler.setFormatter(formatter)
6291

6392
root_logger.addHandler(handler)
6493

packages/core/src/nl2sql/pipeline/graph.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,29 @@
1515
from nl2sql.services.vector_store import OrchestratorVectorStore
1616
from nl2sql.services.llm import LLMRegistry
1717
from nl2sql.common.errors import PipelineError, ErrorSeverity, ErrorCode
18+
from nl2sql.common.logger import trace_context
1819

1920

2021
LLMCallable = Union[Callable[[str], str], Runnable]
2122

2223

24+
def traced_node(node: Callable):
25+
"""Wraps a node to inject trace_id from state into the logging context."""
26+
def wrapper(state: Union[Dict, Any]):
27+
# Extract trace_id from state (dict or object)
28+
tid = None
29+
if isinstance(state, dict):
30+
tid = state.get("trace_id")
31+
else:
32+
tid = getattr(state, "trace_id", None)
33+
34+
if tid:
35+
with trace_context(tid):
36+
return node(state)
37+
return node(state)
38+
return wrapper
39+
40+
2341
def build_graph(
2442
registry: DatasourceRegistry,
2543
llm_registry: LLMRegistry,
@@ -112,12 +130,12 @@ def report_missing_datasource(state: Dict):
112130
"intermediate_results": [message],
113131
}
114132

115-
graph.add_node("semantic_analysis", semantic_node)
116-
graph.add_node("intent_validator", intent_validator_node)
117-
graph.add_node("decomposer", decomposer_node)
118-
graph.add_node("execution_branch", execution_wrapper)
119-
graph.add_node("report_missing_datasource", report_missing_datasource)
120-
graph.add_node("aggregator", aggregator_node)
133+
graph.add_node("semantic_analysis", traced_node(semantic_node))
134+
graph.add_node("intent_validator", traced_node(intent_validator_node))
135+
graph.add_node("decomposer", traced_node(decomposer_node))
136+
graph.add_node("execution_branch", traced_node(execution_wrapper))
137+
graph.add_node("report_missing_datasource", traced_node(report_missing_datasource))
138+
graph.add_node("aggregator", traced_node(aggregator_node))
121139

122140
graph.set_entry_point("semantic_analysis")
123141

@@ -140,6 +158,7 @@ def continue_to_subqueries(state: GraphState):
140158
branches.append(Send("report_missing_datasource", {"user_query": sq.query}))
141159

142160
payload = {
161+
"trace_id": state.trace_id,
143162
"user_query": sq.query,
144163
"selected_datasource_id": sq.datasource_id,
145164
"complexity": sq.complexity,

packages/core/src/nl2sql/pipeline/nodes/intent_validator/prompts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
1818
[OUTPUT SCHEMA]
1919
Return a JSON object matching `IntentValidationResult`:
20-
{
20+
{{
2121
"is_safe": boolean,
2222
"violation_category": "jailbreak" | "pii_exfiltration" | "destructive" | "system_probing" | "none",
2323
"reasoning": "string"
24-
}
24+
}}
2525
2626
[USER_QUERY]
2727
{user_query}

0 commit comments

Comments
 (0)