|
| 1 | +const std = @import("std"); |
| 2 | + |
| 3 | +pub const MuonConfig = struct { |
| 4 | + lr: f32 = 0.02, |
| 5 | + momentum: f32 = 0.95, |
| 6 | + ns_iterations: u32 = 5, |
| 7 | + weight_decay: f32 = 0.0, |
| 8 | + nesterov: bool = true, |
| 9 | +}; |
| 10 | + |
| 11 | +pub const MuonState = struct { |
| 12 | + velocity: []f32, |
| 13 | + momentum_buffer: []f32, |
| 14 | + step_count: u32, |
| 15 | + |
| 16 | + pub fn init(allocator: std.mem.Allocator, num_params: usize) !MuonState { |
| 17 | + const velocity = try allocator.alloc(f32, num_params); |
| 18 | + const momentum_buffer = try allocator.alloc(f32, num_params); |
| 19 | + @memset(velocity, 0); |
| 20 | + @memset(momentum_buffer, 0); |
| 21 | + return .{ |
| 22 | + .velocity = velocity, |
| 23 | + .momentum_buffer = momentum_buffer, |
| 24 | + .step_count = 0, |
| 25 | + }; |
| 26 | + } |
| 27 | + |
| 28 | + pub fn deinit(self: *MuonState, allocator: std.mem.Allocator) void { |
| 29 | + allocator.free(self.velocity); |
| 30 | + allocator.free(self.momentum_buffer); |
| 31 | + } |
| 32 | +}; |
| 33 | + |
| 34 | +pub const MuonOptimizer = struct { |
| 35 | + config: MuonConfig, |
| 36 | + allocator: std.mem.Allocator, |
| 37 | + states: std.ArrayList(MuonState), |
| 38 | + |
| 39 | + pub fn init(allocator: std.mem.Allocator, config: MuonConfig) MuonOptimizer { |
| 40 | + return .{ |
| 41 | + .config = config, |
| 42 | + .allocator = allocator, |
| 43 | + .states = std.ArrayList(MuonState).init(allocator), |
| 44 | + }; |
| 45 | + } |
| 46 | + |
| 47 | + pub fn deinit(self: *MuonOptimizer) void { |
| 48 | + for (self.states.items) |*s| s.deinit(self.allocator); |
| 49 | + self.states.deinit(); |
| 50 | + } |
| 51 | + |
| 52 | + pub fn registerTensor(self: *MuonOptimizer, num_params: usize) !usize { |
| 53 | + const state = try MuonState.init(self.allocator, num_params); |
| 54 | + const idx = self.states.items.len; |
| 55 | + try self.states.append(state); |
| 56 | + return idx; |
| 57 | + } |
| 58 | + |
| 59 | + pub fn step( |
| 60 | + self: *MuonOptimizer, |
| 61 | + tensor_idx: usize, |
| 62 | + weights_2d: []f32, |
| 63 | + grads_2d: []const f32, |
| 64 | + rows: usize, |
| 65 | + cols: usize, |
| 66 | + ) void { |
| 67 | + std.debug.assert(tensor_idx < self.states.items.len); |
| 68 | + std.debug.assert(weights_2d.len >= rows * cols); |
| 69 | + std.debug.assert(grads_2d.len >= rows * cols); |
| 70 | + |
| 71 | + const state = &self.states.items[tensor_idx]; |
| 72 | + const lr = self.config.lr; |
| 73 | + const mu = self.config.momentum; |
| 74 | + const wd = self.config.weight_decay; |
| 75 | + |
| 76 | + for (0..rows) |r| { |
| 77 | + const row_offset = r * cols; |
| 78 | + const grad_row = grads_2d[row_offset .. row_offset + cols]; |
| 79 | + const weight_row = weights_2d[row_offset .. row_offset + cols]; |
| 80 | + const vel_row = state.velocity[row_offset .. row_offset + cols]; |
| 81 | + |
| 82 | + for (vel_row, 0..) |*v, i| { |
| 83 | + const g = grad_row[i]; |
| 84 | + v.* = mu * v.* + g; |
| 85 | + if (wd > 0) { |
| 86 | + v.* -= wd * weight_row[i]; |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + var orthogonal = tryOrthonormalize(vel_row); |
| 91 | + |
| 92 | + const update_base = if (self.config.nesterov) |
| 93 | + mu * vel_row |
| 94 | + else |
| 95 | + vel_row; |
| 96 | + |
| 97 | + for (weight_row, 0..) |*w, i| { |
| 98 | + const ns_scale = @max(std.math.sqrt(@as(f32, @floatFromInt(cols))), 1.0); |
| 99 | + w.* -= lr * (update_base[i] + orthogonal[i]) / ns_scale; |
| 100 | + } |
| 101 | + } |
| 102 | + |
| 103 | + state.step_count += 1; |
| 104 | + } |
| 105 | +}; |
| 106 | + |
| 107 | +fn tryOrthonormalize(vec: []f32) []f32 { |
| 108 | + newtonSchulzIteration5(vec); |
| 109 | + return vec; |
| 110 | +} |
| 111 | + |
| 112 | +pub fn newtonSchulzIteration5(vec: []f32) void { |
| 113 | + var x: f32 = 0; |
| 114 | + for (vec) |v| x += v * v; |
| 115 | + if (x < 1e-12) return; |
| 116 | + |
| 117 | + var a = x; |
| 118 | + for (0..5) |_| { |
| 119 | + const b = a * a; |
| 120 | + const c = (a * (2.874_279_6 - b * 0.159_983_7)) / (2.939_741_5 + b * 0.281_969_3); |
| 121 | + a = c; |
| 122 | + } |
| 123 | + |
| 124 | + const scale = 1.0 / @max(@sqrt(a), 1e-8); |
| 125 | + for (vec) |*v| v.* *= scale; |
| 126 | +} |
| 127 | + |
| 128 | +pub fn newtonSchulz3x3(matrix: []f32, n: usize) void { |
| 129 | + for (0..5) |_| { |
| 130 | + var xt: [9]f32 = undefined; |
| 131 | + for (0..n) |i| { |
| 132 | + for (0..n) |j| { |
| 133 | + xt[i * n + j] = matrix[j * n + i]; |
| 134 | + } |
| 135 | + } |
| 136 | + var aa: [9]f32 = undefined; |
| 137 | + matMul3x3(matrix, &xt, &aa, n); |
| 138 | + |
| 139 | + var i2: [9]f32 = undefined; |
| 140 | + for (0..n * n) |idx| { |
| 141 | + i2[idx] = if (idx % (n + 1) == 0) 2.0 else 0.0; |
| 142 | + } |
| 143 | + var half_i: [9]f32 = undefined; |
| 144 | + for (0..n * n) |idx| { |
| 145 | + half_i[idx] = if (idx % (n + 1) == 0) 0.5 else 0.0; |
| 146 | + } |
| 147 | + |
| 148 | + var correction: [9]f32 = undefined; |
| 149 | + matMul3x3(&i2, &aa, &correction, n); |
| 150 | + for (&correction) |*c| c.* *= 0.5; |
| 151 | + |
| 152 | + var result: [9]f32 = undefined; |
| 153 | + matMul3x3(&half_i, &correction, &result, n); |
| 154 | + for (0..n * n) |idx| { |
| 155 | + matrix[idx] = result[idx]; |
| 156 | + } |
| 157 | + } |
| 158 | +} |
| 159 | + |
| 160 | +fn matMul3x3(a: []const f32, b: []const f32, out: []f32, n: usize) void { |
| 161 | + for (0..n) |i| { |
| 162 | + for (0..n) |j| { |
| 163 | + var s: f32 = 0; |
| 164 | + for (0..n) |k| { |
| 165 | + s += a[i * n + k] * b[k * n + j]; |
| 166 | + } |
| 167 | + out[i * n + j] = s; |
| 168 | + } |
| 169 | + } |
| 170 | +} |
| 171 | + |
| 172 | +pub const AdamWState = struct { |
| 173 | + m: []f32, |
| 174 | + v: []f32, |
| 175 | + step_count: u32, |
| 176 | + |
| 177 | + pub fn init(allocator: std.mem.Allocator, num_params: usize) !AdamWState { |
| 178 | + const m = try allocator.alloc(f32, num_params); |
| 179 | + const v = try allocator.alloc(f32, num_params); |
| 180 | + @memset(m, 0); |
| 181 | + @memset(v, 0); |
| 182 | + return .{ .m = m, .v = v, .step_count = 0 }; |
| 183 | + } |
| 184 | + |
| 185 | + pub fn deinit(self: *AdamWState, allocator: std.mem.Allocator) void { |
| 186 | + allocator.free(self.m); |
| 187 | + allocator.free(self.v); |
| 188 | + } |
| 189 | +}; |
| 190 | + |
| 191 | +pub const HybridMuonAdamW = struct { |
| 192 | + muon: MuonOptimizer, |
| 193 | + adamw_lr: f32 = 3e-4, |
| 194 | + adamw_beta1: f32 = 0.9, |
| 195 | + adamw_beta2: f32 = 0.999, |
| 196 | + adamw_eps: f32 = 1e-8, |
| 197 | + adamw_wd: f32 = 0.1, |
| 198 | + adamw_states: std.ArrayList(AdamWState), |
| 199 | + allocator: std.mem.Allocator, |
| 200 | + |
| 201 | + pub fn init(allocator: std.mem.Allocator, muon_config: MuonConfig) HybridMuonAdamW { |
| 202 | + return .{ |
| 203 | + .muon = MuonOptimizer.init(allocator, muon_config), |
| 204 | + .adamw_states = std.ArrayList(AdamWState).init(allocator), |
| 205 | + .allocator = allocator, |
| 206 | + }; |
| 207 | + } |
| 208 | + |
| 209 | + pub fn deinit(self: *HybridMuonAdamW) void { |
| 210 | + self.muon.deinit(); |
| 211 | + for (self.adamw_states.items) |*s| s.deinit(self.allocator); |
| 212 | + self.adamw_states.deinit(); |
| 213 | + } |
| 214 | + |
| 215 | + pub fn register2D(self: *HybridMuonAdamW, num_params: usize) !usize { |
| 216 | + return self.muon.registerTensor(num_params); |
| 217 | + } |
| 218 | + |
| 219 | + pub fn register1D(self: *HybridMuonAdamW, num_params: usize) !usize { |
| 220 | + const state = try AdamWState.init(self.allocator, num_params); |
| 221 | + const idx = self.adamw_states.items.len; |
| 222 | + try self.adamw_states.append(state); |
| 223 | + return idx; |
| 224 | + } |
| 225 | + |
| 226 | + pub fn stepAdamW( |
| 227 | + self: *HybridMuonAdamW, |
| 228 | + tensor_idx: usize, |
| 229 | + params: []f32, |
| 230 | + grads: []const f32, |
| 231 | + ) void { |
| 232 | + std.debug.assert(tensor_idx < self.adamw_states.items.len); |
| 233 | + const state = &self.adamw_states.items[tensor_idx]; |
| 234 | + state.step_count += 1; |
| 235 | + const t = @as(f32, @floatFromInt(state.step_count)); |
| 236 | + const bias_corr1 = 1.0 - std.math.pow(f32, self.adamw_beta1, t); |
| 237 | + const bias_corr2 = 1.0 - std.math.pow(f32, self.adamw_beta2, t); |
| 238 | + |
| 239 | + for (params, grads, 0..) |*p, g, i| { |
| 240 | + state.m[i] = self.adamw_beta1 * state.m[i] + (1.0 - self.adamw_beta1) * g; |
| 241 | + state.v[i] = self.adamw_beta2 * state.v[i] + (1.0 - self.adamw_beta2) * g * g; |
| 242 | + const m_hat = state.m[i] / bias_corr1; |
| 243 | + const v_hat = state.v[i] / bias_corr2; |
| 244 | + p.* -= self.adamw_lr * (m_hat / (std.math.sqrt(v_hat) + self.adamw_eps) + self.adamw_wd * p.*); |
| 245 | + } |
| 246 | + } |
| 247 | +}; |
| 248 | + |
| 249 | +test "Newton-Schulz 5-iteration normalizes vector" { |
| 250 | + var vec = [_]f32{ 3.0, 4.0 }; |
| 251 | + newtonSchulzIteration5(&vec); |
| 252 | + |
| 253 | + var norm: f32 = 0; |
| 254 | + for (vec) |v| norm += v * v; |
| 255 | + norm = std.math.sqrt(norm); |
| 256 | + try std.testing.expect(@abs(norm - 1.0) < 0.01); |
| 257 | +} |
| 258 | + |
| 259 | +test "Newton-Schulz handles zero vector" { |
| 260 | + var vec = [_]f32{ 0.0, 0.0, 0.0 }; |
| 261 | + newtonSchulzIteration5(&vec); |
| 262 | + for (vec) |v| try std.testing.expect(v == 0.0); |
| 263 | +} |
| 264 | + |
| 265 | +test "Newton-Schulz preserves direction" { |
| 266 | + const original = [_]f32{ 1.0, 2.0, 3.0 }; |
| 267 | + var vec = original; |
| 268 | + newtonSchulzIteration5(&vec); |
| 269 | + |
| 270 | + const cross = original[0] * vec[1] - original[1] * vec[0]; |
| 271 | + try std.testing.expect(@abs(cross) < 0.1); |
| 272 | +} |
| 273 | + |
| 274 | +test "MuonOptimizer updates 2D weights" { |
| 275 | + const allocator = std.testing.allocator; |
| 276 | + var opt = MuonOptimizer.init(allocator, .{ .lr = 0.01, .ns_iterations = 5 }); |
| 277 | + defer opt.deinit(); |
| 278 | + |
| 279 | + const rows: usize = 2; |
| 280 | + const cols: usize = 3; |
| 281 | + var weights = [_]f32{ 1.0, 0.5, -0.3, 0.8, -0.1, 0.2 }; |
| 282 | + const grads = [_]f32{ 0.1, -0.2, 0.3, -0.1, 0.2, -0.3 }; |
| 283 | + |
| 284 | + const idx = try opt.registerTensor(rows * cols); |
| 285 | + opt.step(idx, &weights, &grads, rows, cols); |
| 286 | + |
| 287 | + for (weights) |w| { |
| 288 | + try std.testing.expect(std.math.isFinite(w)); |
| 289 | + } |
| 290 | +} |
| 291 | + |
| 292 | +test "AdamW step reduces parameter magnitude with weight decay" { |
| 293 | + const allocator = std.testing.allocator; |
| 294 | + var hybrid = HybridMuonAdamW.init(allocator, .{ .lr = 0.01 }); |
| 295 | + defer hybrid.deinit(); |
| 296 | + |
| 297 | + var params = [_]f32{ 1.0, 2.0, 3.0 }; |
| 298 | + const grads = [_]f32{ 0.1, 0.1, 0.1 }; |
| 299 | + |
| 300 | + const idx = try hybrid.register1D(params.len); |
| 301 | + for (0..10) |_| { |
| 302 | + hybrid.stepAdamW(idx, ¶ms, &grads); |
| 303 | + } |
| 304 | + |
| 305 | + try std.testing.expect(params[0] < 1.0); |
| 306 | + try std.testing.expect(params[1] < 2.0); |
| 307 | +} |
| 308 | + |
| 309 | +test "Hybrid optimizer handles both 1D and 2D tensors" { |
| 310 | + const allocator = std.testing.allocator; |
| 311 | + var hybrid = HybridMuonAdamW.init(allocator, .{ .lr = 0.01 }); |
| 312 | + defer hybrid.deinit(); |
| 313 | + |
| 314 | + const idx_2d = try hybrid.register2D(6); |
| 315 | + const idx_1d = try hybrid.register1D(3); |
| 316 | + |
| 317 | + var weights_2d = [_]f32{ 1.0, 0.5, -0.3, 0.8, -0.1, 0.2 }; |
| 318 | + const grads_2d = [_]f32{ 0.1, -0.2, 0.3, -0.1, 0.2, -0.3 }; |
| 319 | + var params_1d = [_]f32{ 1.0, 0.0, -1.0 }; |
| 320 | + const grads_1d = [_]f32{ 0.5, 0.5, 0.5 }; |
| 321 | + |
| 322 | + hybrid.muon.step(idx_2d, &weights_2d, &grads_2d, 2, 3); |
| 323 | + hybrid.stepAdamW(idx_1d, ¶ms_1d, &grads_1d); |
| 324 | + |
| 325 | + for (weights_2d) |w| try std.testing.expect(std.math.isFinite(w)); |
| 326 | + for (params_1d) |p| try std.testing.expect(std.math.isFinite(p)); |
| 327 | +} |
0 commit comments