Skip to content

Commit 356925e

Browse files
fixing my ablations...
1 parent b1a848f commit 356925e

10 files changed

Lines changed: 39 additions & 20 deletions

File tree

ablation_studies/experiment_contract_light.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ batch_size: 64
1111
optimizer: AdamW
1212
lr: 3e-4
1313
weight_decay: 1e-2
14-
epochs: 100
14+
epochs: 500
1515
local_lr_eta_local: 0.05
1616
surrogate_slope_k: 10
1717
spike_energy_pJ: 5.0
@@ -32,4 +32,4 @@ cql:
3232
hidden_size: 256
3333
with_lagrange: false
3434
cql_weight: 1.0
35-
target_action_gap: 10.0
35+
target_action_gap: 10.0

ablation_studies/run_experiment.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
# --- Add Project Root to sys.path ---
2020
project_root = Path(__file__).resolve().parent.parent
2121
sys.path.append(str(project_root))
22+
# Add snn-dt/src to sys.path to allow for model imports
23+
snn_dt_src_path = project_root / 'snn-dt' / 'src'
24+
if snn_dt_src_path.exists():
25+
sys.path.insert(0, str(snn_dt_src_path))
2226

2327
# --- Local Imports ---
2428
from ablation_studies.src.datasets import OfflineSequenceDataset, OfflineTransitionDataset
@@ -52,10 +56,10 @@ def load_config(contract_path, variant_path):
5256
# --- Model Factory ---
5357
def get_model(cfg):
5458
model_name_map = {
55-
'dt': ('snn_dt.src.models.dt', 'DecisionTransformer'),
56-
'snn_dt': ('snn_dt.src.models.snn_dt', 'SnnDt'),
57-
'iql': ('snn_dt.src.models.iql', 'IQL'),
58-
'cql': ('snn_dt.src.models.cql', 'CQL'),
59+
'dt': ('models.dt', 'DecisionTransformer'),
60+
'snn_dt': ('models.snn_dt', 'SnnDt'),
61+
'iql': ('models.iql', 'IQL'),
62+
'cql': ('models.cql', 'CQL'),
5963
'ablation_dsformer': ('ablation_studies.src.models.ablation_dsformer', 'AblationDsFormer'),
6064
}
6165

@@ -115,9 +119,9 @@ def evaluate_policy(model, env_name, cfg):
115119
action = int(np.argmax(action))
116120

117121
state, reward, terminated, truncated, _ = env.step(action)
118-
done = terminated or truncated or (t + 1 >= cfg.dataset.max_timesteps)
122+
done = terminated or truncated or (t >= cfg.sequence_length_N - 1)
119123

120-
if t < cfg.sequence_length_N - 1:
124+
if not done:
121125
actions[0, t+1] = torch.tensor(action, device=cfg.device)
122126
states[0, t+1] = torch.from_numpy(state).to(cfg.device)
123127
rtgs[0, t+1] = rtgs[0, t] - reward
@@ -160,14 +164,18 @@ def train(cfg, logger):
160164
dataset_args = {'path': str(dataset_path)}
161165
if not is_transition_model: dataset_args['seq_len'] = cfg.sequence_length_N
162166
dataset = DatasetClass(**dataset_args)
163-
train_loader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=min(os.cpu_count(), 4))
167+
# NOTE: num_workers is set to 0 to avoid a hanging issue with multiprocessing.
168+
train_loader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0)
164169

165170
model = get_model(cfg).to(cfg.device)
166171
optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg.lr), weight_decay=float(cfg.weight_decay)) if list(model.parameters()) else None
167172
loss_fn = torch.nn.MSELoss()
168173

169174
logger.info(json.dumps({"train/param_count": sum(p.numel() for p in model.parameters())}))
170175

176+
# Create save directory if it doesn't exist
177+
Path(cfg.save_dir).mkdir(parents=True, exist_ok=True)
178+
171179
for epoch in range(1, cfg.epochs + 1):
172180
model.train()
173181
for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.epochs}", file=sys.stderr)):
Binary file not shown.
Binary file not shown.
Binary file not shown.

