Skip to content

Commit 6abc8ab

Browse files
j-rauschroot
authored andcommitted
consolidate mbridge distillation scripts; merge distill_hf.py into distill.py
Signed-off-by: jrausch <jrausch@nvidia.com> Signed-off-by: root <root@pool0-00848.cm.cluster>
1 parent 25266b8 commit 6abc8ab

8 files changed

Lines changed: 224 additions & 617 deletions

File tree

examples/megatron_bridge/README.md

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ This section shows how to distill a student model from a teacher model in the Me
9292

9393
This can be used stand-alone or after [Pruning](#pruning) / [Post-Training Quantization](#post-training-quantization) to recover accuracy of the model by distilling from the original model (teacher).
9494

95-
The [distill.py](distill.py) script loads student and teacher models from HuggingFace checkpoints and saves the distilled model to `<output_dir>/checkpoints` in Megatron distributed checkpoint format.
95+
The [distill.py](distill.py) script supports both standard HuggingFace checkpoints and [Puzzletron AnyModel](../puzzletron/README.md) checkpoints as student/teacher inputs. Just pass the checkpoint path via `--student_hf_path` / `--teacher_hf_path`. The distilled model is saved to `<output_dir>/checkpoints` in Megatron distributed checkpoint format.
9696

9797
### Data Preparation
9898

@@ -194,9 +194,22 @@ torchrun --nproc_per_node 8 distill.py \
194194

195195
To run the distillation script on a Slurm cluster for multi-node training, you just need use `python` instead of `torchrun` and set the number of nodes using `#SBATCH --nodes=<num_nodes>` clause in your Slurm script.
196196

197-
### Convert Megatron checkpoint to Hugging Face format
197+
### Converting to Hugging Face format (optional)
198198

199-
To convert the Megatron checkpoint from last iteration (or any intermediate iteration) to Hugging Face format, you need the pruned model config (`--output_hf_path` from `prune_minitron.py` script) and the distilled megatron checkpoint dir (`<distill_output_dir>/checkpoints/iter_<iter_number>`) to run the following command:
199+
The distilled checkpoint is saved in Megatron distributed format. If you need a HuggingFace checkpoint, there are two ways to convert it:
200+
201+
**Inline** -- add `--hf_export_path` and `--student_hf_model` to the `distill.py` command to automatically convert the final checkpoint after distillation:
202+
203+
```bash
204+
torchrun --nnodes 1 --nproc_per_node 8 distill.py \
205+
... \
206+
--hf_export_path /path/to/save/distilled_hf_ckpt \
207+
--student_hf_model Qwen/Qwen3-4B
208+
```
209+
210+
`--student_hf_model` should match the base architecture of the student (used as a template for export).
211+
212+
**Separate conversion** -- convert any saved iteration using the Megatron-Bridge conversion script:
200213

201214
```bash
202215
uv run python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py export \
@@ -205,7 +218,52 @@ uv run python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py ex
205218
--hf-path <path_to_save_distilled_hf_ckpt>
206219
```
207220

208-
For more details, you can refer to the checkpoint conversion scripts in the [Megatron-Bridge README](https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/conversion).
221+
For more details, see the [Megatron-Bridge conversion README](https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/conversion).
222+
223+
> **Known limitation:** HF export does not yet work for Puzzletron AnyModel (heterogeneous) checkpoints -- Megatron-Bridge cannot reload heterogeneous configs from saved checkpoints. Standard models export correctly with both methods.
224+
225+
### Distillation Results
226+
227+
The following MMLU results demonstrate knowledge distillation on student models that were first compressed using [Puzzletron](../puzzletron/README.md). The original (uncompressed) model serves as the teacher, and distillation recovers accuracy lost during compression.
228+
229+
#### Qwen3-8B compressed to 80% of original
230+
231+
The student was created by compressing Qwen3-8B to 80% of its original size using Puzzletron.
232+
233+
| Model | MMLU | Humanities | Other | Social Sci | STEM |
234+
|-------|------|------------|-------|------------|------|
235+
| Student (before distillation) | 0.5910 | 0.5046 | 0.6363 | 0.6831 | 0.5855 |
236+
| Student (after distillation) | 0.6921 | 0.5906 | 0.7316 | 0.7975 | 0.7016 |
237+
| Teacher (original Qwen3-8B) | 0.7493 | 0.6648 | 0.7856 | 0.8385 | 0.7526 |
238+
239+
MMLU accuracy improved from 59.10% to 69.21% (+10.11 pp) after distillation with just 100 iterations on WikiText-103, recovering 64% of the gap to the teacher model.
240+
241+
#### Llama-3.1-8B-Instruct compressed to 50% of original
242+
243+
The student was created by compressing Llama-3.1-8B-Instruct to 50% of its original size using Puzzletron.
244+
245+
| Model | MMLU | Humanities | Other | Social Sciences | STEM |
246+
|-------|------|------------|-------|-----------------|------|
247+
| Student (before distillation) | 0.2316 | 0.2462 | 0.2292 | 0.2250 | 0.2274 |
248+
| Student (after distillation) | 0.2960 | 0.3146 | 0.3085 | 0.2925 | 0.2768 |
249+
| Teacher (original Llama-3.1-8B-Instruct) | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 |
250+
251+
#### Llama-3.1-8B-Instruct compressed to 69% of original (regression)
252+
253+
The student was created by compressing Llama-3.1-8B-Instruct to ~69% of its original size using Puzzletron. This example shows regression due to overfitting on the small WikiText-103 dataset (100 iterations). MMLU was evaluated on a subset of 100 samples per task:
254+
255+
| Model | MMLU | Humanities | Other | Social Sciences | STEM |
256+
|-------|------|------------|-------|-----------------|------|
257+
| Student (before distillation) | 0.6626 | 0.7069 | 0.6892 | 0.7525 | 0.5574 |
258+
| Student (after distillation) | 0.6496 | 0.6862 | 0.6677 | 0.7433 | 0.5532 |
259+
| Teacher (original Llama-3.1-8B-Instruct) | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 |
260+
261+
MMLU decreased from 66.26% to 64.96% (-1.30 pp) -- the model overfitted to WikiText-103. This highlights the importance of using larger, more diverse datasets for distillation.
262+
263+
#### Recommendations
264+
265+
- **Use larger datasets** for production distillation (e.g., [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1)) to avoid overfitting as shown in the regression case above.
266+
- **Train for more iterations** to ensure proper convergence.
209267

210268
## Post-Training Quantization
211269

examples/megatron_bridge/distill.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
"""
2222

2323
import argparse
24+
import contextlib
2425
import os
26+
import shutil
2527

2628
import torch
2729
from megatron.bridge import AutoBridge
@@ -44,6 +46,9 @@
4446
from megatron.core.datasets.utils import get_blend_from_list
4547
from megatron.core.distributed import DistributedDataParallelConfig
4648

49+
with contextlib.suppress(ImportError):
50+
import modelopt.torch.puzzletron.export.mbridge # noqa: F401
51+
4752
import modelopt.torch.utils.distributed as dist
4853
from modelopt.torch.utils import print_rank_0
4954

@@ -124,12 +129,33 @@ def get_args():
124129
)
125130
parser.add_argument("--wandb_entity", type=str, help="Wandb entity name (optional)")
126131
parser.add_argument("--wandb_exp_name", type=str, help="Wandb experiment name (optional)")
132+
# Export arguments
133+
parser.add_argument(
134+
"--hf_export_path",
135+
type=str,
136+
default=None,
137+
help=(
138+
"Path where to save the HuggingFace export. "
139+
"If provided, exports last iteration checkpoint to HF format after distillation."
140+
),
141+
)
142+
parser.add_argument(
143+
"--student_hf_model",
144+
type=str,
145+
required=False,
146+
default=None,
147+
help="HuggingFace model ID to use as template for export (e.g., Qwen/Qwen3-0.6B). "
148+
"Should match the base architecture of the student model if --hf_export_path is provided.",
149+
)
127150
args = parser.parse_args()
128151

129152
# Sanity checks
130153
if not args.use_mock_data and not args.data_paths:
131154
raise ValueError("Must provide either --data_paths or set --use_mock_data.")
132155

156+
if args.hf_export_path and not args.student_hf_model:
157+
raise ValueError("Must provide --student_hf_model if --hf_export_path is provided.")
158+
133159
print_rank_0("\n==================== Arguments ====================")
134160
for k, v in args.__dict__.items():
135161
print_rank_0(f"{k:<35} {v}")
@@ -252,9 +278,31 @@ def _build_model_provider(hf_path):
252278
print_rank_0("\nStarting distillation...")
253279
distill(config)
254280
print_rank_0(
255-
f"\nDistillation done! Saved checkpoint to {checkpoint_dir} in megatron distributed checkpoint format.\n"
281+
f"\nDistillation done! Saved checkpoint to {checkpoint_dir}"
282+
" in megatron distributed checkpoint format.\n"
256283
)
257284

285+
if args.hf_export_path:
286+
print_rank_0(f"Exporting final distilled ckpt to HF format to {args.hf_export_path}")
287+
# Save rank before destroying process group (dist.rank() won't work after destruction)
288+
is_rank_0 = dist.rank() == 0
289+
290+
# Destroy process group on all ranks -- export_ckpt will create its own temporary one.
291+
# This prevents cleanup from hanging (cleanup tries to barrier, but rank 0 would be gone).
292+
dist.cleanup()
293+
294+
if is_rank_0:
295+
export_bridge = AutoBridge.from_hf_pretrained(
296+
args.student_hf_model, trust_remote_code=args.trust_remote_code
297+
)
298+
export_bridge.export_ckpt(
299+
megatron_path=f"{checkpoint_dir}/iter_{args.train_iters:07d}",
300+
hf_path=args.hf_export_path,
301+
show_progress=True,
302+
strict=True,
303+
)
304+
shutil.copy(f"{args.student_hf_path}/config.json", f"{args.hf_export_path}/config.json")
305+
258306

259307
if __name__ == "__main__":
260308
dist.setup()

examples/puzzletron/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ vllm bench throughput --model path/to/model --input-len 2000 --output-len 100 --
277277
278278
To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one.
279279
280-
See [mbridge_distillation/README.md](./mbridge_distillation/README.md) for instructions on using Megatron-Bridge for knowledge distillation.
280+
See [Megatron-Bridge distillation](../megatron_bridge/README.md#distillation) for instructions on using Megatron-Bridge for knowledge distillation. The distillation script supports both standard HuggingFace and Puzzletron AnyModel checkpoints.
281281
282282
## Advanced Usage
283283

examples/puzzletron/mbridge_distillation/README.md

Lines changed: 0 additions & 152 deletions
This file was deleted.

0 commit comments

Comments
 (0)