Skip to content

Commit d28cda9

Browse files
committed
feat: parallelize fakequant export across GPUs via ThreadPoolExecutor
Refactors `export_hf_vllm_fq_checkpoint()` Step 1 (weight fake-quantization) to process weights concurrently across GPUs using ThreadPoolExecutor. - Collects all quantizer work items, groups by device - One thread per GPU, each processes its weight partition - Falls back to sequential on single GPU or CPU (no overhead) - Adds `parallel: bool = True` parameter (default on, backwards compatible) - Adds GPU test: parallel vs sequential produces bitwise identical output - Adds single-GPU fallback test Measured speedup on Qwen3-8B with barboqt_exhaustive_2b (8x H100): - Quantize step: 293s → 49s (6.0x) - Total export: 312s → 66s (4.7x) Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com>
1 parent 18ce04f commit d28cda9

File tree

2 files changed

+279
-44
lines changed

2 files changed

+279
-44
lines changed

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 150 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
# limitations under the License.
1515
"""Export HuggingFace model to vLLM fakequant checkpoint."""
1616

17+
import logging
18+
import time
19+
from collections import defaultdict
20+
from concurrent.futures import ThreadPoolExecutor
21+
from dataclasses import dataclass
1722
from pathlib import Path
1823

1924
import torch
@@ -28,6 +33,92 @@
2833

2934
__all__ = ["export_hf_vllm_fq_checkpoint"]
3035

36+
logger = logging.getLogger(__name__)
37+
38+
39+
@dataclass
40+
class _WeightQuantWork:
41+
"""A single weight tensor to be fake-quantized during export."""
42+
43+
sd_key: str
44+
quantizer: TensorQuantizer
45+
weight: torch.Tensor
46+
# For optional pre_quant_scale folding:
47+
inp_q: TensorQuantizer | None
48+
inp_q_key: str | None
49+
50+
51+
def _collect_quant_work(
52+
model: nn.Module, state_dict: dict[str, torch.Tensor]
53+
) -> list[_WeightQuantWork]:
54+
"""Collect all weight quantization work items from the model."""
55+
work_items = []
56+
seen_keys: set[str] = set()
57+
for module_name, module in model.named_modules():
58+
if not isinstance(module, QuantModule):
59+
continue
60+
for attr_name, quantizer in module.named_children():
61+
if not (
62+
attr_name.endswith("weight_quantizer")
63+
and isinstance(quantizer, TensorQuantizer)
64+
and quantizer.fake_quant
65+
and quantizer.is_enabled
66+
):
67+
continue
68+
weight_name = attr_name.removesuffix("_quantizer")
69+
prefix = f"{module_name}." if module_name else ""
70+
sd_key = f"{prefix}{weight_name}"
71+
assert sd_key not in seen_keys, f"Weight {sd_key} has already been fakequantized"
72+
seen_keys.add(sd_key)
73+
if sd_key not in state_dict:
74+
continue
75+
# Check for pre_quant_scale folding eligibility.
76+
inp_q = None
77+
inp_q_key = None
78+
inp_attr = attr_name.replace("weight_quantizer", "input_quantizer")
79+
if hasattr(module, inp_attr):
80+
candidate = getattr(module, inp_attr)
81+
if (
82+
hasattr(candidate, "_pre_quant_scale")
83+
and candidate._pre_quant_scale is not None
84+
and candidate._disabled
85+
):
86+
inp_q = candidate
87+
inp_q_key = get_unwrapped_name(
88+
f"{module_name}.{inp_attr}" if module_name else inp_attr, model
89+
)
90+
work_items.append(
91+
_WeightQuantWork(
92+
sd_key=sd_key,
93+
quantizer=quantizer,
94+
weight=state_dict[sd_key],
95+
inp_q=inp_q,
96+
inp_q_key=inp_q_key,
97+
)
98+
)
99+
return work_items
100+
101+
102+
def _process_weight(item: _WeightQuantWork) -> tuple[str, torch.Tensor, str | None]:
103+
"""Fake-quantize a single weight tensor and optionally fold pre_quant_scale.
104+
105+
Returns (sd_key, quantized_weight_on_cpu, inp_q_key_or_None).
106+
"""
107+
w = item.weight
108+
w_quant = item.quantizer(w.float()).to(w.dtype).cpu()
109+
if item.inp_q is not None:
110+
scale = item.inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
111+
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)
112+
return item.sd_key, w_quant, item.inp_q_key
113+
114+
115+
def _process_device_batch(items: list[_WeightQuantWork], device: torch.device):
116+
"""Process all weight items on a single GPU. Runs in a dedicated thread."""
117+
with torch.cuda.device(device):
118+
results = [_process_weight(item) for item in items]
119+
torch.cuda.synchronize(device)
120+
return results
121+
31122

