-
Notifications
You must be signed in to change notification settings - Fork 365
Expand file tree
/
Copy pathquant_utils.py
More file actions
executable file
·1552 lines (1304 loc) · 62.1 KB
/
quant_utils.py
File metadata and controls
executable file
·1552 lines (1304 loc) · 62.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for quantization including scaling factors adjustments."""
import logging
from collections.abc import Generator
from types import SimpleNamespace
from typing import Any
from warnings import warn
import torch
import torch.nn as nn
from modelopt import __version__
from modelopt.torch.quantization.model_calib import (
enable_stats_collection,
finish_stats_collection,
svd,
)
from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear
from modelopt.torch.quantization.qtensor import (
FP8QTensor,
MXFP4QTensor,
MXFP8QTensor,
NVFP4QTensor,
QTensorWrapper,
)
from modelopt.torch.quantization.utils import (
QuantizerAttrNames,
quantizer_attr_names,
reduce_block_amax,
weight_attr_names,
)
from modelopt.torch.utils import clear_cuda_cache
from ..quantization.nn import NVFP4StaticQuantizer, SequentialQuantizer, TensorQuantizer
from .model_config import (
KV_CACHE_FP8,
KV_CACHE_INT8,
KV_CACHE_NVFP4,
KV_CACHE_NVFP4_AFFINE,
QUANTIZATION_FP8,
QUANTIZATION_FP8_PB_REAL,
QUANTIZATION_FP8_PB_WO,
QUANTIZATION_FP8_PC_PT,
QUANTIZATION_INT4_AWQ,
QUANTIZATION_INT8_SQ,
QUANTIZATION_INT8_WO,
QUANTIZATION_MXFP4,
QUANTIZATION_MXFP8,
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_MXFP4_FP8,
QUANTIZATION_W4A8_NVFP4_FP8,
)
logger = logging.getLogger(__name__)
def get_scaling_factor_from_weight(weight, group_size) -> torch.tensor:
"""Calculate the weight scaling factor for a given group size."""
[n, k] = weight.shape
if group_size != 0:
# int4_awq
if k % group_size != 0:
raise NotImplementedError(
"Weight shape is not divisible for block size for block quantization."
)
weight = weight.reshape(n, k // group_size, group_size)
maxbound = 7.0
else:
# int8_sq
maxbound = 127.0
amax = weight.abs().max(dim=-1)[0].float()
weights_scaling_factor = amax / maxbound
# Let's filter the zeros in the scaling factor if the weights are zero
# to avoid the divided-by-zero error..
weights_scaling_factor[weights_scaling_factor == 0] = 1.0
return weights_scaling_factor
def maybe_transpose_expert_weight_dimensions(
weight: torch.Tensor,
weight_scale: torch.Tensor | None = None,
is_bmm_expert_weight: bool = True,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Transpose the last two dimensions of expert weights.
This function transposes expert weights between the two layouts:
- (num_experts, input_dim, output_dim) ↔ (num_experts, output_dim, input_dim)
Since transpose(-2, -1) is self-inverse, this function can be used for both
forward and backward transformations. This is needed for quantization functions
that expect the last dimension to be the input dimension for block quantization.
Specifically used for bmm-style expert weights in models like llama4 and gpt-oss.
Args:
weight: The weight tensor to transpose. Expected shape for experts: (num_experts, dim1, dim2)
weight_scale: Optional weight scaling factor tensor to transpose alongside weight
is_bmm_expert_weight: Whether this is an expert weight (3D tensor) that needs transposition
Returns:
Tuple of (transposed_weight, transposed_weight_scale)
"""
if not is_bmm_expert_weight or weight.dim() != 3:
return weight, weight_scale
transposed_weight = weight.transpose(-2, -1)
transposed_weight_scale = weight_scale.transpose(-2, -1) if weight_scale is not None else None
return transposed_weight, transposed_weight_scale
def resmooth_and_get_scale(
merged_weights: torch.Tensor,
pre_quant_scales: list[torch.Tensor],
ranks: int,
group_size: int,
new_pre_quant_scale: torch.Tensor | None = None,
quantization: str | None = QUANTIZATION_NONE,
):
"""Resmooths weights from a single or multiple ranks and get scaling factors and amax.
Args:
merged_weights: Merged weights from ranks.
pre_quant_scales: List of pre-quantization scales for each rank.
ranks: Number of ranks.
group_size: Group size of the quantization block.
new_pre_quant_scale (optional): If not provided, weights will be resmoothed using
the average of pre_quant_scales.
Returns:
weights: Resmoothed weights.
weight_scaling_factors: Resmoothed scaling factors.
avg_pre_quant_scale: Calculated average of the quantization scale.
"""
if new_pre_quant_scale is None:
new_pre_quant_scale = torch.stack(pre_quant_scales).mean(dim=0)
assert len(pre_quant_scales) > 0 and new_pre_quant_scale.numel() == merged_weights.shape[1], (
"Shape of pre_quant_scales and weights do not match."
)
weights = torch.chunk(merged_weights, ranks, dim=0)
scales = []
new_weights = []
for i, p_scaling_factor in enumerate(pre_quant_scales):
# De smooth & Re smooth
weight = (
weights[i]
* p_scaling_factor.type(weights[i].dtype)
/ new_pre_quant_scale.type(weights[i].dtype)
)
new_weights.append(weight)
# If NVFP4_AWQ then we view the scales as uint8 to allow for cat later
if quantization in [QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT]:
scale, _ = NVFP4QTensor.get_weights_scaling_factor(weight, group_size).view(torch.uint8)
else:
scale = get_scaling_factor_from_weight(weight, group_size)
scales.append(scale)
resmoothed_scales = torch.cat(scales, dim=0)
return (
torch.cat(new_weights, dim=0),
resmoothed_scales.view(torch.float8_e4m3fn)
if quantization in [QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT]
else resmoothed_scales, # if NVFP4_AWQ we view the scales back as float8_e4m3fn after cat
new_pre_quant_scale,
)
def adjust_attn_amax_values(module):
"""Adjusts the amax values for the attention layers."""
projection_prefixes = ["q", "k", "v"]
max_amax = float("-inf")
proj_layers = []
# Find all projection layers whose names contain 'q', 'k', or 'v'
for name, sub_module in module.named_children():
for prefix in projection_prefixes:
if (
prefix in name
and hasattr(sub_module, "weight_quantizer")
and hasattr(sub_module.weight_quantizer, "amax")
):
proj_layers.append(sub_module)
max_amax = max(max_amax, sub_module.weight_quantizer.amax.item())
if not proj_layers:
raise ValueError(
"No projection layers with the specified prefixes ('q', 'k', 'v') have amax attributes"
)
assert max_amax > 0, "max_amax must be positive."
# Set all amax values to the maximum found
for proj_layer in proj_layers:
proj_layer.weight_quantizer.amax.fill_(max_amax)
def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor:
"""Returns scaling factor from the quantizer as torch.Tensor."""
if not quantizer.is_enabled:
return None
amax = quantizer.export_amax()
if amax is None:
return None
# tensorrt_llm uses float as the scaling_factors.
if quantizer.num_bits == (2, 1):
scaling_factor = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(quantizer)
else:
scaling_factor = amax.float() / quantizer.maxbound
assert torch.all(scaling_factor > 0), f"scaling factor {scaling_factor} not positive."
return scaling_factor
def _get_nvfp4_block_size(
weight_quantizer: NVFP4StaticQuantizer, weight: torch.Tensor, module_name: str = ""
) -> int:
"""Return block size for NVFP4 from quantizer's block_sizes; raise if missing."""
prefix = f"NVFP4StaticQuantizer{f' for {module_name}' if module_name else ''}"
block_sizes = weight_quantizer.block_sizes
if block_sizes is None:
raise ValueError(f"{prefix} has no block_sizes; cannot compute per-block amax from weight.")
block_size = block_sizes.get(-1) or block_sizes.get(weight.dim() - 1)
if block_size is None:
raise ValueError(
f"{prefix} block_sizes has no -1 or last-dim key; cannot compute per-block amax."
)
return block_size
def _set_amax_from_tensor(weight_quantizer: TensorQuantizer, tensor: torch.Tensor) -> None:
"""Set quantizer _amax buffer from tensor; copy in-place if same shape, else replace buffer."""
if (
hasattr(weight_quantizer, "_amax")
and weight_quantizer._amax is not None
and weight_quantizer._amax.shape == tensor.shape
):
weight_quantizer._amax.data.copy_(tensor.to(weight_quantizer._amax.device))
else:
if hasattr(weight_quantizer, "_amax"):
delattr(weight_quantizer, "_amax")
weight_quantizer.register_buffer("_amax", tensor.clone().detach())
def _ensure_weight_quantizer_calibrated(
weight_quantizer: TensorQuantizer, weight: torch.Tensor, module_name: str = ""
) -> None:
"""Calibrate weight quantizer if amax is not set.
This is a lazy calibration pattern used during export when weight quantizers
may not have been calibrated during the main calibration phase.
For NVFP4StaticQuantizer, _amax is per-block amax and _global_amax is the max over
blocks; both are computed from the weight when missing.
Args:
weight_quantizer: The weight quantizer to calibrate
weight: The weight tensor to use for calibration
module_name: Optional module name for better warning messages
"""
if isinstance(weight_quantizer, NVFP4StaticQuantizer):
need_per_block = not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None
need_global = (
not hasattr(weight_quantizer, "_global_amax") or weight_quantizer.global_amax is None
)
if not (need_per_block or need_global):
return
block_size = _get_nvfp4_block_size(weight_quantizer, weight, module_name)
warn(
f"NVFP4StaticQuantizer{f' for {module_name}' if module_name else ''} was not fully calibrated. "
f"Computing per-block amax and global_amax from weights. This may occur if: "
f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size"
)
per_block_amax = reduce_block_amax(weight, block_sizes={-1: block_size})
if need_per_block:
_set_amax_from_tensor(weight_quantizer, per_block_amax.to(weight.device))
if need_global:
weight_quantizer.global_amax = per_block_amax.max()
return
if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None:
warn(
f"Weight quantizer{f' for {module_name}' if module_name else ''} was not calibrated. "
f"Computing amax from weights. This may occur if: "
f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size"
)
weight_quantizer.reset_amax()
enable_stats_collection(weight_quantizer)
weight_quantizer(weight)
finish_stats_collection(weight_quantizer)
def get_activation_scaling_factor(
module: nn.Module, input_quantizer_name: str = "input_quantizer"
) -> torch.Tensor:
"""Returns the activation scaling factor."""
# If NVFP4, return activation scaling factor from NVFP4QTensor
input_quantizer = getattr(module, input_quantizer_name, None)
if input_quantizer is None:
return None
if get_quantization_format(module) in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
return NVFP4QTensor.get_activation_scaling_factor(input_quantizer)
return get_scaling_factor(input_quantizer)
def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> torch.Tensor:
"""Returns the weight scaling factor."""
# module.weight_quantizer could be a TensorQuantizer (for algorithms except W4A8) or
# a SequentialQuantizer (for W4A8). In the latter case, we need to get the scaling factor from the
# first quantizer of the SequentialQuantizer instance.
weight: nn.Parameter = getattr(module, weight_name)
weight_quantizer: TensorQuantizer | SequentialQuantizer | None = getattr(
module, quantizer_attr_names(weight_name).weight_quantizer, None
)
if weight_quantizer is None:
return None
if isinstance(weight_quantizer, SequentialQuantizer):
return get_scaling_factor(weight_quantizer[0])
quantization_format = get_quantization_format(module)
if quantization_format in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
# Calibrate weight quantizer if amax is not set
module_name = f"{type(module).__name__}.{weight_name}"
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
# This is because the kernel dequantizes weight to fp8, which is in range 448.
weight_scaling_factor_2 = weight_quantizer._amax.float() / 448.0
else:
weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(
weight_quantizer
)
# Unified method handles both static and dynamic quantizers
return NVFP4QTensor.get_weights_scaling_factor_from_quantizer(
weight_quantizer,
weight,
weight_scaling_factor_2.to(weight.device),
)[0]
if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_MXFP4]:
return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[
1
].reshape(*weight.shape[:-1], -1)
if quantization_format == QUANTIZATION_MXFP8:
return MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, weight_quantizer)
return get_scaling_factor(weight_quantizer)
def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") -> torch.Tensor:
"""Returns the secondary weight scaling factor."""
weight_quantizer = getattr(module, quantizer_attr_names(weight_name).weight_quantizer, None)
if weight_quantizer is None:
return None
quantization_format = get_quantization_format(module)
if quantization_format in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
# Calibrate weight quantizer if amax is not set
weight = getattr(module, weight_name)
module_name = f"{type(module).__name__}.{weight_name}"
_ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name)
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
# This is because the kernel dequantizes weight to fp8, which is in range 448.
return weight_quantizer._amax.float() / 448.0
else:
# Unified method handles both static and dynamic quantizers
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
# SequentialQuantizer is required
if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled:
return None
assert len(weight_quantizer) == 2, (
"modelopt only supports 2 sequential quantization layers for now"
)
return get_scaling_factor(weight_quantizer[-1])
def get_prequant_scaling_factor(module: nn.Module) -> torch.Tensor:
"""Returns the prequant scaling factor."""
prequant_scaling_factor = (
module.input_quantizer._pre_quant_scale.squeeze()
if hasattr(module, "input_quantizer")
and hasattr(module.input_quantizer, "_pre_quant_scale")
else None
)
if prequant_scaling_factor is not None:
assert torch.all(prequant_scaling_factor > 0), (
f"prequant scaling factor {prequant_scaling_factor} not positive."
)
return prequant_scaling_factor
def get_kv_cache_bias(kv_module: nn.Module) -> list[torch.Tensor]:
"""Returns the kv_cache bias if _bias_value is set. Else returns None."""
kv_bias = []
for quantizer in ["k_bmm_quantizer", "v_bmm_quantizer"]:
quantizer_module = getattr(kv_module, quantizer, None)
kv_bias.append(getattr(quantizer_module, "_bias_value", None))
return kv_bias
def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> list[torch.Tensor | None]:
"""Get the K and V BMM scaling factors for the self attention module.
Args:
self_attention_module: The self attention module to get the K and V BMM scaling factors from.
Returns:
The K and V BMM scaling factors.
"""
if not hasattr(self_attention_module, "k_bmm_quantizer") or not hasattr(
self_attention_module, "v_bmm_quantizer"
):
return [None, None]
scaling_factors = [
get_scaling_factor(getattr(self_attention_module, quantizer))
for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer")
]
# For FP8, we recommend default kv cache scaling factor to be 1.
if get_kv_cache_dtype(self_attention_module) == KV_CACHE_FP8:
for i, factor in enumerate(scaling_factors):
if factor is None:
continue
if factor.item() > 0.5:
warn(
f"Warning: Large KV activation detected: {factor.item()}, "
"Quantized KV cache may lead to higher accuracy drop."
)
scaling_factors[i] = torch.max(
factor, torch.tensor([1.0], dtype=torch.float, device=factor.device)
)
return scaling_factors
def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None:
"""Returns the kv_cache dtype.
If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8,
otherwise returns None.
Args:
modules: The module or list of modules to inspect.
Returns:
The kv_cache dtype.
"""
num_bits_list = []
is_affine = True
if isinstance(modules, nn.Module):
modules = [modules]
for module in modules:
# Case where the module has both k_bmm_quantizer and v_bmm_quantizer
# Still check for output quantizer for the unified_megatron_export path
for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer", "output_quantizer"):
quantizer_attr = getattr(module, quantizer, None)
if quantizer_attr and quantizer_attr.is_enabled:
num_bits_list.append(quantizer_attr.num_bits)
is_affine &= hasattr(quantizer_attr, "_bias_value")
return _compute_kv_cache_dtype(num_bits_list, is_affine)
def _compute_kv_cache_dtype(
num_bits_list: list[int | tuple[int, int]], is_affine: bool = False
) -> str | None:
"""Returns the kv_cache dtype.
If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8,
otherwise returns None.
Args:
num_bits_list: The list of num_bits from quantizers.
is_affine: Whether the quantizers have bias (affine mode).
Returns:
The kv_cache dtype.
"""
if (4, 3) in num_bits_list:
return KV_CACHE_FP8
elif 8 in num_bits_list:
return KV_CACHE_INT8
elif (2, 1) in num_bits_list and is_affine:
return KV_CACHE_NVFP4_AFFINE
elif (2, 1) in num_bits_list:
return KV_CACHE_NVFP4
else:
return QUANTIZATION_NONE
def get_weight_block_size(module: nn.Module, weight_name: str = "weight") -> int:
"""Returns the weight block size."""
weight_quantizer = getattr(module, quantizer_attr_names(weight_name).weight_quantizer, None)
if weight_quantizer is None:
return 0
if isinstance(weight_quantizer, SequentialQuantizer):
weight_quantizer = weight_quantizer[0]
if not weight_quantizer.is_enabled:
return 0
block_sizes = weight_quantizer.block_sizes
if block_sizes:
return block_sizes[-1]
return 0
def get_quantization_format(module) -> str | None:
"""Gets the quantization string.
Gets the quantization string by iterating through the module and its children.
The first non-None quantization string is returned.
"""
def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames):
weight_quantizer = getattr(layer, quantizer_attr_names.weight_quantizer, None)
input_quantizer = getattr(layer, quantizer_attr_names.input_quantizer, None)
if weight_quantizer is None or not weight_quantizer.is_enabled:
return QUANTIZATION_NONE
# Handle SequentialQuantizer
if isinstance(weight_quantizer, SequentialQuantizer):
assert (
len(weight_quantizer) == 2
and weight_quantizer[0].num_bits == 4
and weight_quantizer[1].num_bits == (4, 3)
), "Unsupported SequentialQuantizer configuration"
assert (
weight_quantizer[0].block_sizes
and len(weight_quantizer[0].block_sizes) > 0
and weight_quantizer[0].block_sizes[-1] > 0
), "Invalid block_sizes for SequentialQuantizer"
return QUANTIZATION_W4A8_AWQ
# Handle individual num_bits cases
if weight_quantizer.num_bits == 4:
assert len(weight_quantizer.block_sizes) > 0 and weight_quantizer.block_sizes[-1] > 0, (
"Invalid block_sizes for INT4 quantizer"
)
return QUANTIZATION_INT4_AWQ
if weight_quantizer.num_bits == 8:
if input_quantizer is not None and input_quantizer.is_enabled:
return QUANTIZATION_INT8_SQ
else:
return QUANTIZATION_INT8_WO
if weight_quantizer.num_bits == (4, 3):
if weight_quantizer.block_sizes:
assert weight_quantizer.block_sizes[-1] > 0, "Invalid block_sizes for FP8 quantizer"
# Check if this is MXFP8 (dynamic block quantization with scale_bits (8, 0))
block_sizes = getattr(weight_quantizer, "block_sizes")
if (
isinstance(block_sizes, dict)
and block_sizes.get("type", "static") == "dynamic"
and block_sizes.get("scale_bits") == (8, 0)
):
return QUANTIZATION_MXFP8
if weight_quantizer.fake_quant:
return QUANTIZATION_FP8_PB_WO
else:
return QUANTIZATION_FP8_PB_REAL
if weight_quantizer.axis == 0:
return QUANTIZATION_FP8_PC_PT
return QUANTIZATION_FP8
if weight_quantizer.num_bits == (2, 1):
# FP4 formats are all block quantization
block_sizes = getattr(weight_quantizer, "block_sizes")
scale_bits = block_sizes.get("scale_bits")
if input_quantizer is not None and hasattr(weight_quantizer, "svdquant_lora_a"):
return QUANTIZATION_NVFP4_SVDQUANT
if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"):
return QUANTIZATION_NVFP4_AWQ
if getattr(layer, "fused_with_prequant", False):
return QUANTIZATION_NVFP4_AWQ
assert input_quantizer is not None, (
f"input_quantizer is None for {quantizer_attr_names}"
)
if (
block_sizes.get("type", "static") == "dynamic"
and scale_bits == (8, 0)
and input_quantizer.is_enabled
and input_quantizer.num_bits == (4, 3)
and input_quantizer.block_sizes is None
):
return QUANTIZATION_W4A8_MXFP4_FP8
if (
block_sizes.get("type", "static") == "dynamic"
and scale_bits == (4, 3)
and input_quantizer.is_enabled
and input_quantizer.num_bits == (4, 3)
and input_quantizer.block_sizes is None
):
return QUANTIZATION_W4A8_NVFP4_FP8
if scale_bits == (4, 3):
return QUANTIZATION_NVFP4
elif scale_bits == (8, 0):
return QUANTIZATION_MXFP4
# Raise error for unsupported num_bits
raise NotImplementedError(
f"Unsupported quantizer with num_bits: {weight_quantizer.num_bits}"
)
for weight_name in weight_attr_names(module):
quantization = _get_quantization_from_layer(module, quantizer_attr_names(weight_name))
if quantization != QUANTIZATION_NONE:
return quantization
for _, layer in module.named_children():
format = get_quantization_format(layer)
if format != QUANTIZATION_NONE:
return format
return QUANTIZATION_NONE
def _prefix_wildcard_summarize_exclude_modules(unquantized_layers, quantized_layers):
"""Generate a summarization of the quantization layer configs using prefix wildcards.
Prefix wildcards means we only consider wildcards that is a prefix with a star in the end.
We do not consider other wildcards such as: a*b.
"""
def all_matching_prefix_wildcards(name):
# include all possible prefix wildcards, and the exact name itself
wildcards = {name}
for i in range(len(name) + 1):
wildcards.add(name[:i] + "*")
return wildcards
def next_formatted_matching_prefix_wildcards(name: str) -> Generator[list[str], None, None]:
"""Enumerate formatted prefix wildcards. A result may be a combination of prefix wildcards.
Formatted here means we only consider wildcards at dot split. We need two patterns.
1. a single wildcard: module_name*
2. a set of 2 wildcards: {module_name, module_name.*}. We need this pattern set because
module_name* may match other modules with module_name as a prefix.
"""
for i in range(len(name)):
if name[i] == ".":
yield [name[:i] + "*"]
yield [name[:i], name[:i] + ".*"]
# in the end, itself only is a wildcard
yield [name]
# any of the wildcard in this set cannot be present in the result
negative_wild_candidates = set()
for layer in quantized_layers:
negative = all_matching_prefix_wildcards(layer)
negative_wild_candidates.update(negative)
logger.debug(
f"Quantized layer {layer}, prefix wildcards {negative} identified as negative wildcards"
)
res_summary = set()
for layer in unquantized_layers:
candidate_wildcards = []
for wildcards in next_formatted_matching_prefix_wildcards(layer):
if any(wildcard in negative_wild_candidates for wildcard in wildcards):
# need a more specific wildcard
logger.debug(
f"Unquantized layer {layer}, prefix wildcards {wildcards} invalidated by negative wildcards"
)
continue
if all(wildcard in res_summary for wildcard in wildcards):
# we get covered already, do not need to move forward, and clear candidate
logger.debug(
f"Unquantized layer {layer}, prefix wildcards {wildcards} already covered"
)
candidate_wildcards = []
break
# find one, now terminate the search
candidate_wildcards = wildcards
logger.debug(
f"Unquantized layer {layer}, prefix wildcards {wildcards} identified as a new match"
)
break
# When candidate is the pair [prefix, prefix+".*"], emit only prefix+".*" for deployment.
if len(candidate_wildcards) == 2:
a, b = sorted(candidate_wildcards, key=len)
if b == a + ".*":
res_summary.add(b)
else:
res_summary.update(candidate_wildcards)
else:
res_summary.update(candidate_wildcards)
return res_summary
def process_layer_quant_config(layer_config_dict):
"""Processes per layer quantization information for TRTLLM export to quant_cfg.json."""
per_layer_config: dict[str, Any] = {
"quant_algo": None,
"kv_cache_quant_algo": None,
"quantized_layers": {},
}
layer_config: dict[str, Any] = {}
# Set of quantization formats used.
quantization_formats = set()
quantization_config = None
exclude_modules = []
for k, v in layer_config_dict.items():
if "awq_block_size" in k:
continue
# Get layer name for constructing quantized_layers dictionary under per_layer_config
prefix = ".".join(k.rsplit(".", 1)[:-1])
awq_key = prefix + ".awq_block_size"
# Get the corresponding AWQ block size
block_size_value = layer_config_dict.get(awq_key, 0)
if v == "fp8":
layer_config = {"quant_algo": "FP8"}
elif v == "fp8_pc_pt":
layer_config = {"quant_algo": "FP8_PER_CHANNEL_PER_TOKEN"}
elif v == "int4_awq":
layer_config = {
"quant_algo": "W4A16_AWQ",
"group_size": block_size_value,
"has_zero_point": False,
"pre_quant_scale": True,
}
elif v == "w4a8_awq":
layer_config = {
"quant_algo": "W4A8_AWQ",
"group_size": block_size_value,
"has_zero_point": False,
"pre_quant_scale": True,
}
elif v == "int8_wo":
layer_config = {"quant_algo": "W8A16"}
elif v == "int8_sq":
layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"}
elif v in ["nvfp4", "nvfp4_static"]:
layer_config = {
"quant_algo": "NVFP4",
"group_size": block_size_value,
}
elif v == "nvfp4_awq":
layer_config = {
"quant_algo": "NVFP4_AWQ",
"group_size": block_size_value,
"has_zero_point": False,
"pre_quant_scale": True,
}
elif v == "w4a8_nvfp4_fp8":
layer_config = {
"quant_algo": "W4A8_NVFP4_FP8",
"group_size": block_size_value,
}
elif v == "w4a8_mxfp4_fp8":
layer_config = {
"quant_algo": "W4A8_MXFP4_FP8",
"group_size": block_size_value,
}
elif v == "nvfp4_svdquant":
layer_config = {
"quant_algo": "NVFP4_SVD",
"group_size": block_size_value,
}
elif v == "mxfp8":
layer_config = {
"quant_algo": "MXFP8",
"group_size": block_size_value,
}
else:
layer_config = {"quant_algo": v}
if layer_config["quant_algo"] != QUANTIZATION_NONE:
quantization_formats.add(str(layer_config))
quantization_config = layer_config
per_layer_config["quantized_layers"].update({prefix: layer_config})
else:
exclude_modules.append(prefix)
# If we have more than one quantization format, infer MIXED_PRECISION
if len(quantization_formats) > 1:
per_layer_config["quant_algo"] = "MIXED_PRECISION"
elif len(quantization_formats) == 1 and quantization_config is not None:
per_layer_config.update(quantization_config)
per_layer_config["exclude_modules"] = sorted(
_prefix_wildcard_summarize_exclude_modules(
exclude_modules, per_layer_config["quantized_layers"].keys()
)
)
per_layer_config.pop("quantized_layers")
return per_layer_config
def pack_int4_in_uint8(weight, weights_scaling_factor):
"""Packs the INT4 weights into uint8 tensor."""
out_dim = weight.shape[-2]
assert out_dim % 2 == 0, f"Cannot pack weight. Out dimension {out_dim} is not an even number."
in_dim = weight.shape[-1]
block_size = weight.shape[-1] // weights_scaling_factor.shape[-1]
# Scale, round, and clamp to the signed 4-bit range [-8..7].
int8_tensor = (
(weight / weights_scaling_factor[..., :, torch.arange(in_dim) // block_size])
.round()
.clamp(-8, 7)
.to(torch.int8)
)
# -- Handle the MoE (3D) case vs. the 2D case --
if int8_tensor.dim() == 3:
# Dimensions might be (experts, out_dim, in_dim)
transpose = int8_tensor.permute(0, 2, 1) # -> (experts, in_dim, out_dim)
# Reshape to group two output channels (out_dim // 2) and keep an extra dimension of size 2
transpose = transpose.reshape(-1, in_dim, out_dim // 2, 2) # (E, in_dim, out_dim//2, 2)
# Pack two 4-bit values (val0,val1) into a single byte:
val0 = transpose[..., 0] & 0x0F
val1 = transpose[..., 1] & 0x0F
packed_byte = val0 | (val1 << 4)
# Transpose back to the shape (experts, out_dim // 2, in_dim)
return packed_byte.permute(0, 2, 1).contiguous().view(torch.uint8)
else:
# 2D weights: shape typically (out_dim, in_dim)
# Transpose to (in_dim, out_dim)
reshaped = int8_tensor.T.reshape(in_dim, out_dim // 2, 2)
# Pack two 4-bit values into one byte
val0 = reshaped[..., 0] & 0x0F
val1 = reshaped[..., 1] & 0x0F
packed_byte = val0 | (val1 << 4)
# Return shape (out_dim // 2, in_dim)
return packed_byte.T.contiguous().view(torch.uint8)
def to_quantized_weight(
weight: torch.Tensor,
weights_scaling_factor: torch.Tensor,
quantization: str,
weights_scaling_factor2: torch.Tensor | None = None,
block_size: int | None = None,
):
"""Converts the weight to the quantized (packed) format."""
if weights_scaling_factor is not None:
weights_scaling_factor = weights_scaling_factor.to(weight.device)
if weights_scaling_factor2 is not None:
weights_scaling_factor2 = weights_scaling_factor2.to(weight.device)
# For compressed weights, we directly return the data from wrapper
if isinstance(weight, QTensorWrapper):
return weight.data
if quantization == QUANTIZATION_FP8:
# Fix RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Float
# in speculative decoding fp8 model export
if weight.dtype == torch.float8_e4m3fn:
warn("Skipping quantization: weight already in fp8 format")
return weight
if weight.dim() == 3:
# for MOE stacked weights
# Clear GPU cache to avoid pontential GPU OOM issues for large models.
clear_cuda_cache()
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
return (weight / weights_scaling_factor).to(torch.float8_e4m3fn)
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)
if quantization == QUANTIZATION_MXFP8:
return MXFP8QTensor.quantize_with_scale(weight, weights_scaling_factor)
if quantization == QUANTIZATION_FP8_PB_WO:
return FP8QTensor.quantize(
weight, weights_scaling_factor.squeeze(), block_sizes={-1: block_size, -2: block_size}
)[0]._quantized_data
if quantization == QUANTIZATION_FP8_PC_PT:
if weight.dim() == 3:
# Handle different scale tensor shapes
if weights_scaling_factor.dim() == 1:
# Per-expert scaling only: (num_experts,) -> (num_experts, 1, 1)
return (weight / weights_scaling_factor[:, None, None]).to(torch.float8_e4m3fn)
elif weights_scaling_factor.dim() == 2:
# Per-channel scaling: check which dimension matches
if weights_scaling_factor.shape[0] != weight.shape[0]:
raise ValueError(
f"First dimension (num_experts) mismatch for FP8_PC_PT quantization. "
f"weight shape: {weight.shape}, scale shape: {weights_scaling_factor.shape}"
)
if weight.shape[-1] == weight.shape[-2]:
raise ValueError(
f"Ambiguous scaling dimension for FP8_PC_PT quantization with square weight matrix. "
f"weight shape: {weight.shape}, scale shape: {weights_scaling_factor.shape}. "
f"Cannot determine if scaling should be applied to input_dim or output_dim."
)
if weights_scaling_factor.shape[-1] == weight.shape[-1]:
# (num_experts, input_dim) -> (num_experts, 1, input_dim), BMM-style
return (weight / weights_scaling_factor.unsqueeze(-2)).to(torch.float8_e4m3fn)
elif weights_scaling_factor.shape[-1] == weight.shape[-2]:
# (num_experts, output_dim) -> (num_experts, output_dim, 1), Standard MoE case
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
else:
raise ValueError(
f"Cannot determine correct unsqueeze dimension for FP8_PC_PT quantization. "
f"weight shape: {weight.shape}, scale shape: {weights_scaling_factor.shape}"
)
return (weight / weights_scaling_factor[:, None]).to(torch.float8_e4m3fn)
if quantization in [QUANTIZATION_INT4_AWQ, QUANTIZATION_W4A8_AWQ]:
return pack_int4_in_uint8(weight, weights_scaling_factor)
if quantization in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
QUANTIZATION_NVFP4_SVDQUANT,
]:
assert block_size is not None, "Block size not passed. Unable to quantize to NVFP4 format."
assert weights_scaling_factor2 is not None, (
"Weights scaling factor 2 not passed. Unable to quantize to NVFP4 format"
)
# If MoE reshape weights_scaling_factor2 to enable quantize operations
return NVFP4QTensor.quantize(
weight,
block_size,
weights_scaling_factor,
weights_scaling_factor2.view(-1, 1, 1)
if weights_scaling_factor2.dim() != 0
else weights_scaling_factor2,
)[0]._quantized_data
if quantization in [QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_MXFP4]: