Skip to content

Commit 5e0490f

Browse files
committed
Professionalize and tweak record experiment
1 parent b8c99b2 commit 5e0490f

1 file changed

Lines changed: 116 additions & 96 deletions

File tree

Lines changed: 116 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,191 +1,211 @@
1+
import os
2+
import time
3+
14
import torch
25
import torch.nn as nn
36
from torchvision import datasets, transforms
47
from torch.utils.data import DataLoader
5-
import os
6-
import time
78

89
from odyssnet import OdyssNet, OdyssNetTrainer, TrainingHistory, set_seed
910

10-
# --- Configuration ---
11+
# ---------------------------------------------------------------------------
12+
# Configuration
13+
# ---------------------------------------------------------------------------
1114
SEED = 42
1215
NUM_EPOCHS = 100
1316
BATCH_SIZE = 32
1417
LR = 1e-2
18+
19+
# Architecture
1520
NUM_NEURONS = 10
21+
EMBED_NEURONS = 4 # Neurons that receive patch input (first N neurons)
22+
NUM_CLASSES = 10 # Output classes
23+
24+
# Patch strategy: divide 28×28 image into GRID_SIZE×GRID_SIZE non-overlapping patches
25+
IMAGE_SIZE = 28
1626
GRID_SIZE = 4
17-
THINKING_STEPS = GRID_SIZE * GRID_SIZE # Total patches in spiral order
27+
PATCH_SIZE = IMAGE_SIZE // GRID_SIZE # 7 pixels per side
28+
PATCH_PIXELS = PATCH_SIZE * PATCH_SIZE # 49 pixels per patch (embed input dim)
29+
NUM_PATCHES = GRID_SIZE * GRID_SIZE # 16 patches total
30+
THINKING_RATIO = 1 # Thinking steps per patch (1 = inject only, 2 = inject + 1 free step, ...)
31+
THINKING_STEPS = NUM_PATCHES * THINKING_RATIO
32+
33+
# DataLoader
34+
NUM_WORKERS = min(4, os.cpu_count() or 1)
35+
1836
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
1937

20-
def get_spiral_indices(rows: int, cols: int):
21-
"""
22-
Algorithmically calculates the clockwise inward spiral visit order for any grid size.
23-
Returns a list of flat indices.
24-
"""
38+
39+
# ---------------------------------------------------------------------------
40+
# Helpers
41+
# ---------------------------------------------------------------------------
42+
43+
def get_spiral_indices(rows: int, cols: int) -> list[int]:
44+
"""Return flat patch indices in clockwise inward spiral order."""
2545
indices = []
2646
top, bottom = 0, rows - 1
2747
left, right = 0, cols - 1
2848

2949
while top <= bottom and left <= right:
30-
# Move Right
31-
for i in range(left, right + 1):
50+
for i in range(left, right + 1): # right
3251
indices.append(top * cols + i)
3352
top += 1
3453

35-
# Move Down
36-
for i in range(top, bottom + 1):
54+
for i in range(top, bottom + 1): # down
3755
indices.append(i * cols + right)
3856
right -= 1
3957

4058
if top <= bottom:
41-
# Move Left
42-
for i in range(right, left - 1, -1):
59+
for i in range(right, left - 1, -1): # left
4360
indices.append(bottom * cols + i)
4461
bottom -= 1
4562

4663
if left <= right:
47-
# Move Up
48-
for i in range(bottom, top - 1, -1):
64+
for i in range(bottom, top - 1, -1): # up
4965
indices.append(i * cols + left)
5066
left += 1
5167

5268
return indices
5369

5470

71+
def format_time(seconds: float) -> str:
72+
m, s = divmod(int(seconds), 60)
73+
h, m = divmod(m, 60)
74+
return f"{h:02d}:{m:02d}:{s:02d}"
75+
76+
77+
def extract_spiral_patches(images: torch.Tensor, spiral: list[int]) -> torch.Tensor:
78+
"""
79+
Extract and spiral-reorder non-overlapping patches from a batch of images.
80+
81+
Args:
82+
images: (B, 1, H, W) image batch.
83+
spiral: Patch visit order (flat indices).
84+
85+
Returns:
86+
(B, NUM_PATCHES, PATCH_PIXELS) tensor.
87+
"""
88+
b = images.size(0)
89+
patches = images.unfold(2, PATCH_SIZE, PATCH_SIZE).unfold(3, PATCH_SIZE, PATCH_SIZE)
90+
# (B, 1, GRID_SIZE, GRID_SIZE, PATCH_SIZE, PATCH_SIZE) → (B, NUM_PATCHES, PATCH_PIXELS)
91+
patches = patches.contiguous().view(b, NUM_PATCHES, PATCH_PIXELS)
92+
return patches[:, spiral, :]
93+
94+
95+
# ---------------------------------------------------------------------------
96+
# Main
97+
# ---------------------------------------------------------------------------
98+
5599
def main():
56-
print("OdyssNet: MNIST RECORD CHALLENGE (Spiral-Fed 4x4 Patch Model)")
57-
print(f"Strategy: 16 Spiral Patches (7x7=49 pixels) -> Embed(4 Neurons) -> Core({NUM_NEURONS}) -> Decoder(10 Classes)")
100+
print("OdyssNet: MNIST Record Challenge")
101+
print(
102+
f" Strategy : {NUM_PATCHES} spiral patches "
103+
f"({PATCH_SIZE}x{PATCH_SIZE}={PATCH_PIXELS} px) → "
104+
f"Embed({EMBED_NEURONS}) → Core({NUM_NEURONS}) → Decoder({NUM_CLASSES})"
105+
)
58106
set_seed(SEED)
59107

