Skip to content

Commit c010f3f

Browse files
committed
addressed comments
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent bd45696 commit c010f3f

File tree

4 files changed

+141
-144
lines changed

4 files changed

+141
-144
lines changed

examples/vllm_serve/vllm_reload_utils.py

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.distributed.parallel_state import get_tp_group
2424

2525
from modelopt.torch.export.plugins.vllm_fakequant_hf import (
26+
infer_quantizer_prefix_remap,
2627
is_weight_quantizer_state_key,
2728
merge_amax_tensors_for_group,
2829
)
@@ -149,9 +150,10 @@ def _group_keys_for_vllm(
149150
for key, value in state_dict.items():
150151
action, new_key, new_value = _convert_key_for_vllm(key, value)
151152
if new_key is None or new_value is None:
152-
assert action == "skip", (
153-
f"Expected action to be 'skip' for key {key}, value {value}, got {action}"
154-
)
153+
if action != "skip":
154+
raise RuntimeError(
155+
f"Expected action to be 'skip' for key {key}, value {value}, got {action}"
156+
)
155157
continue
156158
if action == "copy":
157159
vllm_state_dict[new_key] = new_value
@@ -219,38 +221,6 @@ def _merge_values_require_identical(merged_key: str, key_value_pairs: list[tuple
219221
return first_value
220222

221223

222-
def _infer_prefix_remap(
223-
quantizer_keys: dict[str, Any],
224-
map_fun: Callable[[dict[str, Any]], dict[str, Any]],
225-
) -> dict[str, str]:
226-
"""Map HF root name → vLLM root (e.g. ``backbone`` → ``model``) using ``map_fun`` on ``*.weight`` keys.
227-
228-
Quantizer keys never go through ``map_fun`` later, so we probe with a tiny CPU placeholder.
229-
It must be **2-D** (mappers expect matrix weights; 1-D often errors). Only the returned key
230-
path is used; values are ignored. A CPU tensor is enough for typical HF↔vLLM name mapping.
231-
"""
232-
prefix_remap: dict[str, str] = {}
233-
probe_weight = torch.empty((1, 1))
234-
for key in quantizer_keys:
235-
first_component = key.split(".")[0]
236-
if first_component in prefix_remap:
237-
continue
238-
last_dot = key.rfind(".")
239-
if last_dot == -1:
240-
continue
241-
probe_key = key[:last_dot] + ".weight"
242-
try:
243-
result = map_fun({probe_key: probe_weight})
244-
if result:
245-
new_key = next(iter(result))
246-
new_first = new_key.split(".")[0]
247-
if new_first != first_component:
248-
prefix_remap[first_component] = new_first
249-
except Exception as e:
250-
warnings.warn(f"prefix-remap probe failed for {probe_key!r}: {e}")
251-
return prefix_remap
252-
253-
254224
def convert_dict_to_vllm(
255225
state_dict: dict[str, Any],
256226
max_or_concat: bool = True,
@@ -273,7 +243,7 @@ def convert_dict_to_vllm(
273243
# invoked on non-quantizer keys.
274244
if map_fun is not None:
275245
q_only = {k: v for k, v in state_dict.items() if "_quantizer" in k}
276-
prefix_remap = _infer_prefix_remap(q_only, map_fun)
246+
prefix_remap = infer_quantizer_prefix_remap(q_only, map_fun)
277247
if prefix_remap:
278248
renamed = {}
279249
for k, v in state_dict.items():
@@ -406,11 +376,12 @@ def _has_buffers(state: dict) -> bool:
406376
if not is_weight_quantizer_state_key(wq_k):
407377
continue
408378
wq_state = filtered[wq_k]
409-
assert wq_k in saved or wq_state.get("_disabled"), (
410-
f"Weight quantizer {wq_k!r} is missing from saved quantizer_state but "
411-
f"is not marked _disabled (got _disabled={wq_state.get('_disabled')!r}). "
412-
f"vLLM fakequant export omits weight quantizer keys when weights are folded."
413-
)
379+
if wq_k not in saved and not wq_state.get("_disabled"):
380+
raise RuntimeError(
381+
f"Weight quantizer {wq_k!r} is missing from saved quantizer_state but "
382+
f"is not marked _disabled (got _disabled={wq_state.get('_disabled')!r}). "
383+
f"vLLM fakequant export omits weight quantizer keys when weights are folded."
384+
)
414385
metadata["quantizer_state"] = filtered
415386

416387

@@ -449,7 +420,8 @@ def restore_from_modelopt_state_vllm(
449420

450421
if not manager.has_state and isinstance(model, ModelLikeModule):
451422
model = model.init_modellike()
452-
assert not isinstance(model, ModelLikeModule), "Model must be a regular Module now!"
423+
if isinstance(model, ModelLikeModule):
424+
raise RuntimeError("Model must be a regular Module after restore, got ModelLikeModule")
453425
return model
454426

455427

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 112 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import copy
1818
import logging
1919
import re
20+
import warnings
2021
from collections.abc import Callable
2122
from pathlib import Path
2223
from typing import Any
@@ -121,11 +122,12 @@ def _check_all_weight_quantizers_disabled(model: nn.Module) -> None:
121122
if attr_name.endswith("weight_quantizer") and isinstance(
122123
quantizer, (TensorQuantizer, SequentialQuantizer)
123124
):
124-
assert not quantizer.is_enabled, (
125-
f"vLLM fakequant export: {attr_name!r} must be disabled before saving "
126-
f"quantizer_state (weights already folded). "
127-
f"See filter_modelopt_state_quantizer_state_for_model in vllm_reload_utils."
128-
)
125+
if quantizer.is_enabled:
126+
raise RuntimeError(
127+
f"vLLM fakequant export: {attr_name!r} must be disabled before saving "
128+
f"quantizer_state (weights already folded). "
129+
f"See filter_modelopt_state_quantizer_state_for_model in vllm_reload_utils."
130+
)
129131

130132

131133
def disable_rotate(quantizer: TensorQuantizer):
@@ -171,25 +173,27 @@ def requant_weights_for_export(
171173
``w`` (e.g. CPU offload).
172174
"""
173175
copied = copy.deepcopy(quantizer).to(device=weight.device)
174-
sequence_quantizers: list[TensorQuantizer] = (
176+
quantizers: list[TensorQuantizer] = (
175177
list(copied) if isinstance(copied, SequentialQuantizer) else [copied]
176178
)
177179

178-
for quantizer_copy in sequence_quantizers:
180+
for quantizer_copy in quantizers:
179181
quantizer_copy.eval()
180182
quantizer_copy.reset_amax()
181183
enable_stats_collection(quantizer_copy)
182184
# Match legacy single-quantizer path: first calib uses ``w`` as-is; chains use float.
183-
if len(sequence_quantizers) == 1:
184-
weight_quantized = sequence_quantizers[0](weight)
185+
if len(quantizers) == 1:
186+
weight_quantized = quantizers[0](weight)
185187
else:
186-
weight_quantized = weight.float()
187-
for quantizer_copy in sequence_quantizers:
188+
weight_quantized = weight
189+
for quantizer_copy in quantizers:
188190
weight_quantized = quantizer_copy(weight_quantized)
189-
for quantizer_copy in sequence_quantizers:
191+
for quantizer_copy in quantizers:
190192
finish_stats_collection(quantizer_copy)
191-
weight_quantized = weight.float()
192-
for quantizer_copy in sequence_quantizers:
193+
# Re-run application pass to get the quantized output with the freshly collected amax.
194+
# The calibration forward above only collected stats; its output is intentionally discarded.
195+
weight_quantized = weight
196+
for quantizer_copy in quantizers:
193197
weight_quantized = quantizer_copy(weight_quantized)
194198
return weight_quantized.to(weight.dtype)
195199

@@ -219,6 +223,12 @@ def merge_amax_tensors_for_group(tensors: list[torch.Tensor]) -> torch.Tensor:
219223
try:
220224
return torch.cat(tensors, dim=0).to(dtype=first.dtype, device=first.device)
221225
except RuntimeError:
226+
shapes = [tuple(t.shape) for t in tensors]
227+
warnings.warn(
228+
f"merge_amax_tensors_for_group: torch.cat failed for shapes {shapes}; "
229+
"falling back to scalar max which loses per-channel amax structure.",
230+
stacklevel=2,
231+
)
222232
flat = torch.cat([t.reshape(-1).float() for t in tensors])
223233
return torch.max(flat).to(dtype=first.dtype, device=first.device)
224234

@@ -258,7 +268,9 @@ def _process_group(modules: list[nn.Module]) -> None:
258268
if pqs_list is None:
259269
return
260270

261-
avg_pqs = torch.stack(pqs_list).mean(0)
271+
# Mean and clamp in float32: fp16/bf16 would underflow float32.tiny to 0 and divide by zero.
272+
pqs_dtype = pqs_list[0].dtype
273+
avg_pqs = torch.stack([p.float() for p in pqs_list]).mean(0)
262274
avg_pqs = avg_pqs.clamp(min=torch.finfo(torch.float32).tiny)
263275

264276
for m in modules:
@@ -270,8 +282,8 @@ def _process_group(modules: list[nn.Module]) -> None:
270282
if torch.equal(old_pqs, avg_pqs_dev):
271283
continue
272284
weight = state_dict[f"{nm}.weight"]
273-
ratio = old_pqs.to(dtype=torch.float32, device=weight.device) / avg_pqs_dev.to(
274-
dtype=torch.float32, device=weight.device
285+
ratio = old_pqs.to(dtype=torch.float32, device=weight.device) / avg_pqs.to(
286+
device=weight.device
275287
)
276288
state_dict[f"{nm}.weight"] = (weight.to(torch.float32) * ratio).to(weight.dtype)
277289
requant_weights.add(f"{nm}.weight")
@@ -281,7 +293,7 @@ def _process_group(modules: list[nn.Module]) -> None:
281293
if all(a is not None for a in amaxes):
282294
synced_amax = merge_amax_tensors_for_group(amaxes)
283295

284-
avg_pqs_out = avg_pqs.detach().clone()
296+
avg_pqs_out = avg_pqs.detach().to(pqs_dtype).clone()
285297
for m in modules:
286298
nm = id_to_name.get(id(m))
287299
if nm is None:
@@ -309,14 +321,15 @@ def _process_group(modules: list[nn.Module]) -> None:
309321

310322
def _dummy_forward() -> None:
311323
# Partial forward is OK: hooks record layers reached before failure.
312-
try:
313-
model(torch.ones([1, 2], dtype=torch.long, device=dev))
314-
except Exception as e:
315-
import logging
324+
with torch.inference_mode():
325+
try:
326+
model(torch.ones([1, 2], dtype=torch.long, device=dev))
327+
except Exception as e:
328+
import logging
316329

317-
logging.getLogger(__name__).debug(
318-
"Dummy forward for shared-input detection failed (expected for VLMs): %s", e
319-
)
330+
logging.getLogger(__name__).debug(
331+
"Dummy forward for shared-input detection failed (expected for VLMs): %s", e
332+
)
320333

321334
input_to_linear, _ = collect_shared_input_modules(model, _dummy_forward)
322335
for modules in input_to_linear.values():
@@ -380,9 +393,8 @@ def export_hf_vllm_fq_checkpoint(
380393
weight_name = attr_name.removesuffix("_quantizer")
381394
prefix = f"{module_name}." if module_name else ""
382395
sd_key = f"{prefix}{weight_name}"
383-
assert sd_key not in fakequant_weights, (
384-
f"Weight {sd_key} has already been fakequantized"
385-
)
396+
if sd_key in fakequant_weights:
397+
raise RuntimeError(f"Weight {sd_key} has already been fakequantized")
386398
if sd_key in state_dict:
387399
w = state_dict[sd_key]
388400
if sd_key in requant_weights:
@@ -419,74 +431,75 @@ def export_hf_vllm_fq_checkpoint(
419431
# Rotation is also cleared: the weight was already folded with rotation applied,
420432
# so if fold_weight is called on reload it must not re-rotate the exported weight.
421433
wqs_to_restore: list[tuple[TensorQuantizer, Any]] = []
422-
for _, module in model.named_modules():
423-
if isinstance(module, QuantModule):
424-
for attr_name, quantizer in module.named_children():
425-
if not (attr_name.endswith("weight_quantizer") and quantizer.is_enabled):
426-
continue
427-
if isinstance(quantizer, SequentialQuantizer):
428-
quantizer.disable()
429-
for sub in quantizer:
430-
orig_rotate = sub._rotate
431-
if sub.rotate_is_enabled:
432-
sub._rotate = disable_rotate(sub)
433-
wqs_to_restore.append((sub, orig_rotate))
434-
elif isinstance(quantizer, TensorQuantizer):
435-
quantizer.disable()
436-
orig_rotate = quantizer._rotate
437-
if quantizer.rotate_is_enabled:
438-
quantizer._rotate = disable_rotate(quantizer)
439-
wqs_to_restore.append((quantizer, orig_rotate))
440-
441-
quantizer_state_dict = get_quantizer_state_dict(model)
442-
for key in list(quantizer_state_dict):
443-
if is_weight_quantizer_state_key(key):
444-
# Fakequant amax is folded into HF weights; do not reload weight quantizer tensors.
445-
# Reload must force-disable WQs missing from saved state (see
446-
# ``filter_modelopt_state_quantizer_state_for_model`` assertion in vllm_reload_utils).
447-
quantizer_state_dict.pop(key)
448-
elif key in input_quantizers_folded_pqs:
449-
# pre_quant_scale was folded into the weight; keep the buffer for strict load but
450-
# save identity so activations are not scaled twice.
451-
qstate_val = quantizer_state_dict[key]
452-
if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val:
453-
quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like(
454-
qstate_val["_pre_quant_scale"]
455-
)
456-
457-
# Patch input quantizers with averaged pqs and unified amax so that vLLM's single
458-
# per-group input quantizer sees consistent values (covers both dense qkv and MoE experts).
459-
for iq_key, (avg_pqs, max_input_amax) in pqs_overrides.items():
460-
if iq_key in quantizer_state_dict:
461-
qstate_val = quantizer_state_dict[iq_key]
462-
if isinstance(qstate_val, dict):
463-
if "_pre_quant_scale" in qstate_val:
464-
qstate_val["_pre_quant_scale"] = avg_pqs
465-
if max_input_amax is not None and "_amax" in qstate_val:
466-
qstate_val["_amax"] = max_input_amax
467-
468-
modelopt_state = mto.modelopt_state(model)
469-
# ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild
470-
# ``quantizer_state`` and strip weight-quantizer entries (same policy as
471-
# ``modelopt_state_weights``). Reload synthesizes missing WQ rows with ``_disabled``.
472-
_check_all_weight_quantizers_disabled(model)
473-
qstate = quantizer_state(model)
474-
for key in list(qstate):
475-
if is_weight_quantizer_state_key(key):
476-
qstate.pop(key)
477-
478-
for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []):
479-
if mode_str == "quantize" and "metadata" in m_state:
480-
m_state["metadata"]["quantizer_state"] = qstate
481-
break
482-
483-
# Per-quantizer tensor dict loaded alongside metadata on reload.
484-
modelopt_state["modelopt_state_weights"] = quantizer_state_dict
485-
safe_save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth")
486-
487-
# Step 3: Save HF weights using the pre-built folded state dict.
488-
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)
489-
490-
for wq, orig_rotate in wqs_to_restore:
491-
wq.enable()
492-
wq._rotate = orig_rotate
434+
try:
435+
for _, module in model.named_modules():
436+
if isinstance(module, QuantModule):
437+
for attr_name, quantizer in module.named_children():
438+
if not (attr_name.endswith("weight_quantizer") and quantizer.is_enabled):
439+
continue
440+
if isinstance(quantizer, SequentialQuantizer):
441+
quantizer.disable()
442+
for sub in quantizer:
443+
orig_rotate = sub._rotate
444+
if sub.rotate_is_enabled:
445+
sub._rotate = disable_rotate(sub)
446+
wqs_to_restore.append((sub, orig_rotate))
447+
elif isinstance(quantizer, TensorQuantizer):
448+
quantizer.disable()
449+
orig_rotate = quantizer._rotate
450+
if quantizer.rotate_is_enabled:
451+
quantizer._rotate = disable_rotate(quantizer)
452+
wqs_to_restore.append((quantizer, orig_rotate))
453+
454+
quantizer_state_dict = get_quantizer_state_dict(model)
455+
for key in list(quantizer_state_dict):
456+
if is_weight_quantizer_state_key(key):
457+
# Fakequant amax is folded into HF weights; do not reload weight quantizer tensors.
458+
# Reload must force-disable WQs missing from saved state (see
459+
# ``filter_modelopt_state_quantizer_state_for_model`` assertion in vllm_reload_utils).
460+
quantizer_state_dict.pop(key)
461+
elif key in input_quantizers_folded_pqs:
462+
# pre_quant_scale was folded into the weight; keep the buffer for strict load but
463+
# save identity so activations are not scaled twice.
464+
qstate_val = quantizer_state_dict[key]
465+
if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val:
466+
quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like(
467+
qstate_val["_pre_quant_scale"]
468+
)
469+
470+
# Patch input quantizers with averaged pqs and unified amax so that vLLM's single
471+
# per-group input quantizer sees consistent values (covers both dense qkv and MoE experts).
472+
for iq_key, (avg_pqs, max_input_amax) in pqs_overrides.items():
473+
if iq_key in quantizer_state_dict:
474+
qstate_val = quantizer_state_dict[iq_key]
475+
if isinstance(qstate_val, dict):
476+
if "_pre_quant_scale" in qstate_val:
477+
qstate_val["_pre_quant_scale"] = avg_pqs
478+
if max_input_amax is not None and "_amax" in qstate_val:
479+
qstate_val["_amax"] = max_input_amax
480+
481+
modelopt_state = mto.modelopt_state(model)
482+
# ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild
483+
# ``quantizer_state`` and strip weight-quantizer entries (same policy as
484+
# ``modelopt_state_weights``). Reload synthesizes missing WQ rows with ``_disabled``.
485+
_check_all_weight_quantizers_disabled(model)
486+
qstate = quantizer_state(model)
487+
for key in list(qstate):
488+
if is_weight_quantizer_state_key(key):
489+
qstate.pop(key)
490+
491+
for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []):
492+
if mode_str == "quantize" and "metadata" in m_state:
493+
m_state["metadata"]["quantizer_state"] = qstate
494+
break
495+
496+
# Per-quantizer tensor dict loaded alongside metadata on reload.
497+
modelopt_state["modelopt_state_weights"] = quantizer_state_dict
498+
safe_save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth")
499+
500+
# Step 3: Save HF weights using the pre-built folded state dict.
501+
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)
502+
finally:
503+
for wq, orig_rotate in wqs_to_restore:
504+
wq.enable()
505+
wq._rotate = orig_rotate

0 commit comments

Comments
 (0)