Skip to content

Commit 19608f6

Browse files
committed
added support for GQA
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent b758cc5 commit 19608f6

File tree

2 files changed

+89
-68
lines changed

2 files changed

+89
-68
lines changed

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 86 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""Export HuggingFace model to vLLM fakequant checkpoint."""
1616

17+
import contextlib
1718
import copy
1819
import re
1920
from pathlib import Path
@@ -32,6 +33,7 @@
3233

3334
from ..layer_utils import get_experts_list, is_moe
3435
from ..quant_utils import get_quantization_format
36+
from ..unified_export_hf import collect_shared_input_modules
3537

3638
__all__ = [
3739
"export_hf_vllm_fq_checkpoint",
@@ -122,12 +124,22 @@ def _resmooth_experts_for_export(
122124
model: nn.Module,
123125
state_dict: dict[str, Any],
124126
) -> tuple[dict[str, tuple[torch.Tensor, torch.Tensor | None]], set[str]]:
125-
"""Average pqs and unify input amax across MoE experts when AWQ smoothing applies (no-op otherwise).
126-
127-
Adjusts expert weights in ``state_dict`` as ``W' = W * old_pqs / avg_pqs`` and returns
128-
input-quantizer overrides for ``modelopt_state_weights``. **Does nothing** for weight-only
129-
MoE (no ``pre_quant_scale`` on experts) or unsupported MoE layouts — same as skipping the
130-
MoE branch in :func:`requantize_resmooth_fused_llm_layers`.
127+
"""Average pqs and unify input amax for all groups vLLM collapses to one input quantizer.
128+
129+
Covers two cases that are structurally identical — a set of linears sharing the
130+
same input that vLLM fuses behind a single input quantizer:
131+
132+
* **MoE experts** — vLLM uses one input quantizer per expert group.
133+
* **Dense GQA attention** — vLLM fuses q/k/v into ``qkv_proj`` with one input
134+
quantizer. AWQ calibration gives each projection a different pqs because they
135+
share the same input but have different weight magnitudes; using only q's pqs
136+
for k and v at inference corrupts those activations.
137+
138+
For each group, adjusts weights in ``state_dict`` as ``W' = W * old_pqs / avg_pqs``
139+
and returns input-quantizer overrides so every member exports the same ``avg_pqs``.
140+
Only runs when input quantizers are **enabled** (``nvfp4_awq_wa`` style); for
141+
weight-only AWQ the input quantizer is disabled and pqs is folded into each
142+
projection's own weight independently.
131143
"""
132144
qfmt = get_quantization_format(model)
133145
if qfmt is None or "awq" not in qfmt.lower():
@@ -137,61 +149,70 @@ def _resmooth_experts_for_export(
137149
id_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()}
138150
out: dict[str, tuple[torch.Tensor, torch.Tensor | None]] = {}
139151
requant_weights: set[str] = set()
140-
for _, module in model.named_modules():
141-
if not is_moe(module):
142-
continue
143-
try:
144-
expert_groups = get_experts_list(module, model_type)
145-
except NotImplementedError:
146-
continue
147-
148-
for experts in expert_groups:
149-
if not experts:
152+
153+
def _process_group(modules: list[nn.Module]) -> None:
154+
pqs_list = _collect_expert_pre_quant_scales(modules)
155+
if pqs_list is None:
156+
return
157+
158+
avg_pqs = torch.stack(pqs_list).mean(0)
159+
avg_pqs = avg_pqs.clamp(min=torch.finfo(torch.float32).tiny)
160+
161+
for m in modules:
162+
nm = id_to_name.get(id(m))
163+
if nm is None or f"{nm}.weight" not in state_dict:
150164
continue
151-
pre_quant_scales_list = _collect_expert_pre_quant_scales(experts)
152-
if pre_quant_scales_list is None:
165+
old_pqs = m.input_quantizer._pre_quant_scale
166+
avg_pqs_dev = avg_pqs.to(device=old_pqs.device, dtype=old_pqs.dtype)
167+
if torch.equal(old_pqs, avg_pqs_dev):
153168
continue
169+
weight = state_dict[f"{nm}.weight"]
170+
ratio = old_pqs.to(dtype=torch.float32, device=weight.device) / avg_pqs_dev.to(
171+
dtype=torch.float32, device=weight.device
172+
)
173+
state_dict[f"{nm}.weight"] = (weight.to(torch.float32) * ratio).to(weight.dtype)
174+
requant_weights.add(f"{nm}.weight")
175+
176+
synced_amax: torch.Tensor | None = None
177+
amaxes = [m.input_quantizer.amax for m in modules]
178+
if all(a is not None for a in amaxes):
179+
synced_amax = merge_amax_tensors_for_group(amaxes)
180+
181+
avg_pqs_out = avg_pqs.detach().clone()
182+
for m in modules:
183+
nm = id_to_name.get(id(m))
184+
if nm is None:
185+
continue
186+
out[get_unwrapped_name(f"{nm}.input_quantizer", model)] = (avg_pqs_out, synced_amax)
154187

