33import subprocess
44import time
55import yaml
6+ import argparse
7+ import concurrent .futures
68from pathlib import Path
79import logging
810
@@ -21,20 +23,8 @@ def is_run_complete(save_dir):
2123 # Check for metrics.csv or a specific completion flag
2224 return (save_dir / "metrics.csv" ).exists ()
2325
24- def main ():
25- configs = get_configs ()
26- logger .info (f"Found { len (configs )} experiments to run." )
27-
28- # Verify Data Exists
29- # We expect data/d4rl/{env}/dataset_v1.npz
30- # We can check the first config to see what env it needs, but roughly:
31- if not DATA_DIR .exists ():
32- logger .error (f"Data directory { DATA_DIR } does not exist. Please run convert_d4rl.py first." )
33- # We could try to run conversion here automatically?
34- # Let's assume the user/previous step handled it, or alert.
35- pass
36-
37- for config_path in configs :
26+ def run_experiment (config_path , worker_id = 0 ):
27+ try :
3828 with open (config_path , 'r' ) as f :
3929 cfg = yaml .safe_load (f )
4030
@@ -49,9 +39,9 @@ def main():
4939
5040 if is_run_complete (save_dir ):
5141 logger .info (f"Skipping { model } on { env } (Run Complete)" )
52- continue
42+ return
5343
54- logger .info (f"Starting { model } on { env } ..." )
44+ logger .info (f"Worker { worker_id } : Starting { model } on { env } ..." )
5545
5646 # Construct Command
5747 cmd = [
@@ -61,21 +51,52 @@ def main():
6151 "--env" , env ,
6252 "--model" , model ,
6353 "--seed" , str (seed ),
64- # Add dataset path explicitly if needed, but train.py infers it.
65- # train.py infers: project_root / f"data/{args.env}/dataset.npz"
66- # Our convert script puts it in: data/d4rl/{env}/dataset_v1.npz
67- # This is a MISMATCH. We need to point train.py to the right place.
6854 "--dataset-path" , str (DATA_DIR / env / "dataset_v1.npz" )
6955 ]
7056
71- try :
72- # Run Synchronously for now
73- subprocess .run (cmd , check = True )
74- logger .info (f"Finished { model } on { env } " )
75- except subprocess .CalledProcessError as e :
76- logger .error (f"Failed { model } on { env } : { e } " )
77- # Continue to next experiment?
78- time .sleep (1 )
57+ # Set environment variables for this process to limit CPU usage
58+ env_vars = os .environ .copy ()
59+ # Limit threads per process to avoid thrashing
60+ # Assuming 3 workers on a typical 12+ core machine, 4 threads each is safe.
61+ # If user has fewer cores, they should reduce max-workers.
62+ env_vars ["OMP_NUM_THREADS" ] = "4"
63+ env_vars ["MKL_NUM_THREADS" ] = "4"
64+ env_vars ["TORCH_NUM_THREADS" ] = "4"
65+
66+ # Run Synchronously (within the worker thread)
67+ subprocess .run (cmd , check = True , env = env_vars )
68+ logger .info (f"Worker { worker_id } : Finished { model } on { env } " )
69+
70+ except subprocess .CalledProcessError as e :
71+ logger .error (f"Worker { worker_id } : Failed { model } on { env } : { e } " )
72+ except Exception as e :
73+ logger .error (f"Worker { worker_id } : Error processing { config_path } : { e } " )
74+
75+ def main ():
76+ parser = argparse .ArgumentParser (description = "Run Phase 0 Experiments" )
77+ parser .add_argument ("--max-workers" , type = int , default = 3 , help = "Number of parallel experiments to run" )
78+ args = parser .parse_args ()
79+
80+ configs = get_configs ()
81+ logger .info (f"Found { len (configs )} experiments to run." )
82+
83+ if not DATA_DIR .exists ():
84+ logger .error (f"Data directory { DATA_DIR } does not exist. Please run convert_d4rl.py first." )
85+ return
86+
87+ # Use ThreadPoolExecutor to run experiments in parallel
88+ with concurrent .futures .ThreadPoolExecutor (max_workers = args .max_workers ) as executor :
89+ futures = []
90+ for i , config_path in enumerate (configs ):
91+ # i % args.max_workers is just a rough worker ID for logging
92+ futures .append (executor .submit (run_experiment , config_path , i % args .max_workers ))
93+
94+ # Wait for all futures to complete
95+ for future in concurrent .futures .as_completed (futures ):
96+ try :
97+ future .result ()
98+ except Exception as e :
99+ logger .error (f"An experiment failed with exception: { e } " )
79100
80101if __name__ == "__main__" :
81102 main ()
0 commit comments