Skip to content

Commit ef73576

Browse files
committed
remove static activation test
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
1 parent e158bd2 commit ef73576

1 file changed

Lines changed: 9 additions & 53 deletions

File tree

tests/aiu_addons/test_fp8_addon.py

Lines changed: 9 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -79,42 +79,6 @@ def initialize_fp8_weights(
7979
fp8_linear.bias.copy_(torch.randn(out_features))
8080

8181

82-
def initialize_fp8_input_scale(
83-
fp8_linear,
84-
activation_strategy: str,
85-
batch_size: int,
86-
seq_len: int,
87-
in_features: int,
88-
) -> None:
89-
"""Initialize static input scale for FP8Linear.
90-
91-
Args:
92-
fp8_linear: FP8Linear module to initialize
93-
activation_strategy: "tensor" or "token" for activation quantization
94-
batch_size: Batch size for sample input
95-
seq_len: Sequence length for sample input
96-
in_features: Input feature dimension
97-
"""
98-
with torch.no_grad():
99-
# For static quantization, use a representative input to calculate scales
100-
sample_input = torch.randn(batch_size, seq_len, in_features)
101-
102-
if activation_strategy == "tensor":
103-
# Per-tensor: single scale for entire activation
104-
absmax = sample_input.abs().max()
105-
scale = absmax / FP8_E4M3_MAX
106-
scale = torch.clamp(scale, min=1e-12)
107-
fp8_linear.input_scale.fill_(scale.item())
108-
else: # token
109-
# For per-token static quantization, use a calibrated scale
110-
# based on representative input statistics
111-
absmax = sample_input.abs().max()
112-
scale = absmax / FP8_E4M3_MAX
113-
scale = torch.clamp(scale, min=1e-12)
114-
# Fill all scales with the same representative value
115-
fp8_linear.input_scale.fill_(scale.item())
116-
117-
11882
# ============================================================================
11983
# Pytest Fixtures
12084
# ============================================================================
@@ -179,33 +143,31 @@ def test_fp8_op() -> None:
179143
reason="FMS and torchao required to run this test",
180144
)
181145
@pytest.mark.parametrize(
182-
"weight_strategy,activation_strategy,dynamic_activation",
146+
"weight_strategy,activation_strategy",
183147
[
184-
("tensor", "tensor", True), # Per-tensor weights + per-tensor activations
185-
("tensor", "token", True), # Per-tensor weights + per-token activations
186-
("channel", "tensor", True), # Per-channel weights + per-tensor activations
187-
("channel", "token", True), # Per-channel weights + per-token activations
148+
("tensor", "tensor"), # Per-tensor W + per-tensor dynamic A
149+
("tensor", "token"), # Per-tensor W + per-token dynamic A
150+
("channel", "tensor"), # Per-channel W + per-tensor dynamic A
151+
("channel", "token"), # Per-channel W + per-token dynamic A
188152
],
189153
)
190154
def test_fp8_linear_cpu_support(
191155
weight_strategy: str,
192156
activation_strategy: str,
193-
dynamic_activation: bool,
194157
fp8_test_dimensions: dict,
195158
) -> None:
196159
"""Test FP8Linear on CPU with different quantization strategies.
197160
198-
This test ensures that FP8Linear works correctly on CPU, including:
161+
This test ensures that FP8Linear works correctly on CPU with:
199162
- Per-tensor quantization (native support in PyTorch 2.10+)
200163
- Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+)
201164
202165
Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel
203166
and per-token quantization require a fallback to dequantize + regular matmul.
204167
205168
Args:
206-
weight_strategy: "tensor" or "channel" for weight quantization
207-
activation_strategy: "tensor" or "token" for activation quantization
208-
dynamic_activation: Whether to use dynamic activation quantization
169+
weight_strategy: "tensor" or "channel" weight quantization
170+
activation_strategy: "tensor" or "token" dynamic activation quantization
209171
fp8_test_dimensions: Test dimensions fixture
210172
"""
211173
# Local
@@ -227,7 +189,7 @@ def test_fp8_linear_cpu_support(
227189
"input_activations": {
228190
"strategy": activation_strategy,
229191
"symmetric": True,
230-
"dynamic": dynamic_activation,
192+
"dynamic": True,
231193
},
232194
}
233195

@@ -242,12 +204,6 @@ def test_fp8_linear_cpu_support(
242204
# Initialize weights using helper function
243205
initialize_fp8_weights(fp8_linear, weight_strategy, in_features, out_features)
244206

245-
# Initialize input scale if static quantization
246-
if not dynamic_activation:
247-
initialize_fp8_input_scale(
248-
fp8_linear, activation_strategy, batch_size, seq_len, in_features
249-
)
250-
251207
# Create input tensor on CPU
252208
x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16)
253209

0 commit comments

Comments
 (0)