Skip to content

Commit ce4d081

Browse files
minor
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 48c74bd commit ce4d081

1 file changed

Lines changed: 27 additions & 8 deletions

File tree

examples/megatron_bridge/distill.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Distillation script for Megatron-Bridge.
1616
1717
Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model
18-
to <log_dir>/checkpoints in megatron torch_dist checkpoint format.
18+
to <output_dir>/checkpoints in megatron distributed checkpoint format.
1919
2020
Example usage to distill a 4B student from an 8B teacher on 8 GPUs:
2121
@@ -26,6 +26,7 @@
2626
--student_hf_path Qwen/Qwen3-4B \
2727
--tp_size 8 \
2828
--data_paths 1.0 /path/to/tokenized/data \
29+
--data_path_to_cache /path/to/cache/dataset_indices_qwen3 \
2930
--seq_length 8192 \
3031
--mbs 1 \
3132
--gbs 768 \
@@ -36,7 +37,7 @@
3637
--eval_interval 100 \
3738
--eval_iters 32 \
3839
--log_interval 10 \
39-
--log_dir /output/qwen3_8b_to_4b_distill
40+
--output_dir /output/qwen3_8b_to_4b_distill
4041
4142
Example usage to use mock data for quick testing:
4243
@@ -51,7 +52,9 @@
5152
--mbs 1 \
5253
--gbs 8 \
5354
--train_iters 100 \
54-
--log_dir /tmp/test_distill
55+
--eval_interval 10 \
56+
--eval_iters 4 \
57+
--output_dir /tmp/test_distill
5558
5659
If you want to tokenize your own data for a specific tokenizer, you can use the following command:
5760
@@ -129,12 +132,15 @@ def get_args():
129132
parser.add_argument(
130133
"--split", type=str, default="99,1,0", help="Train,Val,Test ratios to split data"
131134
)
135+
parser.add_argument(
136+
"--data_path_to_cache", type=str, default=None, help="Path to cache the dataset indices"
137+
)
132138
parser.add_argument(
133139
"--use_mock_data", action="store_true", help="Use mock data instead of --data_paths"
134140
)
135-
# Training arguments
141+
# Training & Eval arguments
136142
parser.add_argument(
137-
"--log_dir", type=str, required=True, help="Folder for logging and checkpoint saving"
143+
"--output_dir", type=str, required=True, help="Folder for logging and checkpoint saving"
138144
)
139145
parser.add_argument(
140146
"--seq_length", type=int, default=8192, help="Number of tokens per input sample"
@@ -153,7 +159,13 @@ def get_args():
153159
parser.add_argument(
154160
"--eval_iters", type=int, default=32, help="Number of batches per validation stage"
155161
)
162+
# Logging arguments
156163
parser.add_argument("--log_interval", type=int, default=10, help="Write to log every <N> steps")
164+
parser.add_argument(
165+
"--wandb_project", type=str, help="Wandb project name (required to enable Wandb logging)"
166+
)
167+
parser.add_argument("--wandb_entity", type=str, help="Wandb entity name (optional)")
168+
parser.add_argument("--wandb_exp_name", type=str, help="Wandb experiment name (optional)")
157169
args = parser.parse_args()
158170

159171
# Sanity checks
@@ -169,8 +181,8 @@ def get_args():
169181

170182

171183
def main(args: argparse.Namespace):
172-
checkpoint_dir = os.path.join(args.log_dir, "checkpoints")
173-
tensorboard_dir = os.path.join(args.log_dir, "tb_logs")
184+
checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
185+
tensorboard_dir = os.path.join(args.output_dir, "tb_logs")
174186

175187
# Build student and teacher model providers
176188
def _build_model_provider(hf_path):
@@ -206,6 +218,7 @@ def _build_model_provider(hf_path):
206218
# Build dataset config
207219
dataset_kwargs = {
208220
"seq_length": args.seq_length,
221+
"path_to_cache": args.data_path_to_cache,
209222
"random_seed": SEED,
210223
"reset_attention_mask": False,
211224
"reset_position_ids": False,
@@ -249,15 +262,21 @@ def _build_model_provider(hf_path):
249262
log_interval=args.log_interval,
250263
tensorboard_dir=tensorboard_dir,
251264
log_timers_to_tensorboard=True,
265+
# Weights & Biases logging
266+
wandb_project=args.wandb_project,
267+
wandb_entity=args.wandb_entity, # optional
268+
wandb_exp_name=args.wandb_exp_name,
252269
),
253270
tokenizer=TokenizerConfig(
254271
tokenizer_type="NullTokenizer", vocab_size=distill_provider.vocab_size
255272
),
256273
checkpoint=CheckpointConfig(
257274
save_interval=args.eval_interval,
258275
save=checkpoint_dir,
259-
load=checkpoint_dir,
276+
load=checkpoint_dir, # Resume from this directory (if exists)
277+
most_recent_k=3, # Keeps 3 most recent checkpoints (not metric-based)
260278
ckpt_format="torch_dist",
279+
async_save=True,
261280
fully_parallel_save=True,
262281
finetune=True,
263282
),

0 commit comments

Comments
 (0)