Skip to content

Commit 1cec8ec

Browse files
minor
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent ca688d1 commit 1cec8ec

File tree

3 files changed

+9
-12
lines changed

3 files changed

+9
-12
lines changed

examples/megatron_bridge/prune_minitron.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
while skipping pruning of num_attention_heads using following defaults:
1919
1024 samples from nemotron-post-training-dataset-v2 for calibration,
2020
at-most 20% depth (num_layers) and 40% width is pruned per prunable hparam (hidden_size, ffn_hidden_size, ...),
21-
top-10 candidates are evaluated for MMLU score (10% sampled data) to select the best model.
21+
top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model.
2222
2323
torchrun --nproc_per_node 2 prune_minitron.py \
2424
--hf_model_name_or_path Qwen/Qwen3-8B \
@@ -140,11 +140,11 @@ def get_args() -> argparse.Namespace:
140140
parser.add_argument(
141141
"--prune_score_func",
142142
type=str,
143-
default="mmlu_10pct",
143+
default="mmlu_5pct",
144144
help=(
145145
"Score function to use for NAS-based pruning (--prune_target_params). Only supports MMLU at the moment. "
146146
"Format: mmlu_<N>pct where <N> is the percentage of MMLU data to sample per subject "
147-
"(e.g. mmlu_10pct for 10%, mmlu_100pct for full eval)."
147+
"(e.g. mmlu_5pct for 5%, mmlu_100pct for full eval)."
148148
),
149149
)
150150
parser.add_argument(
@@ -299,17 +299,13 @@ def main(args: argparse.Namespace):
299299
match = re.fullmatch(r"mmlu_(\d+)pct", args.prune_score_func)
300300
if not match:
301301
raise ValueError(
302-
f"Invalid score function: {args.prune_score_func}. "
303-
"Expected format: mmlu_<N>pct (e.g. mmlu_10pct)"
302+
f"Invalid score function: {args.prune_score_func}. Expected format: mmlu_<N>pct (e.g. mmlu_5pct)"
304303
)
305-
mmlu_pct = int(match.group(1))
306-
if not 0 < mmlu_pct <= 100:
307-
raise ValueError("--prune_score_func percentage must be in the range [1, 100].")
308-
_mmlu_frac = mmlu_pct / 100.0
304+
mmlu_frac = float(match.group(1)) / 100.0
309305

310306
def score_func(m):
311307
return megatron_mmlu(
312-
m, tokenizer, few_shots=0, fraction=_mmlu_frac, batch_size=args.calib_mbs
308+
m, tokenizer, few_shots=0, fraction=mmlu_frac, batch_size=args.calib_mbs
313309
)
314310

315311
pruning_config["score_func"] = score_func

examples/pruning/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ This mode can be useful when you don't know the exact dimensions you want to pru
124124
from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu
125125

126126
def score_func(m):
127-
return megatron_mmlu(m, tokenizer, fraction=0.1, batch_size=4) # 10% sampled data for faster eval
127+
return megatron_mmlu(m, tokenizer, fraction=0.05, batch_size=4) # 5% sampled data for faster eval
128128

129129
# Specify target parameter count and configure the auto pruning algorithm
130130
# Save minitron scores at checkpoint so we can resume pruning without running the forward loop again
@@ -147,7 +147,7 @@ mtp.prune(...)
147147

148148
1. **Importance Scoring**: Same as manual pruning - computes activation magnitudes for all parameters (takes ~5 minutes for an 8B model)
149149
2. **Search Space Construction**: Generates a search space of possible architectures based search space config and other configs (`max_width_pruning`, `max_depth_pruning`, `hparams_to_skip`)
150-
3. **Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes ~5 min per candidate for an 8B model MMLU score with 10% sampled data)
150+
3. **Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes 2-3 mins per candidate for an 8B model MMLU score with 5% sampled data)
151151
4. **Best Architecture Selection**: Returns the architecture (best `export_config`) with the highest actual score from the top-K evaluated architectures
152152
5. **Weight Slicing**: Slices the model weights according to the best pruned architecture found
153153

modelopt/torch/utils/plugins/megatron_mmlu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def megatron_mmlu(
7878
f"\nMMLU ({fraction * 100}%, {few_shots}-shot, Batch Size: {batch_size}) evaluation started...\n"
7979
"First batch may take longer to evaluate for Pipeline Parallel models."
8080
)
81+
assert 0 < fraction <= 1, "Fraction must be between 0 and 1"
8182

8283
# Token IDs for " A", " B", " C", " D" — the last subword handles edge cases.
8384
choice_ids = [tokenizer.encode(f" {c}", add_special_tokens=False)[-1] for c in _CHOICES]

0 commit comments

Comments
 (0)