Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 156 additions & 43 deletions modelopt/torch/export/plugins/vllm_fakequant_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
# limitations under the License.
"""Export HuggingFace model to vLLM fakequant checkpoint."""

import logging
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path

import torch
Expand All @@ -28,6 +33,93 @@

__all__ = ["export_hf_vllm_fq_checkpoint"]

logger = logging.getLogger(__name__)


@dataclass
class _WeightQuantWork:
"""A single weight tensor to be fake-quantized during export."""

sd_key: str
quantizer: TensorQuantizer
weight: torch.Tensor
# For optional pre_quant_scale folding:
inp_q: TensorQuantizer | None
inp_q_key: str | None


def _collect_quant_work(
model: nn.Module, state_dict: dict[str, torch.Tensor]
) -> list[_WeightQuantWork]:
"""Collect all weight quantization work items from the model."""
work_items = []
seen_keys: set[str] = set()
for module_name, module in model.named_modules():
if not isinstance(module, QuantModule):
continue
for attr_name, quantizer in module.named_children():
if not (
attr_name.endswith("weight_quantizer")
and isinstance(quantizer, TensorQuantizer)
and quantizer.fake_quant
and quantizer.is_enabled
):
continue
weight_name = attr_name.removesuffix("_quantizer")
prefix = f"{module_name}." if module_name else ""
sd_key = f"{prefix}{weight_name}"
assert sd_key not in seen_keys, f"Weight {sd_key} has already been fakequantized"
seen_keys.add(sd_key)
if sd_key not in state_dict:
continue
# Check for pre_quant_scale folding eligibility.
inp_q = None
inp_q_key = None
inp_attr = attr_name.replace("weight_quantizer", "input_quantizer")
if hasattr(module, inp_attr):
candidate = getattr(module, inp_attr)
if (
hasattr(candidate, "_pre_quant_scale")
and candidate._pre_quant_scale is not None
and candidate._disabled
and getattr(candidate, "_enable_pre_quant_scale", True)
):
inp_q = candidate
inp_q_key = get_unwrapped_name(
f"{module_name}.{inp_attr}" if module_name else inp_attr, model
)
work_items.append(
_WeightQuantWork(
sd_key=sd_key,
quantizer=quantizer,
weight=state_dict[sd_key],
inp_q=inp_q,
inp_q_key=inp_q_key,
)
)
return work_items


def _process_weight(item: _WeightQuantWork) -> tuple[str, torch.Tensor, str | None]:
"""Fake-quantize a single weight tensor and optionally fold pre_quant_scale.

Returns (sd_key, quantized_weight_on_cpu, inp_q_key_or_None).
"""
w = item.weight
w_quant = item.quantizer(w.float()).to(w.dtype)
if item.inp_q is not None:
scale = item.inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)
return item.sd_key, w_quant.cpu(), item.inp_q_key


def _process_device_batch(items: list[_WeightQuantWork], device: torch.device):
"""Process all weight items on a single GPU. Runs in a dedicated thread."""
with torch.inference_mode(), torch.cuda.device(device):
results = [_process_weight(item) for item in items]
torch.cuda.synchronize(device)
return results


