You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
-**Fused Kronecker preconditioner kernel** (k ≤ 128) — computes the full identity+correction in one launch, no intermediate `(m,n)` tensor
90
+
-**Int8 EMA update kernel** — two-pass design: compute new EMA value + requantize to int8
91
+
-**Bug fix**: the naïve implementation had an `O(k·m²·n)` complexity regression (each output thread recomputed the full `U^T @ G` projection); the fused kernel achieves the correct `O(k·m·n)`
92
+
93
+
```bash
94
+
# Compile CUDA extension (requires nvcc + CUDA toolkit)
95
+
cd scao/cuda && python setup.py build_ext --inplace
96
+
```
97
+
98
+
Falls back to pure PyTorch automatically when CUDA extension is not compiled.
99
+
73
100
---
74
101
75
102
## 3. Algorithm
@@ -172,6 +199,24 @@ PPL improvement vs AdamW (lower is better):
172
199
173
200
This confirms the theoretical prediction: as model scale grows, off-diagonal curvature structure becomes more informative, and SCAO's Kronecker approximation provides larger improvements over the diagonal AdamW baseline.
174
201
202
+
### GPT-2 Scale Smoke Test: 125M and 350M Parameters
203
+
204
+
CPU smoke test (5 steps, batch 2, seq\_len 64, seed 42). **Not converged** — validates correctness and int8 memory savings only.
205
+
206
+
| Scale | Optimizer | Val PPL | tok/s | Peak Mem (GB) | Mem Saved |
-**Int8 EMA is lossless**: SCAO+int8 matches full-precision SCAO PPL exactly at both scales.
217
+
-**Consistent 36.7% memory reduction** from int8 EMA (125M: 2.49→1.58 GB; 350M: 8.83→5.59 GB).
218
+
- 350M shows AdamW winning early-steps (5 warmup steps insufficient for the preconditioner); full GPU runs at ≥5k steps are required for the regime where Kronecker curvature dominates.
0 commit comments