Skip to content

Commit 8ea79c1

Browse files
committed
feat(opt): Muon optimizer (Newton-Schulz orthogonalization) for 2D weights
- Add src/tri/math/muon_optimizer.zig - G1: Newton-Schulz 5-iteration orthogonalization for 2D weights - G2: Hybrid Muon+AdamW setup (Muon for attn/MLP, AdamW for embed/norm) - Momentum-based velocity accumulation with Nesterov option - AdamW with bias correction, weight decay, epsilon guard - Cosine decay-ready (externally scheduled) - 6 tests: NS normalization, zero vector, direction preservation, Muon weight update, AdamW convergence, hybrid dual-tensor Closes #535 Ref: R12
1 parent 9ef0871 commit 8ea79c1

1 file changed

Lines changed: 327 additions & 0 deletions

File tree

src/tri/math/muon_optimizer.zig

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
const std = @import("std");
2+
3+
pub const MuonConfig = struct {
4+
lr: f32 = 0.02,
5+
momentum: f32 = 0.95,
6+
ns_iterations: u32 = 5,
7+
weight_decay: f32 = 0.0,
8+
nesterov: bool = true,
9+
};
10+
11+
pub const MuonState = struct {
12+
velocity: []f32,
13+
momentum_buffer: []f32,
14+
step_count: u32,
15+
16+
pub fn init(allocator: std.mem.Allocator, num_params: usize) !MuonState {
17+
const velocity = try allocator.alloc(f32, num_params);
18+
const momentum_buffer = try allocator.alloc(f32, num_params);
19+
@memset(velocity, 0);
20+
@memset(momentum_buffer, 0);
21+
return .{
22+
.velocity = velocity,
23+
.momentum_buffer = momentum_buffer,
24+
.step_count = 0,
25+
};
26+
}
27+
28+
pub fn deinit(self: *MuonState, allocator: std.mem.Allocator) void {
29+
allocator.free(self.velocity);
30+
allocator.free(self.momentum_buffer);
31+
}
32+
};
33+
34+
pub const MuonOptimizer = struct {
35+
config: MuonConfig,
36+
allocator: std.mem.Allocator,
37+
states: std.ArrayList(MuonState),
38+
39+
pub fn init(allocator: std.mem.Allocator, config: MuonConfig) MuonOptimizer {
40+
return .{
41+
.config = config,
42+
.allocator = allocator,
43+
.states = std.ArrayList(MuonState).init(allocator),
44+
};
45+
}
46+
47+
pub fn deinit(self: *MuonOptimizer) void {
48+
for (self.states.items) |*s| s.deinit(self.allocator);
49+
self.states.deinit();
50+
}
51+
52+
pub fn registerTensor(self: *MuonOptimizer, num_params: usize) !usize {
53+
const state = try MuonState.init(self.allocator, num_params);
54+
const idx = self.states.items.len;
55+
try self.states.append(state);
56+
return idx;
57+
}
58+
59+
pub fn step(
60+
self: *MuonOptimizer,
61+
tensor_idx: usize,
62+
weights_2d: []f32,
63+
grads_2d: []const f32,
64+
rows: usize,
65+
cols: usize,
66+
) void {
67+
std.debug.assert(tensor_idx < self.states.items.len);
68+
std.debug.assert(weights_2d.len >= rows * cols);
69+
std.debug.assert(grads_2d.len >= rows * cols);
70+
71+
const state = &self.states.items[tensor_idx];
72+
const lr = self.config.lr;
73+
const mu = self.config.momentum;
74+
const wd = self.config.weight_decay;
75+
76+
for (0..rows) |r| {
77+
const row_offset = r * cols;
78+
const grad_row = grads_2d[row_offset .. row_offset + cols];
79+
const weight_row = weights_2d[row_offset .. row_offset + cols];
80+
const vel_row = state.velocity[row_offset .. row_offset + cols];
81+
82+
for (vel_row, 0..) |*v, i| {
83+
const g = grad_row[i];
84+
v.* = mu * v.* + g;
85+
if (wd > 0) {
86+
v.* -= wd * weight_row[i];
87+
}
88+
}
89+
90+
var orthogonal = tryOrthonormalize(vel_row);
91+
92+
const update_base = if (self.config.nesterov)
93+
mu * vel_row
94+
else
95+
vel_row;
96+
97+
for (weight_row, 0..) |*w, i| {
98+
const ns_scale = @max(std.math.sqrt(@as(f32, @floatFromInt(cols))), 1.0);
99+
w.* -= lr * (update_base[i] + orthogonal[i]) / ns_scale;
100+
}
101+
}
102+
103+
state.step_count += 1;
104+
}
105+
};
106+
107+
fn tryOrthonormalize(vec: []f32) []f32 {
108+
newtonSchulzIteration5(vec);
109+
return vec;
110+
}
111+
112+
pub fn newtonSchulzIteration5(vec: []f32) void {
113+
var x: f32 = 0;
114+
for (vec) |v| x += v * v;
115+
if (x < 1e-12) return;
116+
117+
var a = x;
118+
for (0..5) |_| {
119+
const b = a * a;
120+
const c = (a * (2.874_279_6 - b * 0.159_983_7)) / (2.939_741_5 + b * 0.281_969_3);
121+
a = c;
122+
}
123+
124+
const scale = 1.0 / @max(@sqrt(a), 1e-8);
125+
for (vec) |*v| v.* *= scale;
126+
}
127+
128+
pub fn newtonSchulz3x3(matrix: []f32, n: usize) void {
129+
for (0..5) |_| {
130+
var xt: [9]f32 = undefined;
131+
for (0..n) |i| {
132+
for (0..n) |j| {
133+
xt[i * n + j] = matrix[j * n + i];
134+
}
135+
}
136+
var aa: [9]f32 = undefined;
137+
matMul3x3(matrix, &xt, &aa, n);
138+
139+
var i2: [9]f32 = undefined;
140+
for (0..n * n) |idx| {
141+
i2[idx] = if (idx % (n + 1) == 0) 2.0 else 0.0;
142+
}
143+
var half_i: [9]f32 = undefined;
144+
for (0..n * n) |idx| {
145+
half_i[idx] = if (idx % (n + 1) == 0) 0.5 else 0.0;
146+
}
147+
148+
var correction: [9]f32 = undefined;
149+
matMul3x3(&i2, &aa, &correction, n);
150+
for (&correction) |*c| c.* *= 0.5;
151+
152+
var result: [9]f32 = undefined;
153+
matMul3x3(&half_i, &correction, &result, n);
154+
for (0..n * n) |idx| {
155+
matrix[idx] = result[idx];
156+
}
157+
}
158+
}
159+
160+
fn matMul3x3(a: []const f32, b: []const f32, out: []f32, n: usize) void {
161+
for (0..n) |i| {
162+
for (0..n) |j| {
163+
var s: f32 = 0;
164+
for (0..n) |k| {
165+
s += a[i * n + k] * b[k * n + j];
166+
}
167+
out[i * n + j] = s;
168+
}
169+
}
170+
}
171+
172+
pub const AdamWState = struct {
173+
m: []f32,
174+
v: []f32,
175+
step_count: u32,
176+
177+
pub fn init(allocator: std.mem.Allocator, num_params: usize) !AdamWState {
178+
const m = try allocator.alloc(f32, num_params);
179+
const v = try allocator.alloc(f32, num_params);
180+
@memset(m, 0);
181+
@memset(v, 0);
182+
return .{ .m = m, .v = v, .step_count = 0 };
183+
}
184+
185+
pub fn deinit(self: *AdamWState, allocator: std.mem.Allocator) void {
186+
allocator.free(self.m);
187+
allocator.free(self.v);
188+
}
189+
};
190+
191+
pub const HybridMuonAdamW = struct {
192+
muon: MuonOptimizer,
193+
adamw_lr: f32 = 3e-4,
194+
adamw_beta1: f32 = 0.9,
195+
adamw_beta2: f32 = 0.999,
196+
adamw_eps: f32 = 1e-8,
197+
adamw_wd: f32 = 0.1,
198+
adamw_states: std.ArrayList(AdamWState),
199+
allocator: std.mem.Allocator,
200+
201+
pub fn init(allocator: std.mem.Allocator, muon_config: MuonConfig) HybridMuonAdamW {
202+
return .{
203+
.muon = MuonOptimizer.init(allocator, muon_config),
204+
.adamw_states = std.ArrayList(AdamWState).init(allocator),
205+
.allocator = allocator,
206+
};
207+
}
208+
209+
pub fn deinit(self: *HybridMuonAdamW) void {
210+
self.muon.deinit();
211+
for (self.adamw_states.items) |*s| s.deinit(self.allocator);
212+
self.adamw_states.deinit();
213+
}
214+
215+
pub fn register2D(self: *HybridMuonAdamW, num_params: usize) !usize {
216+
return self.muon.registerTensor(num_params);
217+
}
218+
219+
pub fn register1D(self: *HybridMuonAdamW, num_params: usize) !usize {
220+
const state = try AdamWState.init(self.allocator, num_params);
221+
const idx = self.adamw_states.items.len;
222+
try self.adamw_states.append(state);
223+
return idx;
224+
}
225+
226+
pub fn stepAdamW(
227+
self: *HybridMuonAdamW,
228+
tensor_idx: usize,
229+
params: []f32,
230+
grads: []const f32,
231+
) void {
232+
std.debug.assert(tensor_idx < self.adamw_states.items.len);
233+
const state = &self.adamw_states.items[tensor_idx];
234+
state.step_count += 1;
235+
const t = @as(f32, @floatFromInt(state.step_count));
236+
const bias_corr1 = 1.0 - std.math.pow(f32, self.adamw_beta1, t);
237+
const bias_corr2 = 1.0 - std.math.pow(f32, self.adamw_beta2, t);
238+
239+
for (params, grads, 0..) |*p, g, i| {
240+
state.m[i] = self.adamw_beta1 * state.m[i] + (1.0 - self.adamw_beta1) * g;
241+
state.v[i] = self.adamw_beta2 * state.v[i] + (1.0 - self.adamw_beta2) * g * g;
242+
const m_hat = state.m[i] / bias_corr1;
243+
const v_hat = state.v[i] / bias_corr2;
244+
p.* -= self.adamw_lr * (m_hat / (std.math.sqrt(v_hat) + self.adamw_eps) + self.adamw_wd * p.*);
245+
}
246+
}
247+
};
248+
249+
test "Newton-Schulz 5-iteration normalizes vector" {
250+
var vec = [_]f32{ 3.0, 4.0 };
251+
newtonSchulzIteration5(&vec);
252+
253+
var norm: f32 = 0;
254+
for (vec) |v| norm += v * v;
255+
norm = std.math.sqrt(norm);
256+
try std.testing.expect(@abs(norm - 1.0) < 0.01);
257+
}
258+
259+
test "Newton-Schulz handles zero vector" {
260+
var vec = [_]f32{ 0.0, 0.0, 0.0 };
261+
newtonSchulzIteration5(&vec);
262+
for (vec) |v| try std.testing.expect(v == 0.0);
263+
}
264+
265+
test "Newton-Schulz preserves direction" {
266+
const original = [_]f32{ 1.0, 2.0, 3.0 };
267+
var vec = original;
268+
newtonSchulzIteration5(&vec);
269+
270+
const cross = original[0] * vec[1] - original[1] * vec[0];
271+
try std.testing.expect(@abs(cross) < 0.1);
272+
}
273+
274+
test "MuonOptimizer updates 2D weights" {
275+
const allocator = std.testing.allocator;
276+
var opt = MuonOptimizer.init(allocator, .{ .lr = 0.01, .ns_iterations = 5 });
277+
defer opt.deinit();
278+
279+
const rows: usize = 2;
280+
const cols: usize = 3;
281+
var weights = [_]f32{ 1.0, 0.5, -0.3, 0.8, -0.1, 0.2 };
282+
const grads = [_]f32{ 0.1, -0.2, 0.3, -0.1, 0.2, -0.3 };
283+
284+
const idx = try opt.registerTensor(rows * cols);
285+
opt.step(idx, &weights, &grads, rows, cols);
286+
287+
for (weights) |w| {
288+
try std.testing.expect(std.math.isFinite(w));
289+
}
290+
}
291+
292+
test "AdamW step reduces parameter magnitude with weight decay" {
293+
const allocator = std.testing.allocator;
294+
var hybrid = HybridMuonAdamW.init(allocator, .{ .lr = 0.01 });
295+
defer hybrid.deinit();
296+
297+
var params = [_]f32{ 1.0, 2.0, 3.0 };
298+
const grads = [_]f32{ 0.1, 0.1, 0.1 };
299+
300+
const idx = try hybrid.register1D(params.len);
301+
for (0..10) |_| {
302+
hybrid.stepAdamW(idx, &params, &grads);
303+
}
304+
305+
try std.testing.expect(params[0] < 1.0);
306+
try std.testing.expect(params[1] < 2.0);
307+
}
308+
309+
test "Hybrid optimizer handles both 1D and 2D tensors" {
310+
const allocator = std.testing.allocator;
311+
var hybrid = HybridMuonAdamW.init(allocator, .{ .lr = 0.01 });
312+
defer hybrid.deinit();
313+
314+
const idx_2d = try hybrid.register2D(6);
315+
const idx_1d = try hybrid.register1D(3);
316+
317+
var weights_2d = [_]f32{ 1.0, 0.5, -0.3, 0.8, -0.1, 0.2 };
318+
const grads_2d = [_]f32{ 0.1, -0.2, 0.3, -0.1, 0.2, -0.3 };
319+
var params_1d = [_]f32{ 1.0, 0.0, -1.0 };
320+
const grads_1d = [_]f32{ 0.5, 0.5, 0.5 };
321+
322+
hybrid.muon.step(idx_2d, &weights_2d, &grads_2d, 2, 3);
323+
hybrid.stepAdamW(idx_1d, &params_1d, &grads_1d);
324+
325+
for (weights_2d) |w| try std.testing.expect(std.math.isFinite(w));
326+
for (params_1d) |p| try std.testing.expect(std.math.isFinite(p));
327+
}

0 commit comments

Comments
 (0)