Commit e4b054b
Fix and Speedup megatron_mmlu by >10x via prefill scoring and global batching (#1280)
### What does this PR do?
Type of change: new feature + bug fix
Two improvements to Megatron inference utilities:
**1. Pipeline Parallel (PP) correctness fixes**
PP inference was producing garbage output (MMLU ~0.24, random chance).
Two root causes:
- `megatron_generate` / `megatron_prefill` used
`get_forward_backward_func()` (the training pipeline scheduler), which
is not designed for inference. Rewrote both functions to use explicit
P2P communication via `recv_from_prev_pipeline_rank_` /
`send_to_next_pipeline_rank`, matching the `run_mcore_inference`
pattern.
- `import_mcore_gpt_from_hf` loads HF weights into stage 0's embedding
but never updates the output_layer on the last PP stage when
`share_embeddings_and_output_weights=True`. At model init,
`setup_embeddings_and_output_layer()` all-reduces from stage 0 to sync
the output layer; after importing HF weights that all-reduce is stale.
Fix: call `model.setup_embeddings_and_output_layer()` again after
import.
**2. `megatron_mmlu` speedup (~6x)**
Replaces the `megatron_mmlu` implementation with a significantly faster
approach that matches how `lm-evaluation-harness` scores multiple-choice
questions.
**Before:** autoregressive generation (`megatron_generate`, `osl=2`) per
example, 114 separate `load_dataset` calls, batch_size=1 — 260s for 5%
data.
**After:** single prefill forward pass + argmax over {A,B,C,D} logits, 2
`load_dataset` calls, configurable batch_size — 18s for 5% data (~6x
faster).
### Changes
**PP fixes:**
- `megatron_generate` / `megatron_prefill`: replace
`get_forward_backward_func` with explicit P2P
(`recv_from_prev_pipeline_rank_` / `send_to_next_pipeline_rank`)
- `import_mcore_gpt_from_hf`: call
`model.setup_embeddings_and_output_layer()` after HF weight import when
PP>1 and `share_embeddings_and_output_weights=True`
- `megatron_prefill`: add `skip_return_logits` param and VLM support
(needed for PP non-last stages)
**MMLU speedup:**
- **Log-likelihood scoring**: replace `megatron_generate` with
`megatron_prefill` — one forward pass per batch, no autoregressive
decode loop
- **Global batching**: collect all examples across all subjects, sort by
descending sequence length, run in `batch_size` chunks
- **2 dataset loads** instead of 114: use `load_dataset("cais/mmlu",
"all")` with per-subject grouping; skip dev load when `few_shots=0`
- **`percentage` → `fraction`** parameter rename for clarity
- **tqdm progress bar** (rank-0 only)
### Testing
- `test_megatron_generate_and_mmlu` parametrized over `tp` and `pp`.
Accuracy assertion: `0.36 < score < 0.39`. Manually checked generated
text is coherent.
- Re-ran M-Bridge Minitron MMLU based pruning for Nano v2 9B -> 7B and
all top 10 candidate's MMLU numbers are ballpark similar as before
### Before your PR is "*Ready for review*"
- Is this change backward compatible?: ❌ — `percentage` parameter
renamed to `fraction`; `enable_kv_cache` removed from `megatron_mmlu`
- If you copied code from any other sources or added a new PIP
dependency, did you follow guidance in `CONTRIBUTING.md`: N/A
- Did you write any new necessary tests?: ✅ — existing test updated and
parametrized for TP+PP
- Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?:
✅
🤖 Generated with [Claude Code](https://claude.ai/claude-code)
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Bug Fixes**
* Improved pipeline-parallel generation and MMLU evaluation reliability;
fixed output-layer synchronization in shared-embedding + pipeline
setups.
* **New Features**
* MMLU scoring now uses batched prefill logit scoring for faster,
batched evaluation.
* **Behavior Changes**
* Default MMLU sampling increased from 5% to 10%; calibration batch
sizing adjusted and related CLI/help text updated.
* **Tests**
* Distributed tests cover tensor- and pipeline-parallel modes and
tighten MMLU validation ranges.
* **Documentation**
* Updated pruning example and benchmark timing to reflect new sampling
and speedup.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>1 parent 4e33368 commit e4b054b
File tree
9 files changed
+325
-254
lines changed- examples
- megatron_bridge
- pruning
- modelopt/torch
- export/plugins
- prune/plugins
- utils
- plugins
- tests/gpu_megatron/torch/utils/plugins
9 files changed
+325
-254
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
25 | 26 | | |
26 | 27 | | |
27 | 28 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
21 | | - | |
| 21 | + | |
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| |||
140 | 140 | | |
141 | 141 | | |
142 | 142 | | |
143 | | - | |
| 143 | + | |
144 | 144 | | |
145 | 145 | | |
146 | 146 | | |
147 | | - | |
| 147 | + | |
148 | 148 | | |
149 | 149 | | |
150 | 150 | | |
| |||
299 | 299 | | |
300 | 300 | | |
301 | 301 | | |
302 | | - | |
303 | | - | |
| 302 | + | |
304 | 303 | | |
305 | | - | |
306 | | - | |
307 | | - | |
308 | | - | |
| 304 | + | |
309 | 305 | | |
310 | 306 | | |
311 | | - | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
312 | 310 | | |
313 | 311 | | |
314 | 312 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
124 | 124 | | |
125 | 125 | | |
126 | 126 | | |
127 | | - | |
| 127 | + | |
128 | 128 | | |
129 | 129 | | |
130 | 130 | | |
| |||
147 | 147 | | |
148 | 148 | | |
149 | 149 | | |
150 | | - | |
| 150 | + | |
151 | 151 | | |
152 | 152 | | |
153 | 153 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
747 | 747 | | |
748 | 748 | | |
749 | 749 | | |
| 750 | + | |
| 751 | + | |
| 752 | + | |
| 753 | + | |
| 754 | + | |
| 755 | + | |
| 756 | + | |
| 757 | + | |
| 758 | + | |
| 759 | + | |
| 760 | + | |
750 | 761 | | |
751 | 762 | | |
752 | 763 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
171 | 171 | | |
172 | 172 | | |
173 | 173 | | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
174 | 177 | | |
175 | 178 | | |
176 | 179 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
105 | 105 | | |
106 | 106 | | |
107 | 107 | | |
| 108 | + | |
108 | 109 | | |
109 | | - | |
| 110 | + | |
110 | 111 | | |
111 | 112 | | |
112 | 113 | | |
| |||
0 commit comments