Skip to content

Commit 2dfa873

Browse files
daniserebdanielkorzekwa
authored andcommitted
Add support for MXFP8 PTQ (#736)
## What does this PR do? **Type of change:** new feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Add support for MXFP8 PTQ, enabling MXFP8 hardware acceleration during inference on Blackwell GPUs. ## Usage <!-- You can potentially add a usage example below. --> ```bash export MODEL_PATH=/my_home/hf_models/nvidia/OpenMath2-Llama3.1-8B export OUTPUT_PATH=/my_home/hf_models/nvidia/OpenMath2-Llama3.1-8B-MXFP8 mkdir -p $OUTPUT_PATH python examples/llm_ptq/hf_ptq.py \ --export_fmt hf \ --dataset cnn_dailymail \ --pyt_ckpt_path $MODEL_PATH \ --export_path $OUTPUT_PATH \ --qformat mxfp8 ``` The `hf_quant_config.json` of the output checkpoint: ```json { "producer": { "name": "modelopt", "version": "0.41.0.dev50+g7a796a875" }, "quantization": { "quant_algo": "MXFP8", "kv_cache_quant_algo": "FP8", "group_size": 32, "exclude_modules": [ "lm_head" ] } } ``` And `config.json` (only the `quantization_config`): ```json ... "quantization_config": { "ignore": [ "lm_head" ], "quant_algo": "MXFP8", "kv_cache_scheme": { "dynamic": false, "num_bits": 8, "type": "float" }, "producer": { "name": "modelopt", "version": "0.41.0.dev50+g7a796a875" }, "quant_method": "modelopt" } ``` ## Testing <!-- Mention how have you tested your change if applicable. --> Used `hf_ptq.py` to quantize the model `nvidia/OpenMath2-Llama3.1-8B` ([available in hugging-face](https://huggingface.co/nvidia/OpenMath2-Llama3.1-8B)), see the example command above. Checked that the generated MXFP8 checkpoint can be loaded with vLLM (required changes in vLLM, not merged to main). Added tests for `MXFP8QTensor` in `tests/gpu/torch/quantization/test_qtensor_cuda.py`. Added "mxfp8" in `‎tests/examples/llm_ptq/test_llm_ptq.py` #### Support for Nemotron Models Verify that Nemotron Nano V3 BF16 can be converted to MXFP8 using `hf_ptq.py`: https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added MXFP8 quantization format support with new scaling mechanisms and quantization utilities. * Updated configuration options, example scripts, and utilities to recognize and process MXFP8 quantization workflows. * Extended quantization export pipelines to handle MXFP8 quantized models. * **Tests** * Expanded test coverage for MXFP8 quantization across various tensor shapes, data types, and device configurations. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com> Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
1 parent d1dac55 commit 2dfa873

File tree

10 files changed

+717
-18
lines changed

10 files changed

+717
-18
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
8686
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
8787
"nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG,
88+
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
8889
}
8990

9091
KV_QUANT_CFG_CHOICES = {
@@ -248,6 +249,7 @@ def auto_quantize(
248249
"fp8_pb_wo",
249250
"w4a8_mxfp4_fp8",
250251
"nvfp4_mlp_only",
252+
"mxfp8",
251253
]
252254
for args.qformat in qformat_list
253255
), "One or more quantization formats provided are not supported for unified checkpoint export"
@@ -862,6 +864,7 @@ def quantize_main(
862864
"fp8_pb_wo",
863865
"w4a8_mxfp4_fp8",
864866
"nvfp4_mlp_only",
867+
"mxfp8",
865868
]
866869
or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES
867870
), f"Plain quantization format {args.qformat} not supported for HF export path"

examples/llm_ptq/scripts/huggingface_example.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ esac
5353
IFS=","
5454
for qformat in $QFORMAT; do
5555
case $qformat in
56-
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only | nvfp4_svdquant) ;;
56+
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only | nvfp4_svdquant | mxfp8) ;;
5757
*)
58-
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only, nvfp4_svdquant]" >&2
58+
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only, nvfp4_svdquant, mxfp8]" >&2
5959
exit 1
6060
;;
6161
esac