def disable_rotate(quantizer: TensorQuantizer):
"""Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type."""
Expand All @@ -41,6 +133,7 @@ def disable_rotate(quantizer: TensorQuantizer):
def export_hf_vllm_fq_checkpoint(
model: nn.Module,
export_dir: Path | str,
parallel: bool = True,
):
"""Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload.

Expand All @@ -53,6 +146,9 @@ def export_hf_vllm_fq_checkpoint(
Args:
model: In-memory quantized model.
export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``.
parallel: If True, fake-quantize weights across GPUs concurrently using
one thread per GPU device. Falls back to sequential when all weights
are on the same device or on CPU. Default True.
"""
export_dir = Path(export_dir)
export_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -62,50 +158,66 @@ def export_hf_vllm_fq_checkpoint(
# parameters are never modified. Apply each weight quantizer's fake-quant
# to the corresponding weight tensor in the copy.
state_dict = model.state_dict()
fakequant_weights = set()
input_quantizers_folded_pqs = (
set()
) # keys for input_quantizers where pre_quant_scale was folded
fakequant_weights: set[str] = set()
input_quantizers_folded_pqs: set[str] = set()

work_items = _collect_quant_work(model, state_dict)

# Group work items by device for parallel dispatch.
device_groups: dict[torch.device, list[_WeightQuantWork]] = defaultdict(list)
for item in work_items:
device_groups[item.weight.device].append(item)

num_cuda_devices = sum(1 for d in device_groups if d.type == "cuda")
use_parallel = parallel and num_cuda_devices > 1

t0 = time.monotonic()
with torch.inference_mode():
for module_name, module in model.named_modules():
if not isinstance(module, QuantModule):
continue
for attr_name, quantizer in module.named_children():
if not (
attr_name.endswith("weight_quantizer")
and isinstance(quantizer, TensorQuantizer)
and quantizer.fake_quant
and quantizer.is_enabled
):
continue
weight_name = attr_name.removesuffix("_quantizer")
prefix = f"{module_name}." if module_name else ""
sd_key = f"{prefix}{weight_name}"
assert sd_key not in fakequant_weights, (
f"Weight {sd_key} has already been fakequantized"
)
if sd_key in state_dict:
w = state_dict[sd_key]
w_quant = quantizer(w.float()).to(w.dtype).cpu()
# Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
# Only valid when input_quantizer does NOT fake-quant activations. If it does
# fake_quant(x*s), the non-linearity prevents folding s into W.
inp_attr = attr_name.replace("weight_quantizer", "input_quantizer")
if hasattr(module, inp_attr):
inp_q = getattr(module, inp_attr)
if (
hasattr(inp_q, "_pre_quant_scale")
and inp_q._pre_quant_scale is not None
and inp_q._disabled
):
scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)
inp_q_key = get_unwrapped_name(
f"{module_name}.{inp_attr}" if module_name else inp_attr, model
)
if use_parallel:
logger.info(
"Parallel export: %d weights across %d GPUs (%s)",
len(work_items),
num_cuda_devices,
", ".join(f"{d}: {len(items)} weights" for d, items in device_groups.items()),
)
with ThreadPoolExecutor(max_workers=num_cuda_devices) as pool:
# Submit GPU batches first (non-blocking)
futures = [
pool.submit(_process_device_batch, items, device)
for device, items in device_groups.items()
if device.type == "cuda"
]
# Process CPU weights inline while GPU futures run
for device, items in device_groups.items():
if device.type != "cuda":
for sd_key, w_quant, inp_q_key in map(_process_weight, items):
state_dict[sd_key] = w_quant
fakequant_weights.add(sd_key)
if inp_q_key is not None:
input_quantizers_folded_pqs.add(inp_q_key)
# Collect GPU results
for future in futures:
for sd_key, w_quant, inp_q_key in future.result():
state_dict[sd_key] = w_quant
fakequant_weights.add(sd_key)
if inp_q_key is not None:
input_quantizers_folded_pqs.add(inp_q_key)
state_dict[sd_key] = w_quant
fakequant_weights.add(sd_key)
else:
# Sequential fallback (single GPU, CPU, or parallel=False).
for item in work_items:
sd_key, w_quant, inp_q_key = _process_weight(item)
state_dict[sd_key] = w_quant
fakequant_weights.add(sd_key)
if inp_q_key is not None:
input_quantizers_folded_pqs.add(inp_q_key)

elapsed = time.monotonic() - t0
logger.info(
"Export step 1 (%s): %d weights fake-quantized in %.1fs",
"parallel" if use_parallel else "sequential",
len(fakequant_weights),
elapsed,
)

