@@ -144,23 +144,24 @@ def get_weights_scaling_factor_from_quantizer(
144144 def quantize_with_scale (
145145 cls ,
146146 weight : torch .Tensor ,
147- e8m0_scale : torch .Tensor ,
147+ weights_scaling_factor : torch .Tensor ,
148148 ) -> torch .Tensor :
149149 """Quantize weight tensor using a pre-computed E8M0 scale.
150150
151151 This method is useful for export paths where the scale has already been computed.
152152
153153 Args:
154154 weight: The weight tensor to quantize. Must be at least 1D.
155- e8m0_scale : E8M0 scale as uint8 biased exponent (bias = 127).
155+ weights_scaling_factor : E8M0 scale as uint8 biased exponent (bias = 127).
156156 Shape should be [..., out_dim, in_dim // 32] for 2D+ tensors,
157157 or [in_dim // 32] for 1D tensors.
158158
159159 Returns:
160160 torch.Tensor: Quantized weight as float8_e4m3fn with same shape as input.
161161 """
162- assert e8m0_scale .dtype == cls .SCALE_DTYPE , (
163- f"e8m0_scale must be { cls .SCALE_DTYPE } (E8M0 format), got { e8m0_scale .dtype } "
162+ assert weights_scaling_factor .dtype == cls .SCALE_DTYPE , (
163+ f"weights_scaling_factor must be { cls .SCALE_DTYPE } (E8M0 format), "
164+ f"got { weights_scaling_factor .dtype } "
164165 )
165166
166167 in_dim = weight .shape [- 1 ]
@@ -171,13 +172,13 @@ def quantize_with_scale(
171172 )
172173
173174 # Convert E8M0 biased exponent to scale factor: scale = 2^(127 - exponent)
174- scale_factor = torch .exp2 (127 - e8m0_scale .float ())
175+ scale_factor = torch .exp2 (127 - weights_scaling_factor .float ())
175176
176177 # NOTE: vLLM/flashinfer may require this behavior:
177178 # scale_factor = torch.where(
178- # e8m0_scale == 0,
179+ # weights_scaling_factor == 0,
179180 # 1.0,
180- # torch.exp2(127 - e8m0_scale .float())
181+ # torch.exp2(127 - weights_scaling_factor .float())
181182 # )
182183
183184 weight_reshaped = weight .view (* weight .shape [:- 1 ], num_blocks , cls .BLOCK_SIZE )
@@ -189,30 +190,39 @@ def quantize_with_scale(
189190 return quantized_weight .view (weight .shape )
190191
191192 @classmethod
192- def quantize (cls , input : torch .Tensor ) -> tuple :
193+ def quantize (
194+ cls ,
195+ input : torch .Tensor ,
196+ weights_scaling_factor : torch .Tensor | None = None ,
197+ ) -> tuple :
193198 """Convert a tensor to MXFP8 quantized format.
194199
195200 Args:
196201 input (torch.Tensor): The input tensor to be quantized.
202+ weights_scaling_factor (torch.Tensor | None): Optional pre-computed E8M0 scale
203+ as uint8 biased exponent. If None, the scale will be computed from the input.
204+ Shape should be [..., in_dim // 32] matching input dimensions.
197205
198206 Returns:
199- tuple: (MXFP8QTensor, e8m0_scale) where e8m0_scale is uint8 biased exponent.
207+ tuple: (MXFP8QTensor, weights_scaling_factor) where weights_scaling_factor is
208+ E8M0 scale as uint8 biased exponent.
200209 """
201210 original_shape = input .shape
202211 original_dtype = input .dtype
203212
204213 input = reduce_block_padding (input , block_sizes = {- 1 : cls .BLOCK_SIZE })
205- input_amax = reduce_block_amax (input , block_sizes = {- 1 : cls .BLOCK_SIZE })
206214
207- e8m0_exponent = cls ._compute_e8m0_exponent (input_amax )
208- e8m0_scale = (e8m0_exponent + 127 ).to (cls .SCALE_DTYPE )
215+ if weights_scaling_factor is None :
216+ input_amax = reduce_block_amax (input , block_sizes = {- 1 : cls .BLOCK_SIZE })
217+ e8m0_exponent = cls ._compute_e8m0_exponent (input_amax )
218+ weights_scaling_factor = (e8m0_exponent + 127 ).to (cls .SCALE_DTYPE )
209219
210- quantized_data = cls .quantize_with_scale (input , e8m0_scale )
220+ quantized_data = cls .quantize_with_scale (input , weights_scaling_factor )
211221
212222 # Crop back to original shape
213223 quantized_data = quantized_data [..., : original_shape [- 1 ]]
214224
215- return cls (original_shape , original_dtype , quantized_data ), e8m0_scale
225+ return cls (original_shape , original_dtype , quantized_data ), weights_scaling_factor
216226
217227 def dequantize (self , dtype : torch .dtype = None , ** kwargs ) -> torch .Tensor :
218228 """Dequantize MXFP8 tensor back to the target dtype.
0 commit comments