forked from NVIDIA/Model-Optimizer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhf_ptq.py
More file actions
executable file
·1281 lines (1146 loc) · 48.8 KB
/
hf_ptq.py
File metadata and controls
executable file
·1281 lines (1146 loc) · 48.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import random
import time
import warnings
from typing import Any
import numpy as np
import torch
from accelerate.hooks import remove_hook_from_module
from example_utils import (
build_quant_cfg,
copy_custom_model_files,
create_vlm_calibration_loop,
get_model,
get_processor,
get_tokenizer,
is_enc_dec,
is_nemotron_vl,
load_mtp_weights,
run_nemotron_vl_preview,
)
from torch.utils.data import DataLoader
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoProcessor,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
ProcessorMixin,
WhisperProcessor,
)
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
import modelopt.torch.sparsity as mts
from modelopt.recipe import ModelOptPTQRecipe, load_recipe
from modelopt.torch.export import (
export_hf_checkpoint,
export_speculative_decoding,
export_tensorrt_llm_checkpoint,
get_model_type,
has_spec_opt,
save_expert_token_count_table,
)
from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
from modelopt.torch.quantization.utils import is_quantized
from modelopt.torch.utils.dataset_utils import (
create_forward_loop,
get_dataset_dataloader,
get_max_batch_size,
get_supported_datasets,
)
from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader
from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader
RAND_SEED = 1234
def _set_kv_cache_constant_amax(quant_cfg: dict) -> None:
"""Set use_constant_amax on KV cache quantizers.
Creates a new dict for the KV bmm quantizer config to avoid mutating shared references.
"""
if "*[kv]_bmm_quantizer" in quant_cfg:
quant_cfg["*[kv]_bmm_quantizer"] = {
**quant_cfg["*[kv]_bmm_quantizer"],
"use_constant_amax": True,
}
QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
"int8": mtq.INT8_DEFAULT_CFG,
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
"int8_wo": mtq.INT8_WEIGHT_ONLY_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
"nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG,
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
"fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG,
"nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG,
"nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG,
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
"nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG,
}
KV_QUANT_CFG_CHOICES = {
"none": "none",
"fp8_cast": "FP8_KV_CFG",
"fp8": "FP8_KV_CFG",
"fp8_affine": "FP8_AFFINE_KV_CFG",
"nvfp4_cast": "NVFP4_KV_CFG",
"nvfp4": "NVFP4_KV_CFG",
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
"nvfp4_rotate": "NVFP4_KV_ROTATE_CFG",
}
# Formats that use use_constant_amax (no calibration needed).
_KV_CAST_FORMATS = {"fp8_cast", "nvfp4_cast"}
mto.enable_huggingface_checkpointing()
def extract_and_prepare_language_model_from_vl(full_model):
"""Extract language model from VL model and disable quantization for non-language components.
Args:
full_model: The full VLM model
Returns:
tuple: (language_model, model_type) or (None, None) if not a VLM
"""
language_model_lineage = get_language_model_from_vl(full_model)
if language_model_lineage is not None:
language_model = language_model_lineage.pop(-1)
ancestors = language_model_lineage
# Apply disabled quant to all modules that are not part of language_model
# This excludes them during HF export
disabled_quant_cfg = {
"quant_cfg": {"default": {"enable": False}},
"algorithm": "max",
}
memo = set(ancestors) | {language_model}
for ancestor in ancestors:
for _, module in ancestor.named_children():
if module not in memo:
mtq.quantize(module, disabled_quant_cfg, forward_loop=None)
memo.add(module)
model_type = get_model_type(language_model)
return language_model, model_type
return None, None
def make_calib_dataloader(
args: argparse.Namespace,
language_model: torch.nn.Module,
processor: BaseImageProcessor | ProcessorMixin | None,
tokenizer: PreTrainedTokenizerBase | None,
device: torch.device,
model_type: str | None,
) -> tuple[DataLoader, str | None]:
calib_dataloader = None
first_text_speech_dataset = None
if args.calib_with_images:
# VLM image-text calibration path: assume Nemotron VLM dataset by default.
assert processor is not None, (
"Please provide a processor (e.g., AutoProcessor) for image calibration."
)
assert len(args.calib_size) == 1, (
"Image calibration currently supports a single dataset. "
"Please pass --calib_size with one value (e.g., --calib_size 256)."
)
calib_dataloader = get_vlm_dataset_dataloader(
dataset_name="nemotron_vlm_dataset_v2",
processor=processor,
batch_size=args.batch_size,
num_samples=args.calib_size[0],
device=device,
max_length=args.calib_seq,
require_image=True,
subsets=["sparsetables", "plotqa_cot", "wiki_en"],
shuffle_buffer_size=10_000,
seed=42,
use_media_shards=True,
max_shards=1,
)
elif model_type == "mllama":
assert processor is not None and isinstance(processor, MllamaImageProcessor), (
"The MllamaImageProcessor must be set."
)
assert len(args.calib_size) == 1, (
"mllama only supports one dataset for calibration, can extend this in the future"
)
calib_dataloader = get_vlm_dataset_dataloader(
dataset_name=args.dataset[0] if args.dataset else "scienceqa",
processor=processor,
batch_size=args.batch_size,
num_samples=args.calib_size[0],
)
elif model_type == "whisper":
assert processor is not None and isinstance(processor, WhisperProcessor), (
"The AutoProcessor must be set."
)
assert len(args.calib_size) == 1, (
"whisper only supports one dataset for calibration, can extend this in the future"
)
calib_dataloader, first_text_speech_dataset = get_speech_dataset_dataloader(
dataset_name=args.dataset[0] if args.dataset else "peoples_speech",
processor=processor,
batch_size=args.batch_size,
num_samples=args.calib_size[0],
device=device,
dtype=language_model.dtype,
)
else:
assert tokenizer is not None and isinstance(
tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
), "The PreTrainedTokenizer must be set"
# Labels are only needed for gradient-based auto_quantize
include_labels = (
args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient"
)
calib_dataloader = get_dataset_dataloader(
dataset_name=args.dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_samples=args.calib_size,
device=device,
include_labels=include_labels,
)
return calib_dataloader, first_text_speech_dataset
def auto_quantize(
args: argparse.Namespace,
language_model: torch.nn.Module,
calib_dataloader: DataLoader,
auto_quantize_method="gradient",
auto_quantize_score_size=128,
auto_quantize_checkpoint=None,
):
"""Auto search quantization of multiple formats."""
if args.calib_with_images:
raise NotImplementedError(
"AutoQuantize with image-text calibration is not supported yet. "
"Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images."
)
assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), (
"Auto Quantization is not supported for pipeline parallel size > 1"
)
qformat_list = args.qformat.split(",")
assert qformat_list, "No quantization formats provided"
# Check if all provided quantization formats are supported
assert all(
qformat
in [
"fp8",
"int8_sq",
"int8_wo",
"int4_awq",
"nvfp4",
"nvfp4_awq",
"nvfp4_mse",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"nvfp4_experts_only",
"nvfp4_omlp_only",
"mxfp8",
]
for qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"
def loss_func(output, data):
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
# which contains the loss attribute.
return output.loss
if auto_quantize_method == "gradient":
# For gradient-based method, return full output with loss
def forward_step(model, batch):
return model(**batch)
elif auto_quantize_method == "kl_div":
# For KL divergence method, return only logits
def forward_step(model, batch):
return model(**batch).logits
else:
raise ValueError(
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
)
language_model, _ = mtq.auto_quantize(
language_model,
constraints={"effective_bits": args.auto_quantize_bits},
data_loader=calib_dataloader,
forward_step=forward_step,
loss_func=loss_func, # Only used for gradient-based method
# TRTLLM only support one quantization format or None (do not quantize, internally supported)
quantization_formats=[QUANT_CFG_CHOICES[format] for format in qformat_list],
num_calib_steps=len(calib_dataloader),
# AutoQuantize scoring is the costly phase; allow smaller sample counts than calibration.
num_score_steps=min(
len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1)
),
verbose=True,
# Disable all default disabled layers such as lm_head, mlp.gate, router etc.
disabled_layers=list(_default_disabled_quantizer_cfg.keys()),
method=auto_quantize_method,
checkpoint=auto_quantize_checkpoint,
)
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
# We need to explicitly set up KV cache quantization after auto_quantize
enable_quant_kv_cache = args.kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
if enable_quant_kv_cache:
kv_cache_quant_cfg = copy.deepcopy(
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
)
kv_cache_quant_cfg.pop("default", None) # keep other quantizers from auto_quantize
if args.kv_cache_qformat in _KV_CAST_FORMATS:
_set_kv_cache_constant_amax(kv_cache_quant_cfg)
mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg)
if args.kv_cache_qformat not in _KV_CAST_FORMATS:
# Calibrate only the KV cache quantizers; disable all others.
with mtq.set_quantizer_by_cfg_context(
language_model, {"*": {"enable": False}, **kv_cache_quant_cfg}
):
mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop)
return language_model
def load_model(args: argparse.Namespace):
# If low memory mode is enabled, we compress the model while loading the HF checkpoint.
calibration_only = False
if not args.low_memory_mode:
full_model = get_model(
args.pyt_ckpt_path,
args.device,
gpu_mem_percentage=args.gpu_max_mem_percentage,
trust_remote_code=args.trust_remote_code,
use_seq_device_map=args.use_seq_device_map,
attn_implementation=args.attn_implementation,
)
else:
assert args.qformat in QUANT_CFG_CHOICES, (
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
)
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
if args.kv_cache_qformat != "none":
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
)
# Mirror the use_constant_amax logic from quantize_main so that init_quantized_weights
# builds the KV quantizers with use_constant_amax already set. In calibration_only mode
# mtq.calibrate() does not re-apply quant_cfg, so this must happen before
# init_quantized_weights runs.
if args.kv_cache_qformat in _KV_CAST_FORMATS:
quant_cfg = copy.deepcopy(quant_cfg)
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])
# Do not use real quant GEMM so the calibration can be more accurate.
with init_quantized_weights(
quant_cfg, gpu_mem_percentage=args.gpu_max_mem_percentage, quant_gemm=False
):
model_kwargs = {"trust_remote_code": args.trust_remote_code}
if args.attn_implementation is not None:
model_kwargs["attn_implementation"] = args.attn_implementation
full_model = AutoModelForCausalLM.from_pretrained(
args.pyt_ckpt_path,
**model_kwargs,
)
calibration_only = True
model_type = get_model_type(full_model)
device = full_model.device
if hasattr(full_model, "model"):
device = full_model.model.device
processor = None
tokenizer = None
language_model = full_model
default_padding_side = None
default_pad_token = None
is_nemotron_vl_model = is_nemotron_vl(full_model)
# Default to image-text calibration for VLM models
if is_nemotron_vl_model and not args.calib_with_images:
print("Nemotron VL model detected. Enabling image-text calibration by default.")
args.calib_with_images = True
if model_type == "mllama":
processor = get_processor(
args.pyt_ckpt_path,
model_type,
device,
trust_remote_code=args.trust_remote_code,
attn_implementation=args.attn_implementation,
)
elif model_type == "whisper":
processor = get_processor(
args.pyt_ckpt_path,
model_type,
device,
trust_remote_code=args.trust_remote_code,
)
elif is_nemotron_vl_model and args.calib_with_images:
# For Nemotron VL image calibration, we need an AutoProcessor to build multimodal inputs.
processor = AutoProcessor.from_pretrained(
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code, padding_side="left"
)
if hasattr(processor, "tokenizer") and processor.tokenizer is not None:
tokenizer = processor.tokenizer
else:
tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)
default_pad_token = tokenizer.pad_token
# Some Nemotron tokenizers may not define pad_token by default; but we use padding=True during calibration.
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
assert tokenizer.pad_token is not None, f"Pad token for {args.pyt_ckpt_path} cannot be set!"
default_padding_side = tokenizer.padding_side
tokenizer.padding_side = "left"
# Quantize only the language model, but keep the full_model for calibration forward.
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model)
if extracted_lm is not None:
language_model = extracted_lm
model_type = extracted_model_type
else:
if args.dataset is None:
args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
warnings.warn(
"No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
)
# Adjust calib_size to match dataset length by extending or truncating as needed
args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[
: len(args.dataset)
]
tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)
default_padding_side = tokenizer.padding_side
default_pad_token = tokenizer.pad_token
# Left padding usually provides better calibration result.
tokenizer.padding_side = "left"
# We only quantize the language model for VLMs other than the type supported above.
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model)
if extracted_lm is not None:
language_model = extracted_lm
model_type = extracted_model_type
if model_type == "phi4mm":
warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.")
return (
full_model,
language_model,
model_type,
calibration_only,
processor,
tokenizer,
default_padding_side,
default_pad_token,
device,
)
def sparsity_main(
args: argparse.Namespace,
full_model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase | None,
device: torch.device,
):
if args.batch_size == 0:
# Sparse algorithm takes more GPU memory so we reduce the batch_size by 4.
args.batch_size = max(get_max_batch_size(full_model) // 4, 1)
args.batch_size = min(args.batch_size, sum(args.calib_size))
print(f"Use calib batch_size {args.batch_size}")
# Different calibration datasets are also available, e.g., "pile" and "wikipedia"
# Please also check the docstring for the datasets available
assert tokenizer is not None and isinstance(
tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
), "The PreTrainedTokenizer must be set"
calib_dataloader = get_dataset_dataloader(
dataset_name=args.dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_samples=args.calib_size,
max_sample_length=args.calib_seq,
device=device,
)
full_model = mts.sparsify(
full_model,
args.sparsity_fmt,
config={"data_loader": calib_dataloader, "collect_func": lambda x: x},
)
mts.export(full_model)
def mono_quantize(
args: argparse.Namespace,
quant_cfg: dict[str, Any],
full_model: torch.nn.Module,
language_model: torch.nn.Module,
model_type: str | None,
calibration_only: bool,
calib_dataloader: DataLoader,
is_nemotron_vl_model: bool,
):
"""Plain quantization of the given language model to a single quantization configuration."""
model_is_already_quantized = is_quantized(language_model)
if "awq" in args.qformat:
print(
"\n####\nAWQ calibration could take longer than other calibration methods. "
"Consider reducing calib_size to reduce calibration time.\n####\n"
)
# For Nemotron VL models, disable quantization of vision components
if is_nemotron_vl_model:
print("Disabling quantization for vision components in Nemotron VL model")
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
# Also disable radio model components specifically (for Nemotron-Parse)
quant_cfg["quant_cfg"]["*radio*"] = {"enable": False}
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
quant_cfg["quant_cfg"]["*encoder*"] = {"enable": False} # Disable encoder
quant_cfg["quant_cfg"]["*model_encoder*"] = {"enable": False} # Nemotron-Parse specific
print("Quantization will only be applied to the decoder (text generation) component")
if not model_is_already_quantized or calibration_only:
# quantize the model
use_calibration = need_calibration(quant_cfg)
if not use_calibration:
warnings.warn("Dynamic quantization. Calibration skipped.")
calibrate_loop = None
if use_calibration:
# For Nemotron VL image calibration, the dataloader yields multimodal kwargs (e.g., pixel_values).
# Those kwargs must be consumed by the *full* VLM model, not the extracted language_model.
if args.calib_with_images and is_nemotron_vl_model:
calibrate_loop = create_vlm_calibration_loop(full_model, calib_dataloader)
else:
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
if calibration_only:
language_model = mtq.calibrate(
language_model, quant_cfg["algorithm"], forward_loop=calibrate_loop
)
else:
language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop)
# For VL models, update full_model to use the quantized language model
if is_nemotron_vl_model:
language_model_lineage = get_language_model_from_vl(full_model)
if language_model_lineage is not None:
print("Updating full_model with quantized language_model...")
language_model_lineage[-2].language_model = language_model
else:
warnings.warn("Skipping quantization: model is already quantized.")
def export_quantized(
args: argparse.Namespace,
full_model: torch.nn.Module,
language_model: torch.nn.Module,
model_type: str | None,
tokenizer: PreTrainedTokenizerBase | None,
default_padding_side,
default_pad_token,
):
with torch.inference_mode():
if model_type is None:
print(f"Unknown model type {type(language_model).__name__}. Continue exporting...")
model_type = f"unknown:{type(language_model).__name__}"
export_path = args.export_path
# Early exit for speculative decoding checkpoints
# No tokenizer saving needed for spec ckpts
if has_spec_opt(full_model):
export_speculative_decoding(full_model, export_dir=export_path)
print(f"Quantized speculative decoding checkpoint exported to: {export_path}")
return
# Check if the model is a multimodal/VLM model
is_vlm = is_multimodal_model(full_model)
if is_vlm:
# Save original model config and the processor config to the export path for VLMs.
print(f"Saving original model config to {export_path}")
config_kwargs = {"trust_remote_code": args.trust_remote_code}
if args.attn_implementation is not None:
config_kwargs["attn_implementation"] = args.attn_implementation
AutoConfig.from_pretrained(args.pyt_ckpt_path, **config_kwargs).save_pretrained(
export_path
)
# Try to save processor config if available
try:
print(f"Saving processor config to {export_path}")
AutoProcessor.from_pretrained(
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
).save_pretrained(export_path)
except Exception as e:
print(f"Warning: Could not save processor config: {e}")
print("This is normal for some VLM architectures that don't use AutoProcessor")
if model_type == "mllama":
full_model_config = full_model.config
# TRT-LLM expects both the vision_config and text_config to be set for export.
setattr(full_model.config, "vision_config", full_model_config.vision_config)
setattr(full_model.config, "text_config", full_model_config.text_config)
setattr(full_model.config, "architectures", full_model_config.architectures)
start_time = time.time()
if (
model_type in ["t5", "bart", "whisper"]
or args.sparsity_fmt != "dense"
or "int8_sq" in args.qformat
):
if (
args.inference_tensor_parallel != 1 or args.inference_pipeline_parallel != 1
) and args.qformat == "nvfp4_svdquant":
raise NotImplementedError("Svdquant does not support multiple GPUs yet.")
warnings.warn(
"Still exporting TensorRT-LLM checkpoints for models not supported by the TensorRT-LLM torch runtime."
)
# Move meta tensor back to device before exporting.
remove_hook_from_module(language_model, recurse=True)
export_tensorrt_llm_checkpoint(
language_model,
model_type,
export_dir=export_path,
inference_tensor_parallel=args.inference_tensor_parallel,
inference_pipeline_parallel=args.inference_pipeline_parallel,
)
# Copy custom model files (Python files and JSON configs) for TensorRT-LLM export
copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code)
else:
# Check arguments for unified_hf export format and set to default if unsupported arguments are provided
assert args.sparsity_fmt == "dense", (
f"Sparsity format {args.sparsity_fmt} not supported by unified export api."
)
if args.inference_tensor_parallel != 1 or args.inference_pipeline_parallel != 1:
warnings.warn(
"Unified HF export format does not specify inference tensor parallel or pipeline parallel. "
"They will be set at deployment time."
)
# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
# Store the MTP layer prefixes on the model for later exclusion from quantization
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path)
if mtp_layer_prefixes:
full_model._mtp_layer_prefixes = mtp_layer_prefixes
export_hf_checkpoint(
full_model,
export_dir=export_path,
extra_state_dict=mtp_state_dict,
)
# Restore default padding and export the tokenizer as well.
if tokenizer is not None:
tokenizer.padding_side = default_padding_side
if default_pad_token is not None:
tokenizer.pad_token = default_pad_token
tokenizer.save_pretrained(export_path)
# Copy custom model files (Python files and JSON configs) if trust_remote_code is used.
# This must run AFTER tokenizer.save_pretrained() so original tokenizer files
# from the source checkpoint take precedence over regenerated ones (which may
# differ in format due to newer transformers versions).
copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code)
end_time = time.time()
print(
f"Quantized model exported to: {export_path}. Total time used {end_time - start_time}s"
)
def pre_quantize(
args: argparse.Namespace,
full_model: torch.nn.Module,
model_type: str | None,
tokenizer: PreTrainedTokenizerBase | None,
calib_dataloader: DataLoader,
is_nemotron_vl_model: bool,
):
"""
Processing before the quantization.
Currently we run one round of generation for a sample prompt, to be compared with
post-quantize generation.
"""
# Only run single sample for preview
preview_input_ids = next(iter(calib_dataloader))[
"input_features" if model_type == "whisper" else "input_ids"
][0:1]
# Generate preview before quantization
if args.skip_generate:
generated_ids_before_ptq = None
elif model_type == "deepseek":
# DeepSeek generation may go OOM, so we skip it
generated_ids_before_ptq = None
elif is_nemotron_vl_model and tokenizer is not None:
generated_ids_before_ptq = run_nemotron_vl_preview(
full_model,
tokenizer,
preview_input_ids,
args.pyt_ckpt_path,
"before quantization",
allow_fallback=False,
)
else:
generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)
return preview_input_ids, generated_ids_before_ptq
def post_quantize(
args: argparse.Namespace,
full_model: torch.nn.Module,
model_type: str | None,
tokenizer: PreTrainedTokenizerBase | None,
processor: BaseImageProcessor | ProcessorMixin | None,
preview_input_ids,
generated_ids_before_ptq,
is_nemotron_vl_model,
first_text_speech_dataset,
):
"""
Processing after the quantization.
Currently we run one round of generation using the quantized model for a sample prompt,
and compare it with pre-quantize generation.
"""
if args.verbose:
try:
mtq.print_quant_summary(full_model, args.export_path)
save_expert_token_count_table(full_model, args.export_path)
except Exception as e:
print(f"Error saving quant summary: {e}")
print("Continuing with generation...")
# Run some samples
torch.cuda.empty_cache()
generated_ids_after_ptq = None
if generated_ids_before_ptq is None:
pass
elif model_type != "llama4" and not is_nemotron_vl_model:
# Our fake quantizer may not be fully compatible with torch.compile.
generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)
elif is_nemotron_vl_model and tokenizer is not None:
generated_ids_after_ptq = run_nemotron_vl_preview(
full_model,
tokenizer,
preview_input_ids,
args.pyt_ckpt_path,
"after quantization",
allow_fallback=False,
)
else:
warnings.warn(
"Llama4 Maverick generation after quantization has a bug. Skipping generation sample."
)
def input_decode(input_ids):
if processor is not None and isinstance(processor, MllamaImageProcessor):
return processor.tokenizer.batch_decode(input_ids)
elif processor is not None and isinstance(processor, WhisperProcessor):
return first_text_speech_dataset
elif tokenizer is not None:
return tokenizer.batch_decode(input_ids)
else:
raise ValueError("The processor or tokenizer must be set")
def output_decode(generated_ids, input_shape):
if is_enc_dec(model_type):
if processor is not None and isinstance(processor, WhisperProcessor):
return processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
elif tokenizer is not None:
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
elif processor is not None and isinstance(processor, MllamaImageProcessor):
return processor.tokenizer.batch_decode(generated_ids[:, input_shape:])
elif tokenizer is not None:
return tokenizer.batch_decode(generated_ids[:, input_shape:])
else:
raise ValueError("The processor or tokenizer must be set")
if generated_ids_after_ptq is not None:
print("--------")
if is_nemotron_vl_model:
# For Nemotron VL models, generated_ids are text strings from model.chat()
print("Nemotron VL model text-only generation results:")
print(f"Text response before quantization: {generated_ids_before_ptq}")
print("--------")
print(f"Text response after quantization: {generated_ids_after_ptq}")
print("--------")
print("Note: Additional VL tests with images were run separately above")
else:
# For regular LLMs, generated_ids are token tensors that need decoding
print(f"example test input: {input_decode(preview_input_ids)}")
print("--------")
print(
f"example outputs before ptq: {output_decode(generated_ids_before_ptq, preview_input_ids.shape[1])}"
)
print("--------")
print(
f"example outputs after ptq: {output_decode(generated_ids_after_ptq, preview_input_ids.shape[1])}"
)
def quantize_main(
args: argparse.Namespace,
full_model: torch.nn.Module,
language_model: torch.nn.Module,
model_type: str | None,
calibration_only: bool,
processor: BaseImageProcessor | ProcessorMixin | None,
tokenizer: PreTrainedTokenizerBase | None,
default_padding_side,
default_pad_token,
device: torch.device,
):
if args.batch_size == 0:
# For VL models with image-text calibration, skip automatic batch size detection
# since get_max_batch_size can't handle multimodal inputs
if args.calib_with_images:
print("Image-text calibration enabled. Using default batch_size=1 for calibration.")
args.batch_size = 1
else:
# Calibration/sparsification will actually take much more memory than regular inference
# due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
# to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference.
sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1
# Whisper model expects mel-spectrogram input features of length 3000
# Whisper model needs input of shape (batch_size, num_mel_bins, 3000)
# As the encoder of Whisper doesn't have embedding layer, input dtype has to be float
# For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size()
if model_type == "whisper":
max_sample_length = 3000
num_mel_bins = language_model.config.num_mel_bins
sample_input_single_batch = (
torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to(
language_model.device
)
* 100
)
else:
sample_input_single_batch = None
run_auto_quant = args.auto_quantize_bits is not None
args.batch_size = get_max_batch_size(
language_model,
max_sample_length=args.calib_seq,
sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0,
sample_input_single_batch=sample_input_single_batch,
enable_grad=run_auto_quant,
)
args.batch_size = min(args.batch_size, sum(args.calib_size))
print(f"Use calib batch_size {args.batch_size}")
calib_dataloader, first_text_speech_dataset = make_calib_dataloader(
args, language_model, processor, tokenizer, device, model_type
)
# Detect if this is a Nemotron VL model using architecture-based detection
is_nemotron_vl_model = is_nemotron_vl(full_model)
preview_input_ids, generated_ids_before_ptq = pre_quantize(
args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model
)
if args.auto_quantize_bits:
assert len(args.qformat.split(",")) > 1, (
"Auto quantization needs multiple quantization format."
)
auto_quantize(
args,
language_model,
calib_dataloader,
)
else:
# mono quantization
if args.recipe is not None:
print(f"Use recipe {args.recipe} for quantization")
recipe = load_recipe(args.recipe)
assert isinstance(recipe, ModelOptPTQRecipe), (
f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}"
)
quant_cfg = recipe.ptq_cfg
else:
assert len(args.qformat.split(",")) == 1, (
"Plain quantization supports only one quantization format."
)
assert args.qformat in QUANT_CFG_CHOICES, (
f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}"
)
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
quant_cfg = build_quant_cfg(
args.qformat,
quant_cfg,
args.awq_block_size,
model_type,
args.moe_calib_experts_ratio,
)
enable_quant_kv_cache = args.kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
)
# Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92)
# These layers are typically speculative decoding layers that should be exported as-is
mtp_layer_prefixes = getattr(full_model, "_mtp_layer_prefixes", None)
if mtp_layer_prefixes:
quant_cfg = copy.deepcopy(quant_cfg)
for prefix in mtp_layer_prefixes:
# Add exclusion pattern for this MTP layer (e.g., "*layers.92*")
pattern = f"*{prefix.split('.')[-2]}.{prefix.split('.')[-1]}*"
quant_cfg["quant_cfg"][pattern] = {"enable": False}
print(f"Excluding MTP layer from quantization: {pattern}")
# Use constant amax for KV quantizers when a cast format is selected.
if args.kv_cache_qformat in _KV_CAST_FORMATS:
quant_cfg = copy.deepcopy(quant_cfg)
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])
if args.qformat in QUANT_CFG_CHOICES:
mono_quantize(
args,
quant_cfg,
full_model,
language_model,
model_type,
calibration_only,
calib_dataloader,
is_nemotron_vl_model,
)
else:
assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton"
print(f"qformat: {args.qformat}. No quantization applied, export {device} model")
post_quantize(
args,
full_model,
model_type,
tokenizer,
processor,
preview_input_ids,