# Filter quantizer tensors out for a clean HF checkpoint.
clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k}
Expand Down Expand Up @@ -166,4 +278,5 @@ def export_hf_vllm_fq_checkpoint(

for wq, orig_rotate in wqs_to_restore:
wq.enable()
wq._rotate = orig_rotate
if orig_rotate is not None:
wq._rotate = orig_rotate
139 changes: 139 additions & 0 deletions tests/gpu/torch/export/test_vllm_fakequant_hf_parallel_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test parallel vs sequential export produces identical outputs."""

import pytest
import torch
from _test_utils.torch.transformers_models import create_tiny_llama_dir
from transformers import AutoModelForCausalLM

import modelopt.torch.quantization as mtq
from modelopt.torch.export import export_hf_vllm_fq_checkpoint


def _quantize_model(tmp_path, suffix=""):
"""Create and quantize a tiny LLaMA model. Returns (model, export_dir)."""
tiny_model_dir = create_tiny_llama_dir(tmp_path / f"model{suffix}", num_hidden_layers=4)
model = AutoModelForCausalLM.from_pretrained(tiny_model_dir)
model = model.cuda()
model.eval()

def forward_loop(model):
input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda()
with torch.no_grad():
model(input_ids)

model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop)
return model


@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG])
def test_parallel_vs_sequential_identical(tmp_path, quant_cfg):
"""Verify parallel export produces bitwise identical output to sequential."""
num_gpus = torch.cuda.device_count()
if num_gpus < 2:
pytest.skip("Need >= 2 GPUs for parallel export test")

# Create a tiny model and spread across GPUs.
tiny_model_dir = create_tiny_llama_dir(tmp_path / "model", num_hidden_layers=4)
model = AutoModelForCausalLM.from_pretrained(
tiny_model_dir, device_map="auto", torch_dtype=torch.float16
)
model.eval()

def forward_loop(model):
first_device = next(model.parameters()).device
input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).to(first_device)
with torch.no_grad():
model(input_ids)

model = mtq.quantize(model, quant_cfg, forward_loop)

# Export sequentially.
seq_dir = tmp_path / "export_sequential"
export_hf_vllm_fq_checkpoint(model, export_dir=seq_dir, parallel=False)

# Re-enable weight quantizers (export disables them — need to restore for second export).
# The function already re-enables them at the end, so we can just call it again.

# Export in parallel.
par_dir = tmp_path / "export_parallel"
export_hf_vllm_fq_checkpoint(model, export_dir=par_dir, parallel=True)

# Compare HF weights.
seq_model = AutoModelForCausalLM.from_pretrained(seq_dir)
par_model = AutoModelForCausalLM.from_pretrained(par_dir)
seq_sd = seq_model.state_dict()
par_sd = par_model.state_dict()

assert seq_sd.keys() == par_sd.keys(), "Key mismatch between sequential and parallel export"
for key in seq_sd:
assert torch.allclose(seq_sd[key], par_sd[key]), (
f"Weight mismatch for {key}: max diff={torch.abs(seq_sd[key] - par_sd[key]).max()}"
)

# Compare full modelopt state payload (weights_only=False: modelopt_state contains
# Python objects — dicts, strings, quantizer configs — that require unpickling).
seq_state = torch.load(seq_dir / "vllm_fq_modelopt_state.pth", weights_only=False)
par_state = torch.load(par_dir / "vllm_fq_modelopt_state.pth", weights_only=False)
Comment on lines +89 to +90
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

cat -n tests/gpu/torch/export/test_vllm_fakequant_hf_parallel_export.py | sed -n '50,120p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3520


🏁 Script executed:

head -30 tests/gpu/torch/export/test_vllm_fakequant_hf_parallel_export.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1371


Add inline comment documenting why weights_only=False is safe.

The checkpoint files are test-generated on lines 66 and 73, so they are internally trusted. However, per coding guidelines, an inline comment must be added to document this whenever weights_only=False is used.

🔐 Fix
     # Compare modelopt state.
+    # Safe: both files are written by this test above and are not user-supplied.
     seq_state = torch.load(seq_dir / "vllm_fq_modelopt_state.pth", weights_only=False)
     par_state = torch.load(par_dir / "vllm_fq_modelopt_state.pth", weights_only=False)

As per coding guidelines: "Do not use torch.load(..., weights_only=False) unless a documented exception is provided. If weights_only=False is required, provide an inline comment explaining why and confirming the file is internally-generated or trusted."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/export/test_vllm_fakequant_hf_parallel_export.py` around
lines 88 - 89, Add an inline comment next to the two torch.load calls where
weights_only=False (the assignments to seq_state and par_state loading
"vllm_fq_modelopt_state.pth" from seq_dir and par_dir) that documents this
exception: note that these checkpoint files are generated earlier in the test
(trusted/test-generated on lines where the test writes the files) so using
weights_only=False is safe and intentional per guidelines; reference seq_dir,
par_dir and the filename in the comment for clarity.


# Compare modelopt_state_dict (quantizer metadata including quantizer_state).
seq_msd = seq_state.get("modelopt_state_dict", [])
par_msd = par_state.get("modelopt_state_dict", [])
assert len(seq_msd) == len(par_msd), "modelopt_state_dict length mismatch"
for (seq_mode, seq_ms), (par_mode, par_ms) in zip(seq_msd, par_msd):
assert seq_mode == par_mode, f"Mode mismatch: {seq_mode} vs {par_mode}"

# Compare modelopt_state_weights (per-quantizer tensor state).
seq_qsd = seq_state["modelopt_state_weights"]
par_qsd = par_state["modelopt_state_weights"]
assert seq_qsd.keys() == par_qsd.keys(), "Quantizer state dict key mismatch"
for key in seq_qsd:
seq_val = seq_qsd[key]
par_val = par_qsd[key]
if isinstance(seq_val, dict):
for k in seq_val:
if isinstance(seq_val[k], torch.Tensor):
assert torch.equal(seq_val[k], par_val[k]), (
f"Quantizer state mismatch for {key}.{k}"
)
else:
assert seq_val[k] == par_val[k], f"Quantizer state mismatch for {key}.{k}"
else:
assert seq_val == par_val, f"Quantizer state mismatch for {key}"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def test_single_gpu_fallback(tmp_path):
"""Verify parallel=True gracefully falls back to sequential on single GPU."""
tiny_model_dir = create_tiny_llama_dir(tmp_path / "model", num_hidden_layers=2)
model = AutoModelForCausalLM.from_pretrained(tiny_model_dir)
model = model.cuda() # All on cuda:0
model.eval()

def forward_loop(model):
input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda()
with torch.no_grad():
model(input_ids)

model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop)

# parallel=True but single GPU → should fall back to sequential without error.
export_dir = tmp_path / "export"
export_hf_vllm_fq_checkpoint(model, export_dir=export_dir, parallel=True)

assert (export_dir / "vllm_fq_modelopt_state.pth").exists()
reloaded = AutoModelForCausalLM.from_pretrained(export_dir)
assert reloaded is not None
Loading