Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ def determine_available_memory(self) -> int:
with disable_compilation(model):
return super().determine_available_memory()

def compile_or_warm_up_model(self) -> None:
def compile_or_warm_up_model(self) -> float:
if (
quant_config["quant_cfg"]
or quant_config["kv_quant_cfg"]
or quant_config["modelopt_state_path"]
):
_fakequant_run_prolog_worker(self)
super().compile_or_warm_up_model()
return super().compile_or_warm_up_model()
137 changes: 135 additions & 2 deletions modelopt/torch/export/plugins/vllm_fakequant_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""Export HuggingFace model to vLLM fakequant checkpoint."""

import logging
from pathlib import Path

import torch
Expand All @@ -26,6 +27,8 @@
from modelopt.torch.quantization.utils import get_quantizer_state_dict
from modelopt.torch.utils import get_unwrapped_name

logger = logging.getLogger(__name__)

__all__ = ["export_hf_vllm_fq_checkpoint"]


Expand All @@ -38,6 +41,105 @@ def disable_rotate(quantizer: TensorQuantizer):
return False


def _materialize_offloaded_weights(
model: nn.Module,
state_dict: dict[str, torch.Tensor],
meta_keys: list[str],
) -> None:
"""Replace meta tensors in state_dict with actual data from accelerate offload hooks.

When a model is loaded with ``device_map="auto"`` and some layers are offloaded
to CPU or disk, ``model.state_dict()`` returns meta tensors (no data) for those
layers. This function walks the model's accelerate hooks to retrieve the actual
weight data and updates state_dict in-place.
"""
hook_map: dict[str, tuple] = {}
for name, module in model.named_modules():
hook = getattr(module, "_hf_hook", None)
if hook is None:
continue
hooks = [hook]
if hasattr(hook, "hooks"):
hooks = hook.hooks
for h in hooks:
if hasattr(h, "weights_map") and h.weights_map is not None:
prefix = f"{name}." if name else ""
hook_map[prefix] = (module, h)
break

materialized = 0
for key in meta_keys:
for prefix, (module, hook) in hook_map.items():
if not key.startswith(prefix):
continue
local_key = key[len(prefix) :]
wmap = hook.weights_map
if hasattr(wmap, "dataset"):
lookup_key = wmap.prefix + local_key
actual_sd = wmap.dataset.state_dict
else:
lookup_key = local_key
actual_sd = wmap
if lookup_key in actual_sd:
state_dict[key] = actual_sd[lookup_key].detach().clone()
materialized += 1
break
else:
logger.warning("Could not materialize meta tensor for key: %s", key)

logger.info("Materialized %d/%d offloaded weights to CPU", materialized, len(meta_keys))


def _save_clean_checkpoint(
model: nn.Module,
clean_sd: dict[str, torch.Tensor],
export_dir: Path,
) -> None:
"""Save clean weights + config directly, bypassing model.save_pretrained().

