|
36 | 36 |
|
37 | 37 | class OfflineDataset(Dataset): |
38 | 38 | def __init__(self, dataset_path): |
39 | | - data = np.load(dataset_path, mmap_mode='r') |
40 | | - self.states = torch.from_numpy(data["states"]).float() |
41 | | - self.actions = torch.from_numpy(data["actions"]).float() |
42 | | - self.returns_to_go = torch.from_numpy(data["returns_to_go"]).float() |
43 | | - self.timesteps = torch.from_numpy(data["timesteps"]).long() |
44 | | - self.mask = torch.from_numpy(data["mask"]).float() |
| 39 | + # Load with mmap_mode='r' to keep data on disk |
| 40 | + self.data = np.load(dataset_path, mmap_mode='r') |
| 41 | + self.states = self.data["states"] |
| 42 | + self.actions = self.data["actions"] |
| 43 | + self.returns_to_go = self.data["returns_to_go"] |
| 44 | + self.timesteps = self.data["timesteps"] |
| 45 | + self.mask = self.data["mask"] |
45 | 46 |
|
46 | 47 | def __len__(self): |
47 | 48 | return len(self.states) |
48 | 49 |
|
49 | 50 | def __getitem__(self, idx): |
| 51 | + # Convert to tensor only when accessed |
50 | 52 | return { |
51 | | - "states": self.states[idx], |
52 | | - "actions": self.actions[idx], |
53 | | - "returns_to_go": self.returns_to_go[idx], |
54 | | - "timesteps": self.timesteps[idx], |
55 | | - "mask": self.mask[idx], |
| 53 | + "states": torch.as_tensor(self.states[idx], dtype=torch.float32), |
| 54 | + "actions": torch.as_tensor(self.actions[idx], dtype=torch.float32), |
| 55 | + "returns_to_go": torch.as_tensor(self.returns_to_go[idx], dtype=torch.float32), |
| 56 | + "timesteps": torch.as_tensor(self.timesteps[idx], dtype=torch.long), |
| 57 | + "mask": torch.as_tensor(self.mask[idx], dtype=torch.float32), |
56 | 58 | } |
57 | 59 |
|
58 | 60 |
|
@@ -266,24 +268,27 @@ def train(cfg, logger): |
266 | 268 | epoch_time = time.time() - start_time |
267 | 269 | avg_loss = np.mean(epoch_losses) |
268 | 270 |
|
269 | | - log_str = f"Epoch {epoch+1}/{cfg.training.epochs} | Time: {epoch_time:.2f}s | Loss: {avg_loss:.4f}" |
| 271 | + # Simplified Log String |
| 272 | + log_items = [f"Epoch {epoch+1}/{cfg.training.epochs}"] |
| 273 | + log_items.append(f"Loss: {avg_loss:.4f}") |
| 274 | + log_items.append(f"Return: {eval_results['return_mean']:.2f}") |
270 | 275 |
|
271 | 276 | # Spike counting for SNN models |
272 | 277 | if hasattr(model, "count_spikes"): |
273 | 278 | spikes = model.count_spikes() |
274 | | - log_str += f" | Spikes: {spikes:.2f}" |
| 279 | + log_items.append(f"Spikes: {spikes:.4f}") |
275 | 280 | eval_results["spikes"] = spikes |
276 | 281 | else: |
277 | 282 | eval_results["spikes"] = 0.0 |
278 | 283 |
|
279 | 284 | if hasattr(model, "get_max_attn_score"): |
280 | 285 | max_attn = model.get_max_attn_score() |
281 | | - log_str += f" | Max Attn: {max_attn:.2f}" |
| 286 | + # log_items.append(f"MaxAttn: {max_attn:.2f}") # Reduced clutter |
282 | 287 | eval_results["max_attn"] = max_attn |
283 | 288 |
|
284 | 289 | metrics.append({"epoch": epoch + 1, "loss": avg_loss, **eval_results, "time_s": epoch_time}) |
285 | | - log_str += f" | Eval Return: {eval_results['return_mean']:.2f}" |
286 | | - logger.info(log_str) |
| 290 | + |
| 291 | + logger.info(" | ".join(log_items)) |
287 | 292 |
|
288 | 293 | if eval_results['return_mean'] > best_eval_return: |
289 | 294 | best_eval_return = eval_results['return_mean'] |
|
0 commit comments