Skip to content

Commit 6395b1e

Browse files
copy custom modeling files to pruned checkpoint dirs (#1245)
## Summary pruned FFN checkpoints created by `init_child_from_parent` were missing custom `.py` files (e.g. `modeling_nemotron_h.py`, `configuration_nemotron_h.py`). without them, `AutoConfig.from_pretrained(..., trust_remote_code=True)` fails silently and the checkpoint is excluded from the replacement library, reducing MIP candidate diversity. adds `copy_remote_code_files(source_dir, output_dir)` to `checkpoint_utils_hf.py` that copies `*.py` files from the source checkpoint root. called from `init_child_from_parent` after `copy_tokenizer` when `descriptor.requires_trust_remote_code()` is true. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a new utility to automatically copy Python model files from checkpoints during training initialization. * **Tests** * Added unit tests covering selective Python file copying, preservation of existing files, and proper handling of missing source directories. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 38d9522 commit 6395b1e

File tree

7 files changed

+97
-137
lines changed

7 files changed

+97
-137
lines changed

modelopt/torch/export/plugins/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@
2525

2626
with import_plugin("vllm_fakequant_megatron"):
2727
from .vllm_fakequant_megatron import *
28+
29+
with import_plugin("hf_checkpoint_utils"):
30+
from .hf_checkpoint_utils import *

modelopt/torch/export/plugins/hf_checkpoint_utils.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121
from pathlib import Path
2222

2323
import torch
24-
from huggingface_hub import hf_hub_download, list_repo_files
24+
from huggingface_hub import snapshot_download
2525
from safetensors.torch import safe_open
2626
from tqdm import tqdm
2727

2828

29-
def copy_remote_code(
30-
pretrained_model_path: str | os.PathLike,
31-
save_directory: str | os.PathLike,
29+
def copy_hf_ckpt_remote_code(
30+
pretrained_model_path: str | os.PathLike, save_directory: str | os.PathLike
3231
):
3332
"""Copy remote code from pretrained model to save directory.
3433
@@ -37,26 +36,25 @@ def copy_remote_code(
3736
frameworks.
3837
3938
If ``pretrained_model_path`` is a local directory, Python files are copied directly.
40-
If it is a HuggingFace Hub model ID, Python files are downloaded from the Hub first.
39+
If it's a HF Hub model ID (e.g. ``nvidia/NVIDIA-Nemotron-Nano-12B-v2``), files are downloaded from the Hub.
4140
4241
Args:
4342
pretrained_model_path: Local path to the pretrained model or HuggingFace Hub model ID.
4443
save_directory: Path to the save directory.
4544
"""
4645
hf_checkpoint_path = Path(pretrained_model_path)
4746
save_dir = Path(save_directory)
47+
save_dir.mkdir(parents=True, exist_ok=True)
4848

4949
if hf_checkpoint_path.is_dir():
5050
for py_file in hf_checkpoint_path.glob("*.py"):
51-
if py_file.is_file():
52-
shutil.copy(py_file, save_dir / py_file.name)
51+
shutil.copy2(py_file, save_dir / py_file.name)
5352
else:
54-
# Hub model ID: download any top-level .py files (custom modeling code)
55-
repo_id = str(pretrained_model_path)
56-
for filename in list_repo_files(repo_id):
57-
if "/" not in filename and filename.endswith(".py"):
58-
local_path = hf_hub_download(repo_id=repo_id, filename=filename)
59-
shutil.copy(local_path, save_dir / filename)
53+
snapshot_download(
54+
repo_id=str(pretrained_model_path),
55+
local_dir=str(save_dir),
56+
allow_patterns=["*.py"],
57+
)
6058

6159

6260
def load_multimodal_components(

modelopt/torch/export/unified_export_megatron.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
QUANTIZATION_NONE,
4646
QUANTIZATION_NVFP4,
4747
)
48-
from .plugins.hf_checkpoint_utils import copy_remote_code, load_multimodal_components
48+
from .plugins.hf_checkpoint_utils import copy_hf_ckpt_remote_code, load_multimodal_components
4949
from .plugins.mcore_common import all_mcore_hf_export_mapping
5050
from .plugins.mcore_custom import (
5151
CustomModuleMapping,
@@ -349,7 +349,7 @@ def save_pretrained(
349349
torch.distributed.barrier()
350350

351351
if is_last_stage_main_rank and self._hf_config is not None:
352-
copy_remote_code(pretrained_model_name_or_path, save_directory)
352+
copy_hf_ckpt_remote_code(pretrained_model_name_or_path, save_directory)
353353

354354
# Newer versions of VLLM expect config.json with hf_quant_config
355355
config_json_file = save_directory + "/config.json"

modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import yaml
2525
from transformers import AutoModelForCausalLM
2626

27+
from modelopt.torch.export import copy_hf_ckpt_remote_code
28+
2729
from ...anymodel.model_descriptor import ModelDescriptor, ModelDescriptorFactory
2830
from ...anymodel.puzzformer import deci_x_patcher
2931
from ..checkpoint_utils import copy_tokenizer, load_state_dict
@@ -87,6 +89,9 @@ def init_child_from_parent(
8789
trust_remote_code=descriptor.requires_trust_remote_code(),
8890
)
8991

92+
if descriptor.requires_trust_remote_code():
93+
copy_hf_ckpt_remote_code(parent_checkpoint_dir, output_checkpoint_dir)
94+
9095
parent_model_config = load_model_config(
9196
parent_checkpoint_dir, trust_remote_code=descriptor.requires_trust_remote_code()
9297
)

tests/_test_utils/torch/puzzletron/utils.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
import torch
2020
from _test_utils.torch.transformers_models import get_tiny_tokenizer
2121
from datasets import Dataset, DatasetDict
22-
from huggingface_hub import snapshot_download
2322
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizerBase
2423

2524
import modelopt.torch.puzzletron as mtpz
2625
import modelopt.torch.utils.distributed as dist
26+
from modelopt.torch.export import copy_hf_ckpt_remote_code
2727

2828

2929
def setup_test_model_and_data(
@@ -189,21 +189,11 @@ def create_and_save_small_hf_model(
189189
submodule._tied_weights_keys = None
190190
model.save_pretrained(output_path, save_original_format=False)
191191

192-
# Save tokenizer
192+
# Save tokenizer, config, and custom code files
193193
tokenizer.save_pretrained(output_path)
194-
195-
# Save config
196194
config.save_pretrained(output_path)
197-
198-
# Download trust_remote_code .py files from HF hub into the checkpoint directory so that
199-
# force_cache_dynamic_modules can resolve classes from the local path.
200-
# save_pretrained only saves weights + config, not these .py files.
201195
if hasattr(config, "auto_map") and isinstance(config.auto_map, dict):
202-
snapshot_download(
203-
repo_id=hf_model_name,
204-
local_dir=output_path,
205-
allow_patterns=["*.py"],
206-
)
196+
copy_hf_ckpt_remote_code(hf_model_name, output_path)
207197

208198

209199
def save_dummy_dataset(dataset_path: Path | str):

tests/gpu_megatron/torch/export/test_hf_checkpoint_utils.py

Lines changed: 0 additions & 109 deletions
This file was deleted.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for modelopt/torch/export/plugins/hf_checkpoint_utils.py"""
17+
18+
from unittest.mock import patch
19+
20+
import pytest
21+
22+
pytest.importorskip("huggingface_hub")
23+
24+
from modelopt.torch.export import copy_hf_ckpt_remote_code
25+
26+
27+
def test_copy_hf_ckpt_remote_code_local_dir(tmp_path):
28+
"""copy_hf_ckpt_remote_code copies top-level .py files from a local directory."""
29+
src_dir = tmp_path / "src"
30+
src_dir.mkdir()
31+
(src_dir / "modeling_custom.py").write_text("# custom model")
32+
(src_dir / "configuration_custom.py").write_text("# custom config")
33+
(src_dir / "not_python.txt").write_text("not python")
34+
(src_dir / "subdir").mkdir()
35+
(src_dir / "subdir" / "nested.py").write_text("# nested — should not be copied")
36+
37+
dst_dir = tmp_path / "dst"
38+
dst_dir.mkdir()
39+
40+
copy_hf_ckpt_remote_code(src_dir, dst_dir)
41+
42+
assert (dst_dir / "modeling_custom.py").read_text() == "# custom model"
43+
assert (dst_dir / "configuration_custom.py").read_text() == "# custom config"
44+
assert not (dst_dir / "not_python.txt").exists(), "non-.py files should not be copied"
45+
assert not (dst_dir / "nested.py").exists(), "nested .py files should not be copied"
46+
47+
48+
def test_copy_hf_ckpt_remote_code_local_dir_no_py_files(tmp_path):
49+
"""copy_hf_ckpt_remote_code is a no-op when the local directory has no .py files."""
50+
src_dir = tmp_path / "src"
51+
src_dir.mkdir()
52+
(src_dir / "config.json").write_text("{}")
53+
54+
dst_dir = tmp_path / "dst"
55+
dst_dir.mkdir()
56+
57+
copy_hf_ckpt_remote_code(src_dir, dst_dir) # should not raise
58+
59+
assert list(dst_dir.iterdir()) == [], "no files should be copied"
60+
61+
62+
def test_copy_hf_ckpt_remote_code_hub_id(tmp_path):
63+
"""copy_hf_ckpt_remote_code delegates to snapshot_download for a Hub model ID."""
64+
dst_dir = tmp_path / "dst"
65+
66+
with patch("modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download") as mock_sd:
67+
copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-Nano-12B-v2", dst_dir)
68+
69+
mock_sd.assert_called_once_with(
70+
repo_id="nvidia/NVIDIA-Nemotron-Nano-12B-v2",
71+
local_dir=str(dst_dir),
72+
allow_patterns=["*.py"],
73+
)

0 commit comments

Comments
 (0)