Skip to content

Commit a4ad1b8

Browse files
Add Megatron-Bridge recipe-free distillation example script
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 452c5a0 commit a4ad1b8

4 files changed

Lines changed: 286 additions & 5 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ NVIDIA Model Optimizer Changelog (Linux)
1313
- Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead.
1414
- Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint.
1515
- Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning (using ``export_config``). See `examples/pruning/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/pruning>`_ for more details on its usage.
16-
- New example for Minitron pruning with Megatron-Bridge framework along with advanced pruning usage with new ``params`` constraint based pruning. Check `examples/megatron_bridge/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge>`_ for example scripts.
16+
- New example for Minitron pruning with Megatron-Bridge framework along with advanced pruning usage with new ``params`` constraint based pruning. Also add example for distillation with Megatron-Bridge framework. Check `examples/megatron_bridge/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge>`_ for example scripts.
1717
- Add support for calibration data with multiple samples in ``npz`` format in the ONNX Autocast workflow.
1818
- Add ``--opset`` option to ONNX quantization CLI to specify the target opset version for the quantized model.
1919
- Add support for context parallelism in Eagle speculative decoding for huggingface and megatron core models.

examples/megatron_bridge/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ torchrun --nproc_per_node 2 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/exampl
5050
To see the full usage for advanced configurations, run:
5151

5252
```bash
53-
python /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help
53+
torchrun --nproc_per_node 1 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help
5454
```
5555

5656
> [!TIP]
@@ -60,7 +60,7 @@ python /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/pr
6060
6161
## Distillation
6262

63-
TODO
63+
TODO - Add info!
6464

