From a4ad1b89113ad390b66b8cafed8d7434dc15a3d1 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Thu, 5 Feb 2026 18:04:34 -0800 Subject: [PATCH 1/6] Add Megatron-Bridge recipe-free distillation example script Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- CHANGELOG.rst | 2 +- examples/megatron_bridge/README.md | 4 +- examples/megatron_bridge/distill.py | 279 ++++++++++++++++++ .../utils/plugins/megatron_preprocess_data.py | 6 +- 4 files changed, 286 insertions(+), 5 deletions(-) create mode 100644 examples/megatron_bridge/distill.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index cf3e2ff628..8cf2a4315e 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,7 +13,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead. - Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint. - Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning (using ``export_config``). See `examples/pruning/README.md `_ for more details on its usage. -- New example for Minitron pruning with Megatron-Bridge framework along with advanced pruning usage with new ``params`` constraint based pruning. Check `examples/megatron_bridge/README.md `_ for example scripts. +- New example for Minitron pruning with Megatron-Bridge framework along with advanced pruning usage with new ``params`` constraint based pruning. Also add example for distillation with Megatron-Bridge framework. Check `examples/megatron_bridge/README.md `_ for example scripts. - Add support for calibration data with multiple samples in ``npz`` format in the ONNX Autocast workflow. - Add ``--opset`` option to ONNX quantization CLI to specify the target opset version for the quantized model. - Add support for context parallelism in Eagle speculative decoding for huggingface and megatron core models. diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index 1a60a952a0..ebd5b90f13 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -50,7 +50,7 @@ torchrun --nproc_per_node 2 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/exampl To see the full usage for advanced configurations, run: ```bash -python /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help +torchrun --nproc_per_node 1 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help ``` > [!TIP] @@ -60,7 +60,7 @@ python /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/pr ## Distillation -TODO +TODO - Add info! ## Quantization diff --git a/examples/megatron_bridge/distill.py b/examples/megatron_bridge/distill.py new file mode 100644 index 0000000000..49e8f93b0e --- /dev/null +++ b/examples/megatron_bridge/distill.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Distillation script for Megatron-Bridge. + +Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model +to /checkpoints in megatron torch_dist checkpoint format. + +Example usage to distill a 4B student from an 8B teacher on 8 GPUs: + +.. code-block:: bash + + torchrun --nproc_per_node 8 distill.py \ + --teacher_hf_path Qwen/Qwen3-8B \ + --student_hf_path Qwen/Qwen3-4B \ + --tp_size 8 \ + --data_paths 1.0 /path/to/tokenized/data \ + --seq_length 8192 \ + --mbs 1 \ + --gbs 768 \ + --train_iters 15000 \ + --lr 1e-4 \ + --min_lr 1e-5 \ + --lr_warmup_iters 50 \ + --eval_interval 100 \ + --eval_iters 32 \ + --log_interval 10 \ + --log_dir /output/qwen3_8b_to_4b_distill + +Example usage to use mock data for quick testing: + +.. code-block:: bash + + torchrun --nproc_per_node 8 distill.py \ + --teacher_hf_path Qwen/Qwen3-0.6B \ + --student_hf_path Qwen/Qwen3-0.6B \ + --tp_size 8 \ + --use_mock_data \ + --seq_length 512 \ + --mbs 1 \ + --gbs 8 \ + --train_iters 100 \ + --log_dir /tmp/test_distill + +If you want to tokenize your own data for a specific tokenizer, you can use the following command: + +.. code-block:: python + + from modelopt.torch.utils.plugins import megatron_preprocess_data + + megatron_preprocess_data( + input_path="/path/to/your/data.jsonl", + output_dir="/path/to/tokenized/data", + tokenizer_name_or_path="Qwen/Qwen3-0.6B", + json_keys=["text"], + workers=32, + log_interval=100000, + max_sequence_length=256000, + ) +""" +# TODO: Fix resuming distillation from an intermediate checkpoint. + +import argparse +import os + +import torch +from megatron.bridge import AutoBridge +from megatron.bridge.models.distillation_provider import convert_to_distillation_provider +from megatron.bridge.recipes.utils.optimizer_utils import ( + distributed_fused_adam_with_cosine_annealing, +) +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + GPTDatasetConfig, + LoggerConfig, + MockGPTDatasetConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.distill import distill +from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.distributed import DistributedDataParallelConfig + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.utils import print_rank_0 + +SEED = 1234 + + +def get_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.") + # Model arguments + parser.add_argument( + "--student_hf_path", + type=str, + required=True, + help="HuggingFace model name or path for the student (e.g. Qwen/Qwen3-0.6B)", + ) + parser.add_argument( + "--teacher_hf_path", + type=str, + required=True, + help="HuggingFace model name or path for the teacher (e.g. Qwen/Qwen3-8B)", + ) + # Parallelism arguments + parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") + # Dataset arguments + parser.add_argument( + "--data_paths", + nargs="+", + help="List of tokenized data paths to load from (weight1 path1 weight2 path2 ...)", + ) + parser.add_argument( + "--split", type=str, default="99,1,0", help="Train,Val,Test ratios to split data" + ) + parser.add_argument( + "--use_mock_data", action="store_true", help="Use mock data instead of --data_paths" + ) + # Training arguments + parser.add_argument( + "--log_dir", type=str, required=True, help="Folder for logging and checkpoint saving" + ) + parser.add_argument( + "--seq_length", type=int, default=8192, help="Number of tokens per input sample" + ) + parser.add_argument("--mbs", type=int, default=1, help="Micro-batch Size") + parser.add_argument("--gbs", type=int, default=768, help="Global Batch Size") + parser.add_argument( + "--train_iters", type=int, required=True, help="Number of training iterations" + ) + parser.add_argument("--lr", type=float, default=1e-4, help="Peak learning rate") + parser.add_argument("--min_lr", type=float, default=1e-5, help="Minimum learning rate") + parser.add_argument("--lr_warmup_iters", type=int, default=50, help="Number of LR warmup steps") + parser.add_argument( + "--eval_interval", type=int, default=100, help="Validate + checkpoint every steps" + ) + parser.add_argument( + "--eval_iters", type=int, default=32, help="Number of batches per validation stage" + ) + parser.add_argument("--log_interval", type=int, default=10, help="Write to log every steps") + args = parser.parse_args() + + # Sanity checks + if not args.use_mock_data and not args.data_paths: + raise ValueError("Must provide either --data_paths or set --use_mock_data.") + + print_rank_0("\n==================== Arguments ====================") + for k, v in args.__dict__.items(): + print_rank_0(f"{k:<35} {v}") + print_rank_0("===================================================\n") + + return args + + +def main(args: argparse.Namespace): + checkpoint_dir = os.path.join(args.log_dir, "checkpoints") + tensorboard_dir = os.path.join(args.log_dir, "tb_logs") + + # Build student and teacher model providers + def _build_model_provider(hf_path): + bridge = AutoBridge.from_hf_pretrained(hf_path) + provider = bridge.to_megatron_provider(load_weights=True) + provider.tensor_model_parallel_size = args.tp_size + provider.pipeline_model_parallel_size = args.pp_size + provider.context_parallel_size = 1 + provider.sequence_parallel = args.tp_size > 1 + provider.seq_length = args.seq_length + provider.pipeline_dtype = torch.bfloat16 + provider.cross_entropy_fusion_impl = "te" + return provider + + # TODO: Support megatron-ckpt as an alternative to HF checkpoints + student_provider = _build_model_provider(args.student_hf_path) + teacher_provider = _build_model_provider(args.teacher_hf_path) + + # Wrap into DistillationProvider + kd_config = ModelOptDistillConfig() + distill_provider = convert_to_distillation_provider( + student_provider, teacher_provider, kd_config + ) + + # Build optimizer and scheduler + optimizer_config, scheduler_config = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=args.lr_warmup_iters, + max_lr=args.lr, + min_lr=args.min_lr, + adam_beta2=0.98, + ) + + # Build dataset config + dataset_kwargs = { + "seq_length": args.seq_length, + "random_seed": SEED, + "reset_attention_mask": False, + "reset_position_ids": False, + "eod_mask_loss": False, + "num_dataset_builder_threads": 1, + "data_sharding": True, + "dataloader_type": "single", + "skip_getting_attention_mask_from_dataset": True, + } + if args.use_mock_data: + dataset_config = MockGPTDatasetConfig(**dataset_kwargs) + else: + # Convert flat CLI list (e.g. ["1.0", "/path/data"]) to Megatron blend format + blend = get_blend_from_list(args.data_paths) + dataset_config = GPTDatasetConfig(blend=blend, split=args.split, **dataset_kwargs) + + # Assemble ConfigContainer and run distillation + config = ConfigContainer( + model=distill_provider, + train=TrainingConfig( + train_iters=args.train_iters, + eval_interval=args.eval_interval, + eval_iters=args.eval_iters, + global_batch_size=args.gbs, + micro_batch_size=args.mbs, + manual_gc=True, + manual_gc_interval=100, + ), + optimizer=optimizer_config, + scheduler=scheduler_config, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ), + dataset=dataset_config, + logger=LoggerConfig( + log_interval=args.log_interval, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", vocab_size=distill_provider.vocab_size + ), + checkpoint=CheckpointConfig( + save_interval=args.eval_interval, + save=checkpoint_dir, + load=checkpoint_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + finetune=True, + ), + rng=RNGConfig(seed=SEED), + mixed_precision="bf16_mixed", + ) + + print_rank_0("\nStarting distillation...") + distill(config) + print_rank_0(f"\nDistillation done! Saved checkpoint to {checkpoint_dir}\n") + + +if __name__ == "__main__": + dist.setup() + args = get_args() + try: + main(args) + finally: + dist.cleanup() diff --git a/modelopt/torch/utils/plugins/megatron_preprocess_data.py b/modelopt/torch/utils/plugins/megatron_preprocess_data.py index 94bf268bcb..92ea4bd51b 100644 --- a/modelopt/torch/utils/plugins/megatron_preprocess_data.py +++ b/modelopt/torch/utils/plugins/megatron_preprocess_data.py @@ -42,6 +42,8 @@ from megatron.core.datasets import indexed_dataset from transformers import AutoTokenizer +from modelopt.torch.utils import num2hrb + __all__ = ["megatron_preprocess_data"] @@ -109,7 +111,7 @@ def __init__(self, vocab_size: int, json_keys: list[str], log_interval: int, wor def _print_processing_stats(self, count: int, total_doc_len: int, total_enc_len: int): if count % self.log_interval == 0: print( - f"Processed {count} documents, {total_doc_len} chars, {total_enc_len} tokens", + f"Processed {num2hrb(count)} docs = {num2hrb(total_doc_len)} chars = {num2hrb(total_enc_len)} tokens", file=sys.stderr, ) @@ -202,7 +204,7 @@ def megatron_preprocess_data( num_tokens = partition.process_json_file(name, output_dir, encoder) final_enc_len += num_tokens - print(f">>> Total number of tokens: {final_enc_len}") + print(f">>> Total number of tokens: {num2hrb(final_enc_len)}") def main(): From 48c74bd59edd132bbadada814dd72f31601a68b3 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Fri, 6 Feb 2026 11:26:11 -0800 Subject: [PATCH 2/6] Update docs Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- examples/megatron_bridge/README.md | 34 +++++++++++++++++++--- examples/megatron_bridge/prune_minitron.py | 2 +- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index ebd5b90f13..860df2cd17 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -18,7 +18,33 @@ This directory contains examples of using Model Optimizer with [NeMo Megatron-Br Running these examples requires many additional dependencies to be installed (e.g., Megatron-Bridge, Megatron-core, etc.), hence we strongly recommend directly using the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.02`) which has all the dependencies installed. -To get the latest ModelOpt features and examples, you can mount your latest ModelOpt cloned repository to the container at `/opt/Megatron-Bridge/3rdparty/Model-Optimizer` or pull the latest changes once inside the docker container (`cd /opt/Megatron-Bridge/3rdparty/Model-Optimizer && git checkout main && git pull`). +To get the latest ModelOpt features and examples scripts, mount your Model-Optimizer repo to the container. + +```bash +export MODELOPT_DIR=${PWD}/Model-Optimizer # or set to your local Model-Optimizer repository path if you have cloned it +if [ ! -d "${MODELOPT_DIR}" ]; then + git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR} +fi + +export DOCKER_IMAGE=nvcr.io/nvidia/nemo:26.02 +docker run \ + --gpus all \ + --shm-size=16GB \ + --net=host \ + --ulimit memlock=-1 \ + --rm -it \ + -v ${MODELOPT_DIR}:/opt/Model-Optimizer \ + -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ + -w /opt/Model-Optimizer/examples/megatron_bridge \ + ${DOCKER_IMAGE} bash +``` + +Once inside the container, you need to login with your HuggingFace token to download gated datasets / models. +Note that the default dataset for pruning and quantization is [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2), which is gated. + +```bash +huggingface-cli login --token +``` ## Pruning @@ -30,7 +56,7 @@ Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model. ```bash -torchrun --nproc_per_node 2 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \ +torchrun --nproc_per_node 2 prune_minitron.py \ --hf_model_name_or_path Qwen/Qwen3-8B \ --prune_target_params 6e9 \ --hparams_to_skip num_attention_heads \ @@ -41,7 +67,7 @@ Example usage for manually pruning to a specific architecture using following de 1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration. ```bash -torchrun --nproc_per_node 2 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \ +torchrun --nproc_per_node 2 prune_minitron.py \ --hf_model_name_or_path Qwen/Qwen3-8B \ --prune_export_config '{"hidden_size": 3584, "ffn_hidden_size": 9216}' \ --output_hf_path /tmp/Qwen3-8B-Pruned-6B-manual @@ -50,7 +76,7 @@ torchrun --nproc_per_node 2 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/exampl To see the full usage for advanced configurations, run: ```bash -torchrun --nproc_per_node 1 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help +torchrun --nproc_per_node 1 prune_minitron.py --help ``` > [!TIP] diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index 1f4e54adfd..52c792b9c8 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -27,7 +27,7 @@ --output_hf_path /tmp/Qwen3-8B-Pruned-6B To see the full usage for advanced configurations, run: - python prune_minitron.py --help + torchrun --nproc_per_node 1 prune_minitron.py --help """ import argparse From ce4d0813db029539c51ef09b4736795a78992809 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 9 Feb 2026 10:33:37 -0800 Subject: [PATCH 3/6] minor Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- examples/megatron_bridge/distill.py | 35 ++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/examples/megatron_bridge/distill.py b/examples/megatron_bridge/distill.py index 49e8f93b0e..ec2086c7ae 100644 --- a/examples/megatron_bridge/distill.py +++ b/examples/megatron_bridge/distill.py @@ -15,7 +15,7 @@ """Distillation script for Megatron-Bridge. Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model -to /checkpoints in megatron torch_dist checkpoint format. +to /checkpoints in megatron distributed checkpoint format. Example usage to distill a 4B student from an 8B teacher on 8 GPUs: @@ -26,6 +26,7 @@ --student_hf_path Qwen/Qwen3-4B \ --tp_size 8 \ --data_paths 1.0 /path/to/tokenized/data \ + --data_path_to_cache /path/to/cache/dataset_indices_qwen3 \ --seq_length 8192 \ --mbs 1 \ --gbs 768 \ @@ -36,7 +37,7 @@ --eval_interval 100 \ --eval_iters 32 \ --log_interval 10 \ - --log_dir /output/qwen3_8b_to_4b_distill + --output_dir /output/qwen3_8b_to_4b_distill Example usage to use mock data for quick testing: @@ -51,7 +52,9 @@ --mbs 1 \ --gbs 8 \ --train_iters 100 \ - --log_dir /tmp/test_distill + --eval_interval 10 \ + --eval_iters 4 \ + --output_dir /tmp/test_distill If you want to tokenize your own data for a specific tokenizer, you can use the following command: @@ -129,12 +132,15 @@ def get_args(): parser.add_argument( "--split", type=str, default="99,1,0", help="Train,Val,Test ratios to split data" ) + parser.add_argument( + "--data_path_to_cache", type=str, default=None, help="Path to cache the dataset indices" + ) parser.add_argument( "--use_mock_data", action="store_true", help="Use mock data instead of --data_paths" ) - # Training arguments + # Training & Eval arguments parser.add_argument( - "--log_dir", type=str, required=True, help="Folder for logging and checkpoint saving" + "--output_dir", type=str, required=True, help="Folder for logging and checkpoint saving" ) parser.add_argument( "--seq_length", type=int, default=8192, help="Number of tokens per input sample" @@ -153,7 +159,13 @@ def get_args(): parser.add_argument( "--eval_iters", type=int, default=32, help="Number of batches per validation stage" ) + # Logging arguments parser.add_argument("--log_interval", type=int, default=10, help="Write to log every steps") + parser.add_argument( + "--wandb_project", type=str, help="Wandb project name (required to enable Wandb logging)" + ) + parser.add_argument("--wandb_entity", type=str, help="Wandb entity name (optional)") + parser.add_argument("--wandb_exp_name", type=str, help="Wandb experiment name (optional)") args = parser.parse_args() # Sanity checks @@ -169,8 +181,8 @@ def get_args(): def main(args: argparse.Namespace): - checkpoint_dir = os.path.join(args.log_dir, "checkpoints") - tensorboard_dir = os.path.join(args.log_dir, "tb_logs") + checkpoint_dir = os.path.join(args.output_dir, "checkpoints") + tensorboard_dir = os.path.join(args.output_dir, "tb_logs") # Build student and teacher model providers def _build_model_provider(hf_path): @@ -206,6 +218,7 @@ def _build_model_provider(hf_path): # Build dataset config dataset_kwargs = { "seq_length": args.seq_length, + "path_to_cache": args.data_path_to_cache, "random_seed": SEED, "reset_attention_mask": False, "reset_position_ids": False, @@ -249,6 +262,10 @@ def _build_model_provider(hf_path): log_interval=args.log_interval, tensorboard_dir=tensorboard_dir, log_timers_to_tensorboard=True, + # Weights & Biases logging + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, # optional + wandb_exp_name=args.wandb_exp_name, ), tokenizer=TokenizerConfig( tokenizer_type="NullTokenizer", vocab_size=distill_provider.vocab_size @@ -256,8 +273,10 @@ def _build_model_provider(hf_path): checkpoint=CheckpointConfig( save_interval=args.eval_interval, save=checkpoint_dir, - load=checkpoint_dir, + load=checkpoint_dir, # Resume from this directory (if exists) + most_recent_k=3, # Keeps 3 most recent checkpoints (not metric-based) ckpt_format="torch_dist", + async_save=True, fully_parallel_save=True, finetune=True, ), From c18315bdf1024a9458bcc98423a81ece7a63ff63 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 9 Feb 2026 13:40:06 -0800 Subject: [PATCH 4/6] Fix resuming Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- examples/megatron_bridge/distill.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/megatron_bridge/distill.py b/examples/megatron_bridge/distill.py index ec2086c7ae..58ec0594d2 100644 --- a/examples/megatron_bridge/distill.py +++ b/examples/megatron_bridge/distill.py @@ -72,7 +72,6 @@ max_sequence_length=256000, ) """ -# TODO: Fix resuming distillation from an intermediate checkpoint. import argparse import os @@ -278,7 +277,6 @@ def _build_model_provider(hf_path): ckpt_format="torch_dist", async_save=True, fully_parallel_save=True, - finetune=True, ), rng=RNGConfig(seed=SEED), mixed_precision="bf16_mixed", From 50b6b7eab328cdcc757d6e08b9087d7dd7c38b13 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:31:41 -0800 Subject: [PATCH 5/6] Update readme Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- examples/megatron_bridge/README.md | 100 +++++++++++++++++++-- examples/megatron_bridge/distill.py | 70 +++------------ examples/megatron_bridge/prune_minitron.py | 25 ++++-- 3 files changed, 122 insertions(+), 73 deletions(-) diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index 860df2cd17..fdad51e3ad 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -4,13 +4,13 @@ This directory contains examples of using Model Optimizer with [NeMo Megatron-Br
-| **Section** | **Description** | **Link** | **Docs** | -| :------------: | :------------: | :------------: | :------------: | -| Pre-Requisites | Development environment setup | \[[Link](#pre-requisites)\] | | -| Pruning | Examples of pruning a model using Minitron algorithm | \[[Link](#pruning)\] | | -| Distillation | Examples of distillation a pruned or quantized model | \[[Link](#distillation)\] | | -| Quantization | Examples of quantizing a model | \[[Link](#quantization)\] | | -| Resources | Extra links to relevant resources | \[[Link](#resources)\] | | +| **Section** | **Description** | **Link** | +| :------------: | :------------: | :------------: | +| Pre-Requisites | Development environment setup | \[[Link](#pre-requisites)\] | +| Pruning | Examples of pruning a model using Minitron algorithm | \[[Link](#pruning)\] | +| Distillation | Examples of distillation a pruned or quantized model | \[[Link](#distillation)\] | +| Quantization | Examples of quantizing a model | \[[Link](#quantization)\] | +| Resources | Extra links to relevant resources | \[[Link](#resources)\] |
@@ -57,6 +57,7 @@ Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while ```bash torchrun --nproc_per_node 2 prune_minitron.py \ + --pp_size 2 \ --hf_model_name_or_path Qwen/Qwen3-8B \ --prune_target_params 6e9 \ --hparams_to_skip num_attention_heads \ @@ -68,6 +69,7 @@ Example usage for manually pruning to a specific architecture using following de ```bash torchrun --nproc_per_node 2 prune_minitron.py \ + --pp_size 2 \ --hf_model_name_or_path Qwen/Qwen3-8B \ --prune_export_config '{"hidden_size": 3584, "ffn_hidden_size": 9216}' \ --output_hf_path /tmp/Qwen3-8B-Pruned-6B-manual @@ -86,7 +88,89 @@ torchrun --nproc_per_node 1 prune_minitron.py --help ## Distillation -TODO - Add info! +This section shows how to distill a student model from a teacher model in the Megatron-Bridge framework. + +This can be used stand-alone or after pruning (see [Pruning](#pruning)) / quantization (see [Quantization](#quantization)) to recover accuracy of the model by distilling from the original model (teacher). + +The [distill.py](distill.py) script loads student and teacher models from HuggingFace checkpoints and saves the distilled model to `/checkpoints` in Megatron distributed checkpoint format. + +### Data Preparation + +The distillation script expects pre-tokenized data in Megatron's binary format (`.bin` / `.idx` files). +You can tokenize your JSONL dataset using the following function: + +```python +from modelopt.torch.utils.plugins import megatron_preprocess_data + +megatron_preprocess_data( + input_path="/path/to/your/data.jsonl", + output_dir="/path/to/tokenized/data", + tokenizer_name_or_path="Qwen/Qwen3-0.6B", + json_keys=["text"], # change to your JSON key if needed + workers=32, + log_interval=100000, + max_sequence_length=256000, # To avoid rare OOM errors if text is too long +) +``` + +If you have multiple JSONL files, you can tokenize them one by one and pass all the paths to the `--data_paths` argument. + +### Distillation with Real Data + +Example usage to distill a 4B student (HF) from an 8B teacher (HF) on 8 GPUs (TP=8, PP=1): + +```bash +torchrun --nnodes 1 --nproc_per_node 8 distill.py \ + --tp_size 8 \ + --teacher_hf_path Qwen/Qwen3-8B \ + --student_hf_path Qwen/Qwen3-4B \ + --data_paths 1.0 /path/to/tokenized/data \ + --data_path_to_cache /path/to/cache/dataset_indices_qwen3 \ + --seq_length 8192 \ + --mbs 1 \ + --gbs 768 \ + --train_iters 15000 \ + --lr 1e-4 \ + --min_lr 1e-5 \ + --lr_warmup_iters 50 \ + --eval_interval 100 \ + --eval_iters 32 \ + --log_interval 10 \ + --output_dir /output/qwen3_8b_to_4b_distill +``` + +Tensorboard logging is enabled by default and logs are saved to `/tensorboard` directory. +To use Weights & Biases for logging, set the `WANDB_API_KEY` environment variable and pass the `--wandb_project` argument. +Optionally, you can also pass `--wandb_entity` and `--wandb_exp_name` arguments to group runs under a project and experiment name. + +To see all available arguments: + +```bash +torchrun --nproc_per_node 1 distill.py --help +``` + +### Quick Test with Mock Data + +Example usage with mock data for quick testing (no pre-tokenized data needed): + +```bash +torchrun --nproc_per_node 8 distill.py \ + --tp_size 8 \ + --teacher_hf_path Qwen/Qwen3-0.6B \ + --student_hf_path Qwen/Qwen3-0.6B \ + --use_mock_data \ + --seq_length 512 \ + --mbs 1 \ + --gbs 8 \ + --train_iters 100 \ + --eval_interval 10 \ + --eval_iters 4 \ + --output_dir /tmp/test_distill +``` + +### Slurm Usage + +To run the distillation script on a Slurm cluster for multi-node training, you just need use `python` instead of `torchrun` and set the number of nodes using `#SBATCH --nodes=` clause in your Slurm script. ## Quantization diff --git a/examples/megatron_bridge/distill.py b/examples/megatron_bridge/distill.py index 58ec0594d2..c21bf73121 100644 --- a/examples/megatron_bridge/distill.py +++ b/examples/megatron_bridge/distill.py @@ -15,62 +15,9 @@ """Distillation script for Megatron-Bridge. Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model -to /checkpoints in megatron distributed checkpoint format. +to `/checkpoints` in megatron distributed checkpoint format. -Example usage to distill a 4B student from an 8B teacher on 8 GPUs: - -.. code-block:: bash - - torchrun --nproc_per_node 8 distill.py \ - --teacher_hf_path Qwen/Qwen3-8B \ - --student_hf_path Qwen/Qwen3-4B \ - --tp_size 8 \ - --data_paths 1.0 /path/to/tokenized/data \ - --data_path_to_cache /path/to/cache/dataset_indices_qwen3 \ - --seq_length 8192 \ - --mbs 1 \ - --gbs 768 \ - --train_iters 15000 \ - --lr 1e-4 \ - --min_lr 1e-5 \ - --lr_warmup_iters 50 \ - --eval_interval 100 \ - --eval_iters 32 \ - --log_interval 10 \ - --output_dir /output/qwen3_8b_to_4b_distill - -Example usage to use mock data for quick testing: - -.. code-block:: bash - - torchrun --nproc_per_node 8 distill.py \ - --teacher_hf_path Qwen/Qwen3-0.6B \ - --student_hf_path Qwen/Qwen3-0.6B \ - --tp_size 8 \ - --use_mock_data \ - --seq_length 512 \ - --mbs 1 \ - --gbs 8 \ - --train_iters 100 \ - --eval_interval 10 \ - --eval_iters 4 \ - --output_dir /tmp/test_distill - -If you want to tokenize your own data for a specific tokenizer, you can use the following command: - -.. code-block:: python - - from modelopt.torch.utils.plugins import megatron_preprocess_data - - megatron_preprocess_data( - input_path="/path/to/your/data.jsonl", - output_dir="/path/to/tokenized/data", - tokenizer_name_or_path="Qwen/Qwen3-0.6B", - json_keys=["text"], - workers=32, - log_interval=100000, - max_sequence_length=256000, - ) +See `README.md` in this directory for example usage and data preparation instructions. """ import argparse @@ -106,7 +53,7 @@ def get_args(): """Parse command-line arguments.""" parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.") - # Model arguments + # Model arguments (accepts HuggingFace input only at the moment) parser.add_argument( "--student_hf_path", type=str, @@ -142,7 +89,10 @@ def get_args(): "--output_dir", type=str, required=True, help="Folder for logging and checkpoint saving" ) parser.add_argument( - "--seq_length", type=int, default=8192, help="Number of tokens per input sample" + "--seq_length", + type=int, + default=4096, + help="Number of tokens per input sample. Use 8192 if your dataset has longer sequences.", ) parser.add_argument("--mbs", type=int, default=1, help="Micro-batch Size") parser.add_argument("--gbs", type=int, default=768, help="Global Batch Size") @@ -187,16 +137,18 @@ def main(args: argparse.Namespace): def _build_model_provider(hf_path): bridge = AutoBridge.from_hf_pretrained(hf_path) provider = bridge.to_megatron_provider(load_weights=True) + + # Override parallelism / training settings provider.tensor_model_parallel_size = args.tp_size provider.pipeline_model_parallel_size = args.pp_size provider.context_parallel_size = 1 provider.sequence_parallel = args.tp_size > 1 provider.seq_length = args.seq_length provider.pipeline_dtype = torch.bfloat16 - provider.cross_entropy_fusion_impl = "te" return provider - # TODO: Support megatron-ckpt as an alternative to HF checkpoints + # TODO: Support megatron-ckpt as an alternative to HF checkpoints (e.g. /path/to/ckpt/iter_0000000) + # Still requires an HF model name or path to build provider correctly student_provider = _build_model_provider(args.student_hf_path) teacher_provider = _build_model_provider(args.teacher_hf_path) diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index 52c792b9c8..44eac3a31a 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -28,8 +28,11 @@ To see the full usage for advanced configurations, run: torchrun --nproc_per_node 1 prune_minitron.py --help + +See `README.md` in this directory for more details. """ +# TODO: Test multi-node pruning import argparse import json import os @@ -66,9 +69,20 @@ def get_args() -> argparse.Namespace: "--output_hf_path", type=str, help="Path to save the pruned model in HF checkpoint format" ) - # Uneven Pipeline Parallelism parameters - parser.add_argument("--num_layers_in_first_pipeline_stage", type=int, default=None) - parser.add_argument("--num_layers_in_last_pipeline_stage", type=int, default=None) + # Parallelism arguments + parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") + parser.add_argument( + "--num_layers_in_first_pipeline_stage", + type=int, + default=None, + help="Number of layers in the first pipeline stage (Uneven Pipeline Parallelism)", + ) + parser.add_argument( + "--num_layers_in_last_pipeline_stage", + type=int, + default=None, + help="Number of layers in the last pipeline stage (Uneven Pipeline Parallelism)", + ) # Calibration dataset parameters parser.add_argument( @@ -201,8 +215,7 @@ def get_args() -> argparse.Namespace: def main(args: argparse.Namespace): - pp_size = dist.size() - print_rank_0(f"Setting pipeline_model_parallel_size to {pp_size}") + assert dist.size() == args.pp_size, "Only Pipeline parallelism is supported for pruning." if args.output_megatron_path and os.path.exists( f"{args.output_megatron_path}/latest_checkpointed_iteration.txt" @@ -218,7 +231,7 @@ def main(args: argparse.Namespace): trust_remote_code=args.trust_remote_code, provider_overrides={ "tensor_model_parallel_size": 1, - "pipeline_model_parallel_size": pp_size, + "pipeline_model_parallel_size": args.pp_size, "num_layers_in_first_pipeline_stage": args.num_layers_in_first_pipeline_stage, "num_layers_in_last_pipeline_stage": args.num_layers_in_last_pipeline_stage, "pipeline_dtype": torch.bfloat16, From 59bc44cb1c9444274b685d9b5c8d147871df86ef Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 11 Feb 2026 12:29:40 -0800 Subject: [PATCH 6/6] minor doc update Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- examples/megatron_bridge/README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index fdad51e3ad..385cbcb659 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -172,6 +172,19 @@ torchrun --nproc_per_node 8 distill.py \ To run the distillation script on a Slurm cluster for multi-node training, you just need use `python` instead of `torchrun` and set the number of nodes using `#SBATCH --nodes=` clause in your Slurm script. +### Convert Megatron checkpoint to Hugging Face format + +To convert the Megatron checkpoint from last iteration (or any intermediate iteration) to Hugging Face format, you need the pruned model config (`--output_hf_path` from `prune_minitron.py` script) and the distilled megatron checkpoint dir (`/checkpoints/iter_`) to run the following command: + +```bash +uv run python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py export \ + --hf-model \ + --megatron-path /checkpoints/iter_ \ + --hf-path +``` + +For more details, you can refer to the checkpoint conversion scripts in the [Megatron-Bridge README](https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/conversion). + ## Quantization TODO