Skip to content

Commit 079847b

Browse files
Fixed: Training of baseline models
1 parent 4906783 commit 079847b

2 files changed

Lines changed: 41 additions & 1 deletion

File tree

results/all_runs/dt_CartPole-v1/training.log

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,41 @@
4141
2025-11-05 09:19:55,306 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
4242
2025-11-05 09:19:55,307 [INFO] Starting training...
4343
2025-11-05 09:19:55,445 [INFO] DataLoader created with num_workers=1.
44+
2025-11-05 10:40:41,355 [INFO] Checking for dataset...
45+
2025-11-05 10:40:41,375 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
46+
2025-11-05 10:40:41,375 [INFO] Starting training...
47+
2025-11-05 10:40:41,510 [INFO] Dataset size: 1000 clips
48+
2025-11-05 10:40:41,516 [INFO] DataLoader created with num_workers=1 and pin_memory=False.
49+
2025-11-05 10:40:54,189 [INFO] Starting training loop...
50+
2025-11-05 12:47:08,817 [INFO] Checking for dataset...
51+
2025-11-05 12:47:08,825 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
52+
2025-11-05 12:47:08,826 [INFO] Starting training...
53+
2025-11-05 12:47:08,970 [INFO] Dataset size: 1000 clips
54+
2025-11-05 12:47:08,975 [INFO] DataLoader created with num_workers=1 and pin_memory=False.
55+
2025-11-05 12:47:20,769 [INFO] Starting training loop...
56+
2025-11-05 13:09:19,915 [INFO] Checking for dataset...
57+
2025-11-05 13:09:19,917 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
58+
2025-11-05 13:09:19,918 [INFO] Starting training...
59+
2025-11-05 13:09:20,034 [INFO] Dataset size: 1000 clips
60+
2025-11-05 13:09:20,039 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
61+
2025-11-05 13:09:31,624 [INFO] Starting training loop...
62+
2025-11-05 13:24:22,013 [INFO] Checking for dataset...
63+
2025-11-05 13:24:22,027 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
64+
2025-11-05 13:24:22,028 [INFO] Starting training...
65+
2025-11-05 13:24:22,238 [INFO] Dataset size: 1000 clips
66+
2025-11-05 13:24:22,244 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
67+
2025-11-05 13:24:42,828 [INFO] Starting training loop...
68+
2025-11-05 13:24:42,862 [INFO] Epoch 1, Batch 1
69+
2025-11-05 13:25:08,877 [INFO] Epoch 1, Batch 2
70+
2025-11-05 13:25:30,430 [INFO] Epoch 1, Batch 3
71+
2025-11-05 13:25:55,037 [INFO] Epoch 1, Batch 4
72+
2025-11-05 13:26:17,308 [INFO] Epoch 1, Batch 5
73+
2025-11-05 13:26:34,649 [INFO] Epoch 1, Batch 6
74+
2025-11-05 13:26:53,128 [INFO] Epoch 1, Batch 7
75+
2025-11-05 13:27:15,121 [INFO] Epoch 1, Batch 8
76+
2025-11-05 13:27:33,624 [INFO] Epoch 1, Batch 9
77+
2025-11-05 13:27:47,284 [INFO] Epoch 1, Batch 10
78+
2025-11-05 13:28:04,851 [INFO] Epoch 1, Batch 11
79+
2025-11-05 13:28:17,372 [INFO] Epoch 1, Batch 12
80+
2025-11-05 13:28:33,096 [INFO] Epoch 1, Batch 13
81+
2025-11-05 13:28:49,192 [INFO] Epoch 1, Batch 14

snn-dt/scripts/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def train(cfg, logger):
120120
cfg.dataset.max_timesteps = metadata["max_timesteps"]
121121

122122
# OS-aware num_workers
123-
num_workers = cfg.training.get("num_workers", 1 if os.name == "nt" else 4)
123+
num_workers = 0
124124
from torch.utils.data import DataLoader
125125
train_loader = DataLoader(
126126
dataset,
@@ -152,6 +152,7 @@ def train(cfg, logger):
152152
start_time = time.time()
153153
epoch_losses = []
154154
for i, batch in enumerate(train_loader):
155+
logger.info(f"Epoch {epoch+1}, Batch {i+1}")
155156
if i >= cfg.training.batches_per_epoch:
156157
break
157158

@@ -267,6 +268,7 @@ def main():
267268
cfg = {
268269
"model": {
269270
"name": args.model,
271+
"seq_len": cfg_raw.get("seq_len"),
270272
"d_model": cfg_raw.get("hidden_dim", 128),
271273
"n_heads": cfg_raw.get("n_heads", 4),
272274
"n_layers": cfg_raw.get("n_layers", 4),

0 commit comments

Comments
 (0)