@@ -1168,6 +1168,67 @@ def format_comparison_report(scdl_result, anndata_result, input_path: str, sampl
11681168 return "\n " .join (report )
11691169
11701170
1171+ def average_benchmark_results (results : List [BenchmarkResult ], averaged_name : str | None = None ) -> BenchmarkResult :
1172+ """Average multiple benchmark results into a single result."""
1173+ if not results :
1174+ raise ValueError ("Cannot average empty list of results" )
1175+ if len (results ) == 1 :
1176+ return results [0 ]
1177+
1178+ base = results [0 ]
1179+ n = len (results )
1180+ name = averaged_name or f"{ base .name } (avg of { n } runs)"
1181+
1182+ # Average all numeric fields
1183+ numeric_fields = {
1184+ "disk_size_mb" ,
1185+ "setup_time_seconds" ,
1186+ "warmup_time_seconds" ,
1187+ "total_iteration_time_seconds" ,
1188+ "average_batch_time_seconds" ,
1189+ "samples_per_second" ,
1190+ "batches_per_second" ,
1191+ "peak_memory_mb" ,
1192+ "average_memory_mb" ,
1193+ "gpu_memory_mb" ,
1194+ }
1195+
1196+ kwargs = {"name" : name }
1197+ for field in numeric_fields :
1198+ kwargs [field ] = sum (getattr (r , field ) for r in results ) / n
1199+
1200+ # Average integer fields
1201+ for field in ["total_batches" , "total_samples" , "warmup_samples" , "warmup_batches" ]:
1202+ kwargs [field ] = int (sum (getattr (r , field ) for r in results ) / n )
1203+
1204+ # Average optional numeric fields (only if present in all results)
1205+ optional_fields = [
1206+ "instantiation_time_seconds" ,
1207+ "peak_memory_during_instantiation_mb" ,
1208+ "memory_before_instantiation_mb" ,
1209+ "memory_after_instantiation_mb" ,
1210+ "conversion_time_seconds" ,
1211+ "load_time_seconds" ,
1212+ ]
1213+ for field in optional_fields :
1214+ values = [getattr (r , field ) for r in results if getattr (r , field ) is not None ]
1215+ if len (values ) == len (results ): # Only average if all have values
1216+ kwargs [field ] = sum (values ) / len (values )
1217+
1218+ # Copy non-numeric fields from base
1219+ for field in [
1220+ "data_path" ,
1221+ "max_time_seconds" ,
1222+ "shuffle" ,
1223+ "madvise_interval" ,
1224+ "conversion_performed" ,
1225+ "load_performed" ,
1226+ ]:
1227+ kwargs [field ] = getattr (base , field )
1228+
1229+ return BenchmarkResult (** kwargs )
1230+
1231+
11711232def main ():
11721233 """Main function to run the benchmark."""
11731234 parser = argparse .ArgumentParser (
@@ -1182,6 +1243,8 @@ def main():
11821243 python scdl_speedtest.py --csv # Export detailed CSV files
11831244 python scdl_speedtest.py --json results.json # Export detailed JSON file
11841245 python scdl_speedtest.py --generate-baseline # Compare SCDL vs AnnData performance
1246+ python scdl_speedtest.py --num-runs 3 # Run 3 iterations and average results
1247+ python scdl_speedtest.py --num-runs 5 --csv # Average 5 runs and export CSV
11851248 """ ,
11861249 )
11871250
@@ -1205,9 +1268,15 @@ def main():
12051268 parser .add_argument ("--scdl-path" , type = str , help = "Path to SCDL dataset (default: None)" )
12061269
12071270 parser .add_argument ("--num-epochs" , type = int , default = 1 , help = "Number of epochs (default: 1)" )
1271+ parser .add_argument ("--num-runs" , type = int , default = 1 , help = "Number of benchmark runs to average (default: 1)" )
12081272
12091273 args = parser .parse_args ()
12101274
1275+ # Validate num_runs parameter
1276+ if args .num_runs < 1 :
1277+ print ("Error: --num-runs must be a positive integer" )
1278+ sys .exit (1 )
1279+
12111280 # Check if baseline generation is requested
12121281 if args .generate_baseline :
12131282 if not ANNDATA_AVAILABLE :
@@ -1246,14 +1315,32 @@ def main():
12461315 print ("Exiting..." )
12471316 sys .exit (1 )
12481317
1318+ def run_single_benchmark (name , factory , data_path , run_num = None ):
1319+ """Run a single benchmark iteration with optional run number for progress display."""
1320+ run_suffix = f" (run { run_num } /{ args .num_runs } )" if args .num_runs > 1 and run_num else ""
1321+ if args .num_runs > 1 and run_num :
1322+ print (f"\n --- Running { name } { run_suffix } ---" )
1323+
1324+ return benchmark_dataloader (
1325+ name = name ,
1326+ dataloader_factory = factory ,
1327+ data_path = data_path ,
1328+ num_epochs = args .num_epochs ,
1329+ max_time_seconds = args .max_time ,
1330+ warmup_time_seconds = args .warmup_time ,
1331+ print_progress = True ,
1332+ )
1333+
12491334 try :
12501335 if args .generate_baseline :
12511336 # Run comparison benchmark
12521337 print (f"\n Running SCDL vs AnnData comparison: { Path (input_path ).name } " )
12531338 print (f"Sampling: { args .sampling_scheme } " )
1339+ if args .num_runs > 1 :
1340+ print (f"Number of runs: { args .num_runs } (results will be averaged)" )
12541341 print ("This will benchmark both SCDL and AnnData approaches...\n " )
12551342
1256- # Run SCDL benchmark
1343+ # Run SCDL benchmark(s)
12571344 print ("=== Running SCDL Benchmark ===" )
12581345 if args .scdl_path :
12591346 scdl_path = args .scdl_path
@@ -1263,32 +1350,46 @@ def main():
12631350 str (scdl_path ), args .sampling_scheme , args .batch_size , use_anndata = False
12641351 )
12651352
1266- scdl_result = benchmark_dataloader (
1267- name = f"SCDL-{ args .sampling_scheme } " ,
1268- dataloader_factory = scdl_factory ,
1269- data_path = str (scdl_path ),
1270- num_epochs = args .num_epochs ,
1271- max_time_seconds = args .max_time ,
1272- warmup_time_seconds = args .warmup_time ,
1273- print_progress = True ,
1274- )
1353+ scdl_results = []
1354+ for run_num in range (1 , args .num_runs + 1 ):
1355+ result = run_single_benchmark (
1356+ f"SCDL-{ args .sampling_scheme } " ,
1357+ scdl_factory ,
1358+ str (scdl_path ),
1359+ run_num if args .num_runs > 1 else None ,
1360+ )
1361+ scdl_results .append (result )
1362+
1363+ # Average SCDL results if multiple runs
1364+ if args .num_runs > 1 :
1365+ scdl_result = average_benchmark_results (scdl_results , f"SCDL-{ args .sampling_scheme } " )
1366+ print (f"\n SCDL benchmark completed: averaged { len (scdl_results )} runs" )
1367+ else :
1368+ scdl_result = scdl_results [0 ]
12751369
1276- # Run AnnData benchmark
1370+ # Run AnnData benchmark(s)
12771371 adata_path = input_path
12781372 print ("\n === Running AnnData Benchmark ===" )
12791373 anndata_factory = create_dataloader_factory (
12801374 str (adata_path ), args .sampling_scheme , args .batch_size , use_anndata = True
12811375 )
12821376
1283- anndata_result = benchmark_dataloader (
1284- name = f"AnnData-{ args .sampling_scheme } " ,
1285- dataloader_factory = anndata_factory ,
1286- data_path = str (adata_path ),
1287- num_epochs = args .num_epochs ,
1288- max_time_seconds = args .max_time ,
1289- warmup_time_seconds = args .warmup_time ,
1290- print_progress = True ,
1291- )
1377+ anndata_results = []
1378+ for run_num in range (1 , args .num_runs + 1 ):
1379+ result = run_single_benchmark (
1380+ f"AnnData-{ args .sampling_scheme } " ,
1381+ anndata_factory ,
1382+ str (adata_path ),
1383+ run_num if args .num_runs > 1 else None ,
1384+ )
1385+ anndata_results .append (result )
1386+
1387+ # Average AnnData results if multiple runs
1388+ if args .num_runs > 1 :
1389+ anndata_result = average_benchmark_results (anndata_results , f"AnnData-{ args .sampling_scheme } " )
1390+ print (f"\n AnnData benchmark completed: averaged { len (anndata_results )} runs" )
1391+ else :
1392+ anndata_result = anndata_results [0 ]
12921393
12931394 # Format and output comparison report
12941395 comparison_report = format_comparison_report (
@@ -1323,17 +1424,23 @@ def main():
13231424
13241425 print (f"\n Benchmarking: { Path (input_path ).name } " )
13251426 print (f"Sampling: { args .sampling_scheme } " )
1427+ if args .num_runs > 1 :
1428+ print (f"Number of runs: { args .num_runs } (results will be averaged)" )
13261429 print ("Running benchmark...\n " )
13271430
1328- result = benchmark_dataloader (
1329- name = f"SCDL-{ args .sampling_scheme } " ,
1330- dataloader_factory = factory ,
1331- data_path = str (input_path ),
1332- num_epochs = args .num_epochs ,
1333- max_time_seconds = args .max_time ,
1334- warmup_time_seconds = args .warmup_time ,
1335- print_progress = True ,
1336- )
1431+ results = []
1432+ for run_num in range (1 , args .num_runs + 1 ):
1433+ single_result = run_single_benchmark (
1434+ f"SCDL-{ args .sampling_scheme } " , factory , str (input_path ), run_num if args .num_runs > 1 else None
1435+ )
1436+ results .append (single_result )
1437+
1438+ # Average results if multiple runs
1439+ if args .num_runs > 1 :
1440+ result = average_benchmark_results (results , f"SCDL-{ args .sampling_scheme } " )
1441+ print (f"\n Benchmark completed: averaged { len (results )} runs" )
1442+ else :
1443+ result = results [0 ]
13371444
13381445 # Format and output report
13391446 report = format_report (result , str (input_path ), args .sampling_scheme )
0 commit comments