Skip to content

Commit 740a3db

Browse files
j-rauschroot
authored andcommitted
copy custom modeling files to pruned checkpoint dirs; without them, trust_remote_code checkpoints are silently excluded from the replacement library
Signed-off-by: jrausch <jrausch@nvidia.com>
1 parent 38d9522 commit 740a3db

File tree

3 files changed

+93
-0
lines changed

3 files changed

+93
-0
lines changed

modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ..checkpoint_utils_hf import (
3131
_get_auto_class_for_trust_remote_code,
3232
_save_checkpoint,
33+
copy_remote_code_files,
3334
load_model_config,
3435
)
3536
from ..logger import mprint
@@ -87,6 +88,9 @@ def init_child_from_parent(
8788
trust_remote_code=descriptor.requires_trust_remote_code(),
8889
)
8990

91+
if descriptor.requires_trust_remote_code():
92+
copy_remote_code_files(parent_checkpoint_dir, output_checkpoint_dir)
93+
9094
parent_model_config = load_model_config(
9195
parent_checkpoint_dir, trust_remote_code=descriptor.requires_trust_remote_code()
9296
)

modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import dataclasses
2323
import fcntl
2424
import os
25+
import shutil
2526
import time
2627
from collections import defaultdict
2728
from collections.abc import Callable, Mapping
@@ -448,3 +449,26 @@ def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str
448449
for conf in model_config.block_configs
449450
]
450451
model_config.save_pretrained(checkpoint_dir)
452+
453+
454+
def copy_remote_code_files(source_dir: Path | str, output_dir: Path | str) -> None:
455+
"""Copy custom code ``.py`` files from a trust-remote-code checkpoint into *output_dir*.
456+
457+
Models with dynamic modules (e.g. Nemotron-H) ship custom Python files
458+
(``modeling_*.py``, ``configuration_*.py``, etc.) alongside ``config.json``.
459+
When pruned checkpoints are saved via ``_save_checkpoint``, these files are not
460+
automatically copied, so ``AutoConfig.from_pretrained(output_dir, trust_remote_code=True)``
461+
fails and the checkpoint is silently excluded from the replacement library.
462+
463+
Only copies ``.py`` files from the root of *source_dir* (matching the layout of
464+
HuggingFace Hub checkpoints with custom code).
465+
"""
466+
source_dir = Path(source_dir)
467+
output_dir = Path(output_dir)
468+
if not source_dir.is_dir():
469+
return
470+
output_dir.mkdir(parents=True, exist_ok=True)
471+
for py_file in source_dir.glob("*.py"):
472+
target = output_dir / py_file.name
473+
if not target.exists():
474+
shutil.copy2(py_file, target)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for copy_remote_code_files.
17+
18+
Verifies that custom modeling .py files are copied from a source checkpoint
19+
directory to an output directory so that pruned checkpoints remain valid
20+
trust_remote_code checkpoints.
21+
"""
22+
23+
import pytest
24+
25+
pytest.importorskip("transformers")
26+
27+
from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import copy_remote_code_files
28+
29+
30+
def test_copies_only_py_files(tmp_path):
31+
source = tmp_path / "source"
32+
source.mkdir()
33+
(source / "modeling_nemotron_h.py").write_text("class Model: pass")
34+
(source / "configuration_nemotron_h.py").write_text("class Config: pass")
35+
(source / "config.json").write_text("{}")
36+
(source / "model.safetensors").write_bytes(b"\x00")
37+
38+
output = tmp_path / "child"
39+
40+
copy_remote_code_files(source, output)
41+
42+
assert (output / "modeling_nemotron_h.py").exists()
43+
assert (output / "configuration_nemotron_h.py").exists()
44+
assert not (output / "config.json").exists()
45+
assert not (output / "model.safetensors").exists()
46+
47+
48+
def test_skips_existing_files(tmp_path):
49+
source = tmp_path / "source"
50+
source.mkdir()
51+
(source / "modeling.py").write_text("new content")
52+
53+
output = tmp_path / "child"
54+
output.mkdir()
55+
(output / "modeling.py").write_text("existing content")
56+
57+
copy_remote_code_files(source, output)
58+
59+
assert (output / "modeling.py").read_text() == "existing content"
60+
61+
62+
def test_noop_for_missing_source(tmp_path):
63+
copy_remote_code_files(tmp_path / "nonexistent", tmp_path / "output")
64+
65+
assert not (tmp_path / "output").exists()

0 commit comments

Comments
 (0)