1515"""Distillation script for Megatron-Bridge.
1616
1717Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model
18- to <log_dir >/checkpoints in megatron torch_dist checkpoint format.
18+ to <output_dir >/checkpoints in megatron distributed checkpoint format.
1919
2020Example usage to distill a 4B student from an 8B teacher on 8 GPUs:
2121
2626 --student_hf_path Qwen/Qwen3-4B \
2727 --tp_size 8 \
2828 --data_paths 1.0 /path/to/tokenized/data \
29+ --data_path_to_cache /path/to/cache/dataset_indices_qwen3 \
2930 --seq_length 8192 \
3031 --mbs 1 \
3132 --gbs 768 \
3637 --eval_interval 100 \
3738 --eval_iters 32 \
3839 --log_interval 10 \
39- --log_dir /output/qwen3_8b_to_4b_distill
40+ --output_dir /output/qwen3_8b_to_4b_distill
4041
4142Example usage to use mock data for quick testing:
4243
5152 --mbs 1 \
5253 --gbs 8 \
5354 --train_iters 100 \
54- --log_dir /tmp/test_distill
55+ --eval_interval 10 \
56+ --eval_iters 4 \
57+ --output_dir /tmp/test_distill
5558
5659If you want to tokenize your own data for a specific tokenizer, you can use the following command:
5760
@@ -129,12 +132,15 @@ def get_args():
129132 parser .add_argument (
130133 "--split" , type = str , default = "99,1,0" , help = "Train,Val,Test ratios to split data"
131134 )
135+ parser .add_argument (
136+ "--data_path_to_cache" , type = str , default = None , help = "Path to cache the dataset indices"
137+ )
132138 parser .add_argument (
133139 "--use_mock_data" , action = "store_true" , help = "Use mock data instead of --data_paths"
134140 )
135- # Training arguments
141+ # Training & Eval arguments
136142 parser .add_argument (
137- "--log_dir " , type = str , required = True , help = "Folder for logging and checkpoint saving"
143+ "--output_dir " , type = str , required = True , help = "Folder for logging and checkpoint saving"
138144 )
139145 parser .add_argument (
140146 "--seq_length" , type = int , default = 8192 , help = "Number of tokens per input sample"
@@ -153,7 +159,13 @@ def get_args():
153159 parser .add_argument (
154160 "--eval_iters" , type = int , default = 32 , help = "Number of batches per validation stage"
155161 )
162+ # Logging arguments
156163 parser .add_argument ("--log_interval" , type = int , default = 10 , help = "Write to log every <N> steps" )
164+ parser .add_argument (
165+ "--wandb_project" , type = str , help = "Wandb project name (required to enable Wandb logging)"
166+ )
167+ parser .add_argument ("--wandb_entity" , type = str , help = "Wandb entity name (optional)" )
168+ parser .add_argument ("--wandb_exp_name" , type = str , help = "Wandb experiment name (optional)" )
157169 args = parser .parse_args ()
158170
159171 # Sanity checks
@@ -169,8 +181,8 @@ def get_args():
169181
170182
171183def main (args : argparse .Namespace ):
172- checkpoint_dir = os .path .join (args .log_dir , "checkpoints" )
173- tensorboard_dir = os .path .join (args .log_dir , "tb_logs" )
184+ checkpoint_dir = os .path .join (args .output_dir , "checkpoints" )
185+ tensorboard_dir = os .path .join (args .output_dir , "tb_logs" )
174186
175187 # Build student and teacher model providers
176188 def _build_model_provider (hf_path ):
@@ -206,6 +218,7 @@ def _build_model_provider(hf_path):
206218 # Build dataset config
207219 dataset_kwargs = {
208220 "seq_length" : args .seq_length ,
221+ "path_to_cache" : args .data_path_to_cache ,
209222 "random_seed" : SEED ,
210223 "reset_attention_mask" : False ,
211224 "reset_position_ids" : False ,
@@ -249,15 +262,21 @@ def _build_model_provider(hf_path):
249262 log_interval = args .log_interval ,
250263 tensorboard_dir = tensorboard_dir ,
251264 log_timers_to_tensorboard = True ,
265+ # Weights & Biases logging
266+ wandb_project = args .wandb_project ,
267+ wandb_entity = args .wandb_entity , # optional
268+ wandb_exp_name = args .wandb_exp_name ,
252269 ),
253270 tokenizer = TokenizerConfig (
254271 tokenizer_type = "NullTokenizer" , vocab_size = distill_provider .vocab_size
255272 ),
256273 checkpoint = CheckpointConfig (
257274 save_interval = args .eval_interval ,
258275 save = checkpoint_dir ,
259- load = checkpoint_dir ,
276+ load = checkpoint_dir , # Resume from this directory (if exists)
277+ most_recent_k = 3 , # Keeps 3 most recent checkpoints (not metric-based)
260278 ckpt_format = "torch_dist" ,
279+ async_save = True ,
261280 fully_parallel_save = True ,
262281 finetune = True ,
263282 ),
0 commit comments