Skip to content

Commit 50b6b7e

Browse files
Update readme
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent c18315b commit 50b6b7e

3 files changed

Lines changed: 122 additions & 73 deletions

File tree

examples/megatron_bridge/README.md

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ This directory contains examples of using Model Optimizer with [NeMo Megatron-Br
44

55
<div align="center">
66

7-
| **Section** | **Description** | **Link** | **Docs** |
8-
| :------------: | :------------: | :------------: | :------------: |
9-
| Pre-Requisites | Development environment setup | \[[Link](#pre-requisites)\] | |
10-
| Pruning | Examples of pruning a model using Minitron algorithm | \[[Link](#pruning)\] | |
11-
| Distillation | Examples of distillation a pruned or quantized model | \[[Link](#distillation)\] | |
12-
| Quantization | Examples of quantizing a model | \[[Link](#quantization)\] | |
13-
| Resources | Extra links to relevant resources | \[[Link](#resources)\] | |
7+
| **Section** | **Description** | **Link** |
8+
| :------------: | :------------: | :------------: |
9+
| Pre-Requisites | Development environment setup | \[[Link](#pre-requisites)\] |
10+
| Pruning | Examples of pruning a model using Minitron algorithm | \[[Link](#pruning)\] |
11+
| Distillation | Examples of distillation a pruned or quantized model | \[[Link](#distillation)\] |
12+
| Quantization | Examples of quantizing a model | \[[Link](#quantization)\] |
13+
| Resources | Extra links to relevant resources | \[[Link](#resources)\] |
1414

1515
</div>
1616

@@ -57,6 +57,7 @@ Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while
5757

5858
```bash
5959
torchrun --nproc_per_node 2 prune_minitron.py \
60+
--pp_size 2 \
6061
--hf_model_name_or_path Qwen/Qwen3-8B \
6162
--prune_target_params 6e9 \
6263
--hparams_to_skip num_attention_heads \
@@ -68,6 +69,7 @@ Example usage for manually pruning to a specific architecture using following de
6869

6970
```bash
7071
torchrun --nproc_per_node 2 prune_minitron.py \
72+
--pp_size 2 \
7173
--hf_model_name_or_path Qwen/Qwen3-8B \
7274
--prune_export_config '{"hidden_size": 3584, "ffn_hidden_size": 9216}' \
7375
--output_hf_path /tmp/Qwen3-8B-Pruned-6B-manual
@@ -86,7 +88,89 @@ torchrun --nproc_per_node 1 prune_minitron.py --help
8688
8789
## Distillation
8890

89-
TODO - Add info!
91+
This section shows how to distill a student model from a teacher model in the Megatron-Bridge framework.
92+
93+
This can be used stand-alone or after pruning (see [Pruning](#pruning)) / quantization (see [Quantization](#quantization)) to recover accuracy of the model by distilling from the original model (teacher).
94+
95+
The [distill.py](distill.py) script loads student and teacher models from HuggingFace checkpoints and saves the distilled model to `<output_dir>/checkpoints` in Megatron distributed checkpoint format.
96+
97+
### Data Preparation
98+
99+
The distillation script expects pre-tokenized data in Megatron's binary format (`.bin` / `.idx` files).
100+
You can tokenize your JSONL dataset using the following function:
101+
102+
```python
103+
from modelopt.torch.utils.plugins import megatron_preprocess_data
104+
105+
megatron_preprocess_data(
106+
input_path="/path/to/your/data.jsonl",
107+
output_dir="/path/to/tokenized/data",
108+
tokenizer_name_or_path="Qwen/Qwen3-0.6B",
109+
json_keys=["text"], # change to your JSON key if needed
110+
workers=32,
111+
log_interval=100000,
112+
max_sequence_length=256000, # To avoid rare OOM errors if text is too long
113+
)
114+
```
115+
116+
If you have multiple JSONL files, you can tokenize them one by one and pass all the paths to the `--data_paths` argument.
117+
118+
### Distillation with Real Data
119+
120+
Example usage to distill a 4B student (HF) from an 8B teacher (HF) on 8 GPUs (TP=8, PP=1):
121+
122+
```bash
123+
torchrun --nnodes 1 --nproc_per_node 8 distill.py \
124+
--tp_size 8 \
125+
--teacher_hf_path Qwen/Qwen3-8B \
126+
--student_hf_path Qwen/Qwen3-4B \
127+
--data_paths 1.0 /path/to/tokenized/data \
128+
--data_path_to_cache /path/to/cache/dataset_indices_qwen3 \
129+
--seq_length 8192 \
130+
--mbs 1 \
131+
--gbs 768 \
132+
--train_iters 15000 \
133+
--lr 1e-4 \
134+
--min_lr 1e-5 \
135+
--lr_warmup_iters 50 \
136+
--eval_interval 100 \
137+
--eval_iters 32 \
138+
--log_interval 10 \
139+
--output_dir /output/qwen3_8b_to_4b_distill
140+
```
141+
142+
Tensorboard logging is enabled by default and logs are saved to `<output_dir>/tensorboard` directory.
143+
To use Weights & Biases for logging, set the `WANDB_API_KEY` environment variable and pass the `--wandb_project` argument.
144+
Optionally, you can also pass `--wandb_entity` and `--wandb_exp_name` arguments to group runs under a project and experiment name.
145+
146+
To see all available arguments:
147+
148+
```bash
149+
torchrun --nproc_per_node 1 distill.py --help
150+
```
151+
152+
### Quick Test with Mock Data
153+
154+
Example usage with mock data for quick testing (no pre-tokenized data needed):
155+
156+
```bash
157+
torchrun --nproc_per_node 8 distill.py \
158+
--tp_size 8 \
159+
--teacher_hf_path Qwen/Qwen3-0.6B \
160+
--student_hf_path Qwen/Qwen3-0.6B \
161+
--use_mock_data \
162+
--seq_length 512 \
163+
--mbs 1 \
164+
--gbs 8 \
165+
--train_iters 100 \
166+
--eval_interval 10 \
167+
--eval_iters 4 \
168+
--output_dir /tmp/test_distill
169+
```
170+
171+
### Slurm Usage
172+
173+
To run the distillation script on a Slurm cluster for multi-node training, you just need use `python` instead of `torchrun` and set the number of nodes using `#SBATCH --nodes=<num_nodes>` clause in your Slurm script.
90174

91175
## Quantization
92176

examples/megatron_bridge/distill.py

Lines changed: 11 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -15,62 +15,9 @@
1515
"""Distillation script for Megatron-Bridge.
1616
1717
Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model
18-
to <output_dir>/checkpoints in megatron distributed checkpoint format.
18+
to `<output_dir>/checkpoints` in megatron distributed checkpoint format.
1919
20-
Example usage to distill a 4B student from an 8B teacher on 8 GPUs:
21-
22-
.. code-block:: bash
23-
24-
torchrun --nproc_per_node 8 distill.py \
25-
--teacher_hf_path Qwen/Qwen3-8B \
26-
--student_hf_path Qwen/Qwen3-4B \
27-
--tp_size 8 \
28-
--data_paths 1.0 /path/to/tokenized/data \
29-
--data_path_to_cache /path/to/cache/dataset_indices_qwen3 \
30-
--seq_length 8192 \
31-
--mbs 1 \
32-
--gbs 768 \
33-
--train_iters 15000 \
34-
--lr 1e-4 \
35-
--min_lr 1e-5 \
36-
--lr_warmup_iters 50 \
37-
--eval_interval 100 \
38-
--eval_iters 32 \
39-
--log_interval 10 \
40-
--output_dir /output/qwen3_8b_to_4b_distill
41-
42-
Example usage to use mock data for quick testing:
43-
44-
.. code-block:: bash
45-
46-
torchrun --nproc_per_node 8 distill.py \
47-
--teacher_hf_path Qwen/Qwen3-0.6B \
48-
--student_hf_path Qwen/Qwen3-0.6B \
49-
--tp_size 8 \
50-
--use_mock_data \
51-
--seq_length 512 \
52-
--mbs 1 \
53-
--gbs 8 \
54-
--train_iters 100 \
55-
--eval_interval 10 \
56-
--eval_iters 4 \
57-
--output_dir /tmp/test_distill
58-
59-
If you want to tokenize your own data for a specific tokenizer, you can use the following command:
60-
61-
.. code-block:: python
62-
63-
from modelopt.torch.utils.plugins import megatron_preprocess_data
64-
65-
megatron_preprocess_data(
66-
input_path="/path/to/your/data.jsonl",
67-
output_dir="/path/to/tokenized/data",
68-
tokenizer_name_or_path="Qwen/Qwen3-0.6B",
69-
json_keys=["text"],
70-
workers=32,
71-
log_interval=100000,
72-
max_sequence_length=256000,
73-
)
20+
See `README.md` in this directory for example usage and data preparation instructions.
7421
"""
7522

7623
import argparse
@@ -106,7 +53,7 @@
10653
def get_args():
10754
"""Parse command-line arguments."""
10855
parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.")
109-
# Model arguments
56+
# Model arguments (accepts HuggingFace input only at the moment)
11057
parser.add_argument(
11158
"--student_hf_path",
11259
type=str,
@@ -142,7 +89,10 @@ def get_args():
14289
"--output_dir", type=str, required=True, help="Folder for logging and checkpoint saving"
14390
)
14491
parser.add_argument(
145-
"--seq_length", type=int, default=8192, help="Number of tokens per input sample"
92+
"--seq_length",
93+
type=int,
94+
default=4096,
95+
help="Number of tokens per input sample. Use 8192 if your dataset has longer sequences.",
14696
)
14797
parser.add_argument("--mbs", type=int, default=1, help="Micro-batch Size")
14898
parser.add_argument("--gbs", type=int, default=768, help="Global Batch Size")
@@ -187,16 +137,18 @@ def main(args: argparse.Namespace):
187137
def _build_model_provider(hf_path):
188138
bridge = AutoBridge.from_hf_pretrained(hf_path)
189139
provider = bridge.to_megatron_provider(load_weights=True)
140+
141+
# Override parallelism / training settings
190142
provider.tensor_model_parallel_size = args.tp_size
191143
provider.pipeline_model_parallel_size = args.pp_size
192144
provider.context_parallel_size = 1
193145
provider.sequence_parallel = args.tp_size > 1
194146
provider.seq_length = args.seq_length
195147
provider.pipeline_dtype = torch.bfloat16
196-
provider.cross_entropy_fusion_impl = "te"
197148
return provider
198149

199-
# TODO: Support megatron-ckpt as an alternative to HF checkpoints
150+
# TODO: Support megatron-ckpt as an alternative to HF checkpoints (e.g. /path/to/ckpt/iter_0000000)
151+
# Still requires an HF model name or path to build provider correctly
200152
student_provider = _build_model_provider(args.student_hf_path)
201153
teacher_provider = _build_model_provider(args.teacher_hf_path)
202154

examples/megatron_bridge/prune_minitron.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@
2828
2929
To see the full usage for advanced configurations, run:
3030
torchrun --nproc_per_node 1 prune_minitron.py --help
31+
32+
See `README.md` in this directory for more details.
3133
"""
3234

35+
# TODO: Test multi-node pruning
3336
import argparse
3437
import json
3538
import os
@@ -66,9 +69,20 @@ def get_args() -> argparse.Namespace:
6669
"--output_hf_path", type=str, help="Path to save the pruned model in HF checkpoint format"
6770
)
6871

69-
# Uneven Pipeline Parallelism parameters
70-
parser.add_argument("--num_layers_in_first_pipeline_stage", type=int, default=None)
71-
parser.add_argument("--num_layers_in_last_pipeline_stage", type=int, default=None)
72+
# Parallelism arguments
73+
parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size")
74+
parser.add_argument(
75+
"--num_layers_in_first_pipeline_stage",
76+
type=int,
77+
default=None,
78+
help="Number of layers in the first pipeline stage (Uneven Pipeline Parallelism)",
79+
)
80+
parser.add_argument(
81+
"--num_layers_in_last_pipeline_stage",
82+
type=int,
83+
default=None,
84+
help="Number of layers in the last pipeline stage (Uneven Pipeline Parallelism)",
85+
)
7286

7387
# Calibration dataset parameters
7488
parser.add_argument(
@@ -201,8 +215,7 @@ def get_args() -> argparse.Namespace:
201215

202216

203217
def main(args: argparse.Namespace):
204-
pp_size = dist.size()
205-
print_rank_0(f"Setting pipeline_model_parallel_size to {pp_size}")
218+
assert dist.size() == args.pp_size, "Only Pipeline parallelism is supported for pruning."
206219

207220
if args.output_megatron_path and os.path.exists(
208221
f"{args.output_megatron_path}/latest_checkpointed_iteration.txt"
@@ -218,7 +231,7 @@ def main(args: argparse.Namespace):
218231
trust_remote_code=args.trust_remote_code,
219232
provider_overrides={
220233
"tensor_model_parallel_size": 1,
221-
"pipeline_model_parallel_size": pp_size,
234+
"pipeline_model_parallel_size": args.pp_size,
222235
"num_layers_in_first_pipeline_stage": args.num_layers_in_first_pipeline_stage,
223236
"num_layers_in_last_pipeline_stage": args.num_layers_in_last_pipeline_stage,
224237
"pipeline_dtype": torch.bfloat16,

0 commit comments

Comments
 (0)