11import sys
2- import argparse
32import statistics
43import yaml
54from nl2sql .services .llm import parse_llm_config , LLMRegistry , get_usage_summary
98from nl2sql .pipeline .graph import run_with_graph
109from nl2sql .evaluation .evaluator import ModelEvaluator
1110from nl2sql .reporting import ConsolePresenter
11+ from nl2sql_cli .types import BenchmarkConfig
1212
1313
1414def run_benchmark (
15- args : argparse . Namespace ,
15+ config : BenchmarkConfig ,
1616 datasource_registry : DatasourceRegistry ,
1717 vector_store : OrchestratorVectorStore
1818) -> None :
1919 """Runs the benchmark suite based on provided arguments.
2020
2121 Args:
22- args (argparse.Namespace ): Command-line arguments .
22+ config (BenchmarkConfig ): Benchmark run configuration .
2323 datasource_registry (DatasourceRegistry): Datasource registry.
2424 vector_store (OrchestratorVectorStore): Vector store instance.
2525 """
@@ -28,9 +28,9 @@ def run_benchmark(
2828 # Matrix Benchmarking
2929 llm_configs = {}
3030
31- if args . bench_config and args . bench_config .exists ():
31+ if config . bench_config_path and config . bench_config_path .exists ():
3232 try :
33- bench_data = yaml .safe_load (args . bench_config .read_text ()) or {}
33+ bench_data = yaml .safe_load (config . bench_config_path .read_text ()) or {}
3434 for name , cfg_data in bench_data .items ():
3535 if isinstance (cfg_data , dict ):
3636 llm_configs [name ] = parse_llm_config (cfg_data )
@@ -40,11 +40,11 @@ def run_benchmark(
4040
4141 if not llm_configs :
4242 llm_cfg = parse_llm_config ({"default" : {"provider" : "openai" , "model" : "gpt-4o" }}) # Fallback
43- if args . llm_config and args . llm_config .exists ():
43+ if config . llm_config_path and config . llm_config_path .exists ():
4444 from nl2sql .services .llm import load_llm_config
45- llm_cfg = load_llm_config (args . llm_config )
45+ llm_cfg = load_llm_config (config . llm_config_path )
4646
47- if getattr ( args , " stub_llm" , False ) :
47+ if config . stub_llm :
4848 llm_cfg .default .provider = "stub"
4949 for agent_cfg in llm_cfg .agents .values ():
5050 agent_cfg .provider = "stub"
@@ -55,7 +55,7 @@ def run_benchmark(
5555 for name , llm_cfg in llm_configs .items ():
5656 llm_registry = LLMRegistry (llm_cfg )
5757 _run_dataset_evaluation (
58- args ,
58+ config ,
5959 datasource_registry ,
6060 vector_store ,
6161 llm_registry ,
@@ -64,7 +64,7 @@ def run_benchmark(
6464
6565
6666def _run_dataset_evaluation (
67- args : argparse . Namespace ,
67+ config : BenchmarkConfig ,
6868 datasource_registry : DatasourceRegistry ,
6969 vector_store : OrchestratorVectorStore ,
7070 llm_registry : LLMRegistry ,
@@ -73,7 +73,7 @@ def _run_dataset_evaluation(
7373 """Runs evaluation against a golden dataset for a specific config.
7474
7575 Args:
76- args (argparse.Namespace ): CLI arguments.
76+ config (BenchmarkConfig ): CLI arguments.
7777 datasource_registry (DatasourceRegistry): Registry of datasources.
7878 vector_store (OrchestratorVectorStore): Vector store.
7979 llm_registry (LLMRegistry): LLM registry.
@@ -84,7 +84,7 @@ def _run_dataset_evaluation(
8484
8585 presenter = ConsolePresenter ()
8686
87- dataset_path = args . dataset
87+ dataset_path = config . dataset_path
8888 if not dataset_path .exists ():
8989 presenter .print_error (f"Dataset file not found: { dataset_path } " )
9090 sys .exit (1 )
@@ -99,10 +99,10 @@ def _run_dataset_evaluation(
9999 presenter .print_error ("Dataset must be a list of test cases." )
100100 sys .exit (1 )
101101
102- if args .include_ids :
103- dataset = [item for item in dataset if item .get ("id" ) in args .include_ids ]
102+ if config .include_ids :
103+ dataset = [item for item in dataset if item .get ("id" ) in config .include_ids ]
104104 if not dataset :
105- presenter .print_error (f"No test cases found matching IDs: { args .include_ids } " )
105+ presenter .print_error (f"No test cases found matching IDs: { config .include_ids } " )
106106 sys .exit (1 )
107107
108108 presenter .print_header (f"Evaluating Config: { config_name } " )
@@ -122,9 +122,9 @@ def _evaluate_case(item: dict) -> dict:
122122 llm_registry = llm_registry ,
123123 user_query = question ,
124124 datasource_id = None ,
125- execute = not args .routing_only ,
125+ execute = not config .routing_only ,
126126 vector_store = vector_store ,
127- vector_store_path = args . vector_store
127+ vector_store_path = config . vector_store_path
128128 )
129129 except Exception as e :
130130 return {
@@ -181,7 +181,7 @@ def get_val(obj, key, default=None):
181181
182182 layer_match = (routing_layer == expected_layer )
183183
184- if args .routing_only :
184+ if config .routing_only :
185185 return {
186186 "id" : q_id ,
187187 "question" : question ,
@@ -329,7 +329,7 @@ def get_val(obj, key, default=None):
329329
330330 import concurrent .futures
331331 workers = 5
332- iterations = args .iterations if args .iterations else 1
332+ iterations = config .iterations if config .iterations else 1
333333
334334 with concurrent .futures .ThreadPoolExecutor (max_workers = workers ) as executor :
335335 futures = []
@@ -345,11 +345,11 @@ def get_val(obj, key, default=None):
345345 results .sort (key = lambda x : x ["id" ])
346346
347347 # Delegate Reporting
348- presenter .print_dataset_benchmark_results (results , iterations = iterations , routing_only = args .routing_only )
348+ presenter .print_dataset_benchmark_results (results , iterations = iterations , routing_only = config .routing_only )
349349
350350 # Calculate Metrics and Print Summary
351351 metrics = ModelEvaluator .calculate_aggregate_metrics (results , len (results ))
352- presenter .print_metrics_summary (metrics , results , routing_only = args .routing_only )
352+ presenter .print_metrics_summary (metrics , results , routing_only = config .routing_only )
353353
354- if args .export_path :
355- presenter .export_results (results , args .export_path )
354+ if config .export_path :
355+ presenter .export_results (results , config .export_path )
0 commit comments