Skip to content

Commit 032b570

Browse files
updated ablation studies
1 parent ad692ea commit 032b570

9 files changed

Lines changed: 199 additions & 185 deletions

File tree

ablation_studies/all_commands.txt

Lines changed: 37 additions & 160 deletions
Large diffs are not rendered by default.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
phase_encoder:
2+
enabled: true
3+
routing:
4+
enabled: true
5+
local_plasticity:
6+
enabled: true

ablation_studies/logs/ablation_studies_dry_run.log renamed to ablation_studies/random_logs/ablation_studies_dry_run.log

File renamed without changes.
File renamed without changes.
File renamed without changes.
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import argparse
2+
import subprocess
3+
import sys
4+
import numpy as np
5+
import json
6+
from pathlib import Path
7+
8+
# --- Configuration ---
9+
DEFAULT_SEEDS = 5
10+
DEFAULT_CONTRACT = "experiment_contract_light.yaml"
11+
12+
def run_single_seed(variant, env, seed, contract, device=None):
13+
"""
14+
Runs a single experiment seed using run_experiment.py via subprocess.
15+
Returns the result dictionary (or None if failed).
16+
"""
17+
cmd = [
18+
sys.executable,
19+
"ablation_studies/run_experiment.py",
20+
"--variant", variant,
21+
"--env", env,
22+
"--seed", str(seed),
23+
"--contract", contract
24+
]
25+
26+
print(f" > Starting Seed {seed}...")
27+
try:
28+
# Run the command and capture output
29+
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
30+
31+
# Parse the output to find the final metrics (jsonl or stdout)
32+
# We assume run_experiment.py logs the final metrics in a way we can grab,
33+
# but since it saves to runs/.../metrics.jsonl, we can also read that file.
34+
35+
return True
36+
37+
except subprocess.CalledProcessError as e:
38+
print(f" !!! Error running seed {seed} !!!")
39+
print(e.stderr)
40+
return False
41+
42+
def get_run_metrics(variant, env, seed):
43+
"""
44+
Reads the metrics.jsonl file for a specific run to get the final performance.
45+
"""
46+
# Structure: runs/{variant}/seed_{seed}/{env}/metrics.jsonl
47+
# Note: run_experiment.py logic for run_name:
48+
# run_name = cfg.model.name if cfg.model.name != 'ablation_dsformer' else args.variant
49+
# This might need some adjustment if the variant name != model name mapping is complex.
50+
# Based on run_experiment.py:
51+
# model_name = cfg.get('model', {}).get('name', args.variant if args.variant in ['dt', 'snn_dt', 'iql', 'cql'] else 'ablation_dsformer')
52+
# run_name = cfg.model.name if cfg.model.name != 'ablation_dsformer' else args.variant
53+
54+
# We will try to reconstruct the path.
55+
project_root = Path(__file__).parent
56+
57+
# Determine directory name based on variant logic from run_experiment.py
58+
# If variant is simple, dir is variant. If dsformer, it's the variant name.
59+
# To be safe, we check both possible paths.
60+
61+
possible_run_names = [variant]
62+
# Add mapped names if necessary, but 'snn_dt', 'iql' etc map to themselves usually unless configured otherwise.
63+
64+
metrics_file = None
65+
for r_name in possible_run_names:
66+
p = project_root / "runs" / r_name / f"seed_{seed}" / env / "metrics.jsonl"
67+
if p.exists():
68+
metrics_file = p
69+
break
70+
71+
if not metrics_file:
72+
# Fallback check for model-based names if variant was just a config name
73+
# E.g. variant 'no_plasticity' might map to model 'ablation_dsformer' -> run_name 'no_plasticity'
74+
# It seems consistent.
75+
print(f" [Warning] Could not find metrics file for {variant} seed {seed}")
76+
return None
77+
78+
final_return = None
79+
try:
80+
with open(metrics_file, 'r') as f:
81+
for line in f:
82+
if not line.strip(): continue
83+
data = json.loads(line)
84+
if 'val/mean_return' in data:
85+
final_return = data['val/mean_return']
86+
except Exception as e:
87+
print(f" [Error] Reading metrics file: {e}")
88+
89+
return final_return
90+
91+
def main():
92+
parser = argparse.ArgumentParser(description="Run a group of ablation experiments (multiple seeds) and report Mean +/- Std.")
93+
parser.add_argument("--variant", required=True, help="Experiment variant (e.g., snn_dt, no_plasticity)")
94+
parser.add_argument("--env", required=True, help="Environment (e.g., CartPole-v1)")
95+
parser.add_argument("--num_seeds", type=int, default=DEFAULT_SEEDS, help="Number of seeds to run (0 to N-1)")
96+
parser.add_argument("--contract", default=DEFAULT_CONTRACT, help="Experiment contract YAML")
97+
98+
args = parser.parse_args()
99+
100+
print(f"\n=======================================================")
101+
print(f" Running Ablation Group: {args.variant} | {args.env}")
102+
print(f" Seeds: 0 to {args.num_seeds - 1}")
103+
print(f"=======================================================\n")
104+
105+
returns = []
106+
107+
for seed in range(args.num_seeds):
108+
success = run_single_seed(args.variant, args.env, seed, args.contract)
109+
if success:
110+
val_return = get_run_metrics(args.variant, args.env, seed)
111+
if val_return is not None:
112+
returns.append(val_return)
113+
print(f" > Seed {seed} Finished. Return: {val_return:.2f}")
114+
else:
115+
print(f" > Seed {seed} Finished but no return found.")
116+
else:
117+
print(f" > Seed {seed} FAILED.")
118+
119+
print(f"\n=======================================================")
120+
if returns:
121+
mean_ret = np.mean(returns)
122+
std_ret = np.std(returns)
123+
print(f" FINAL RESULT [{args.variant} / {args.env}]:")
124+
print(f" Mean Return: {mean_ret:.2f} ± {std_ret:.2f}")
125+
print(f" (Based on {len(returns)}/{args.num_seeds} successful runs)")
126+
else:
127+
print(f" NO SUCCESSFUL RUNS.")
128+
print(f"=======================================================\n")
129+
130+
if __name__ == "__main__":
131+
main()

