Skip to content

Commit b44c60a

Browse files
authored
Svdquant huggingface checkpoint export support (#754)
## What does this PR do? **Type of change:** new feature **Overview:** ## Usage ```bash cd ./examples/llm_ptq/ python hf_ptq.py \ --pyt_ckpt_path Qwen/Qwen3-4B \ --export_path /home/scratch.shiychen_coreai/quantized_models/Qwen3-4B-svdq \ --qformat nvfp4_awq_svdquant --kv_cache_qformat none --sparsity_fmt dense --calib_size 8 ``` ## Testing exported checkpoint and loaded. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added nvfp4_svdquant as a new quantization format option for LLM model quantization workflows. * **Limitations** * Multi-GPU export configurations using tensor or pipeline parallelism are not supported with nvfp4_svdquant quantization. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
1 parent 945ee02 commit b44c60a

8 files changed

Lines changed: 160 additions & 62 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
8484
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
8585
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
86+
"nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG,
8687
}
8788

8889
KV_QUANT_CFG_CHOICES = {
@@ -506,6 +507,10 @@ def export_quantized(
506507
or args.sparsity_fmt != "dense"
507508
or "int8_sq" in args.qformat
508509
):
510+
if (
511+
args.inference_tensor_parallel != 1 or args.inference_pipeline_parallel != 1
512+
) and args.qformat == "nvfp4_svdquant":
513+
raise NotImplementedError("Svdquant does not support multiple GPUs yet.")
509514
warnings.warn(
510515
"Still exporting TensorRT-LLM checkpoints for models not supported by the TensorRT-LLM torch runtime."
511516
)

examples/llm_ptq/scripts/huggingface_example.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ esac
5353
IFS=","
5454
for qformat in $QFORMAT; do
5555
case $qformat in
56-
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only) ;;
56+
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only | nvfp4_svdquant) ;;
5757
*)
58-
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only]" >&2
58+
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only, nvfp4_svdquant]" >&2
5959
exit 1
6060
;;
6161
esac

modelopt/torch/export/model_config.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
QUANTIZATION_INT4_AWQ = "int4_awq"
3434
QUANTIZATION_W4A8_AWQ = "w4a8_awq"
3535
QUANTIZATION_NVFP4 = "nvfp4"
36+
QUANTIZATION_NVFP4_SVDQUANT = "nvfp4_svdquant"
3637
QUANTIZATION_W4A8_NVFP4_FP8 = "w4a8_nvfp4_fp8"
3738
QUANTIZATION_MXFP4 = "mxfp4"
3839
QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8"
@@ -507,12 +508,20 @@ def hidden_size(self):
507508
"""Returns the hidden size of the transformer model."""
508509
if isinstance(self.mlp, MOEConfig):
509510
# fc.weight for MOE is stacked
510-
if self.mlp.fc.quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
511+
if self.mlp.fc.quantization in [
512+
QUANTIZATION_NVFP4,
513+
QUANTIZATION_NVFP4_AWQ,
514+
QUANTIZATION_NVFP4_SVDQUANT,
515+
]:
511516
return self.mlp.fc.weight.shape[-1] * 2
512517
return self.mlp.fc.weight.shape[-1]
513518
else:
514519
k = self.mlp.fc.weight.shape[1]
515-
if self.mlp.fc.quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
520+
if self.mlp.fc.quantization in [
521+
QUANTIZATION_NVFP4,
522+
QUANTIZATION_NVFP4_AWQ,
523+
QUANTIZATION_NVFP4_SVDQUANT,
524+
]:
516525
return k * 2
517526
return k
518527

