Skip to content

Commit 4fcb798

Browse files
authored
Merge branch 'main' into nvfp4-block-size-validation
2 parents 19d26f8 + 6a3b6b8 commit 4fcb798

6 files changed

Lines changed: 247 additions & 29 deletions

File tree

examples/llm_ptq/scripts/huggingface_example.sh

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,7 @@ dense | sparsegpt) ;;
4949
;;
5050
esac
5151

52-
#Iterate over list of qformats provided and check if they are valid
53-
IFS=","
54-
for qformat in $QFORMAT; do
55-
case $qformat in
56-
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian) ;;
57-
*)
58-
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian]" >&2
59-
exit 1
60-
;;
61-
esac
62-
done
63-
IFS=" "
52+
# Quant format / recipe validation is delegated to hf_ptq.py.
6453

6554
script_dir="$(dirname "$(readlink -f "$0")")"
6655

@@ -72,7 +61,14 @@ fi
7261

7362
QFORMAT_MODIFIED="${QFORMAT//,/_}"
7463

75-
MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}}
64+
# When using --recipe, build the model name from the recipe basename (without
65+
# directory or .yaml suffix) so each recipe gets its own SAVE_PATH.
66+
if [ -n "$RECIPE" ]; then
67+
RECIPE_TAG=$(basename "$RECIPE" .yaml | sed 's/[^0-9a-zA-Z\-]/_/g')
68+
MODEL_NAME=$(basename "$MODEL_PATH" | sed 's/[^0-9a-zA-Z\-]/_/g')_recipe_${RECIPE_TAG}
69+
else
70+
MODEL_NAME=$(basename "$MODEL_PATH" | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}}
71+
fi
7672

7773
SAVE_PATH=${ROOT_SAVE_PATH}/saved_models_${MODEL_NAME}
7874

@@ -164,24 +160,18 @@ fi
164160

165161
if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH) ]]; then
166162

