File tree Expand file tree Collapse file tree
results/all_runs/dt_CartPole-v1 Expand file tree Collapse file tree Original file line number Diff line number Diff line change 55pyyaml
66transformers
77bindsnet
8- pytest
8+ pytest
9+ tqdm
Original file line number Diff line number Diff line change 79792025-11-05 13:28:17,372 [INFO] Epoch 1, Batch 12
80802025-11-05 13:28:33,096 [INFO] Epoch 1, Batch 13
81812025-11-05 13:28:49,192 [INFO] Epoch 1, Batch 14
82+ 2025-11-05 13:42:57,570 [INFO] Checking for dataset...
83+ 2025-11-05 13:42:57,585 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
84+ 2025-11-05 13:42:57,585 [INFO] Starting training...
85+ 2025-11-05 13:42:57,704 [INFO] Dataset size: 1000 clips
86+ 2025-11-05 13:42:57,709 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
87+ 2025-11-05 13:43:09,026 [INFO] Starting training loop...
88+ 2025-11-05 14:33:37,438 [INFO] Checking for dataset...
89+ 2025-11-05 14:33:37,450 [INFO] Dataset found at D:\Github\neuromorphic_decision_transformer\data\CartPole-v1\dataset.npz.
90+ 2025-11-05 14:33:37,450 [INFO] Starting training...
91+ 2025-11-05 14:33:37,577 [INFO] Dataset size: 1000 clips
92+ 2025-11-05 14:33:37,582 [INFO] DataLoader created with num_workers=0 and pin_memory=False.
93+ 2025-11-05 14:33:48,431 [INFO] Starting training loop...
Original file line number Diff line number Diff line change 1616import torch
1717import yaml
1818from torch .utils .data import Dataset
19+ from tqdm import tqdm
1920import warnings
2021warnings .filterwarnings ('ignore' )
2122
@@ -151,11 +152,14 @@ def train(cfg, logger):
151152 for epoch in range (cfg .training .epochs ):
152153 start_time = time .time ()
153154 epoch_losses = []
154- for i , batch in enumerate (train_loader ):
155- logger .info (f"Epoch { epoch + 1 } , Batch { i + 1 } " )
156- if i >= cfg .training .batches_per_epoch :
157- break
155+
156+ batch_iter = tqdm (
157+ enumerate (train_loader ),
158+ total = len (train_loader ),
159+ desc = f"Epoch { epoch + 1 } /{ cfg .training .epochs } "
160+ )
158161
162+ for i , batch in batch_iter :
159163 model .train ()
160164
161165 for k , v in batch .items ():
@@ -177,6 +181,7 @@ def train(cfg, logger):
177181 torch .nn .utils .clip_grad_norm_ (model .parameters (), 1.0 )
178182 optimizer .step ()
179183 epoch_losses .append (loss .item ())
184+ batch_iter .set_postfix (loss = f"{ np .mean (epoch_losses ):.4f} " )
180185
181186 # Evaluation, Checkpointing, and Logging
182187 if (epoch + 1 ) % cfg .training .eval_every == 0 :
You can’t perform that action at this time.
0 commit comments