Skip to content

Commit 4906783

Browse files
Update train.py
1 parent 7fa84d8 commit 4906783

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

snn-dt/scripts/train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ def train(cfg, logger):
104104

105105
# Load data and metadata
106106
dataset = OfflineDataset(cfg.dataset.path)
107-
assert len(dataset) > 0, f"Dataset at {cfg.dataset.path} is empty."
107+
if len(dataset) == 0:
108+
logger.error(f"Dataset at {cfg.dataset.path} is empty! Aborting training.")
109+
sys.exit(1)
110+
logger.info(f"Dataset size: {len(dataset)} clips")
108111

109112
with np.load(cfg.dataset.path, allow_pickle=True) as data:
110113
metadata = data["metadata"].item()
@@ -124,8 +127,9 @@ def train(cfg, logger):
124127
batch_size=cfg.training.batch_size,
125128
shuffle=True,
126129
num_workers=num_workers,
130+
pin_memory=False,
127131
)
128-
logger.info(f"DataLoader created with num_workers={num_workers}.")
132+
logger.info(f"DataLoader created with num_workers={num_workers} and pin_memory=False.")
129133

130134
# Initialize model and optimizer
131135
model = get_model(cfg).to(cfg.training.device)
@@ -143,6 +147,7 @@ def train(cfg, logger):
143147
# Lazily initialize the environment
144148
env = None
145149

150+
logger.info("Starting training loop...")
146151
for epoch in range(cfg.training.epochs):
147152
start_time = time.time()
148153
epoch_losses = []

0 commit comments

Comments
 (0)