32123
def disable_rotate(quantizer: TensorQuantizer):
33124
"""Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type."""
@@ -41,6 +132,7 @@ def disable_rotate(quantizer: TensorQuantizer):
41132
def export_hf_vllm_fq_checkpoint(
42133
model: nn.Module,
43134
export_dir: Path | str,
135+
parallel: bool = True,
44136
):
45137
"""Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload.
46138
@@ -53,6 +145,9 @@ def export_hf_vllm_fq_checkpoint(
53145
Args:
54146
model: In-memory quantized model.
55147
export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``.
148+
parallel: If True, fake-quantize weights across GPUs concurrently using
149+
one thread per GPU device. Falls back to sequential when all weights
150+
are on the same device or on CPU. Default True.
56151
"""
57152
export_dir = Path(export_dir)
58153
export_dir.mkdir(parents=True, exist_ok=True)
@@ -62,50 +157,60 @@ def export_hf_vllm_fq_checkpoint(
62157
# parameters are never modified. Apply each weight quantizer's fake-quant
63158
# to the corresponding weight tensor in the copy.
64159
state_dict = model.state_dict()
65-
fakequant_weights = set()
66-
input_quantizers_folded_pqs = (
67-
set()
68-
) # keys for input_quantizers where pre_quant_scale was folded
160+
fakequant_weights: set[str] = set()
161+
input_quantizers_folded_pqs: set[str] = set()
162+
163+
work_items = _collect_quant_work(model, state_dict)
164+
165+
# Group work items by device for parallel dispatch.
166+
device_groups: dict[torch.device, list[_WeightQuantWork]] = defaultdict(list)
167+
for item in work_items:
168+
device_groups[item.weight.device].append(item)
169+
170+
num_cuda_devices = sum(1 for d in device_groups if d.type == "cuda")
171+
use_parallel = parallel and num_cuda_devices > 1
172+
173+
t0 = time.monotonic()
69174
with torch.inference_mode():
70-
for module_name, module in model.named_modules():
71-
if not isinstance(module, QuantModule):
72-
continue
73-
for attr_name, quantizer in module.named_children():
74-
if not (
75-
attr_name.endswith("weight_quantizer")
76-
and isinstance(quantizer, TensorQuantizer)
77-
and quantizer.fake_quant
78-
and quantizer.is_enabled
79-
):
80-
continue
81-
weight_name = attr_name.removesuffix("_quantizer")
82-
prefix = f"{module_name}." if module_name else ""
83-
sd_key = f"{prefix}{weight_name}"
84-
assert sd_key not in fakequant_weights, (
85-
f"Weight {sd_key} has already been fakequantized"
86-
)
87-
if sd_key in state_dict:
88-
w = state_dict[sd_key]
89-
w_quant = quantizer(w.float()).to(w.dtype).cpu()
90-
# Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
91-
# Only valid when input_quantizer does NOT fake-quant activations. If it does
92-
# fake_quant(x*s), the non-linearity prevents folding s into W.
93-
inp_attr = attr_name.replace("weight_quantizer", "input_quantizer")
94-
if hasattr(module, inp_attr):
95-
inp_q = getattr(module, inp_attr)
96-
if (
97-
hasattr(inp_q, "_pre_quant_scale")
98-
and inp_q._pre_quant_scale is not None
99-
and inp_q._disabled
100-
):
101-
scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
102-
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)
103-
inp_q_key = get_unwrapped_name(
104-
f"{module_name}.{inp_attr}" if module_name else inp_attr, model
105-
)
106-
input_quantizers_folded_pqs.add(inp_q_key)
107-
state_dict[sd_key] = w_quant
108-
fakequant_weights.add(sd_key)
175+
if use_parallel:
176+
logger.info(
177+
"Parallel export: %d weights across %d GPUs (%s)",
178+
len(work_items),
179+
num_cuda_devices,
180+
", ".join(f"{d}: {len(items)} weights" for d, items in device_groups.items()),
181+
)
182+
all_results: list[tuple[str, torch.Tensor, str | None]] = []
183+
with ThreadPoolExecutor(max_workers=num_cuda_devices) as pool:
184+
futures = []
185+
for device, items in device_groups.items():
186+
if device.type == "cuda":
187+
futures.append(pool.submit(_process_device_batch, items, device))
188+
else:
189+
# CPU weights: process inline (no thread needed).
190+
all_results.extend([_process_weight(item) for item in items])
191+
for future in futures:
192+
all_results.extend(future.result())
193+
for sd_key, w_quant, inp_q_key in all_results:
194+
state_dict[sd_key] = w_quant
195+
fakequant_weights.add(sd_key)
196+
if inp_q_key is not None:
197+
input_quantizers_folded_pqs.add(inp_q_key)
198+
else:
199+
# Sequential fallback (single GPU, CPU, or parallel=False).
200+
for item in work_items:
201+
sd_key, w_quant, inp_q_key = _process_weight(item)
202+
state_dict[sd_key] = w_quant
203+
fakequant_weights.add(sd_key)
204+
if inp_q_key is not None:
205+
input_quantizers_folded_pqs.add(inp_q_key)
206+
207+
elapsed = time.monotonic() - t0
208+
logger.info(
209+
"Export step 1 (%s): %d weights fake-quantized in %.1fs",
210+
"parallel" if use_parallel else "sequential",
211+
len(fakequant_weights),
212+
elapsed,
213+
)
109214

