|
22 | 22 | from typing import TYPE_CHECKING |
23 | 23 |
|
24 | 24 | import torch |
| 25 | +import transformers |
| 26 | +from packaging import version |
25 | 27 | from torch import Tensor |
26 | 28 | from torch.nn.functional import linear |
27 | 29 |
|
|
38 | 40 | kitchen = None |
39 | 41 |
|
40 | 42 | import torch.nn as nn |
41 | | -import transformers |
42 | 43 | from transformers.models.t5.modeling_t5 import T5Attention |
43 | 44 |
|
44 | 45 | from modelopt.torch.opt.dynamic import DynamicModule |
|
48 | 49 | from ..conversion import register |
49 | 50 | from ..nn import QuantInputBase, QuantModule, QuantModuleRegistry, TensorQuantizer |
50 | 51 | from ..nn.modules.quant_linear import _QuantLinear |
| 52 | +from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE |
| 53 | + |
| 54 | +if IS_TRITON_AVAILABLE: |
| 55 | + from ..triton import weight_dequant |
| 56 | +else: |
| 57 | + weight_dequant = None |
| 58 | + |
51 | 59 | from ..utils import replace_function |
52 | 60 | from .attention import register_attention_for_kv_quant |
53 | 61 | from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin |
|
57 | 65 |
|
58 | 66 | __all__ = ["register_hf_attentions_on_the_fly"] |
59 | 67 |
|
| 68 | +TRANSFORMERS_VERSION_GE_5_0 = version.parse(transformers.__version__) >= version.parse("5.0.0") |
| 69 | + |
60 | 70 |
|
61 | 71 | class _QuantAttention(QuantModule): |
62 | 72 | """Attention class for KV Cache quantization compatible with new_attention_interface in transformers >= 4.48.0.""" |
@@ -447,10 +457,24 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
447 | 457 | # If any of the experts are in calibration mode, we will forward all tokens to all experts |
448 | 458 | # This is used only for calibration, we need to re-calculate the actual outputs again using |
449 | 459 | # the original top_k |
450 | | - original_top_k = self.top_k |
451 | | - self.top_k = self.num_experts |
452 | | - super().forward(hidden_states) |
453 | | - self.top_k = original_top_k |
| 460 | + if TRANSFORMERS_VERSION_GE_5_0: |
| 461 | + assert hasattr(self, "gate") |
| 462 | + # Path for transformers >= 5.0 |
| 463 | + original_top_k = self.gate.topk |
| 464 | + self.gate.topk = self.gate.num_experts |
| 465 | + super().forward(hidden_states) |
| 466 | + self.gate.topk = original_top_k |
| 467 | + else: |
| 468 | + # Path for transformers < 5.0 |
| 469 | + original_top_k = self.top_k |
| 470 | + if hasattr(self, "num_experts"): |
| 471 | + self.top_k = self.num_experts |
| 472 | + elif hasattr(self, "experts"): |
| 473 | + self.top_k = self.experts.num_experts |
| 474 | + else: |
| 475 | + raise ValueError(f"Could not find num_experts in module {self}") |
| 476 | + super().forward(hidden_states) |
| 477 | + self.top_k = original_top_k |
454 | 478 | return super().forward(hidden_states) |
455 | 479 |
|
456 | 480 |
|
@@ -693,6 +717,53 @@ def unpack_weight(self): |
693 | 717 | del self.weight_scale |
694 | 718 |
|
695 | 719 |
|
| 720 | +class _QuantFP8Linear(QuantModule): |
| 721 | + def _setup(self): |
| 722 | + self.input_quantizer = TensorQuantizer() |
| 723 | + self.weight_quantizer = TensorQuantizer() |
| 724 | + assert self.weight_scale_inv.ndim == 2, "Weight scale inverse must be 2D" |
| 725 | + assert self.weight.ndim == 2, "Weight must be 2D" |
| 726 | + self.block_size = max( |
| 727 | + self.weight.shape[0] // self.weight_scale_inv.shape[0], |
| 728 | + self.weight.shape[1] // self.weight_scale_inv.shape[1], |
| 729 | + ) |
| 730 | + assert self.block_size == 128, "Block size must be 128" |
| 731 | + |
| 732 | + def _get_weight_and_scale_inv(self): |
| 733 | + if isinstance(self.weight, torch.distributed.tensor.DTensor): |
| 734 | + weight = self.weight._local_tensor.contiguous() |
| 735 | + scale_inv = self.weight_scale_inv._local_tensor.contiguous() |
| 736 | + else: |
| 737 | + weight = self.weight.contiguous() |
| 738 | + scale_inv = self.weight_scale_inv.contiguous() |
| 739 | + return weight, scale_inv |
| 740 | + |
| 741 | + def forward(self, input: Tensor) -> Tensor: |
| 742 | + assert weight_dequant is not None, "Triton is not available" |
| 743 | + if self.weight.element_size() == 1: |
| 744 | + with torch.cuda.device(self.weight.device): |
| 745 | + weight, scale_inv = self._get_weight_and_scale_inv() |
| 746 | + weight = weight_dequant(weight, scale_inv, self.block_size, dtype=input.dtype) |
| 747 | + else: |
| 748 | + weight = self.weight |
| 749 | + return linear( |
| 750 | + self.input_quantizer(input), |
| 751 | + self.weight_quantizer(weight), |
| 752 | + self.bias, |
| 753 | + ) |
| 754 | + |
| 755 | + def unpack_weight(self): |
| 756 | + assert weight_dequant is not None, "Triton is not available" |
| 757 | + with torch.cuda.device(self.weight.device): |
| 758 | + weight, scale_inv = self._get_weight_and_scale_inv() |
| 759 | + self.weight = nn.Parameter( |
| 760 | + weight_dequant(weight, scale_inv, self.block_size, dtype=torch.get_default_dtype()), |
| 761 | + requires_grad=False, |
| 762 | + ) |
| 763 | + if hasattr(self, "weight_scale_inv"): |
| 764 | + del self.weight_scale_inv |
| 765 | + |
| 766 | + |
696 | 767 | try: |
697 | 768 | from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe |
698 | 769 |
|
@@ -796,6 +867,14 @@ def unpack_weight(self): |
796 | 867 | except ImportError: |
797 | 868 | pass |
798 | 869 |
|
| 870 | +try: |
| 871 | + from transformers.integrations.finegrained_fp8 import FP8Linear |
| 872 | + |
| 873 | + if FP8Linear not in QuantModuleRegistry: |
| 874 | + QuantModuleRegistry.register({FP8Linear: "hf.FP8Linear"})(_QuantFP8Linear) |
| 875 | +except ImportError: |
| 876 | + pass |
| 877 | + |
799 | 878 |
|
800 | 879 | class _QuantGptOssExperts(_QuantFunctionalMixin): |
801 | 880 | """Quantized wrapper for `transformers.GptOssExperts`. |
@@ -910,6 +989,17 @@ def register_falcon_linears_on_the_fly(model): |
910 | 989 | QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear) |
911 | 990 |
|
912 | 991 |
|
| 992 | +def register_minimax_m2_moe_on_the_fly(model): |
| 993 | + """Register MiniMax M2 MoE modules as a QUANT_MODULE. |
| 994 | +
|
| 995 | + MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly. |
| 996 | + """ |
| 997 | + if type(model).__name__ in ["MiniMaxM2ForCausalLM"]: |
| 998 | + moe_type = type(model.model.layers[0].block_sparse_moe) |
| 999 | + if QuantModuleRegistry.get(moe_type) is None: |
| 1000 | + QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe) |
| 1001 | + |
| 1002 | + |
913 | 1003 | def _is_supported_hf_model(model): |
914 | 1004 | """Check if the model a valid model for transformers quantization specific support.""" |
915 | 1005 | supported_models = [transformers.PreTrainedModel] |
@@ -975,6 +1065,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): |
975 | 1065 | [ |
976 | 1066 | register_falcon_linears_on_the_fly, |
977 | 1067 | register_dbrx_moe_on_the_fly, |
| 1068 | + register_minimax_m2_moe_on_the_fly, |
978 | 1069 | register_hf_attentions_on_the_fly, |
979 | 1070 | convert_hf_parallel_linears_on_the_fly, |
980 | 1071 | ] |
|
0 commit comments