@@ -132,15 +132,6 @@ def get_weights_scaling_factor_from_quantizer(
132132 assert scale .dtype == cls .SCALE_DTYPE , (
133133 f"MXFP8 scale must be { cls .SCALE_DTYPE } (E8M0 format), got { scale .dtype } "
134134 )
135-
136- # Reshape if needed (same number of elements but wrong shape)
137- if scale .shape != expected_shape :
138- expected_numel = 1
139- for dim in expected_shape :
140- expected_numel *= dim
141- if scale .numel () == expected_numel :
142- scale = scale .reshape (expected_shape )
143-
144135 assert scale .shape == expected_shape , (
145136 f"Scale shape { scale .shape } does not match expected shape { expected_shape } "
146137 )
@@ -179,12 +170,6 @@ def quantize_with_scale(
179170 f"Weight inner dimension ({ in_dim } ) must be divisible by MXFP8 block size ({ cls .BLOCK_SIZE } )"
180171 )
181172
182- # Reshape scale if needed (same number of elements but wrong shape)
183- expected_shape = (* weight .shape [:- 1 ], num_blocks )
184- if e8m0_scale .shape != expected_shape :
185- if e8m0_scale .numel () == weight .numel () // cls .BLOCK_SIZE :
186- e8m0_scale = e8m0_scale .reshape (expected_shape )
187-
188173 # Convert E8M0 biased exponent to scale factor: scale = 2^(127 - exponent)
189174 scale_factor = torch .exp2 (127 - e8m0_scale .float ())
190175
@@ -258,13 +243,6 @@ def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor:
258243 # Convert E8M0 biased exponent back to scale factor: descale = 2^(exponent - 127)
259244 descale = torch .exp2 (e8m0_scale .float () - 127 )
260245
261- # Reshape descale to match blocked tensor for broadcasting
262- expected_scale_shape = (* quantized_data .shape [:- 1 ], num_blocks )
263- if descale .shape != expected_scale_shape and descale .numel () == num_blocks * (
264- quantized_data .numel () // quantized_data .shape [- 1 ]
265- ):
266- descale = descale .view (expected_scale_shape )
267-
268246 dequantized = quantized_blocked * descale .unsqueeze (- 1 )
269247
270248 # Reshape and crop back to original shape
0 commit comments