Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions modelopt/torch/export/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@

with import_plugin("vllm_fakequant_megatron"):
from .vllm_fakequant_megatron import *

with import_plugin("hf_checkpoint_utils"):
from .hf_checkpoint_utils import *
24 changes: 11 additions & 13 deletions modelopt/torch/export/plugins/hf_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
from pathlib import Path

import torch
from huggingface_hub import hf_hub_download, list_repo_files
from huggingface_hub import snapshot_download
from safetensors.torch import safe_open
from tqdm import tqdm


def copy_remote_code(
pretrained_model_path: str | os.PathLike,
save_directory: str | os.PathLike,
def copy_hf_ckpt_remote_code(
pretrained_model_path: str | os.PathLike, save_directory: str | os.PathLike
):
"""Copy remote code from pretrained model to save directory.

Expand All @@ -37,26 +36,25 @@ def copy_remote_code(
frameworks.

If ``pretrained_model_path`` is a local directory, Python files are copied directly.
If it is a HuggingFace Hub model ID, Python files are downloaded from the Hub first.
If it's a HF Hub model ID (e.g. ``nvidia/NVIDIA-Nemotron-Nano-12B-v2``), files are downloaded from the Hub.

Args:
pretrained_model_path: Local path to the pretrained model or HuggingFace Hub model ID.
save_directory: Path to the save directory.
"""
hf_checkpoint_path = Path(pretrained_model_path)
save_dir = Path(save_directory)
save_dir.mkdir(parents=True, exist_ok=True)

if hf_checkpoint_path.is_dir():
for py_file in hf_checkpoint_path.glob("*.py"):
if py_file.is_file():
shutil.copy(py_file, save_dir / py_file.name)
shutil.copy2(py_file, save_dir / py_file.name)
else:
# Hub model ID: download any top-level .py files (custom modeling code)
repo_id = str(pretrained_model_path)
for filename in list_repo_files(repo_id):
if "/" not in filename and filename.endswith(".py"):
local_path = hf_hub_download(repo_id=repo_id, filename=filename)
shutil.copy(local_path, save_dir / filename)
snapshot_download(
repo_id=str(pretrained_model_path),
local_dir=str(save_dir),
allow_patterns=["*.py"],
)


def load_multimodal_components(
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
)
from .plugins.hf_checkpoint_utils import copy_remote_code, load_multimodal_components
from .plugins.hf_checkpoint_utils import copy_hf_ckpt_remote_code, load_multimodal_components
from .plugins.mcore_common import all_mcore_hf_export_mapping
from .plugins.mcore_custom import (
CustomModuleMapping,
Expand Down Expand Up @@ -349,7 +349,7 @@ def save_pretrained(
torch.distributed.barrier()

if is_last_stage_main_rank and self._hf_config is not None:
copy_remote_code(pretrained_model_name_or_path, save_directory)
copy_hf_ckpt_remote_code(pretrained_model_name_or_path, save_directory)

# Newer versions of VLLM expect config.json with hf_quant_config
config_json_file = save_directory + "/config.json"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import yaml
from transformers import AutoModelForCausalLM

from modelopt.torch.export import copy_hf_ckpt_remote_code

from ...anymodel.model_descriptor import ModelDescriptor, ModelDescriptorFactory
from ...anymodel.puzzformer import deci_x_patcher
from ..checkpoint_utils import copy_tokenizer, load_state_dict
Expand Down Expand Up @@ -87,6 +89,9 @@ def init_child_from_parent(
trust_remote_code=descriptor.requires_trust_remote_code(),
)

if descriptor.requires_trust_remote_code():
copy_hf_ckpt_remote_code(parent_checkpoint_dir, output_checkpoint_dir)

parent_model_config = load_model_config(
parent_checkpoint_dir, trust_remote_code=descriptor.requires_trust_remote_code()
)
Expand Down
16 changes: 3 additions & 13 deletions tests/_test_utils/torch/puzzletron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import torch
from _test_utils.torch.transformers_models import get_tiny_tokenizer
from datasets import Dataset, DatasetDict
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizerBase

import modelopt.torch.puzzletron as mtpz
import modelopt.torch.utils.distributed as dist
from modelopt.torch.export import copy_hf_ckpt_remote_code


def setup_test_model_and_data(
Expand Down Expand Up @@ -189,21 +189,11 @@ def create_and_save_small_hf_model(
submodule._tied_weights_keys = None
model.save_pretrained(output_path, save_original_format=False)

# Save tokenizer
# Save tokenizer, config, and custom code files
tokenizer.save_pretrained(output_path)

# Save config
config.save_pretrained(output_path)

# Download trust_remote_code .py files from HF hub into the checkpoint directory so that
# force_cache_dynamic_modules can resolve classes from the local path.
# save_pretrained only saves weights + config, not these .py files.
if hasattr(config, "auto_map") and isinstance(config.auto_map, dict):
snapshot_download(
repo_id=hf_model_name,
local_dir=output_path,
allow_patterns=["*.py"],
)
copy_hf_ckpt_remote_code(hf_model_name, output_path)


def save_dummy_dataset(dataset_path: Path | str):
Expand Down
109 changes: 0 additions & 109 deletions tests/gpu_megatron/torch/export/test_hf_checkpoint_utils.py

This file was deleted.

73 changes: 73 additions & 0 deletions tests/unit/torch/export/test_hf_checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for modelopt/torch/export/plugins/hf_checkpoint_utils.py"""

from unittest.mock import patch

import pytest

pytest.importorskip("huggingface_hub")

from modelopt.torch.export import copy_hf_ckpt_remote_code


def test_copy_hf_ckpt_remote_code_local_dir(tmp_path):
"""copy_hf_ckpt_remote_code copies top-level .py files from a local directory."""
src_dir = tmp_path / "src"
src_dir.mkdir()
(src_dir / "modeling_custom.py").write_text("# custom model")
(src_dir / "configuration_custom.py").write_text("# custom config")
(src_dir / "not_python.txt").write_text("not python")
(src_dir / "subdir").mkdir()
(src_dir / "subdir" / "nested.py").write_text("# nested — should not be copied")

dst_dir = tmp_path / "dst"
dst_dir.mkdir()

copy_hf_ckpt_remote_code(src_dir, dst_dir)

assert (dst_dir / "modeling_custom.py").read_text() == "# custom model"
assert (dst_dir / "configuration_custom.py").read_text() == "# custom config"
assert not (dst_dir / "not_python.txt").exists(), "non-.py files should not be copied"
assert not (dst_dir / "nested.py").exists(), "nested .py files should not be copied"


def test_copy_hf_ckpt_remote_code_local_dir_no_py_files(tmp_path):
"""copy_hf_ckpt_remote_code is a no-op when the local directory has no .py files."""
src_dir = tmp_path / "src"
src_dir.mkdir()
(src_dir / "config.json").write_text("{}")

dst_dir = tmp_path / "dst"
dst_dir.mkdir()

copy_hf_ckpt_remote_code(src_dir, dst_dir) # should not raise

assert list(dst_dir.iterdir()) == [], "no files should be copied"


def test_copy_hf_ckpt_remote_code_hub_id(tmp_path):
"""copy_hf_ckpt_remote_code delegates to snapshot_download for a Hub model ID."""
dst_dir = tmp_path / "dst"

with patch("modelopt.torch.export.plugins.hf_checkpoint_utils.snapshot_download") as mock_sd:
copy_hf_ckpt_remote_code("nvidia/NVIDIA-Nemotron-Nano-12B-v2", dst_dir)

mock_sd.assert_called_once_with(
repo_id="nvidia/NVIDIA-Nemotron-Nano-12B-v2",
local_dir=str(dst_dir),
allow_patterns=["*.py"],
)
Loading