108+
# GPU optimisations
109+
use_compile = False
60110
if DEVICE == 'cuda':
61111
torch.set_float32_matmul_precision('high')
62112
torch.backends.cudnn.benchmark = True
63-
# Try to compile model for speed if available
64-
model_compile = hasattr(torch, 'compile')
65-
if model_compile:
66-
print("OdyssNet: torch.compile enabled for speed.")
67-
else:
68-
model_compile = False
69-
70-
# Strategy: 16 Patches -> Embed(4) -> Core(10) -> Output Decoder (10)
71-
input_ids = [0, 1, 2, 3] # Map to first 4 neurons
72-
output_ids = list(range(NUM_NEURONS)) # Decoder reads from all neurons
73-
74-
# vocab_size = [v_in, v_out]
75-
# v_in = 49 pixels (from each 7x7 patch)
76-
# v_out = 10 classes
113+
use_compile = hasattr(torch, 'compile')
114+
if use_compile:
115+
print(" torch.compile enabled.")
116+
117+
# Model
118+
input_ids = list(range(EMBED_NEURONS))
119+
output_ids = list(range(NUM_NEURONS))
120+
77121
model = OdyssNet(
78122
num_neurons=NUM_NEURONS,
79123
input_ids=input_ids,
80124
output_ids=output_ids,
81125
device=DEVICE,
82-
vocab_size=[49, 10],
126+
vocab_size=[PATCH_PIXELS, NUM_CLASSES],
83127
vocab_mode='continuous',
84128
weight_init='micro_quiet_warm',
85-
gate='none'
129+
gate='none',
86130
)
87131

88-
# Speed up core with torch.compile if on PyTorch 2.0+
89-
if 'model_compile' in locals() and model_compile:
132+
if use_compile:
90133
model = torch.compile(model)
91134

92135
total_params = model.get_num_params()
93-
print(f"Total Params: {total_params} (Goal: < 500)")
136+
print(f" Params : {total_params} (target: < 500)\n")
94137

95-
# Data Preparation
138+
# Data
96139
train_transform = transforms.Compose([
97140
transforms.RandomAffine(degrees=5, translate=(0.05, 0.05), scale=(0.95, 1.05)),
98141
transforms.ToTensor(),
99-
transforms.Normalize((0.5,), (0.5,))
142+
transforms.Normalize((0.5,), (0.5,)),
100143
])
101-
102144
test_transform = transforms.Compose([
103145
transforms.ToTensor(),
104-
transforms.Normalize((0.5,), (0.5,))
146+
transforms.Normalize((0.5,), (0.5,)),
105147
])
106148

107149
data_dir = os.path.join(os.path.dirname(__file__), '..', 'data')
108150
train_dataset = datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
109151
test_dataset = datasets.MNIST(root=data_dir, train=False, download=True, transform=test_transform)
110152

111-
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=8)
112-
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=8)
153+
loader_kwargs = dict(batch_size=BATCH_SIZE, pin_memory=(DEVICE == 'cuda'), num_workers=NUM_WORKERS)
154+
train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs)
155+
test_loader = DataLoader(test_dataset, shuffle=False, **loader_kwargs)
113156

114-
trainer = OdyssNetTrainer(
115-
model,
116-
device=DEVICE, lr=LR,
117-
)
157+
# Trainer
158+
trainer = OdyssNetTrainer(model, device=DEVICE, lr=LR)
159+
trainer.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
118160

119-
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
120-
trainer.loss_fn = loss_fn
161+
print(f"Training for {NUM_EPOCHS} epochs | batch {BATCH_SIZE} | lr {LR} | device {DEVICE}")
121162

