Skip to content

Commit d88dfcb

Browse files
consolidate mbridge distillation: merge distill_hf.py into distill.py (#1220)
## Summary - Unified `examples/puzzletron/mbridge_distillation/distill_hf.py` (AnyModel-specific) into `examples/megatron_bridge/distill.py` (general) - The single script now handles both standard HF and Puzzletron AnyModel checkpoints. - Added `--hf_export_path` / `--student_hf_model` args for inline HF export after distillation. - Merged AnyModel integration test into `tests/examples/megatron_bridge/test_distill.py` - test models use `vocab_size=128` (instead of default 102) for TP divisibility including 8. - Moved MMLU distillation results into `megatron_bridge/README.md` - puzzletron README now redirects to the consolidated docs. Limitation discovered during consolidation: HF export via `--hf_export_path` seems to currently not work for Puzzletron AnyModel (heterogeneous) checkpoints. Megatron-Bridge's `export_ckpt` cannot reload heterogeneous model configs from saved checkpoints (`heterogeneous_layers_config_encoded_json` is `None` during `__post_init__` in `heterogeneous_config.py`). This affects both inline `--hf_export_path` and the separate `convert_checkpoints.py export` script. The original `distill_hf.py` README documented this as supported, but I think it might have been broken there too (on the side of Megatron-Bridge). The consolidated README now documents this as a known limitation. HF export for standard models works fine via both methods. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added support for Puzzletron AnyModel checkpoints in distillation pipeline. * Introduced inline HuggingFace export capability during distillation process. * **Documentation** * Updated distillation guide with clearer conversion workflows and optional HuggingFace export instructions. * Added distillation benchmarks and performance recommendations. * **Bug Fixes & Improvements** * Streamlined test infrastructure and workflow configuration. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: jrausch <jrausch@nvidia.com> Signed-off-by: root <root@pool0-00848.cm.cluster> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 6395b1e commit d88dfcb

File tree

10 files changed

+234
-612
lines changed

10 files changed

+234
-612
lines changed

.github/workflows/example_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ jobs:
125125
strategy: &nemo_strategy
126126
fail-fast: false
127127
matrix:
128-
example: [megatron_bridge, puzzletron]
128+
example: [megatron_bridge]
129129
uses: ./.github/workflows/_example_tests_runner.yml
130130
secrets: inherit
131131
with:

examples/megatron_bridge/README.md

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ This section shows how to distill a student model from a teacher model in the Me
9292

9393
This can be used stand-alone or after [Pruning](#pruning) / [Post-Training Quantization](#post-training-quantization) to recover accuracy of the model by distilling from the original model (teacher).
9494

95-
The [distill.py](distill.py) script loads student and teacher models from HuggingFace checkpoints and saves the distilled model to `<output_dir>/checkpoints` in Megatron distributed checkpoint format.
95+
The [distill.py](distill.py) script supports both standard HuggingFace checkpoints and [Puzzletron AnyModel](../puzzletron/README.md) checkpoints as student/teacher inputs. Just pass the checkpoint path via `--student_hf_path` / `--teacher_hf_path`. The distilled model is saved to `<output_dir>/checkpoints` in Megatron distributed checkpoint format.
9696

9797
### Data Preparation
9898

@@ -158,9 +158,22 @@ torchrun --nproc_per_node 8 distill.py \
158158

159159
To run the distillation script on a Slurm cluster for multi-node training, you just need use `python` instead of `torchrun` and set the number of nodes using `#SBATCH --nodes=<num_nodes>` clause in your Slurm script.
160160

161-
### Convert Megatron checkpoint to Hugging Face format
161+
### Converting to Hugging Face format (optional)
162162

163-
To convert the Megatron checkpoint from last iteration (or any intermediate iteration) to Hugging Face format, you need the pruned model config (`--output_hf_path` from `prune_minitron.py` script) and the distilled megatron checkpoint dir (`<distill_output_dir>/checkpoints/iter_<iter_number>`) to run the following command:
163+
The distilled checkpoint is saved in Megatron distributed format. If you need a HuggingFace checkpoint, there are two ways to convert it:
164+
165+
**Inline** -- add `--hf_export_path` and `--student_hf_model` to the `distill.py` command to automatically convert the final checkpoint after distillation:
166+
167+
```bash
168+
torchrun --nnodes 1 --nproc_per_node 8 distill.py \
169+
... \
170+
--hf_export_path /path/to/save/distilled_hf_ckpt \
171+
--student_hf_model Qwen/Qwen3-4B
172+
```
173+
174+
`--student_hf_model` should match the base architecture of the student (used as a template for export). For non-Puzzletron (i.e. standard) models, it should be same as `--student_hf_path`.
175+
176+
**Separate conversion** -- convert any saved iteration using the Megatron-Bridge conversion script:
164177

165178
```bash
166179
uv run python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py export \
@@ -169,7 +182,11 @@ uv run python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py ex
169182
--hf-path <path_to_save_distilled_hf_ckpt>
170183
```
171184

172-
For more details, you can refer to the checkpoint conversion scripts in the [Megatron-Bridge README](https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/conversion).
185+
For more details, see the [Megatron-Bridge conversion README](https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/conversion).
186+
187+
### Distillation Results
188+
189+
See [results/puzzletron.md](results/puzzletron.md) for MMLU results demonstrating knowledge distillation on Puzzletron-compressed student models.
173190

174191
## Post-Training Quantization
175192

examples/megatron_bridge/distill.py

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,22 @@
1515
"""Distillation script for Megatron-Bridge.
1616
1717
Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model
18-
to `<output_dir>/checkpoints` in megatron distributed checkpoint format.
18+
to `<output_dir>/checkpoints` in megatron distributed checkpoint or HuggingFace format.
1919
2020
See `README.md` in this directory for example usage and data preparation instructions.
2121
"""
2222

2323
import argparse
24+
import contextlib
2425
import os
26+
from dataclasses import fields
2527

2628
import torch
2729
from megatron.bridge import AutoBridge
28-
from megatron.bridge.models.distillation_provider import convert_to_distillation_provider
30+
from megatron.bridge.models.distillation_provider import (
31+
DistillationProvider,
32+
convert_to_distillation_provider,
33+
)
2934
from megatron.bridge.recipes.utils.optimizer_utils import (
3035
distributed_fused_adam_with_cosine_annealing,
3136
)
@@ -43,13 +48,50 @@
4348
from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig
4449
from megatron.core.datasets.utils import get_blend_from_list
4550
from megatron.core.distributed import DistributedDataParallelConfig
51+
from transformers import AutoConfig
4652

4753
import modelopt.torch.utils.distributed as dist
4854
from modelopt.torch.utils import print_rank_0
4955

56+
with contextlib.suppress(ImportError):
57+
import modelopt.torch.puzzletron.plugins.mbridge # noqa: F401
58+
5059
SEED = 1234
5160

5261

62+
def _patched_to_cfg_dict(self):
63+
"""Patched DistillationProvider.to_cfg_dict method for heterogeneous teacher and student models.
64+
65+
TODO: Upstream this patch to Megatron-Bridge.
66+
"""
67+
from megatron.bridge.training.utils.config_utils import _ConfigContainerBase
68+
69+
result = {"_target_": f"{self._super_class.__module__}.{self._super_class.__qualname__}"}
70+
# Use fields from the actual student provider class, not DistillationProvider.
71+
# DistillationProvider's __dataclass_fields__ only includes TransformerConfig fields
72+
# (set at class definition time), missing GPTModelProvider-level fields like
73+
# vocab_size, share_embeddings_and_output_weights, etc.
74+
excluded_fields = {"teacher", "kd_config"}
75+
for field in fields(self._super_class):
76+
if field.name.startswith("_") or field.name in excluded_fields:
77+
continue
78+
if hasattr(self, field.name):
79+
result[field.name] = _ConfigContainerBase._convert_value_to_dict(
80+
getattr(self, field.name)
81+
)
82+
for field in fields(self):
83+
if field.name.startswith("_") or field.name in excluded_fields:
84+
continue
85+
if field.name not in result:
86+
result[field.name] = _ConfigContainerBase._convert_value_to_dict(
87+
getattr(self, field.name)
88+
)
89+
return result
90+
91+
92+
DistillationProvider.to_cfg_dict = _patched_to_cfg_dict
93+
94+
5395
def get_args():
5496
"""Parse command-line arguments."""
5597
parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.")
@@ -124,12 +166,33 @@ def get_args():
124166
)
125167
parser.add_argument("--wandb_entity", type=str, help="Wandb entity name (optional)")
126168
parser.add_argument("--wandb_exp_name", type=str, help="Wandb experiment name (optional)")
169+
# Export arguments
170+
parser.add_argument(
171+
"--hf_export_path",
172+
type=str,
173+
default=None,
174+
help=(
175+
"Path where to save the HuggingFace export. "
176+
"If provided, exports last iteration checkpoint to HF format after distillation."
177+
),
178+
)
179+
parser.add_argument(
180+
"--student_hf_model",
181+
type=str,
182+
required=False,
183+
default=None,
184+
help="HuggingFace model ID to use as template for export (e.g., Qwen/Qwen3-0.6B). "
185+
"Should match the base architecture of the student model if --hf_export_path is provided.",
186+
)
127187
args = parser.parse_args()
128188

129189
# Sanity checks
130190
if not args.use_mock_data and not args.data_paths:
131191
raise ValueError("Must provide either --data_paths or set --use_mock_data.")
132192

193+
if args.hf_export_path and not args.student_hf_model:
194+
raise ValueError("Must provide --student_hf_model if --hf_export_path is provided.")
195+
133196
print_rank_0("\n==================== Arguments ====================")
134197
for k, v in args.__dict__.items():
135198
print_rank_0(f"{k:<35} {v}")
@@ -252,9 +315,35 @@ def _build_model_provider(hf_path):
252315
print_rank_0("\nStarting distillation...")
253316
distill(config)
254317
print_rank_0(
255-
f"\nDistillation done! Saved checkpoint to {checkpoint_dir} in megatron distributed checkpoint format.\n"
318+
f"\nDistillation done! Saved checkpoint to {checkpoint_dir}"
319+
" in megatron distributed checkpoint format.\n"
256320
)
257321

322+
if args.hf_export_path:
323+
print_rank_0(f"Exporting final distilled ckpt to HF format to {args.hf_export_path}")
324+
# Save rank before destroying process group (dist.rank() won't work after destruction)
325+
is_rank_0 = dist.rank() == 0
326+
327+
# Destroy process group on all ranks -- export_ckpt will create its own temporary one.
328+
# This prevents cleanup from hanging (cleanup tries to barrier, but rank 0 would be gone).
329+
dist.cleanup()
330+
331+
if is_rank_0:
332+
export_bridge = AutoBridge.from_hf_pretrained(
333+
args.student_hf_model, trust_remote_code=args.trust_remote_code
334+
)
335+
# Copy weights and remote code
336+
export_bridge.export_ckpt(
337+
megatron_path=f"{checkpoint_dir}/iter_{args.train_iters:07d}",
338+
hf_path=args.hf_export_path,
339+
show_progress=True,
340+
strict=True,
341+
)
342+
# Copy config.json from student_hf_path (handles both local paths and HF model IDs)
343+
AutoConfig.from_pretrained(
344+
args.student_hf_path, trust_remote_code=args.trust_remote_code
345+
).save_pretrained(args.hf_export_path)
346+
258347

259348
if __name__ == "__main__":
260349
dist.setup()
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Puzzletron Distillation Results
2+
3+
The following MMLU results demonstrate knowledge distillation on student models that were first compressed using [Puzzletron](../../puzzletron/README.md). The original (uncompressed) model serves as the teacher, and distillation recovers accuracy lost during compression.
4+
5+
## Qwen3-8B compressed to 80% of original
6+
7+
The student was created by compressing Qwen3-8B to 80% of its original size using Puzzletron.
8+
9+
| Model | MMLU | Humanities | Other | Social Sci | STEM |
10+
|-------|------|------------|-------|------------|------|
11+
| Student (before distillation) | 0.5910 | 0.5046 | 0.6363 | 0.6831 | 0.5855 |
12+
| Student (after distillation) | 0.6921 | 0.5906 | 0.7316 | 0.7975 | 0.7016 |
13+
| Teacher (original Qwen3-8B) | 0.7493 | 0.6648 | 0.7856 | 0.8385 | 0.7526 |
14+
15+
MMLU accuracy improved from 59.10% to 69.21% (+10.11 pp) after distillation with just 100 iterations on WikiText-103, recovering 64% of the gap to the teacher model.
16+
17+
## Llama-3.1-8B-Instruct compressed to 50% of original
18+
19+
The student was created by compressing Llama-3.1-8B-Instruct to 50% of its original size using Puzzletron.
20+
21+
| Model | MMLU | Humanities | Other | Social Sciences | STEM |
22+
|-------|------|------------|-------|-----------------|------|
23+
| Student (before distillation) | 0.2316 | 0.2462 | 0.2292 | 0.2250 | 0.2274 |
24+
| Student (after distillation) | 0.2960 | 0.3146 | 0.3085 | 0.2925 | 0.2768 |
25+
| Teacher (original Llama-3.1-8B-Instruct) | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 |
26+
27+
## Llama-3.1-8B-Instruct compressed to 69% of original (regression)
28+
29+
The student was created by compressing Llama-3.1-8B-Instruct to ~69% of its original size using Puzzletron. This example shows regression due to overfitting on the small WikiText-103 dataset (100 iterations). MMLU was evaluated on a subset of 100 samples per task:
30+
31+
| Model | MMLU | Humanities | Other | Social Sciences | STEM |
32+
|-------|------|------------|-------|-----------------|------|
33+
| Student (before distillation) | 0.6626 | 0.7069 | 0.6892 | 0.7525 | 0.5574 |
34+
| Student (after distillation) | 0.6496 | 0.6862 | 0.6677 | 0.7433 | 0.5532 |
35+
| Teacher (original Llama-3.1-8B-Instruct) | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 |
36+
37+
MMLU decreased from 66.26% to 64.96% (-1.30 pp) -- the model overfitted to WikiText-103. This highlights the importance of using larger, more diverse datasets for distillation.
38+
39+
## Recommendations
40+
41+
- **Use larger datasets** for production distillation (e.g., [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1)) to avoid overfitting as shown in the regression case above.
42+
- **Train for more iterations** to ensure proper convergence.

examples/puzzletron/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ vllm bench throughput --model path/to/model --input-len 2000 --output-len 100 --
299299
300300
To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one.
301301
302-
See [mbridge_distillation/README.md](./mbridge_distillation/README.md) for instructions on using Megatron-Bridge for knowledge distillation.
302+
See [Megatron-Bridge distillation](../megatron_bridge/README.md#distillation) for instructions on using Megatron-Bridge for knowledge distillation. The distillation script supports both standard HuggingFace and Puzzletron AnyModel checkpoints.
303303
304304
## Advanced Usage
305305

0 commit comments

Comments
 (0)