Skip to content

Commit 90fc888

Browse files
Merge pull request #200 from foundation-model-stack/fp8_cpu
fix: FP8 fallback for AIU addons running on CPU
2 parents c7e0b85 + 1558c3f commit 90fc888

3 files changed

Lines changed: 257 additions & 7 deletions

File tree

fms_mo/aiu_addons/fp8/fp8_linear.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
"""Implement FP8 linear module to be loaded via FMS."""
1515

1616
# Standard
17+
from importlib.metadata import version
1718
from typing import Any, Mapping
1819

1920
# Third Party
21+
from packaging.version import Version
2022
import torch
2123

2224
# Local
@@ -27,6 +29,9 @@
2729
# torch.nn.functional.linear not recognized as callable
2830
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
2931

32+
TORCH_VERSION = Version(torch.__version__.split("+")[0])
33+
SUPPORTS_CPU_PER_CHANNEL_FP8 = Version("2.10") > TORCH_VERSION
34+
3035
# Gated torchao imports for FP8 implementation
3136
if available_packages["fms"] and available_packages["torchao"]:
3237
# Third Party
@@ -213,7 +218,11 @@ def _construct_qweight_structure(self) -> "AffineQuantizedTensor":
213218

214219
def forward(self, x: torch.Tensor) -> torch.Tensor:
215220
"""If input quantization is active, compute FP8xFP8 addmm leveraging torchao
216-
functionalities. Otherwise compute non-quantized addmm."""
221+
functionalities. Otherwise compute non-quantized addmm.
222+
223+
In Pytorch 2.10, torch._scale_mm only supports FP8 on CPU when
224+
quantization is per-tensor. In this case, we perform a mock FP8xFP8 matmul.
225+
"""
217226
# fp8 weight tensor for torchao
218227
qweight: AffineQuantizedTensor = self._construct_qweight_structure()
219228

@@ -234,6 +243,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
234243
)
235244
qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs)
236245

246+
# Check if we need CPU fallback for per-channel quantization
247+
is_cpu = qx.device.type == "cpu"
248+
is_per_tensor = (
249+
self.linear_config["weights"]["strategy"] == "tensor"
250+
and self.linear_config["input_activations"]["strategy"] == "tensor"
251+
)
252+
253+
# Perform mock FP8xFP8 matmul
254+
if is_cpu and not is_per_tensor and not SUPPORTS_CPU_PER_CHANNEL_FP8:
255+
# Check torchao version without loading the full package
256+
if Version("0.11") < Version(version("torchao")):
257+
raise NotImplementedError(
258+
"Fallback path for FP8 matmul on CPU is not supported "
259+
"on torchao > 0.11."
260+
)
261+
x_dequant = qx.dequantize()
262+
w_dequant = qweight.dequantize()
263+
out = torch.nn.functional.linear(
264+
x_dequant.to(w_dequant.dtype),
265+
w_dequant,
266+
self.bias if self.has_bias else None,
267+
)
268+
return out.to(x.dtype)
269+
237270
# Copied from torchao _linear_fp8_act_fp8_weight_impl
238271
# (with changes to support fp8 out)
239272
scaled_mm_config = Float8MMConfig(use_fast_accum=True)
@@ -276,10 +309,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
276309
).reshape(out_shape)
277310

278311
# activations not quantized, dequant fp8 weight and do regular matmul
312+
w_dequant = qweight.dequantize()
279313
out = torch.nn.functional.linear(
280-
x, qweight.dequantize(), self.bias if self.has_bias else None
314+
x.to(w_dequant.dtype), w_dequant, self.bias if self.has_bias else None
281315
)
282-
return out
316+
return out.to(x.dtype)
283317

284318
def __repr__(self) -> str:
285319
return (

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ dependencies = [
3535

3636
[project.optional-dependencies]
3737
examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"]
38-
fp8 = ["llmcompressor", "torchao==0.11"]
38+
fp8 = ["llmcompressor", "torchao==0.11"] # FP8 matmul on CPU needs a fix before advancing torchao > 0.11
3939
gptq = ["Cython", "gptqmodel>=1.7.3"]
4040
mx = ["microxcaling>=1.1"]
4141
opt = ["fms-model-optimizer[fp8, gptq, mx]"]

tests/aiu_addons/test_fp8_addon.py

Lines changed: 219 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,84 @@
2121
from fms_mo.prep import available_packages
2222
import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import
2323

24+
# ============================================================================
25+
# Constants
26+
# ============================================================================
27+
28+
# FP8 E4M3 maximum value
29+
FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
30+
31+
# ============================================================================
32+
# Helper Functions
33+
# ============================================================================
34+
35+
36+
def initialize_fp8_weights(
37+
fp8_linear,
38+
weight_strategy: str,
39+
in_features: int,
40+
out_features: int,
41+
) -> None:
42+
"""Initialize FP8Linear weights with proper absmax scaling.
43+
44+
Args:
45+
fp8_linear: FP8Linear module to initialize
46+
weight_strategy: "tensor" or "channel" for weight quantization
47+
in_features: Input feature dimension
48+
out_features: Output feature dimension
49+
"""
50+
with torch.no_grad():
51+
# Create random float weights
52+
float_weights = torch.randn(out_features, in_features)
53+
54+
# Set appropriate scales based on strategy using absmax
55+
if weight_strategy == "tensor":
56+
# Per-tensor: single scale for entire weight matrix
57+
absmax = float_weights.abs().max()
58+
scale = absmax / FP8_E4M3_MAX
59+
# Ensure scale is not zero
60+
scale = torch.clamp(scale, min=1e-12)
61+
fp8_linear.weight_scale.fill_(scale.item())
62+
else: # channel (per-row for weight matrix)
63+
# Per-channel: one scale per output channel (row)
64+
absmax = float_weights.abs().amax(dim=1)
65+
scale = absmax / FP8_E4M3_MAX
66+
# Ensure scales are not zero
67+
scale = torch.clamp(scale, min=1e-12)
68+
# Reshape to match weight_scale parameter shape (out_features, 1)
69+
fp8_linear.weight_scale.copy_(scale.reshape(-1, 1))
70+
71+
# Quantize weights to FP8
72+
quantized_weights = (float_weights / fp8_linear.weight_scale).clamp(
73+
-FP8_E4M3_MAX, FP8_E4M3_MAX
74+
)
75+
fp8_linear.weight.copy_(quantized_weights.to(torch.float8_e4m3fn))
76+
77+
# Initialize bias if present
78+
if fp8_linear.has_bias:
79+
fp8_linear.bias.copy_(torch.randn(out_features))
80+
81+
82+
# ============================================================================
83+
# Pytest Fixtures
84+
# ============================================================================
85+
86+
87+
@pytest.fixture
88+
def fp8_test_dimensions():
89+
"""Common test dimensions for FP8Linear tests."""
90+
return {
91+
"batch_size": 2,
92+
"seq_len": 4,
93+
"in_features": 8,
94+
"out_features": 16,
95+
}
96+
97+
98+
# ============================================================================
99+
# Tests
100+
# ============================================================================
101+
24102

25103
def test_fp8_registration() -> None:
26104
"""
@@ -44,9 +122,10 @@ def test_fp8_registration() -> None:
44122
reason="FP8 is only available on GPUs with device level 8.9 or higher",
45123
)
46124
def test_fp8_op() -> None:
47-
"""Validate output shapes of GPTQ W4A16 tensors.
48-
Note: this AIU-compatible operation only returns a zero tensor of the
49-
expected shape, it does not perform a real W4A16 matmul operation.
125+
"""Validate output shapes of FP8 attention operation.
126+
127+
Tests the FP8 attention compute operation to ensure it produces
128+
outputs with the expected shape.
50129
"""
51130
# Local
52131
from fms_mo.aiu_addons.fp8.fp8_attn import _math_fp8_compute_op
@@ -57,3 +136,140 @@ def test_fp8_op() -> None:
57136

58137
out = _math_fp8_compute_op(query, key, value, 32, 32, 0.0, None)
59138
assert out.size() == query.size()
139+
140+
141+
@pytest.mark.skipif(
142+
not available_packages["torchao"] or not available_packages["fms"],
143+
reason="FMS and torchao required to run this test",
144+
)
145+
@pytest.mark.parametrize(
146+
"weight_strategy,activation_strategy",
147+
[
148+
("tensor", "tensor"), # Per-tensor W + per-tensor dynamic A
149+
("tensor", "token"), # Per-tensor W + per-token dynamic A
150+
("channel", "tensor"), # Per-channel W + per-tensor dynamic A
151+
("channel", "token"), # Per-channel W + per-token dynamic A
152+
],
153+
)
154+
def test_fp8_linear_cpu_support( # pylint: disable=redefined-outer-name
155+
weight_strategy: str,
156+
activation_strategy: str,
157+
fp8_test_dimensions: dict,
158+
) -> None:
159+
"""Test FP8Linear on CPU with different quantization strategies.
160+
161+
This test ensures that FP8Linear works correctly on CPU with:
162+
- Per-tensor quantization (native support in PyTorch 2.10+)
163+
- Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+)
164+
165+
Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel
166+
and per-token quantization require a fallback to dequantize + regular matmul.
167+
168+
Args:
169+
weight_strategy: "tensor" or "channel" weight quantization
170+
activation_strategy: "tensor" or "token" dynamic activation quantization
171+
fp8_test_dimensions: Test dimensions fixture
172+
"""
173+
# Local
174+
from fms_mo.aiu_addons.fp8.fp8_linear import FP8Linear
175+
176+
# Get test dimensions
177+
batch_size = fp8_test_dimensions["batch_size"]
178+
seq_len = fp8_test_dimensions["seq_len"]
179+
in_features = fp8_test_dimensions["in_features"]
180+
out_features = fp8_test_dimensions["out_features"]
181+
182+
# Create FP8Linear configuration
183+
linear_config = {
184+
"weights": {
185+
"strategy": weight_strategy,
186+
"symmetric": True,
187+
"dynamic": False,
188+
},
189+
"input_activations": {
190+
"strategy": activation_strategy,
191+
"symmetric": True,
192+
"dynamic": True,
193+
},
194+
}
195+
196+
# Create FP8Linear module
197+
fp8_linear = FP8Linear(
198+
in_features=in_features,
199+
out_features=out_features,
200+
bias=True,
201+
linear_config=linear_config,
202+
)
203+
204+
# Initialize weights using helper function
205+
initialize_fp8_weights(fp8_linear, weight_strategy, in_features, out_features)
206+
207+
# Create input tensor on CPU
208+
x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16)
209+
210+
# Run forward pass - should not raise an error
211+
output = fp8_linear(x)
212+
213+
# Validate output shape
214+
assert output.shape == (batch_size, seq_len, out_features)
215+
216+
# Validate output is not NaN or Inf
217+
assert not torch.isnan(output).any()
218+
assert not torch.isinf(output).any()
219+
220+
# Validate output dtype matches input dtype
221+
assert output.dtype == x.dtype
222+
223+
224+
@pytest.mark.skipif(
225+
not available_packages["torchao"] or not available_packages["fms"],
226+
reason="FMS and torchao required to run this test",
227+
)
228+
def test_fp8_linear_cpu_no_activation_quantization(fp8_test_dimensions: dict) -> None: # pylint: disable=redefined-outer-name
229+
"""Test FP8Linear on CPU with only weight quantization (no activation quantization).
230+
231+
This tests the code path where activations are not quantized but weights are FP8.
232+
233+
Args:
234+
fp8_test_dimensions: Test dimensions fixture
235+
"""
236+
# Local
237+
from fms_mo.aiu_addons.fp8.fp8_linear import FP8Linear
238+
239+
# Get test dimensions
240+
batch_size = fp8_test_dimensions["batch_size"]
241+
seq_len = fp8_test_dimensions["seq_len"]
242+
in_features = fp8_test_dimensions["in_features"]
243+
out_features = fp8_test_dimensions["out_features"]
244+
245+
# Create FP8Linear configuration with no activation quantization
246+
linear_config = {
247+
"weights": {
248+
"strategy": "channel",
249+
"symmetric": True,
250+
"dynamic": False,
251+
},
252+
"input_activations": None, # No activation quantization
253+
}
254+
255+
# Create FP8Linear module
256+
fp8_linear = FP8Linear(
257+
in_features=in_features,
258+
out_features=out_features,
259+
bias=True,
260+
linear_config=linear_config,
261+
)
262+
263+
# Initialize weights using helper function
264+
initialize_fp8_weights(fp8_linear, "channel", in_features, out_features)
265+
266+
# Create input tensor on CPU
267+
x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16)
268+
269+
# Run forward pass
270+
output = fp8_linear(x)
271+
272+
# Validate output
273+
assert output.shape == (batch_size, seq_len, out_features)
274+
assert not torch.isnan(output).any()
275+
assert not torch.isinf(output).any()

0 commit comments

Comments
 (0)