-
Notifications
You must be signed in to change notification settings - Fork 342
fix: handle accelerate CPU-offloaded models in FakeQuant export #1194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| # limitations under the License. | ||
| """Export HuggingFace model to vLLM fakequant checkpoint.""" | ||
|
|
||
| import logging | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
|
|
@@ -26,6 +27,8 @@ | |
| from modelopt.torch.quantization.utils import get_quantizer_state_dict | ||
| from modelopt.torch.utils import get_unwrapped_name | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| __all__ = ["export_hf_vllm_fq_checkpoint"] | ||
|
|
||
|
|
||
|
|
@@ -38,6 +41,105 @@ def disable_rotate(quantizer: TensorQuantizer): | |
| return False | ||
|
|
||
|
|
||
| def _materialize_offloaded_weights( | ||
| model: nn.Module, | ||
| state_dict: dict[str, torch.Tensor], | ||
| meta_keys: list[str], | ||
| ) -> None: | ||
| """Replace meta tensors in state_dict with actual data from accelerate offload hooks. | ||
|
|
||
| When a model is loaded with ``device_map="auto"`` and some layers are offloaded | ||
| to CPU or disk, ``model.state_dict()`` returns meta tensors (no data) for those | ||
| layers. This function walks the model's accelerate hooks to retrieve the actual | ||
| weight data and updates state_dict in-place. | ||
| """ | ||
| hook_map: dict[str, tuple] = {} | ||
| for name, module in model.named_modules(): | ||
| hook = getattr(module, "_hf_hook", None) | ||
| if hook is None: | ||
| continue | ||
| hooks = [hook] | ||
| if hasattr(hook, "hooks"): | ||
| hooks = hook.hooks | ||
| for h in hooks: | ||
| if hasattr(h, "weights_map") and h.weights_map is not None: | ||
| prefix = f"{name}." if name else "" | ||
| hook_map[prefix] = (module, h) | ||
| break | ||
|
|
||
| materialized = 0 | ||
| for key in meta_keys: | ||
| for prefix, (module, hook) in hook_map.items(): | ||
| if not key.startswith(prefix): | ||
| continue | ||
| local_key = key[len(prefix) :] | ||
| wmap = hook.weights_map | ||
| if hasattr(wmap, "dataset"): | ||
| lookup_key = wmap.prefix + local_key | ||
| actual_sd = wmap.dataset.state_dict | ||
| else: | ||
| lookup_key = local_key | ||
| actual_sd = wmap | ||
| if lookup_key in actual_sd: | ||
| state_dict[key] = actual_sd[lookup_key].detach().clone() | ||
| materialized += 1 | ||
| break | ||
| else: | ||
| logger.warning("Could not materialize meta tensor for key: %s", key) | ||
|
|
||
| logger.info("Materialized %d/%d offloaded weights to CPU", materialized, len(meta_keys)) | ||
|
|
||
|
|
||
| def _save_clean_checkpoint( | ||
| model: nn.Module, | ||
| clean_sd: dict[str, torch.Tensor], | ||
| export_dir: Path, | ||
| ) -> None: | ||
| """Save clean weights + config directly, bypassing model.save_pretrained(). | ||
|
|
||
| For accelerate-offloaded models, ``save_pretrained(state_dict=clean_sd)`` | ||
| ignores the provided state_dict and saves from internal state, leaking | ||
| quantizer keys. This function saves ``clean_sd`` directly via safetensors | ||
| API, guaranteeing only the intended keys are written. | ||
| """ | ||
| import json | ||
|
|
||
| from huggingface_hub import split_torch_state_dict_into_shards | ||
| from safetensors.torch import save_file | ||
|
|
||
| # Move to CPU and clone to break shared storage (tied weights like lm_head/embed_tokens). | ||
| # safetensors rejects tensors that share underlying storage. | ||
| cpu_sd = {k: v.cpu().clone() for k, v in clean_sd.items()} | ||
|
|
||
| state_dict_split = split_torch_state_dict_into_shards(cpu_sd, max_shard_size="5GB") | ||
| for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items(): | ||
| shard = {k: cpu_sd[k] for k in tensor_keys} | ||
| save_file(shard, str(export_dir / shard_file)) | ||
| logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard)) | ||
|
Comment on lines
+110
to
+118
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shard before cloning tensors to CPU.
💡 Proposed fix- # Move to CPU and clone to break shared storage (tied weights like lm_head/embed_tokens).
- # safetensors rejects tensors that share underlying storage.
- cpu_sd = {k: v.cpu().clone() for k, v in clean_sd.items()}
-
- state_dict_split = split_torch_state_dict_into_shards(cpu_sd, max_shard_size="5GB")
+ state_dict_split = split_torch_state_dict_into_shards(clean_sd, max_shard_size="5GB")
for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items():
- shard = {k: cpu_sd[k] for k in tensor_keys}
+ # Move only the current shard to CPU to keep peak memory bounded.
+ shard = {k: clean_sd[k].cpu().clone() for k in tensor_keys}
save_file(shard, str(export_dir / shard_file))
logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard))
@@
logger.info(
"Checkpoint saved: %d weights in %d shard(s)",
- len(cpu_sd),
+ len(clean_sd),
len(state_dict_split.filename_to_tensors),
)Also applies to: 136-140 |
||
|
|
||
| if state_dict_split.is_sharded: | ||
| index = { | ||
| "metadata": state_dict_split.metadata, | ||
| "weight_map": state_dict_split.tensor_to_filename, | ||
| } | ||
| (export_dir / "model.safetensors.index.json").write_text(json.dumps(index, indent=2)) | ||
|
|
||
| if hasattr(model, "config"): | ||
| model.config.save_pretrained(export_dir) | ||
| config_path = export_dir / "config.json" | ||
| if config_path.exists(): | ||
| config = json.loads(config_path.read_text()) | ||
| if config.pop("auto_map", None): | ||
| config_path.write_text(json.dumps(config, indent=2)) | ||
| logger.info("Saved config.json (auto_map stripped)") | ||
|
|
||
| logger.info( | ||
| "Checkpoint saved: %d weights in %d shard(s)", | ||
| len(cpu_sd), | ||
| len(state_dict_split.filename_to_tensors), | ||
| ) | ||
|
|
||
|
|
||
| def export_hf_vllm_fq_checkpoint( | ||
| model: nn.Module, | ||
| export_dir: Path | str, | ||
|
|
@@ -62,6 +164,18 @@ def export_hf_vllm_fq_checkpoint( | |
| # parameters are never modified. Apply each weight quantizer's fake-quant | ||
| # to the corresponding weight tensor in the copy. | ||
| state_dict = model.state_dict() | ||
|
|
||
| # Handle accelerate-offloaded models: state_dict() returns meta tensors | ||
| # for CPU/disk-offloaded layers. Materialize them from the offload hooks. | ||
| meta_keys = [k for k, v in state_dict.items() if v.is_meta] | ||
| if meta_keys: | ||
| logger.info( | ||
| "Found %d meta tensors in state_dict (accelerate offloading). " | ||
| "Materializing from offload hooks...", | ||
| len(meta_keys), | ||
| ) | ||
| _materialize_offloaded_weights(model, state_dict, meta_keys) | ||
|
|
||
|
Comment on lines
+167
to
+178
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fail fast if any offloaded tensors stay on
💡 Proposed fix if meta_keys:
logger.info(
"Found %d meta tensors in state_dict (accelerate offloading). "
"Materializing from offload hooks...",
len(meta_keys),
)
_materialize_offloaded_weights(model, state_dict, meta_keys)
+ unresolved_meta_keys = [
+ k for k, v in state_dict.items() if v.is_meta and "quantizer" not in k
+ ]
+ if unresolved_meta_keys:
+ shown = ", ".join(unresolved_meta_keys[:10])
+ suffix = " ..." if len(unresolved_meta_keys) > 10 else ""
+ raise RuntimeError(f"Failed to materialize offloaded tensors: {shown}{suffix}")🤖 Prompt for AI Agents |
||
| fakequant_weights = set() | ||
| input_quantizers_folded_pqs = ( | ||
| set() | ||
|
|
@@ -86,6 +200,23 @@ def export_hf_vllm_fq_checkpoint( | |
| ) | ||
| if sd_key in state_dict: | ||
| w = state_dict[sd_key] | ||
| # Quantizer kernels (e.g., fp4_fake_quant_block) require CUDA. | ||
| # Offloaded weights materialized to CPU need a GPU hop. | ||
| if not w.is_cuda: | ||
| # Find a CUDA device: check quantizer buffers/params first, | ||
| # then fall back to sibling tensors on the parent module. | ||
| cuda_dev = None | ||
| for t in list(quantizer.parameters()) + list(quantizer.buffers()): | ||
| if t.is_cuda: | ||
| cuda_dev = t.device | ||
| break | ||
| if cuda_dev is None: | ||
| for t in module.parameters(): | ||
| if t.is_cuda: | ||
| cuda_dev = t.device | ||
| break | ||
| if cuda_dev is not None: | ||
| w = w.to(cuda_dev) | ||
| w_quant = quantizer(w.float()).to(w.dtype).cpu() | ||
| # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) | ||
| # Only valid when input_quantizer does NOT fake-quant activations. If it does | ||
|
|
@@ -161,8 +292,10 @@ def export_hf_vllm_fq_checkpoint( | |
| modelopt_state["modelopt_state_weights"] = quantizer_state_dict | ||
| torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") | ||
|
|
||
| # Step 3: Save HF weights using the pre-built folded state dict. | ||
| model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) | ||
| # Step 3: Save HF weights directly from clean_sd. | ||
| # Bypass model.save_pretrained() because accelerate-offloaded models | ||
| # ignore the state_dict= argument, leaking quantizer keys into safetensors. | ||
| _save_clean_checkpoint(model, clean_sd, export_dir) | ||
|
|
||
| for wq, orig_rotate in wqs_to_restore: | ||
| wq.enable() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.