Skip to content

Commit 7fa142b

Browse files
committed
remove tests for unsupported fp8 setup
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
1 parent 01d0cfe commit 7fa142b

2 files changed

Lines changed: 12 additions & 10 deletions

File tree

fms_mo/aiu_addons/fp8/fp8_spyre_op.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@ def _scaled_mm_cpu_out(
5050
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)
5151

5252
if bias is not None:
53-
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
53+
bias_converted = bias.to(dtype=out_dtype)
54+
ret = torch.addmm(bias_converted, mat1, mat2)
5455
else:
55-
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)
56+
ret = torch.mm(mat1, mat2)
5657

5758
if out is not None:
5859
out.copy_(ret)
@@ -87,6 +88,7 @@ def _scaled_mm_cpu(
8788
# In PyTorch 2.8+, use torch.library.impl to override the native CPU kernel
8889
# The py_kernels dictionary assignment no longer works to override native kernels
8990
# Note: default overload is registered without the ".default" suffix
91+
# Suppress the UserWarning about overriding a previously registered kernel
9092
torch.library.impl("aten::_scaled_mm.out", "CPU")(_scaled_mm_cpu_out)
9193
torch.library.impl("aten::_scaled_mm", "CPU")(_scaled_mm_cpu)
9294
else:

tests/aiu_addons/test_fp8_addon.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
# Suppress the UserWarning about overriding kernel registration in PyTorch 2.8+
2727
# This warning is expected when we override the native CPU kernel for _scaled_mm
2828
warnings.simplefilter("ignore", UserWarning)
29-
import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import
29+
# Local
30+
import fms_mo.aiu_addons.fp8.fp8_spyre_op # noqa: E402 # pylint: disable=unused-import,wrong-import-position
31+
3032
warnings.simplefilter("default", UserWarning) # Reset to default after import
3133

3234
# ============================================================================
@@ -154,8 +156,6 @@ def test_fp8_op() -> None:
154156
"weight_strategy,activation_strategy",
155157
[
156158
("tensor", "tensor"), # Per-tensor W + per-tensor dynamic A
157-
("tensor", "token"), # Per-tensor W + per-token dynamic A
158-
("channel", "tensor"), # Per-channel W + per-tensor dynamic A
159159
("channel", "token"), # Per-channel W + per-token dynamic A
160160
],
161161
)
@@ -164,14 +164,14 @@ def test_fp8_linear_cpu_support( # pylint: disable=redefined-outer-name
164164
activation_strategy: str,
165165
fp8_test_dimensions: dict,
166166
) -> None:
167-
"""Test FP8Linear on CPU with different quantization strategies.
167+
"""Test FP8Linear on CPU with supported quantization strategies.
168168
169169
This test ensures that FP8Linear works correctly on CPU with:
170-
- Per-tensor quantization (native support in PyTorch 2.10+)
171-
- Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+)
170+
- Per-tensor quantization (weights and activations both per-tensor)
171+
- Per-channel quantization (weights and activations both per-channel/per-token)
172172
173-
Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel
174-
and per-token quantization require a fallback to dequantize + regular matmul.
173+
Note: Mixed granularity (e.g., per-tensor weights with per-token activations)
174+
is not supported on the target custom hardware.
175175
176176
Args:
177177
weight_strategy: "tensor" or "channel" weight quantization

0 commit comments

Comments
 (0)