Skip to content

Commit 7199bb1

Browse files
Update run_phase0_experiments.py
1 parent fece993 commit 7199bb1

1 file changed

Lines changed: 49 additions & 28 deletions

File tree

scripts/run_phase0_experiments.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import subprocess
44
import time
55
import yaml
6+
import argparse
7+
import concurrent.futures
68
from pathlib import Path
79
import logging
810

@@ -21,20 +23,8 @@ def is_run_complete(save_dir):
2123
# Check for metrics.csv or a specific completion flag
2224
return (save_dir / "metrics.csv").exists()
2325

24-
def main():
25-
configs = get_configs()
26-
logger.info(f"Found {len(configs)} experiments to run.")
27-
28-
# Verify Data Exists
29-
# We expect data/d4rl/{env}/dataset_v1.npz
30-
# We can check the first config to see what env it needs, but roughly:
31-
if not DATA_DIR.exists():
32-
logger.error(f"Data directory {DATA_DIR} does not exist. Please run convert_d4rl.py first.")
33-
# We could try to run conversion here automatically?
34-
# Let's assume the user/previous step handled it, or alert.
35-
pass
36-
37-
for config_path in configs:
26+
def run_experiment(config_path, worker_id=0):
27+
try:
3828
with open(config_path, 'r') as f:
3929
cfg = yaml.safe_load(f)
4030

@@ -49,9 +39,9 @@ def main():
4939

5040
if is_run_complete(save_dir):
5141
logger.info(f"Skipping {model} on {env} (Run Complete)")
52-
continue
42+
return
5343

54-
logger.info(f"Starting {model} on {env}...")
44+
logger.info(f"Worker {worker_id}: Starting {model} on {env}...")
5545

5646
# Construct Command
5747
cmd = [
@@ -61,21 +51,52 @@ def main():
6151
"--env", env,
6252
"--model", model,
6353
"--seed", str(seed),
64-
# Add dataset path explicitly if needed, but train.py infers it.
65-
# train.py infers: project_root / f"data/{args.env}/dataset.npz"
66-
# Our convert script puts it in: data/d4rl/{env}/dataset_v1.npz
67-
# This is a MISMATCH. We need to point train.py to the right place.
6854
"--dataset-path", str(DATA_DIR / env / "dataset_v1.npz")
6955
]
7056

71-
try:
72-
# Run Synchronously for now
73-
subprocess.run(cmd, check=True)
74-
logger.info(f"Finished {model} on {env}")
75-
except subprocess.CalledProcessError as e:
76-
logger.error(f"Failed {model} on {env}: {e}")
77-
# Continue to next experiment?
78-
time.sleep(1)
57+
# Set environment variables for this process to limit CPU usage
58+
env_vars = os.environ.copy()
59+
# Limit threads per process to avoid thrashing
60+
# Assuming 3 workers on a typical 12+ core machine, 4 threads each is safe.
61+
# If user has fewer cores, they should reduce max-workers.
62+
env_vars["OMP_NUM_THREADS"] = "4"
63+
env_vars["MKL_NUM_THREADS"] = "4"
64+
env_vars["TORCH_NUM_THREADS"] = "4"
65+
66+
# Run Synchronously (within the worker thread)
67+
subprocess.run(cmd, check=True, env=env_vars)
68+
logger.info(f"Worker {worker_id}: Finished {model} on {env}")
69+
70+
except subprocess.CalledProcessError as e:
71+
logger.error(f"Worker {worker_id}: Failed {model} on {env}: {e}")
72+
except Exception as e:
73+
logger.error(f"Worker {worker_id}: Error processing {config_path}: {e}")
74+
75+
def main():
76+
parser = argparse.ArgumentParser(description="Run Phase 0 Experiments")
77+
parser.add_argument("--max-workers", type=int, default=3, help="Number of parallel experiments to run")
78+
args = parser.parse_args()
79+
80+
configs = get_configs()
81+
logger.info(f"Found {len(configs)} experiments to run.")
82+
83+
if not DATA_DIR.exists():
84+
logger.error(f"Data directory {DATA_DIR} does not exist. Please run convert_d4rl.py first.")
85+
return
86+
87+
# Use ThreadPoolExecutor to run experiments in parallel
88+
with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
89+
futures = []
90+
for i, config_path in enumerate(configs):
91+
# i % args.max_workers is just a rough worker ID for logging
92+
futures.append(executor.submit(run_experiment, config_path, i % args.max_workers))
93+
94+
# Wait for all futures to complete
95+
for future in concurrent.futures.as_completed(futures):
96+
try:
97+
future.result()
98+
except Exception as e:
99+
logger.error(f"An experiment failed with exception: {e}")
79100

80101
if __name__ == "__main__":
81102
main()

0 commit comments

Comments
 (0)