Skip to content

Commit 5b1fa47

Browse files
Update train.py
1 parent 8354657 commit 5b1fa47

1 file changed

Lines changed: 21 additions & 16 deletions

File tree

snn-dt/scripts/train.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,25 @@
3636

3737
class OfflineDataset(Dataset):
3838
def __init__(self, dataset_path):
39-
data = np.load(dataset_path, mmap_mode='r')
40-
self.states = torch.from_numpy(data["states"]).float()
41-
self.actions = torch.from_numpy(data["actions"]).float()
42-
self.returns_to_go = torch.from_numpy(data["returns_to_go"]).float()
43-
self.timesteps = torch.from_numpy(data["timesteps"]).long()
44-
self.mask = torch.from_numpy(data["mask"]).float()
39+
# Load with mmap_mode='r' to keep data on disk
40+
self.data = np.load(dataset_path, mmap_mode='r')
41+
self.states = self.data["states"]
42+
self.actions = self.data["actions"]
43+
self.returns_to_go = self.data["returns_to_go"]
44+
self.timesteps = self.data["timesteps"]
45+
self.mask = self.data["mask"]
4546

4647
def __len__(self):
4748
return len(self.states)
4849

4950
def __getitem__(self, idx):
51+
# Convert to tensor only when accessed
5052
return {
51-
"states": self.states[idx],
52-
"actions": self.actions[idx],
53-
"returns_to_go": self.returns_to_go[idx],
54-
"timesteps": self.timesteps[idx],
55-
"mask": self.mask[idx],
53+
"states": torch.as_tensor(self.states[idx], dtype=torch.float32),
54+
"actions": torch.as_tensor(self.actions[idx], dtype=torch.float32),
55+
"returns_to_go": torch.as_tensor(self.returns_to_go[idx], dtype=torch.float32),
56+
"timesteps": torch.as_tensor(self.timesteps[idx], dtype=torch.long),
57+
"mask": torch.as_tensor(self.mask[idx], dtype=torch.float32),
5658
}
5759

5860

@@ -266,24 +268,27 @@ def train(cfg, logger):
266268
epoch_time = time.time() - start_time
267269
avg_loss = np.mean(epoch_losses)
268270

269-
log_str = f"Epoch {epoch+1}/{cfg.training.epochs} | Time: {epoch_time:.2f}s | Loss: {avg_loss:.4f}"
271+
# Simplified Log String
272+
log_items = [f"Epoch {epoch+1}/{cfg.training.epochs}"]
273+
log_items.append(f"Loss: {avg_loss:.4f}")
274+
log_items.append(f"Return: {eval_results['return_mean']:.2f}")
270275

271276
# Spike counting for SNN models
272277
if hasattr(model, "count_spikes"):
273278
spikes = model.count_spikes()
274-
log_str += f" | Spikes: {spikes:.2f}"
279+
log_items.append(f"Spikes: {spikes:.4f}")
275280
eval_results["spikes"] = spikes
276281
else:
277282
eval_results["spikes"] = 0.0
278283

279284
if hasattr(model, "get_max_attn_score"):
280285
max_attn = model.get_max_attn_score()
281-
log_str += f" | Max Attn: {max_attn:.2f}"
286+
# log_items.append(f"MaxAttn: {max_attn:.2f}") # Reduced clutter
282287
eval_results["max_attn"] = max_attn
283288

284289
metrics.append({"epoch": epoch + 1, "loss": avg_loss, **eval_results, "time_s": epoch_time})
285-
log_str += f" | Eval Return: {eval_results['return_mean']:.2f}"
286-
logger.info(log_str)
290+
291+
logger.info(" | ".join(log_items))
287292

288293
if eval_results['return_mean'] > best_eval_return:
289294
best_eval_return = eval_results['return_mean']

0 commit comments

Comments
 (0)