Skip to content

Commit 6be72a8

Browse files
committed
minor
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent a719ae2 commit 6be72a8

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -621,19 +621,20 @@ def export_hf_vllm_fq_checkpoint(
621621
qstate_val["_amax"] = max_input_amax
622622

623623
modelopt_state = mto.modelopt_state(model)
624-
# ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild
625-
# ``quantizer_state`` and strip weight-quantizer entries (same policy as
626-
# ``modelopt_state_weights``). Reload synthesizes missing WQ rows with ``_disabled``.
627624
_check_all_weight_quantizers_disabled(model)
625+
# Rebuild quantizer_state from the live model (post-disable) and strip weight-quantizer
626+
# entries. Apply to every mode that carries quantizer_state so that stale entries from
627+
# a calibrate pass (which also stores quantizer_state in its metadata) are cleaned up.
628+
# Reload synthesizes missing WQ rows with ``_disabled`` via
629+
# ``filter_modelopt_state_quantizer_state_for_model``.
628630
qstate = quantizer_state(model)
629631
for key in list(qstate):
630632
if is_weight_quantizer_state_key(key):
631633
qstate.pop(key)
632-
633-
for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []):
634-
if mode_str == "quantize" and "metadata" in m_state:
635-
m_state["metadata"]["quantizer_state"] = qstate
636-
break
634+
for _mode_str, m_state in modelopt_state.get("modelopt_state_dict", []):
635+
md = m_state.get("metadata", {})
636+
if "quantizer_state" in md:
637+
md["quantizer_state"] = qstate
637638

638639
# Per-quantizer tensor dict loaded alongside metadata on reload.
639640
modelopt_state["modelopt_state_weights"] = quantizer_state_dict

0 commit comments

Comments
 (0)