2121from fms_mo .prep import available_packages
2222import fms_mo .aiu_addons .fp8 .fp8_spyre_op # pylint: disable=unused-import
2323
24+ # ============================================================================
25+ # Helper Functions
26+ # ============================================================================
27+
28+
29+ def initialize_fp8_weights (
30+ fp8_linear ,
31+ weight_strategy : str ,
32+ in_features : int ,
33+ out_features : int ,
34+ ) -> None :
35+ """Initialize FP8Linear weights with proper absmax scaling.
36+
37+ Args:
38+ fp8_linear: FP8Linear module to initialize
39+ weight_strategy: "tensor" or "channel" for weight quantization
40+ in_features: Input feature dimension
41+ out_features: Output feature dimension
42+ """
43+ with torch .no_grad ():
44+ # Create random float weights
45+ float_weights = torch .randn (out_features , in_features )
46+
47+ # Calculate FP8 E4M3 max value (448.0)
48+ fp8_max = torch .finfo (torch .float8_e4m3fn ).max
49+
50+ # Set appropriate scales based on strategy using absmax
51+ if weight_strategy == "tensor" :
52+ # Per-tensor: single scale for entire weight matrix
53+ absmax = float_weights .abs ().max ()
54+ scale = absmax / fp8_max
55+ # Ensure scale is not zero
56+ scale = torch .clamp (scale , min = 1e-12 )
57+ fp8_linear .weight_scale .fill_ (scale .item ())
58+ else : # channel (per-row for weight matrix)
59+ # Per-channel: one scale per output channel (row)
60+ absmax = float_weights .abs ().amax (dim = 1 )
61+ scale = absmax / fp8_max
62+ # Ensure scales are not zero
63+ scale = torch .clamp (scale , min = 1e-12 )
64+ # Reshape to match weight_scale parameter shape (out_features, 1)
65+ fp8_linear .weight_scale .copy_ (scale .reshape (- 1 , 1 ))
66+
67+ # Quantize weights to FP8
68+ quantized_weights = (float_weights / fp8_linear .weight_scale ).clamp (
69+ - fp8_max , fp8_max
70+ )
71+ fp8_linear .weight .copy_ (quantized_weights .to (torch .float8_e4m3fn ))
72+
73+ # Initialize bias if present
74+ if fp8_linear .has_bias :
75+ fp8_linear .bias .copy_ (torch .randn (out_features ))
76+
77+
78+ def initialize_fp8_input_scale (
79+ fp8_linear ,
80+ activation_strategy : str ,
81+ batch_size : int ,
82+ seq_len : int ,
83+ in_features : int ,
84+ ) -> None :
85+ """Initialize static input scale for FP8Linear.
86+
87+ Args:
88+ fp8_linear: FP8Linear module to initialize
89+ activation_strategy: "tensor" or "token" for activation quantization
90+ batch_size: Batch size for sample input
91+ seq_len: Sequence length for sample input
92+ in_features: Input feature dimension
93+ """
94+ with torch .no_grad ():
95+ # For static quantization, use a representative input to calculate scales
96+ sample_input = torch .randn (batch_size , seq_len , in_features )
97+ fp8_max = torch .finfo (torch .float8_e4m3fn ).max
98+
99+ if activation_strategy == "tensor" :
100+ # Per-tensor: single scale for entire activation
101+ absmax = sample_input .abs ().max ()
102+ scale = absmax / fp8_max
103+ scale = torch .clamp (scale , min = 1e-12 )
104+ fp8_linear .input_scale .fill_ (scale .item ())
105+ else : # token
106+ # For per-token static quantization, use a calibrated scale
107+ # based on representative input statistics
108+ absmax = sample_input .abs ().max ()
109+ scale = absmax / fp8_max
110+ scale = torch .clamp (scale , min = 1e-12 )
111+ # Fill all scales with the same representative value
112+ fp8_linear .input_scale .fill_ (scale .item ())
113+
114+
115+ # ============================================================================
116+ # Pytest Fixtures
117+ # ============================================================================
118+
119+
120+ @pytest .fixture
121+ def fp8_test_dimensions ():
122+ """Common test dimensions for FP8Linear tests."""
123+ return {
124+ "batch_size" : 2 ,
125+ "seq_len" : 4 ,
126+ "in_features" : 8 ,
127+ "out_features" : 16 ,
128+ }
129+
130+
131+ # ============================================================================
132+ # Tests
133+ # ============================================================================
134+
24135
25136def test_fp8_registration () -> None :
26137 """
@@ -44,9 +155,10 @@ def test_fp8_registration() -> None:
44155 reason = "FP8 is only available on GPUs with device level 8.9 or higher" ,
45156)
46157def test_fp8_op () -> None :
47- """Validate output shapes of GPTQ W4A16 tensors.
48- Note: this AIU-compatible operation only returns a zero tensor of the
49- expected shape, it does not perform a real W4A16 matmul operation.
158+ """Validate output shapes of FP8 attention operation.
159+
160+ Tests the FP8 attention compute operation to ensure it produces
161+ outputs with the expected shape.
50162 """
51163 # Local
52164 from fms_mo .aiu_addons .fp8 .fp8_attn import _math_fp8_compute_op
@@ -57,3 +169,148 @@ def test_fp8_op() -> None:
57169
58170 out = _math_fp8_compute_op (query , key , value , 32 , 32 , 0.0 , None )
59171 assert out .size () == query .size ()
172+
173+
174+ @pytest .mark .skipif (
175+ not available_packages ["torchao" ] or not available_packages ["fms" ],
176+ reason = "FMS and torchao required to run this test" ,
177+ )
178+ @pytest .mark .parametrize (
179+ "weight_strategy,activation_strategy,dynamic_activation" ,
180+ [
181+ ("tensor" , "tensor" , True ), # Per-tensor weights + per-tensor activations
182+ ("tensor" , "token" , True ), # Per-tensor weights + per-token activations
183+ ("channel" , "tensor" , True ), # Per-channel weights + per-tensor activations
184+ ("channel" , "token" , True ), # Per-channel weights + per-token activations
185+ ],
186+ )
187+ def test_fp8_linear_cpu_support (
188+ weight_strategy : str ,
189+ activation_strategy : str ,
190+ dynamic_activation : bool ,
191+ fp8_test_dimensions : dict ,
192+ ) -> None :
193+ """Test FP8Linear on CPU with different quantization strategies.
194+
195+ This test ensures that FP8Linear works correctly on CPU, including:
196+ - Per-tensor quantization (native support in PyTorch 2.10+)
197+ - Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+)
198+
199+ Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel
200+ and per-token quantization require a fallback to dequantize + regular matmul.
201+
202+ Args:
203+ weight_strategy: "tensor" or "channel" for weight quantization
204+ activation_strategy: "tensor" or "token" for activation quantization
205+ dynamic_activation: Whether to use dynamic activation quantization
206+ fp8_test_dimensions: Test dimensions fixture
207+ """
208+ # Local
209+ from fms_mo .aiu_addons .fp8 .fp8_linear import FP8Linear
210+
211+ # Get test dimensions
212+ batch_size = fp8_test_dimensions ["batch_size" ]
213+ seq_len = fp8_test_dimensions ["seq_len" ]
214+ in_features = fp8_test_dimensions ["in_features" ]
215+ out_features = fp8_test_dimensions ["out_features" ]
216+
217+ # Create FP8Linear configuration
218+ linear_config = {
219+ "weights" : {
220+ "strategy" : weight_strategy ,
221+ "symmetric" : True ,
222+ "dynamic" : False ,
223+ },
224+ "input_activations" : {
225+ "strategy" : activation_strategy ,
226+ "symmetric" : True ,
227+ "dynamic" : dynamic_activation ,
228+ },
229+ }
230+
231+ # Create FP8Linear module
232+ fp8_linear = FP8Linear (
233+ in_features = in_features ,
234+ out_features = out_features ,
235+ bias = True ,
236+ linear_config = linear_config ,
237+ )
238+
239+ # Initialize weights using helper function
240+ initialize_fp8_weights (fp8_linear , weight_strategy , in_features , out_features )
241+
242+ # Initialize input scale if static quantization
243+ if not dynamic_activation :
244+ initialize_fp8_input_scale (
245+ fp8_linear , activation_strategy , batch_size , seq_len , in_features
246+ )
247+
248+ # Create input tensor on CPU
249+ x = torch .randn (batch_size , seq_len , in_features , dtype = torch .bfloat16 )
250+
251+ # Run forward pass - should not raise an error
252+ output = fp8_linear (x )
253+
254+ # Validate output shape
255+ assert output .shape == (batch_size , seq_len , out_features )
256+
257+ # Validate output is not NaN or Inf
258+ assert not torch .isnan (output ).any ()
259+ assert not torch .isinf (output ).any ()
260+
261+ # Validate output dtype matches input dtype
262+ assert output .dtype == x .dtype
263+
264+
265+ @pytest .mark .skipif (
266+ not available_packages ["torchao" ] or not available_packages ["fms" ],
267+ reason = "FMS and torchao required to run this test" ,
268+ )
269+ def test_fp8_linear_cpu_no_activation_quantization (fp8_test_dimensions : dict ) -> None :
270+ """Test FP8Linear on CPU with only weight quantization (no activation quantization).
271+
272+ This tests the code path where activations are not quantized but weights are FP8.
273+
274+ Args:
275+ fp8_test_dimensions: Test dimensions fixture
276+ """
277+ # Local
278+ from fms_mo .aiu_addons .fp8 .fp8_linear import FP8Linear
279+
280+ # Get test dimensions
281+ batch_size = fp8_test_dimensions ["batch_size" ]
282+ seq_len = fp8_test_dimensions ["seq_len" ]
283+ in_features = fp8_test_dimensions ["in_features" ]
284+ out_features = fp8_test_dimensions ["out_features" ]
285+
286+ # Create FP8Linear configuration with no activation quantization
287+ linear_config = {
288+ "weights" : {
289+ "strategy" : "channel" ,
290+ "symmetric" : True ,
291+ "dynamic" : False ,
292+ },
293+ "input_activations" : None , # No activation quantization
294+ }
295+
296+ # Create FP8Linear module
297+ fp8_linear = FP8Linear (
298+ in_features = in_features ,
299+ out_features = out_features ,
300+ bias = True ,
301+ linear_config = linear_config ,
302+ )
303+
304+ # Initialize weights using helper function
305+ initialize_fp8_weights (fp8_linear , "channel" , in_features , out_features )
306+
307+ # Create input tensor on CPU
308+ x = torch .randn (batch_size , seq_len , in_features , dtype = torch .bfloat16 )
309+
310+ # Run forward pass
311+ output = fp8_linear (x )
312+
313+ # Validate output
314+ assert output .shape == (batch_size , seq_len , out_features )
315+ assert not torch .isnan (output ).any ()
316+ assert not torch .isinf (output ).any ()
0 commit comments