2121from fms_mo .prep import available_packages
2222import fms_mo .aiu_addons .fp8 .fp8_spyre_op # pylint: disable=unused-import
2323
24+ # ============================================================================
25+ # Constants
26+ # ============================================================================
27+
28+ # FP8 E4M3 maximum value
29+ FP8_E4M3_MAX = torch .finfo (torch .float8_e4m3fn ).max
30+
31+ # ============================================================================
32+ # Helper Functions
33+ # ============================================================================
34+
35+
36+ def initialize_fp8_weights (
37+ fp8_linear ,
38+ weight_strategy : str ,
39+ in_features : int ,
40+ out_features : int ,
41+ ) -> None :
42+ """Initialize FP8Linear weights with proper absmax scaling.
43+
44+ Args:
45+ fp8_linear: FP8Linear module to initialize
46+ weight_strategy: "tensor" or "channel" for weight quantization
47+ in_features: Input feature dimension
48+ out_features: Output feature dimension
49+ """
50+ with torch .no_grad ():
51+ # Create random float weights
52+ float_weights = torch .randn (out_features , in_features )
53+
54+ # Set appropriate scales based on strategy using absmax
55+ if weight_strategy == "tensor" :
56+ # Per-tensor: single scale for entire weight matrix
57+ absmax = float_weights .abs ().max ()
58+ scale = absmax / FP8_E4M3_MAX
59+ # Ensure scale is not zero
60+ scale = torch .clamp (scale , min = 1e-12 )
61+ fp8_linear .weight_scale .fill_ (scale .item ())
62+ else : # channel (per-row for weight matrix)
63+ # Per-channel: one scale per output channel (row)
64+ absmax = float_weights .abs ().amax (dim = 1 )
65+ scale = absmax / FP8_E4M3_MAX
66+ # Ensure scales are not zero
67+ scale = torch .clamp (scale , min = 1e-12 )
68+ # Reshape to match weight_scale parameter shape (out_features, 1)
69+ fp8_linear .weight_scale .copy_ (scale .reshape (- 1 , 1 ))
70+
71+ # Quantize weights to FP8
72+ quantized_weights = (float_weights / fp8_linear .weight_scale ).clamp (
73+ - FP8_E4M3_MAX , FP8_E4M3_MAX
74+ )
75+ fp8_linear .weight .copy_ (quantized_weights .to (torch .float8_e4m3fn ))
76+
77+ # Initialize bias if present
78+ if fp8_linear .has_bias :
79+ fp8_linear .bias .copy_ (torch .randn (out_features ))
80+
81+
82+ # ============================================================================
83+ # Pytest Fixtures
84+ # ============================================================================
85+
86+
87+ @pytest .fixture
88+ def fp8_test_dimensions ():
89+ """Common test dimensions for FP8Linear tests."""
90+ return {
91+ "batch_size" : 2 ,
92+ "seq_len" : 4 ,
93+ "in_features" : 8 ,
94+ "out_features" : 16 ,
95+ }
96+
97+
98+ # ============================================================================
99+ # Tests
100+ # ============================================================================
101+
24102
25103def test_fp8_registration () -> None :
26104 """
@@ -44,9 +122,10 @@ def test_fp8_registration() -> None:
44122 reason = "FP8 is only available on GPUs with device level 8.9 or higher" ,
45123)
46124def test_fp8_op () -> None :
47- """Validate output shapes of GPTQ W4A16 tensors.
48- Note: this AIU-compatible operation only returns a zero tensor of the
49- expected shape, it does not perform a real W4A16 matmul operation.
125+ """Validate output shapes of FP8 attention operation.
126+
127+ Tests the FP8 attention compute operation to ensure it produces
128+ outputs with the expected shape.
50129 """
51130 # Local
52131 from fms_mo .aiu_addons .fp8 .fp8_attn import _math_fp8_compute_op
@@ -57,3 +136,140 @@ def test_fp8_op() -> None:
57136
58137 out = _math_fp8_compute_op (query , key , value , 32 , 32 , 0.0 , None )
59138 assert out .size () == query .size ()
139+
140+
141+ @pytest .mark .skipif (
142+ not available_packages ["torchao" ] or not available_packages ["fms" ],
143+ reason = "FMS and torchao required to run this test" ,
144+ )
145+ @pytest .mark .parametrize (
146+ "weight_strategy,activation_strategy" ,
147+ [
148+ ("tensor" , "tensor" ), # Per-tensor W + per-tensor dynamic A
149+ ("tensor" , "token" ), # Per-tensor W + per-token dynamic A
150+ ("channel" , "tensor" ), # Per-channel W + per-tensor dynamic A
151+ ("channel" , "token" ), # Per-channel W + per-token dynamic A
152+ ],
153+ )
154+ def test_fp8_linear_cpu_support ( # pylint: disable=redefined-outer-name
155+ weight_strategy : str ,
156+ activation_strategy : str ,
157+ fp8_test_dimensions : dict ,
158+ ) -> None :
159+ """Test FP8Linear on CPU with different quantization strategies.
160+
161+ This test ensures that FP8Linear works correctly on CPU with:
162+ - Per-tensor quantization (native support in PyTorch 2.10+)
163+ - Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+)
164+
165+ Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel
166+ and per-token quantization require a fallback to dequantize + regular matmul.
167+
168+ Args:
169+ weight_strategy: "tensor" or "channel" weight quantization
170+ activation_strategy: "tensor" or "token" dynamic activation quantization
171+ fp8_test_dimensions: Test dimensions fixture
172+ """
173+ # Local
174+ from fms_mo .aiu_addons .fp8 .fp8_linear import FP8Linear
175+
176+ # Get test dimensions
177+ batch_size = fp8_test_dimensions ["batch_size" ]
178+ seq_len = fp8_test_dimensions ["seq_len" ]
179+ in_features = fp8_test_dimensions ["in_features" ]
180+ out_features = fp8_test_dimensions ["out_features" ]
181+
182+ # Create FP8Linear configuration
183+ linear_config = {
184+ "weights" : {
185+ "strategy" : weight_strategy ,
186+ "symmetric" : True ,
187+ "dynamic" : False ,
188+ },
189+ "input_activations" : {
190+ "strategy" : activation_strategy ,
191+ "symmetric" : True ,
192+ "dynamic" : True ,
193+ },
194+ }
195+
196+ # Create FP8Linear module
197+ fp8_linear = FP8Linear (
198+ in_features = in_features ,
199+ out_features = out_features ,
200+ bias = True ,
201+ linear_config = linear_config ,
202+ )
203+
204+ # Initialize weights using helper function
205+ initialize_fp8_weights (fp8_linear , weight_strategy , in_features , out_features )
206+
207+ # Create input tensor on CPU
208+ x = torch .randn (batch_size , seq_len , in_features , dtype = torch .bfloat16 )
209+
210+ # Run forward pass - should not raise an error
211+ output = fp8_linear (x )
212+
213+ # Validate output shape
214+ assert output .shape == (batch_size , seq_len , out_features )
215+
216+ # Validate output is not NaN or Inf
217+ assert not torch .isnan (output ).any ()
218+ assert not torch .isinf (output ).any ()
219+
220+ # Validate output dtype matches input dtype
221+ assert output .dtype == x .dtype
222+
223+
224+ @pytest .mark .skipif (
225+ not available_packages ["torchao" ] or not available_packages ["fms" ],
226+ reason = "FMS and torchao required to run this test" ,
227+ )
228+ def test_fp8_linear_cpu_no_activation_quantization (fp8_test_dimensions : dict ) -> None : # pylint: disable=redefined-outer-name
229+ """Test FP8Linear on CPU with only weight quantization (no activation quantization).
230+
231+ This tests the code path where activations are not quantized but weights are FP8.
232+
233+ Args:
234+ fp8_test_dimensions: Test dimensions fixture
235+ """
236+ # Local
237+ from fms_mo .aiu_addons .fp8 .fp8_linear import FP8Linear
238+
239+ # Get test dimensions
240+ batch_size = fp8_test_dimensions ["batch_size" ]
241+ seq_len = fp8_test_dimensions ["seq_len" ]
242+ in_features = fp8_test_dimensions ["in_features" ]
243+ out_features = fp8_test_dimensions ["out_features" ]
244+
245+ # Create FP8Linear configuration with no activation quantization
246+ linear_config = {
247+ "weights" : {
248+ "strategy" : "channel" ,
249+ "symmetric" : True ,
250+ "dynamic" : False ,
251+ },
252+ "input_activations" : None , # No activation quantization
253+ }
254+
255+ # Create FP8Linear module
256+ fp8_linear = FP8Linear (
257+ in_features = in_features ,
258+ out_features = out_features ,
259+ bias = True ,
260+ linear_config = linear_config ,
261+ )
262+
263+ # Initialize weights using helper function
264+ initialize_fp8_weights (fp8_linear , "channel" , in_features , out_features )
265+
266+ # Create input tensor on CPU
267+ x = torch .randn (batch_size , seq_len , in_features , dtype = torch .bfloat16 )
268+
269+ # Run forward pass
270+ output = fp8_linear (x )
271+
272+ # Validate output
273+ assert output .shape == (batch_size , seq_len , out_features )
274+ assert not torch .isnan (output ).any ()
275+ assert not torch .isinf (output ).any ()
0 commit comments