@@ -79,42 +79,6 @@ def initialize_fp8_weights(
7979 fp8_linear .bias .copy_ (torch .randn (out_features ))
8080
8181
82- def initialize_fp8_input_scale (
83- fp8_linear ,
84- activation_strategy : str ,
85- batch_size : int ,
86- seq_len : int ,
87- in_features : int ,
88- ) -> None :
89- """Initialize static input scale for FP8Linear.
90-
91- Args:
92- fp8_linear: FP8Linear module to initialize
93- activation_strategy: "tensor" or "token" for activation quantization
94- batch_size: Batch size for sample input
95- seq_len: Sequence length for sample input
96- in_features: Input feature dimension
97- """
98- with torch .no_grad ():
99- # For static quantization, use a representative input to calculate scales
100- sample_input = torch .randn (batch_size , seq_len , in_features )
101-
102- if activation_strategy == "tensor" :
103- # Per-tensor: single scale for entire activation
104- absmax = sample_input .abs ().max ()
105- scale = absmax / FP8_E4M3_MAX
106- scale = torch .clamp (scale , min = 1e-12 )
107- fp8_linear .input_scale .fill_ (scale .item ())
108- else : # token
109- # For per-token static quantization, use a calibrated scale
110- # based on representative input statistics
111- absmax = sample_input .abs ().max ()
112- scale = absmax / FP8_E4M3_MAX
113- scale = torch .clamp (scale , min = 1e-12 )
114- # Fill all scales with the same representative value
115- fp8_linear .input_scale .fill_ (scale .item ())
116-
117-
11882# ============================================================================
11983# Pytest Fixtures
12084# ============================================================================
@@ -179,33 +143,31 @@ def test_fp8_op() -> None:
179143 reason = "FMS and torchao required to run this test" ,
180144)
181145@pytest .mark .parametrize (
182- "weight_strategy,activation_strategy,dynamic_activation " ,
146+ "weight_strategy,activation_strategy" ,
183147 [
184- ("tensor" , "tensor" , True ), # Per-tensor weights + per-tensor activations
185- ("tensor" , "token" , True ), # Per-tensor weights + per-token activations
186- ("channel" , "tensor" , True ), # Per-channel weights + per-tensor activations
187- ("channel" , "token" , True ), # Per-channel weights + per-token activations
148+ ("tensor" , "tensor" ), # Per-tensor W + per-tensor dynamic A
149+ ("tensor" , "token" ), # Per-tensor W + per-token dynamic A
150+ ("channel" , "tensor" ), # Per-channel W + per-tensor dynamic A
151+ ("channel" , "token" ), # Per-channel W + per-token dynamic A
188152 ],
189153)
190154def test_fp8_linear_cpu_support (
191155 weight_strategy : str ,
192156 activation_strategy : str ,
193- dynamic_activation : bool ,
194157 fp8_test_dimensions : dict ,
195158) -> None :
196159 """Test FP8Linear on CPU with different quantization strategies.
197160
198- This test ensures that FP8Linear works correctly on CPU, including :
161+ This test ensures that FP8Linear works correctly on CPU with :
199162 - Per-tensor quantization (native support in PyTorch 2.10+)
200163 - Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+)
201164
202165 Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel
203166 and per-token quantization require a fallback to dequantize + regular matmul.
204167
205168 Args:
206- weight_strategy: "tensor" or "channel" for weight quantization
207- activation_strategy: "tensor" or "token" for activation quantization
208- dynamic_activation: Whether to use dynamic activation quantization
169+ weight_strategy: "tensor" or "channel" weight quantization
170+ activation_strategy: "tensor" or "token" dynamic activation quantization
209171 fp8_test_dimensions: Test dimensions fixture
210172 """
211173 # Local
@@ -227,7 +189,7 @@ def test_fp8_linear_cpu_support(
227189 "input_activations" : {
228190 "strategy" : activation_strategy ,
229191 "symmetric" : True ,
230- "dynamic" : dynamic_activation ,
192+ "dynamic" : True ,
231193 },
232194 }
233195
@@ -242,12 +204,6 @@ def test_fp8_linear_cpu_support(
242204 # Initialize weights using helper function
243205 initialize_fp8_weights (fp8_linear , weight_strategy , in_features , out_features )
244206
245- # Initialize input scale if static quantization
246- if not dynamic_activation :
247- initialize_fp8_input_scale (
248- fp8_linear , activation_strategy , batch_size , seq_len , in_features
249- )
250-
251207 # Create input tensor on CPU
252208 x = torch .randn (batch_size , seq_len , in_features , dtype = torch .bfloat16 )
253209
0 commit comments