modelopt/torch/export/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
QUANTIZATION_NVFP4_SVDQUANT = "nvfp4_svdquant"
3737
QUANTIZATION_W4A8_NVFP4_FP8 = "w4a8_nvfp4_fp8"
3838
QUANTIZATION_MXFP4 = "mxfp4"
39+
QUANTIZATION_MXFP8 = "mxfp8"
3940
QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8"
4041
QUANTIZATION_NVFP4_AWQ = "nvfp4_awq"
4142
QUANTIZATION_FP8_PB_REAL = "fp8_pb_real"

modelopt/torch/export/quant_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from modelopt.torch.quantization.qtensor import (
3535
FP8QTensor,
3636
MXFP4QTensor,
37+
MXFP8QTensor,
3738
NVFP4QTensor,
3839
QTensorWrapper,
3940
)
@@ -58,6 +59,7 @@
5859
QUANTIZATION_INT8_SQ,
5960
QUANTIZATION_INT8_WO,
6061
QUANTIZATION_MXFP4,
62+
QUANTIZATION_MXFP8,
6163
QUANTIZATION_NONE,
6264
QUANTIZATION_NVFP4,
6365
QUANTIZATION_NVFP4_AWQ,
@@ -326,6 +328,9 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
326328
return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[
327329
1
328330
].reshape(*weight.shape[:-1], -1)
331+
332+
if quantization_format == QUANTIZATION_MXFP8:
333+
return MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, weight_quantizer)
329334
return get_scaling_factor(weight_quantizer)
330335

331336

@@ -524,6 +529,14 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
524529
if weight_quantizer.num_bits == (4, 3):
525530
if weight_quantizer.block_sizes:
526531
assert weight_quantizer.block_sizes[-1] > 0, "Invalid block_sizes for FP8 quantizer"
532+
# Check if this is MXFP8 (dynamic block quantization with scale_bits (8, 0))
533+
block_sizes = getattr(weight_quantizer, "block_sizes")
534+
if (
535+
isinstance(block_sizes, dict)
536+
and block_sizes.get("type", "static") == "dynamic"
537+
and block_sizes.get("scale_bits") == (8, 0)
538+
):
539+
return QUANTIZATION_MXFP8
527540
if weight_quantizer.fake_quant:
528541
return QUANTIZATION_FP8_PB_WO
529542
else:
@@ -724,6 +737,11 @@ def process_layer_quant_config(layer_config_dict):
724737
"quant_algo": "NVFP4_SVD",
725738
"group_size": block_size_value,
726739
}
740+
elif v == "mxfp8":
741+
layer_config = {
742+
"quant_algo": "MXFP8",
743+
"group_size": block_size_value,
744+
}
727745
else:
728746
layer_config = {"quant_algo": v}
729747

@@ -828,6 +846,9 @@ def to_quantized_weight(
828846
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
829847
return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)
830848

