Skip to content

Commit a1b358b

Browse files
SCAO AuthorsCopilot
andcommitted
feat: block-diagonal preconditioning, scao_step tracking, distributed sync
- preconditioner.py: implement use_block_diagonal for max(m,n)>max_precond_dim each block gets its own Kronecker preconditioner (matches paper Algorithm 1) handles update_curvature, precondition, natural_grad_norm, state_dict, memory_bytes - optimizer.py: track scao_step (t_s) in Phase 2; shared-moment bias correction uses global step (correct for non-reset design); blend ramp = min(1, t_s/50) - distributed.py: sync_preconditioners handles use_block_diagonal sub-blocks - tests: +4 block-diagonal tests, +3 distributed tests (7 new, all pass, 67 total) - README.md: accurate Innovation 2 description, corrected algorithm pseudocode Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 0b0b04b commit a1b358b

31 files changed

Lines changed: 1215 additions & 27 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,4 @@ paper/
8181
secrets.yml
8282
*.pem
8383
*.key
84+
SCAO_Technical_Paper.docx

README.md

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ k* = argmin k such that Σᵢ₌₁ᵏ λᵢ / Σⱼ λⱼ ≥ 1 − ε
6161
This reduces memory from `O(m² + n²)` to `O((m+n)·k)`. At GPT-2 scale (`d=768`), typical `k ≤ 32–64`, giving a **16–32× reduction** over full-rank Kronecker factors.
6262

6363
### Innovation 2 — Sparse Block-Diagonal FIM
64-
For layers where `max(m, n) > max_precond_dim`, SCAO falls back to a **diagonal curvature approximation** rather than storing any matrix at all. This prevents memory blow-up at large scales while preserving per-element adaptivity.
64+
For layers where `max(m, n) > max_precond_dim`, SCAO applies **sparse block-diagonal preconditioning**: the gradient matrix is partitioned into contiguous blocks of size ≤ `max_precond_dim` along the larger dimension, and an independent low-rank Kronecker preconditioner is applied per block. This bounds eigendecomp cost at `O(max_precond_dim³)` while preserving full curvature information across all blocks — unlike a diagonal fallback, which discards all inter-parameter correlation.
6565

