Skip to content

Commit bd45696

Browse files
committed
addressed comments
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 19608f6 commit bd45696

File tree

3 files changed

+314
-34
lines changed

3 files changed

+314
-34
lines changed

examples/vllm_serve/vllm_reload_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import logging
1716
import re
1817
import warnings
1918
from collections import defaultdict
@@ -248,7 +247,7 @@ def _infer_prefix_remap(
248247
if new_first != first_component:
249248
prefix_remap[first_component] = new_first
250249
except Exception as e:
251-
logging.getLogger(__name__).debug("prefix-remap probe failed for %r: %s", probe_key, e)
250+
warnings.warn(f"prefix-remap probe failed for {probe_key!r}: {e}")
252251
return prefix_remap
253252

254253

@@ -401,6 +400,17 @@ def _has_buffers(state: dict) -> bool:
401400
if is_weight_quantizer_state_key(k) and not state.get("_disabled"):
402401
state = {**state, "_disabled": True}
403402
filtered[k] = state
403+
404+
# Invariant: weight quantizers absent from export must be _disabled.
405+
for wq_k in model_keys:
406+
if not is_weight_quantizer_state_key(wq_k):
407+
continue
408+
wq_state = filtered[wq_k]
409+
assert wq_k in saved or wq_state.get("_disabled"), (
410+
f"Weight quantizer {wq_k!r} is missing from saved quantizer_state but "
411+
f"is not marked _disabled (got _disabled={wq_state.get('_disabled')!r}). "
412+
f"vLLM fakequant export omits weight quantizer keys when weights are folded."
413+
)
404414
metadata["quantizer_state"] = filtered
405415

406416

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 150 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
# limitations under the License.
1515
"""Export HuggingFace model to vLLM fakequant checkpoint."""
1616

17-
import contextlib
1817
import copy
18+
import logging
1919
import re
20+
from collections.abc import Callable
2021
from pathlib import Path
2122
from typing import Any
2223

@@ -27,7 +28,7 @@
2728
from modelopt.torch.quantization.config import RotateConfig
2829
from modelopt.torch.quantization.conversion import quantizer_state
2930
from modelopt.torch.quantization.model_calib import enable_stats_collection, finish_stats_collection
30-
from modelopt.torch.quantization.nn import QuantModule, TensorQuantizer
31+
from modelopt.torch.quantization.nn import QuantModule, SequentialQuantizer, TensorQuantizer
3132
from modelopt.torch.quantization.utils import get_quantizer_state_dict
3233
from modelopt.torch.utils import get_unwrapped_name, safe_save
3334

@@ -37,6 +38,7 @@
3738

3839
__all__ = [
3940
"export_hf_vllm_fq_checkpoint",
41+
"infer_quantizer_prefix_remap",
4042
"is_weight_quantizer_state_key",
4143
"merge_amax_tensors_for_group",
4244
]
@@ -53,6 +55,79 @@ def is_weight_quantizer_state_key(key: str) -> bool:
5355
return bool(_WEIGHT_QUANTIZER_STATE_KEY.search(key))
5456

5557

58+
def infer_quantizer_prefix_remap(
59+
quantizer_keys: dict[str, Any],
60+
map_fun: Callable[[dict[str, Any]], dict[str, Any]],
61+
) -> dict[str, str]:
62+
"""Infer HF root name → vLLM root (e.g. ``backbone`` → ``model``) for reload/export.
63+
64+
Map HF root → vLLM root (e.g. ``backbone`` → ``model``) by probing ``map_fun`` with
65+
synthetic ``<module>.weight`` keys and a 2-D placeholder (quantizer paths are not weight
66+
keys). Keys under the same HF root must agree on the target root or :exc:`ValueError` is
67+
raised; failed probes are skipped. Returns ``{hf_root: vllm_root}`` only where the root
68+
renames; not for arbitrary layer rewrites.
69+
70+
Args:
71+
quantizer_keys: HF quantizer state paths as keys (values unused).
72+
map_fun: HF→vLLM weight ``state_dict`` mapper, same as for ``convert_dict_to_vllm``.
73+
74+
Returns:
75+
``{hf_root: vllm_root}`` for roots that rename; omits identity pairs.
76+
"""
77+
logger = logging.getLogger(__name__)
78+
probe_weight = torch.empty((1, 1))
79+
observed_vllm_root: dict[str, str] = {}
80+
81+
for key in quantizer_keys:
82+
first_component = key.split(".")[0]
83+
last_dot = key.rfind(".")
84+
if last_dot == -1:
85+
continue
86+
probe_key = key[:last_dot] + ".weight"
87+
try:
88+
result = map_fun({probe_key: probe_weight})
89+
if not result:
90+
continue
91+
new_key = next(iter(result))
92+
new_first = new_key.split(".")[0]
93+
except Exception as e:
94+
logger.debug("prefix-remap probe failed for %r: %s", probe_key, e)
95+
continue
96+
97+
if first_component not in observed_vllm_root:
98+
observed_vllm_root[first_component] = new_first
99+
elif observed_vllm_root[first_component] != new_first:
100+
raise ValueError(
101+
"Inconsistent HF→vLLM prefix remap for "
102+
f"{first_component!r}: probes implied "
103+
f"{observed_vllm_root[first_component]!r} and {new_first!r}. "
104+
"map_fun must apply one target root per HF root, or use explicit quantizer "
105+
"key remapping."
106+
)
107+
108+
return {
109+
hf_root: vllm_root
110+
for hf_root, vllm_root in observed_vllm_root.items()
111+
if hf_root != vllm_root
112+
}
113+
114+
115+
def _check_all_weight_quantizers_disabled(model: nn.Module) -> None:
116+
"""Export invariant before writing metadata: every weight quantizer must be off."""
117+
for _, module in model.named_modules():
118+
if not isinstance(module, QuantModule):
119+
continue
120+
for attr_name, quantizer in module.named_children():
121+
if attr_name.endswith("weight_quantizer") and isinstance(
122+
quantizer, (TensorQuantizer, SequentialQuantizer)
123+
):
124+
assert not quantizer.is_enabled, (
125+
f"vLLM fakequant export: {attr_name!r} must be disabled before saving "
126+
f"quantizer_state (weights already folded). "
127+
f"See filter_modelopt_state_quantizer_state_for_model in vllm_reload_utils."
128+
)
129+
130+
56131
def disable_rotate(quantizer: TensorQuantizer):
57132
"""Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type."""
58133
if isinstance(quantizer._rotate, RotateConfig):
@@ -62,35 +137,61 @@ def disable_rotate(quantizer: TensorQuantizer):
62137
return False
63138

64139

65-
def _collect_expert_pre_quant_scales(
140+
def _collect_group_pre_quant_scales(
66141
experts: list[nn.Module],
67142
) -> list[torch.Tensor] | None:
68143
"""Return per-expert ``pre_quant_scale`` tensors if every expert can be averaged; else None.
69144
70145
Skips groups where any expert has no input quantizer, no pqs (e.g. weight-only AWQ INT4),
71146
or a disabled input quantizer (pqs already folded / not used).
72147
"""
73-
pqs_list: list[torch.Tensor] = []
74-
for ex in experts:
75-
iq = getattr(ex, "input_quantizer", None)
76-
if iq is None or not iq.is_enabled or iq.pre_quant_scale is None:
148+
pre_quant_scales: list[torch.Tensor] = []
149+
for expert_module in experts:
150+
input_quantizer = getattr(expert_module, "input_quantizer", None)
151+
if (
152+
input_quantizer is None
153+
or not input_quantizer.is_enabled
154+
or input_quantizer.pre_quant_scale is None
155+
):
77156
return None
78-
pqs_list.append(iq.pre_quant_scale)
79-
return pqs_list
157+
pre_quant_scales.append(input_quantizer.pre_quant_scale)
158+
return pre_quant_scales
80159

81160

82161
def requant_weights_for_export(
83-
quantizer: TensorQuantizer,
84-
w: torch.Tensor,
162+
quantizer: TensorQuantizer | SequentialQuantizer,
163+
weight: torch.Tensor,
85164
) -> torch.Tensor:
86-
"""Requantize weights for export."""
87-
quantizer_copy = copy.deepcopy(quantizer)
88-
quantizer_copy.eval()
89-
quantizer_copy.reset_amax()
90-
enable_stats_collection(quantizer_copy)
91-
quantizer_copy(w)
92-
finish_stats_collection(quantizer_copy)
93-
return quantizer_copy(w.float()).to(w.dtype)
165+
"""Requantize folded weights after resmooth (``TensorQuantizer`` or ``SequentialQuantizer``).
166+
167+
A single ``TensorQuantizer`` is treated as a one-stage chain so the same
168+
calibrate-then-apply steps cover W4A8-style sequential weights (e.g. INT4→FP8).
169+
170+
Deepcopy may leave buffers on the original device; ``.to(device=w.device)`` aligns with
171+
``w`` (e.g. CPU offload).
172+
"""
173+
copied = copy.deepcopy(quantizer).to(device=weight.device)
174+
sequence_quantizers: list[TensorQuantizer] = (
175+
list(copied) if isinstance(copied, SequentialQuantizer) else [copied]
176+
)
177+
178+
for quantizer_copy in sequence_quantizers:
179+
quantizer_copy.eval()
180+
quantizer_copy.reset_amax()
181+
enable_stats_collection(quantizer_copy)
182+
# Match legacy single-quantizer path: first calib uses ``w`` as-is; chains use float.
183+
if len(sequence_quantizers) == 1:
184+
weight_quantized = sequence_quantizers[0](weight)
185+
else:
186+
weight_quantized = weight.float()
187+
for quantizer_copy in sequence_quantizers:
188+
weight_quantized = quantizer_copy(weight_quantized)
189+
for quantizer_copy in sequence_quantizers:
190+
finish_stats_collection(quantizer_copy)
191+
weight_quantized = weight.float()
192+
for quantizer_copy in sequence_quantizers:
193+
weight_quantized = quantizer_copy(weight_quantized)
194+
return weight_quantized.to(weight.dtype)
94195

95196

96197
def merge_amax_tensors_for_group(tensors: list[torch.Tensor]) -> torch.Tensor:
@@ -100,8 +201,10 @@ def merge_amax_tensors_for_group(tensors: list[torch.Tensor]) -> torch.Tensor:
100201
101202
- If every tensor has the same shape, take the element-wise maximum over the group
102203
(conservative when each branch carried the same axis layout).
103-
- If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for
104-
per-channel amax; otherwise fall back to a scalar max over all elements.
204+
- If shapes differ: ``torch.cat(..., dim=0)`` assumes **1D per-channel** amaxes in
205+
fused order (e.g. GQA q/k/v → ``[N_q]`` + ``[N_kv]`` + ``[N_kv]``), matching vLLM’s
206+
grouped quantizer. Not valid for 2D blockwise amax; on failure, **scalar**
207+
max (drops channel structure).
105208
"""
106209
if not tensors:
107210
raise ValueError("merge_amax_tensors_for_group: expected at least one tensor")
@@ -151,7 +254,7 @@ def _resmooth_experts_for_export(
151254
requant_weights: set[str] = set()
152255

153256
def _process_group(modules: list[nn.Module]) -> None:
154-
pqs_list = _collect_expert_pre_quant_scales(modules)
257+
pqs_list = _collect_group_pre_quant_scales(modules)
155258
if pqs_list is None:
156259
return
157260

@@ -205,9 +308,15 @@ def _process_group(modules: list[nn.Module]) -> None:
205308
dev = next(model.parameters()).device
206309

207310
def _dummy_forward() -> None:
208-
# Partial forward is OK: hooks record layers reached before failure (e.g. VLMs).
209-
with contextlib.suppress(Exception):
311+
# Partial forward is OK: hooks record layers reached before failure.
312+
try:
210313
model(torch.ones([1, 2], dtype=torch.long, device=dev))
314+
except Exception as e:
315+
import logging
316+
317+
logging.getLogger(__name__).debug(
318+
"Dummy forward for shared-input detection failed (expected for VLMs): %s", e
319+
)
211320

212321
input_to_linear, _ = collect_shared_input_modules(model, _dummy_forward)
213322
for modules in input_to_linear.values():
@@ -263,7 +372,7 @@ def export_hf_vllm_fq_checkpoint(
263372
for attr_name, quantizer in module.named_children():
264373
if not (
265374
attr_name.endswith("weight_quantizer")
266-
and isinstance(quantizer, TensorQuantizer)
375+
and (isinstance(quantizer, (TensorQuantizer, SequentialQuantizer)))
267376
and quantizer.fake_quant
268377
and quantizer.is_enabled
269378
):
@@ -313,11 +422,16 @@ def export_hf_vllm_fq_checkpoint(
313422
for _, module in model.named_modules():
314423
if isinstance(module, QuantModule):
315424
for attr_name, quantizer in module.named_children():
316-
if (
317-
attr_name.endswith("weight_quantizer")
318-
and isinstance(quantizer, TensorQuantizer)
319-
and quantizer.is_enabled
320-
):
425+
if not (attr_name.endswith("weight_quantizer") and quantizer.is_enabled):
426+
continue
427+
if isinstance(quantizer, SequentialQuantizer):
428+
quantizer.disable()
429+
for sub in quantizer:
430+
orig_rotate = sub._rotate
431+
if sub.rotate_is_enabled:
432+
sub._rotate = disable_rotate(sub)
433+
wqs_to_restore.append((sub, orig_rotate))
434+
elif isinstance(quantizer, TensorQuantizer):
321435
quantizer.disable()
322436
orig_rotate = quantizer._rotate
323437
if quantizer.rotate_is_enabled:
@@ -328,6 +442,8 @@ def export_hf_vllm_fq_checkpoint(
328442
for key in list(quantizer_state_dict):
329443
if is_weight_quantizer_state_key(key):
330444
# Fakequant amax is folded into HF weights; do not reload weight quantizer tensors.
445+
# Reload must force-disable WQs missing from saved state (see
446+
# ``filter_modelopt_state_quantizer_state_for_model`` assertion in vllm_reload_utils).
331447
quantizer_state_dict.pop(key)
332448
elif key in input_quantizers_folded_pqs:
333449
# pre_quant_scale was folded into the weight; keep the buffer for strict load but
@@ -351,10 +467,12 @@ def export_hf_vllm_fq_checkpoint(
351467

352468
modelopt_state = mto.modelopt_state(model)
353469
# ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild
354-
# ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded).
470+
# ``quantizer_state`` and strip weight-quantizer entries (same policy as
471+
# ``modelopt_state_weights``). Reload synthesizes missing WQ rows with ``_disabled``.
472+
_check_all_weight_quantizers_disabled(model)
355473
qstate = quantizer_state(model)
356474
for key in list(qstate):
357-
if is_weight_quantizer_state_key(key) and qstate[key].get("_disabled"):
475+
if is_weight_quantizer_state_key(key):
358476
qstate.pop(key)
359477

360478
for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []):

0 commit comments

Comments
 (0)