167-
if [ "$qformat" == "bf16" ] || [ "$qformat" == "fp16" ]; then
168-
if [ -d "$MODEL_PATH" ]; then
169-
MODEL_CONFIG_EXIST=true
170-
MODEL_CONFIG=$MODEL_PATH/config.json
171-
for file in $MODEL_PATH/*; do ln -sf "$file" $SAVE_PATH/; done
172-
else
173-
echo "Please use the model directory where the config.json file is present."
174-
exit 1
175-
fi
176-
fi
177-
178163
if [[ "$MODEL_CONFIG_EXIST" == false ]]; then
179164
echo "Quantizing original model..."
165+
if [ -n "$RECIPE" ]; then
166+
QUANT_SPEC_ARGS="--recipe=$RECIPE"
167+
else
168+
QUANT_SPEC_ARGS="--qformat=${QFORMAT// /,}"
169+
fi
180170
python hf_ptq.py \
181171
--pyt_ckpt_path=$MODEL_PATH \
182172
--export_path=$SAVE_PATH \
183173
--sparsity_fmt=$SPARSITY_FMT \
184-
--qformat="${QFORMAT// /,}" \
174+
$QUANT_SPEC_ARGS \
185175
--calib_size=$CALIB_SIZE \
186176
--batch_size=$CALIB_BATCH_SIZE \
187177
--inference_tensor_parallel=$TP \
@@ -203,7 +193,7 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
203193
exit 0
204194
fi
205195

206-
if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]]; then
196+
if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]] || [[ "$RECIPE" == *"nvfp4"* ]]; then
207197
cuda_major=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader -i 0 | cut -d. -f1)
208198

209199
if [ "$cuda_major" -lt 10 ]; then
@@ -212,6 +202,11 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
212202
fi
213203
fi
214204

205+
if [ -n "$RECIPE" ]; then
206+
echo "Recipe $RECIPE used. Please deploy with TensorRT-LLM directly. Checkpoint export_path: $SAVE_PATH"
207+
exit 0
208+
fi
209+
215210
if [[ ! " fp8 nvfp4 bf16 fp16 " =~ " ${QFORMAT} " ]]; then
216211
echo "Quant $QFORMAT specified. Please read TensorRT-LLM quantization support matrix https://nvidia.github.io/TensorRT-LLM/features/quantization.html#quantization-in-tensorrt-llm and use TensorRT-LLM for deployment. Checkpoint export_path: $SAVE_PATH"
217212
exit 0

examples/llm_ptq/scripts/parser.sh

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ parse_options() {
2020
# Default values
2121
MODEL_PATH=""
2222
QFORMAT=""
23+
RECIPE=""
2324
KV_CACHE_QUANT=""
2425
TP=1
2526
PP=1
@@ -37,13 +38,14 @@ parse_options() {
3738
CAST_MXFP4_TO_NVFP4=false
3839

3940
# Parse command-line options
40-
ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@")
41+
ARGS=$(getopt -o "" -l "model:,quant:,recipe:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@")
4142

4243
eval set -- "$ARGS"
4344
while true; do
4445
case "$1" in
4546
--model ) MODEL_PATH="$2"; shift 2;;
4647
--quant ) QFORMAT="$2"; shift 2;;
48+
--recipe ) RECIPE="$2"; shift 2;;
4749
--kv_cache_quant ) KV_CACHE_QUANT="$2"; shift 2;;
4850
--tp ) TP="$2"; shift 2;;
4951
--pp ) PP="$2"; shift 2;;
@@ -99,12 +101,19 @@ parse_options() {
99101
fi
100102

101103
# Verify required options are provided
102-
if [ -z "$MODEL_PATH" ] || [ -z "$QFORMAT" ] || [ -z "$TASKS" ]; then
103-
echo "Usage: $0 --model=<MODEL_PATH> --quant=<QFORMAT> --tasks=<TASK,...>"
104+
if [ -z "$MODEL_PATH" ] || [ -z "$TASKS" ] || ([ -z "$QFORMAT" ] && [ -z "$RECIPE" ]); then
105+
echo "Usage: $0 --model=<MODEL_PATH> (--quant=<QFORMAT> | --recipe=<RECIPE>) --tasks=<TASK,...>"
104106
echo "Optional args: --sparsity=<SPARSITY_FMT> --awq_block_size=<AWQ_BLOCK_SIZE> --calib=<CALIB_SIZE>"
105107
exit 1
106108
fi
107109

110+
# --quant and --recipe are mutually exclusive: --recipe is a full PTQ spec, while
111+
# --quant selects a built-in qformat preset. Pick exactly one.
112+
if [ -n "$QFORMAT" ] && [ -n "$RECIPE" ]; then
113+
echo "Cannot specify both --quant and --recipe; pick one." >&2
114+
exit 1
115+
fi
116+
108117
VALID_TASKS=("quant" "mmlu" "lm_eval" "livecodebench" "simple_eval")
109118

110119
for task in $(echo "$TASKS" | tr ',' ' '); do
@@ -135,6 +144,7 @@ parse_options() {
135144
echo "================="
136145
echo "model: $MODEL_PATH"
137146
echo "quant: $QFORMAT"
147+
echo "recipe: $RECIPE"
138148
echo "tp (TensorRT-LLM Checkpoint only): $TP"
139149
echo "pp (TensorRT-LLM Checkpoint only): $PP"
140150
echo "sparsity: $SPARSITY_FMT"

modelopt/torch/export/moe_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,29 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
6262
for idx in range(n):
6363
expert = nn.Module()
6464

65+
# If the gate_up source quantizer was never calibrated (rare expert
66+
# that received no calibration tokens), derive its amax once from the
67+
# FUSED tensor so gate and up share the same weight_scale_2 below.
68+
# Why: vLLM fuses W1 (gate) and W3 (up) at load time and asserts a
69+
# single per-tensor scale across the fusion. The per-projection
70+
# fallback further down would otherwise compute amax independently from
71+
# each half — gate's max and up's max generally differ — producing
72+
# mismatched weight_scale_2 and garbled MoE output at inference.
73+
gate_up_q = module.gate_up_proj_weight_quantizers[idx]
74+
if getattr(gate_up_q, "is_enabled", False) and (
75+
not hasattr(gate_up_q, "_amax")
76+
or gate_up_q._amax is None
77+
or torch.all(gate_up_q._amax == 0)
78+
):
79+
gate_up_q.amax = gate_up[idx].abs().amax().to(torch.float32)
80+
warnings.warn(
81+
f"Expert {idx} gate_up_proj weight quantizer was not calibrated "
82+
f"(amax missing or zero). Using fused-tensor amax as fallback "
83+
f"(shared by gate and up so weight_scale_2 stays consistent). "
84+
f"Consider increasing calibration size to activate all experts.",
85+
stacklevel=2,
86+
)
87+
6588
projections = [
6689
("gate_proj", gate_up[idx, :expert_dim, :], 0, fused_dim0, True),
6790
("up_proj", gate_up[idx, expert_dim:, :], expert_dim, fused_dim0, True),
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
imports:
17+
base_disable_all: configs/ptq/units/base_disable_all
18+
default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers
19+
nvfp4: configs/numerics/nvfp4
20+
nvfp4_static: configs/numerics/nvfp4_static
21+
kv_fp8_cast: configs/ptq/units/kv_fp8_cast
22+
23+
metadata:
24+
recipe_type: ptq
25+
description: NVFP4 static weight (MSE FP8-scale sweep) and dynamic activation for expert layers only (W4A4), FP8 KV cache with constant amax.
26+
quantize:
27+
algorithm:
28+
method: mse
29+
fp8_scale_sweep: true
30+
# layerwise=false required for VLMs where the decoder layers are nested under
31+
# `model.language_model.layers` (layerwise_calibrate can't find them otherwise).
32+
layerwise: false
33+
quant_cfg:
34+
- $import: base_disable_all
35+
- quantizer_name: '*mlp.experts*weight_quantizer'
36+
cfg:
37+
$import: nvfp4_static
38+
- quantizer_name: '*mlp.experts*input_quantizer'
39+
cfg:
40+
$import: nvfp4
41+
- quantizer_name: '*block_sparse_moe*weight_quantizer'
42+
cfg:
43+
$import: nvfp4_static
44+
- quantizer_name: '*block_sparse_moe*input_quantizer'
45+
cfg:
46+
$import: nvfp4
47+
- $import: kv_fp8_cast
48+
- $import: default_disabled_quantizers
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
imports:
17+
base_disable_all: configs/ptq/units/base_disable_all
18+
default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers
19+
nvfp4: configs/numerics/nvfp4
20+
nvfp4_static: configs/numerics/nvfp4_static
21+
kv_fp8_cast: configs/ptq/units/kv_fp8_cast
22+
23+
metadata:
24+
recipe_type: ptq
25+
description: NVFP4 static weight (MSE FP8-scale sweep) and dynamic activation for MLP/MoE linear layers (W4A4), FP8 KV cache with constant amax.
26+
quantize:
27+
algorithm:
28+
method: mse
29+
fp8_scale_sweep: true
30+
# layerwise=false required for VLMs where the decoder layers are nested under
31+
# `model.language_model.layers` (layerwise_calibrate can't find them otherwise).
32+
layerwise: false
33+
quant_cfg:
34+
- $import: base_disable_all
35+
- quantizer_name: '*mlp*weight_quantizer'
36+
cfg:
37+
$import: nvfp4_static
38+
- quantizer_name: '*mlp*input_quantizer'
39+
cfg:
40+
$import: nvfp4
41+
- quantizer_name: '*block_sparse_moe*weight_quantizer'
42+
cfg:
43+
$import: nvfp4_static
44+
- quantizer_name: '*block_sparse_moe*input_quantizer'
45+
cfg:
46+
$import: nvfp4
47+
- quantizer_name: '*.experts.*weight_quantizer'
48+
cfg:
49+
$import: nvfp4_static
50+
- quantizer_name: '*.experts.*input_quantizer'
51+
cfg:
52+
$import: nvfp4
53+
- $import: kv_fp8_cast
54+
- $import: default_disabled_quantizers

tests/unit/torch/quantization/plugins/test_fused_experts.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,94 @@ def test_export_creates_per_expert_submodules(self):
300300
if QuantModuleRegistry.get(expert_type) is not None:
301301
QuantModuleRegistry.unregister(expert_type)
302302

303+
def test_uncalibrated_expert_gate_up_share_amax(self, monkeypatch):
304+
"""gate_proj and up_proj must share weight_scale_2 even when an expert
305+
was never routed during calibration.
306+
307+
Regression for the bug where ``_export_fused_experts``'s per-projection
308+
fallback computed amax independently from the gate and up halves of the
309+
fused tensor — producing mismatched ``weight_scale_2`` values for any
310+
uncalibrated expert. vLLM fuses W1 (gate) and W3 (up) at load time and
311+
asserts a single shared scale; mismatched scales corrupted MoE output.
312+
The fix derives the fallback amax once from the fused ``gate_up[idx]``
313+
tensor before the deepcopies, so gate's clone and up's clone start with
314+
the same amax.
315+
"""
316+
from modelopt.torch.export.moe_utils import _export_fused_experts
317+
318+
# Build experts where gate and up have very different magnitudes —
319+
# any per-half fallback would clearly produce different amaxes.
320+
experts = _SyntheticFusedExperts()
321+
gate = torch.randn(NUM_EXPERTS, INTERMEDIATE_DIM, HIDDEN_DIM) * 0.02
322+
up = torch.randn(NUM_EXPERTS, INTERMEDIATE_DIM, HIDDEN_DIM) * 0.20
323+
with torch.no_grad():
324+
experts.gate_up_proj.copy_(torch.cat([gate, up], dim=1))
325+
326+
expert_type = type(experts)
327+
if QuantModuleRegistry.get(expert_type) is None:
328+
QuantModuleRegistry.register({expert_type: "test.SyntheticFusedExperts"})(
329+
_QuantFusedExperts
330+
)
331+
try:
332+
converted = QuantModuleRegistry.convert(experts)
333+
334+
# Leave every expert weight quantizer uncalibrated (no _amax).
335+
# Mark them enabled to exercise the export-time fallback path.
336+
for q in converted.gate_up_proj_weight_quantizers:
337+
q._disabled = False
338+
for q in converted.down_proj_weight_quantizers:
339+
q._disabled = False
340+
341+
# Capture the amax each per-projection wrapper carries into the
342+
# FP4 quantization step. Patching here avoids needing CUDA / FP4.
343+
seen = {} # (expert_idx, proj_name) -> amax tensor
344+
345+
def _spy_export(wrapper, dtype):
346+
# Identify which expert/projection this wrapper belongs to by
347+
# matching the weight tensor against the fused parameters.
348+
w = wrapper.weight.data
349+
# gate_up_proj is (N, 2*INTER, HIDDEN); split halves are
350+
# contiguous .data views or .contiguous() copies — we can match
351+
# by shape and value identity for this synthetic case.
352+
amax = wrapper.weight_quantizer._amax.detach().clone()
353+
# Identify by matching against gate vs. up slices of each expert.
354+
for idx in range(NUM_EXPERTS):
355+
g_slice = converted.gate_up_proj.data[idx, :INTERMEDIATE_DIM, :]
356+
u_slice = converted.gate_up_proj.data[idx, INTERMEDIATE_DIM:, :]
357+
d_slice = converted.down_proj.data[idx]
358+
if w.shape == g_slice.shape and torch.equal(w, g_slice):
359+
seen[(idx, "gate_proj")] = amax
360+
return
361+
if w.shape == u_slice.shape and torch.equal(w, u_slice):
362+
seen[(idx, "up_proj")] = amax
363+
return
364+
if w.shape == d_slice.shape and torch.equal(w, d_slice):
365+
seen[(idx, "down_proj")] = amax
366+
return
367+
368+
monkeypatch.setattr(
369+
"modelopt.torch.export.unified_export_hf._export_quantized_weight",
370+
_spy_export,
371+
)
372+
373+
_export_fused_experts(converted, torch.float16)
374+
375+
# Assert: for every expert, gate's amax matches up's amax.
376+
for idx in range(NUM_EXPERTS):
377+
g_amax = seen.get((idx, "gate_proj"))
378+
u_amax = seen.get((idx, "up_proj"))
379+
assert g_amax is not None and u_amax is not None, (
380+
f"Expert {idx}: missing recorded amax (gate={g_amax}, up={u_amax})"
381+
)
382+
assert torch.allclose(g_amax, u_amax), (
383+
f"Expert {idx}: gate amax {g_amax.item()} != up amax {u_amax.item()}. "
384+
f"Uncalibrated fused experts must share gate/up amax so that "
385+
f"weight_scale_2 stays consistent across the fusion."
386+
)
387+
finally:
388+
if QuantModuleRegistry.get(expert_type) is not None:
389+
QuantModuleRegistry.unregister(expert_type)
390+
303391

304392
# ---------------------------------------------------------------------------
305393
# Tests for force_eager_experts_impl_on_the_fly

0 commit comments

Comments
 (0)