Skip to content

Commit f3d02a1

Browse files
training bar
1 parent 079847b commit f3d02a1

3 files changed

Lines changed: 23 additions & 5 deletions

File tree

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ norse
55
pyyaml
66
transformers
77
bindsnet
8-
pytest
8+
pytest
9+
tqdm

results/all_runs/dt_CartPole-v1/training.log

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,15 @@
7979
2025-11-05 13:28:17,372 [INFO] Epoch 1, Batch 12
8080
2025-11-05 13:28:33,096 [INFO] Epoch 1, Batch 13
8181
2025-11-05 13:28:49,192 [INFO] Epoch 1, Batch 14
82+
2025-11-05 13:42:57,570 [INFO] Checking for dataset...
83+
2025-11-05 13:42:57,585 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
84+
2025-11-05 13:42:57,585 [INFO] Starting training...
85+
2025-11-05 13:42:57,704 [INFO] Dataset size: 1000 clips
86+
2025-11-05 13:42:57,709 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
87+
2025-11-05 13:43:09,026 [INFO] Starting training loop...
88+
2025-11-05 14:33:37,438 [INFO] Checking for dataset...
89+
2025-11-05 14:33:37,450 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
90+
2025-11-05 14:33:37,450 [INFO] Starting training...
91+
2025-11-05 14:33:37,577 [INFO] Dataset size: 1000 clips
92+
2025-11-05 14:33:37,582 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
93+
2025-11-05 14:33:48,431 [INFO] Starting training loop...

snn-dt/scripts/train.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717
import yaml
1818
from torch.utils.data import Dataset
19+
from tqdm import tqdm
1920
import warnings
2021
warnings.filterwarnings('ignore')
2122

@@ -151,11 +152,14 @@ def train(cfg, logger):
151152
for epoch in range(cfg.training.epochs):
152153
start_time = time.time()
153154
epoch_losses = []
154-
for i, batch in enumerate(train_loader):
155-
logger.info(f"Epoch {epoch+1}, Batch {i+1}")
156-
if i >= cfg.training.batches_per_epoch:
157-
break
155+
156+
batch_iter = tqdm(
157+
enumerate(train_loader),
158+
total=len(train_loader),
159+
desc=f"Epoch {epoch+1}/{cfg.training.epochs}"
160+
)
158161

162+
for i, batch in batch_iter:
159163
model.train()
160164

161165
for k, v in batch.items():
@@ -177,6 +181,7 @@ def train(cfg, logger):
177181
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
178182
optimizer.step()
179183
epoch_losses.append(loss.item())
184+
batch_iter.set_postfix(loss=f"{np.mean(epoch_losses):.4f}")
180185

181186
# Evaluation, Checkpointing, and Logging
182187
if (epoch + 1) % cfg.training.eval_every == 0:

0 commit comments

Comments
 (0)