|
12 | 12 | ) |
13 | 13 |
|
14 | 14 |
|
| 15 | +def _bf16_matmul_supported(device: torch.device) -> bool: |
| 16 | + """Check if bf16 matmul is reliably supported on the given device.""" |
| 17 | + if device.type == "cuda": |
| 18 | + if not torch.cuda.is_available(): |
| 19 | + return False |
| 20 | + # bf16 requires compute capability >= 8.0 (Ampere+) for native support |
| 21 | + # or >= 7.0 (Volta) with tensor cores, but may have precision issues |
| 22 | + if hasattr(torch.cuda, "is_bf16_supported"): |
| 23 | + return torch.cuda.is_bf16_supported() |
| 24 | + # Fallback: check compute capability directly |
| 25 | + cap = torch.cuda.get_device_capability(device) |
| 26 | + return cap[0] >= 8 |
| 27 | + # CPU bf16 support: available on x86 with AVX-512 BF16 or ARM with BF16 extension |
| 28 | + # Since it's hard to detect reliably, try a small matmul and check for errors |
| 29 | + try: |
| 30 | + a = torch.randn(4, 4, dtype=torch.bfloat16, device=device) |
| 31 | + _ = torch.mm(a, a.T) |
| 32 | + return True |
| 33 | + except (RuntimeError, TypeError): |
| 34 | + return False |
| 35 | + |
| 36 | + |
| 37 | +BF16_SUPPORTED = _bf16_matmul_supported(env.DEVICE) |
| 38 | + |
| 39 | + |
| 40 | +@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") |
15 | 41 | class TestNewtonSchulzOrthogonalization(unittest.TestCase): |
16 | 42 | """Test Newton-Schulz orthogonalization algorithm.""" |
17 | 43 |
|
@@ -55,6 +81,7 @@ def test_invalid_input(self) -> None: |
55 | 81 | zeropower_via_newtonschulz5(G_1d) |
56 | 82 |
|
57 | 83 |
|
| 84 | +@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") |
58 | 85 | class TestMuonOptimizer(unittest.TestCase): |
59 | 86 | """Test MuonOptimizer class.""" |
60 | 87 |
|
@@ -164,6 +191,7 @@ def test_lr_adjust_modes(self) -> None: |
164 | 191 | ) |
165 | 192 |
|
166 | 193 |
|
| 194 | +@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") |
167 | 195 | class TestMuonOptimizerStateDict(unittest.TestCase): |
168 | 196 | """Test optimizer state dict save/load.""" |
169 | 197 |
|
|
0 commit comments