Skip to content

Commit e4b054b

Browse files
Fix and Speedup megatron_mmlu by >10x via prefill scoring and global batching (#1280)
### What does this PR do? Type of change: new feature + bug fix Two improvements to Megatron inference utilities: **1. Pipeline Parallel (PP) correctness fixes** PP inference was producing garbage output (MMLU ~0.24, random chance). Two root causes: - `megatron_generate` / `megatron_prefill` used `get_forward_backward_func()` (the training pipeline scheduler), which is not designed for inference. Rewrote both functions to use explicit P2P communication via `recv_from_prev_pipeline_rank_` / `send_to_next_pipeline_rank`, matching the `run_mcore_inference` pattern. - `import_mcore_gpt_from_hf` loads HF weights into stage 0's embedding but never updates the output_layer on the last PP stage when `share_embeddings_and_output_weights=True`. At model init, `setup_embeddings_and_output_layer()` all-reduces from stage 0 to sync the output layer; after importing HF weights that all-reduce is stale. Fix: call `model.setup_embeddings_and_output_layer()` again after import. **2. `megatron_mmlu` speedup (~6x)** Replaces the `megatron_mmlu` implementation with a significantly faster approach that matches how `lm-evaluation-harness` scores multiple-choice questions. **Before:** autoregressive generation (`megatron_generate`, `osl=2`) per example, 114 separate `load_dataset` calls, batch_size=1 — 260s for 5% data. **After:** single prefill forward pass + argmax over {A,B,C,D} logits, 2 `load_dataset` calls, configurable batch_size — 18s for 5% data (~6x faster). ### Changes **PP fixes:** - `megatron_generate` / `megatron_prefill`: replace `get_forward_backward_func` with explicit P2P (`recv_from_prev_pipeline_rank_` / `send_to_next_pipeline_rank`) - `import_mcore_gpt_from_hf`: call `model.setup_embeddings_and_output_layer()` after HF weight import when PP>1 and `share_embeddings_and_output_weights=True` - `megatron_prefill`: add `skip_return_logits` param and VLM support (needed for PP non-last stages) **MMLU speedup:** - **Log-likelihood scoring**: replace `megatron_generate` with `megatron_prefill` — one forward pass per batch, no autoregressive decode loop - **Global batching**: collect all examples across all subjects, sort by descending sequence length, run in `batch_size` chunks - **2 dataset loads** instead of 114: use `load_dataset("cais/mmlu", "all")` with per-subject grouping; skip dev load when `few_shots=0` - **`percentage` → `fraction`** parameter rename for clarity - **tqdm progress bar** (rank-0 only) ### Testing - `test_megatron_generate_and_mmlu` parametrized over `tp` and `pp`. Accuracy assertion: `0.36 < score < 0.39`. Manually checked generated text is coherent. - Re-ran M-Bridge Minitron MMLU based pruning for Nano v2 9B -> 7B and all top 10 candidate's MMLU numbers are ballpark similar as before ### Before your PR is "*Ready for review*" - Is this change backward compatible?: ❌ — `percentage` parameter renamed to `fraction`; `enable_kv_cache` removed from `megatron_mmlu` - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ — existing test updated and parametrized for TP+PP - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ 🤖 Generated with [Claude Code](https://claude.ai/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved pipeline-parallel generation and MMLU evaluation reliability; fixed output-layer synchronization in shared-embedding + pipeline setups. * **New Features** * MMLU scoring now uses batched prefill logit scoring for faster, batched evaluation. * **Behavior Changes** * Default MMLU sampling increased from 5% to 10%; calibration batch sizing adjusted and related CLI/help text updated. * **Tests** * Distributed tests cover tensor- and pipeline-parallel modes and tighten MMLU validation ranges. * **Documentation** * Updated pruning example and benchmark timing to reflect new sampling and speedup. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 4e33368 commit e4b054b

File tree

9 files changed

+325
-254
lines changed

9 files changed

+325
-254
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Changelog
2222

2323
**Bug Fixes**
2424

25+
- Fix Megatron utility functions for generation (with pipeline parallelism) and ~10x speedup in MMLU score evaluation (by batching prefill passes).
2526
- Fix Minitron pruning (``mcore_minitron``) for MoE models. Importance estimation hooks were incorrectly registered for MoE modules and NAS step was hanging before this.
2627
- Fix TRT support for remote autotuning in ONNX Autotune from 10.16+ to 10.15+ and fix TRT versioning check to the ``trtexec`` version instead of the TRT Python API when using ``trtexec`` backend.
2728

examples/megatron_bridge/prune_minitron.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
while skipping pruning of num_attention_heads using following defaults:
1919
1024 samples from nemotron-post-training-dataset-v2 for calibration,
2020
at-most 20% depth (num_layers) and 40% width is pruned per prunable hparam (hidden_size, ffn_hidden_size, ...),
21-
top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model.
21+
top-10 candidates are evaluated for MMLU score (10% sampled data) to select the best model.
2222
2323
torchrun --nproc_per_node 2 prune_minitron.py \
2424
--hf_model_name_or_path Qwen/Qwen3-8B \
@@ -140,11 +140,11 @@ def get_args() -> argparse.Namespace:
140140
parser.add_argument(
141141
"--prune_score_func",
142142
type=str,
143-
default="mmlu_5pct",
143+
default="mmlu_10pct",
144144
help=(
145145
"Score function to use for NAS-based pruning (--prune_target_params). Only supports MMLU at the moment. "
146146
"Format: mmlu_<N>pct where <N> is the percentage of MMLU data to sample per subject "
147-
"(e.g. mmlu_5pct for 5%, mmlu_100pct for full eval)."
147+
"(e.g. mmlu_10pct for 10%, mmlu_100pct for full eval)."
148148
),
149149
)
150150
parser.add_argument(
@@ -299,16 +299,14 @@ def main(args: argparse.Namespace):
299299
match = re.fullmatch(r"mmlu_(\d+)pct", args.prune_score_func)
300300
if not match:
301301
raise ValueError(
302-
f"Invalid score function: {args.prune_score_func}. "
303-
"Expected format: mmlu_<N>pct (e.g. mmlu_5pct)"
302+
f"Invalid score function: {args.prune_score_func}. Expected format: mmlu_<N>pct (e.g. mmlu_10pct)"
304303
)
305-
mmlu_pct = int(match.group(1))
306-
if not 0 < mmlu_pct <= 100:
307-
raise ValueError("--prune_score_func percentage must be in the range [1, 100].")
308-
_mmlu_pct = mmlu_pct / 100.0
304+
mmlu_frac = float(match.group(1)) / 100.0
309305

310306
def score_func(m):
311-
return megatron_mmlu(m, tokenizer, percentage=_mmlu_pct)
307+
return megatron_mmlu(
308+
m, tokenizer, few_shots=0, fraction=mmlu_frac, batch_size=args.calib_mbs
309+
)
312310

313311
pruning_config["score_func"] = score_func
314312
pruning_config["max_width_pruning"] = args.max_width_pruning

examples/pruning/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ This mode can be useful when you don't know the exact dimensions you want to pru
124124
from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu
125125

126126
def score_func(m):
127-
return megatron_mmlu(m, tokenizer, percentage=0.05) # 5% sampled data for faster eval
127+
return megatron_mmlu(m, tokenizer, fraction=0.1, batch_size=4) # 10% sampled data for faster eval
128128

129129
# Specify target parameter count and configure the auto pruning algorithm
130130
# Save minitron scores at checkpoint so we can resume pruning without running the forward loop again
@@ -147,7 +147,7 @@ mtp.prune(...)
147147

148148
1. **Importance Scoring**: Same as manual pruning - computes activation magnitudes for all parameters (takes ~5 minutes for an 8B model)
149149
2. **Search Space Construction**: Generates a search space of possible architectures based search space config and other configs (`max_width_pruning`, `max_depth_pruning`, `hparams_to_skip`)
150-
3. **Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes ~10 mins per candidate for an 8B model pruning)
150+
3. **Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes ~5 min per candidate for an 8B model MMLU score with 10% sampled data)
151151
4. **Best Architecture Selection**: Returns the architecture (best `export_config`) with the highest actual score from the top-K evaluated architectures
152152
5. **Weight Slicing**: Slices the model weights according to the best pruned architecture found
153153

