You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Speedup megatron_mmlu by ~6x via prefill scoring and global batching
Replace autoregressive generation with a single prefill forward pass per
batch and argmax over the four answer-choice token logits. This matches
the log-likelihood approach used by lm-evaluation-harness and avoids the
autoregressive decode loop entirely.
Additional improvements:
- Load dataset once with the "all" config (2 calls) instead of once per
subject (114 calls), eliminating the main CPU overhead bottleneck
- Batch globally across all subjects sorted by descending sequence length
to minimise padding waste and fail-fast on OOM
- Skip dev dataset load when few_shots=0
- Rename percentage -> fraction for clearer semantics
- Fix few-shot answer formatting (was emitting integer index, now letter)
- Fix off-by-one: idx > threshold -> idx >= threshold
- Fix avg_correct reset bug inside subject loop
- Add tqdm progress bar (rank-0 only)
- Explicitly del logits/padded after each batch to avoid tensor lifetime
overlap that caused OOM on long-sequence runs
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Copy file name to clipboardExpand all lines: examples/pruning/README.md
+2-2Lines changed: 2 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -124,7 +124,7 @@ This mode can be useful when you don't know the exact dimensions you want to pru
124
124
from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu
125
125
126
126
defscore_func(m):
127
-
return megatron_mmlu(m, tokenizer, percentage=0.05) #5% sampled data for faster eval
127
+
return megatron_mmlu(m, tokenizer, fraction=0.1, batch_size=4) #10% sampled data for faster eval
128
128
129
129
# Specify target parameter count and configure the auto pruning algorithm
130
130
# Save minitron scores at checkpoint so we can resume pruning without running the forward loop again
@@ -147,7 +147,7 @@ mtp.prune(...)
147
147
148
148
1.**Importance Scoring**: Same as manual pruning - computes activation magnitudes for all parameters (takes ~5 minutes for an 8B model)
149
149
2.**Search Space Construction**: Generates a search space of possible architectures based search space config and other configs (`max_width_pruning`, `max_depth_pruning`, `hparams_to_skip`)
150
-
3.**Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes ~10 mins per candidate for an 8B model pruning)
150
+
3.**Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes ~1 min per candidate for an 8B model MMLU score with 10% sampled data)
151
151
4.**Best Architecture Selection**: Returns the architecture (best `export_config`) with the highest actual score from the top-K evaluated architectures
152
152
5.**Weight Slicing**: Slices the model weights according to the best pruned architecture found
0 commit comments