Skip to content

Commit d498851

Browse files
Merge pull request #201 from foundation-model-stack/fp8_cpu_op_fix
fix: update FP8 syntax for custom torch._scaled_mm on CPU
2 parents 90fc888 + 1af819a commit d498851

3 files changed

Lines changed: 22 additions & 40 deletions

File tree

fms_mo/aiu_addons/fp8/fp8_linear.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"""Implement FP8 linear module to be loaded via FMS."""
1515

1616
# Standard
17-
from importlib.metadata import version
1817
from typing import Any, Mapping
1918

2019
# Third Party
@@ -30,7 +29,6 @@
3029
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3130

3231
TORCH_VERSION = Version(torch.__version__.split("+")[0])
33-
SUPPORTS_CPU_PER_CHANNEL_FP8 = Version("2.10") > TORCH_VERSION
3432

3533
# Gated torchao imports for FP8 implementation
3634
if available_packages["fms"] and available_packages["torchao"]:
@@ -243,30 +241,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
243241
)
244242
qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs)
245243

246-
# Check if we need CPU fallback for per-channel quantization
247-
is_cpu = qx.device.type == "cpu"
248-
is_per_tensor = (
249-
self.linear_config["weights"]["strategy"] == "tensor"
250-
and self.linear_config["input_activations"]["strategy"] == "tensor"
251-
)
252-
253-
# Perform mock FP8xFP8 matmul
254-
if is_cpu and not is_per_tensor and not SUPPORTS_CPU_PER_CHANNEL_FP8:
255-
# Check torchao version without loading the full package
256-
if Version("0.11") < Version(version("torchao")):
257-
raise NotImplementedError(
258-
"Fallback path for FP8 matmul on CPU is not supported "
259-
"on torchao > 0.11."
260-
)
261-
x_dequant = qx.dequantize()
262-
w_dequant = qweight.dequantize()
263-
out = torch.nn.functional.linear(
264-
x_dequant.to(w_dequant.dtype),
265-
w_dequant,
266-
self.bias if self.has_bias else None,
267-
)
268-
return out.to(x.dtype)
269-
270244
# Copied from torchao _linear_fp8_act_fp8_weight_impl
271245
# (with changes to support fp8 out)
272246
scaled_mm_config = Float8MMConfig(use_fast_accum=True)

fms_mo/aiu_addons/fp8/fp8_spyre_op.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def _scaled_mm_cpu_out(
5050
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)
5151

5252
if bias is not None:
53-
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
53+
ret = torch.addmm(bias.to(dtype=out_dtype), mat1, mat2)
5454
else:
55-
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)
55+
ret = torch.mm(mat1, mat2)
5656

5757
if out is not None:
5858
out.copy_(ret)
@@ -84,9 +84,12 @@ def _scaled_mm_cpu(
8484

8585

8686
if torch.__version__ >= "2.8":
87-
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
88-
torch.ops.aten._scaled_mm.out.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu_out
89-
torch.ops.aten._scaled_mm.default.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu
87+
# In PyTorch 2.8+, use torch.library.impl to override the native CPU kernel
88+
# The py_kernels dictionary assignment no longer works to override native kernels
89+
# Note: default overload is registered without the ".default" suffix
90+
# Suppress the UserWarning about overriding a previously registered kernel
91+
torch.library.impl("aten::_scaled_mm.out", "CPU")(_scaled_mm_cpu_out)
92+
torch.library.impl("aten::_scaled_mm", "CPU")(_scaled_mm_cpu)
9093
else:
9194
torch.library.register_kernel(
9295
torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out

tests/aiu_addons/test_fp8_addon.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,23 @@
1313
# limitations under the License.
1414
"""Test suite for FMS addon introducing FP8 functionalities"""
1515

16+
# Standard
17+
import warnings
18+
1619
# Third Party
1720
import pytest
1821
import torch
1922

2023
# Local
2124
from fms_mo.prep import available_packages
22-
import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import
25+
26+
# Suppress the UserWarning about overriding kernel registration in PyTorch 2.8+
27+
# This warning is expected when we override the native CPU kernel for _scaled_mm
28+
warnings.simplefilter("ignore", UserWarning)
29+
# Local
30+
import fms_mo.aiu_addons.fp8.fp8_spyre_op # noqa: E402 # pylint: disable=unused-import,wrong-import-position
31+
32+
warnings.simplefilter("default", UserWarning) # Reset to default after import
2333

2434
# ============================================================================
2535
# Constants
@@ -146,8 +156,6 @@ def test_fp8_op() -> None:
146156
"weight_strategy,activation_strategy",
147157
[
148158
("tensor", "tensor"), # Per-tensor W + per-tensor dynamic A
149-
("tensor", "token"), # Per-tensor W + per-token dynamic A
150-
("channel", "tensor"), # Per-channel W + per-tensor dynamic A
151159
("channel", "token"), # Per-channel W + per-token dynamic A
152160
],
153161
)
@@ -156,14 +164,11 @@ def test_fp8_linear_cpu_support( # pylint: disable=redefined-outer-name
156164
activation_strategy: str,
157165
fp8_test_dimensions: dict,
158166
) -> None:
159-
"""Test FP8Linear on CPU with different quantization strategies.
167+
"""Test FP8Linear on CPU with supported quantization strategies.
160168
161169
This test ensures that FP8Linear works correctly on CPU with:
162-
- Per-tensor quantization (native support in PyTorch 2.10+)
163-
- Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+)
164-
165-
Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel
166-
and per-token quantization require a fallback to dequantize + regular matmul.
170+
- Per-tensor quantization (weights and activations both per-tensor)
171+
- Per-channel quantization (weights and activations both per-channel/per-token)
167172
168173
Args:
169174
weight_strategy: "tensor" or "channel" weight quantization

0 commit comments

Comments
 (0)