Skip to content

Commit 7145ed6

Browse files
update: train.py
1 parent 584b3cf commit 7145ed6

2 files changed

Lines changed: 16 additions & 10 deletions

File tree

results/all_runs/dsformer_CartPole-v1/training.log

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,11 @@
2222
2025-11-09 14:50:55,110 [INFO] Dataset size: 1000 clips
2323
2025-11-09 14:50:55,139 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
2424
2025-11-09 14:51:05,312 [INFO] Starting training loop...
25+
2025-11-12 09:53:11,160 [INFO] Checking for dataset...
26+
2025-11-12 09:53:11,277 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
27+
2025-11-12 09:53:11,277 [INFO] --- Checkpoint: Starting training ---
28+
2025-11-12 09:53:11,293 [INFO] --- Checkpoint: Save directory created at results\all_runs\dsformer_CartPole-v1 ---
29+
2025-11-12 09:53:11,457 [INFO] Dataset size: 1000 clips
30+
2025-11-12 09:53:11,499 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
31+
2025-11-12 09:53:11,556 [INFO] --- Checkpoint: Model 'dsformer' initialized on device 'cpu' ---
32+
2025-11-12 09:53:24,560 [INFO] --- Checkpoint: Starting main training loop ---

snn-dt/scripts/train.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -285,12 +285,15 @@ def train(cfg, logger):
285285
df.to_csv(save_dir / "metrics.csv", index=False)
286286

287287
# Update summary
288-
summary_path = Path(cfg.save_dir).parent.parent / "summary.csv"
289-
summary_df = pd.DataFrame([{"model": cfg.model.name, "env": cfg.env, "seed": cfg.seed, "return_mean": df["return_mean"].max()}])
290-
if summary_path.exists():
291-
summary_df.to_csv(summary_path, mode="a", header=False, index=False)
288+
if not df.empty and "return_mean" in df.columns:
289+
summary_path = Path(cfg.save_dir).parent.parent / "summary.csv"
290+
summary_df = pd.DataFrame([{"model": cfg.model.name, "env": cfg.env, "seed": cfg.seed, "return_mean": df["return_mean"].max()}])
291+
if summary_path.exists():
292+
summary_df.to_csv(summary_path, mode="a", header=False, index=False)
293+
else:
294+
summary_df.to_csv(summary_path, index=False)
292295
else:
293-
summary_df.to_csv(summary_path, index=False)
296+
logger.warning("No evaluation metrics found. Skipping summary generation.")
294297

295298
logger.info("Training complete.")
296299

@@ -397,11 +400,6 @@ def handle_exception(exc_type, exc_value, exc_traceback):
397400
# Convert to AttrDict for easy access
398401
cfg = AttrDict(cfg)
399402

400-
# Adaptive training controls for SNNs
401-
if "snn" in cfg.model.name or "dsformer" in cfg.model.name:
402-
cfg.training.batches_per_epoch = min(cfg.training.batches_per_epoch, cfg_raw.get("snn_batches_per_epoch", 100))
403-
cfg.training.eval_every = max(cfg.training.eval_every, cfg_raw.get("snn_eval_every", 50))
404-
405403
# Construct dataset path from env name, relative to project root
406404
cfg.dataset.path = str(project_root / f"data/{args.env}/dataset.npz")
407405

0 commit comments

Comments
 (0)