Skip to content

Commit c0a617f

Browse files
committed
add unit tests for FP8 matmul on CPU
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
1 parent 8a046a9 commit c0a617f

1 file changed

Lines changed: 260 additions & 3 deletions

File tree

tests/aiu_addons/test_fp8_addon.py

Lines changed: 260 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,117 @@
2121
from fms_mo.prep import available_packages
2222
import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import
2323

24+
# ============================================================================
25+
# Helper Functions
26+
# ============================================================================
27+
28+
29+
def initialize_fp8_weights(
30+
fp8_linear,
31+
weight_strategy: str,
32+
in_features: int,
33+
out_features: int,
34+
) -> None:
35+
"""Initialize FP8Linear weights with proper absmax scaling.
36+
37+
Args:
38+
fp8_linear: FP8Linear module to initialize
39+
weight_strategy: "tensor" or "channel" for weight quantization
40+
in_features: Input feature dimension
41+
out_features: Output feature dimension
42+
"""
43+
with torch.no_grad():
44+
# Create random float weights
45+
float_weights = torch.randn(out_features, in_features)
46+
47+
# Calculate FP8 E4M3 max value (448.0)
48+
fp8_max = torch.finfo(torch.float8_e4m3fn).max
49+
50+
# Set appropriate scales based on strategy using absmax
51+
if weight_strategy == "tensor":
52+
# Per-tensor: single scale for entire weight matrix
53+
absmax = float_weights.abs().max()
54+
scale = absmax / fp8_max
55+
# Ensure scale is not zero
56+
scale = torch.clamp(scale, min=1e-12)
57+
fp8_linear.weight_scale.fill_(scale.item())
58+
else: # channel (per-row for weight matrix)
59+
# Per-channel: one scale per output channel (row)
60+
absmax = float_weights.abs().amax(dim=1)
61+
scale = absmax / fp8_max
62+
# Ensure scales are not zero
63+
scale = torch.clamp(scale, min=1e-12)
64+
# Reshape to match weight_scale parameter shape (out_features, 1)
65+
fp8_linear.weight_scale.copy_(scale.reshape(-1, 1))
66+
67+
# Quantize weights to FP8
68+
quantized_weights = (float_weights / fp8_linear.weight_scale).clamp(
69+
-fp8_max, fp8_max
70+
)
71+
fp8_linear.weight.copy_(quantized_weights.to(torch.float8_e4m3fn))
72+
73+
# Initialize bias if present
74+
if fp8_linear.has_bias:
75+
fp8_linear.bias.copy_(torch.randn(out_features))
76+
77+
78+
def initialize_fp8_input_scale(
79+
fp8_linear,
80+
activation_strategy: str,
81+
batch_size: int,
82+
seq_len: int,
83+
in_features: int,
84+
) -> None:
85+
"""Initialize static input scale for FP8Linear.
86+
87+
Args:
88+
fp8_linear: FP8Linear module to initialize
89+
activation_strategy: "tensor" or "token" for activation quantization
90+
batch_size: Batch size for sample input
91+
seq_len: Sequence length for sample input
92+
in_features: Input feature dimension
93+
"""
94+
with torch.no_grad():
95+
# For static quantization, use a representative input to calculate scales
96+
sample_input = torch.randn(batch_size, seq_len, in_features)
97+
fp8_max = torch.finfo(torch.float8_e4m3fn).max
98+
99+
if activation_strategy == "tensor":
100+
# Per-tensor: single scale for entire activation
101+
absmax = sample_input.abs().max()
102+
scale = absmax / fp8_max
103+
scale = torch.clamp(scale, min=1e-12)
104+
fp8_linear.input_scale.fill_(scale.item())
105+
else: # token
106+
# For per-token static quantization, use a calibrated scale
107+
# based on representative input statistics
108+
absmax = sample_input.abs().max()
109+
scale = absmax / fp8_max
110+
scale = torch.clamp(scale, min=1e-12)
111+
# Fill all scales with the same representative value
112+
fp8_linear.input_scale.fill_(scale.item())
113+
114+
115+
# ============================================================================
116+
# Pytest Fixtures
117+
# ============================================================================
118+
119+
120+
@pytest.fixture
121+
def fp8_test_dimensions():
122+
"""Common test dimensions for FP8Linear tests."""
123+
return {
124+
"batch_size": 2,
125+
"seq_len": 4,
126+
"in_features": 8,
127+
"out_features": 16,
128+
}
129+
130+
131+
# ============================================================================
132+
# Tests
133+
# ============================================================================
134+
24135

25136
def test_fp8_registration() -> None:
26137
"""
@@ -44,9 +155,10 @@ def test_fp8_registration() -> None:
44155
reason="FP8 is only available on GPUs with device level 8.9 or higher",
45156
)
46157
def test_fp8_op() -> None:
47-
"""Validate output shapes of GPTQ W4A16 tensors.
48-
Note: this AIU-compatible operation only returns a zero tensor of the
49-
expected shape, it does not perform a real W4A16 matmul operation.
158+
"""Validate output shapes of FP8 attention operation.
159+
160+
Tests the FP8 attention compute operation to ensure it produces
161+
outputs with the expected shape.
50162
"""
51163
# Local
52164
from fms_mo.aiu_addons.fp8.fp8_attn import _math_fp8_compute_op
@@ -57,3 +169,148 @@ def test_fp8_op() -> None:
57169

