Skip to content

Commit 3801923

Browse files
authored
Support MiniMax M2.1 (FP8 checkpoint) (#817)
## What does this PR do? **Type of change:** ? new feature **Overview:** ? Support loading the MiniMax M2.1 (FP8) checkpoint for PTQ. ## Usage scripts/huggingface_example.sh --model <minimax checkpoint> --quant nvfp4 --trust_remote_code ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added MiniMax M2.1 model quantization support with nvfp4 format. * Extended FP8 quantization capabilities with configurable dtype parameter for enhanced precision control. * **Improvements** * Enhanced detection of quantized linear module variants. * Improved weight unpacking for FP8-based linear modules. * **Documentation** * Updated supported models table to include MiniMax M2.1. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chenjie Luo <chenjiel@nvidia.com> Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com>
1 parent 590f9fc commit 3801923

9 files changed

Lines changed: 128 additions & 32 deletions

File tree

examples/deepseek/ptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from modelopt.torch.export.model_config import KV_CACHE_FP8
5757
from modelopt.torch.export.quant_utils import get_quant_config
5858
from modelopt.torch.quantization.nn import TensorQuantizer
59+
from modelopt.torch.quantization.triton import weight_dequant
5960
from modelopt.torch.quantization.utils import (
6061
is_quantized_column_parallel_linear,
6162
is_quantized_parallel_linear,
@@ -77,7 +78,6 @@
7778
)
7879

7980
import model as deekseep_model # noqa: E402
80-
from ds_kernel import weight_dequant # noqa: E402
8181
from kernel import act_quant, fp8_gemm # noqa: E402
8282

8383

examples/llm_ptq/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
111111
| DeepSeek V3, R1, V3.1, V3.2<sup>7</sup> | - | - | - | - ||
112112
| GLM-4.7<sup>8</sup> || - | - | - ||
113113
| Kimi K2 | - | - | - | - ||
114+
| MiniMax M2.1 | - | - | - | - ||
114115
| T5 ||||| - |
115116
| Whisper ||||| - |
116117

examples/llm_ptq/example_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def build_quant_cfg(
243243
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
244244
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
245245

246-
if model_type in ["qwen3moe", "qwen3next"] and qformat == "nvfp4":
246+
if model_type in ["qwen3moe", "qwen3next", "minimax"] and qformat == "nvfp4":
247247
# Disable the attention projection layers to retain accuracy
248248
quant_cfg["quant_cfg"]["model*.*attn*in_proj*"] = {"enable": False}
249249
quant_cfg["quant_cfg"]["model*.*attn*q_proj*"] = {"enable": False}

modelopt/torch/export/layer_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,14 @@ def is_moe(module: nn.Module) -> bool:
346346
def is_quantlinear(module: nn.Module) -> bool:
347347
"""Returns whether the module is a quantized linear layer."""
348348
name = type(module).__name__
349-
return ("QuantLinear" in name or "QuantCompressedLinear" in name) and "lora" not in name.lower()
349+
return (
350+
any(
351+
keyword in name
352+
for keyword in ["QuantLinear", "QuantCompressedLinear", "QuantFP8Linear"]
353+
)
354+
and "lora" not in name.lower()
355+
and "ds_kernel" not in name.lower()
356+
)
350357

351358

352359
def dup_kv_weight(v: torch.Tensor, head_size: int, num_head: int, tp_size: int) -> torch.Tensor:

modelopt/torch/export/model_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
"Deepseek": "deepseek",
5656
"Whisper": "whisper",
5757
"gptoss": "gptoss",
58+
"MiniMax": "minimax",
5859
}
5960

6061
__doc__ = f"""Utility functions for model type detection and classification.

modelopt/torch/export/unified_export_hf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,9 @@ def _process_quantized_modules(
589589
if is_modelopt_qlora and (hasattr(sub_module, "base_layer")):
590590
continue
591591

592-
if hasattr(sub_module, "weight_packed"):
592+
if hasattr(sub_module, "weight_packed") or (
593+
"QuantFP8Linear" in type(sub_module).__name__ and sub_module.weight.element_size() <= 1
594+
):
593595
sub_module.unpack_weight()
594596
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
595597
if is_quantlinear(sub_module):

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from typing import TYPE_CHECKING
2323

2424
import torch
25+
import transformers
26+
from packaging import version
2527
from torch import Tensor
2628
from torch.nn.functional import linear
2729

@@ -38,7 +40,6 @@
3840
kitchen = None
3941

4042
import torch.nn as nn
41-
import transformers
4243
from transformers.models.t5.modeling_t5 import T5Attention
4344

4445
from modelopt.torch.opt.dynamic import DynamicModule
@@ -48,6 +49,13 @@
4849
from ..conversion import register
4950
from ..nn import QuantInputBase, QuantModule, QuantModuleRegistry, TensorQuantizer
5051
from ..nn.modules.quant_linear import _QuantLinear
52+
from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE
53+
54+
if IS_TRITON_AVAILABLE:
55+
from ..triton import weight_dequant
56+
else:
57+
weight_dequant = None
58+
5159
from ..utils import replace_function
5260
from .attention import register_attention_for_kv_quant
5361
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin
@@ -57,6 +65,8 @@
5765

5866
__all__ = ["register_hf_attentions_on_the_fly"]
5967

68+
TRANSFORMERS_VERSION_GE_5_0 = version.parse(transformers.__version__) >= version.parse("5.0.0")
69+
6070

6171
class _QuantAttention(QuantModule):
6272
"""Attention class for KV Cache quantization compatible with new_attention_interface in transformers >= 4.48.0."""
@@ -447,10 +457,24 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
447457
# If any of the experts are in calibration mode, we will forward all tokens to all experts
448458
# This is used only for calibration, we need to re-calculate the actual outputs again using
449459
# the original top_k
450-
original_top_k = self.top_k
451-
self.top_k = self.num_experts
452-
super().forward(hidden_states)
453-
self.top_k = original_top_k
460+
if TRANSFORMERS_VERSION_GE_5_0:
461+
assert hasattr(self, "gate")
462+
# Path for transformers >= 5.0
463+
original_top_k = self.gate.topk
464+
self.gate.topk = self.gate.num_experts
465+
super().forward(hidden_states)
466+
self.gate.topk = original_top_k
467+
else:
468+
# Path for transformers < 5.0
469+
original_top_k = self.top_k
470+
if hasattr(self, "num_experts"):
471+
self.top_k = self.num_experts
472+
elif hasattr(self, "experts"):
473+
self.top_k = self.experts.num_experts
474+
else:
475+
raise ValueError(f"Could not find num_experts in module {self}")
476+
super().forward(hidden_states)
477+
self.top_k = original_top_k
454478
return super().forward(hidden_states)
455479

456480

@@ -693,6 +717,53 @@ def unpack_weight(self):
693717
del self.weight_scale
694718

695719

720+
class _QuantFP8Linear(QuantModule):
721+
def _setup(self):
722+
self.input_quantizer = TensorQuantizer()
723+
self.weight_quantizer = TensorQuantizer()
724+
assert self.weight_scale_inv.ndim == 2, "Weight scale inverse must be 2D"
725+
assert self.weight.ndim == 2, "Weight must be 2D"
726+
self.block_size = max(
727+
self.weight.shape[0] // self.weight_scale_inv.shape[0],
728+
self.weight.shape[1] // self.weight_scale_inv.shape[1],
729+
)
730+
assert self.block_size == 128, "Block size must be 128"
731+
732+
def _get_weight_and_scale_inv(self):
733+
if isinstance(self.weight, torch.distributed.tensor.DTensor):
734+
weight = self.weight._local_tensor.contiguous()
735+
scale_inv = self.weight_scale_inv._local_tensor.contiguous()
736+
else:
737+
weight = self.weight.contiguous()
738+
scale_inv = self.weight_scale_inv.contiguous()
739+
return weight, scale_inv
740+
741+
def forward(self, input: Tensor) -> Tensor:
742+
assert weight_dequant is not None, "Triton is not available"
743+
if self.weight.element_size() == 1:
744+
with torch.cuda.device(self.weight.device):
745+
weight, scale_inv = self._get_weight_and_scale_inv()
746+
weight = weight_dequant(weight, scale_inv, self.block_size, dtype=input.dtype)
747+
else:
748+
weight = self.weight
749+
return linear(
750+
self.input_quantizer(input),
751+
self.weight_quantizer(weight),
752+
self.bias,
753+
)
754+
755+
def unpack_weight(self):
756+
assert weight_dequant is not None, "Triton is not available"
757+
with torch.cuda.device(self.weight.device):
758+
weight, scale_inv = self._get_weight_and_scale_inv()
759+
self.weight = nn.Parameter(
760+
weight_dequant(weight, scale_inv, self.block_size, dtype=torch.get_default_dtype()),
761+
requires_grad=False,
762+
)
763+
if hasattr(self, "weight_scale_inv"):
764+
del self.weight_scale_inv
765+
766+
696767
try:
697768
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe
698769

@@ -796,6 +867,14 @@ def unpack_weight(self):
796867
except ImportError:
797868
pass
798869

870+
try:
871+
from transformers.integrations.finegrained_fp8 import FP8Linear
872+
873+
if FP8Linear not in QuantModuleRegistry:
874+
QuantModuleRegistry.register({FP8Linear: "hf.FP8Linear"})(_QuantFP8Linear)
875+
except ImportError:
876+
pass
877+
799878

800879
class _QuantGptOssExperts(_QuantFunctionalMixin):
801880
"""Quantized wrapper for `transformers.GptOssExperts`.
@@ -910,6 +989,17 @@ def register_falcon_linears_on_the_fly(model):
910989
QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear)
911990