ablation_studies/scripts/run_ablations.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pathlib import Path
55

66
# Configuration
7-
VARIANTS = ["no_plasticity", "no_routing", "no_phase", "dt", "snn_dt", "iql", "cql", "full"]
7+
VARIANTS = ["no_plasticity", "no_routing", "no_phase", "dt", "snn_dt", "iql", "cql", "full", "dsformer"]
88
ENVS = ["CartPole-v1", "Acrobot-v1", "Pendulum-v1", "MountainCar-v0"]
99
SEEDS = [0, 1, 2, 3, 4]
1010
CONTRACT = "experiment_contract_light.yaml"
@@ -24,33 +24,32 @@ def main():
2424
print(f"Contract: {CONTRACT}")
2525
print(f"---------------------------------")
2626

27-
total_jobs = len(VARIANTS) * len(ENVS) * len(SEEDS)
27+
# Remove seed loop, run_grouped_ablation handles it
28+
total_jobs = len(VARIANTS) * len(ENVS)
2829
current_job = 0
2930

3031
for env in ENVS:
3132
for variant in VARIANTS:
32-
for seed in SEEDS:
33-
current_job += 1
34-
cmd = [
35-
sys.executable,
36-
str(run_script),
37-
"--variant", variant,
38-
"--env", env,
39-
"--seed", str(seed),
40-
"--contract", CONTRACT
41-
]
33+
current_job += 1
34+
cmd = [
35+
sys.executable,
36+
"ablation_studies/run_grouped_ablation.py",
37+
"--variant", variant,
38+
"--env", env,
39+
"--contract", CONTRACT
40+
]
4241

43-
print(f"\n--- [{current_job}/{total_jobs}] Running: Variant={variant}, Env={env}, Seed={seed} ---")
44-
45-
if args.dry_run:
46-
print(f" Command: {' '.join(cmd)}")
47-
else:
48-
try:
49-
subprocess.run(cmd, check=True)
50-
print(f"--- Finished: Variant={variant}, Env={env}, Seed={seed} (Success) ---")
51-
except subprocess.CalledProcessError as e:
52-
print(f" Error running job: {e}")
53-
print(f"--- Finished: Variant={variant}, Env={env}, Seed={seed} (Failed) ---")
42+
print(f"\n--- [{current_job}/{total_jobs}] Running: Variant={variant}, Env={env} ---")
43+
44+
if args.dry_run:
45+
print(f" Command: {' '.join(cmd)}")
46+
else:
47+
try:
48+
subprocess.run(cmd, check=True)
49+
print(f"--- Finished: Variant={variant}, Env={env} (Success) ---")
50+
except subprocess.CalledProcessError as e:
51+
print(f" Error running job: {e}")
52+
print(f"--- Finished: Variant={variant}, Env={env} (Failed) ---")
5453

5554
print("\n--- All targeted experimental runs complete! ---")
5655

snn-dt/src/models/iql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(self, cfg):
112112
self.tau = cfg.iql.tau
113113
self.temperature = cfg.iql.temperature
114114
self.expectile = cfg.iql.expectile
115-
self.is_discrete = cfg.dataset.is_discrete
115+
self.is_discrete = 'CartPole' in cfg.env or 'Acrobot' in cfg.env or 'MountainCar' in cfg.env
116116

117117
self.actor = Actor(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size, is_discrete=self.is_discrete).to(self.device)
118118
self.critic1 = Critic(cfg.dataset.state_dim, cfg.dataset.act_dim, cfg.iql.hidden_size, is_discrete=self.is_discrete).to(self.device)

verify_iql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def __getattr__(self, name):
3232
cfg.dataset = AttrDict()
3333
cfg.dataset.act_dim = 2
3434
cfg.dataset.state_dim = 4
35-
cfg.dataset.is_discrete = True
35+
cfg.dataset.state_dim = 4
36+
# cfg.dataset.is_discrete = True # Removed to verify the fix works without this setting
3637

3738
try:
3839
model = IQL(cfg)

0 commit comments

Comments
 (0)