|
14 | 14 | # limitations under the License. |
15 | 15 | """Distillation script for Megatron-Bridge. |
16 | 16 |
|
17 | | -Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model |
18 | | -to <output_dir>/checkpoints in megatron distributed checkpoint format. |
19 | | -
|
20 | | -Example usage to distill a 4B student from an 8B teacher on 8 GPUs: |
21 | | -
|
22 | | -.. code-block:: bash |
23 | | -
|
24 | | - torchrun --nproc_per_node 8 distill.py \ |
25 | | - --teacher_hf_path Qwen/Qwen3-8B \ |
26 | | - --student_hf_path Qwen/Qwen3-4B \ |
27 | | - --tp_size 8 \ |
28 | | - --data_paths 1.0 /path/to/tokenized/data \ |
29 | | - --data_path_to_cache /path/to/cache/dataset_indices_qwen3 \ |
30 | | - --seq_length 8192 \ |
31 | | - --mbs 1 \ |
32 | | - --gbs 768 \ |
33 | | - --train_iters 15000 \ |
34 | | - --lr 1e-4 \ |
35 | | - --min_lr 1e-5 \ |
36 | | - --lr_warmup_iters 50 \ |
37 | | - --eval_interval 100 \ |
38 | | - --eval_iters 32 \ |
39 | | - --log_interval 10 \ |
40 | | - --output_dir /output/qwen3_8b_to_4b_distill |
41 | | -
|
42 | | -Example usage to use mock data for quick testing: |
43 | | -
|
44 | | -.. code-block:: bash |
45 | | -
|
46 | | - torchrun --nproc_per_node 8 distill.py \ |
47 | | - --teacher_hf_path Qwen/Qwen3-0.6B \ |
48 | | - --student_hf_path Qwen/Qwen3-0.6B \ |
49 | | - --tp_size 8 \ |
50 | | - --use_mock_data \ |
51 | | - --seq_length 512 \ |
52 | | - --mbs 1 \ |
53 | | - --gbs 8 \ |
54 | | - --train_iters 100 \ |
55 | | - --eval_interval 10 \ |
56 | | - --eval_iters 4 \ |
57 | | - --output_dir /tmp/test_distill |
58 | | -
|
59 | | -If you want to tokenize your own data for a specific tokenizer, you can use the following command: |
60 | | -
|
61 | | -.. code-block:: python |
62 | | -
|
63 | | - from modelopt.torch.utils.plugins import megatron_preprocess_data |
64 | | -
|
65 | | - megatron_preprocess_data( |
66 | | - input_path="/path/to/your/data.jsonl", |
67 | | - output_dir="/path/to/tokenized/data", |
68 | | - tokenizer_name_or_path="Qwen/Qwen3-0.6B", |
69 | | - json_keys=["text"], |
70 | | - workers=32, |
71 | | - log_interval=100000, |
72 | | - max_sequence_length=256000, |
73 | | - ) |
| 17 | +Loads student and teacher models from HuggingFace or Megatron checkpoints and saves the distilled model |
| 18 | +to `<output_dir>/checkpoints` in Megatron distributed checkpoint format. |
| 19 | +
|
| 20 | +Supported input model path formats for student and teacher (auto-detected): |
| 21 | + - HuggingFace: model hub name (e.g. `Qwen/Qwen3-8B`) or local HF dir |
| 22 | + - Megatron: `iter_*` directory inside a Megatron checkpoint (e.g. `/path/to/ckpt/iter_0000000`) |
| 23 | +
|
| 24 | +See `README.md` in this directory for example usage and data preparation instructions. |
74 | 25 | """ |
75 | 26 |
|
76 | 27 | import argparse |
| 28 | +import contextlib |
77 | 29 | import os |
78 | 30 |
|
79 | 31 | import torch |
|
82 | 34 | from megatron.bridge.recipes.utils.optimizer_utils import ( |
83 | 35 | distributed_fused_adam_with_cosine_annealing, |
84 | 36 | ) |
| 37 | +from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint |
85 | 38 | from megatron.bridge.training.config import ( |
86 | 39 | CheckpointConfig, |
87 | 40 | ConfigContainer, |
|
93 | 46 | TrainingConfig, |
94 | 47 | ) |
95 | 48 | from megatron.bridge.training.distill import distill |
| 49 | +from megatron.bridge.training.model_load_save import load_model_config |
96 | 50 | from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig |
97 | 51 | from megatron.core.datasets.utils import get_blend_from_list |
98 | 52 | from megatron.core.distributed import DistributedDataParallelConfig |
|
106 | 60 | def get_args(): |
107 | 61 | """Parse command-line arguments.""" |
108 | 62 | parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.") |
109 | | - # Model arguments |
| 63 | + # Model arguments (accepts HuggingFace paths or Megatron checkpoint directories) |
110 | 64 | parser.add_argument( |
111 | | - "--student_hf_path", |
| 65 | + "--student_path", |
112 | 66 | type=str, |
113 | 67 | required=True, |
114 | | - help="HuggingFace model name or path for the student (e.g. Qwen/Qwen3-0.6B)", |
| 68 | + help="Student model: HuggingFace name/path (e.g. Qwen/Qwen3-0.6B) or Megatron dir (e.g. /path/iter_0000000)", |
115 | 69 | ) |
116 | 70 | parser.add_argument( |
117 | | - "--teacher_hf_path", |
| 71 | + "--teacher_path", |
118 | 72 | type=str, |
119 | 73 | required=True, |
120 | | - help="HuggingFace model name or path for the teacher (e.g. Qwen/Qwen3-8B)", |
| 74 | + help="Teacher model: HuggingFace name/path (e.g. Qwen/Qwen3-8B) or Megatron dir (e.g. /path/iter_0000000)", |
121 | 75 | ) |
122 | 76 | # Parallelism arguments |
123 | 77 | parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") |
@@ -179,26 +133,93 @@ def get_args(): |
179 | 133 | return args |
180 | 134 |
|
181 | 135 |
|
| 136 | +def _build_provider_from_hf(hf_path): |
| 137 | + """Build a model provider from a HuggingFace model path.""" |
| 138 | + bridge = AutoBridge.from_hf_pretrained(hf_path) |
| 139 | + provider = bridge.to_megatron_provider(load_weights=True) |
| 140 | + return provider |
| 141 | + |
| 142 | + |
| 143 | +def _build_provider_from_megatron(iter_path): |
| 144 | + """Build a model provider from a Megatron checkpoint iter directory. |
| 145 | +
|
| 146 | + Uses load_model_config() to reconstruct the provider (architecture) from the |
| 147 | + checkpoint's saved run config, then registers a pre_wrap_hook that loads the |
| 148 | + checkpoint weights after model creation. |
| 149 | + """ |
| 150 | + print_rank_0(f"Loading model config from Megatron checkpoint: {iter_path}") |
| 151 | + provider, _ = load_model_config(iter_path) |
| 152 | + provider.perform_initialization = False # Will load weights from ckpt |
| 153 | + |
| 154 | + def _load_megatron_weights_hook(model): |
| 155 | + """Pre-wrap hook: load weights from a Megatron checkpoint into the model. |
| 156 | +
|
| 157 | + For distillation kd_models, hide teacher / loss modules so only the |
| 158 | + student weights are loaded (same pattern as HF weight loading). |
| 159 | + """ |
| 160 | + cms = [] |
| 161 | + for m in model: |
| 162 | + if hasattr(m, "hide_teacher_model"): |
| 163 | + cms.append(m.hide_teacher_model()) |
| 164 | + if hasattr(m, "hide_loss_modules"): |
| 165 | + cms.append(m.hide_loss_modules()) |
| 166 | + with contextlib.ExitStack() as stack: |
| 167 | + for cm in cms: |
| 168 | + stack.enter_context(cm) |
| 169 | + _load_model_weights_from_checkpoint(iter_path, model) |
| 170 | + return model |
| 171 | + |
| 172 | + provider.register_pre_wrap_hook(_load_megatron_weights_hook) |
| 173 | + return provider |
| 174 | + |
| 175 | + |
| 176 | +def _build_model_provider(path, *, tp_size, pp_size, seq_length): |
| 177 | + """Build a model provider, auto-detecting HuggingFace vs Megatron checkpoint. |
| 178 | +
|
| 179 | + Megatron checkpoints are detected by the `iter_` prefix in the directory name |
| 180 | + (e.g. `/path/to/checkpoints/iter_0000000`). |
| 181 | + """ |
| 182 | + if os.path.basename(os.path.normpath(path)).startswith("iter_"): |
| 183 | + print_rank_0(f"Detected Megatron checkpoint: {path}") |
| 184 | + provider = _build_provider_from_megatron(path) |
| 185 | + else: |
| 186 | + print_rank_0(f"Loading from HuggingFace: {path}") |
| 187 | + provider = _build_provider_from_hf(path) |
| 188 | + |
| 189 | + # Override parallelism / training settings |
| 190 | + provider.tensor_model_parallel_size = tp_size |
| 191 | + provider.pipeline_model_parallel_size = pp_size |
| 192 | + provider.context_parallel_size = 1 |
| 193 | + provider.sequence_parallel = tp_size > 1 |
| 194 | + provider.seq_length = seq_length |
| 195 | + provider.pipeline_dtype = torch.bfloat16 |
| 196 | + provider.cross_entropy_fusion_impl = "te" |
| 197 | + |
| 198 | + # Run deferred __post_init__ to compute derived fields (e.g. init_method from init_method_std). |
| 199 | + # Must be called after all overrides are set. |
| 200 | + # Safe to call on HF-sourced providers too (idempotent). |
| 201 | + provider.finalize() |
| 202 | + |
| 203 | + return provider |
| 204 | + |
| 205 | + |
182 | 206 | def main(args: argparse.Namespace): |
183 | 207 | checkpoint_dir = os.path.join(args.output_dir, "checkpoints") |
184 | 208 | tensorboard_dir = os.path.join(args.output_dir, "tb_logs") |
185 | 209 |
|
186 | | - # Build student and teacher model providers |
187 | | - def _build_model_provider(hf_path): |
188 | | - bridge = AutoBridge.from_hf_pretrained(hf_path) |
189 | | - provider = bridge.to_megatron_provider(load_weights=True) |
190 | | - provider.tensor_model_parallel_size = args.tp_size |
191 | | - provider.pipeline_model_parallel_size = args.pp_size |
192 | | - provider.context_parallel_size = 1 |
193 | | - provider.sequence_parallel = args.tp_size > 1 |
194 | | - provider.seq_length = args.seq_length |
195 | | - provider.pipeline_dtype = torch.bfloat16 |
196 | | - provider.cross_entropy_fusion_impl = "te" |
197 | | - return provider |
198 | | - |
199 | | - # TODO: Support megatron-ckpt as an alternative to HF checkpoints |
200 | | - student_provider = _build_model_provider(args.student_hf_path) |
201 | | - teacher_provider = _build_model_provider(args.teacher_hf_path) |
| 210 | + # Build student and teacher model providers (auto-detects HF vs Megatron ckpt) |
| 211 | + student_provider = _build_model_provider( |
| 212 | + args.student_path, |
| 213 | + tp_size=args.tp_size, |
| 214 | + pp_size=args.pp_size, |
| 215 | + seq_length=args.seq_length, |
| 216 | + ) |
| 217 | + teacher_provider = _build_model_provider( |
| 218 | + args.teacher_path, |
| 219 | + tp_size=args.tp_size, |
| 220 | + pp_size=args.pp_size, |
| 221 | + seq_length=args.seq_length, |
| 222 | + ) |
202 | 223 |
|
203 | 224 | # Wrap into DistillationProvider |
204 | 225 | kd_config = ModelOptDistillConfig() |
|
0 commit comments