110215
# Filter quantizer tensors out for a clean HF checkpoint.
111216
clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k}
@@ -166,4 +271,5 @@ def export_hf_vllm_fq_checkpoint(
166271

167272
for wq, orig_rotate in wqs_to_restore:
168273
wq.enable()
169-
wq._rotate = orig_rotate
274+
if orig_rotate is not None:
275+
wq._rotate = orig_rotate
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Test parallel vs sequential export produces identical outputs."""
16+
17+
import pytest
18+
import torch
19+
from _test_utils.torch.transformers_models import create_tiny_llama_dir
20+
from transformers import AutoModelForCausalLM
21+
22+
import modelopt.torch.quantization as mtq
23+
from modelopt.torch.export import export_hf_vllm_fq_checkpoint
24+
25+
26+
def _quantize_model(tmp_path, suffix=""):
27+
"""Create and quantize a tiny LLaMA model. Returns (model, export_dir)."""
28+
tiny_model_dir = create_tiny_llama_dir(tmp_path / f"model{suffix}", num_hidden_layers=4)
29+
model = AutoModelForCausalLM.from_pretrained(tiny_model_dir)
30+
model = model.cuda()
31+
model.eval()
32+
33+
def forward_loop(model):
34+
input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda()
35+
with torch.no_grad():
36+
model(input_ids)
37+
38+
model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop)
39+
return model
40+
41+
42+
@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG])
43+
def test_parallel_vs_sequential_identical(tmp_path, quant_cfg):
44+
"""Verify parallel export produces bitwise identical output to sequential."""
45+
num_gpus = torch.cuda.device_count()
46+
if num_gpus < 2:
47+
pytest.skip("Need >= 2 GPUs for parallel export test")
48+
49+
# Create a tiny model and spread across GPUs.
50+
tiny_model_dir = create_tiny_llama_dir(tmp_path / "model", num_hidden_layers=4)
51+
model = AutoModelForCausalLM.from_pretrained(
52+
tiny_model_dir, device_map="auto", torch_dtype=torch.float16
53+
)
54+
model.eval()
55+
56+
def forward_loop(model):
57+
first_device = next(model.parameters()).device
58+
input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).to(first_device)
59+
with torch.no_grad():
60+
model(input_ids)
61+
62+
model = mtq.quantize(model, quant_cfg, forward_loop)
63+
64+
# Export sequentially.
65+
seq_dir = tmp_path / "export_sequential"
66+
export_hf_vllm_fq_checkpoint(model, export_dir=seq_dir, parallel=False)
67+
68+
# Re-enable weight quantizers (export disables them — need to restore for second export).
69+
# The function already re-enables them at the end, so we can just call it again.
70+
71+
# Export in parallel.
72+
par_dir = tmp_path / "export_parallel"
73+
export_hf_vllm_fq_checkpoint(model, export_dir=par_dir, parallel=True)
74+
75+
# Compare HF weights.
76+
seq_model = AutoModelForCausalLM.from_pretrained(seq_dir)
77+
par_model = AutoModelForCausalLM.from_pretrained(par_dir)
78+
seq_sd = seq_model.state_dict()
79+
par_sd = par_model.state_dict()
80+
81+
assert seq_sd.keys() == par_sd.keys(), "Key mismatch between sequential and parallel export"
82+
for key in seq_sd:
83+
assert torch.equal(seq_sd[key], par_sd[key]), (
84+
f"Weight mismatch for {key}: max diff={torch.abs(seq_sd[key] - par_sd[key]).max()}"
85+
)
86+
87+
# Compare modelopt state.
88+
seq_state = torch.load(seq_dir / "vllm_fq_modelopt_state.pth", weights_only=False)
89+
par_state = torch.load(par_dir / "vllm_fq_modelopt_state.pth", weights_only=False)
90+
91+
seq_qsd = seq_state["modelopt_state_weights"]
92+
par_qsd = par_state["modelopt_state_weights"]
93+
assert seq_qsd.keys() == par_qsd.keys(), "Quantizer state dict key mismatch"
94+
for key in seq_qsd:
95+
seq_val = seq_qsd[key]
96+
par_val = par_qsd[key]
97+
if isinstance(seq_val, dict):
98+
for k in seq_val:
99+
if isinstance(seq_val[k], torch.Tensor):
100+
assert torch.equal(seq_val[k], par_val[k]), (
101+
f"Quantizer state mismatch for {key}.{k}"
102+
)
103+
else:
104+
assert seq_val[k] == par_val[k], f"Quantizer state mismatch for {key}.{k}"
105+
else:
106+
assert seq_val == par_val, f"Quantizer state mismatch for {key}"
107+
108+
109+
def test_single_gpu_fallback(tmp_path):
110+
"""Verify parallel=True gracefully falls back to sequential on single GPU."""
111+
tiny_model_dir = create_tiny_llama_dir(tmp_path / "model", num_hidden_layers=2)
112+
model = AutoModelForCausalLM.from_pretrained(tiny_model_dir)
113+
model = model.cuda() # All on cuda:0
114+
model.eval()
115+
116+
def forward_loop(model):
117+
input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda()
118+
with torch.no_grad():
119+
model(input_ids)
120+
121+
model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop)
122+
123+
# parallel=True but single GPU → should fall back to sequential without error.
124+
export_dir = tmp_path / "export"
125+
export_hf_vllm_fq_checkpoint(model, export_dir=export_dir, parallel=True)
126+
127+
assert (export_dir / "vllm_fq_modelopt_state.pth").exists()
128+
reloaded = AutoModelForCausalLM.from_pretrained(export_dir)
129+
assert reloaded is not None

0 commit comments

Comments
 (0)