122-
print(f"Training with Batch Size: {BATCH_SIZE} for {NUM_EPOCHS} Epochs...")
123163
history = TrainingHistory()
164+
spiral = get_spiral_indices(GRID_SIZE, GRID_SIZE)
124165
start_time = time.time()
125166

126-
SPIRAL = get_spiral_indices(GRID_SIZE, GRID_SIZE)
127-
128-
# Processing Loop
129-
130167
for epoch in range(NUM_EPOCHS):
168+
# --- Train ---
131169
model.train()
132-
total_loss = 0
133-
134-
for batch_idx, (data, target) in enumerate(train_loader):
135-
batch_size = data.size(0)
136-
data = data.to(DEVICE, non_blocking=True)
137-
target = target.to(DEVICE, non_blocking=True)
138-
139-
# Extract 4x4 grid of 7x7 patches: (B, 1, 28, 28) -> (B, 16, 49)
140-
patches = data.unfold(2, 7, 7).unfold(3, 7, 7) # (B, 1, 4, 4, 7, 7)
141-
patches = patches.contiguous().view(batch_size, 16, 49) # (B, 16, 49)
170+
total_loss = 0.0
142171

143-
# Reorder patches: edges first, center last (spiral inward)
144-
seq_input = patches[:, SPIRAL, :] # (B, 16, 49)
145-
146-
loss = trainer.train_batch(seq_input, target, thinking_steps=THINKING_STEPS)
147-
total_loss += loss
172+
for images, targets in train_loader:
173+
images = images.to(DEVICE, non_blocking=True)
174+
targets = targets.to(DEVICE, non_blocking=True)
175+
seq = extract_spiral_patches(images, spiral)
176+
total_loss += trainer.train_batch(seq, targets, thinking_steps=THINKING_STEPS)
148177

149178
avg_loss = total_loss / len(train_loader)
150179

151-
# Eval
180+
# --- Evaluate ---
152181
model.eval()
153182
correct = 0
154183
total = 0
155184
with torch.no_grad():
156-
for data, target in test_loader:
157-
batch_size = data.size(0)
158-
data = data.to(DEVICE)
159-
target = target.to(DEVICE)
160-
161-
# Same patch extraction and spiral reordering
162-
patches = data.unfold(2, 7, 7).unfold(3, 7, 7)
163-
patches = patches.contiguous().view(batch_size, 16, 49)
164-
seq_input = patches[:, SPIRAL, :]
165-
166-
preds = trainer.predict(seq_input, thinking_steps=THINKING_STEPS)
167-
correct += (preds.argmax(1) == target).sum().item()
168-
total += target.size(0)
185+
for images, targets in test_loader:
186+
images = images.to(DEVICE, non_blocking=True)
187+
targets = targets.to(DEVICE, non_blocking=True)
188+
seq = extract_spiral_patches(images, spiral)
189+
preds = trainer.predict(seq, thinking_steps=THINKING_STEPS)
190+
correct += (preds.argmax(1) == targets).sum().item()
191+
total += targets.size(0)
169192

170193
acc = 100.0 * correct / total
171194

172-
# Calculate time metrics
173195
elapsed = time.time() - start_time
174-
avg_time_per_epoch = elapsed / (epoch + 1)
175-
remaining_epochs = NUM_EPOCHS - (epoch + 1)
176-
eta_seconds = remaining_epochs * avg_time_per_epoch
177-
178-
def format_time(seconds):
179-
m, s = divmod(int(seconds), 60)
180-
h, m = divmod(m, 60)
181-
return f"{h:02d}:{m:02d}:{s:02d}"
196+
eta = (elapsed / (epoch + 1)) * (NUM_EPOCHS - epoch - 1)
182197

183198
history.record(loss=avg_loss, accuracy=acc)
199+
print(
200+
f"Epoch {epoch+1:4d}/{NUM_EPOCHS} | "
201+
f"Loss {avg_loss:.4f} | "
202+
f"Acc {acc:5.2f}% | "
203+
f"Elapsed {format_time(elapsed)} | "
204+
f"ETA {format_time(eta)}"
205+
)
184206

185-
print(f"Epoch {epoch+1:4d}/{NUM_EPOCHS} | Loss {avg_loss:.4f} | Acc {acc:5.2f}% | "
186-
f"Elapsed {format_time(elapsed)} | ETA {format_time(eta_seconds)}")
207+
history.plot(title=f"MNIST Record ({total_params} params) — {NUM_PATCHES}-patch spiral")
187208

188-
history.plot(title="MNIST Record (480 Params) Training")
189209

190210
if __name__ == "__main__":
191211
main()

0 commit comments

Comments
 (0)