Skip to content

Commit c9b7f82

Browse files
committed
Add offload test for vllm fakequant export; CHANGELOG entry
Adds test_hf_vllm_export_offload covering the inplace_mem_efficient=True path of export_hf_vllm_fq_checkpoint on a CPU-offloaded tiny LLaMA. The test asserts the inplace path actually mutates offloaded layer weights (falsifying a silent fall-through to the copy path), that the reloaded HF model matches a deepcopy+fold_weight reference built inside enable_weight_access_and_writeback (materializes meta tensors before folding), and that the saved quantizer state preserves input amaxes. Also adds a CHANGELOG.rst bullet under 0.44 New Features describing the layerwise calibration feature and linking to the experts-only recipe. Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 5658381 commit c9b7f82

File tree

2 files changed

+121
-1
lines changed

2 files changed

+121
-1
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Changelog
1515
- Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml>`_ for more details.
1616
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
1717
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
18+
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). Each decoder layer is materialized once per calibration step instead of per-batch, enabling larger batch sizes during PTQ. Includes per-layer checkpoint save/resume so calibration can survive cluster time limits. See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml>`_ for usage.
1819

1920
**Backward Breaking Changes**
2021

tests/gpu/torch/export/test_vllm_fakequant_hf_export.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,19 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import copy
1516
from copy import deepcopy
1617

1718
import pytest
1819
import torch
1920
from _test_utils.torch.transformers_models import create_tiny_llama_dir
20-
from transformers import AutoModelForCausalLM
21+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
22+
from transformers import AutoConfig, AutoModelForCausalLM
2123

2224
import modelopt.torch.quantization as mtq
2325
from modelopt.torch.export import export_hf_vllm_fq_checkpoint
2426
from modelopt.torch.quantization.model_quant import fold_weight
27+
from modelopt.torch.quantization.utils import enable_weight_access_and_writeback
2528
from modelopt.torch.utils import safe_load
2629

2730

@@ -111,3 +114,119 @@ def forward_loop(model):
111114
"_amax" in k for k in quantizer_state_dict_before[name]
112115
):
113116
assert any("_amax" in k for k in state), f"input quantizer {name} should preserve _amax"
117+
118+
119+
def _make_cpu_offloaded_model(tmp_path, num_hidden_layers=3):
120+
"""Create a tiny LLaMA model with layer 0 offloaded to CPU via accelerate."""
121+
tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_hidden_layers)
122+
config = AutoConfig.from_pretrained(tiny_llama_dir)
123+
124+
with init_empty_weights():
125+
model = AutoModelForCausalLM.from_config(config)
126+
127+
device_map = {
128+
n: 0
129+
for n, m in model.named_modules()
130+
if "layers" not in n or n.split("layers.")[-1].isdigit()
131+
}
132+
device_map["model.layers.0"] = "cpu"
133+
134+
model = load_checkpoint_and_dispatch(model, tiny_llama_dir, device_map=device_map)
135+
return model, config, tiny_llama_dir
136+
137+
138+
def _make_layerwise_cfg(base_cfg):
139+
"""Add layerwise=True to a quant config's algorithm field."""
140+
cfg = copy.deepcopy(base_cfg)
141+
algo = cfg.get("algorithm", "max")
142+
if isinstance(algo, str):
143+
cfg["algorithm"] = {"method": algo, "layerwise": True}
144+
else:
145+
algo["layerwise"] = True
146+
return cfg
147+
148+
149+
@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG])
150+
def test_hf_vllm_export_offload(tmp_path, quant_cfg):
151+
"""Test ``inplace_mem_efficient=True`` export path on a CPU-offloaded model.
152+
153+
Mirrors ``test_hf_vllm_export`` but uses a CPU-offloaded model with layerwise
154+
calibration. Skips the "model not mutated" assertion since the inplace path
155+
is intentionally destructive.
156+
"""
157+
num_hidden_layers = 3
158+
159+
# Test model: CPU-offloaded, layerwise calibration
160+
model, _config, tiny_llama_dir = _make_cpu_offloaded_model(
161+
tmp_path / "offloaded", num_hidden_layers=num_hidden_layers
162+
)
163+
model.eval()
164+
165+
seq_cfg = _make_layerwise_cfg(quant_cfg)
166+
167+
def forward_loop(model):
168+
input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda()
169+
with torch.no_grad():
170+
model(input_ids)
171+
172+
model = mtq.quantize(model, seq_cfg, forward_loop)
173+
quantizer_state_dict_before = mtq.utils.get_quantizer_state_dict(model)
174+
175+
folded_model = deepcopy(model)
176+
with enable_weight_access_and_writeback(folded_model.model.layers[0], folded_model):
177+
fold_weight(folded_model)
178+
expected_weights = {
179+
k: v.detach().clone()
180+
for k, v in folded_model.state_dict().items()
181+
if "quantizer" not in k
182+
}
183+
del folded_model
184+
185+
export_dir = tmp_path / "vllm_export_offload"
186+
export_dir.mkdir(exist_ok=True)
187+
188+
# Snapshot the offloaded layer's weight before/after export to verify the
189+
# inplace_mem_efficient path actually mutates offloaded weights (would otherwise
190+
# be unfalsifiable if the function silently took the copy path).
191+
with enable_weight_access_and_writeback(model.model.layers[0], model):
192+
weight_before = model.model.layers[0].self_attn.q_proj.weight.data.clone()
193+
194+
export_hf_vllm_fq_checkpoint(model, export_dir=export_dir, inplace_mem_efficient=True)
195+
196+
with enable_weight_access_and_writeback(model.model.layers[0], model):
197+
weight_after = model.model.layers[0].self_attn.q_proj.weight.data.clone()
198+
assert not torch.equal(weight_before, weight_after), (
199+
"inplace path must mutate offloaded layer weights"
200+
)
201+
202+
modelopt_state_file = export_dir / "vllm_fq_modelopt_state.pth"
203+
assert modelopt_state_file.exists(), (
204+
f"vllm_fq_modelopt_state.pth file should be created in {export_dir}"
205+
)
206+
207+
hf_quant_config_file = export_dir / "hf_quant_config.json"
208+
assert not hf_quant_config_file.exists(), (
209+
f"hf_quant_config.json file should not be created in {export_dir}"
210+
)
211+
212+
model_after = AutoModelForCausalLM.from_pretrained(export_dir).cuda()
213+
model_after.eval()
214+
model_after_state_dict = model_after.state_dict()
215+
for key, param in expected_weights.items():
216+
assert torch.allclose(param, model_after_state_dict[key], atol=1e-6), (
217+
f"Weight mismatch for {key}: "
218+
f"before shape={param.shape}, after shape={model_after_state_dict[key].shape}, "
219+
f"max diff={torch.abs(param - model_after_state_dict[key]).max()}"
220+
)
221+
222+
quantizer_state_dict = safe_load(modelopt_state_file)["modelopt_state_weights"]
223+
assert len(quantizer_state_dict) > 0, (
224+
f"modelopt_state_weights should not be empty in {modelopt_state_file}"
225+
)
226+
for name, state in quantizer_state_dict.items():
227+
if "weight_quantizer" in name:
228+
assert state == {}, f"weight quantizer {name} should have empty state after fold"
229+
elif "input_quantizer" in name and any(
230+
"_amax" in k for k in quantizer_state_dict_before[name]
231+
):
232+
assert any("_amax" in k for k in state), f"input quantizer {name} should preserve _amax"

0 commit comments

Comments
 (0)