modelopt/torch/export/plugins/megatron_importer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,17 @@ def _import_state_dict(self):
747747
if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights:
748748
self.rules["output_layer"](model.output_layer)
749749

750+
# For PP with shared embedding/output weights, re-sync the output layer on the last
751+
# pipeline stage from stage 0's (now HF-loaded) embedding. At model init,
752+
# setup_embeddings_and_output_layer() zeros out the last stage's weight and all-reduces
753+
# from stage 0. After importing HF weights into stage 0's embedding, that sync is stale,
754+
# so we re-run it here.
755+
if (
756+
model.share_embeddings_and_output_weights
757+
and model.config.pipeline_model_parallel_size > 1
758+
):
759+
model.setup_embeddings_and_output_layer()
760+
750761
# MTP
751762
if hasattr(model, "mtp"):
752763
layer_pbar.set_description("Importing MTP")

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ class CandidateSubnet:
171171
score: float | None
172172

173173

174+
torch.serialization.add_safe_globals([CandidateSubnet])
175+
176+
174177
class MCoreMinitronSearcher(BaseSearcher):
175178
"""Searcher for Minitron pruning algorithm.
176179

modelopt/torch/utils/logging.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ def no_stdout():
105105

106106
def print_rank_0(*args, **kwargs):
107107
"""Prints only on the master process."""
108+
kwargs.setdefault("flush", True)
108109
if dist.is_master():
109-
print(*args, **kwargs, flush=True)
110+
print(*args, **kwargs)
110111

111112

112113
def warn_rank_0(message, *args, **kwargs):

0 commit comments

Comments
 (0)