modelopt/torch/export/postprocess.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
LINEAR_ROW,
3636
QUANTIZATION_NVFP4,
3737
QUANTIZATION_NVFP4_AWQ,
38+
QUANTIZATION_NVFP4_SVDQUANT,
3839
ConvConfig,
3940
EmbeddingConfig,
4041
ExpertConfig,
@@ -398,7 +399,10 @@ def _merge_model_configs_to_first_tp(config, ranks: list[int], group=None):
398399
group_size=config.awq_block_size,
399400
quantization=config.quantization,
400401
)
401-
if config.quantization == QUANTIZATION_NVFP4_AWQ:
402+
if config.quantization in [
403+
QUANTIZATION_NVFP4_AWQ,
404+
QUANTIZATION_NVFP4_SVDQUANT,
405+
]:
402406
# We have to update weight_scaling_factor and weight_scaling_factor_2
403407
config.weights_scaling_factor, config.weights_scaling_factor_2 = (
404408
NVFP4QTensor.get_weights_scaling_factor(
@@ -430,6 +434,7 @@ def _merge_model_configs_to_first_tp(config, ranks: list[int], group=None):
430434
if config.quantization in [
431435
QUANTIZATION_NVFP4,
432436
QUANTIZATION_NVFP4_AWQ,
437+
QUANTIZATION_NVFP4_SVDQUANT,
433438
]:
434439
(
435440
config.weights_scaling_factor,

modelopt/torch/export/quant_utils.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
import torch.nn as nn
2626

2727
from modelopt import __version__
28-
from modelopt.torch.quantization.model_calib import enable_stats_collection, finish_stats_collection
28+
from modelopt.torch.quantization.model_calib import (
29+
enable_stats_collection,
30+
finish_stats_collection,
31+
svd,
32+
)
2933
from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear
3034
from modelopt.torch.quantization.qtensor import (
3135
FP8QTensor,
@@ -57,6 +61,7 @@
5761
QUANTIZATION_NONE,
5862
QUANTIZATION_NVFP4,
5963
QUANTIZATION_NVFP4_AWQ,
64+
QUANTIZATION_NVFP4_SVDQUANT,
6065
QUANTIZATION_W4A8_AWQ,
6166
QUANTIZATION_W4A8_MXFP4_FP8,
6267
QUANTIZATION_W4A8_NVFP4_FP8,
@@ -165,7 +170,7 @@ def resmooth_and_get_scale(
165170
)
166171
new_weights.append(weight)
167172
# If NVFP4_AWQ then we view the scales as uint8 to allow for cat later
168-
if quantization == QUANTIZATION_NVFP4_AWQ:
173+
if quantization in [QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT]:
169174
scale, _ = NVFP4QTensor.get_weights_scaling_factor(weight, group_size).view(torch.uint8)
170175
else:
171176
scale = get_scaling_factor_from_weight(weight, group_size)
@@ -176,7 +181,7 @@ def resmooth_and_get_scale(
176181
return (
177182
torch.cat(new_weights, dim=0),
178183
resmoothed_scales.view(torch.float8_e4m3fn)
179-
if quantization == QUANTIZATION_NVFP4_AWQ
184+
if quantization in [QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT]
180185
else resmoothed_scales, # if NVFP4_AWQ we view the scales back as float8_e4m3fn after cat
181186
new_pre_quant_scale,
182187
)
@@ -243,6 +248,7 @@ def get_activation_scaling_factor(
243248
if get_quantization_format(module) in [
244249
QUANTIZATION_NVFP4,
245250
QUANTIZATION_NVFP4_AWQ,
251+
QUANTIZATION_NVFP4_SVDQUANT,
246252
]:
247253
return NVFP4QTensor.get_activation_scaling_factor(input_quantizer)
248254
return get_scaling_factor(input_quantizer)
@@ -270,6 +276,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
270276
if quantization_format in [
271277
QUANTIZATION_NVFP4,
272278
QUANTIZATION_NVFP4_AWQ,
279+
QUANTIZATION_NVFP4_SVDQUANT,
273280
QUANTIZATION_W4A8_NVFP4_FP8,
274281
]:
275282
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
@@ -303,6 +310,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
303310
if get_quantization_format(module) in [
304311
QUANTIZATION_NVFP4,
305312
QUANTIZATION_NVFP4_AWQ,
313+
QUANTIZATION_NVFP4_SVDQUANT,
306314
]:
307315
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
308316
elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8:
@@ -487,6 +495,8 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
487495
block_sizes = getattr(weight_quantizer, "block_sizes")
488496
scale_bits = block_sizes.get("scale_bits")
489497

498+
if input_quantizer is not None and hasattr(weight_quantizer, "svdquant_lora_a"):
499+
return QUANTIZATION_NVFP4_SVDQUANT
490500
if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"):
491501
return QUANTIZATION_NVFP4_AWQ
492502
if getattr(layer, "fused_with_prequant", False):
@@ -660,15 +670,18 @@ def process_layer_quant_config(layer_config_dict):
660670
elif v == "w4a8_nvfp4_fp8":
661671
layer_config = {
662672
"quant_algo": "W4A8_NVFP4_FP8",
663-
"group_size": layer_config_dict[prefix + ".awq_block_size"],
664-
"has_zero_point": False,
665-
"pre_quant_scale": True,
673+
"group_size": block_size_value,
666674
}
667675
elif v == "w4a8_mxfp4_fp8":
668676
layer_config = {
669677
"quant_algo": "W4A8_MXFP4_FP8",
670678
"group_size": block_size_value,
671679
}
680+
elif v == "nvfp4_svdquant":
681+
layer_config = {
682+
"quant_algo": "NVFP4_SVD",
683+
"group_size": block_size_value,
684+
}
672685
else:
673686
layer_config = {"quant_algo": v}
674687

@@ -813,7 +826,12 @@ def to_quantized_weight(
813826
if quantization in [QUANTIZATION_INT4_AWQ, QUANTIZATION_W4A8_AWQ]:
814827
return pack_int4_in_uint8(weight, weights_scaling_factor)
815828

816-
if quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_W4A8_NVFP4_FP8]:
829+
if quantization in [
830+
QUANTIZATION_NVFP4,
831+
QUANTIZATION_NVFP4_AWQ,
832+
QUANTIZATION_W4A8_NVFP4_FP8,
833+
QUANTIZATION_NVFP4_SVDQUANT,
834+
]:
817835
assert block_size is not None, "Block size not passed. Unable to quantize to NVFP4 format."
818836
assert weights_scaling_factor2 is not None, (
819837
"Weights scaling factor 2 not passed. Unable to quantize to NVFP4 format"
@@ -1014,6 +1032,40 @@ def _update_pre_quant_scale(module, new_pre_quant_scale):
10141032
finish_stats_collection(module.weight_quantizer)
10151033

10161034

1035+
def _update_svdquant(modules, new_pre_quant_scale):
1036+
"""Updates the pre_quant_scale, svdquant_lora_a and svdquant_lora_b matrices when pre_quant_scale is changed."""
1037+
new_pre_quant_scale = new_pre_quant_scale.to(torch.float32)
1038+
lora_a = [m.weight_quantizer.svdquant_lora_a.to(torch.float32) for m in modules]
1039+
lora_b = [m.weight_quantizer.svdquant_lora_b.to(torch.float32) for m in modules]
1040+
weight = [m.weight.to(torch.float32) for m in modules]
1041+
old_pre_quant_scale = [m.input_quantizer._pre_quant_scale.to(torch.float32) for m in modules]
1042+
weight = [
1043+
(w + (lb @ la)) * (s / new_pre_quant_scale)
1044+
for w, la, lb, s in zip(weight, lora_a, lora_b, old_pre_quant_scale)
1045+
]
1046+
weight_concatenated = torch.cat(weight, dim=0)
1047+
lb, la = svd(weight_concatenated, rank=lora_a[0].shape[0])
1048+
weight_concatenated -= lb @ la
1049+
weight_concatenated = weight_concatenated.to(modules[0].weight.dtype)
1050+
la = la.to(modules[0].weight_quantizer.svdquant_lora_a.dtype)
1051+
lb = lb.to(modules[0].weight_quantizer.svdquant_lora_b.dtype)
1052+
new_pre_quant_scale = new_pre_quant_scale.to(modules[0].input_quantizer.pre_quant_scale.dtype)
1053+
1054+
index = 0
1055+
for i, module in enumerate(modules):
1056+
module.input_quantizer.pre_quant_scale = new_pre_quant_scale
1057+
module.weight_quantizer.svdquant_lora_a = la
1058+
assert lora_b[i].shape[0] == module.weight.shape[0]
1059+
module.weight_quantizer.svdquant_lora_b = lb[index : index + lora_b[i].shape[0], :]
1060+
module.weight = nn.Parameter(weight_concatenated[index : index + lora_b[i].shape[0], :])
1061+
index += lora_b[i].shape[0]
1062+
# Redo weights collection
1063+
module.weight_quantizer.reset_amax()
1064+
enable_stats_collection(module.weight_quantizer)
1065+
module.weight_quantizer(module.weight)
1066+
finish_stats_collection(module.weight_quantizer)
1067+
1068+
10171069
# Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale))
10181070
PQS_FUSE_MODULE_MAPPING = [
10191071
# Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension
@@ -1166,9 +1218,14 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False
11661218
dim=0,
11671219
)
11681220

1169-
for module in modules:
1170-
if not torch.equal(module.input_quantizer.pre_quant_scale, avg_prequant_scale):
1171-
_update_pre_quant_scale(module, avg_prequant_scale)
1221+
if all(
1222+
getattr(m.weight_quantizer, "svdquant_lora_a", None) is not None for m in modules
1223+
):
1224+
_update_svdquant(modules, avg_prequant_scale)
1225+
else:
1226+
for module in modules:
1227+
if not torch.equal(module.input_quantizer.pre_quant_scale, avg_prequant_scale):
1228+
_update_pre_quant_scale(module, avg_prequant_scale)
11721229

11731230
if resmooth_only:
11741231
return

modelopt/torch/export/unified_export_hf.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
QUANTIZATION_NONE,
7171
QUANTIZATION_NVFP4,
7272
QUANTIZATION_NVFP4_AWQ,
73+
QUANTIZATION_NVFP4_SVDQUANT,
7374
QUANTIZATION_W4A8_AWQ,
7475
QUANTIZATION_W4A8_NVFP4_FP8,
7576
)
@@ -258,6 +259,10 @@ def requantize_resmooth_fused_llm_layers(model: torch.nn.Module):
258259
model_type = type(model).__name__.lower()
259260
module_names = set()
260261

262+
# NVFP4 SVDQuant does not need pre-quant scale fusion (either into previous linear or layernorm) because
263+
# 1) its kernel handles pre-quant scale.
264+
# 2) fusing into previous linear will need to change the lora_up in up_proj which may cause issue in
265+
# the later gate up fusion.
261266
# Fuse pre_quant_scale to the linear weights if possible
262267
if quantization_format is not None and "nvfp4_awq" in quantization_format.lower():
263268
fuse_prequant_to_linear(model)
@@ -268,7 +273,8 @@ def requantize_resmooth_fused_llm_layers(model: torch.nn.Module):
268273

269274
# For MoE models update pre_quant_scale to average pre_quant_scale amongst experts
270275
if is_moe(module) and (
271-
quantization_format is not QUANTIZATION_NONE and "awq" in quantization_format
276+
quantization_format is not QUANTIZATION_NONE
277+
and ("awq" in quantization_format or quantization_format == QUANTIZATION_NVFP4_SVDQUANT)
272278
):
273279
# update_experts_avg_prequant_scale(module)
274280
grouped_experts = get_experts_list(module, model_type)
@@ -439,6 +445,7 @@ def _export_quantized_weight(
439445

440446
if quantization_format in [
441447
QUANTIZATION_NVFP4_AWQ,
448+
QUANTIZATION_NVFP4_SVDQUANT,
442449
QUANTIZATION_NVFP4,
443450
QUANTIZATION_W4A8_AWQ,
444451
QUANTIZATION_W4A8_NVFP4_FP8,
@@ -459,7 +466,11 @@ def _export_quantized_weight(
459466
for expert_type in ["Llama4TextExperts", "GptOssExperts"]
460467
)
461468

462-
if quantization_format in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
469+
if quantization_format in [
470+
QUANTIZATION_NVFP4,
471+
QUANTIZATION_NVFP4_AWQ,
472+
QUANTIZATION_NVFP4_SVDQUANT,
473+
]:
463474
# Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
464475
# for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization
465476
weight, _ = maybe_transpose_expert_weight_dimensions(

modelopt/torch/quantization/model_calib.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,30 @@ def _get_awq_quantizer_block_size(tensor: torch.Tensor, quantizer: TensorQuantiz
10751075
return blocksize
10761076

10771077

1078+
def svd(weight, rank):
1079+
original_device = weight.device
1080+
original_dtype = weight.dtype
1081+
weight_f64 = weight.to(dtype=torch.float64, device=original_device)
1082+
u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False)
1083+
us = u[:, :rank] * s[:rank]
1084+
vt = vt[:rank]
1085+
us = us.to(device=original_device, dtype=original_dtype)
1086+
vt = vt.to(device=original_device, dtype=original_dtype)
1087+
if us.shape[1] < rank or vt.shape[0] < rank:
1088+
warnings.warn(
1089+
"The low-rank dimensions do not match the layer dimensions. "
1090+
"Please verify your configuration and model settings. "
1091+
f"Rank is {us.shape[1]} and {vt.shape[0]}"
1092+
)
1093+
us_temp = torch.zeros((us.shape[0], rank), dtype=us.dtype, device=us.device)
1094+
vt_temp = torch.zeros((rank, vt.shape[1]), dtype=vt.dtype, device=vt.device)
1095+
us_temp[:, : us.shape[1]] = us
1096+
vt_temp[: vt.shape[0], :] = vt
1097+
us = us_temp
1098+
vt = vt_temp
1099+
return us, vt
1100+
1101+
10781102
@torch.no_grad()
10791103
def svdquant(
10801104
model: nn.Module,
@@ -1096,25 +1120,9 @@ def svdquant(
10961120
def postprocess(module, name):
10971121
print_rank_0(f"SVD {name}")
10981122
weight = module.weight.data
1099-
original_device = weight.device
1100-
original_dtype = weight.dtype
1101-
weight_f64 = weight.to(dtype=torch.float64, device=original_device)
1102-
u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False)
1103-
if u.shape[1] < lowrank or vt.shape[0] < lowrank:
1104-
warnings.warn(
1105-
"The low-rank dimensions do not match the layer dimensions. "
1106-
"Please verify your configuration and model settings. "
1107-
f"SVD will be skipped for this layer {name}."
1108-
)
1109-
return
1110-
us = u[:, :lowrank] * s[:lowrank]
1111-
vt = vt[:lowrank]
1112-
module.weight_quantizer.svdquant_lora_a = vt.to(
1113-
dtype=original_dtype, device=original_device
1114-
)
1115-
module.weight_quantizer.svdquant_lora_b = us.to(
1116-
dtype=original_dtype, device=original_device
1117-
)
1123+
us, vt = svd(weight, lowrank)
1124+
module.weight_quantizer.svdquant_lora_a = vt
1125+
module.weight_quantizer.svdquant_lora_b = us
11181126
module.weight.data.sub_(
11191127
module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a
11201128
)

0 commit comments

Comments
 (0)