58170
out = _math_fp8_compute_op(query, key, value, 32, 32, 0.0, None)
59171
assert out.size() == query.size()
172+
173+
174+
@pytest.mark.skipif(
175+
not available_packages["torchao"] or not available_packages["fms"],
176+
reason="FMS and torchao required to run this test",
177+
)
178+
@pytest.mark.parametrize(
179+
"weight_strategy,activation_strategy,dynamic_activation",
180+
[
181+
("tensor", "tensor", True), # Per-tensor weights + per-tensor activations
182+
("tensor", "token", True), # Per-tensor weights + per-token activations
183+
("channel", "tensor", True), # Per-channel weights + per-tensor activations
184+
("channel", "token", True), # Per-channel weights + per-token activations
185+
],
186+
)
187+
def test_fp8_linear_cpu_support(
188+
weight_strategy: str,
189+
activation_strategy: str,
190+
dynamic_activation: bool,
191+
fp8_test_dimensions: dict,
192+
) -> None:
193+
"""Test FP8Linear on CPU with different quantization strategies.
194+
195+
This test ensures that FP8Linear works correctly on CPU, including:
196+
- Per-tensor quantization (native support in PyTorch 2.10+)
197+
- Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+)
198+
199+
Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel
200+
and per-token quantization require a fallback to dequantize + regular matmul.
201+
202+
Args:
203+
weight_strategy: "tensor" or "channel" for weight quantization
204+
activation_strategy: "tensor" or "token" for activation quantization
205+
dynamic_activation: Whether to use dynamic activation quantization
206+
fp8_test_dimensions: Test dimensions fixture
207+
"""
208+
# Local
209+
from fms_mo.aiu_addons.fp8.fp8_linear import FP8Linear
210+
211+
# Get test dimensions
212+
batch_size = fp8_test_dimensions["batch_size"]
213+
seq_len = fp8_test_dimensions["seq_len"]
214+
in_features = fp8_test_dimensions["in_features"]
215+
out_features = fp8_test_dimensions["out_features"]
216+
217+
# Create FP8Linear configuration
218+
linear_config = {
219+
"weights": {
220+
"strategy": weight_strategy,
221+
"symmetric": True,
222+
"dynamic": False,
223+
},
224+
"input_activations": {
225+
"strategy": activation_strategy,
226+
"symmetric": True,
227+
"dynamic": dynamic_activation,
228+
},
229+
}
230+
231+
# Create FP8Linear module
232+
fp8_linear = FP8Linear(
233+
in_features=in_features,
234+
out_features=out_features,
235+
bias=True,
236+
linear_config=linear_config,
237+
)
238+
239+
# Initialize weights using helper function
240+
initialize_fp8_weights(fp8_linear, weight_strategy, in_features, out_features)
241+
242+
# Initialize input scale if static quantization
243+
if not dynamic_activation:
244+
initialize_fp8_input_scale(
245+
fp8_linear, activation_strategy, batch_size, seq_len, in_features
246+
)
247+
248+
# Create input tensor on CPU
249+
x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16)
250+
251+
# Run forward pass - should not raise an error
252+
output = fp8_linear(x)
253+
254+
# Validate output shape
255+
assert output.shape == (batch_size, seq_len, out_features)
256+
257+
# Validate output is not NaN or Inf
258+
assert not torch.isnan(output).any()
259+
assert not torch.isinf(output).any()
260+
261+
# Validate output dtype matches input dtype
262+
assert output.dtype == x.dtype
263+
264+
265+
@pytest.mark.skipif(
266+
not available_packages["torchao"] or not available_packages["fms"],
267+
reason="FMS and torchao required to run this test",
268+
)
269+
def test_fp8_linear_cpu_no_activation_quantization(fp8_test_dimensions: dict) -> None:
270+
"""Test FP8Linear on CPU with only weight quantization (no activation quantization).
271+
272+
This tests the code path where activations are not quantized but weights are FP8.
273+
274+
Args:
275+
fp8_test_dimensions: Test dimensions fixture
276+
"""
277+
# Local
278+
from fms_mo.aiu_addons.fp8.fp8_linear import FP8Linear
279+
280+
# Get test dimensions
281+
batch_size = fp8_test_dimensions["batch_size"]
282+
seq_len = fp8_test_dimensions["seq_len"]
283+
in_features = fp8_test_dimensions["in_features"]
284+
out_features = fp8_test_dimensions["out_features"]
285+
286+
# Create FP8Linear configuration with no activation quantization
287+
linear_config = {
288+
"weights": {
289+
"strategy": "channel",
290+
"symmetric": True,
291+
"dynamic": False,
292+
},
293+
"input_activations": None, # No activation quantization
294+
}
295+
296+
# Create FP8Linear module
297+
fp8_linear = FP8Linear(
298+
in_features=in_features,
299+
out_features=out_features,
300+
bias=True,
301+
linear_config=linear_config,
302+
)
303+
304+
# Initialize weights using helper function
305+
initialize_fp8_weights(fp8_linear, "channel", in_features, out_features)
306+
307+
# Create input tensor on CPU
308+
x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16)
309+
310+
# Run forward pass
311+
output = fp8_linear(x)
312+
313+
# Validate output
314+
assert output.shape == (batch_size, seq_len, out_features)
315+
assert not torch.isnan(output).any()
316+
assert not torch.isinf(output).any()

0 commit comments

Comments
 (0)