-
Notifications
You must be signed in to change notification settings - Fork 375
Expand file tree
/
Copy pathnvfp4_tensor.py
More file actions
402 lines (334 loc) · 15.8 KB
/
nvfp4_tensor.py
File metadata and controls
402 lines (334 loc) · 15.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements NVFP4 quantization for efficient tensor storage and computation."""
import torch
from ..backends.utils import fp4_compatible
from ..qtensor.base_qtensor import BaseQuantizedTensor
from ..utils import reduce_amax, reduce_block_amax, reduce_block_padding
# Define conversion tables
e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
e2m1_values = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6])
__all__ = ["NVFP4QTensor"]
class NVFP4QTensor(BaseQuantizedTensor):
"""Implements the INT4 quantization on tensors for more efficient storage or computation.
Attributes:
quantized_data (torch.Tensor): The quantized data stored as a packed uint8 tensor.
"""
e2m1_values_on_device = {}
e2m1_bounds_on_device = {}
@classmethod
def get_e2m1_values(cls, device):
"""Returns the e2m1 values on the device."""
if device not in cls.e2m1_values_on_device:
cls.e2m1_values_on_device[device] = e2m1_values.to(device)
return cls.e2m1_values_on_device[device]
@classmethod
def get_e2m1_bounds(cls, device):
"""Returns the e2m1 values on the device."""
if device not in cls.e2m1_bounds_on_device:
cls.e2m1_bounds_on_device[device] = e2m1_bounds.to(device)
return cls.e2m1_bounds_on_device[device]
@classmethod
def _is_static_quantizer(cls, weight_quantizer) -> bool:
"""Check if the weight quantizer is a static NVFP4 quantizer with pre-computed amax."""
return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None
@classmethod
def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer):
"""Returns per tensor weight scaling factor from the weight_quantizer.
Handles both static NVFP4 quantizers (using global_amax) and
dynamic quantizers (using _amax).
Args:
weight_quantizer: The weight quantizer (static or dynamic).
Returns:
The global scaling factor as a float tensor.
"""
if cls._is_static_quantizer(weight_quantizer):
return weight_quantizer.global_amax.float() / (6.0 * 448.0)
else:
assert hasattr(weight_quantizer, "_amax"), (
"Weight quantizer does not have attribute amax"
)
return weight_quantizer._amax.float() / (6.0 * 448.0)
@classmethod
def get_weights_scaling_factor_from_quantizer(
cls,
weight_quantizer,
weight: torch.Tensor,
weights_scaling_factor_2: torch.Tensor | None = None,
keep_high_precision: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns quantized per block weight scaling factor from quantizer.
Handles both static NVFP4 quantizers (with pre-computed per-block amax)
and dynamic quantizers (computing from weight tensor).
Args:
weight_quantizer: The weight quantizer (static or dynamic).
weight: The weight tensor (used for shape in static, values in dynamic).
weights_scaling_factor_2: Optional pre-computed global scale.
keep_high_precision: Whether to keep scales in high precision.
Returns:
Tuple of (per_block_scale, weights_scaling_factor_2).
"""
block_size = weight_quantizer.block_sizes[-1]
if weights_scaling_factor_2 is None:
weights_scaling_factor_2 = cls.get_weights_scaling_factor_2_from_quantizer(
weight_quantizer
)
if cls._is_static_quantizer(weight_quantizer):
# Static path: use pre-computed per-block amax values from quantizer
global_amax = weight_quantizer.global_amax.float()
per_block_amax = weight_quantizer._amax.float()
# Compute scales in float
per_block_scale_max = global_amax / 6.0
per_block_scale = per_block_amax / 6.0
per_block_scale[per_block_scale == 0] = 1.0
# Reshape per_block_scale to match weight's block structure
num_blocks_per_row = weight.shape[-1] // block_size
expected_shape = (*weight.shape[:-1], num_blocks_per_row)
per_block_scale = per_block_scale.view(expected_shape)
# Quantize scales to FP8
if not keep_high_precision:
per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to(
torch.float8_e4m3fn
)
return per_block_scale, weights_scaling_factor_2
else:
# Dynamic path: compute from weight tensor
return cls.get_weights_scaling_factor(
weight, block_size, weights_scaling_factor_2, keep_high_precision
)
@classmethod
def get_weights_scaling_factor(
cls,
input: torch.Tensor,
block_size: int,
weights_scaling_factor_2: torch.Tensor | None = None,
keep_high_precision: bool = False,
):
"""Returns quantized per block weight scaling factor from weight tensor.
This is the dynamic path that computes scales directly from the weight values.
For quantizers with pre-computed amax, use get_weights_scaling_factor_from_quantizer.
"""
if weights_scaling_factor_2 is None:
weights_scaling_factor_2 = cls.get_weights_scaling_factor_2(input)
# Get per_block amax
assert block_size != 0, "Block size is zero. Cannot return per_block amax for given input."
assert input.shape[-1] % block_size == 0, (
"Weight shape is not divisible for block size for block quantization."
)
# Get per block amax
per_block_amax = reduce_block_amax(input, block_sizes={-1: block_size}).float()
# Get per-block-scale
per_block_scale = per_block_amax / (
6.0 * weights_scaling_factor_2.to(per_block_amax.device)
)
# Set all zero values in scale to 1.0
per_block_scale[per_block_scale == 0] = 1.0
# Convert to torch.float8_e4m3fn
if not keep_high_precision:
per_block_scale = per_block_scale.to(torch.float8_e4m3fn)
return per_block_scale, weights_scaling_factor_2
@classmethod
def get_weights_scaling_factor_2(cls, input: torch.Tensor):
"""Returns per tensor weight scaling factor."""
return reduce_amax(input).float() / (6.0 * 448.0)
@classmethod
def get_activation_scaling_factor(cls, quantizer):
"""Returns the activation scaling factor for export."""
# TODO: Update to use module and not quantizer
if not quantizer.is_enabled:
return None
amax = quantizer.export_amax()
if amax is None:
return None
activation_scaling_factor = amax.float() / (quantizer.maxbound * 448.0)
# Handle exact-zero entries produced by MoE routing sparsity: some
# per-channel input slots on rarely-routed experts never see traffic
# during calibration, leaving their amax (and therefore scaling factor)
# at exactly zero. A zero scaling factor would break downstream
# dequantization arithmetic. Replace exact zeros with the minimum
# positive value in the same tensor — this is a no-op for values
# flowing through zeroed channels (~0 anyway) and keeps the tensor
# valid. We deliberately leave negative entries alone so that the
# existing positivity assertion below still catches upstream
# quantizer/config bugs rather than silently masking them.
zero_mask = activation_scaling_factor == 0
if zero_mask.any():
positive = activation_scaling_factor[activation_scaling_factor > 0]
replacement = (
positive.min()
if positive.numel() > 0
else torch.tensor(
1e-8,
device=activation_scaling_factor.device,
dtype=activation_scaling_factor.dtype,
)
)
activation_scaling_factor = torch.where(
zero_mask, replacement, activation_scaling_factor
)
assert torch.all(activation_scaling_factor > 0), (
f" activation scaling factor {activation_scaling_factor} not positive."
)
return activation_scaling_factor
@classmethod
def _cast_fp4(cls, weight: torch.Tensor):
"""Converts tensor to uint4."""
device = weight.device
# Extract sign and compute absolute values in one pass
sign_bit = (weight < 0).to(torch.uint8)
weight_abs = weight.abs_()
# Get bounds and compute ordinal values
e2m1_bounds = cls.get_e2m1_bounds(device)
ord = torch.searchsorted(e2m1_bounds, weight_abs, out_int32=True).to(torch.uint8)
# Efficiently check for rounding at odd-indexed bounds [0.75, 1.75, 2.5]
# Only need to check bounds at indices 1, 3, 5
odd_bounds = e2m1_bounds[[1, 3, 5]] # [0.75, 1.75, 2.5]
equals_odd_bounds = torch.any(weight_abs.unsqueeze(-1) == odd_bounds, dim=-1).to(
torch.uint8
)
# Combine sign, ordinal, and rounding adjustment
return (sign_bit << 3) + ord + equals_odd_bounds
@classmethod
def quantize(
cls,
input: torch.Tensor,
block_size: int,
weights_scaling_factor: torch.Tensor | None = None,
weights_scaling_factor_2: torch.Tensor | None = None,
keep_high_precision: bool = False,
try_tensorrt: bool = False,
):
"""Converting a tensor to a quantized format based on NVFP4 quantization.
Args:
input (torch.Tensor): The input tensor to be quantized.
block_size (int): The size of each block for quantization.
weights_scaling_factor (torch.Tensor): The scaling factor for the weights.
weights_scaling_factor_2 (torch.Tensor): The scaling factor for the weights.
keep_high_precision (bool): Whether to keep output scales at high precision.
Returns:
tuple: Contains quantized data, quantized per block scaling factor, and per tensor scaling factor.
"""
# Get original input shape
input_shape = input.shape
input_dtype = input.dtype
# pad the input if needed
input = reduce_block_padding(input, block_sizes={-1: block_size})
if weights_scaling_factor_2 is None:
weights_scaling_factor_2 = cls.get_weights_scaling_factor_2(input)
# try call trtllm fp4 quantization if possible
if (
fp4_compatible()
and weights_scaling_factor is None
and try_tensorrt
and block_size == 16
and input.is_cuda
and input.dtype in [torch.half, torch.bfloat16]
):
try:
import tensorrt_llm # noqa: F401
# Make sure this utils is available for dequantize
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import (
cutlass_fp4_scale_to_modelopt_fp4_scale, # noqa: F401
)
packed_weight, weights_scaling_factor = torch.ops.trtllm.fp4_quantize(
input, 1.0 / weights_scaling_factor_2, block_size, False
)
# weights_scaling_factor is ready for nvfp4_gemm to use;
# however, it is different from the non trtllm version, so when dequantize,
# it will be converted.
return (
cls(input_shape, input_dtype, packed_weight),
weights_scaling_factor,
weights_scaling_factor_2,
)
except ImportError:
pass
if weights_scaling_factor is None:
weights_scaling_factor, _ = cls.get_weights_scaling_factor(
input, block_size, weights_scaling_factor_2
)
# Reshape the weight and scale factors
original_shape = input.shape
input = input.view((*tuple(input.shape[:-1]), -1, block_size))
# Scale weights
scaled_weight = input / (
(weights_scaling_factor.to(torch.float32) * weights_scaling_factor_2).unsqueeze(-1)
)
# Reshape weights to original
scaled_weight = scaled_weight.view(original_shape)
if keep_high_precision:
return scaled_weight
# Cast weights to fp4
q_weight = cls._cast_fp4(scaled_weight)
# Pack weights
packed_weight = (q_weight[..., 1::2] << 4) | q_weight[..., 0::2]
return (
cls(input_shape, input_dtype, packed_weight),
weights_scaling_factor,
weights_scaling_factor_2,
)
def dequantize(self, dtype: torch.dtype = None, fast=False, **kwarg):
"""Dequantze NVFP4 packed tensor to a target dtype."""
if dtype is None:
dtype = self.metadata["dtype"]
def _unpack_tensor(input: torch.Tensor):
# Initialize storage for unpacked tensor
unpacked_shape = list(input.shape)
unpacked_shape[-1] = unpacked_shape[-1] * 2
unpacked = torch.empty(unpacked_shape, dtype=dtype, device=input.device)
unpacked[..., 1::2] = input >> 4
unpacked[..., 0::2] = input & 0x0F
unpacked = unpacked.reshape(-1)
unpacked = self.get_e2m1_values(input.device)[unpacked.long()]
return unpacked.reshape(unpacked_shape)
# Get scales from kwargs
if kwarg["scale"].dtype == torch.uint8 and kwarg["scale"].ndim == 1:
# If quantization is done by trtllm, convert cutlass fp4 scale to modelopt fp4 scale
try:
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import (
cutlass_fp4_scale_to_modelopt_fp4_scale,
)
kwarg["scale"] = cutlass_fp4_scale_to_modelopt_fp4_scale(
kwarg["scale"], self.metadata["shape"][-2:]
)
except ImportError as e:
raise ImportError(
"This tensor is quantized by trtllm, but tensorrt_llm cannot be imported."
) from e
if fast:
from ..triton.fp4_kernel import fp4_dequantize
return fp4_dequantize(
self._quantized_data,
kwarg["scale"],
kwarg["double_scale"],
block_size=kwarg["block_sizes"][-1],
dtype=dtype,
).reshape(self.metadata["shape"])
else:
q_per_block_scale = (
kwarg["scale"].to(torch.float32)
if kwarg["scale"].dtype == torch.float8_e4m3fn
else kwarg["scale"]
)
block_size = kwarg["block_sizes"][-1]
per_block_quant_scale = kwarg["double_scale"]
# Dequantize scales
per_block_scale = q_per_block_scale * per_block_quant_scale
# Unpack and unscale weights
deq_data = _unpack_tensor(self._quantized_data)
deq_data = deq_data.view(
(*tuple(deq_data.shape[:-1]), -1, block_size)
) * per_block_scale.unsqueeze(-1)
return deq_data.reshape(self.metadata["shape"]).to(dtype)