155-
avg_pre_quant_scale = torch.stack(pre_quant_scales_list).mean(0)
156-
# Guard against degenerate calibration where a channel's scale is zero:
157-
# zero avg_pqs would produce inf ratio and corrupt the exported weight.
158-
avg_pre_quant_scale = avg_pre_quant_scale.clamp(min=torch.finfo(torch.float32).tiny)
159-
160-
for ex in experts:
161-
nm = id_to_name.get(id(ex))
162-
if nm is None or f"{nm}.weight" not in state_dict:
163-
continue
164-
old_pre_quant_scale = ex.input_quantizer._pre_quant_scale
165-
avg_pre_quant_scale = avg_pre_quant_scale.to(
166-
device=old_pre_quant_scale.device, dtype=old_pre_quant_scale.dtype
167-
)
168-
if torch.equal(old_pre_quant_scale, avg_pre_quant_scale):
169-
continue
170-
weight = state_dict[f"{nm}.weight"]
171-
updated_weight = (
172-
weight.to(torch.float32)
173-
* old_pre_quant_scale.to(dtype=torch.float32, device=weight.device)
174-
/ avg_pre_quant_scale.to(dtype=torch.float32, device=weight.device)
175-
).to(weight.dtype)
176-
state_dict[f"{nm}.weight"] = updated_weight
177-
requant_weights.add(f"{nm}.weight")
178-
179-
iq0 = experts[0].input_quantizer
180-
synced_amax: torch.Tensor | None = None
181-
if iq0.is_enabled:
182-
amaxes = [e.input_quantizer.amax for e in experts]
183-
if all(a is not None for a in amaxes):
184-
synced_amax = merge_amax_tensors_for_group(amaxes)
185-
186-
avg_pre_quant_scale_output = avg_pre_quant_scale.detach().clone()
187-
for ex in experts:
188-
nm = id_to_name.get(id(ex))
189-
if nm is None:
190-
continue
191-
out[get_unwrapped_name(f"{nm}.input_quantizer", model)] = (
192-
avg_pre_quant_scale_output,
193-
synced_amax,
194-
)
188+
# MoE expert groups — must be enumerated by name because MoE routing sends
189+
# different tokens to each expert, so forward hooks cannot detect them as
190+
# sharing the same input tensor.
191+
for _, module in model.named_modules():
192+
if is_moe(module):
193+
try:
194+
expert_groups = get_experts_list(module, model_type)
195+
except NotImplementedError:
196+
pass
197+
else:
198+
for experts in expert_groups:
199+
if experts:
200+
_process_group(experts)
201+
202+
# Dense shared-input groups (e.g. q/k/v in GQA attention) — detected via forward
203+
# hooks so any architecture is covered regardless of projection attribute names.
204+
205+
dev = next(model.parameters()).device
206+
207+
def _dummy_forward() -> None:
208+
# Partial forward is OK: hooks record layers reached before failure (e.g. VLMs).
209+
with contextlib.suppress(Exception):
210+
model(torch.ones([1, 2], dtype=torch.long, device=dev))
211+
212+
input_to_linear, _ = collect_shared_input_modules(model, _dummy_forward)
213+
for modules in input_to_linear.values():
214+
if len(modules) > 1:
215+
_process_group(modules)
195216

196217
return out, requant_weights
197218

@@ -227,10 +248,10 @@ def export_hf_vllm_fq_checkpoint(
227248
# to the corresponding weight tensor in the copy.
228249
state_dict = model.state_dict()
229250

230-
# Non-mutating MoE expert resmooth: average pqs and adjust state_dict weights.
231-
# Must run before the fakequant loop so that the adjusted weights are fakequanted
232-
# with the correct per-block scales.
233-
expert_pqs_overrides, requant_weights = _resmooth_experts_for_export(model, state_dict)
251+
# Non-mutating resmooth: average pqs across all shared-input groups (dense GQA q/k/v
252+
# and MoE experts) and adjust state_dict weights. Must run before the fakequant loop
253+
# so adjusted weights are fakequanted with the correct per-block scales.
254+
pqs_overrides, requant_weights = _resmooth_experts_for_export(model, state_dict)
234255

235256
fakequant_weights: set[str] = set()
236257
# Input quantizer keys whose _pre_quant_scale was folded into the weight above.
@@ -317,9 +338,9 @@ def export_hf_vllm_fq_checkpoint(
317338
qstate_val["_pre_quant_scale"]
318339
)
319340

320-
# Patch expert input quantizers with averaged pqs and unified amax so that
321-
# vLLM's single per-group input quantizer sees consistent values across experts.
322-
for iq_key, (avg_pqs, max_input_amax) in expert_pqs_overrides.items():
341+
# Patch input quantizers with averaged pqs and unified amax so that vLLM's single
342+
# per-group input quantizer sees consistent values (covers both dense qkv and MoE experts).
343+
for iq_key, (avg_pqs, max_input_amax) in pqs_overrides.items():
323344
if iq_key in quantizer_state_dict:
324345
qstate_val = quantizer_state_dict[iq_key]
325346
if isinstance(qstate_val, dict):

modelopt/torch/export/unified_export_hf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def _save_component_state_dict_safetensors(
163163
json.dump(metadata, f, indent=4)
164164

165165

166-
def _collect_shared_input_modules(
166+
def collect_shared_input_modules(
167167
model: nn.Module,
168168
dummy_forward_fn: Callable[[], None],
169169
collect_layernorms: bool = False,
@@ -387,7 +387,7 @@ def llm_dummy_forward():
387387
else:
388388
model(fake_input)
389389

390-
input_to_linear, output_to_layernorm = _collect_shared_input_modules(
390+
input_to_linear, output_to_layernorm = collect_shared_input_modules(
391391
model, llm_dummy_forward, collect_layernorms=True
392392
)
393393

@@ -862,7 +862,7 @@ def _fuse_qkv_linears_diffusion(
862862

863863
# Collect modules sharing the same input
864864
try:
865-
input_to_linear, _ = _collect_shared_input_modules(
865+
input_to_linear, _ = collect_shared_input_modules(
866866
model, dummy_forward_fn, collect_layernorms=False
867867
)
868868
except Exception as e:

0 commit comments

Comments
 (0)