For accelerate-offloaded models, ``save_pretrained(state_dict=clean_sd)``
ignores the provided state_dict and saves from internal state, leaking
quantizer keys. This function saves ``clean_sd`` directly via safetensors
API, guaranteeing only the intended keys are written.
"""
import json

from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file

# Move to CPU and clone to break shared storage (tied weights like lm_head/embed_tokens).
# safetensors rejects tensors that share underlying storage.
cpu_sd = {k: v.cpu().clone() for k, v in clean_sd.items()}

state_dict_split = split_torch_state_dict_into_shards(cpu_sd, max_shard_size="5GB")
for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items():
shard = {k: cpu_sd[k] for k in tensor_keys}
save_file(shard, str(export_dir / shard_file))
logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard))
Comment on lines +110 to +118
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Shard before cloning tensors to CPU.

cpu_sd = {k: v.cpu().clone() ...} creates a second full copy of the checkpoint in host RAM before sharding. On the offload path, that doubles peak memory and can OOM the exact large-model exports this change is trying to unblock.

💡 Proposed fix
-    # Move to CPU and clone to break shared storage (tied weights like lm_head/embed_tokens).
-    # safetensors rejects tensors that share underlying storage.
-    cpu_sd = {k: v.cpu().clone() for k, v in clean_sd.items()}
-
-    state_dict_split = split_torch_state_dict_into_shards(cpu_sd, max_shard_size="5GB")
+    state_dict_split = split_torch_state_dict_into_shards(clean_sd, max_shard_size="5GB")
     for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items():
-        shard = {k: cpu_sd[k] for k in tensor_keys}
+        # Move only the current shard to CPU to keep peak memory bounded.
+        shard = {k: clean_sd[k].cpu().clone() for k in tensor_keys}
         save_file(shard, str(export_dir / shard_file))
         logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard))
@@
     logger.info(
         "Checkpoint saved: %d weights in %d shard(s)",
-        len(cpu_sd),
+        len(clean_sd),
         len(state_dict_split.filename_to_tensors),
     )

Also applies to: 136-140


if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
(export_dir / "model.safetensors.index.json").write_text(json.dumps(index, indent=2))

if hasattr(model, "config"):
model.config.save_pretrained(export_dir)
config_path = export_dir / "config.json"
if config_path.exists():
config = json.loads(config_path.read_text())
if config.pop("auto_map", None):
config_path.write_text(json.dumps(config, indent=2))
logger.info("Saved config.json (auto_map stripped)")

logger.info(
"Checkpoint saved: %d weights in %d shard(s)",
len(cpu_sd),
len(state_dict_split.filename_to_tensors),
)


def export_hf_vllm_fq_checkpoint(
model: nn.Module,
export_dir: Path | str,
Expand All @@ -62,6 +164,18 @@ def export_hf_vllm_fq_checkpoint(
# parameters are never modified. Apply each weight quantizer's fake-quant
# to the corresponding weight tensor in the copy.
state_dict = model.state_dict()

# Handle accelerate-offloaded models: state_dict() returns meta tensors
# for CPU/disk-offloaded layers. Materialize them from the offload hooks.
meta_keys = [k for k, v in state_dict.items() if v.is_meta]
if meta_keys:
logger.info(
"Found %d meta tensors in state_dict (accelerate offloading). "
"Materializing from offload hooks...",
len(meta_keys),
)
_materialize_offloaded_weights(model, state_dict, meta_keys)

Comment on lines +167 to +178
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast if any offloaded tensors stay on meta.

_materialize_offloaded_weights() only logs misses. If a non-quantizer key is still meta here, the later fake-quant fold or _save_clean_checkpoint() will blow up with a much less actionable error. Please re-check the state dict immediately after materialization and raise with the unresolved keys.

💡 Proposed fix
     if meta_keys:
         logger.info(
             "Found %d meta tensors in state_dict (accelerate offloading). "
             "Materializing from offload hooks...",
             len(meta_keys),
         )
         _materialize_offloaded_weights(model, state_dict, meta_keys)
+        unresolved_meta_keys = [
+            k for k, v in state_dict.items() if v.is_meta and "quantizer" not in k
+        ]
+        if unresolved_meta_keys:
+            shown = ", ".join(unresolved_meta_keys[:10])
+            suffix = " ..." if len(unresolved_meta_keys) > 10 else ""
+            raise RuntimeError(f"Failed to materialize offloaded tensors: {shown}{suffix}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py` around lines 167 - 178,
After calling _materialize_offloaded_weights(model, state_dict, meta_keys)
recompute unresolved_meta = [k for k,v in state_dict.items() if v.is_meta]; if
unresolved_meta is non-empty and contains any keys that are not
quantizer-related (e.g. not containing "quant" or "quantizer"), raise a
RuntimeError listing unresolved_meta and a short message mentioning that
materialization failed and will break subsequent fake-quant folding or
_save_clean_checkpoint; reference the symbols meta_keys,
_materialize_offloaded_weights, state_dict, and _save_clean_checkpoint so the
error helps locate the problem.

fakequant_weights = set()
input_quantizers_folded_pqs = (
set()
Expand All @@ -86,6 +200,23 @@ def export_hf_vllm_fq_checkpoint(
)
if sd_key in state_dict:
w = state_dict[sd_key]
# Quantizer kernels (e.g., fp4_fake_quant_block) require CUDA.
# Offloaded weights materialized to CPU need a GPU hop.
if not w.is_cuda:
# Find a CUDA device: check quantizer buffers/params first,
# then fall back to sibling tensors on the parent module.
cuda_dev = None
for t in list(quantizer.parameters()) + list(quantizer.buffers()):
if t.is_cuda:
cuda_dev = t.device
break
if cuda_dev is None:
for t in module.parameters():
if t.is_cuda:
cuda_dev = t.device
break
if cuda_dev is not None:
w = w.to(cuda_dev)
w_quant = quantizer(w.float()).to(w.dtype).cpu()
# Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
# Only valid when input_quantizer does NOT fake-quant activations. If it does
Expand Down Expand Up @@ -161,8 +292,10 @@ def export_hf_vllm_fq_checkpoint(
modelopt_state["modelopt_state_weights"] = quantizer_state_dict
torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth")

# Step 3: Save HF weights using the pre-built folded state dict.
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)
# Step 3: Save HF weights directly from clean_sd.
# Bypass model.save_pretrained() because accelerate-offloaded models
# ignore the state_dict= argument, leaking quantizer keys into safetensors.
_save_clean_checkpoint(model, clean_sd, export_dir)

for wq, orig_rotate in wqs_to_restore:
wq.enable()
Expand Down
Loading