849+
if quantization == QUANTIZATION_MXFP8:
850+
return MXFP8QTensor.quantize_with_scale(weight, weights_scaling_factor)
851+
831852
if quantization == QUANTIZATION_FP8_PB_WO:
832853
return FP8QTensor.quantize(
833854
weight, weights_scaling_factor.squeeze(), block_sizes={-1: block_size, -2: block_size}

modelopt/torch/export/unified_export_hf.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151

5252
from modelopt.torch.quantization import set_quantizer_by_cfg_context
5353
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
54-
from modelopt.torch.quantization.qtensor import NVFP4QTensor
54+
from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor
5555
from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names
5656

5757
from .convert_hf_config import convert_hf_quant_config_format
@@ -67,6 +67,7 @@
6767
QUANTIZATION_FP8,
6868
QUANTIZATION_FP8_PB_REAL,
6969
QUANTIZATION_FP8_PC_PT,
70+
QUANTIZATION_MXFP8,
7071
QUANTIZATION_NONE,
7172
QUANTIZATION_NVFP4,
7273
QUANTIZATION_NVFP4_AWQ,
@@ -426,6 +427,15 @@ def _export_quantized_weight(
426427
weight_quantizer._scale.to(torch.float32),
427428
)
428429
del weight_quantizer._scale
430+
elif quantization_format == QUANTIZATION_MXFP8:
431+
# MXFP8 uses dynamic block quantization with E8M0 scales (uint8)
432+
weight = getattr(sub_module, weight_name)
433+
e8m0_scale = MXFP8QTensor.get_weights_scaling_factor_from_quantizer(
434+
weight, weight_quantizer
435+
)
436+
sub_module.register_buffer(quantizer_attrs.weight_scale, e8m0_scale)
437+
if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None:
438+
del weight_quantizer._scale
429439
else:
430440
sub_module.register_buffer(
431441
quantizer_attrs.weight_scale, get_weight_scaling_factor(sub_module, weight_name)

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
INT4QTensor,
5050
INT8QTensor,
5151
MXFP4QTensor,
52+
MXFP8QTensor,
5253
NF4QTensor,
5354
NVFP4QTensor,
5455
QTensorWrapper,
@@ -649,8 +650,32 @@ def _real_quantize(self, inputs):
649650
assert self._is_real_quantize_support(), "Real quantization not supported for this format."
650651

651652
buffer_to_register = {}
652-
if self._num_bits == (4, 3):
653-
# FP8 quantization
653+
# Check MX formats first (before FP8) since MXFP8 also has num_bits=(4,3)
654+
if (
655+
self._block_sizes
656+
and self._block_sizes.get("scale_bits") == (8, 0)
657+
and self._block_sizes.get("type") == "dynamic"
658+
):
659+
# MX quantization (MXFP4/MXFP8)
660+
if self._num_bits == (2, 1):
661+
# MXFP4
662+
outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1])
663+
buffer_to_register["_scale"] = scales
664+
elif self._num_bits == (4, 3):
665+
# MXFP8
666+
assert self._block_sizes[-1] == MXFP8QTensor.BLOCK_SIZE, (
667+
f"MXFP8 requires block size {MXFP8QTensor.BLOCK_SIZE}, "
668+
f"got {self._block_sizes[-1]}"
669+
)
670+
outputs, scales = MXFP8QTensor.quantize(inputs)
671+
buffer_to_register["_scale"] = scales
672+
else:
673+
raise ValueError(
674+
f"Unsupported MX format: num_bits={self._num_bits}. "
675+
f"Expected (2, 1) for MXFP4 or (4, 3) for MXFP8."
676+
)
677+
elif self._num_bits == (4, 3):
678+
# FP8 quantization (non-MX)
654679
# For per-tensor/per-channel quantization, we might need amax which is synced across all ranks
655680
# For blockwise quantization, amax will be recomputed in the kernel
656681
use_amax = self.amax is not None and not (self._block_sizes and self.amax.numel() == 1)
@@ -683,18 +708,6 @@ def _real_quantize(self, inputs):
683708
buffer_to_register["_scale"] = _scale
684709
buffer_to_register["_double_scale"] = _double_scale
685710
buffer_to_register["_scale_zeros"] = _scale_zeros
686-
elif (
687-
self._block_sizes.get("scale_bits") == (8, 0)
688-
and self._block_sizes.get("type") == "dynamic"
689-
):
690-
# MX quantization
691-
if self._num_bits == (2, 1):
692-
outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1])
693-
buffer_to_register["_scale"] = scales
694-
else:
695-
raise ValueError(
696-
f"Real quantization for MX {self._num_bits} format is not supported."
697-
)
698711
elif self._block_sizes.get("scale_bits") == (4, 3):
699712
# NVFP4 default quantization
700713
# Return real quantized tensor and store scales inside TensorQuantizer

modelopt/torch/quantization/qtensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@
2020
from .int4_tensor import *
2121
from .int8_tensor import *
2222
from .mxfp4_tensor import *
23+
from .mxfp8_tensor import *
2324
from .nf4_tensor import *
2425
from .nvfp4_tensor import *

0 commit comments

Comments
 (0)