-
Notifications
You must be signed in to change notification settings - Fork 366
Expand file tree
/
Copy pathhuggingface.py
More file actions
1618 lines (1331 loc) · 65 KB
/
huggingface.py
File metadata and controls
1618 lines (1331 loc) · 65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support quantization for huggingface layers."""
import inspect
import logging
import warnings
from contextlib import contextmanager
from functools import partial
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
import transformers
from packaging import version
from torch import Tensor
from torch.nn.functional import linear
from transformers.models.t5.modeling_t5 import T5Attention
from modelopt.torch.opt.dynamic import DynamicModule
from modelopt.torch.utils.distributed import ParallelState
from ..algorithms import AutoQuantizeGradientSearcher
from ..conversion import register
from ..nn import QuantInputBase, QuantModule, QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import _QuantLinear
from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE
from ..utils import replace_function, sync_moe_expert_amax
from ..utils.activation_collector import LayerActivationCollector
from .attention import register_attention_for_kv_quant
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin
logger = logging.getLogger(__name__)
try:
from torch.distributed.tensor import Shard
except ImportError:
Shard = None
try:
import kitchen
from kitchen.fa import KitchenFlashAttentionModule
from kitchen.triton_module import triton_fa_params
except ImportError:
kitchen = None
if IS_TRITON_AVAILABLE:
from ..triton import weight_dequant
else:
weight_dequant = None
if TYPE_CHECKING:
from types import ModuleType
__all__ = ["register_hf_attentions_on_the_fly"]
TRANSFORMERS_VERSION_GE_5_0 = version.parse(transformers.__version__) >= version.parse("5.0.0")
class _QuantAttention(QuantModule):
"""Attention class for KV Cache quantization compatible with new_attention_interface in transformers >= 4.48.0."""
def _setup(self):
self.q_bmm_quantizer = TensorQuantizer()
self.k_bmm_quantizer = TensorQuantizer()
self.v_bmm_quantizer = TensorQuantizer()
self.softmax_quantizer = TensorQuantizer()
self.kitchen_attn_fn = None
self.use_kitchen = False
def _init_kitchen_attn_fn(self):
if not self.softmax_quantizer.is_enabled:
self.kitchen_attn_fn = "disabled"
return
self.use_kitchen = True
if self.softmax_quantizer.is_mxfp(8):
qfa_params = triton_fa_params.QTritonFAParams(
backend="triton",
qk_dot_precisions="bf16@bf16",
pv_dot_precisions="mxfp8_e4m3_emulation@bf16",
dp_v_x_do_dot_precisions="bf16@bf16",
dp_do_x_v_dot_precisions="bf16@bf16",
dq_ds_x_k_dot_precisions="bf16@bf16",
dk_ds_x_q_dot_precisions="bf16@bf16",
dv_p_x_do_dot_precisions="bf16@bf16",
use_natural_transcendental_func=False, # Different from default
)
else:
raise NotImplementedError(f"softmax_quantizer not supported: {self.softmax_quantizer}")
self.kitchen_attn_fn = KitchenFlashAttentionModule(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.head_dim,
num_gqa_groups=None, # self.config.num_key_value_heads, kitchen does not support gqa.
attention_dropout=self.config.attention_dropout,
qkv_format="sbhd", # this is not used at all, but in forward, this is the only supported format.
attn_mask_type="causal",
window_size=getattr(self.config, "sliding_window", None),
sequence_parallel=False,
get_rng_state_tracker=None,
layer_number=None,
attention_type="self",
softmax_scale=None, # This will be convert to the same default as sdpa: 1/sqrt(dim_q)
qfa_params=qfa_params,
)
@staticmethod
def _quantized_attention(
original_attention_interface,
self,
query_states,
key_states,
value_states,
*args,
**kwargs,
):
if kitchen is not None and self.kitchen_attn_fn is None:
self._init_kitchen_attn_fn()
query_states = self.q_bmm_quantizer(query_states)
key_states = self.k_bmm_quantizer(key_states)
value_states = self.v_bmm_quantizer(value_states)
if not self.use_kitchen:
return original_attention_interface(
self, query_states, key_states, value_states, *args, **kwargs
)
query_sequence_length = query_states.shape[2]
if query_states.shape[2] < key_states.shape[2]: # For decoding stage.
shape = list(query_states.shape)
shape[2] = key_states.shape[2] - query_states.shape[2]
query_states = torch.cat(
[
torch.empty(shape, dtype=query_states.dtype, device=query_states.device),
query_states,
],
dim=2,
)
n_repeat = self.config.num_attention_heads // self.config.num_key_value_heads
if n_repeat > 1:
key_states = key_states.repeat_interleave(n_repeat, dim=1)
value_states = value_states.repeat_interleave(n_repeat, dim=1)
# kitchen only supports sbhd. we have bhsd.
query_states = query_states.permute(2, 0, 1, 3)
key_states = key_states.permute(2, 0, 1, 3)
value_states = value_states.permute(2, 0, 1, 3)
attn_out = self.kitchen_attn_fn(query_states, key_states, value_states)
attn_out = attn_out[-query_sequence_length:, :, :]
# output is sb(h*d), we need bshd
attn_out = attn_out.reshape(
(attn_out.shape[0], attn_out.shape[1], query_states.shape[2], -1)
).permute(1, 0, 2, 3)
return attn_out.contiguous(), None
def forward(self, *args, **kwargs):
"""Forward method for KV cache quantization compatible with new_attention_interface in transformers >= 4.48.0.
The forward method is used to patch the attention interface with _quantized_attention.
Once output tensors are generated, it restores the original attention interface.
"""
def _is_eager_attention():
if self.config._attn_implementation == "eager":
return True
return bool(
self.config._attn_implementation == "sdpa"
and kwargs.get("output_attentions", False)
)
# Get the original transformers module before wrapped in any ModelOpt DynamicModule
module: ModuleType = inspect.getmodule(self.get_attn_type(self))
# Preprocessing logic to patch attention interface
original_attention_interface = (
module.eager_attention_forward
if _is_eager_attention()
else module.ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
)
patch_fn = partial(self._quantized_attention, original_attention_interface)
if _is_eager_attention():
if not hasattr(module, "eager_attention_forward"):
raise AssertionError(
f"Module {module} does not have `eager_attention_forward` to enable KV Cache quantization. "
"Please use a different attention implementation such as `sdpa` by setting "
"`model.config._attn_implementation = 'sdpa'` before quantization."
)
module.eager_attention_forward = patch_fn # type: ignore[attr-defined]
else:
module.ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] = patch_fn
try:
outputs = super().forward(*args, **kwargs)
finally:
# Cleanup logic to restore the original attention interface
if _is_eager_attention():
module.eager_attention_forward = original_attention_interface # type: ignore[attr-defined]
else:
module.ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] = (
original_attention_interface
)
return outputs
@staticmethod
def is_compatible_attention(attn):
# The new_attention_interface is only available in transformers >= 4.48.0
# In addition, the new attention interface is not available for some models such as T5
# Hence lets do a crude check here to see if the attention module is using the new_attention_interface
# This is not foolproof but should work for most cases
module = inspect.getmodule(attn)
return getattr(module, "ALL_ATTENTION_FUNCTIONS", None) is not None
@staticmethod
def get_attn_type(attn_module) -> type:
# If this is a DynamicModule, it means that the module class has been wrapped by ModelOpt
# Hence, we need to get the original class by level=0
return (
attn_module.get_original_cls_by_level(level=0)
if isinstance(attn_module, DynamicModule)
else type(attn_module)
)
class _T5QuantAttention(QuantModule):
"""Attention class for KV Cache quantization compatible with T5 Model."""
def _quantized_matmul(self, batch1, batch2):
# T5Attention has two matmul operations, one for the query and key and one for the attention and value.
# The first matmul is quantized with the q_bmm_quantizer and k_bmm_quantizer. The second matmul is
# quantized with the v_bmm_quantizer.
if self.qk_quant_matmul:
self.qk_quant_matmul = False
q, k = batch1, batch2
return torch._matmul(
self.q_bmm_quantizer(q), self.k_bmm_quantizer(k.transpose(3, 2)).transpose(3, 2)
)
else:
self.qk_quant_matmul = True
attn, v = batch1, batch2
return torch._matmul(attn, self.v_bmm_quantizer(v))
def _setup(self):
self.q_bmm_quantizer = TensorQuantizer(QuantInputBase.default_quant_desc_input)
self.k_bmm_quantizer = TensorQuantizer(QuantInputBase.default_quant_desc_input)
self.v_bmm_quantizer = TensorQuantizer(QuantInputBase.default_quant_desc_input)
@staticmethod
def is_compatible_attention(attn):
return issubclass(attn, T5Attention)
def forward(self, *args, **kwargs):
# self.qk_quant_matmul is used to alternate between the two matmul operations for T5Attention
self.qk_quant_matmul = True
with replace_function(torch, "matmul", self._quantized_matmul):
return super().forward(*args, **kwargs)
def register_hf_attentions_on_the_fly(model):
"""Find HF Attention modules in the model and register them for KV Cache quantization.
This function attempts to find child modules ending with "Attention" in the name.
If such child modules are not found, or the corresponding class does not contain
identifiable attention patterns, the function will not register any new modules.
"""
if not _is_supported_hf_model(model):
return
attention_cls = set()
registered_attn_module = False
for name, module in model.named_modules():
# Only register attention classes that are from Huggingface transformers
if type(module).__name__.endswith("Attention"):
attention_type = _QuantAttention.get_attn_type(module)
# Add modules to be registered only if they arent already registered
if (
QuantModuleRegistry.get(attention_type) is None
and attention_type not in attention_cls
):
if _QuantAttention.is_compatible_attention(attention_type):
# Lets register the attention class for KV Cache quantization
register(attention_type, _QuantAttention)
registered_attn_module = True
print(
f"Registered {attention_type} to {_QuantAttention.__name__} for KV Cache quantization"
)
elif _T5QuantAttention.is_compatible_attention(attention_type):
register(attention_type, _T5QuantAttention)
registered_attn_module = True
print(
f"Registered {attention_type} to {_T5QuantAttention.__name__} for KV Cache quantization"
)
else:
attention_cls.add(attention_type)
print(
f"Registered {attention_type} to AST based quantized class for KV Cache quantization"
)
# Check if the attention class has been registered
# For T5Attention, we want to avoid registering T5LayerCrossAttention and T5LayerSelfAttention.
# Hence we check if the attention class has been registered.
if registered_attn_module or not attention_cls:
return
# this is the case for models that do not use the new_attention_interface or transformers version < 4.48.0
# Register the attention class for KV Cache quantization
success = any(register_attention_for_kv_quant(cls) for cls in attention_cls)
if not success:
warnings.warn(
f"Could not create a quantized attention class for {attention_cls} from this model. "
"To enable KV Cache quantization, please create a custom quantized attention class for this model and "
"register it to ModelOpt using `mtq.register` "
"(see https://nvidia.github.io/Model-Optimizer/guides/_pytorch_quantization.html#custom-quantized-module-and-quantizer-placement)"
)
class HFParallelLinear(torch.nn.Linear, DynamicModule):
supported_hf_tp_plans = []
shard = None
def _setup(self):
assert self.weight.placements == self.shard, (
f"Received unexpected shard {self.weight.placements} for {self}"
)
tp_group = self.weight.device_mesh.get_group()
self._parallel_state = ParallelState(data_parallel_group=-1, tensor_parallel_group=tp_group)
@classmethod
def is_compatible(cls, linear) -> bool:
if not isinstance(linear, torch.nn.Linear):
return False
if not hasattr(linear, "_hf_tp_plan"):
return False
return linear._hf_tp_plan in cls.supported_hf_tp_plans
# This is hack for now, otherwise DMRegistry treats this class same as nn.Linear
def forward(self, x):
return super().forward(x)
class HFColumnParallelLinear(HFParallelLinear):
supported_hf_tp_plans = ["colwise", "colwise_rep"]
shard = (Shard(0),) if Shard is not None else None
class HFRowParallelLinear(HFParallelLinear):
supported_hf_tp_plans = ["rowwise", "rowwise_rep"]
shard = (Shard(1),) if Shard is not None else None
class _QuantHFParallelLinear(_ParallelLinear):
_functionals_to_replace = [(torch.nn.functional, "linear")]
def fold_weight(self, keep_attrs: bool = False):
with self.enable_weight_access_and_writeback():
super().fold_weight(keep_attrs)
@contextmanager
def enable_weight_access_and_writeback(self):
assert self.weight.placements == self.shard, (
f"Received unexpected shard {self.weight.placements} for {self}"
)
weight = self.weight
# TODO: To support TP + FSDP, we need to redistribute the tensor with replicate instead of shard
self.weight = nn.Parameter(weight.to_local())
yield
self.weight = weight
@QuantModuleRegistry.register({HFColumnParallelLinear: "HFColumnParallelLinear"})
class QuantHFColumnParallelLinear(_QuantHFParallelLinear):
_is_column_parallel = True
@QuantModuleRegistry.register({HFRowParallelLinear: "HFRowParallelLinear"})
class QuantHFRowParallelLinear(_QuantHFParallelLinear):
_is_row_parallel = True
def convert_hf_parallel_linears_on_the_fly(model):
"""Convert nn.Linear layers that have been TP sharded by HF.
Huggingface shards regular nn.Linear layers to rowwise or columnwise tensor-parallel layers dynamically.
This method converts them to `HFColumnParallelLinear` and `HFRowParallelLinear` so that they
can be treated as TP sharded layers and not like regular nn.Linear layers.
"""
for name, module in model.named_modules():
if HFColumnParallelLinear.is_compatible(module):
HFColumnParallelLinear.convert(module)
elif HFRowParallelLinear.is_compatible(module):
HFRowParallelLinear.convert(module)
if transformers.pytorch_utils.Conv1D not in QuantModuleRegistry:
# transformers.pytorch_utils.Conv1D used in HF-GPT2 is not a real Conv1D
# It is actually a Linear layer where weight is transposed and torch.addmm is used
@QuantModuleRegistry.register({transformers.pytorch_utils.Conv1D: "Conv1D"})
class _QuantConv1D(_QuantLinear):
@classmethod
@torch.no_grad()
def convert(cls, module: nn.Module) -> "_QuantConv1D":
module.weight = nn.Parameter(module.weight.T.contiguous())
module.out_features, module.in_features = module.weight.shape
# We want the forward method of nn.Linear to be called instead of the forward method of Conv1D
dyn_cls: QuantModule = QuantModuleRegistry.get(nn.Linear)
return dyn_cls.convert(module)
class _TransposedQuantization(torch.autograd.Function):
"""Applies transposed quantization.
This is useful for weight quantization of some MoEs such as gpt-oss or Llama4 which has expert weights
of shape (num_experts, in_dim, out_dim). Per-channel/Per-block quantization from ModelOpt
assumes that `in_dim` is -1 dim. Hence for quantizing such MoE weights, lets use transposed quantization.
"""
# Note: TransposedQuantization uses STE with no clipping
@staticmethod
def forward(ctx, inputs, quantizer):
return quantizer(inputs.transpose(-1, -2).contiguous()).transpose(-1, -2)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
_transposed_quantize = _TransposedQuantization.apply
class _QuantSparseMoe(QuantModule):
"""Quantization wrapper for HuggingFace sparse MoE blocks.
Supports ``layer_sync_moe_local_experts_amax`` to sync input quantizer amax across experts.
Optionally supports config-driven features (disabled by default):
- ``_moe_calib_experts_ratio``: force-forward tokens to more experts during calibration.
When set to a value > 0, also enables token counting per expert.
When disabled, forward is a direct pass-through with zero overhead.
"""
def _setup(self):
self._moe_calib_experts_ratio = None
self._token_counting_initialized = False
def _init_token_counting(self):
"""Lazy-init token counting infra (buffer + gate hook). Called once from forward."""
self._token_counting_initialized = True
num_experts = 0
for obj in [getattr(self, "gate", None), self, getattr(self, "experts", None)]:
if obj is not None:
for attr in ("num_experts", "n_routed_experts"):
if hasattr(obj, attr):
num_experts = getattr(obj, attr)
break
if num_experts:
break
if num_experts == 0:
warnings.warn(
f"{self.__class__.__name__}: could not resolve num_experts; "
"expert routing will not be tracked for this layer."
)
return
self.register_buffer(
"expert_token_count",
torch.zeros(num_experts, dtype=torch.long, device=next(self.parameters()).device),
persistent=False,
)
self._count_expert_tokens = False
if hasattr(self, "gate"):
self.gate.register_forward_hook(self._gate_forward_hook)
def _gate_forward_hook(self, module, input, output):
if not self._count_expert_tokens:
return
with torch.no_grad():
if isinstance(output, tuple) and len(output) >= 3:
# v5.x TopKRouter: returns (logits, scores, indices)
indices = output[2]
else:
# v4.x nn.Linear gate: returns logits tensor
logits = output if not isinstance(output, tuple) else output[0]
top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k
_, indices = torch.topk(logits.float(), top_k, dim=-1)
counts = torch.bincount(indices.reshape(-1), minlength=self.expert_token_count.shape[0])
self.expert_token_count += counts.to(self.expert_token_count.device)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self._moe_calib_experts_ratio is None:
return super().forward(hidden_states)
is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules())
# During calibration, forward all tokens to a larger fraction of experts to improve
# calibration coverage, then re-run with the original top_k for actual outputs.
if is_calib:
# Skip counting when all experts are calibrated (ratio == 1.0).
self._count_expert_tokens = self._moe_calib_experts_ratio < 1.0
if self._count_expert_tokens and not self._token_counting_initialized:
self._init_token_counting()
if TRANSFORMERS_VERSION_GE_5_0:
assert hasattr(self, "gate") and hasattr(self.gate, "top_k")
original_top_k = self.gate.top_k
self.gate.top_k = max(
original_top_k, round(self.gate.num_experts * self._moe_calib_experts_ratio)
)
super().forward(hidden_states)
self.gate.top_k = original_top_k
else:
# Path for transformers < 5.0
if hasattr(self, "gate") and hasattr(self.gate, "top_k"):
top_k_owner = self.gate
else:
top_k_owner = self
original_top_k = top_k_owner.top_k
if hasattr(self, "num_experts"):
top_k_owner.top_k = max(
original_top_k, round(self.num_experts * self._moe_calib_experts_ratio)
)
elif hasattr(self, "experts"):
num_experts = (
self.experts.num_experts
if hasattr(self.experts, "num_experts")
else len(self.experts)
)
top_k_owner.top_k = max(
original_top_k,
round(num_experts * self._moe_calib_experts_ratio),
)
else:
raise ValueError(f"Could not find num_experts in module {self}")
super().forward(hidden_states)
top_k_owner.top_k = original_top_k
self._count_expert_tokens = False
output = super().forward(hidden_states)
self._count_expert_tokens = False
return output
def layer_sync_moe_local_experts_amax(self, sync_weight_amax=False):
"""Sync input_quantizer amax across experts so all share the same amax per quantizer.
Skipped when _moe_calib_experts_ratio is set, as each expert is calibrated independently.
Also skipped when experts is a fused module (e.g. Llama4TextExperts) with shared quantizers.
Args:
sync_weight_amax: If True, also sync weight quantizer amax across experts.
"""
if self._moe_calib_experts_ratio is not None:
return
try:
iter(self.experts)
except TypeError:
return
sync_moe_expert_amax(self.experts, sync_weight_amax=sync_weight_amax)
class _QuantLlama4TextExperts(QuantModule):
def _setup(self):
self.gate_up_proj_input_quantizer = TensorQuantizer()
self.gate_up_proj_weight_quantizer = TensorQuantizer()
self.down_proj_input_quantizer = TensorQuantizer()
self.down_proj_weight_quantizer = TensorQuantizer()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
gate_up = torch.bmm(
self.gate_up_proj_input_quantizer(hidden_states),
_transposed_quantize(self.gate_up_proj, self.gate_up_proj_weight_quantizer),
)
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
next_states = torch.bmm(
self.down_proj_input_quantizer(up * self.act_fn(gate)),
_transposed_quantize(self.down_proj, self.down_proj_weight_quantizer),
)
next_states = next_states.view(-1, self.hidden_size)
return next_states
# For more information on DbrxExpert, see https://github.com/huggingface/transformers/blob/dcdda532/src/transformers/models/dbrx/modeling_dbrx.py#L756
class _QuantDbrxExperts(QuantModule):
def _setup(self):
"""Modify the DbrxExpert."""
# No setup is needed for DbrxExpert, we only need to update DbrxExpertGLU
# forward method copied from the original dbrx repo - https://github.com/databricks/dbrx/blob/a3200393/model/modeling_dbrx.py#L795
def forward(
self,
x: torch.Tensor,
weights: torch.Tensor,
top_weights: torch.Tensor,
top_experts: torch.LongTensor,
) -> torch.Tensor:
bsz, q_len, hidden_size = x.shape
x = x.view(-1, hidden_size)
out = torch.zeros_like(x)
expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(
2, 1, 0
)
for expert_idx in range(self.moe_num_experts):
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
if token_idx.shape[0] == 0:
continue
token_list = token_idx.tolist()
topk_list = topk_idx.tolist()
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
expert_out = (
self.mlp(expert_tokens, expert_idx) * top_weights[token_list, topk_list, None]
)
out.index_add_(0, token_idx, expert_out)
out = out.reshape(bsz, q_len, hidden_size)
return out
class _QuantDbrxExpertGLU(QuantModule):
def _setup(self):
"""Modify the DbrxExpertGLU by using nn.Linear layers."""
dtype, device = self.w1.dtype, self.w1.device
def _copy_weights(modules, weights):
modules.to(dtype=dtype, device=device)
for expert_idx, module in enumerate(modules):
with torch.no_grad():
module.weight.copy_(weights[expert_idx].detach())
self.w1_linear = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.ffn_hidden_size, bias=False)
for _ in range(self.moe_num_experts)
]
)
_copy_weights(
self.w1_linear,
self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size),
)
delattr(self, "w1")
self.v1_linear = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.ffn_hidden_size, bias=False)
for _ in range(self.moe_num_experts)
]
)
_copy_weights(
self.v1_linear,
self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size),
)
delattr(self, "v1")
self.w2_linear = nn.ModuleList(
[
nn.Linear(self.ffn_hidden_size, self.hidden_size, bias=False)
for _ in range(self.moe_num_experts)
]
)
_copy_weights(
self.w2_linear,
self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size).transpose(
1, 2
),
)
delattr(self, "w2")
def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
x1 = self.w1_linear[expert_idx](x)
x2 = self.v1_linear[expert_idx](x)
x1 = self.activation_fn(x1)
x1 = x1 * x2
return self.w2_linear[expert_idx](x1)
class _Qwen3VLMoeExpertModule(nn.Module):
"""Container for a single Qwen3VL MoE expert's linear layers.
Produces the naming pattern: experts.{id}.gate_proj.weight
(consistent with standard Qwen3 MoE per-expert module structure).
"""
def __init__(self, hidden_size: int, expert_dim: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, expert_dim, bias=False)
self.up_proj = nn.Linear(hidden_size, expert_dim, bias=False)
self.down_proj = nn.Linear(expert_dim, hidden_size, bias=False)
class _QuantQwen3VLMoeTextExperts(QuantModule):
def _setup(self):
"""Modify the Qwen3VLMoeTextExperts by using per-expert nn.Module containers.
This produces the naming pattern: experts.{id}.gate_proj.weight
(consistent with standard Qwen3 MoE per-expert module structure).
"""
from accelerate import init_empty_weights
dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device
def _copy_weight(module, weight):
module.to_empty(device=device)
with torch.no_grad():
module.weight.data = weight.detach().data.to(dtype=dtype, device=device)
# The attribute name was changed from `intermediate_size` to `intermediate_dim` in
# https://github.com/huggingface/transformers/commit/0642963ba13f2dae0596fe489415569e1d91fbda
if hasattr(self, "intermediate_size"):
expert_dim = self.intermediate_size
elif hasattr(self, "intermediate_dim"):
expert_dim = self.intermediate_dim
else:
raise AttributeError("Could not find intermediate dimension size in model")
with init_empty_weights():
expert_modules = nn.ModuleList(
[
_Qwen3VLMoeExpertModule(self.hidden_size, expert_dim)
for _ in range(self.num_experts)
]
)
for idx in range(self.num_experts):
_copy_weight(expert_modules[idx].gate_proj, self.gate_up_proj[idx, :, :expert_dim].T)
_copy_weight(expert_modules[idx].up_proj, self.gate_up_proj[idx, :, expert_dim:].T)
_copy_weight(expert_modules[idx].down_proj, self.down_proj[idx, :].T)
delattr(self, "gate_up_proj")
delattr(self, "down_proj")
# Register expert modules directly as numbered children
# so the naming pattern is: experts.{id}.gate_proj.weight (no extra nesting)
for idx in range(self.num_experts):
self.add_module(str(idx), expert_modules[idx])
def __len__(self):
"""Support len() so the module is iterable like standard MoE experts."""
return self.num_experts
def __iter__(self):
"""Support iteration over expert modules."""
for idx in range(self.num_experts):
yield getattr(self, str(idx))
def __getitem__(self, idx):
"""Support indexing to get individual expert modules."""
return getattr(self, str(int(idx)))
def forward(
self,
hidden_states: torch.Tensor,
routing_weights: torch.Tensor,
router_indices: torch.Tensor,
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
next_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
expert = self[expert_idx]
gate = expert.gate_proj(current_state)
up = expert.up_proj(current_state)
gated_output = up * self.act_fn(gate)
out = expert.down_proj(gated_output)
weighted_output = out * routing_weights[token_idx, expert_idx, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
next_states = next_states.view(batch_size, -1, self.hidden_size)
return next_states
class _Qwen35MoeExpertModule(nn.Module):
"""Container for a single Qwen3.5 MoE expert's linear layers.
Produces the naming pattern: experts.{id}.gate_proj.weight
(consistent with standard Qwen3 MoE per-expert module structure).
"""
def __init__(self, hidden_dim: int, expert_dim: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_dim, expert_dim, bias=False)
self.up_proj = nn.Linear(hidden_dim, expert_dim, bias=False)
self.down_proj = nn.Linear(expert_dim, hidden_dim, bias=False)
class _QuantQwen35MoeExperts(QuantModule):
def _setup(self):
"""Modify the Qwen3_5MoeExperts by using per-expert nn.Module containers.
This produces the naming pattern: experts.{id}.gate_proj.weight
(consistent with standard Qwen3 MoE).
"""
from accelerate import init_empty_weights
dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device
def _copy_weight(module, weight):
module.to_empty(device=device)
with torch.no_grad():
module.weight.data = weight.detach().data.to(dtype=dtype, device=device)
expert_dim = self.intermediate_dim
with init_empty_weights():
expert_modules = nn.ModuleList(
[
_Qwen35MoeExpertModule(self.hidden_dim, expert_dim)
for _ in range(self.num_experts)
]
)
for idx in range(self.num_experts):
# gate_up_proj shape: (num_experts, 2*intermediate_dim, hidden_dim)
# Already in (out_features, in_features) format, no transpose needed
_copy_weight(expert_modules[idx].gate_proj, self.gate_up_proj[idx, :expert_dim, :])
_copy_weight(expert_modules[idx].up_proj, self.gate_up_proj[idx, expert_dim:, :])
# down_proj shape: (num_experts, hidden_dim, intermediate_dim)
# Already in (out_features, in_features) format
_copy_weight(expert_modules[idx].down_proj, self.down_proj[idx])
delattr(self, "gate_up_proj")
delattr(self, "down_proj")
# Register expert modules directly as numbered children (like nn.ModuleList)
# so the naming pattern is: experts.{id}.gate_proj.weight (no extra nesting)
for idx in range(self.num_experts):
self.add_module(str(idx), expert_modules[idx])
def __len__(self):
"""Support len() so the module is iterable like standard MoE experts."""
return self.num_experts
def __iter__(self):
"""Support iteration over expert modules."""
for idx in range(self.num_experts):
yield getattr(self, str(idx))
def __getitem__(self, idx):
"""Support indexing to get individual expert modules."""
return getattr(self, str(int(idx)))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == self.num_experts:
continue
with torch.no_grad():
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
expert = self[expert_idx]
gate = expert.gate_proj(current_state)
up = expert.up_proj(current_state)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = expert.down_proj(current_hidden_states)
current_hidden_states = (
current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
)
final_hidden_states.index_add_(
0, token_idx, current_hidden_states.to(final_hidden_states.dtype)
)
return final_hidden_states
class _QuantDbrxFFN(_QuantSparseMoe):
@property
def num_experts(self):
return self.router.moe_num_experts
@property
def top_k(self):
return self.router.moe_top_k
@top_k.setter
def top_k(self, value):
self.router.moe_top_k = value
@contextmanager
def patch_compressed_linear_loading():
"""Context manager that patches CompressedLinear to survive custom ``_init_weights`` calls.
When loading pack-quantized models with ``trust_remote_code=True``,
``compressed_tensors`` replaces ``.weight`` with ``.weight_packed`` on
CompressedLinear modules. Custom model code (e.g. ``modeling_deepseek.py``)
often does ``module.weight.data.normal_(...)`` inside ``_init_weights``,
which crashes because ``.weight`` no longer exists.
This context manager monkey-patches ``CompressedLinear.__getattr__`` to
return a harmless dummy for ``.weight`` accesses, and restores the original
behaviour on exit (even if an exception is raised).
Usage::
from modelopt.torch.quantization.plugins.huggingface import patch_compressed_linear_loading
with patch_compressed_linear_loading():
model = AutoModelForCausalLM.from_pretrained(
ckpt_path,
device_map="auto",
trust_remote_code=True,
torch_dtype="auto",
)
"""
try:
from compressed_tensors.linear.compressed_linear import CompressedLinear
except ImportError:
yield
return
if getattr(CompressedLinear, "_modelopt_init_patched", False):
yield
return
original_getattr = getattr(CompressedLinear, "__getattr__", None)
class _DummyWeightData:
def __getattr__(self, name):
return lambda *args, **kwargs: self
class _DummyWeight:
def __init__(self):
self.data = _DummyWeightData()
def __getattr__(self, name):
return lambda *args, **kwargs: self
def patched_getattr(self, name):
if name == "weight":
if "_parameters" in self.__dict__ and "weight" in self._parameters:
return self._parameters["weight"]
if "weight" in self.__dict__:
return self.__dict__["weight"]
return _DummyWeight()
if original_getattr is not None:
return original_getattr(self, name)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
CompressedLinear.__getattr__ = patched_getattr
CompressedLinear._modelopt_init_patched = True
logger.info("Patched CompressedLinear for transformers compatibility")
try:
yield
finally:
if original_getattr is not None:
CompressedLinear.__getattr__ = original_getattr
elif hasattr(CompressedLinear, "__getattr__"):
del CompressedLinear.__getattr__
CompressedLinear._modelopt_init_patched = False
logger.info("Restored CompressedLinear original state")
class _QuantCompressedLinear(QuantModule):
"""Quantization wrapper for ``compressed_tensors`` CompressedLinear modules.
Handles on-the-fly decompression of pack-quantized INT4 weights during
calibration. This avoids fully decompressing all experts into GPU memory
at once (which would OOM for large MoE models), and also correctly handles
the ``weight_shape`` metadata that ``compressed_tensors`` stores as a
tensor rather than a plain list.
"""
def _setup(self):
self.input_quantizer = TensorQuantizer()
self.weight_quantizer = TensorQuantizer()
def _build_compressed_data(self):