ablation_studies/runs/no_plasticity/seed_0/CartPole-v1/metrics.jsonl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,15 @@
77
{"train/param_count": 399889}
88
{"train/step": 250, "train/loss": 0.1063467487692833}
99
{"train/step": 500, "train/loss": 0.08335910737514496}
10+
{"train/param_count": 399889}
11+
{"train/step": 250, "train/loss": 0.1063467487692833}
12+
{"train/step": 500, "train/loss": 0.08335910737514496}
13+
{"epoch": 10, "val/mean_return": 20.0, "val/std_return": 0.0}
14+
{"train/step": 750, "train/loss": 0.08903709799051285}
15+
{"train/step": 1000, "train/loss": 0.06532430648803711}
16+
{"train/step": 1250, "train/loss": 0.07413662225008011}
17+
{"epoch": 20, "val/mean_return": 20.0, "val/std_return": 0.0}
18+
{"train/step": 1500, "train/loss": 0.06820853054523468}
19+
{"train/step": 1750, "train/loss": 0.06684345006942749}
20+
{"train/step": 2000, "train/loss": 0.06681139767169952}
21+
{"epoch": 30, "val/mean_return": 20.0, "val/std_return": 0.0}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
Command: D:\Github\neuromorphic_decision_transformer\ablation_studies\run_experiment.py --variant no_plasticity --env CartPole-v1 --seed 0 --contract experiment_contract_light.yaml
2-
Git Hash: fd20c7baafb79138e1195b4cb4a1a62d3ace4f60
2+
Git Hash: b1a848ff72b193feb2a523af890fd78e7a644a82

ablation_studies/scripts/run_ablations.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,19 @@ def main():
4040
"--contract", CONTRACT
4141
]
4242

43-
print(f"[{current_job}/{total_jobs}] Running: Variant={variant}, Env={env}, Seed={seed}")
43+
print(f"\n--- [{current_job}/{total_jobs}] Running: Variant={variant}, Env={env}, Seed={seed} ---")
4444

4545
if args.dry_run:
46-
print(f"Command: {' '.join(cmd)}")
46+
print(f" Command: {' '.join(cmd)}")
4747
else:
4848
try:
4949
subprocess.run(cmd, check=True)
50+
print(f"--- Finished: Variant={variant}, Env={env}, Seed={seed} (Success) ---")
5051
except subprocess.CalledProcessError as e:
51-
print(f"Error running job: {e}")
52-
# Depending on preference, we might want to continue or stop.
53-
# For now, let's continue to the next one but log the error.
54-
print("Continuing to next job...")
52+
print(f" Error running job: {e}")
53+
print(f"--- Finished: Variant={variant}, Env={env}, Seed={seed} (Failed) ---")
5554

56-
print("--- All targeted experimental runs complete! ---")
55+
print("\n--- All targeted experimental runs complete! ---")
5756

5857
if __name__ == "__main__":
5958
main()

ablation_studies/src/datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ class OfflineTransitionDataset(Dataset):
3232
Dataset for transition-based models like IQL and CQL.
3333
Processes trajectories into individual (s, a, r, s', d) transitions.
3434
"""
35-
def __init__(self, dataset_path):
36-
data = np.load(dataset_path, mmap_mode='r')
35+
def __init__(self, path):
36+
data = np.load(path, mmap_mode='r')
3737

3838
# Calculate total number of transitions
3939
total_transitions = int(np.sum(data['mask'])) - data['mask'].shape[0]

ablation_studies/src/models/ablation_dsformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def forward(self, batch):
138138

139139
state_embed = self.embed_state(batch["states"])
140140
action_embed = self.embed_action(batch["actions"])
141-
rtg_embed = self.embed_return(batch["returns_to_go"])
141+
rtg_embed = self.embed_return(batch["returns_to_go"].float())
142142
time_embed = self.embed_timestep(batch["timesteps"].squeeze(-1))
143143

144144
state_embed, action_embed, rtg_embed = state_embed + time_embed, action_embed + time_embed, rtg_embed + time_embed

0 commit comments

Comments
 (0)