912991

992+
def register_minimax_m2_moe_on_the_fly(model):
993+
"""Register MiniMax M2 MoE modules as a QUANT_MODULE.
994+
995+
MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly.
996+
"""
997+
if type(model).__name__ in ["MiniMaxM2ForCausalLM"]:
998+
moe_type = type(model.model.layers[0].block_sparse_moe)
999+
if QuantModuleRegistry.get(moe_type) is None:
1000+
QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe)
1001+
1002+
9131003
def _is_supported_hf_model(model):
9141004
"""Check if the model a valid model for transformers quantization specific support."""
9151005
supported_models = [transformers.PreTrainedModel]
@@ -975,6 +1065,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
9751065
[
9761066
register_falcon_linears_on_the_fly,
9771067
register_dbrx_moe_on_the_fly,
1068+
register_minimax_m2_moe_on_the_fly,
9781069
register_hf_attentions_on_the_fly,
9791070
convert_hf_parallel_linears_on_the_fly,
9801071
]

modelopt/torch/quantization/triton/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
):
3333
# fp4_kernel works on any CUDA GPU with triton
3434
from .fp4_kernel import *
35+
from .fp8_kernel import *
3536

3637
# fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv)
3738
if torch.cuda.get_device_capability() >= (8, 9):

examples/deepseek/ds_kernel.py renamed to modelopt/torch/quantization/triton/fp8_kernel.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -35,32 +35,18 @@
3535
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3636
# SOFTWARE.
3737

