Skip to content

Commit 1d29bbf

Browse files
committed
skip bf16 at test if no bf16 support
1 parent 586ca17 commit 1d29bbf

2 files changed

Lines changed: 32 additions & 4 deletions

File tree

deepmd/pt/optimizer/muon.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -537,10 +537,10 @@ def step(
537537
)
538538
max_delta = torch.maximum(max_rel_change * p_norms, floors)
539539
scales_tensor = torch.clamp(max_delta / (delta_norms + 1e-12), max=1.0)
540-
for i, delta in enumerate(raw_deltas):
541-
delta.mul_(scales_tensor[i])
542-
543-
torch._foreach_add_(adam_matrix_params, raw_deltas)
540+
for i, (p, delta) in enumerate(
541+
zip(adam_matrix_params, raw_deltas, strict=False)
542+
):
543+
p.add_(delta.mul_(scales_tensor[i]).to(p.dtype))
544544

545545
# === Step 3. Muon update for >=2D parameters (weight matrices) ===
546546
# === Step 3.1. Collect gradients and initialize momentum ===

source/tests/pt/test_muon.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,32 @@
1212
)
1313

1414

15+
def _bf16_matmul_supported(device: torch.device) -> bool:
16+
"""Check if bf16 matmul is reliably supported on the given device."""
17+
if device.type == "cuda":
18+
if not torch.cuda.is_available():
19+
return False
20+
# bf16 requires compute capability >= 8.0 (Ampere+) for native support
21+
# or >= 7.0 (Volta) with tensor cores, but may have precision issues
22+
if hasattr(torch.cuda, "is_bf16_supported"):
23+
return torch.cuda.is_bf16_supported()
24+
# Fallback: check compute capability directly
25+
cap = torch.cuda.get_device_capability(device)
26+
return cap[0] >= 8
27+
# CPU bf16 support: available on x86 with AVX-512 BF16 or ARM with BF16 extension
28+
# Since it's hard to detect reliably, try a small matmul and check for errors
29+
try:
30+
a = torch.randn(4, 4, dtype=torch.bfloat16, device=device)
31+
_ = torch.mm(a, a.T)
32+
return True
33+
except (RuntimeError, TypeError):
34+
return False
35+
36+
37+
BF16_SUPPORTED = _bf16_matmul_supported(env.DEVICE)
38+
39+
40+
@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device")
1541
class TestNewtonSchulzOrthogonalization(unittest.TestCase):
1642
"""Test Newton-Schulz orthogonalization algorithm."""
1743

@@ -55,6 +81,7 @@ def test_invalid_input(self) -> None:
5581
zeropower_via_newtonschulz5(G_1d)
5682

5783

84+
@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device")
5885
class TestMuonOptimizer(unittest.TestCase):
5986
"""Test MuonOptimizer class."""
6087

@@ -164,6 +191,7 @@ def test_lr_adjust_modes(self) -> None:
164191
)
165192

166193

194+
@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device")
167195
class TestMuonOptimizerStateDict(unittest.TestCase):
168196
"""Test optimizer state dict save/load."""
169197

0 commit comments

Comments
 (0)