From 8606d3a6e97e46384f0957cc4eacb682018de460 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 9 Feb 2026 14:05:45 +0800 Subject: [PATCH 1/4] fix: FSDP2 do not support foreach ops in HybridMuon --- deepmd/pt/optimizer/hybrid_muon.py | 159 ++++++---------------------- source/tests/pt/test_hybrid_muon.py | 14 +-- 2 files changed, 37 insertions(+), 136 deletions(-) diff --git a/deepmd/pt/optimizer/hybrid_muon.py b/deepmd/pt/optimizer/hybrid_muon.py index abf4d3a572..e2b6db1195 100644 --- a/deepmd/pt/optimizer/hybrid_muon.py +++ b/deepmd/pt/optimizer/hybrid_muon.py @@ -84,22 +84,7 @@ NS_COEFF_C: float = 2.0315 -def _maybe_compile( - fn: callable, -) -> callable: - """Compile a function if torch.compile is available.""" - if not hasattr(torch, "compile"): - return fn - # Skip compile if default device is CUDA but CUDA is unavailable. - if hasattr(torch, "get_default_device"): - default_device = torch.get_default_device() - if default_device.type == "cuda" and not torch.cuda.is_available(): - return fn - return torch.compile(fn, fullgraph=True, dynamic=True) - - -@_maybe_compile -def _zeropower_via_newtonschulz5_2d( +def _newton_schulz_orth( G: torch.Tensor, ) -> torch.Tensor: """ @@ -132,70 +117,6 @@ def _zeropower_via_newtonschulz5_2d( return X -@_maybe_compile -def _zeropower_via_newtonschulz5_3d( - G: torch.Tensor, -) -> torch.Tensor: - """ - Orthogonalize a 3D batch of matrices via quintic Newton-Schulz iteration. - - Mathematical formulation: - X_0 = G / ||G||_F - X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T - Coefficients: a=3.4445, b=-4.7750, c=2.0315 - """ - # === Step 1. Cast to bf16 and transpose tall matrices === - X = G.to(dtype=torch.bfloat16) - transposed = X.size(-2) > X.size(-1) - if transposed: - X = X.transpose(-2, -1) - - # === Step 2. Normalize Frobenius norm to at most 1 === - X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=EPS) - - # === Step 3. Newton-Schulz iterations with batched fused GEMM === - for _ in range(NS_STEPS): - A = torch.bmm(X, X.transpose(-2, -1)) - gram_update = torch.baddbmm(A, A, A, beta=NS_COEFF_B, alpha=NS_COEFF_C) - X = torch.baddbmm(X, gram_update, X, beta=NS_COEFF_A, alpha=1.0) - - # === Step 4. Transpose back if needed === - if transposed: - X = X.transpose(-2, -1) - - return X - - -def zeropower_via_newtonschulz5( - G: torch.Tensor, -) -> torch.Tensor: - """ - Compute the zeroth power (orthogonalization) via Newton-Schulz iteration. - - Dispatches to compiled 2D or 3D kernels for best performance. - - Parameters - ---------- - G : torch.Tensor - Input matrix with shape (M, N) or batched input with shape (B, M, N). - - Returns - ------- - torch.Tensor - Orthogonalized tensor in bfloat16 with same shape as input. - - Raises - ------ - ValueError - If input is not 2D or 3D. - """ - if G.ndim == 2: - return _zeropower_via_newtonschulz5_2d(G) - if G.ndim == 3: - return _zeropower_via_newtonschulz5_3d(G) - raise ValueError("Input must be 2D or 3D for Newton-Schulz orthogonalization.") - - def should_fallback_to_adam_for_matrix( p: torch.Tensor, min_2d_dim: int, @@ -478,9 +399,11 @@ def step( # exp_avg = beta1 * exp_avg + (1 - beta1) * grad # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 - torch._foreach_lerp_(adam_exp_avgs, adam_grads_fp32, 1 - adam_betas[0]) - grad_sq = torch._foreach_mul(adam_grads_fp32, adam_grads_fp32) - torch._foreach_lerp_(adam_exp_avg_sqs, grad_sq, 1 - adam_betas[1]) + for ea, g in zip(adam_exp_avgs, adam_grads_fp32): + ea.lerp_(g, 1 - adam_betas[0]) + grad_sq = [g * g for g in adam_grads_fp32] + for eas, gsq in zip(adam_exp_avg_sqs, grad_sq): + eas.lerp_(gsq, 1 - adam_betas[1]) # === Step 1.3. Bias correction and parameter update === for i, p in enumerate(adam_params): @@ -531,11 +454,11 @@ def step( # exp_avg = beta1 * exp_avg + (1 - beta1) * grad # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 - torch._foreach_lerp_( - adam_nd_exp_avgs, adam_nd_grads_fp32, 1 - adam_betas[0] - ) - grad_sq = torch._foreach_mul(adam_nd_grads_fp32, adam_nd_grads_fp32) - torch._foreach_lerp_(adam_nd_exp_avg_sqs, grad_sq, 1 - adam_betas[1]) + for ea, g in zip(adam_nd_exp_avgs, adam_nd_grads_fp32): + ea.lerp_(g, 1 - adam_betas[0]) + grad_sq = [g * g for g in adam_nd_grads_fp32] + for eas, gsq in zip(adam_nd_exp_avg_sqs, grad_sq): + eas.lerp_(gsq, 1 - adam_betas[1]) # === Step 2.3. Bias correction and parameter update === for i, p in enumerate(adam_nd_params): @@ -589,15 +512,11 @@ def step( # exp_avg = beta1 * exp_avg + (1 - beta1) * grad # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 - torch._foreach_lerp_( - adam_matrix_exp_avgs, adam_matrix_grads_fp32, 1 - adam_betas[0] - ) - grad_sq_m = torch._foreach_mul( - adam_matrix_grads_fp32, adam_matrix_grads_fp32 - ) - torch._foreach_lerp_( - adam_matrix_exp_avg_sqs, grad_sq_m, 1 - adam_betas[1] - ) + for ea, g in zip(adam_matrix_exp_avgs, adam_matrix_grads_fp32): + ea.lerp_(g, 1 - adam_betas[0]) + grad_sq_m = [g * g for g in adam_matrix_grads_fp32] + for eas, gsq in zip(adam_matrix_exp_avg_sqs, grad_sq_m): + eas.lerp_(gsq, 1 - adam_betas[1]) # === Step 3.3. Compute unclipped deltas === raw_deltas: list[torch.Tensor] = [] @@ -611,8 +530,8 @@ def step( # === Step 3.4. Clip updates by relative norm and apply === max_rel_change = 0.05 - p_norms = torch.stack(torch._foreach_norm(adam_matrix_params)) - delta_norms = torch.stack(torch._foreach_norm(raw_deltas)) + p_norms = torch.stack([p.norm() for p in adam_matrix_params]) + delta_norms = torch.stack([d.norm() for d in raw_deltas]) floors = torch.tensor( adam_matrix_abs_floor, device=p_norms.device, @@ -653,18 +572,21 @@ def step( # === Step 4.2. Apply weight decay (Muon path only) === if weight_decay > 0 and muon_params_for_decay: - torch._foreach_mul_(muon_params_for_decay, 1.0 - lr * weight_decay) + for p in muon_params_for_decay: + p.mul_(1.0 - lr * weight_decay) if not active_entries: continue # === Step 4.3. Momentum update (Nesterov) === # m_t = beta * m_{t-1} + (1 - beta) * g_t - torch._foreach_lerp_(muon_momentum_buffers, muon_grads, 1 - momentum) + for buf, g in zip(muon_momentum_buffers, muon_grads): + buf.lerp_(g, 1 - momentum) # update = beta * m_t + (1 - beta) * g_t - muon_updates = torch._foreach_lerp( - muon_grads, muon_momentum_buffers, momentum - ) + muon_updates = [ + torch.lerp(g, buf, momentum) + for g, buf in zip(muon_grads, muon_momentum_buffers) + ] # === Step 4.4. Bucket by shape/device/dtype for batched NS === buckets: dict[ @@ -689,37 +611,16 @@ def step( else: scale = max(1.0, rows / cols) ** 0.5 - if len(bucket_entries) == 1: - entry, update_tensor = bucket_entries[0] + # Process each entry individually with _newton_schulz_orth. + # compatible with sharding propagation under FSDP2. + for entry, update_tensor in bucket_entries: update_matrix = update_tensor.reshape(rows, cols) if not update_matrix.is_contiguous(): update_matrix = update_matrix.contiguous() - orth = _zeropower_via_newtonschulz5_2d(update_matrix) + orth = _newton_schulz_orth(update_matrix) orth.mul_(scale) delta = orth.reshape(entry["param"].shape) entry["param"].add_(delta, alpha=-lr) - continue - - matrices: list[torch.Tensor] = [] - params: list[torch.Tensor] = [] - orig_shapes: list[tuple[int, ...]] = [] - - for entry, update_tensor in bucket_entries: - update_matrix = update_tensor.reshape(rows, cols) - matrices.append( - update_matrix - if update_matrix.is_contiguous() - else update_matrix.contiguous() - ) - params.append(entry["param"]) - orig_shapes.append(entry["param"].shape) - - stacked = torch.stack(matrices, dim=0) - orth = _zeropower_via_newtonschulz5_3d(stacked) - orth.mul_(scale) - - for i, _ in enumerate(bucket_entries): - params[i].add_(orth[i].reshape(orig_shapes[i]), alpha=-lr) return loss diff --git a/source/tests/pt/test_hybrid_muon.py b/source/tests/pt/test_hybrid_muon.py index 77973c5728..6d94227a21 100644 --- a/source/tests/pt/test_hybrid_muon.py +++ b/source/tests/pt/test_hybrid_muon.py @@ -5,7 +5,7 @@ from deepmd.pt.optimizer.hybrid_muon import ( HybridMuonOptimizer, - zeropower_via_newtonschulz5, + _newton_schulz_orth, ) from deepmd.pt.utils import ( env, @@ -48,7 +48,7 @@ def test_orthogonalization(self) -> None: """Test that NS produces approximately orthogonal output.""" torch.manual_seed(42) G = torch.randn(4, 4, dtype=torch.float32, device=self.device) - X = zeropower_via_newtonschulz5(G) + X = _newton_schulz_orth(G) # X @ X.T should be approximately identity # Note: NS uses bf16 internally, 5 iterations gives ~0.1-0.3 error @@ -68,17 +68,17 @@ def test_orthogonalization(self) -> None: def test_shape_and_dtype(self) -> None: """Test that output preserves shape and returns bf16.""" torch.manual_seed(42) - for shape in [(4, 4), (6, 4), (3, 4, 4)]: + for shape in [(4, 4), (6, 4)]: G = torch.randn(*shape, dtype=torch.float32, device=self.device) - X = zeropower_via_newtonschulz5(G) + X = _newton_schulz_orth(G) self.assertEqual(X.shape, G.shape) self.assertEqual(X.dtype, torch.bfloat16) def test_invalid_input(self) -> None: - """Test that <2D input raises ValueError.""" + """Test that 1D input raises error.""" G_1d = torch.randn(10, dtype=torch.float32, device=self.device) - with self.assertRaises(ValueError): - zeropower_via_newtonschulz5(G_1d) + with self.assertRaises((ValueError, RuntimeError, IndexError)): + _newton_schulz_orth(G_1d) @unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") From 5f19065d4a8534ce24c3f14cc353ae27e5573557 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sat, 14 Feb 2026 11:42:30 +0800 Subject: [PATCH 2/4] feat: use flash Muon --- deepmd/pt/optimizer/hybrid_muon.py | 221 +++++++++++++++++++++++++++- deepmd/pt/train/training.py | 2 + deepmd/utils/argcheck.py | 10 ++ source/tests/pt/test_hybrid_muon.py | 83 +++++++++++ 4 files changed, 313 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/optimizer/hybrid_muon.py b/deepmd/pt/optimizer/hybrid_muon.py index e2b6db1195..691b75a093 100644 --- a/deepmd/pt/optimizer/hybrid_muon.py +++ b/deepmd/pt/optimizer/hybrid_muon.py @@ -48,6 +48,8 @@ https://arxiv.org/abs/2502.16982 .. [3] Moonlight GitHub Repository. https://github.com/MoonshotAI/Moonlight +.. [4] Flash-Muon: Triton-accelerated symmetric matmul for Newton-Schulz. + https://github.com/lintianyang/flash-muon (MIT License, Tianyang Lin) """ from __future__ import ( @@ -70,6 +72,18 @@ Iterable, ) +# ============================================================================ +# Triton availability detection +# ============================================================================ + +try: + import triton + import triton.language as tl + + TRITON_AVAILABLE = True +except ImportError: + TRITON_AVAILABLE = False + # ============================================================================ # Constants # ============================================================================ @@ -84,6 +98,156 @@ NS_COEFF_C: float = 2.0315 +# ============================================================================ +# Triton-accelerated symmetric matmul kernel (from flash-muon [4]) +# ============================================================================ + +if TRITON_AVAILABLE: + + def _get_autotune_config(): # noqa: ANN202 + return [ + triton.Config( + { + "BLOCK_SIZE_M": blk_m, + "BLOCK_SIZE_K": blk_k, + "GROUP_SIZE_M": 8, + }, + num_stages=n_stages, + num_warps=n_warps, + ) + for blk_m in [32, 64, 128] + for blk_k in [32, 64] + for n_stages in [3, 4, 5] + for n_warps in [4, 8] + ] + + @triton.autotune(configs=_get_autotune_config(), key=["M", "K"]) + @triton.jit + def _mmt_kernel( + x, # noqa: ANN001 + y, # noqa: ANN001 + M, # noqa: ANN001 + K, # noqa: ANN001 + stride_xm, # noqa: ANN001 + stride_xk, # noqa: ANN001 + stride_ym, # noqa: ANN001 + stride_yn, # noqa: ANN001 + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ) -> None: + """Compute y = x @ x.T, exploiting symmetry (upper triangle only).""" + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + # Skip lower triangle — mirror from upper triangle instead + if pid_m > pid_n: + return + + offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) + b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) + a_ptrs += BLOCK_SIZE_K * stride_xk + b_ptrs += BLOCK_SIZE_K * stride_xk + + c = accumulator.to(x.dtype.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, c, mask=c_mask) + + # Transpose-and-copy: mirror upper triangle to lower + if pid_m < pid_n: + ct_ptrs = y + stride_ym * offs_cn[:, None] + stride_yn * offs_cm[None, :] + ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) + + def _matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using triton symmetric matmul kernel.""" + d_in = d_in.contiguous() + M, K = d_in.shape + grid = lambda META: ( # noqa: E731 + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(M, META["BLOCK_SIZE_M"]), + ) + with torch.cuda.device(d_in.device.index): + _mmt_kernel[grid]( + d_in, + d_out, + M, + K, + d_in.stride(0), + d_in.stride(1), + d_out.stride(0), + d_out.stride(1), + ) + + +# ============================================================================ +# Flash Newton-Schulz orthogonalization (triton-accelerated) +# ============================================================================ + + +def _flash_newton_schulz_orth( + G: torch.Tensor, + buf1: torch.Tensor, + buf2: torch.Tensor, +) -> torch.Tensor: + """ + Orthogonalize a 2D matrix via quintic Newton-Schulz with triton-accelerated + symmetric matmul. Mathematically equivalent to ``_newton_schulz_orth``. + + Parameters + ---------- + G : torch.Tensor + Input 2D gradient/update matrix with shape (m, n). + buf1 : torch.Tensor + Pre-allocated buffer with shape (M, M) where M = min(m, n), in bfloat16. + buf2 : torch.Tensor + Pre-allocated buffer with shape (M, M) where M = min(m, n), in bfloat16. + + Returns + ------- + torch.Tensor + Orthogonalized matrix in bfloat16 with shape (m, n). + """ + # === Step 1. Cast to bf16 and transpose tall matrices === + X = G.to(dtype=torch.bfloat16) + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.transpose(-2, -1) + + # === Step 2. Normalize Frobenius norm to at most 1 === + X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=EPS) + + # === Step 3. Newton-Schulz iterations with triton symmetric matmul === + for _ in range(NS_STEPS): + _matmul_transpose_assign(X, buf1) # buf1 = X @ X.T = A + _matmul_transpose_assign(buf1, buf2) # buf2 = A @ A.T = A² (A symmetric) + B = NS_COEFF_B * buf1 + NS_COEFF_C * buf2 + X = NS_COEFF_A * X + B @ X + + # === Step 4. Transpose back if needed === + if transposed: + X = X.transpose(-2, -1) + + return X + + def _newton_schulz_orth( G: torch.Tensor, ) -> torch.Tensor: @@ -219,6 +383,11 @@ class HybridMuonOptimizer(Optimizer): Must be >= 1. Set to 1 to disable fallback. Default is 1. + flash_muon : bool + Enable triton-accelerated Newton-Schulz orthogonalization. + Requires triton and CUDA. Falls back to PyTorch implementation + when triton is unavailable or running on CPU. + Default is True. Examples -------- @@ -240,6 +409,7 @@ def __init__( lr_adjust_coeff: float = 0.2, muon_2d_only: bool = True, min_2d_dim: int = 1, + flash_muon: bool = True, ) -> None: if min_2d_dim < 1: raise ValueError("min_2d_dim must be >= 1.") @@ -259,6 +429,42 @@ def __init__( self._routing_built = False self._routing: list[dict[str, Any]] = [] + # Flash-Muon: triton-accelerated Newton-Schulz + self._use_flash = flash_muon and TRITON_AVAILABLE + # Lazily allocated NS iteration buffers, keyed by (M, device) + self._ns_buffers: dict[ + tuple[int, torch.device], + tuple[torch.Tensor, torch.Tensor], + ] = {} + + def _get_ns_buffers( + self, + M: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get or lazily allocate pre-allocated buffers for flash Newton-Schulz. + + Parameters + ---------- + M : int + Square buffer dimension (= min(rows, cols) of the update matrix). + device : torch.device + Target CUDA device. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + (buf1, buf2), each with shape (M, M) in bfloat16. + """ + key = (M, device) + if key not in self._ns_buffers: + self._ns_buffers[key] = ( + torch.empty(M, M, dtype=torch.bfloat16, device=device), + torch.empty(M, M, dtype=torch.bfloat16, device=device), + ) + return self._ns_buffers[key] + def _build_param_routing(self) -> None: """ Classify parameters into Muon and Adam routes (static routing). @@ -611,14 +817,23 @@ def step( else: scale = max(1.0, rows / cols) ** 0.5 - # Process each entry individually with _newton_schulz_orth. - # compatible with sharding propagation under FSDP2. + # Determine if flash path is usable for this bucket + use_flash = self._use_flash and _device.type == "cuda" + if use_flash: + M = min(rows, cols) + buf1, buf2 = self._get_ns_buffers(M, _device) + + # Process each entry individually with Newton-Schulz orth. + # Compatible with sharding propagation under FSDP2. for entry, update_tensor in bucket_entries: update_matrix = update_tensor.reshape(rows, cols) if not update_matrix.is_contiguous(): update_matrix = update_matrix.contiguous() - orth = _newton_schulz_orth(update_matrix) + if use_flash: + orth = _flash_newton_schulz_orth(update_matrix, buf1, buf2) + else: + orth = _newton_schulz_orth(update_matrix) orth.mul_(scale) delta = orth.reshape(entry["param"].shape) entry["param"].add_(delta, alpha=-lr) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 5b35206661..42077e247e 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -175,6 +175,7 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]: "lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2), "muon_2d_only": params.get("muon_2d_only", True), "min_2d_dim": params.get("min_2d_dim", 1), + "flash_muon": params.get("flash_muon", True), } return opt_type, opt_param @@ -763,6 +764,7 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]), muon_2d_only=bool(self.opt_param["muon_2d_only"]), min_2d_dim=int(self.opt_param["min_2d_dim"]), + flash_muon=bool(self.opt_param["flash_muon"]), ) if optimizer_state_dict is not None and self.restart_training: self.optimizer.load_state_dict(optimizer_state_dict) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 8c20bb8bf4..7424f31aa6 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3554,6 +3554,16 @@ def training_args( "those with min(m, n) < min_2d_dim use Adam fallback. " "Set to 1 to disable fallback.", ), + Argument( + "flash_muon", + bool, + optional=True, + default=True, + doc=doc_only_pt_supported + + "Enable triton-accelerated Newton-Schulz orthogonalization. " + "Requires triton and CUDA. Falls back to PyTorch implementation " + "when triton is unavailable or running on CPU.", + ), ], [], optional=True, diff --git a/source/tests/pt/test_hybrid_muon.py b/source/tests/pt/test_hybrid_muon.py index 6d94227a21..78aa267a9d 100644 --- a/source/tests/pt/test_hybrid_muon.py +++ b/source/tests/pt/test_hybrid_muon.py @@ -4,6 +4,7 @@ import torch from deepmd.pt.optimizer.hybrid_muon import ( + TRITON_AVAILABLE, HybridMuonOptimizer, _newton_schulz_orth, ) @@ -229,5 +230,87 @@ def test_state_dict_save_load(self) -> None: self.assertEqual(s1[key], s2[key]) +@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") +class TestFlashMuon(unittest.TestCase): + """Test flash_muon triton-accelerated Newton-Schulz path.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_flash_muon_false_runs(self) -> None: + """Test that flash_muon=False uses pure PyTorch path without error.""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 20, device=self.device) + optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02, flash_muon=False) + x = torch.randn(4, 10, device=self.device) + model(x).sum().backward() + optimizer.step() + # Should complete without error + + def test_flash_muon_true_runs(self) -> None: + """Test that flash_muon=True runs (falls back on CPU, uses triton on CUDA).""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 20, device=self.device) + optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02, flash_muon=True) + x = torch.randn(4, 10, device=self.device) + model(x).sum().backward() + optimizer.step() + + def test_flash_vs_pytorch_consistency(self) -> None: + """Test that flash and non-flash paths produce consistent results. + + On CPU (no triton), both paths are identical (PyTorch fallback). + On CUDA with triton, results should be close (same math, bf16 rounding). + """ + torch.manual_seed(42) + model1 = torch.nn.Linear(32, 64, device=self.device) + model2 = torch.nn.Linear(32, 64, device=self.device) + model2.load_state_dict(model1.state_dict()) + + opt1 = HybridMuonOptimizer(model1.parameters(), lr=0.02, flash_muon=False) + opt2 = HybridMuonOptimizer(model2.parameters(), lr=0.02, flash_muon=True) + + x = torch.randn(4, 32, device=self.device) + + opt1.zero_grad() + model1(x).sum().backward() + opt1.step() + + opt2.zero_grad() + model2(x).sum().backward() + opt2.step() + + # Both paths should produce similar results + self.assertTrue( + torch.allclose(model1.weight, model2.weight, atol=1e-2), + f"Flash and non-flash weight diff: {(model1.weight - model2.weight).abs().max().item():.6f}", + ) + self.assertTrue( + torch.allclose(model1.bias, model2.bias, atol=1e-2), + f"Flash and non-flash bias diff: {(model1.bias - model2.bias).abs().max().item():.6f}", + ) + + @unittest.skipIf( + not (TRITON_AVAILABLE and env.DEVICE.type == "cuda"), + "Triton + CUDA required for flash path verification", + ) + def test_flash_path_actually_used(self) -> None: + """Verify that flash path is actually active when triton + CUDA available.""" + torch.manual_seed(42) + model = torch.nn.Linear(32, 64, device=self.device) + optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02, flash_muon=True) + # _use_flash should be True when triton is available + self.assertTrue(optimizer._use_flash) + # _ns_buffers should be empty before first step + self.assertEqual(len(optimizer._ns_buffers), 0) + + x = torch.randn(4, 32, device=self.device) + model(x).sum().backward() + optimizer.step() + + # After step, buffers should have been allocated for the weight matrix + self.assertGreater(len(optimizer._ns_buffers), 0) + + if __name__ == "__main__": unittest.main() From 5703ce38154f6c216a05f268784a0de81d3a6e1c Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sat, 14 Feb 2026 12:27:00 +0800 Subject: [PATCH 3/4] feat: add FLASH_MIN_DIM for Muon --- deepmd/pt/optimizer/hybrid_muon.py | 14 +++++++++++--- source/tests/pt/test_hybrid_muon.py | 10 ++++++++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/optimizer/hybrid_muon.py b/deepmd/pt/optimizer/hybrid_muon.py index 691b75a093..06d682d286 100644 --- a/deepmd/pt/optimizer/hybrid_muon.py +++ b/deepmd/pt/optimizer/hybrid_muon.py @@ -96,6 +96,10 @@ NS_COEFF_A: float = 3.4445 NS_COEFF_B: float = -4.7750 NS_COEFF_C: float = 2.0315 +# Minimum matrix dimension for flash path to be beneficial. +# Below this threshold, triton kernel launch overhead dominates over compute, +# and cuBLAS (via torch.mm/addmm) is faster for small matrices. +FLASH_MIN_DIM: int = 1024 # ============================================================================ @@ -817,10 +821,14 @@ def step( else: scale = max(1.0, rows / cols) ** 0.5 - # Determine if flash path is usable for this bucket - use_flash = self._use_flash and _device.type == "cuda" + # Determine if flash path is usable for this bucket. + # Only beneficial when min(rows, cols) >= FLASH_MIN_DIM; + # for small matrices, triton launch overhead > compute savings. + M = min(rows, cols) + use_flash = ( + self._use_flash and _device.type == "cuda" and M >= FLASH_MIN_DIM + ) if use_flash: - M = min(rows, cols) buf1, buf2 = self._get_ns_buffers(M, _device) # Process each entry individually with Newton-Schulz orth. diff --git a/source/tests/pt/test_hybrid_muon.py b/source/tests/pt/test_hybrid_muon.py index 78aa267a9d..f28e014188 100644 --- a/source/tests/pt/test_hybrid_muon.py +++ b/source/tests/pt/test_hybrid_muon.py @@ -296,15 +296,21 @@ def test_flash_vs_pytorch_consistency(self) -> None: ) def test_flash_path_actually_used(self) -> None: """Verify that flash path is actually active when triton + CUDA available.""" + from deepmd.pt.optimizer.hybrid_muon import ( + FLASH_MIN_DIM, + ) + torch.manual_seed(42) - model = torch.nn.Linear(32, 64, device=self.device) + # Use matrix large enough to exceed FLASH_MIN_DIM threshold + dim = max(FLASH_MIN_DIM, 128) + model = torch.nn.Linear(dim, dim * 2, device=self.device) optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02, flash_muon=True) # _use_flash should be True when triton is available self.assertTrue(optimizer._use_flash) # _ns_buffers should be empty before first step self.assertEqual(len(optimizer._ns_buffers), 0) - x = torch.randn(4, 32, device=self.device) + x = torch.randn(4, dim, device=self.device) model(x).sum().backward() optimizer.step() From aa3298927c0cb48d3d261876a75c67696067c399 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 15 Feb 2026 14:42:13 +0800 Subject: [PATCH 4/4] refactor: Muon --- deepmd/pt/optimizer/hybrid_muon.py | 95 +++++++++++++++++++----------- deepmd/utils/argcheck.py | 12 ++-- 2 files changed, 65 insertions(+), 42 deletions(-) diff --git a/deepmd/pt/optimizer/hybrid_muon.py b/deepmd/pt/optimizer/hybrid_muon.py index 06d682d286..1083b30107 100644 --- a/deepmd/pt/optimizer/hybrid_muon.py +++ b/deepmd/pt/optimizer/hybrid_muon.py @@ -2,10 +2,19 @@ """ HybridMuon optimizer for DeePMD-kit PyTorch backend. -HybridMuon is a HYBRID optimizer that automatically combines Muon and Adam: -- For >=2D parameters with min(m,n) >= min_2d_dim: Muon update with Newton-Schulz -- For 2D parameters with min(m,n) < min_2d_dim: Adam fallback with update clipping -- For 1D parameters (biases, layer norms): Standard Adam +HybridMuon is a hybrid optimizer that automatically combines Muon and Adam. +Routing is controlled by parameter dimensionality and ``muon_2d_only``: + +- 1D parameters (biases, norms): Adam (no weight decay). +- When ``muon_2d_only=True`` (default): + - 2D parameters: Muon if ``min(m, n) >= min_2d_dim``, else Adam fallback. + - >2D parameters: Adam. +- When ``muon_2d_only=False``: + - >=2D parameters use matrix-view routing: + Muon if ``min(m, n) >= min_2d_dim``, else Adam fallback. + +For matrix-view routing, any parameter with ndim >= 2 is reshaped as: +``(rows, cols) = (numel // shape[-1], shape[-1])``. This is different from PyTorch's torch.optim.Muon, which ONLY supports 2D parameters and requires manual configuration of AdamW for 1D parameters. HybridMuon provides @@ -13,7 +22,7 @@ Algorithm --------- -For >=2D parameters (weight matrices), the Muon update is: +For Muon-routed parameters, the update is: 1. Momentum update (Nesterov): m_t = beta * m_{t-1} + (1 - beta) * g_t @@ -29,7 +38,8 @@ 4. Parameter update: theta -= lr * scale * orth(update) -For 1D parameters (biases, norms), standard Adam is used. +For Adam-routed parameters, standard Adam moments are used. +AdamW behavior (decoupled weight decay) is applied only on >=2D Adam paths. Dtype Behavior -------------- @@ -290,7 +300,7 @@ def should_fallback_to_adam_for_matrix( min_2d_dim: int, ) -> bool: """ - Check if a 2D matrix should fallback to Adam due to small dimensions. + Check if a parameter should fallback to Adam based on matrix-view dimensions. Parameters ---------- @@ -315,8 +325,11 @@ def should_fallback_to_adam_for_matrix( raise ValueError("Parameter must have ndim >= 2 for Muon suitability check.") # === Step 2. Derive matrix shape consistent with Muon reshape === - m = int(p.shape[0]) - n = int(p.numel() // p.shape[0]) + # Flatten all leading axes into rows and keep the last axis as cols. + # This preserves the "input-channel axis" as the NS orthogonalization space + # for N-D linear weights (e.g., (..., C_out, C_in) -> (-1, C_in)). + m = int(p.numel() // p.shape[-1]) + n = int(p.shape[-1]) # === Step 3. Check if any dimension too small for Muon === return min(m, n) < min_2d_dim @@ -324,15 +337,17 @@ def should_fallback_to_adam_for_matrix( class HybridMuonOptimizer(Optimizer): """ - HybridMuon optimizer with small-2D Adam fallback and 1D Adam path. - - This optimizer applies different update rules based on parameter dimensionality: - - For >=2D parameters with min(m, n) >= min_2d_dim: - Muon update with Newton-Schulz orthogonalization. - - For 2D parameters with min(m, n) < min_2d_dim (small matrices): - Adam update with scaled learning rate and update clipping. - - For 1D parameters (biases, layer norms): - Standard Adam update. + HybridMuon optimizer with small-matrix Adam fallback and 1D Adam path. + + This optimizer applies different update rules based on parameter dimensionality + and ``muon_2d_only``: + - 1D parameters (biases, layer norms): standard Adam update. + - When ``muon_2d_only=True``: + - 2D parameters use Muon/Adam-fallback according to ``min_2d_dim``. + - >2D parameters use Adam. + - When ``muon_2d_only=False``: + - >=2D parameters use matrix-view Muon/Adam-fallback according to + ``min_2d_dim``. This hybrid approach is effective because Muon's orthogonalization is designed for weight matrices, while Adam is more suitable for biases and normalization params. @@ -346,8 +361,9 @@ class HybridMuonOptimizer(Optimizer): 4. Scaling: scale = coeff*sqrt(max(m,n)) or sqrt(max(1, m/n)) 5. Parameter update: theta -= lr * scale * orth - Adam (1D params): + Adam: Standard Adam with bias correction, all computations in float32. + Decoupled weight decay is applied only to >=2D Adam-routed parameters. Parameters ---------- @@ -358,7 +374,9 @@ class HybridMuonOptimizer(Optimizer): momentum : float Momentum coefficient for Muon with default 0.95. weight_decay : float - Weight decay coefficient (applied only to Muon-routed parameters) with default 0.001. + Weight decay coefficient with default 0.001. + Applied to Muon-routed parameters and >=2D Adam-routed parameters + with AdamW-style decoupled decay. Not applied to 1D Adam parameters. adam_betas : tuple[float, float] Adam beta coefficients with default (0.9, 0.95). lr_adjust : float @@ -372,17 +390,17 @@ class HybridMuonOptimizer(Optimizer): Dual-purpose coefficient with default 0.2: 1. For Muon (when lr_adjust <= 0): match-RMS scaling factor, scale = lr_adjust_coeff * sqrt(max(m, n)). - 2. For 2D Adam fallback: learning rate multiplier, + 2. For matrix Adam fallback: learning rate multiplier, adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1). The min(., 0.1) cap ensures conservative updates for small matrices. muon_2d_only : bool If True, only 2D parameters use Muon (matching PyTorch's torch.optim.Muon). - Parameters with ndim > 2 use Adam without weight decay. - If False, all >=2D parameters use Muon (default behavior). + Parameters with ndim > 2 use AdamW-style updates. + If False, all >=2D parameters are eligible for Muon via matrix-view routing. Default is True. min_2d_dim : int - Minimum min(m, n) threshold for Muon on 2D matrices. - Matrices with min(m, n) >= min_2d_dim use Muon; + Minimum min(m, n) threshold for Muon on eligible matrix-view parameters. + Eligible parameters with min(m, n) >= min_2d_dim use Muon; those with min(m, n) < min_2d_dim use Adam fallback. Must be >= 1. Set to 1 to disable fallback. @@ -476,9 +494,8 @@ def _build_param_routing(self) -> None: Routing logic: - 1D parameters → Adam path - >2D parameters (when muon_2d_only=True) → Adam path - - 2D parameters with min(m, n) < min_2d_dim → Adam fallback path - - 2D parameters with min(m, n) >= min_2d_dim → Muon path - - >=2D parameters (when muon_2d_only=False) → Muon path + - >=2D parameters with min(m, n) < min_2d_dim → Adam fallback path + - remaining >=2D parameters → Muon path """ if self._routing_built: return @@ -504,10 +521,8 @@ def _build_param_routing(self) -> None: adam_nd.append({"param": p}) continue - # === Step 3. 2D small matrices → Adam fallback === - if (p.ndim == 2) and should_fallback_to_adam_for_matrix( - p, min_2d_dim=min_2d_dim - ): + # === Step 3. Small matrix-view params → Adam fallback === + if should_fallback_to_adam_for_matrix(p, min_2d_dim=min_2d_dim): adam_matrix.append( { "param": p, @@ -520,8 +535,8 @@ def _build_param_routing(self) -> None: muon_params.append( { "param": p, - "rows": int(p.shape[0]), - "cols": int(p.numel() // p.shape[0]), + "rows": int(p.numel() // p.shape[-1]), + "cols": int(p.shape[-1]), } ) @@ -661,6 +676,10 @@ def step( if adam_nd_params: # === Step 2.2. Update exp_avg / exp_avg_sq === adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust + # AdamW decay for >=2D Adam path. + if weight_decay > 0: + for p in adam_nd_params: + p.mul_(1.0 - lr * weight_decay) # exp_avg = beta1 * exp_avg + (1 - beta1) * grad # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 @@ -681,7 +700,7 @@ def step( delta_fp32 = -step_size * (adam_nd_exp_avgs[i] / denom) p.add_(delta_fp32.to(p.dtype)) - # === Step 3. Adam update for small 2D matrices (fallback path) === + # === Step 3. Adam update for small matrix-view params (fallback path) === # === Step 3.1. Collect gradients and initialize state === adam_matrix_params: list[torch.Tensor] = [] adam_matrix_grads_fp32: list[torch.Tensor] = [] @@ -719,6 +738,10 @@ def step( # === Step 3.2. Update exp_avg / exp_avg_sq with scaled lr === adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1) + # AdamW decay for matrix fallback path. + if weight_decay > 0: + for p in adam_matrix_params: + p.mul_(1.0 - lr * weight_decay) # exp_avg = beta1 * exp_avg + (1 - beta1) * grad # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 @@ -780,7 +803,7 @@ def step( muon_momentum_buffers.append(buf) active_entries.append((entry, grad)) - # === Step 4.2. Apply weight decay (Muon path only) === + # === Step 4.2. Apply weight decay on Muon path === if weight_decay > 0 and muon_params_for_decay: for p in muon_params_for_decay: p.mul_(1.0 - lr * weight_decay) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 7424f31aa6..9795e85d10 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3511,7 +3511,7 @@ def training_args( optional=True, default=0.001, doc=doc_only_pt_supported - + "Weight decay coefficient. Applied only to Muon-routed parameters", + + "Weight decay coefficient. Applied to Muon-routed parameters and >=2D Adam-routed parameters (AdamW-style decoupled decay). Not applied to 1D Adam parameters.", ), Argument( "lr_adjust", @@ -3539,8 +3539,8 @@ def training_args( default=True, doc=doc_only_pt_supported + "If True, only 2D parameters use Muon (matching PyTorch's torch.optim.Muon). " - + "Parameters with ndim > 2 use Adam without weight decay. " - + "If False, all >=2D parameters use Muon.", + + "Parameters with ndim > 2 use AdamW-style updates. " + + "If False, all >=2D parameters are eligible for Muon (with min_2d_dim fallback to AdamW-style updates).", ), Argument( "min_2d_dim", @@ -3549,8 +3549,8 @@ def training_args( default=1, alias=["muon_min_2d_dim"], doc=doc_only_pt_supported - + "Minimum min(m, n) threshold for HybridMuon on 2D matrices. " - "Matrices with min(m, n) >= min_2d_dim use HybridMuon; " + + "Minimum min(m, n) threshold for HybridMuon on matrix-view parameters. " + "Parameters with min(m, n) >= min_2d_dim use HybridMuon; " "those with min(m, n) < min_2d_dim use Adam fallback. " "Set to 1 to disable fallback.", ), @@ -3570,7 +3570,7 @@ def training_args( doc=doc_only_pt_supported + "HybridMuon optimizer (DeePMD-kit custom implementation). " + "This is a Hybrid optimizer that automatically combines Muon and Adam. " - + "For >=2D params: Muon update with Newton-Schulz. " + + "For >=2D params: Muon update with Newton-Schulz (or Adam fallback when matrix-view dimensions are too small). " + "For 1D params: Standard Adam. " + "This is DIFFERENT from PyTorch's torch.optim.Muon which ONLY supports 2D parameters.", ),