38-
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
39-
# SPDX-License-Identifier: Apache-2.0
40-
#
41-
# Licensed under the Apache License, Version 2.0 (the "License");
42-
# you may not use this file except in compliance with the License.
43-
# You may obtain a copy of the License at
44-
#
45-
# http://www.apache.org/licenses/LICENSE-2.0
46-
#
47-
# Unless required by applicable law or agreed to in writing, software
48-
# distributed under the License is distributed on an "AS IS" BASIS,
49-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50-
# See the License for the specific language governing permissions and
51-
# limitations under the License.
38+
"""FP8 Triton Kernel Implementations."""
5239

5340
import torch
5441
import triton
5542
import triton.language as tl
5643

57-
"""Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py"""
58-
5944

6045
@triton.jit
6146
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
62-
"""
63-
Dequantizes weights using the provided scaling factors and stores the result.
47+
"""Dequantizes weights using the provided scaling factors and stores the result.
48+
49+
Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
6450
6551
Args:
6652
x_ptr (tl.pointer): Pointer to the quantized weights.
@@ -86,14 +72,21 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
8672
tl.store(y_ptr + offs, y, mask=mask)
8773

8874

89-
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
90-
"""
91-
Dequantizes the given weight tensor using the provided scale tensor.
75+
def weight_dequant(
76+
x: torch.Tensor,
77+
s: torch.Tensor,
78+
block_size: int = 128,
79+
dtype: torch.dtype = torch.get_default_dtype(),
80+
) -> torch.Tensor:
81+
"""Dequantizes the given weight tensor using the provided scale tensor.
82+
83+
Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
9284
9385
Args:
9486
x (torch.Tensor): The quantized weight tensor of shape (M, N).
9587
s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size).
9688
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
89+
dtype (torch.dtype, optional): The dtype of the output tensor. Defaults to torch.get_default_dtype().
9790
9891
Returns:
9992
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
@@ -104,7 +97,7 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
10497
assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous"
10598
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
10699
M, N = x.size()
107-
y = torch.empty_like(x, dtype=torch.get_default_dtype())
100+
y = torch.empty_like(x, dtype=dtype)
108101
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]))
109102
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
110103
return y

0 commit comments

Comments
 (0)