diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 1908354a0a..6f2e3caf8b 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -14,6 +14,11 @@ # limitations under the License. """Export HuggingFace model to vLLM fakequant checkpoint.""" +import logging +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from pathlib import Path import torch @@ -28,6 +33,93 @@ __all__ = ["export_hf_vllm_fq_checkpoint"] +logger = logging.getLogger(__name__) + + +@dataclass +class _WeightQuantWork: + """A single weight tensor to be fake-quantized during export.""" + + sd_key: str + quantizer: TensorQuantizer + weight: torch.Tensor + # For optional pre_quant_scale folding: + inp_q: TensorQuantizer | None + inp_q_key: str | None + + +def _collect_quant_work( + model: nn.Module, state_dict: dict[str, torch.Tensor] +) -> list[_WeightQuantWork]: + """Collect all weight quantization work items from the model.""" + work_items = [] + seen_keys: set[str] = set() + for module_name, module in model.named_modules(): + if not isinstance(module, QuantModule): + continue + for attr_name, quantizer in module.named_children(): + if not ( + attr_name.endswith("weight_quantizer") + and isinstance(quantizer, TensorQuantizer) + and quantizer.fake_quant + and quantizer.is_enabled + ): + continue + weight_name = attr_name.removesuffix("_quantizer") + prefix = f"{module_name}." if module_name else "" + sd_key = f"{prefix}{weight_name}" + assert sd_key not in seen_keys, f"Weight {sd_key} has already been fakequantized" + seen_keys.add(sd_key) + if sd_key not in state_dict: + continue + # Check for pre_quant_scale folding eligibility. + inp_q = None + inp_q_key = None + inp_attr = attr_name.replace("weight_quantizer", "input_quantizer") + if hasattr(module, inp_attr): + candidate = getattr(module, inp_attr) + if ( + hasattr(candidate, "_pre_quant_scale") + and candidate._pre_quant_scale is not None + and candidate._disabled + and getattr(candidate, "_enable_pre_quant_scale", True) + ): + inp_q = candidate + inp_q_key = get_unwrapped_name( + f"{module_name}.{inp_attr}" if module_name else inp_attr, model + ) + work_items.append( + _WeightQuantWork( + sd_key=sd_key, + quantizer=quantizer, + weight=state_dict[sd_key], + inp_q=inp_q, + inp_q_key=inp_q_key, + ) + ) + return work_items + + +def _process_weight(item: _WeightQuantWork) -> tuple[str, torch.Tensor, str | None]: + """Fake-quantize a single weight tensor and optionally fold pre_quant_scale. + + Returns (sd_key, quantized_weight_on_cpu, inp_q_key_or_None). + """ + w = item.weight + w_quant = item.quantizer(w.float()).to(w.dtype) + if item.inp_q is not None: + scale = item.inp_q._pre_quant_scale.squeeze().to(device=w_quant.device) + w_quant = (w_quant * scale[None, :]).to(w_quant.dtype) + return item.sd_key, w_quant.cpu(), item.inp_q_key + + +def _process_device_batch(items: list[_WeightQuantWork], device: torch.device): + """Process all weight items on a single GPU. Runs in a dedicated thread.""" + with torch.inference_mode(), torch.cuda.device(device): + results = [_process_weight(item) for item in items] + torch.cuda.synchronize(device) + return results + def disable_rotate(quantizer: TensorQuantizer): """Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type.""" @@ -41,6 +133,7 @@ def disable_rotate(quantizer: TensorQuantizer): def export_hf_vllm_fq_checkpoint( model: nn.Module, export_dir: Path | str, + parallel: bool = True, ): """Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload. @@ -53,6 +146,9 @@ def export_hf_vllm_fq_checkpoint( Args: model: In-memory quantized model. export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``. + parallel: If True, fake-quantize weights across GPUs concurrently using + one thread per GPU device. Falls back to sequential when all weights + are on the same device or on CPU. Default True. """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) @@ -62,50 +158,66 @@ def export_hf_vllm_fq_checkpoint( # parameters are never modified. Apply each weight quantizer's fake-quant # to the corresponding weight tensor in the copy. state_dict = model.state_dict() - fakequant_weights = set() - input_quantizers_folded_pqs = ( - set() - ) # keys for input_quantizers where pre_quant_scale was folded + fakequant_weights: set[str] = set() + input_quantizers_folded_pqs: set[str] = set() + + work_items = _collect_quant_work(model, state_dict) + + # Group work items by device for parallel dispatch. + device_groups: dict[torch.device, list[_WeightQuantWork]] = defaultdict(list) + for item in work_items: + device_groups[item.weight.device].append(item) + + num_cuda_devices = sum(1 for d in device_groups if d.type == "cuda") + use_parallel = parallel and num_cuda_devices > 1 + + t0 = time.monotonic() with torch.inference_mode(): - for module_name, module in model.named_modules(): - if not isinstance(module, QuantModule): - continue - for attr_name, quantizer in module.named_children(): - if not ( - attr_name.endswith("weight_quantizer") - and isinstance(quantizer, TensorQuantizer) - and quantizer.fake_quant - and quantizer.is_enabled - ): - continue - weight_name = attr_name.removesuffix("_quantizer") - prefix = f"{module_name}." if module_name else "" - sd_key = f"{prefix}{weight_name}" - assert sd_key not in fakequant_weights, ( - f"Weight {sd_key} has already been fakequantized" - ) - if sd_key in state_dict: - w = state_dict[sd_key] - w_quant = quantizer(w.float()).to(w.dtype).cpu() - # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) - # Only valid when input_quantizer does NOT fake-quant activations. If it does - # fake_quant(x*s), the non-linearity prevents folding s into W. - inp_attr = attr_name.replace("weight_quantizer", "input_quantizer") - if hasattr(module, inp_attr): - inp_q = getattr(module, inp_attr) - if ( - hasattr(inp_q, "_pre_quant_scale") - and inp_q._pre_quant_scale is not None - and inp_q._disabled - ): - scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device) - w_quant = (w_quant * scale[None, :]).to(w_quant.dtype) - inp_q_key = get_unwrapped_name( - f"{module_name}.{inp_attr}" if module_name else inp_attr, model - ) + if use_parallel: + logger.info( + "Parallel export: %d weights across %d GPUs (%s)", + len(work_items), + num_cuda_devices, + ", ".join(f"{d}: {len(items)} weights" for d, items in device_groups.items()), + ) + with ThreadPoolExecutor(max_workers=num_cuda_devices) as pool: + # Submit GPU batches first (non-blocking) + futures = [ + pool.submit(_process_device_batch, items, device) + for device, items in device_groups.items() + if device.type == "cuda" + ] + # Process CPU weights inline while GPU futures run + for device, items in device_groups.items(): + if device.type != "cuda": + for sd_key, w_quant, inp_q_key in map(_process_weight, items): + state_dict[sd_key] = w_quant + fakequant_weights.add(sd_key) + if inp_q_key is not None: + input_quantizers_folded_pqs.add(inp_q_key) + # Collect GPU results + for future in futures: + for sd_key, w_quant, inp_q_key in future.result(): + state_dict[sd_key] = w_quant + fakequant_weights.add(sd_key) + if inp_q_key is not None: input_quantizers_folded_pqs.add(inp_q_key) - state_dict[sd_key] = w_quant - fakequant_weights.add(sd_key) + else: + # Sequential fallback (single GPU, CPU, or parallel=False). + for item in work_items: + sd_key, w_quant, inp_q_key = _process_weight(item) + state_dict[sd_key] = w_quant + fakequant_weights.add(sd_key) + if inp_q_key is not None: + input_quantizers_folded_pqs.add(inp_q_key) + + elapsed = time.monotonic() - t0 + logger.info( + "Export step 1 (%s): %d weights fake-quantized in %.1fs", + "parallel" if use_parallel else "sequential", + len(fakequant_weights), + elapsed, + ) # Filter quantizer tensors out for a clean HF checkpoint. clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k} @@ -166,4 +278,5 @@ def export_hf_vllm_fq_checkpoint( for wq, orig_rotate in wqs_to_restore: wq.enable() - wq._rotate = orig_rotate + if orig_rotate is not None: + wq._rotate = orig_rotate diff --git a/tests/gpu/torch/export/test_vllm_fakequant_hf_parallel_export.py b/tests/gpu/torch/export/test_vllm_fakequant_hf_parallel_export.py new file mode 100644 index 0000000000..3e559a3ee5 --- /dev/null +++ b/tests/gpu/torch/export/test_vllm_fakequant_hf_parallel_export.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test parallel vs sequential export produces identical outputs.""" + +import pytest +import torch +from _test_utils.torch.transformers_models import create_tiny_llama_dir +from transformers import AutoModelForCausalLM + +import modelopt.torch.quantization as mtq +from modelopt.torch.export import export_hf_vllm_fq_checkpoint + + +def _quantize_model(tmp_path, suffix=""): + """Create and quantize a tiny LLaMA model. Returns (model, export_dir).""" + tiny_model_dir = create_tiny_llama_dir(tmp_path / f"model{suffix}", num_hidden_layers=4) + model = AutoModelForCausalLM.from_pretrained(tiny_model_dir) + model = model.cuda() + model.eval() + + def forward_loop(model): + input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda() + with torch.no_grad(): + model(input_ids) + + model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop) + return model + + +@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG]) +def test_parallel_vs_sequential_identical(tmp_path, quant_cfg): + """Verify parallel export produces bitwise identical output to sequential.""" + num_gpus = torch.cuda.device_count() + if num_gpus < 2: + pytest.skip("Need >= 2 GPUs for parallel export test") + + # Create a tiny model and spread across GPUs. + tiny_model_dir = create_tiny_llama_dir(tmp_path / "model", num_hidden_layers=4) + model = AutoModelForCausalLM.from_pretrained( + tiny_model_dir, device_map="auto", torch_dtype=torch.float16 + ) + model.eval() + + def forward_loop(model): + first_device = next(model.parameters()).device + input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).to(first_device) + with torch.no_grad(): + model(input_ids) + + model = mtq.quantize(model, quant_cfg, forward_loop) + + # Export sequentially. + seq_dir = tmp_path / "export_sequential" + export_hf_vllm_fq_checkpoint(model, export_dir=seq_dir, parallel=False) + + # Re-enable weight quantizers (export disables them — need to restore for second export). + # The function already re-enables them at the end, so we can just call it again. + + # Export in parallel. + par_dir = tmp_path / "export_parallel" + export_hf_vllm_fq_checkpoint(model, export_dir=par_dir, parallel=True) + + # Compare HF weights. + seq_model = AutoModelForCausalLM.from_pretrained(seq_dir) + par_model = AutoModelForCausalLM.from_pretrained(par_dir) + seq_sd = seq_model.state_dict() + par_sd = par_model.state_dict() + + assert seq_sd.keys() == par_sd.keys(), "Key mismatch between sequential and parallel export" + for key in seq_sd: + assert torch.allclose(seq_sd[key], par_sd[key]), ( + f"Weight mismatch for {key}: max diff={torch.abs(seq_sd[key] - par_sd[key]).max()}" + ) + + # Compare full modelopt state payload (weights_only=False: modelopt_state contains + # Python objects — dicts, strings, quantizer configs — that require unpickling). + seq_state = torch.load(seq_dir / "vllm_fq_modelopt_state.pth", weights_only=False) + par_state = torch.load(par_dir / "vllm_fq_modelopt_state.pth", weights_only=False) + + # Compare modelopt_state_dict (quantizer metadata including quantizer_state). + seq_msd = seq_state.get("modelopt_state_dict", []) + par_msd = par_state.get("modelopt_state_dict", []) + assert len(seq_msd) == len(par_msd), "modelopt_state_dict length mismatch" + for (seq_mode, seq_ms), (par_mode, par_ms) in zip(seq_msd, par_msd): + assert seq_mode == par_mode, f"Mode mismatch: {seq_mode} vs {par_mode}" + + # Compare modelopt_state_weights (per-quantizer tensor state). + seq_qsd = seq_state["modelopt_state_weights"] + par_qsd = par_state["modelopt_state_weights"] + assert seq_qsd.keys() == par_qsd.keys(), "Quantizer state dict key mismatch" + for key in seq_qsd: + seq_val = seq_qsd[key] + par_val = par_qsd[key] + if isinstance(seq_val, dict): + for k in seq_val: + if isinstance(seq_val[k], torch.Tensor): + assert torch.equal(seq_val[k], par_val[k]), ( + f"Quantizer state mismatch for {key}.{k}" + ) + else: + assert seq_val[k] == par_val[k], f"Quantizer state mismatch for {key}.{k}" + else: + assert seq_val == par_val, f"Quantizer state mismatch for {key}" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_single_gpu_fallback(tmp_path): + """Verify parallel=True gracefully falls back to sequential on single GPU.""" + tiny_model_dir = create_tiny_llama_dir(tmp_path / "model", num_hidden_layers=2) + model = AutoModelForCausalLM.from_pretrained(tiny_model_dir) + model = model.cuda() # All on cuda:0 + model.eval() + + def forward_loop(model): + input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda() + with torch.no_grad(): + model(input_ids) + + model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop) + + # parallel=True but single GPU → should fall back to sequential without error. + export_dir = tmp_path / "export" + export_hf_vllm_fq_checkpoint(model, export_dir=export_dir, parallel=True) + + assert (export_dir / "vllm_fq_modelopt_state.pth").exists() + reloaded = AutoModelForCausalLM.from_pretrained(export_dir) + assert reloaded is not None