Skip to content

Commit d0f930f

Browse files
Support Megatron ckpt inputs for distill.py + doc updates
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent c18315b commit d0f930f

3 files changed

Lines changed: 183 additions & 79 deletions

File tree

examples/megatron_bridge/README.md

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,88 @@ torchrun --nproc_per_node 1 prune_minitron.py --help
8686
8787
## Distillation
8888

89-
TODO - Add info!
89+
This section shows how to distill a student model from a teacher model in the Megatron-Bridge framework.
90+
The student and teacher can each be loaded from a HuggingFace model (hub name or local dir) or a Megatron checkpoint (`iter_*` directory) - auto-detected based on the path provided.
91+
92+
This can be used stand-alone or after pruning (see [Pruning](#pruning)) / quantization (see [Quantization](#quantization)) to recover accuracy of the model by distilling from the original model (teacher).
93+
94+
The [distill.py](distill.py) script loads student and teacher models from HuggingFace or Megatron checkpoints and saves the distilled model to `<output_dir>/checkpoints` in Megatron distributed checkpoint format.
95+
96+
Supported input model path formats for student and teacher (auto-detected by the script):
97+
- HuggingFace: model hub name (e.g. `Qwen/Qwen3-8B`) or local HF dir
98+
- Megatron: `iter_*` directory inside a Megatron checkpoint (e.g. `/path/to/ckpt/iter_0000000`)
99+
100+
### Data Preparation
101+
102+
The distillation script expects pre-tokenized data in Megatron's binary format (`.bin` / `.idx` files).
103+
You can tokenize your JSONL dataset using the following function. If you have multiple JSONL files, you can tokenize them one by one and pass all the paths to the `--data_paths` argument.
104+
105+
```python
106+
from modelopt.torch.utils.plugins import megatron_preprocess_data
107+
108+
megatron_preprocess_data(
109+
input_path="/path/to/your/data.jsonl",
110+
output_dir="/path/to/tokenized/data",
111+
tokenizer_name_or_path="Qwen/Qwen3-0.6B",
112+
json_keys=["text"], # change to your JSON key if needed
113+
workers=32,
114+
log_interval=100000,
115+
max_sequence_length=256000, # To avoid rare OOM errors if text is too long
116+
)
117+
```
118+
119+
### Distillation with Real Data
120+
121+
Example usage to distill a 4B student (HF) from an 8B teacher (HF) on 8 GPUs:
122+
123+
```bash
124+
torchrun --nproc_per_node 8 distill.py \
125+
--teacher_path Qwen/Qwen3-8B \
126+
--student_path Qwen/Qwen3-4B \
127+
--tp_size 8 \
128+
--data_paths 1.0 /path/to/tokenized/data \
129+
--data_path_to_cache /path/to/cache/dataset_indices_qwen3 \
130+
--seq_length 8192 \
131+
--mbs 1 \
132+
--gbs 768 \
133+
--train_iters 15000 \
134+
--lr 1e-4 \
135+
--min_lr 1e-5 \
136+
--lr_warmup_iters 50 \
137+
--eval_interval 100 \
138+
--eval_iters 32 \
139+
--log_interval 10 \
140+
--output_dir /output/qwen3_8b_to_4b_distill
141+
```
142+
143+
Tensorboard logging is enabled by default and logs are saved to `<output_dir>/tensorboard` directory.
144+
To use Weights & Biases for logging, set the `WANDB_API_KEY` environment variable and pass the `--wandb_project` argument.
145+
Optionally, you can also pass `--wandb_entity` and `--wandb_exp_name` arguments to group runs under a project and experiment name.
146+
147+
To see all available arguments:
148+
149+
```bash
150+
torchrun --nproc_per_node 1 distill.py --help
151+
```
152+
153+
### Quick Test with Mock Data
154+
155+
Example usage with mock data for quick testing (no pre-tokenized data needed):
156+
157+
```bash
158+
torchrun --nproc_per_node 8 distill.py \
159+
--teacher_path Qwen/Qwen3-0.6B \
160+
--student_path Qwen/Qwen3-0.6B \
161+
--tp_size 8 \
162+
--use_mock_data \
163+
--seq_length 512 \
164+
--mbs 1 \
165+
--gbs 8 \
166+
--train_iters 100 \
167+
--eval_interval 10 \
168+
--eval_iters 4 \
169+
--output_dir /tmp/test_distill
170+
```
90171

91172
## Quantization
92173

examples/megatron_bridge/distill.py

Lines changed: 99 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -14,66 +14,18 @@
1414
# limitations under the License.
1515
"""Distillation script for Megatron-Bridge.
1616
17-
Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model
18-
to <output_dir>/checkpoints in megatron distributed checkpoint format.
19-
20-
Example usage to distill a 4B student from an 8B teacher on 8 GPUs:
21-
22-
.. code-block:: bash
23-
24-
torchrun --nproc_per_node 8 distill.py \
25-
--teacher_hf_path Qwen/Qwen3-8B \
26-
--student_hf_path Qwen/Qwen3-4B \
27-
--tp_size 8 \
28-
--data_paths 1.0 /path/to/tokenized/data \
29-
--data_path_to_cache /path/to/cache/dataset_indices_qwen3 \
30-
--seq_length 8192 \
31-
--mbs 1 \
32-
--gbs 768 \
33-
--train_iters 15000 \
34-
--lr 1e-4 \
35-
--min_lr 1e-5 \
36-
--lr_warmup_iters 50 \
37-
--eval_interval 100 \
38-
--eval_iters 32 \
39-
--log_interval 10 \
40-
--output_dir /output/qwen3_8b_to_4b_distill
41-
42-
Example usage to use mock data for quick testing:
43-
44-
.. code-block:: bash
45-
46-
torchrun --nproc_per_node 8 distill.py \
47-
--teacher_hf_path Qwen/Qwen3-0.6B \
48-
--student_hf_path Qwen/Qwen3-0.6B \
49-
--tp_size 8 \
50-
--use_mock_data \
51-
--seq_length 512 \
52-
--mbs 1 \
53-
--gbs 8 \
54-
--train_iters 100 \
55-
--eval_interval 10 \
56-
--eval_iters 4 \
57-
--output_dir /tmp/test_distill
58-
59-
If you want to tokenize your own data for a specific tokenizer, you can use the following command:
60-
61-
.. code-block:: python
62-
63-
from modelopt.torch.utils.plugins import megatron_preprocess_data
64-
65-
megatron_preprocess_data(
66-
input_path="/path/to/your/data.jsonl",
67-
output_dir="/path/to/tokenized/data",
68-
tokenizer_name_or_path="Qwen/Qwen3-0.6B",
69-
json_keys=["text"],
70-
workers=32,
71-
log_interval=100000,
72-
max_sequence_length=256000,
73-
)
17+
Loads student and teacher models from HuggingFace or Megatron checkpoints and saves the distilled model
18+
to `<output_dir>/checkpoints` in Megatron distributed checkpoint format.
19+
20+
Supported input model path formats for student and teacher (auto-detected):
21+
- HuggingFace: model hub name (e.g. `Qwen/Qwen3-8B`) or local HF dir
22+
- Megatron: `iter_*` directory inside a Megatron checkpoint (e.g. `/path/to/ckpt/iter_0000000`)
23+
24+
See `README.md` in this directory for example usage and data preparation instructions.
7425
"""
7526

7627
import argparse
28+
import contextlib
7729
import os
7830

7931
import torch
@@ -82,6 +34,7 @@
8234
from megatron.bridge.recipes.utils.optimizer_utils import (
8335
distributed_fused_adam_with_cosine_annealing,
8436
)
37+
from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint
8538
from megatron.bridge.training.config import (
8639
CheckpointConfig,
8740
ConfigContainer,
@@ -93,6 +46,7 @@
9346
TrainingConfig,
9447
)
9548
from megatron.bridge.training.distill import distill
49+
from megatron.bridge.training.model_load_save import load_model_config
9650
from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig
9751
from megatron.core.datasets.utils import get_blend_from_list
9852
from megatron.core.distributed import DistributedDataParallelConfig
@@ -106,18 +60,18 @@
10660
def get_args():
10761
"""Parse command-line arguments."""
10862
parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.")
109-
# Model arguments
63+
# Model arguments (accepts HuggingFace paths or Megatron checkpoint directories)
11064
parser.add_argument(
111-
"--student_hf_path",
65+
"--student_path",
11266
type=str,
11367
required=True,
114-
help="HuggingFace model name or path for the student (e.g. Qwen/Qwen3-0.6B)",
68+
help="Student model: HuggingFace name/path (e.g. Qwen/Qwen3-0.6B) or Megatron dir (e.g. /path/iter_0000000)",
11569
)
11670
parser.add_argument(
117-
"--teacher_hf_path",
71+
"--teacher_path",
11872
type=str,
11973
required=True,
120-
help="HuggingFace model name or path for the teacher (e.g. Qwen/Qwen3-8B)",
74+
help="Teacher model: HuggingFace name/path (e.g. Qwen/Qwen3-8B) or Megatron dir (e.g. /path/iter_0000000)",
12175
)
12276
# Parallelism arguments
12377
parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size")
@@ -179,26 +133,93 @@ def get_args():
179133
return args
180134

181135

136+
def _build_provider_from_hf(hf_path):
137+
"""Build a model provider from a HuggingFace model path."""
138+
bridge = AutoBridge.from_hf_pretrained(hf_path)
139+
provider = bridge.to_megatron_provider(load_weights=True)
140+
return provider
141+
142+
143+
def _build_provider_from_megatron(iter_path):
144+
"""Build a model provider from a Megatron checkpoint iter directory.
145+
146+
Uses load_model_config() to reconstruct the provider (architecture) from the
147+
checkpoint's saved run config, then registers a pre_wrap_hook that loads the
148+
checkpoint weights after model creation.
149+
"""
150+
print_rank_0(f"Loading model config from Megatron checkpoint: {iter_path}")
151+
provider, _ = load_model_config(iter_path)
152+
provider.perform_initialization = False # Will load weights from ckpt
153+
154+
def _load_megatron_weights_hook(model):
155+
"""Pre-wrap hook: load weights from a Megatron checkpoint into the model.
156+
157+
For distillation kd_models, hide teacher / loss modules so only the
158+
student weights are loaded (same pattern as HF weight loading).
159+
"""
160+
cms = []
161+
for m in model:
162+
if hasattr(m, "hide_teacher_model"):
163+
cms.append(m.hide_teacher_model())
164+
if hasattr(m, "hide_loss_modules"):
165+
cms.append(m.hide_loss_modules())
166+
with contextlib.ExitStack() as stack:
167+
for cm in cms:
168+
stack.enter_context(cm)
169+
_load_model_weights_from_checkpoint(iter_path, model)
170+
return model
171+
172+
provider.register_pre_wrap_hook(_load_megatron_weights_hook)
173+
return provider
174+
175+
176+
def _build_model_provider(path, *, tp_size, pp_size, seq_length):
177+
"""Build a model provider, auto-detecting HuggingFace vs Megatron checkpoint.
178+
179+
Megatron checkpoints are detected by the `iter_` prefix in the directory name
180+
(e.g. `/path/to/checkpoints/iter_0000000`).
181+
"""
182+
if os.path.basename(os.path.normpath(path)).startswith("iter_"):
183+
print_rank_0(f"Detected Megatron checkpoint: {path}")
184+
provider = _build_provider_from_megatron(path)
185+
else:
186+
print_rank_0(f"Loading from HuggingFace: {path}")
187+
provider = _build_provider_from_hf(path)
188+
189+
# Override parallelism / training settings
190+
provider.tensor_model_parallel_size = tp_size
191+
provider.pipeline_model_parallel_size = pp_size
192+
provider.context_parallel_size = 1
193+
provider.sequence_parallel = tp_size > 1
194+
provider.seq_length = seq_length
195+
provider.pipeline_dtype = torch.bfloat16
196+
provider.cross_entropy_fusion_impl = "te"
197+
198+
# Run deferred __post_init__ to compute derived fields (e.g. init_method from init_method_std).
199+
# Must be called after all overrides are set.
200+
# Safe to call on HF-sourced providers too (idempotent).
201+
provider.finalize()
202+
203+
return provider
204+
205+
182206
def main(args: argparse.Namespace):
183207
checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
184208
tensorboard_dir = os.path.join(args.output_dir, "tb_logs")
185209

186-
# Build student and teacher model providers
187-
def _build_model_provider(hf_path):
188-
bridge = AutoBridge.from_hf_pretrained(hf_path)
189-
provider = bridge.to_megatron_provider(load_weights=True)
190-
provider.tensor_model_parallel_size = args.tp_size
191-
provider.pipeline_model_parallel_size = args.pp_size
192-
provider.context_parallel_size = 1
193-
provider.sequence_parallel = args.tp_size > 1
194-
provider.seq_length = args.seq_length
195-
provider.pipeline_dtype = torch.bfloat16
196-
provider.cross_entropy_fusion_impl = "te"
197-
return provider
198-
199-
# TODO: Support megatron-ckpt as an alternative to HF checkpoints
200-
student_provider = _build_model_provider(args.student_hf_path)
201-
teacher_provider = _build_model_provider(args.teacher_hf_path)
210+
# Build student and teacher model providers (auto-detects HF vs Megatron ckpt)
211+
student_provider = _build_model_provider(
212+
args.student_path,
213+
tp_size=args.tp_size,
214+
pp_size=args.pp_size,
215+
seq_length=args.seq_length,
216+
)
217+
teacher_provider = _build_model_provider(
218+
args.teacher_path,
219+
tp_size=args.tp_size,
220+
pp_size=args.pp_size,
221+
seq_length=args.seq_length,
222+
)
202223

203224
# Wrap into DistillationProvider
204225
kd_config = ModelOptDistillConfig()

examples/megatron_bridge/prune_minitron.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
2929
To see the full usage for advanced configurations, run:
3030
torchrun --nproc_per_node 1 prune_minitron.py --help
31+
32+
See `README.md` in this directory for more details.
3133
"""
3234

3335
import argparse

0 commit comments

Comments
 (0)