6565
## Quantization
6666

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Distillation script for Megatron-Bridge.
16+
17+
Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model
18+
to <log_dir>/checkpoints in megatron torch_dist checkpoint format.
19+
20+
Example usage to distill a 4B student from an 8B teacher on 8 GPUs:
21+
22+
.. code-block:: bash
23+
24+
torchrun --nproc_per_node 8 distill.py \
25+
--teacher_hf_path Qwen/Qwen3-8B \
26+
--student_hf_path Qwen/Qwen3-4B \
27+
--tp_size 8 \
28+
--data_paths 1.0 /path/to/tokenized/data \
29+
--seq_length 8192 \
30+
--mbs 1 \
31+
--gbs 768 \
32+
--train_iters 15000 \
33+
--lr 1e-4 \
34+
--min_lr 1e-5 \
35+
--lr_warmup_iters 50 \
36+
--eval_interval 100 \
37+
--eval_iters 32 \
38+
--log_interval 10 \
39+
--log_dir /output/qwen3_8b_to_4b_distill
40+
41+
Example usage to use mock data for quick testing:
42+
43+
.. code-block:: bash
44+
45+
torchrun --nproc_per_node 8 distill.py \
46+
--teacher_hf_path Qwen/Qwen3-0.6B \
47+
--student_hf_path Qwen/Qwen3-0.6B \
48+
--tp_size 8 \
49+
--use_mock_data \
50+
--seq_length 512 \
51+
--mbs 1 \
52+
--gbs 8 \
53+
--train_iters 100 \
54+
--log_dir /tmp/test_distill
55+
56+
If you want to tokenize your own data for a specific tokenizer, you can use the following command:
57+
58+
.. code-block:: python
59+
60+
from modelopt.torch.utils.plugins import megatron_preprocess_data
61+
62+
megatron_preprocess_data(
63+
input_path="/path/to/your/data.jsonl",
64+
output_dir="/path/to/tokenized/data",
65+
tokenizer_name_or_path="Qwen/Qwen3-0.6B",
66+
json_keys=["text"],
67+
workers=32,
68+
log_interval=100000,
69+
max_sequence_length=256000,
70+
)
71+
"""
72+
# TODO: Fix resuming distillation from an intermediate checkpoint.
73+
74+
import argparse
75+
import os
76+
77+
import torch
78+
from megatron.bridge import AutoBridge
79+
from megatron.bridge.models.distillation_provider import convert_to_distillation_provider
80+
from megatron.bridge.recipes.utils.optimizer_utils import (
81+
distributed_fused_adam_with_cosine_annealing,
82+
)
83+
from megatron.bridge.training.config import (
84+
CheckpointConfig,
85+
ConfigContainer,
86+
GPTDatasetConfig,
87+
LoggerConfig,
88+
MockGPTDatasetConfig,
89+
RNGConfig,
90+
TokenizerConfig,
91+
TrainingConfig,
92+
)
93+
from megatron.bridge.training.distill import distill
94+
from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig
95+
from megatron.core.datasets.utils import get_blend_from_list
96+
from megatron.core.distributed import DistributedDataParallelConfig
97+
98+
import modelopt.torch.utils.distributed as dist
99+
from modelopt.torch.utils import print_rank_0
100+
101+
SEED = 1234
102+
103+
104+
def get_args():
105+
"""Parse command-line arguments."""
106+
parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.")
107+
# Model arguments
108+
parser.add_argument(
109+
"--student_hf_path",
110+
type=str,
111+
required=True,
112+
help="HuggingFace model name or path for the student (e.g. Qwen/Qwen3-0.6B)",
113+
)
114+
parser.add_argument(
115+
"--teacher_hf_path",
116+
type=str,
117+
required=True,
118+
help="HuggingFace model name or path for the teacher (e.g. Qwen/Qwen3-8B)",
119+
)
120+
# Parallelism arguments
121+
parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size")
122+
parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size")
123+
# Dataset arguments
124+
parser.add_argument(
125+
"--data_paths",
126+
nargs="+",
127+
help="List of tokenized data paths to load from (weight1 path1 weight2 path2 ...)",
128+
)
129+
parser.add_argument(
130+
"--split", type=str, default="99,1,0", help="Train,Val,Test ratios to split data"
131+
)
132+
parser.add_argument(
133+
"--use_mock_data", action="store_true", help="Use mock data instead of --data_paths"
134+
)
135+
# Training arguments
136+
parser.add_argument(
137+
"--log_dir", type=str, required=True, help="Folder for logging and checkpoint saving"
138+
)
139+
parser.add_argument(
140+
"--seq_length", type=int, default=8192, help="Number of tokens per input sample"
141+
)
142+
parser.add_argument("--mbs", type=int, default=1, help="Micro-batch Size")
143+
parser.add_argument("--gbs", type=int, default=768, help="Global Batch Size")
144+
parser.add_argument(
145+
"--train_iters", type=int, required=True, help="Number of training iterations"
146+
)
147+
parser.add_argument("--lr", type=float, default=1e-4, help="Peak learning rate")
148+
parser.add_argument("--min_lr", type=float, default=1e-5, help="Minimum learning rate")
149+
parser.add_argument("--lr_warmup_iters", type=int, default=50, help="Number of LR warmup steps")
150+
parser.add_argument(
151+
"--eval_interval", type=int, default=100, help="Validate + checkpoint every <N> steps"
152+
)
153+
parser.add_argument(
154+
"--eval_iters", type=int, default=32, help="Number of batches per validation stage"
155+
)
156+
parser.add_argument("--log_interval", type=int, default=10, help="Write to log every <N> steps")
157+
args = parser.parse_args()
158+
159+
# Sanity checks
160+
if not args.use_mock_data and not args.data_paths:
161+
raise ValueError("Must provide either --data_paths or set --use_mock_data.")
162+
163+
print_rank_0("\n==================== Arguments ====================")
164+
for k, v in args.__dict__.items():
165+
print_rank_0(f"{k:<35} {v}")
166+
print_rank_0("===================================================\n")
167+
168+
return args
169+
170+
171+
def main(args: argparse.Namespace):
172+
checkpoint_dir = os.path.join(args.log_dir, "checkpoints")
173+
tensorboard_dir = os.path.join(args.log_dir, "tb_logs")
174+
175+
# Build student and teacher model providers
176+
def _build_model_provider(hf_path):
177+
bridge = AutoBridge.from_hf_pretrained(hf_path)
178+
provider = bridge.to_megatron_provider(load_weights=True)
179+
provider.tensor_model_parallel_size = args.tp_size
180+
provider.pipeline_model_parallel_size = args.pp_size
181+
provider.context_parallel_size = 1
182+
provider.sequence_parallel = args.tp_size > 1
183+
provider.seq_length = args.seq_length
184+
provider.pipeline_dtype = torch.bfloat16
185+
provider.cross_entropy_fusion_impl = "te"
186+
return provider
187+
188+
# TODO: Support megatron-ckpt as an alternative to HF checkpoints
189+
student_provider = _build_model_provider(args.student_hf_path)
190+
teacher_provider = _build_model_provider(args.teacher_hf_path)
191+
192+
# Wrap into DistillationProvider
193+
kd_config = ModelOptDistillConfig()
194+
distill_provider = convert_to_distillation_provider(
195+
student_provider, teacher_provider, kd_config
196+
)
197+
198+
# Build optimizer and scheduler
199+
optimizer_config, scheduler_config = distributed_fused_adam_with_cosine_annealing(
200+
lr_warmup_iters=args.lr_warmup_iters,
201+
max_lr=args.lr,
202+
min_lr=args.min_lr,
203+
adam_beta2=0.98,
204+
)
205+
206+
# Build dataset config
207+
dataset_kwargs = {
208+
"seq_length": args.seq_length,
209+
"random_seed": SEED,
210+
"reset_attention_mask": False,
211+
"reset_position_ids": False,
212+
"eod_mask_loss": False,
213+
"num_dataset_builder_threads": 1,
214+
"data_sharding": True,
215+
"dataloader_type": "single",
216+
"skip_getting_attention_mask_from_dataset": True,
217+
}
218+
if args.use_mock_data:
219+
dataset_config = MockGPTDatasetConfig(**dataset_kwargs)
220+
else:
221+
# Convert flat CLI list (e.g. ["1.0", "/path/data"]) to Megatron blend format
222+
blend = get_blend_from_list(args.data_paths)
223+
dataset_config = GPTDatasetConfig(blend=blend, split=args.split, **dataset_kwargs)
224+
225+
# Assemble ConfigContainer and run distillation
226+
config = ConfigContainer(
227+
model=distill_provider,
228+
train=TrainingConfig(
229+
train_iters=args.train_iters,
230+
eval_interval=args.eval_interval,
231+
eval_iters=args.eval_iters,
232+
global_batch_size=args.gbs,
233+
micro_batch_size=args.mbs,
234+
manual_gc=True,
235+
manual_gc_interval=100,
236+
),
237+
optimizer=optimizer_config,
238+
scheduler=scheduler_config,
239+
ddp=DistributedDataParallelConfig(
240+
check_for_nan_in_grad=True,
241+
grad_reduce_in_fp32=True,
242+
overlap_grad_reduce=True,
243+
overlap_param_gather=True,
244+
average_in_collective=True,
245+
use_distributed_optimizer=True,
246+
),
247+
dataset=dataset_config,
248+
logger=LoggerConfig(
249+
log_interval=args.log_interval,
250+
tensorboard_dir=tensorboard_dir,
251+
log_timers_to_tensorboard=True,
252+
),
253+
tokenizer=TokenizerConfig(
254+
tokenizer_type="NullTokenizer", vocab_size=distill_provider.vocab_size
255+
),
256+
checkpoint=CheckpointConfig(
257+
save_interval=args.eval_interval,
258+
save=checkpoint_dir,
259+
load=checkpoint_dir,
260+
ckpt_format="torch_dist",
261+
fully_parallel_save=True,
262+
finetune=True,
263+
),
264+
rng=RNGConfig(seed=SEED),
265+
mixed_precision="bf16_mixed",
266+
)
267+
268+
print_rank_0("\nStarting distillation...")
269+
distill(config)
270+
print_rank_0(f"\nDistillation done! Saved checkpoint to {checkpoint_dir}\n")
271+
272+
273+
if __name__ == "__main__":
274+
dist.setup()
275+
args = get_args()
276+
try:
277+
main(args)
278+
finally:
279+
dist.cleanup()

modelopt/torch/utils/plugins/megatron_preprocess_data.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
from megatron.core.datasets import indexed_dataset
4343
from transformers import AutoTokenizer
4444

45+
from modelopt.torch.utils import num2hrb
46+
4547
__all__ = ["megatron_preprocess_data"]
4648

4749

@@ -109,7 +111,7 @@ def __init__(self, vocab_size: int, json_keys: list[str], log_interval: int, wor
109111
def _print_processing_stats(self, count: int, total_doc_len: int, total_enc_len: int):
110112
if count % self.log_interval == 0:
111113
print(
112-
f"Processed {count} documents, {total_doc_len} chars, {total_enc_len} tokens",
114+
f"Processed {num2hrb(count)} docs = {num2hrb(total_doc_len)} chars = {num2hrb(total_enc_len)} tokens",
113115
file=sys.stderr,
114116
)
115117

@@ -202,7 +204,7 @@ def megatron_preprocess_data(
202204
num_tokens = partition.process_json_file(name, output_dir, encoder)
203205
final_enc_len += num_tokens
204206

205-
print(f">>> Total number of tokens: {final_enc_len}")
207+
print(f">>> Total number of tokens: {num2hrb(final_enc_len)}")
206208

207209

208210
def main():

0 commit comments

Comments
 (0)