6666
### Innovation 3 — Phase-Transition Stability
6767
The transition from Adam (Phase 1) to SCAO preconditioning (Phase 2) is the most dangerous moment in training. Three guards prevent instability:
@@ -93,18 +93,22 @@ Phase 2 — SCAO preconditioning (steps T_w + 1 onwards):
9393
Store: (U_L[:, :k], S_L[:k], U_R[:, :k], S_R[:k])
9494
9595
Every step:
96-
Project gradient: G_mid = U_L^T · G · U_R (k×k)
97-
Scale: G_scaled = diag(S_L^{-1/4}) · G_mid · diag(S_R^{-1/4})
98-
Reconstruct: g_eff = blend · (U_L · G_scaled · U_R^T) + (1-blend) · g
99-
where blend = min(1, (t - T_w) / 50)
96+
Preconditioned gradient (identity + low-rank correction):
97+
G_proj = U_L^T · G · U_R (k×k)
98+
G_scaled = diag(S_L^{-1/4}) · G_proj · diag(S_R^{-1/4})
99+
G_precond = G + U_L · (G_scaled - G_proj) · U_R^T (m×n)
100100
101-
Apply Adam update on g_eff:
101+
50-step linear blend ramp (t_s = Phase-2 step count):
102+
blend = min(1.0, t_s / 50)
103+
g_eff = blend · G_precond + (1 - blend) · g
104+
105+
Apply Adam update on g_eff (shared moments, warm-started from Phase 1):
102106
m_t ← β₁ · m_{t-1} + (1-β₁) · g_eff
103-
v_t ← β₂ · v_t-1 + (1-β₂) · g_eff²
107+
v_t ← β₂ · v_{t-1} + (1-β₂) · g_eff²
104108
θ_t ← θ_{t-1} - α · (m_t / (1-β₁^t)) / (√(v_t/(1-β₂^t)) + ε)
105109
```
106110

107-
**Key insight:** SCAO applies Adam *on top of* the preconditioned gradient. Both the first moment `m_t` and second moment `v_t` track `g_eff`, keeping numerator and denominator on the same scale throughout training.
111+
**Key insight:** SCAO uses **shared momentum tensors** that warm-start from Phase 1 into Phase 2. The blend ramp ensures a smooth transition — at `t_s=1`, `g_eff ≈ 0.98 · g_raw`, so the preconditioner's influence grows gradually without disrupting the accumulated momentum. Both moments track `g_eff`, keeping the numerator and denominator on the same scale throughout training.
108112

109113
---
110114

results_benchmark.csv

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
step,adamw,scao,diag_shampoo
2+
1,5.5728559494018555,5.5728559494018555,5.5728559494018555
3+
2,5.578701019287109,5.578701019287109,5.5962300300598145
4+
3,5.586780548095703,5.586779594421387,5.629923343658447
5+
4,5.573197364807129,5.573197364807129,5.654507160186768
6+
5,5.55900239944458,5.55900239944458,5.687943935394287
7+
6,5.579440593719482,5.579440593719482,5.665504455566406
8+
7,5.565978050231934,5.565978050231934,5.662998199462891
9+
8,5.568857192993164,5.568857192993164,5.635212421417236
10+
9,5.5764312744140625,5.576431751251221,5.631796836853027
11+
10,5.573879718780518,5.573879718780518,5.6374006271362305
12+
11,5.5734405517578125,5.5734405517578125,5.630044460296631
13+
12,5.573633193969727,5.573633193969727,5.607842922210693
14+
13,5.564875602722168,5.564875602722168,5.60982084274292
15+
14,5.560469627380371,5.560469627380371,5.59499979019165
16+
15,5.584121227264404,5.584121227264404,5.5788140296936035
17+
16,5.569767951965332,5.569767951965332,5.593851089477539
18+
17,5.562833786010742,5.562833309173584,5.580272674560547
19+
18,5.558739185333252,5.558739185333252,5.5749406814575195
20+
19,5.569887638092041,5.569887638092041,5.59522008895874
21+
20,5.548027515411377,5.548027515411377,5.575991153717041
22+
21,5.552359580993652,5.552359580993652,5.572174072265625
23+
22,5.557245254516602,5.557244777679443,5.568301677703857
24+
23,5.5531721115112305,5.5531721115112305,5.575419902801514
25+
24,5.551608562469482,5.551608562469482,5.5859551429748535
26+
25,5.547980785369873,5.547980785369873,5.567434787750244
27+
26,5.558878421783447,5.558878421783447,5.573901653289795
28+
27,5.572822570800781,5.572822570800781,5.58870792388916
29+
28,5.558602333068848,5.558602333068848,5.563601016998291
30+
29,5.554257869720459,5.554257392883301,5.5545220375061035
31+
30,5.559041500091553,5.559041500091553,5.554906845092773
32+
31,5.564431190490723,5.564431190490723,5.592169761657715
33+
32,5.570614337921143,5.570614337921143,5.566442966461182
34+
33,5.541515350341797,5.541515350341797,5.56029748916626
35+
34,5.551905632019043,5.551905632019043,5.564456939697266
36+
35,5.560442924499512,5.560442924499512,5.5753045082092285
37+
36,5.552741527557373,5.552741050720215,5.565302848815918
38+
37,5.552472114562988,5.552472114562988,5.564217567443848
39+
38,5.564489364624023,5.564489364624023,5.570488929748535
40+
39,5.560131549835205,5.560131549835205,5.577801704406738
41+
40,5.56781005859375,5.56781005859375,5.580154895782471
42+
41,5.559487342834473,5.559488296508789,5.557094097137451
43+
42,5.549378395080566,5.549378395080566,5.5669989585876465
44+
43,5.554733753204346,5.554733753204346,5.551562786102295
45+
44,5.542470455169678,5.5424699783325195,5.5577898025512695
46+
45,5.553576946258545,5.553576946258545,5.570135116577148
47+
46,5.560152530670166,5.560152530670166,5.565868854522705
48+
47,5.557215213775635,5.557215213775635,5.561673164367676
49+
48,5.553524017333984,5.553524017333984,5.55897331237793
50+
49,5.558567523956299,5.558567523956299,5.5570068359375
51+
50,5.561322212219238,5.561322212219238,5.563434600830078
52+
51,5.562047481536865,5.562047481536865,5.559027194976807
53+
52,5.554262161254883,5.5553483963012695,5.551726341247559
54+
53,5.554727077484131,5.554988861083984,5.562111854553223
55+
54,5.564721584320068,5.566233158111572,5.5689263343811035
56+
55,5.55735969543457,5.562169075012207,5.564178943634033
57+
56,5.548841953277588,5.54956579208374,5.55482292175293
58+
57,5.553441524505615,5.5556817054748535,5.549459934234619
59+
58,5.551535606384277,5.55885648727417,5.55819845199585
60+
59,5.548044204711914,5.556419372558594,5.549111843109131
61+
60,5.554605007171631,5.5576372146606445,5.55856466293335
62+
61,5.541609287261963,5.5446882247924805,5.542795181274414
63+
62,5.551147937774658,5.548646926879883,5.5556793212890625
64+
63,5.546630859375,5.547421932220459,5.549322128295898
65+
64,5.554196834564209,5.549775123596191,5.561946868896484
66+
65,5.555569171905518,5.558575630187988,5.5518670082092285
67+
66,5.5655622482299805,5.562289237976074,5.576994895935059
68+
67,5.556685924530029,5.547102928161621,5.552685737609863
69+
68,5.539889812469482,5.554402828216553,5.555147171020508
70+
69,5.552992820739746,5.554600715637207,5.5526347160339355
71+
70,5.563039302825928,5.557827472686768,5.560840129852295
72+
71,5.564231872558594,5.546955585479736,5.567209720611572
73+
72,5.550198554992676,5.556711673736572,5.553574562072754
74+
73,5.54953145980835,5.559920787811279,5.551657676696777
75+
74,5.554469108581543,5.542832851409912,5.54835844039917
76+
75,5.542092323303223,5.5447916984558105,5.555337429046631
77+
76,5.54931640625,5.550760269165039,5.557024955749512
78+
77,5.539743900299072,5.549052715301514,5.546274185180664
79+
78,5.5560760498046875,5.549284934997559,5.564888954162598
80+
79,5.550052165985107,5.553570747375488,5.553678035736084
81+
80,5.551743507385254,5.549341201782227,5.54948616027832
82+
81,5.550701141357422,5.562289237976074,5.554744720458984
83+
82,5.545189380645752,5.559484481811523,5.55905294418335
84+
83,5.5510478019714355,5.564089298248291,5.553728103637695
85+
84,5.5521464347839355,5.561048984527588,5.546876907348633
86+
85,5.550338268280029,5.563986301422119,5.553805351257324
87+
86,5.561927318572998,5.557210922241211,5.559905052185059
88+
87,5.551202774047852,5.563952922821045,5.550068378448486
89+
88,5.554379940032959,5.563688278198242,5.5552849769592285
90+
89,5.549014091491699,5.549198627471924,5.5551958084106445
91+
90,5.560193061828613,5.565088748931885,5.562371253967285
92+
91,5.557939052581787,5.5522332191467285,5.554843425750732
93+
92,5.562680244445801,5.553189277648926,5.557914733886719
94+
93,5.550804138183594,5.550830364227295,5.545635223388672
95+
94,5.540013790130615,5.566745281219482,5.540511608123779
96+
95,5.56000280380249,5.5555830001831055,5.565472602844238
97+
96,5.549357891082764,5.544004440307617,5.5490899085998535
98+
97,5.5422773361206055,5.561646461486816,5.544616222381592
99+
98,5.547672271728516,5.546198844909668,5.546009540557861
100+
99,5.547053337097168,5.553074359893799,5.554990291595459
101+
100,5.562369346618652,5.549615859985352,5.562082290649414

results_multiscale.csv

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
optimizer,scale,n_params,steps,seed,final_train,final_val,final_ppl,avg_last_20,auc,total_time_s,tokens_per_sec,peak_mem_gb
2+
adamw,1M,837888,200,42,2.609443187713623,2.6979486656188967,14.84923970496087,2.6730387449264525,3.289305282831192,249.54487310000695,820.6940798095454,0.0
3+
scao,1M,837888,200,42,2.7116637229919434,2.798486475944519,16.41977622845508,2.768521749973297,3.3539043509960176,190.83978559996467,1073.1514885962956,0.0
4+
diag_shampoo,1M,837888,200,42,4.459746360778809,4.489192714691162,89.0495284775349,4.477831196784973,4.808437628746033,190.79926289990544,1073.3794087424722,0.0

results_multiscale_curves.csv

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
optimizer,seed,wall_clock_s,ppl
2+
adamw,42,3.18,149.4621
3+
adamw,42,11.18,113.0973
4+
adamw,42,19.58,82.9878
5+
adamw,42,27.86,54.7381
6+
adamw,42,36.02,37.6566
7+
adamw,42,44.59,28.6875
8+
adamw,42,53.14,23.9316
9+
adamw,42,61.41,21.2133
10+
adamw,42,69.89,19.2509
11+
adamw,42,78.34,17.9698
12+
adamw,42,87.19,17.0501
13+
adamw,42,95.47,16.4875
14+
adamw,42,104.39,16.0428
15+
adamw,42,116.65,15.6604
16+
adamw,42,157.08,15.4163
17+
adamw,42,189.74,15.2580
18+
adamw,42,204.84,15.1107
19+
adamw,42,223.29,15.0021
20+
adamw,42,235.74,14.9159
21+
adamw,42,244.24,14.8492
22+
scao,42,3.47,150.3750
23+
scao,42,11.79,114.1543
24+
scao,42,20.34,84.9395
25+
scao,42,29.06,57.7818
26+
scao,42,37.52,39.7365
27+
scao,42,46.08,30.3128
28+
scao,42,54.66,25.3575
29+
scao,42,63.22,22.4175
30+
scao,42,71.86,20.3986
31+
scao,42,80.28,19.0522
32+
scao,42,88.86,18.4219
33+
scao,42,97.52,17.8990
34+
scao,42,106.12,17.4777
35+
scao,42,115.52,17.2133
36+
scao,42,124.05,16.9788
37+
scao,42,133.15,16.8036
38+
scao,42,141.87,16.6641
39+
scao,42,150.55,16.5728
40+
scao,42,162.65,16.4775
41+
scao,42,180.19,16.4198
42+
diag_shampoo,42,6.87,258.2252
43+
diag_shampoo,42,25.45,232.8457
44+
diag_shampoo,42,38.61,200.6362
45+
diag_shampoo,42,46.85,170.7445
46+
diag_shampoo,42,55.46,149.4159
47+
diag_shampoo,42,64.14,135.2993
48+
diag_shampoo,42,72.42,125.2282
49+
diag_shampoo,42,81.22,117.5615
50+
diag_shampoo,42,90.68,111.4450
51+
diag_shampoo,42,98.89,106.5383
52+
diag_shampoo,42,107.09,102.4727
53+
diag_shampoo,42,115.54,99.1991
54+
diag_shampoo,42,123.59,96.5694
55+
diag_shampoo,42,132.07,94.4825
56+
diag_shampoo,42,140.33,92.8692
57+
diag_shampoo,42,148.76,91.6371
58+
diag_shampoo,42,156.99,90.7070
59+
diag_shampoo,42,165.30,90.0230
60+
diag_shampoo,42,173.38,89.4939
61+
diag_shampoo,42,181.59,89.0495

results_multiscale_v2.csv

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
optimizer,scale,n_params,steps,seed,final_train,final_val,final_ppl,avg_last_20,auc,total_time_s,tokens_per_sec,peak_mem_gb
2+
adamw,1M,837888,500,42,2.467369318008423,2.4726580572128296,11.85391339318089,2.4515873908996584,2.87281511592865,812.2171206999337,1260.7466327692784,0.01248563826084137
3+
scao,1M,837888,500,42,2.5562660694122314,2.5774046182632446,13.162930957876387,2.543643128871918,2.96584121799469,1292.7824406999862,792.0899663871888,0.012485504150390625

results_multiscale_v2_curves.csv

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
optimizer,seed,wall_clock_s,ppl
2+
adamw,42,15.98,116.5229
3+
adamw,42,42.38,66.5564
4+
adamw,42,68.01,34.7519
5+
adamw,42,93.25,21.7346
6+
adamw,42,118.97,16.9728
7+
adamw,42,154.26,15.0195
8+
adamw,42,206.78,14.0978
9+
adamw,42,270.43,13.4651
10+
adamw,42,319.93,13.2326
11+
adamw,42,369.99,12.7835
12+
adamw,42,396.38,12.6075
13+
adamw,42,421.70,12.4974
14+
adamw,42,462.48,12.2746
15+
adamw,42,513.26,12.2202
16+
adamw,42,563.49,12.1451
17+
adamw,42,616.59,12.0276
18+
adamw,42,667.08,11.9785
19+
adamw,42,716.76,11.9383
20+
adamw,42,767.41,11.8942
21+
adamw,42,801.80,11.8539
22+
scao,42,16.33,117.5468
23+
scao,42,66.01,68.5892
24+
scao,42,120.64,36.3092
25+
scao,42,174.24,23.0127
26+
scao,42,248.61,19.3651
27+
scao,42,315.58,17.1757
28+
scao,42,388.50,16.1121
29+
scao,42,462.22,15.3712
30+
scao,42,526.31,14.7936
31+
scao,42,598.67,14.3948
32+
scao,42,675.21,14.1201
33+
scao,42,738.67,13.8888
34+
scao,42,796.70,13.7077
35+
scao,42,872.42,13.5892
36+
scao,42,942.70,13.5458
37+
scao,42,1007.91,13.3735
38+
scao,42,1075.33,13.3204
39+
scao,42,1142.23,13.2639
40+
scao,42,1204.05,13.2167
41+
scao,42,1267.80,13.1629

results_multiscale_v3.csv

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
optimizer,scale,n_params,steps,seed,final_train,final_val,final_ppl,avg_last_20,auc,total_time_s,tokens_per_sec,peak_mem_gb
2+
adamw,1M,837888,200,42,2.6237757205963135,2.680860481262207,14.597648896524271,2.651871347427368,3.267444581985474,544.8934243000112,751.7066305694295,0.01248563826084137
3+
scao,1M,837888,200,42,2.755004644393921,2.8124120807647706,16.650031041261926,2.7797602295875548,3.355783064365387,719.2764292000793,569.4611742741574,0.02484196424484253

results_multiscale_v3_curves.csv

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
optimizer,seed,wall_clock_s,ppl
2+
adamw,42,6.05,148.7067
3+
adamw,42,22.44,112.6560
4+
adamw,42,38.52,82.5069
5+
adamw,42,54.67,54.6896
6+
adamw,42,72.72,37.4680
7+
adamw,42,104.85,28.5375
8+
adamw,42,138.20,23.6325
9+
adamw,42,169.92,20.6589
10+
adamw,42,202.55,18.7545
11+
adamw,42,236.78,17.6057
12+
adamw,42,270.20,16.7743
13+
adamw,42,310.94,16.2468
14+
adamw,42,345.56,15.8245
15+
adamw,42,378.86,15.4520
16+
adamw,42,411.10,15.1534
17+
adamw,42,442.98,14.9748
18+
adamw,42,462.65,14.8452
19+
adamw,42,478.92,14.7617
20+
adamw,42,495.01,14.6753
21+
adamw,42,524.07,14.5976
22+
scao,42,11.56,149.6049
23+
scao,42,44.15,113.7046
24+
scao,42,81.87,84.5125
25+
scao,42,128.89,57.1945
26+
scao,42,180.01,39.5233
27+
scao,42,220.11,30.2746
28+
scao,42,254.97,25.1753
29+
scao,42,287.96,22.0837
30+
scao,42,321.88,20.8752
31+
scao,42,355.28,19.8731
32+
scao,42,388.20,19.0707
33+
scao,42,421.67,18.4424
34+
scao,42,453.84,17.9741
35+
scao,42,487.58,17.6508
36+
scao,42,521.59,17.3120
37+
scao,42,555.05,17.0866
38+
scao,42,587.47,16.9289
39+
scao,42,624.79,16.8306
40+
scao,42,662.70,16.7322
41+
scao,42,696.66,16.6500

results_multiscale_v4.csv

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
optimizer,scale,n_params,steps,seed,final_train,final_val,final_ppl,avg_last_20,auc,total_time_s,tokens_per_sec,peak_mem_gb
2+
adamw,1M,837888,200,42,2.6237757205963135,2.680860481262207,14.597648896524271,2.651871347427368,3.267444581985474,701.1919704000466,584.1481609755365,0.01248563826084137
3+
scao,1M,837888,200,42,2.7310643196105957,2.792778310775757,16.32631642946134,2.759758996963501,3.2999178767204285,801.5647062000353,511.0005428529707,0.02490319311618805

0 commit comments

Comments
 (0)