|
| 1 | +import os |
| 2 | +import time |
| 3 | + |
1 | 4 | import torch |
2 | 5 | import torch.nn as nn |
3 | 6 | from torchvision import datasets, transforms |
4 | 7 | from torch.utils.data import DataLoader |
5 | | -import os |
6 | | -import time |
7 | 8 |
|
8 | 9 | from odyssnet import OdyssNet, OdyssNetTrainer, TrainingHistory, set_seed |
9 | 10 |
|
10 | | -# --- Configuration --- |
| 11 | +# --------------------------------------------------------------------------- |
| 12 | +# Configuration |
| 13 | +# --------------------------------------------------------------------------- |
11 | 14 | SEED = 42 |
12 | 15 | NUM_EPOCHS = 100 |
13 | 16 | BATCH_SIZE = 32 |
14 | 17 | LR = 1e-2 |
| 18 | + |
| 19 | +# Architecture |
15 | 20 | NUM_NEURONS = 10 |
| 21 | +EMBED_NEURONS = 4 # Neurons that receive patch input (first N neurons) |
| 22 | +NUM_CLASSES = 10 # Output classes |
| 23 | + |
| 24 | +# Patch strategy: divide 28×28 image into GRID_SIZE×GRID_SIZE non-overlapping patches |
| 25 | +IMAGE_SIZE = 28 |
16 | 26 | GRID_SIZE = 4 |
17 | | -THINKING_STEPS = GRID_SIZE * GRID_SIZE # Total patches in spiral order |
| 27 | +PATCH_SIZE = IMAGE_SIZE // GRID_SIZE # 7 pixels per side |
| 28 | +PATCH_PIXELS = PATCH_SIZE * PATCH_SIZE # 49 pixels per patch (embed input dim) |
| 29 | +NUM_PATCHES = GRID_SIZE * GRID_SIZE # 16 patches total |
| 30 | +THINKING_RATIO = 1 # Thinking steps per patch (1 = inject only, 2 = inject + 1 free step, ...) |
| 31 | +THINKING_STEPS = NUM_PATCHES * THINKING_RATIO |
| 32 | + |
| 33 | +# DataLoader |
| 34 | +NUM_WORKERS = min(4, os.cpu_count() or 1) |
| 35 | + |
18 | 36 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
19 | 37 |
|
20 | | -def get_spiral_indices(rows: int, cols: int): |
21 | | - """ |
22 | | - Algorithmically calculates the clockwise inward spiral visit order for any grid size. |
23 | | - Returns a list of flat indices. |
24 | | - """ |
| 38 | + |
| 39 | +# --------------------------------------------------------------------------- |
| 40 | +# Helpers |
| 41 | +# --------------------------------------------------------------------------- |
| 42 | + |
| 43 | +def get_spiral_indices(rows: int, cols: int) -> list[int]: |
| 44 | + """Return flat patch indices in clockwise inward spiral order.""" |
25 | 45 | indices = [] |
26 | 46 | top, bottom = 0, rows - 1 |
27 | 47 | left, right = 0, cols - 1 |
28 | 48 |
|
29 | 49 | while top <= bottom and left <= right: |
30 | | - # Move Right |
31 | | - for i in range(left, right + 1): |
| 50 | + for i in range(left, right + 1): # right |
32 | 51 | indices.append(top * cols + i) |
33 | 52 | top += 1 |
34 | 53 |
|
35 | | - # Move Down |
36 | | - for i in range(top, bottom + 1): |
| 54 | + for i in range(top, bottom + 1): # down |
37 | 55 | indices.append(i * cols + right) |
38 | 56 | right -= 1 |
39 | 57 |
|
40 | 58 | if top <= bottom: |
41 | | - # Move Left |
42 | | - for i in range(right, left - 1, -1): |
| 59 | + for i in range(right, left - 1, -1): # left |
43 | 60 | indices.append(bottom * cols + i) |
44 | 61 | bottom -= 1 |
45 | 62 |
|
46 | 63 | if left <= right: |
47 | | - # Move Up |
48 | | - for i in range(bottom, top - 1, -1): |
| 64 | + for i in range(bottom, top - 1, -1): # up |
49 | 65 | indices.append(i * cols + left) |
50 | 66 | left += 1 |
51 | 67 |
|
52 | 68 | return indices |
53 | 69 |
|
54 | 70 |
|
| 71 | +def format_time(seconds: float) -> str: |
| 72 | + m, s = divmod(int(seconds), 60) |
| 73 | + h, m = divmod(m, 60) |
| 74 | + return f"{h:02d}:{m:02d}:{s:02d}" |
| 75 | + |
| 76 | + |
| 77 | +def extract_spiral_patches(images: torch.Tensor, spiral: list[int]) -> torch.Tensor: |
| 78 | + """ |
| 79 | + Extract and spiral-reorder non-overlapping patches from a batch of images. |
| 80 | +
|
| 81 | + Args: |
| 82 | + images: (B, 1, H, W) image batch. |
| 83 | + spiral: Patch visit order (flat indices). |
| 84 | +
|
| 85 | + Returns: |
| 86 | + (B, NUM_PATCHES, PATCH_PIXELS) tensor. |
| 87 | + """ |
| 88 | + b = images.size(0) |
| 89 | + patches = images.unfold(2, PATCH_SIZE, PATCH_SIZE).unfold(3, PATCH_SIZE, PATCH_SIZE) |
| 90 | + # (B, 1, GRID_SIZE, GRID_SIZE, PATCH_SIZE, PATCH_SIZE) → (B, NUM_PATCHES, PATCH_PIXELS) |
| 91 | + patches = patches.contiguous().view(b, NUM_PATCHES, PATCH_PIXELS) |
| 92 | + return patches[:, spiral, :] |
| 93 | + |
| 94 | + |
| 95 | +# --------------------------------------------------------------------------- |
| 96 | +# Main |
| 97 | +# --------------------------------------------------------------------------- |
| 98 | + |
55 | 99 | def main(): |
56 | | - print("OdyssNet: MNIST RECORD CHALLENGE (Spiral-Fed 4x4 Patch Model)") |
57 | | - print(f"Strategy: 16 Spiral Patches (7x7=49 pixels) -> Embed(4 Neurons) -> Core({NUM_NEURONS}) -> Decoder(10 Classes)") |
| 100 | + print("OdyssNet: MNIST Record Challenge") |
| 101 | + print( |
| 102 | + f" Strategy : {NUM_PATCHES} spiral patches " |
| 103 | + f"({PATCH_SIZE}x{PATCH_SIZE}={PATCH_PIXELS} px) → " |
| 104 | + f"Embed({EMBED_NEURONS}) → Core({NUM_NEURONS}) → Decoder({NUM_CLASSES})" |
| 105 | + ) |
58 | 106 | set_seed(SEED) |
59 | 107 |
|
| 108 | + # GPU optimisations |
| 109 | + use_compile = False |
60 | 110 | if DEVICE == 'cuda': |
61 | 111 | torch.set_float32_matmul_precision('high') |
62 | 112 | torch.backends.cudnn.benchmark = True |
63 | | - # Try to compile model for speed if available |
64 | | - model_compile = hasattr(torch, 'compile') |
65 | | - if model_compile: |
66 | | - print("OdyssNet: torch.compile enabled for speed.") |
67 | | - else: |
68 | | - model_compile = False |
69 | | - |
70 | | - # Strategy: 16 Patches -> Embed(4) -> Core(10) -> Output Decoder (10) |
71 | | - input_ids = [0, 1, 2, 3] # Map to first 4 neurons |
72 | | - output_ids = list(range(NUM_NEURONS)) # Decoder reads from all neurons |
73 | | - |
74 | | - # vocab_size = [v_in, v_out] |
75 | | - # v_in = 49 pixels (from each 7x7 patch) |
76 | | - # v_out = 10 classes |
| 113 | + use_compile = hasattr(torch, 'compile') |
| 114 | + if use_compile: |
| 115 | + print(" torch.compile enabled.") |
| 116 | + |
| 117 | + # Model |
| 118 | + input_ids = list(range(EMBED_NEURONS)) |
| 119 | + output_ids = list(range(NUM_NEURONS)) |
| 120 | + |
77 | 121 | model = OdyssNet( |
78 | 122 | num_neurons=NUM_NEURONS, |
79 | 123 | input_ids=input_ids, |
80 | 124 | output_ids=output_ids, |
81 | 125 | device=DEVICE, |
82 | | - vocab_size=[49, 10], |
| 126 | + vocab_size=[PATCH_PIXELS, NUM_CLASSES], |
83 | 127 | vocab_mode='continuous', |
84 | 128 | weight_init='micro_quiet_warm', |
85 | | - gate='none' |
| 129 | + gate='none', |
86 | 130 | ) |
87 | 131 |
|
88 | | - # Speed up core with torch.compile if on PyTorch 2.0+ |
89 | | - if 'model_compile' in locals() and model_compile: |
| 132 | + if use_compile: |
90 | 133 | model = torch.compile(model) |
91 | 134 |
|
92 | 135 | total_params = model.get_num_params() |
93 | | - print(f"Total Params: {total_params} (Goal: < 500)") |
| 136 | + print(f" Params : {total_params} (target: < 500)\n") |
94 | 137 |
|
95 | | - # Data Preparation |
| 138 | + # Data |
96 | 139 | train_transform = transforms.Compose([ |
97 | 140 | transforms.RandomAffine(degrees=5, translate=(0.05, 0.05), scale=(0.95, 1.05)), |
98 | 141 | transforms.ToTensor(), |
99 | | - transforms.Normalize((0.5,), (0.5,)) |
| 142 | + transforms.Normalize((0.5,), (0.5,)), |
100 | 143 | ]) |
101 | | - |
102 | 144 | test_transform = transforms.Compose([ |
103 | 145 | transforms.ToTensor(), |
104 | | - transforms.Normalize((0.5,), (0.5,)) |
| 146 | + transforms.Normalize((0.5,), (0.5,)), |
105 | 147 | ]) |
106 | 148 |
|
107 | 149 | data_dir = os.path.join(os.path.dirname(__file__), '..', 'data') |
108 | 150 | train_dataset = datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform) |
109 | 151 | test_dataset = datasets.MNIST(root=data_dir, train=False, download=True, transform=test_transform) |
110 | 152 |
|
111 | | - train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=8) |
112 | | - test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=8) |
| 153 | + loader_kwargs = dict(batch_size=BATCH_SIZE, pin_memory=(DEVICE == 'cuda'), num_workers=NUM_WORKERS) |
| 154 | + train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs) |
| 155 | + test_loader = DataLoader(test_dataset, shuffle=False, **loader_kwargs) |
113 | 156 |
|
114 | | - trainer = OdyssNetTrainer( |
115 | | - model, |
116 | | - device=DEVICE, lr=LR, |
117 | | - ) |
| 157 | + # Trainer |
| 158 | + trainer = OdyssNetTrainer(model, device=DEVICE, lr=LR) |
| 159 | + trainer.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) |
118 | 160 |
|
119 | | - loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) |
120 | | - trainer.loss_fn = loss_fn |
| 161 | + print(f"Training for {NUM_EPOCHS} epochs | batch {BATCH_SIZE} | lr {LR} | device {DEVICE}") |
121 | 162 |
|
122 | | - print(f"Training with Batch Size: {BATCH_SIZE} for {NUM_EPOCHS} Epochs...") |
123 | 163 | history = TrainingHistory() |
| 164 | + spiral = get_spiral_indices(GRID_SIZE, GRID_SIZE) |
124 | 165 | start_time = time.time() |
125 | 166 |
|
126 | | - SPIRAL = get_spiral_indices(GRID_SIZE, GRID_SIZE) |
127 | | - |
128 | | - # Processing Loop |
129 | | - |
130 | 167 | for epoch in range(NUM_EPOCHS): |
| 168 | + # --- Train --- |
131 | 169 | model.train() |
132 | | - total_loss = 0 |
133 | | - |
134 | | - for batch_idx, (data, target) in enumerate(train_loader): |
135 | | - batch_size = data.size(0) |
136 | | - data = data.to(DEVICE, non_blocking=True) |
137 | | - target = target.to(DEVICE, non_blocking=True) |
138 | | - |
139 | | - # Extract 4x4 grid of 7x7 patches: (B, 1, 28, 28) -> (B, 16, 49) |
140 | | - patches = data.unfold(2, 7, 7).unfold(3, 7, 7) # (B, 1, 4, 4, 7, 7) |
141 | | - patches = patches.contiguous().view(batch_size, 16, 49) # (B, 16, 49) |
| 170 | + total_loss = 0.0 |
142 | 171 |
|
143 | | - # Reorder patches: edges first, center last (spiral inward) |
144 | | - seq_input = patches[:, SPIRAL, :] # (B, 16, 49) |
145 | | - |
146 | | - loss = trainer.train_batch(seq_input, target, thinking_steps=THINKING_STEPS) |
147 | | - total_loss += loss |
| 172 | + for images, targets in train_loader: |
| 173 | + images = images.to(DEVICE, non_blocking=True) |
| 174 | + targets = targets.to(DEVICE, non_blocking=True) |
| 175 | + seq = extract_spiral_patches(images, spiral) |
| 176 | + total_loss += trainer.train_batch(seq, targets, thinking_steps=THINKING_STEPS) |
148 | 177 |
|
149 | 178 | avg_loss = total_loss / len(train_loader) |
150 | 179 |
|
151 | | - # Eval |
| 180 | + # --- Evaluate --- |
152 | 181 | model.eval() |
153 | 182 | correct = 0 |
154 | 183 | total = 0 |
155 | 184 | with torch.no_grad(): |
156 | | - for data, target in test_loader: |
157 | | - batch_size = data.size(0) |
158 | | - data = data.to(DEVICE) |
159 | | - target = target.to(DEVICE) |
160 | | - |
161 | | - # Same patch extraction and spiral reordering |
162 | | - patches = data.unfold(2, 7, 7).unfold(3, 7, 7) |
163 | | - patches = patches.contiguous().view(batch_size, 16, 49) |
164 | | - seq_input = patches[:, SPIRAL, :] |
165 | | - |
166 | | - preds = trainer.predict(seq_input, thinking_steps=THINKING_STEPS) |
167 | | - correct += (preds.argmax(1) == target).sum().item() |
168 | | - total += target.size(0) |
| 185 | + for images, targets in test_loader: |
| 186 | + images = images.to(DEVICE, non_blocking=True) |
| 187 | + targets = targets.to(DEVICE, non_blocking=True) |
| 188 | + seq = extract_spiral_patches(images, spiral) |
| 189 | + preds = trainer.predict(seq, thinking_steps=THINKING_STEPS) |
| 190 | + correct += (preds.argmax(1) == targets).sum().item() |
| 191 | + total += targets.size(0) |
169 | 192 |
|
170 | 193 | acc = 100.0 * correct / total |
171 | 194 |
|
172 | | - # Calculate time metrics |
173 | 195 | elapsed = time.time() - start_time |
174 | | - avg_time_per_epoch = elapsed / (epoch + 1) |
175 | | - remaining_epochs = NUM_EPOCHS - (epoch + 1) |
176 | | - eta_seconds = remaining_epochs * avg_time_per_epoch |
177 | | - |
178 | | - def format_time(seconds): |
179 | | - m, s = divmod(int(seconds), 60) |
180 | | - h, m = divmod(m, 60) |
181 | | - return f"{h:02d}:{m:02d}:{s:02d}" |
| 196 | + eta = (elapsed / (epoch + 1)) * (NUM_EPOCHS - epoch - 1) |
182 | 197 |
|
183 | 198 | history.record(loss=avg_loss, accuracy=acc) |
| 199 | + print( |
| 200 | + f"Epoch {epoch+1:4d}/{NUM_EPOCHS} | " |
| 201 | + f"Loss {avg_loss:.4f} | " |
| 202 | + f"Acc {acc:5.2f}% | " |
| 203 | + f"Elapsed {format_time(elapsed)} | " |
| 204 | + f"ETA {format_time(eta)}" |
| 205 | + ) |
184 | 206 |
|
185 | | - print(f"Epoch {epoch+1:4d}/{NUM_EPOCHS} | Loss {avg_loss:.4f} | Acc {acc:5.2f}% | " |
186 | | - f"Elapsed {format_time(elapsed)} | ETA {format_time(eta_seconds)}") |
| 207 | + history.plot(title=f"MNIST Record ({total_params} params) — {NUM_PATCHES}-patch spiral") |
187 | 208 |
|
188 | | - history.plot(title="MNIST Record (480 Params) Training") |
189 | 209 |
|
190 | 210 | if __name__ == "__main__": |
191 | 211 | main() |
0 commit comments