Skip to content

Commit 24891f0

Browse files
fix: megatron export correctness for TP>1 GQA, single-file MTP, and Hub remote code (#1209)
### What does this PR do? Type of change: Bug fix Three correctness fixes for the Megatron Core GPT export pipeline: **1. `_qkv_slicing`: reshape failure with TP>1 on GQA models** When tensor parallelism is enabled, the `linear_qkv` weight tensor arriving in `_qkv_slicing` is already TP-sharded, so `weight.shape[0]` equals `per_rank_qkv_dim * head_size`, not `qkv_total_dim * head_size`. All five reshape/`arange` operations were using the global `qkv_total_dim`, causing a runtime shape mismatch for any GQA model with TP > 1. The fix derives `per_rank_qkv_dim` and `num_query_groups_local` from the actual tensor shape, making the logic correct for any TP degree (a no-op for TP=1). **2. `_get_mtp_state_dict`: `EntryNotFoundError` for non-sharded models** `hf_hub_download("model.safetensors.index.json")` raises `EntryNotFoundError` for small models that ship a single `model.safetensors` rather than a sharded index. The function now catches this and falls back to downloading/reading `model.safetensors` directly, scanning its keys with `safe_open`. The same two-path logic applies to local directories. **3. `copy_remote_code`: `ValueError` for Hub model IDs** `copy_remote_code` only accepted local directory paths and raised `ValueError` for HuggingFace Hub model IDs (e.g. `"meta-llama/Llama-3.2-1B"`). The function now falls back to `list_repo_files` + `hf_hub_download` to fetch and copy top-level `.py` files (custom modeling code) when the path is not a local directory. ### Usage ```python # TP>1 GQA export now works (previously raised RuntimeError on reshape) export_mcore_gpt_to_hf(gqa_model, "meta-llama/Llama-3.2-1B", export_dir="./out", dtype=torch.bfloat16) # Models with a single model.safetensors now have their MTP weights exported export_mcore_gpt_to_hf(model, "./small_model_dir", export_dir="./out", dtype=torch.bfloat16) # Hub model IDs no longer raise ValueError in copy_remote_code export_mcore_gpt_to_hf(model, "org/custom-model-with-remote-code", export_dir="./out", dtype=torch.bfloat16) ``` ### Testing New tests added in `tests/gpu_megatron/torch/export/`: - `test_unified_export_megatron.py::test_qkv_slicing_gqa_tp2` — FP8-quantized GQA model export with TP=2 (`num_query_groups=2 < num_attention_heads=8`), exercises both the weight reshape and per-channel weight-scale reshape paths. - `test_unified_export_megatron.py::test_mtp_state_dict_single_safetensors` — unit test verifying MTP weights are collected from a single `model.safetensors` file. - `test_unified_export_megatron.py::test_mtp_state_dict_index_file` — unit test verifying MTP weights are collected from a sharded checkpoint. - `test_unified_export_megatron.py::test_mtp_state_dict_no_mtp_keys` — edge case: no MTP keys → empty dict, no side effects. - `test_hf_checkpoint_utils.py` — four tests covering `copy_remote_code` for local directories and Hub model IDs (with and without `.py` files). ### Before your PR is "*Ready for review*" - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A ### Additional Information Fixes reported against Megatron export when running quantization with TP>1, small non-sharded HF models, and HuggingFace Hub model IDs passed to `export_mcore_gpt_to_hf`. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Export functionality now supports downloading code directly from Hugging Face Hub model repositories in addition to local directories. * **Bug Fixes** * Improved safetensors loading with better error handling for missing model entries and support for both single and sharded weight files. * Enhanced tensor slicing behavior for multi-GPU distributed export scenarios. * **Tests** * Added comprehensive test coverage for Hugging Face integration and export functionality. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent e5de5ec commit 24891f0

File tree

4 files changed

+337
-35
lines changed

4 files changed

+337
-35
lines changed

modelopt/torch/export/plugins/hf_checkpoint_utils.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pathlib import Path
2222

2323
import torch
24+
from huggingface_hub import hf_hub_download, list_repo_files
2425
from safetensors.torch import safe_open
2526
from tqdm import tqdm
2627

@@ -35,24 +36,27 @@ def copy_remote_code(
3536
we need to copy them to the export directory for seamless integration with inference
3637
frameworks.
3738
39+
If ``pretrained_model_path`` is a local directory, Python files are copied directly.
40+
If it is a HuggingFace Hub model ID, Python files are downloaded from the Hub first.
41+
3842
Args:
39-
pretrained_model_path: Path to the pretrained model.
43+
pretrained_model_path: Local path to the pretrained model or HuggingFace Hub model ID.
4044
save_directory: Path to the save directory.
41-
42-
Raises:
43-
ValueError: If the pretrained model path is not a directory.
4445
"""
4546
hf_checkpoint_path = Path(pretrained_model_path)
4647
save_dir = Path(save_directory)
4748

48-
if not hf_checkpoint_path.is_dir():
49-
raise ValueError(
50-
f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory."
51-
)
52-
53-
for py_file in hf_checkpoint_path.glob("*.py"):
54-
if py_file.is_file():
55-
shutil.copy(py_file, save_dir / py_file.name)
49+
if hf_checkpoint_path.is_dir():
50+
for py_file in hf_checkpoint_path.glob("*.py"):
51+
if py_file.is_file():
52+
shutil.copy(py_file, save_dir / py_file.name)
53+
else:
54+
# Hub model ID: download any top-level .py files (custom modeling code)
55+
repo_id = str(pretrained_model_path)
56+
for filename in list_repo_files(repo_id):
57+
if "/" not in filename and filename.endswith(".py"):
58+
local_path = hf_hub_download(repo_id=repo_id, filename=filename)
59+
shutil.copy(local_path, save_dir / filename)
5660

5761

5862
def load_multimodal_components(

modelopt/torch/export/unified_export_megatron.py

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
import torch
2929
import torch.distributed
3030
from huggingface_hub import hf_hub_download
31+
from huggingface_hub.errors import EntryNotFoundError
32+
from safetensors import safe_open
3133
from safetensors.torch import save_file
3234

3335
from modelopt import __version__
@@ -527,29 +529,57 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]:
527529
# TODO Implement MTP export for quantized MTP
528530
# Hacky version for now: copy MTP weights from pretrained model
529531
mtp_state_dict = {}
530-
if self._hf_pretrained_model_name:
531-
if os.path.isdir(self._hf_pretrained_model_name):
532-
safetensors_index_file = (
533-
Path(self._hf_pretrained_model_name) / "model.safetensors.index.json"
534-
)
535-
else:
536-
safetensors_index_file = hf_hub_download(
537-
repo_id=self._hf_pretrained_model_name, filename="model.safetensors.index.json"
532+
if not self._hf_pretrained_model_name:
533+
return mtp_state_dict
534+
535+
mtp_exists = False
536+
537+
if os.path.isdir(self._hf_pretrained_model_name):
538+
safetensors_index_file = (
539+
Path(self._hf_pretrained_model_name) / "model.safetensors.index.json"
540+
)
541+
single_safetensors_file = Path(self._hf_pretrained_model_name) / "model.safetensors"
542+
else:
543+
try:
544+
safetensors_index_file = Path(
545+
hf_hub_download(
546+
repo_id=self._hf_pretrained_model_name,
547+
filename="model.safetensors.index.json",
548+
)
538549
)
550+
single_safetensors_file = None
551+
except EntryNotFoundError:
552+
# Model uses a single unsharded safetensors file — check it for MTP weights.
553+
safetensors_index_file = None
554+
try:
555+
single_safetensors_file = Path(
556+
hf_hub_download(
557+
repo_id=self._hf_pretrained_model_name,
558+
filename="model.safetensors",
559+
)
560+
)
561+
except EntryNotFoundError:
562+
return mtp_state_dict
539563

564+
if safetensors_index_file is not None and safetensors_index_file.exists():
540565
print(f"Exporting MTP: using safetensors_index_file: {safetensors_index_file}")
541-
mtp_exists = False
542-
if safetensors_index_file and os.path.exists(safetensors_index_file):
543-
with open(safetensors_index_file) as f:
544-
safetensors_index = json.load(f)
545-
model_dir = Path(safetensors_index_file).parent
546-
for key in safetensors_index["weight_map"]:
566+
with open(safetensors_index_file) as f:
567+
safetensors_index = json.load(f)
568+
model_dir = safetensors_index_file.parent
569+
for key in safetensors_index["weight_map"]:
570+
if key.startswith("mtp.") and key not in self._state_dict:
571+
mtp_state_dict[key] = get_safetensor(model_dir, key)
572+
mtp_exists = True
573+
elif single_safetensors_file is not None and single_safetensors_file.exists():
574+
print(f"Exporting MTP: using single safetensors file: {single_safetensors_file}")
575+
with safe_open(str(single_safetensors_file), framework="pt", device="cpu") as f:
576+
for key in f.keys(): # noqa: SIM118
547577
if key.startswith("mtp.") and key not in self._state_dict:
548-
mtp_state_dict[key] = get_safetensor(model_dir, key)
578+
mtp_state_dict[key] = f.get_tensor(key)
549579
mtp_exists = True
550580

551-
if mtp_exists:
552-
self.exclude_modules.append("mtp*")
581+
if mtp_exists:
582+
self.exclude_modules.append("mtp*")
553583
return mtp_state_dict
554584

555585
def _get_mamba_layer_state_dict(self, layer, layer_id):
@@ -985,17 +1015,22 @@ def _qkv_slicing(
9851015
)
9861016
hidden_size = 2 * hidden_size
9871017

988-
weight = weight.reshape([qkv_total_dim, head_size, hidden_size])
1018+
# When TP > 1 the weight tensor is already sharded: shape[0] = per_rank_qkv_dim, not
1019+
# qkv_total_dim. Derive the per-rank dimensions from the actual tensor shape so that
1020+
# all subsequent reshape/slice operations are correct regardless of TP degree.
1021+
per_rank_qkv_dim = weight.shape[0] // head_size
1022+
num_query_groups_local = num_query_groups * per_rank_qkv_dim // qkv_total_dim
1023+
weight = weight.reshape([per_rank_qkv_dim, head_size, hidden_size])
9891024
weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
9901025

9911026
q_slice = torch.cat(
9921027
[
9931028
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
994-
for i in range(num_query_groups)
1029+
for i in range(num_query_groups_local)
9951030
]
9961031
)
997-
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
998-
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
1032+
k_slice = torch.arange(heads_per_group, per_rank_qkv_dim, (heads_per_group + 2))
1033+
v_slice = torch.arange(heads_per_group + 1, per_rank_qkv_dim, (heads_per_group + 2))
9991034
## Example of slices
10001035
## 7b: num_query_groups = head_num = 32,
10011036
## q_slice = [0, 3, 6, 9 , ... 90, 93]
@@ -1020,7 +1055,7 @@ def _qkv_slicing(
10201055
weight_scale_dtype = weight_scale.dtype
10211056
weight_scale_hidden_size = weight_scale.shape[-1]
10221057
weight_scale = weight_scale.to(dtype=float).reshape(
1023-
[qkv_total_dim, head_size, weight_scale_hidden_size]
1058+
[per_rank_qkv_dim, head_size, weight_scale_hidden_size]
10241059
)
10251060
proj_weight_scales = [
10261061
weight_scale[s]
@@ -1061,7 +1096,7 @@ def _qkv_slicing(
10611096
if key == "bias":
10621097
# Slice bias similar to weight
10631098
bias = val.detach().clone()
1064-
bias = bias.reshape([qkv_total_dim, head_size])
1099+
bias = bias.reshape([per_rank_qkv_dim, head_size])
10651100
proj_biases = [bias[s].reshape(-1) for s in slices]
10661101
proj_bias_keys = [q_proj_prefix + key, k_proj_prefix + key, v_proj_prefix + key]
10671102
for bias_tensor, bias_key in zip(proj_biases, proj_bias_keys):
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for modelopt/torch/export/plugins/hf_checkpoint_utils.py"""
17+
18+
from unittest.mock import patch
19+
20+
from modelopt.torch.export.plugins.hf_checkpoint_utils import copy_remote_code
21+
22+
23+
def test_copy_remote_code_local_dir(tmp_path):
24+
"""copy_remote_code copies top-level .py files from a local directory."""
25+
src_dir = tmp_path / "src"
26+
src_dir.mkdir()
27+
(src_dir / "modeling_custom.py").write_text("# custom model")
28+
(src_dir / "configuration_custom.py").write_text("# custom config")
29+
(src_dir / "not_python.txt").write_text("not python")
30+
(src_dir / "subdir").mkdir()
31+
(src_dir / "subdir" / "nested.py").write_text("# nested — should not be copied")
32+
33+
dst_dir = tmp_path / "dst"
34+
dst_dir.mkdir()
35+
36+
copy_remote_code(src_dir, dst_dir)
37+
38+
assert (dst_dir / "modeling_custom.py").read_text() == "# custom model"
39+
assert (dst_dir / "configuration_custom.py").read_text() == "# custom config"
40+
assert not (dst_dir / "not_python.txt").exists(), "non-.py files should not be copied"
41+
assert not (dst_dir / "nested.py").exists(), "nested .py files should not be copied"
42+
43+
44+
def test_copy_remote_code_local_dir_no_py_files(tmp_path):
45+
"""copy_remote_code is a no-op when the local directory has no .py files."""
46+
src_dir = tmp_path / "src"
47+
src_dir.mkdir()
48+
(src_dir / "config.json").write_text("{}")
49+
50+
dst_dir = tmp_path / "dst"
51+
dst_dir.mkdir()
52+
53+
copy_remote_code(src_dir, dst_dir) # should not raise
54+
55+
assert list(dst_dir.iterdir()) == [], "no files should be copied"
56+
57+
58+
def test_copy_remote_code_hub_id(tmp_path):
59+
"""copy_remote_code downloads and copies top-level .py files from a Hub model ID."""
60+
dst_dir = tmp_path / "dst"
61+
dst_dir.mkdir()
62+
63+
# Create a fake cached file that hf_hub_download would return
64+
cached_py = tmp_path / "cached_modeling_custom.py"
65+
cached_py.write_text("# custom hub model")
66+
67+
repo_files = [
68+
"modeling_custom.py", # top-level .py — should be downloaded
69+
"config.json", # non-.py — skip
70+
"model.safetensors", # non-.py — skip
71+
"subdir/nested.py", # subdirectory .py — skip (contains "/")
72+
]
73+
74+
with (
75+
patch(
76+
"modelopt.torch.export.plugins.hf_checkpoint_utils.list_repo_files",
77+
return_value=repo_files,
78+
) as mock_list,
79+
patch(
80+
"modelopt.torch.export.plugins.hf_checkpoint_utils.hf_hub_download",
81+
return_value=str(cached_py),
82+
) as mock_download,
83+
):
84+
copy_remote_code("meta-llama/Llama-3.2-1B", dst_dir)
85+
86+
mock_list.assert_called_once_with("meta-llama/Llama-3.2-1B")
87+
# Only the top-level .py should have been downloaded
88+
mock_download.assert_called_once_with(
89+
repo_id="meta-llama/Llama-3.2-1B", filename="modeling_custom.py"
90+
)
91+
assert (dst_dir / "modeling_custom.py").read_text() == "# custom hub model"
92+
93+
94+
def test_copy_remote_code_hub_id_no_py_files(tmp_path):
95+
"""copy_remote_code is a no-op when the Hub repo has no top-level .py files."""
96+
dst_dir = tmp_path / "dst"
97+
dst_dir.mkdir()
98+
99+
with (
100+
patch(
101+
"modelopt.torch.export.plugins.hf_checkpoint_utils.list_repo_files",
102+
return_value=["config.json", "model.safetensors"],
103+
),
104+
patch("modelopt.torch.export.plugins.hf_checkpoint_utils.hf_hub_download") as mock_download,
105+
):
106+
copy_remote_code("meta-llama/Llama-3.2-1B", dst_dir)
107+
108+
mock_download.assert_not_called()
109+
assert list(dst_dir.iterdir()) == []

0 commit comments

Comments
 (0)