Skip to content

Commit d7f62d3

Browse files
ChenhanYukevalmorabia97
authored andcommitted
Chenhany/megatron export per layer (#881)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> Bug fix **Overview:** ? 1. Fixing megatron ignore module has additional `.` in the suffix 2. Change megatron export to safe per layer as a safetensor (avoid ghost safetensors) ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Export workflow now supports additional model components (EAGLE/Medusa modules) * Per-layer model state organization for improved checkpoint management * **Bug Fixes** * More robust Hugging Face configuration, tokenizer, and image processor preservation * Enhanced multimodal component extraction and loading * **Refactor** * Optimized model export process with improved per-layer safetensors handling <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
1 parent 8857025 commit d7f62d3

File tree

4 files changed

+279
-167
lines changed

4 files changed

+279
-167
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Hugging Face checkpoint utility."""
17+
18+
import json
19+
import os
20+
import shutil
21+
from pathlib import Path
22+
23+
import torch
24+
from safetensors.torch import safe_open
25+
from tqdm import tqdm
26+
27+
28+
def copy_remote_code(
29+
pretrained_model_path: str | os.PathLike,
30+
save_directory: str | os.PathLike,
31+
):
32+
"""Copy remote code from pretrained model to save directory.
33+
34+
For models that keep configuration and modeling files as part of the checkpoint,
35+
we need to copy them to the export directory for seamless integration with inference
36+
frameworks.
37+
38+
Args:
39+
pretrained_model_path: Path to the pretrained model.
40+
save_directory: Path to the save directory.
41+
42+
Raises:
43+
ValueError: If the pretrained model path is not a directory.
44+
"""
45+
hf_checkpoint_path = Path(pretrained_model_path)
46+
save_dir = Path(save_directory)
47+
48+
if not hf_checkpoint_path.is_dir():
49+
raise ValueError(
50+
f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory."
51+
)
52+
53+
for py_file in hf_checkpoint_path.glob("*.py"):
54+
if py_file.is_file():
55+
shutil.copy(py_file, save_dir / py_file.name)
56+
57+
58+
def load_multimodal_components(
59+
pretrained_model_path: str | os.PathLike,
60+
) -> dict[str, torch.Tensor]:
61+
"""Load multimodal components from safetensors file.
62+
63+
Args:
64+
pretrained_model_path: Path to the pretrained model.
65+
66+
Returns:
67+
A dictionary of multimodal components.
68+
"""
69+
hf_checkpoint_path = Path(pretrained_model_path)
70+
if not hf_checkpoint_path.is_dir():
71+
raise ValueError(
72+
f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory."
73+
)
74+
75+
safetensors_file = Path(hf_checkpoint_path) / "model.safetensors"
76+
safetensors_index_file = Path(hf_checkpoint_path) / "model.safetensors.index.json"
77+
78+
multimodal_state_dict = {}
79+
80+
if safetensors_file.is_file():
81+
print(f"Loading multimodal components from single file: {safetensors_file}")
82+
with safe_open(safetensors_file, framework="pt") as f:
83+
multimodal_keys = [
84+
key
85+
for key in f.keys() # noqa: SIM118
86+
if key.startswith(("multi_modal_projector", "vision_model"))
87+
]
88+
for key in tqdm(multimodal_keys, desc="Loading multimodal tensors"):
89+
multimodal_state_dict[key] = f.get_tensor(key)
90+
91+
elif safetensors_index_file.is_file():
92+
print(f"Loading multimodal components from sharded model: {hf_checkpoint_path}")
93+
with open(safetensors_index_file) as f:
94+
safetensors_index = json.load(f)
95+
96+
# For multimodal models, vision_model and multi_modal_projector are in the first shard
97+
all_shard_files = sorted(set(safetensors_index["weight_map"].values()))
98+
first_shard_file = all_shard_files[0] # e.g., "model-00001-of-00050.safetensors"
99+
100+
# Load multimodal components from the first shard file
101+
safetensors_filepath = Path(hf_checkpoint_path) / first_shard_file
102+
print(f"Loading multimodal components from {first_shard_file}")
103+
104+
with safe_open(safetensors_filepath, framework="pt") as f:
105+
shard_keys = list(f.keys())
106+
multimodal_keys_in_shard = [
107+
k for k in shard_keys if k.startswith(("multi_modal_projector", "vision_model"))
108+
]
109+
110+
if multimodal_keys_in_shard:
111+
print(
112+
f"Found {len(multimodal_keys_in_shard)} multimodal tensors in {first_shard_file}"
113+
)
114+
for key in tqdm(multimodal_keys_in_shard, desc="Loading multimodal tensors"):
115+
multimodal_state_dict[key] = f.get_tensor(key)
116+
else:
117+
print(f"No multimodal components found in {first_shard_file}")
118+
119+
else:
120+
print(f"Warning: No safetensors files found in {hf_checkpoint_path}")
121+
122+
print(f"Successfully loaded {len(multimodal_state_dict)} multimodal tensors")
123+
return multimodal_state_dict

modelopt/torch/export/plugins/mcore_custom.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,59 @@ def save_safetensors(state_dict, save_directory: str | os.PathLike):
274274
json.dump(safetensor_index, f, indent=4)
275275

276276

277+
def save_safetensors_by_layer_index(
278+
layer_state_dicts: dict[int, dict[str, torch.Tensor]],
279+
total_layers: int,
280+
save_directory: str | os.PathLike,
281+
name_template: str = "model-{:05d}-of-{:05d}",
282+
):
283+
"""Save safetensors by layer index.
284+
285+
Args:
286+
layer_state_dicts: A dictionary of layer state dictionaries.
287+
total_layers: Total number of layers.
288+
save_directory: Path to the save directory.
289+
name_template: Template for the filename.
290+
"""
291+
for layer_index, layer_state_dict in layer_state_dicts.items():
292+
filename = name_template.format(layer_index, total_layers)
293+
meta_filename = filename + ".json"
294+
ckpt_filename = filename + ".safetensors"
295+
296+
weight_map = {}
297+
layer_total_size = 0
298+
for key, val in layer_state_dict.items():
299+
tensor_size = val.numel() * val.element_size()
300+
layer_total_size += tensor_size
301+
weight_map[key] = ckpt_filename
302+
303+
with open(save_directory + "/" + meta_filename, "w") as f:
304+
json.dump(
305+
{"metadata": {"total_size": layer_total_size}, "weight_map": weight_map},
306+
f,
307+
indent=4,
308+
)
309+
save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"})
310+
311+
# [TODO]: this global barrier needs to be replaced with something safer
312+
torch.distributed.barrier()
313+
314+
if torch.distributed.get_rank() == 0:
315+
safetensor_index = {
316+
"metadata": {"total_size": 0},
317+
"weight_map": {},
318+
}
319+
for layer_index in range(total_layers):
320+
meta_filename = name_template.format(layer_index + 1, total_layers) + ".json"
321+
with open(save_directory + "/" + meta_filename) as f:
322+
shard = json.load(f)
323+
safetensor_index["metadata"]["total_size"] += shard["metadata"]["total_size"]
324+
safetensor_index["weight_map"].update(shard["weight_map"])
325+
326+
with open(save_directory + "/model.safetensors.index.json", "w") as f:
327+
json.dump(safetensor_index, f, indent=4)
328+
329+
277330
def _get_safetensors_file(pretrained_model_path: str | Path, key: str) -> Path | None:
278331
"""Given a tensor key return the safetensors file that contains this tensor if exists.
279332

modelopt/torch/export/plugins/vllm_fakequant_megatron.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class VllmFqGPTModelExporter(GPTModelExporter):
7272
def save_pretrained(
7373
self,
7474
save_directory: str | os.PathLike,
75-
pretrained_model_name_or_path: str | os.PathLike | None = None,
75+
pretrained_model_name_or_path: str | os.PathLike,
7676
):
7777
os.makedirs(save_directory, exist_ok=True)
7878
gather_mcore_vllm_fq_quantized_state_dict(self.model, self.state_dict, save_directory)
@@ -91,7 +91,7 @@ def _get_quantization_format(self, module: torch.nn.Module):
9191

9292
def export_mcore_gpt_to_hf_vllm_fq(
9393
model: torch.nn.Module,
94-
pretrained_model_name_or_path: str | os.PathLike | None = None,
94+
pretrained_model_name_or_path: str | os.PathLike,
9595
export_extra_modules: bool = False,
9696
dtype: torch.dtype = torch.bfloat16,
9797
export_dir: Path | str = tempfile.gettempdir(),

0 commit comments

Comments
 (0)