-
Notifications
You must be signed in to change notification settings - Fork 381
Expand file tree
/
Copy path__init__.py
More file actions
1424 lines (1181 loc) · 60 KB
/
__init__.py
File metadata and controls
1424 lines (1181 loc) · 60 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for advanced quantization algorithms."""
import fnmatch
import gc
import types
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Callable, Sequence
from contextlib import nullcontext
from typing import Any
import regex as re
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from modelopt.torch.opt.conversion import ModeloptStateManager
from modelopt.torch.opt.hparam import CustomHPType, Hparam, HPType
from modelopt.torch.opt.searcher import LPS, BaseSearcher, SearchConfig, SearchStateDict
from modelopt.torch.opt.utils import get_hparam, named_hparams
from modelopt.torch.utils import create_param_grad_clear_hook, print_rank_0, report_memory
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState, is_master
from .. import config as mtq_config
from .. import model_calib
from ..config import QuantizeConfig, QuantizerAttributeConfig
from ..conversion import set_quantizer_by_cfg
from ..nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer
from ..utils import is_quantized_linear
def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float:
"""Estimate the compression ratio of a quantization configuration.
Right now, we find the minimum compression ratio across all quantizer attribute configs.
This is not perfect but is a good proxy for the overall compression ratio. We will improve
this in future releases.
Args:
quant_cfg: The quantization configuration to estimate compression for.
Returns:
float: The estimated compression ratio (0.0 to 1.0).
"""
def estimate_quant_compression_for_quantizer(quantizer_attr_cfg):
if isinstance(quantizer_attr_cfg, list):
if not quantizer_attr_cfg:
return 1.0
return min(estimate_quant_compression_for_quantizer(q) for q in quantizer_attr_cfg)
if isinstance(quantizer_attr_cfg, dict):
# Handle raw quantizer cfg dicts (e.g. {"num_bits": (4, 3), "axis": None})
if not quantizer_attr_cfg.get("enable", True):
return 1.0
num_bits = quantizer_attr_cfg.get("num_bits")
if num_bits is None:
return 1.0
if isinstance(num_bits, tuple):
return (sum(num_bits) + 1) / 16
elif isinstance(num_bits, int):
return num_bits / 16
else:
raise ValueError(f"Unknown quantization config {num_bits}")
if isinstance(quantizer_attr_cfg, QuantizerAttributeConfig):
if not quantizer_attr_cfg.enable:
return 1.0
if not hasattr(quantizer_attr_cfg, "num_bits"):
return 1.0
if isinstance(quantizer_attr_cfg.num_bits, tuple):
return (sum(quantizer_attr_cfg.num_bits) + 1) / 16
elif isinstance(quantizer_attr_cfg.num_bits, int):
return quantizer_attr_cfg.num_bits / 16
else:
raise ValueError(f"Unknown quantization config {quantizer_attr_cfg.num_bits}")
raise ValueError(f"Unknown type {type(quantizer_attr_cfg)}, {quantizer_attr_cfg}")
cfgs = []
for e in quant_cfg.quant_cfg:
if e.get("enable", True) is False:
continue
c = e.get("cfg")
if c is not None:
cfgs.append(c)
return estimate_quant_compression_for_quantizer(cfgs) if cfgs else 1.0
class QuantRecipe(CustomHPType):
"""A subclass of QuantizeConfig enabling auto_quantize specific configurations.
Args:
quant_cfg: str or dict or None. dict is used for custom quantization formats.
name: name for custom quantization formats. Only used if quantization format is a custom
format not available in :mod:`modelopt.torch.quantization.config`.
"""
def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | None = None):
"""Initialize the QuantRecipe with the quantization configuration."""
name = self.get_auto_name_for_config(quant_cfg) or name
if quant_cfg is None:
quant_cfg = {"quant_cfg": [{"quantizer_name": "*", "enable": False}]}
elif isinstance(quant_cfg, str):
assert hasattr(mtq_config, quant_cfg), f"Unknown quantization format {quant_cfg}"
quant_cfg = getattr(mtq_config, quant_cfg)
else:
assert name is not None, "name must be provided for custom quantization formats"
self.config = mtq_config.QuantizeConfig(**quant_cfg) # type: ignore [arg-type]
# Disable KV Cache quantization
# Currently KV Cache quantization is enabled for some quantization formats and disabled for others
# This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy
self.config.quant_cfg.append({"quantizer_name": "*output_quantizer", "enable": False})
self.compression = estimate_quant_compression(self.config)
self._str_repr: str = f"{name}(effective-bits: {self.compression * 16})"
@staticmethod
def get_auto_name_for_config(quant_cfg: str | dict[str, Any] | None) -> str | None:
"""Get a name for the quantization configuration."""
if quant_cfg is None:
return "NONE"
if isinstance(quant_cfg, str):
return quant_cfg
for quant_cfg_name in mtq_config.choices:
if quant_cfg == getattr(mtq_config, quant_cfg_name):
return quant_cfg_name
return None
@property
def num_bits(self) -> int:
"""Get the number of bits for the quantization format."""
return int(self.compression * 16)
def __str__(self) -> str:
return self._str_repr
def __repr__(self) -> str:
return self._str_repr
def __lt__(self, other: "QuantRecipe"):
return self.compression < other.compression
def __eq__(self, other: object):
assert isinstance(other, QuantRecipe)
return self._str_repr == other._str_repr
def __hash__(self) -> int:
return hash(self._str_repr)
@staticmethod
def disable_folding_pqs_to_weights():
"""Disable the folding of pre_quant_scale to weights."""
model_calib._ENABLE_FOLDING_PQS_TO_WEIGHTS = False
@staticmethod
def fold_pqs_to_weights(model):
"""Fold the pre_quant_scale in weight_quantizers to weights."""
model_calib._ENABLE_FOLDING_PQS_TO_WEIGHTS = True
for name, module in model.named_modules():
if is_quantized_linear(module):
with SequentialQuantizer.convert_to_single_quantizer(module):
if module.weight_quantizer.pre_quant_scale is not None:
weight_pqs = module.weight_quantizer.pre_quant_scale
delattr(module.weight_quantizer, "_pre_quant_scale")
model_calib._apply_weight_pre_quant_scale(module, weight_pqs)
class QuantRecipeHparam(Hparam):
"""An Hparam for quantization recipes.
See :class:`Hparam <modelopt.torch.opt.hparam.Hparam>` for more details. In addition, this Hparam also:
* Keeps a link to its ``quant_modules`` and ``score_modules`` and sets the quantizers for the
``quant_modules`` based on the active recipe.
* Provides ``get_score()`` and ``get_cost()`` methods to evaluate recipes.
* Registers itself with each ``score_module`` via the ``_hparams_for_scoring`` attribute.
"""
def __init__(
self,
choices: Sequence[QuantRecipe] | None = None,
quant_modules: list[nn.Module] | None = None,
score_modules: list[nn.Module] | None = None,
name: str | None = None,
quant_module_names: list[str] | None = None,
) -> None:
"""Initializes Hparam with original value and choices."""
choices = sorted({*(choices if choices else []), QuantRecipe(quant_cfg=None)})
super().__init__(choices, original=choices[0])
self.name = name
self.quant_module_names = quant_module_names or []
self.quant_modules = list(set(quant_modules or []))
self.score_modules = list(set(score_modules or self.quant_modules))
# This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer
# a dynamic attribute for backward compatibility with the model_calib.py
# TODO: Make input_quantizer, weight_quantizer, output_quantizer a dynamic attribute and get rid of this hack
self._all_quantizer_choices = {quant_recipe: {} for quant_recipe in self.choices}
quant_recipe: QuantRecipe
for quant_recipe in self.choices:
for quant_module in self.quant_modules:
for quantizer_attr_name in [
"input_quantizer",
"weight_quantizer",
"output_quantizer",
]:
setattr(quant_module, quantizer_attr_name, TensorQuantizer())
set_quantizer_by_cfg(quant_module, quant_recipe.config.quant_cfg)
self._all_quantizer_choices[quant_recipe][quant_module] = {
quantizer_attr_name: getattr(quant_module, quantizer_attr_name)
for quantizer_attr_name in [
"input_quantizer",
"weight_quantizer",
"output_quantizer",
]
}
self.active = self.original
# Importance dict is keyed by score_module (where the score is computed)
self._importance_dict = {
quant_recipe: dict.fromkeys(self.score_modules) for quant_recipe in self.choices
}
# Attach this hparam to each score_module's set of hparams it scores
for score_module in self.score_modules:
if not hasattr(score_module, "_hparams_for_scoring"):
score_module._hparams_for_scoring = set()
score_module._hparams_for_scoring.add(self)
@property
def active(self) -> HPType:
"""Return the currently active value."""
return self._active
@active.setter
def active(self, val: HPType | None):
"""Set the active value with a sanity check for choices and dynamic hparams."""
val = self.original if val is None else val
assert val in self._choices, f"val = {val}, choices = {self.choices}"
if self.is_configurable:
self._active = val
else:
assert self._active == val
for nn_module, quantizer_choices in self._all_quantizer_choices[val].items():
for quantizer_attr_name, quantizer in quantizer_choices.items():
setattr(nn_module, quantizer_attr_name, quantizer)
@property
def importance(self) -> dict:
"""Raises an error since this is not a useful abstraction for AutoQuantize."""
raise NotImplementedError
def get_score(self, recipe: QuantRecipe) -> float:
"""Get the score for a given recipe."""
total_score = 0
for score_module in self.score_modules:
importance = self._importance_dict[recipe][score_module]
if importance is None:
continue
parallel_state = getattr(score_module, "parallel_state", None)
if parallel_state is None:
total_score += importance.cpu().item()
continue
if parallel_state.expert_model_parallel_group.is_initialized():
# TODO: Support expert model parallelism for score estimation
warnings.warn("AutoQuantize does not support expert model parallelism yet.")
importance = importance.cpu()
importance = DistributedProcessGroup.get_dist_syncd_obj(
importance,
[parallel_state.tensor_parallel_group, parallel_state.data_parallel_group],
sum,
)
total_score += importance.item()
return total_score
def get_cost(self, recipe: QuantRecipe) -> float:
"""Get the cost for a given recipe.
The cost is the total weight size of the quantizable modules multiplied by
the compression ratio of the recipe.
"""
cost = 0
for quant_module in self.quant_modules:
weight_size = _AutoQuantizeBaseSearcher._get_total_weight_size([quant_module])
parallel_state = getattr(quant_module, "parallel_state", None)
if parallel_state is None:
cost += weight_size * recipe.compression
continue
if parallel_state.expert_model_parallel_group.is_initialized():
# TODO: Support expert model parallelism
warnings.warn("AutoQuantize does not support expert model parallelism yet.")
weight_size = DistributedProcessGroup.get_dist_syncd_obj(
weight_size,
[parallel_state.tensor_parallel_group],
sum,
)
# Across data parallel groups, the weight size is the same for all the ranks.
weight_size = DistributedProcessGroup.get_dist_syncd_obj(
weight_size,
[parallel_state.data_parallel_group],
lambda a: a[0],
)
cost += weight_size * recipe.compression
return cost
@property
def attrs(self) -> list[str]:
"""Return the attributes of the hparam for repr."""
return ["name", *super().attrs]
class _AutoQuantizeBaseSearcher(BaseSearcher, ABC):
"""Base searcher for AutoQuantize algorithm."""
# This searcher finds optimal per-layer quantization by searching across quantization formats
# for each quantizable module (quant module). Optionally, quant grouping rules can restrict
# certain modules to share the same format. Sensitivity scores are computed from perturbations
# at score modules. See AutoQuantizeGradientSearcher for detailed documentation.
candidate_stats: dict[str, dict[str, list[float]]]
best: dict[str, Any]
quantizer_states: dict
method_name: str | None = None
quant_grouping_rules = [
r"^(.*?)\.(q_proj|k_proj|v_proj)$", # q_proj, k_proj, v_proj for llama like models
# gate_proj, up_proj, down_proj for Qwen3 like MoE models
r"^(.*?\.mlp\.experts)\.\d+\.(gate_proj|up_proj|down_proj)$",
r"^(.*?\.mixer\.experts)\.\d+\.(up_proj|down_proj)$", # NemotronH MoE experts
r"^(.*?)\.(gate_proj|up_proj)$", # gate_proj, up_proj for llama like models
r"^(.*?)\.(\d+\.(w1|w2|w3))$", # mixtral experts
r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts
]
score_module_rules = []
@property
def default_search_config(self):
"""Get the default config for the searcher."""
return {
"quantization_formats": ["NVFP4_DEFAULT_CFG", "FP8_DEFAULT_CFG"],
"data_loader": None,
"num_calib_steps": 512,
"num_score_steps": 128,
"deployment": None,
"disabled_layers": None,
"verbose": is_master(),
"checkpoint": None,
}
@property
def default_state_dict(self) -> SearchStateDict:
"""Get the default state dict for AutoQuantize."""
return {
"method": self.method_name,
"candidate_stats": defaultdict(dict),
"quantizer_states": {},
"best": {"recipe": {}, "constraints": {}, "score": float("inf"), "is_satisfied": False},
}
def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
"""Sanitize the search config dict."""
config = config or {}
config = super().sanitize_search_config(config)
assert config["data_loader"] is not None, (
"`data_loader` must be provided for `auto_quantize`."
)
assert config["forward_step"] is not None, (
"`forward_step` must be provided for `auto_quantize`."
)
return config
def load_search_checkpoint(self) -> bool:
return super().load_search_checkpoint(strict=False)
@staticmethod
def _is_auto_quantize_module(module):
return (
is_quantized_linear(module) or isinstance(module, QuantLinearConvBase)
) and isinstance(module, QuantModule)
@staticmethod
def _get_search_recipes(quantization_formats):
return sorted(
{
QuantRecipe(quant_cfg=q[0], name=q[1])
if isinstance(q, tuple)
else QuantRecipe(quant_cfg=q)
for q in quantization_formats
}
)
def _apply_quant_group_rule(self, name: str, rule) -> str | None:
"""Apply a single quant_group_rule to a module name.
Args:
name: Module name
rule: Either a regex pattern string or a callable that returns a unique key;
If callable, it should take the model and the name as input and return the unique key
Returns:
The group key if the rule matches, None otherwise
"""
if callable(rule):
return rule(self.model, name)
else:
# Regex pattern
pattern = re.compile(rule)
match = pattern.match(name)
if match:
return match.group(1)
return None
def _apply_score_group_rule(self, name: str, rule) -> str | None:
"""Apply a single score_group_rule to a module name.
Args:
name: Module name
rule: Either a regex pattern string or a callable that returns the score module name.
If callable, it should take the model and the name as input and return the score module name
Returns:
The score module name if the rule matches, None otherwise
"""
if callable(rule):
return rule(self.model, name)
else:
# Regex pattern - return the matched name or full match
pattern = re.compile(rule)
match = pattern.match(name)
if match:
# For score rules, return the full match or first group
return match.group(0) if match.lastindex is None else match.group(1)
return None
def _get_score_module_from_name(
self, model: nn.Module, score_module_name: str, quant_module: nn.Module
) -> nn.Module:
"""Get the actual score module object from its name.
Args:
model: The model containing all modules
score_module_name: The name of the score module to retrieve
quant_module: The quantized module for which the score is estimated
Returns:
The score module object, or the quantized module itself if the score module is not found
"""
try:
score_module = model.get_submodule(score_module_name)
return score_module
except AttributeError:
warnings.warn(
f"Score module '{score_module_name}' not found. Score will estimated from the quantized module itself."
)
return quant_module
def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers=None):
"""Restrict the search space using the merge rules and insert the hparams for the model."""
# TRTLLM fuses linear layers such as q_proj, k_proj, v_proj into same layer
# Hence we need to restrict the search space so that all these layers share the same recipe
# Lets group the modules based on the rules and insert the same hparam for all the modules in the group
if disabled_layers is None:
disabled_layers = []
elif isinstance(disabled_layers, str):
disabled_layers = [disabled_layers]
# Map from group key to list of (quant_module, name, disabled, score_module)
search_map: dict[str, list[tuple[nn.Module, str, bool, nn.Module]]] = {}
for name, module in model.named_modules():
if not self._is_auto_quantize_module(module):
continue
# Skip layers that match disabled_layers patterns
disabled = False
for pattern in disabled_layers:
if fnmatch.fnmatch(name, pattern):
disabled = True
break
# Apply quant_grouping_rules to determine the group key
group_key = name # Default: each module in its own group
for rule in self.quant_grouping_rules:
result = self._apply_quant_group_rule(name, rule)
if result is not None:
group_key = result
# We support only one rule for matching per module
break
# Apply score_module_rules to determine the score module name, then get the actual module
score_module_name = name # Default: score from same module
for rule in self.score_module_rules:
result = self._apply_score_group_rule(name, rule)
if result is not None:
score_module_name = result
# We support only one rule for matching per module
break
# Get the actual score module object immediately
score_module = self._get_score_module_from_name(model, score_module_name, module)
if group_key not in search_map:
search_map[group_key] = [(module, name, disabled, score_module)]
else:
search_map[group_key].append((module, name, disabled, score_module))
for group_key, module_info_list in search_map.items():
quant_modules = [module for module, _, _, _ in module_info_list]
disabled = any(disabled for _, _, disabled, _ in module_info_list)
score_modules = [score_module for _, _, _, score_module in module_info_list]
_quant_recipes = None if disabled else quant_recipes
hparam = QuantRecipeHparam(
_quant_recipes,
quant_modules=quant_modules,
score_modules=score_modules,
name=str(group_key),
quant_module_names=[name for _, name, _, _ in module_info_list],
)
for module in quant_modules:
module._register_hparam("quant_recipe", hparam)
def _get_formatted_weight_compression_constraint(self):
effective_bits = self.constraints["effective_bits"]
assert effective_bits > 0 and effective_bits <= 16, (
"effective_bits should be between 0 and 16."
)
weight_compression = self.constraints["effective_bits"] / 16.0
return weight_compression
def _verify_constraint(self, search_recipes):
assert self.constraints["effective_bits"] >= search_recipes[0].num_bits, (
f"The effective_bits {self.constraints['effective_bits']} constraint cannot be lower than the "
f"num_bits of most aggressive quantization format for this search which is "
f"{search_recipes[0]} whose num_bits = {search_recipes[0].num_bits}."
)
@abstractmethod
def estimate_sensitivity_scores(self) -> None:
"""Estimate sensitivity scores and track them with Hparam."""
def initialize_candidate_stats(self):
"""Initialize the candidate stats for the model."""
for name, hparam in named_hparams(self.model, unique=True):
if not isinstance(hparam, QuantRecipeHparam):
continue
formats, scores, costs = [], [], []
prev_score = float("inf")
for recipe in hparam.choices:
formats.append(recipe)
score = hparam.get_score(recipe) # type: ignore [arg-type]
cost = hparam.get_cost(recipe) # type: ignore [arg-type]
score = min(score, prev_score) # TODO: Should we get rid of this?
scores.append(score)
costs.append(cost)
prev_score = score
self.candidate_stats[name]["formats"] = formats
self.candidate_stats[name]["scores"] = scores
self.candidate_stats[name]["costs"] = costs
self.candidate_stats[name]["module_names"] = hparam.quant_module_names
def _run_func(self, func, num_iters=1, desc=""):
for i, data in tqdm(
zip(range(num_iters), self.config["data_loader"]),
desc=desc,
total=num_iters,
):
func(self.model, data)
def before_search(self):
"""Prepare the model for search by calibrating the quantizers and collecting ``AutoQuantize`` score."""
# Import here to avoid circular import
from modelopt.torch.quantization.model_quant import calibrate
from ..conversion import restore_quantizer_state, update_quantize_metadata
from ..utils import get_quantizer_state_dict, set_quantizer_state_dict
super().before_search()
restored_method = getattr(self, "method", None)
if self.candidate_stats and restored_method not in (None, self.method_name):
raise ValueError(
f"Checkpoint method '{restored_method}' does not match current method "
f"'{self.method_name}'. Use a different checkpoint path."
)
self.method = self.method_name
search_recipes = self._get_search_recipes(self.config["quantization_formats"])
self._verify_constraint(search_recipes)
self.insert_hparams_after_merge_rules(
self.model, search_recipes, self.config["disabled_layers"]
)
QuantRecipe.disable_folding_pqs_to_weights()
# Iterate over the search recipes and calibrate the quantizers for each recipe
calibrated_new = False
for recipe in search_recipes:
if recipe == QuantRecipe(quant_cfg=None): # No-quant format
continue
for name, hparam in named_hparams(self.model, configurable=True):
if not isinstance(hparam, QuantRecipeHparam):
continue
hparam.active = recipe
if recipe in self.quantizer_states:
saved = self.quantizer_states[recipe]
# config is unused by restore_quantizer_state
restore_quantizer_state(
self.model, QuantizeConfig(), {"quantizer_state": saved["metadata"]}
)
set_quantizer_state_dict(self.model, saved["state_dict"])
if self.config["verbose"]:
print_rank_0(f"AutoQuantize: Restored calibration for {recipe}")
continue
# Lets reduce the number of calibration steps for AWQ since it takes longer
num_calib_steps = (
self.config["num_calib_steps"]
if "awq" not in str(recipe.config.algorithm)
else max(1, self.config["num_calib_steps"] // 4)
)
def forward_loop(model):
self._run_func(
self.config["forward_step"],
num_iters=num_calib_steps,
desc=f"Calibrating for {recipe}",
)
calibrate(
self.model,
algorithm=recipe.config.algorithm,
forward_loop=forward_loop,
)
# Calibrate adds a new mode to the model. Since auto_quantize mixes the quantization recipes
# across layers, lets not save this new mode in the modelopt state.
# TODO: This is a hack. We need to create a mode for auto_quantize to handle this in a clean way.
ModeloptStateManager(self.model).state_dict().pop()
metadata: dict = {}
# config is unused by update_quantize_metadata
update_quantize_metadata(self.model, QuantizeConfig(), metadata)
self.quantizer_states[recipe] = {
"metadata": metadata["quantizer_state"],
"state_dict": get_quantizer_state_dict(self.model),
}
calibrated_new = True
if calibrated_new:
self.save_search_checkpoint(verbose=self.config["verbose"])
if self.candidate_stats:
if self.config["verbose"]:
print_rank_0("AutoQuantize: Restored from checkpoint, skipping scoring")
return
self.estimate_sensitivity_scores()
self.initialize_candidate_stats()
self.save_search_checkpoint(verbose=self.config["verbose"])
@staticmethod
def _print_recipe_summary(best_recipe, total_cost, total_weight_size, prefix="AutoQuantize"):
for name, recipe in best_recipe.items():
print_rank_0(f"{prefix} best recipe for {name.replace('.quant_recipe', '')}: {recipe}")
effective_bits = (total_cost / total_weight_size) * 16
print_rank_0(f"{prefix} effective bits: {effective_bits:.2f}")
return effective_bits
@staticmethod
def _get_total_weight_size(modules):
return sum(
(
module.weight.numel()
if _AutoQuantizeBaseSearcher._is_auto_quantize_module(module)
else 0
)
for module in modules
)
def _get_constraints_for_search(self, max_weight_size, lower_bound=None):
constraints = {
"weight_size_after_compression": (
lower_bound * max_weight_size if lower_bound else lower_bound,
max_weight_size,
)
}
return constraints, "weight_size_after_compression"
@abstractmethod
def run_search_with_stats(self, max_weight_size, verbose=False):
"""Run the search with stats to get the best recipe and whether the constraints are satisfied."""
def run_search(self):
"""Search for the best per-layer quantization configuration and return the best model and configuration."""
verbose = self.config["verbose"]
assert len(self.constraints) == 1 and "effective_bits" in self.constraints, (
f"`constraints` must contain only 'effective_bits' constraint. "
f"Got {self.constraints.keys()}"
)
compression = self._get_formatted_weight_compression_constraint()
total_weight_size = self._get_total_weight_size(self.model.modules())
max_weight_size = total_weight_size * compression
# Run the search with stats to get the best recipe and whether the constraints are satisfied
best_recipe_info, is_satisfied = self.run_search_with_stats(max_weight_size, verbose)
self.best["is_satisfied"] = is_satisfied
best_recipe = {}
best_constraints, best_scores = 0, 0
for name, best_hparam_recipe_info in best_recipe_info.items():
# Solvers could give different solutions for the same layer across DP/TP groups even though
# the scores and costs are the same. Lets make sure the same recipe is selected across DP/TP
_ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state
best_format = DistributedProcessGroup.get_dist_syncd_obj(
best_hparam_recipe_info["format"],
[_ps.data_parallel_group, _ps.tensor_parallel_group],
lambda a: a[0],
)
best_recipe[name] = best_format
get_hparam(self.model, name).active = best_format
best_constraints += best_hparam_recipe_info["costs"]
best_scores += best_hparam_recipe_info["scores"]
if verbose:
effective_bits_from_search = self._print_recipe_summary(
best_recipe, best_constraints, total_weight_size
)
else:
effective_bits_from_search = (best_constraints / total_weight_size) * 16
self.best["recipe"] = best_recipe
self.best["constraints"] = {"effective_bits": effective_bits_from_search}
self.best["score"] = best_scores
QuantRecipe.fold_pqs_to_weights(self.model)
def _get_auto_quantize_score(grad_output, output_diff):
x = grad_output.float() * output_diff.float()
return x.clamp(-1e10, 1e10).square().sum()
def _add_auto_quantize_score(grad_output, output_diff, score_tensor):
score_tensor += _get_auto_quantize_score(grad_output, output_diff)
class AutoQuantizeGradientSearcher(_AutoQuantizeBaseSearcher):
"""A searcher for AutoQuantize algorithm that uses gradient based score estimation.
In AutoQuantize, we search for the best per-layer quantization configuration that minimizes the sum of per-layer
scores while meeting the specified constraint. AutoQuantize uses Linear Programming Solver to find the
optimal quantization configuration.
The auto_quantize score for a layer quantization configuration is an approximation of model loss change due
to quantizing the particular layer with the particular configuration.
The approximation is based on taylor expansion of the loss function wrt to the quantized output of the layer and
substitution of Fisher information for Hessian.
This approximation is mathematically correct for models where the loss
is a log likelihood loss such as BERT, GPT, etc. However, the auto_quantize score can still be used as a proxy
for other models such as ResNet.
**Quant Modules:**
This searcher operates on quantizable modules (quant modules), which are typically Linear or Conv layers
that support quantization. Optionally, grouping rules can be applied to ensure certain layers share the same
quantization format (e.g., Q, K, V projections in the same attention layer). For details on quant_grouping_rules
and customization, see the :meth:`auto_quantize <modelopt.torch.quantization.model_quant.auto_quantize>`
API documentation.
**Score Modules:**
By default, for each quant module, its sensitivity score is estimated using that module's output perturbation.
However, the sensitivity can also be estimated by looking at perturbation at a separate point in the neural
network (score module). This is helpful in some cases such as MoEs for speed and lower memory consumption.
Since all experts are already restricted to the same quant format by quant grouping rules, their sensitivity
can be estimated together at a single point (e.g., the MLP output level).
"""
method_name = "gradient"
score_module_rules = [
# Use MLP layer output for gate_proj, up_proj, down_proj for Qwen3 like MoE models (local and shared experts)
r"^(.*?\.mlp)\.experts\.\d+\.(gate_proj|up_proj|down_proj)$",
r"^(.*?\.mixer)\.experts\.\d+\.(up_proj|down_proj)$", # NemotronH MoE experts
r"^(.*?)\.(\d+\.(w1|w2|w3))$", # mixtral experts
r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts
]
# See `register_custom_support` for details
_custom_support: list[tuple[Callable, Callable, Callable]] = []
@property
def default_search_config(self):
"""Get the default config for the searcher."""
config = super().default_search_config
config.update(
{
"forward_step": None,
"loss_func": None,
"forward_backward_step": None,
}
)
return config
def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
"""Sanitize the search config dict."""
config = config or {}
if "score_func" in config:
warnings.warn("`score_func` is ignored for gradient based `auto_quantize`.")
config.pop("score_func")
config = super().sanitize_search_config(config)
if config["forward_backward_step"] is None:
assert config["loss_func"] is not None, (
"`loss_func` or `forward_backward_step` must be provided for `auto_quantize`."
)
config["forward_backward_step"] = self._get_default_forward_backward_step()
return config
@classmethod
def register_custom_support(
cls,
is_supported_checker: Callable,
grad_ckpt_context: Callable,
is_param_grad_enabled: Callable,
) -> None:
"""(Optional) Register custom support for `AutoQuantize` score estimation.
This custom support is used to enable memory/compute efficient backward gradient propagation. This involves:
- `grad_ckpt_context`: backward pass with gradient checkpointing enabled
- `is_param_grad_enabled`: AutoQuantize only needs activation gradients to be computed (not weight
gradients). `is_param_grad_enabled` is used to select which parameters should have gradients enabled,
limiting gradient computation to only what's needed for activation gradients. For LLMs, to trigger all
activation gradient computation, just enabling the embedding layer weight gradient is sufficient. This will
enable gradient computation for all the activation gradients downstream.
If the `is_supported_checker(model)` returns True, the `grad_ckpt_context(model)` will be
used to enable gradient checkpointing and `is_param_grad_enabled(pname, model)`
will be used to select which parameters have gradients enabled to minimize gradient computation.
"""
cls._custom_support.append((is_supported_checker, grad_ckpt_context, is_param_grad_enabled))
def _get_default_forward_backward_step(self):
def forward_backward_step(model, data):
output = self.config["forward_step"](model, data)
loss = self.config["loss_func"](output, data)
try:
loss.backward()
except RuntimeError as e:
raise RuntimeError(
"AutoQuantize: Error while calling `backward()` on the loss returned by `loss_func`. "
"Please fix this!"
f"error: {e}"
) from e
return forward_backward_step
@torch.enable_grad()
def _estimate_auto_quantize_scores(self, is_param_grad_enabled):
# TODO: remove the no-quant recipe
def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
for hparam in module._hparams_for_scoring:
if hparam.is_configurable:
hparam.active = QuantRecipe(quant_cfg=None)
output = module._forward_original(input, *args, **kwargs)
# If gradient checkpointing is enabled, gradient will not be enabled in the global forward pass.
# With gradient checkpointing, gradients are computed in the local forward pass during backward pass
# Lets compute the output_diff and save it in memory only if gradient is enabled to be memory efficient
if not torch.is_grad_enabled():
return output
module.output_diff_dict = {hparam: {} for hparam in module._hparams_for_scoring}
with torch.no_grad():
for hparam in module._hparams_for_scoring:
if not hparam.is_configurable:
continue
for recipe in hparam.choices:
if recipe == QuantRecipe(quant_cfg=None):
continue
hparam.active = recipe
output_diff = module._forward_original(input, *args, **kwargs)
if isinstance(output_diff, tuple):
output_diff = output_diff[0] - output[0]
else:
output_diff -= output
module.output_diff_dict[hparam][recipe] = output_diff.detach()
# Disable the configurable hparam now that we have computed the diff
hparam.active = QuantRecipe(quant_cfg=None)
return output
def backward_hook(module, grad_input, grad_output):
for hparam, output_diff_dict in module.output_diff_dict.items():
for recipe, output_diff in output_diff_dict.items():
if hparam._importance_dict[recipe][module] is None:
hparam._importance_dict[recipe][module] = _get_auto_quantize_score(
grad_output[0], output_diff
)
else:
_add_auto_quantize_score(
grad_output[0], output_diff, hparam._importance_dict[recipe][module]
)
def setup_params_for_score_estimation(name, param, params_metadata, enable_grad=True):
# Let us delete the gradient as soon as they are computed to save memory
params_metadata[name] = {"requires_grad": param.requires_grad}
param.requires_grad = enable_grad
if not enable_grad:
return
if self.config.get("verbose", False):
print_rank_0(f"AutoQuantize: Enabling gradient for param {name}.")
accum_grad, handle = create_param_grad_clear_hook(param)
params_metadata[name]["accum_grad"] = accum_grad # We need to keep the accum_grad alive
params_metadata[name]["handle"] = handle
def setup_module_for_score_estimation(module):
module._forward_original = module.forward
module.forward = types.MethodType(auto_quantize_score_estimate_forward, module)
module._backward_hook_handle = module.register_full_backward_hook(backward_hook)
def cleanup_module_after_score_estimation(module):
module.forward = module._forward_original
del module._forward_original
module._backward_hook_handle.remove()
def cleanup_params_after_score_estimation(name, param, params_metadata):
param.requires_grad = params_metadata[name]["requires_grad"]
handle = params_metadata[name].get("handle", None)
if handle is not None:
handle.remove()
score_modules = set()
for name, module in self.model.named_modules():
if (
hasattr(module, "_hparams_for_scoring")
and any(hparam.is_configurable for hparam in module._hparams_for_scoring)
and module not in score_modules
):
# Monkey patch the forward methods to cache (Q(Y) - Y)
setup_module_for_score_estimation(module)
score_modules.add(module)
params_metadata = {}
for name, param in self.model.named_parameters():
setup_params_for_score_estimation(
name, param, params_metadata, is_param_grad_enabled(name, self.model)
)