Skip to content

Commit 303e429

Browse files
Speedup megatron_mmlu by ~6x via prefill scoring and global batching
Replace autoregressive generation with a single prefill forward pass per batch and argmax over the four answer-choice token logits. This matches the log-likelihood approach used by lm-evaluation-harness and avoids the autoregressive decode loop entirely. Additional improvements: - Load dataset once with the "all" config (2 calls) instead of once per subject (114 calls), eliminating the main CPU overhead bottleneck - Batch globally across all subjects sorted by descending sequence length to minimise padding waste and fail-fast on OOM - Skip dev dataset load when few_shots=0 - Rename percentage -> fraction for clearer semantics - Fix few-shot answer formatting (was emitting integer index, now letter) - Fix off-by-one: idx > threshold -> idx >= threshold - Fix avg_correct reset bug inside subject loop - Add tqdm progress bar (rank-0 only) - Explicitly del logits/padded after each batch to avoid tensor lifetime overlap that caused OOM on long-sequence runs Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 6ded36b commit 303e429

File tree

4 files changed

+118
-81
lines changed

4 files changed

+118
-81
lines changed

examples/megatron_bridge/prune_minitron.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
while skipping pruning of num_attention_heads using following defaults:
1919
1024 samples from nemotron-post-training-dataset-v2 for calibration,
2020
at-most 20% depth (num_layers) and 40% width is pruned per prunable hparam (hidden_size, ffn_hidden_size, ...),
21-
top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model.
21+
top-10 candidates are evaluated for MMLU score (10% sampled data) to select the best model.
2222
2323
torchrun --nproc_per_node 2 prune_minitron.py \
2424
--hf_model_name_or_path Qwen/Qwen3-8B \
@@ -140,11 +140,11 @@ def get_args() -> argparse.Namespace:
140140
parser.add_argument(
141141
"--prune_score_func",
142142
type=str,
143-
default="mmlu_5pct",
143+
default="mmlu_10pct",
144144
help=(
145145
"Score function to use for NAS-based pruning (--prune_target_params). Only supports MMLU at the moment. "
146146
"Format: mmlu_<N>pct where <N> is the percentage of MMLU data to sample per subject "
147-
"(e.g. mmlu_5pct for 5%, mmlu_100pct for full eval)."
147+
"(e.g. mmlu_10pct for 10%, mmlu_100pct for full eval)."
148148
),
149149
)
150150
parser.add_argument(
@@ -300,15 +300,17 @@ def main(args: argparse.Namespace):
300300
if not match:
301301
raise ValueError(
302302
f"Invalid score function: {args.prune_score_func}. "
303-
"Expected format: mmlu_<N>pct (e.g. mmlu_5pct)"
303+
"Expected format: mmlu_<N>pct (e.g. mmlu_10pct)"
304304
)
305305
mmlu_pct = int(match.group(1))
306306
if not 0 < mmlu_pct <= 100:
307307
raise ValueError("--prune_score_func percentage must be in the range [1, 100].")
308-
_mmlu_pct = mmlu_pct / 100.0
308+
_mmlu_frac = mmlu_pct / 100.0
309309

310310
def score_func(m):
311-
return megatron_mmlu(m, tokenizer, percentage=_mmlu_pct)
311+
return megatron_mmlu(
312+
m, tokenizer, few_shots=0, fraction=_mmlu_frac, batch_size=args.calib_mbs
313+
)
312314

313315
pruning_config["score_func"] = score_func
314316
pruning_config["max_width_pruning"] = args.max_width_pruning

examples/pruning/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ This mode can be useful when you don't know the exact dimensions you want to pru
124124
from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu
125125

126126
def score_func(m):
127-
return megatron_mmlu(m, tokenizer, percentage=0.05) # 5% sampled data for faster eval
127+
return megatron_mmlu(m, tokenizer, fraction=0.1, batch_size=4) # 10% sampled data for faster eval
128128

129129
# Specify target parameter count and configure the auto pruning algorithm
130130
# Save minitron scores at checkpoint so we can resume pruning without running the forward loop again
@@ -147,7 +147,7 @@ mtp.prune(...)
147147

148148
1. **Importance Scoring**: Same as manual pruning - computes activation magnitudes for all parameters (takes ~5 minutes for an 8B model)
149149
2. **Search Space Construction**: Generates a search space of possible architectures based search space config and other configs (`max_width_pruning`, `max_depth_pruning`, `hparams_to_skip`)
150-
3. **Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes ~10 mins per candidate for an 8B model pruning)
150+
3. **Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes ~1 min per candidate for an 8B model MMLU score with 10% sampled data)
151151
4. **Best Architecture Selection**: Returns the architecture (best `export_config`) with the highest actual score from the top-K evaluated architectures
152152
5. **Weight Slicing**: Slices the model weights according to the best pruned architecture found
153153

modelopt/torch/utils/plugins/megatron_mmlu.py

Lines changed: 106 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -40,62 +40,53 @@
4040

4141
"""A simple MMLU evaluation for Megatron LM models."""
4242

43-
import requests
4443
import torch
45-
import transformers
4644
from datasets import load_dataset
45+
from tqdm import tqdm
46+
from transformers import PreTrainedTokenizer
4747

48-
from .megatron_generate import megatron_generate
48+
from .. import distributed as dist
49+
from .megatron_generate import megatron_prefill
4950

5051
__all__ = ["megatron_mmlu"]
5152

52-
53-
def _get_all_subjects():
54-
"""All subjects (anatomy, ...) can be acquired from querying all subsets and splits."""
55-
response = requests.get(
56-
"https://datasets-server.huggingface.co/splits?dataset=cais/mmlu", timeout=10
57-
)
58-
data = response.json()
59-
all_subjects = set()
60-
for split in data["splits"]:
61-
all_subjects.add(split["config"])
62-
for name in ["all", "auxiliary_train"]:
63-
all_subjects.discard(name)
64-
return sorted(all_subjects)
53+
_CHOICES = ["A", "B", "C", "D"]
6554

6655

6756
def megatron_mmlu(
6857
model,
69-
tokenizer: transformers.PreTrainedTokenizer,
58+
tokenizer: PreTrainedTokenizer,
7059
few_shots: int = 0,
71-
percentage: float = 0.05,
72-
enable_kv_cache: bool = False,
60+
fraction: float = 0.05,
61+
batch_size: int = 1,
7362
) -> float:
74-
"""Evaluate the model on MMLU.
63+
"""Evaluate the model on MMLU using log-likelihood scoring over batched prefill passes.
64+
65+
Instead of autoregressively generating tokens, a single prefill forward pass is run per
66+
batch and the answer is selected as argmax over the four choice token logits at the last
67+
prompt position. This is the same approach used by lm-evaluation-harness.
7568
7669
Args:
7770
model: The model to evaluate.
7871
tokenizer: The tokenizer to use.
7972
few_shots: The number of few-shot examples to use.
80-
percentage: The percentage of the test set to evaluate on.
81-
enable_kv_cache: Whether to disable KV-cache.
73+
fraction: The fraction of the test set to evaluate on.
74+
batch_size: Number of examples to process in one forward pass.
8275
"""
83-
all_correct = {}
84-
all_subjects = _get_all_subjects()
76+
# Token IDs for " A", " B", " C", " D" — the last subword handles edge cases.
77+
choice_ids = [tokenizer.encode(f" {c}", add_special_tokens=False)[-1] for c in _CHOICES]
8578

8679
def _format_example(example, include_answer: bool = True):
87-
"""Format an example into a multi-choices problem."""
8880
prompt = example["question"]
89-
for choice, answer in zip(["A", "B", "C", "D"], example["choices"]):
81+
for choice, answer in zip(_CHOICES, example["choices"]):
9082
prompt += f"\n{choice}. {answer}"
9183
if include_answer:
92-
prompt += "Answer: {}\n\n".format(example["answer"])
84+
prompt += "Answer: {}\n\n".format(_CHOICES[example["answer"]])
9385
else:
9486
prompt += "\nAnswer:"
9587
return prompt
9688

9789
def _generate_prompt(test_example, dev_examples, few_shots=0):
98-
"""Generating few-shot prompts."""
9990
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
10091
" ".join(test_example["subject"].split("_"))
10192
)
@@ -104,51 +95,97 @@ def _generate_prompt(test_example, dev_examples, few_shots=0):
10495
prompt += _format_example(test_example, include_answer=False)
10596
return prompt
10697

107-
if torch.distributed.get_rank() == 0:
108-
print(f"\nMMLU ({percentage * 100}%, {few_shots}-shot) evaluation started...\n", flush=True)
98+
# Load all subjects in two dataset calls instead of 2x num_subjects calls.
99+
# The "all" config includes a "subject" field for per-subject reporting.
100+
test_dataset = load_dataset("cais/mmlu", "all", split="test")
101+
dev_dataset = load_dataset("cais/mmlu", "all", split="dev") if few_shots > 0 else None
102+
103+
# Group dev examples by subject for few-shot prompt construction.
104+
dev_by_subject: dict = {}
105+
if dev_dataset is not None:
106+
for ex in dev_dataset:
107+
dev_by_subject.setdefault(ex["subject"], []).append(ex)
108+
109+
# Collect all examples, tracking subject membership for per-subject reporting.
110+
all_subjects_seen: list[str] = []
111+
all_prompts: list[str] = []
112+
all_labels: list[str] = []
113+
114+
# Count test examples per subject to apply the fraction cutoff correctly.
115+
subject_counts: dict[str, int] = {}
116+
for ex in test_dataset:
117+
subject_counts[ex["subject"]] = subject_counts.get(ex["subject"], 0) + 1
118+
119+
subject_idx: dict[str, int] = {}
120+
for ex in test_dataset:
121+
subj = ex["subject"]
122+
idx = subject_idx.get(subj, 0)
123+
if idx >= fraction * subject_counts[subj]:
124+
continue
125+
subject_idx[subj] = idx + 1
126+
prompt = _generate_prompt(ex, dev_by_subject.get(subj, []), few_shots=few_shots)
127+
all_prompts.append(prompt)
128+
all_labels.append(_CHOICES[ex["answer"]])
129+
all_subjects_seen.append(subj)
130+
131+
# Tokenize all prompts and sort by length to minimise padding waste within batches.
132+
encoded = [tokenizer(p, return_tensors="pt").input_ids[0] for p in all_prompts]
133+
lengths = [e.shape[0] for e in encoded]
134+
order = sorted(range(len(encoded)), key=lambda i: lengths[i], reverse=True)
135+
136+
sorted_encoded = [encoded[i] for i in order]
137+
sorted_lengths = [lengths[i] for i in order]
138+
139+
# Run inference in global batches.
140+
predictions: list[str] = [""] * len(encoded)
141+
n_batches = (len(sorted_encoded) + batch_size - 1) // batch_size
142+
pbar = tqdm(
143+
range(0, len(sorted_encoded), batch_size),
144+
total=n_batches,
145+
desc="MMLU",
146+
unit="batch",
147+
disable=not dist.is_master(),
148+
)
149+
for batch_start in pbar:
150+
batch_enc = sorted_encoded[batch_start : batch_start + batch_size]
151+
batch_len = sorted_lengths[batch_start : batch_start + batch_size]
152+
max_len = max(batch_len)
153+
154+
# Right-pad to max_len; causal mask means the last real token is unaffected by padding.
155+
padded = torch.zeros(len(batch_enc), max_len, dtype=torch.long)
156+
for i, (e, seq_len) in enumerate(zip(batch_enc, batch_len)):
157+
padded[i, :seq_len] = e
158+
159+
logits = megatron_prefill(model, padded.cuda()) # [B, max_len, vocab]
160+
161+
for i, seq_len in enumerate(batch_len):
162+
answer_logits = logits[i, seq_len - 1, choice_ids]
163+
predictions[order[batch_start + i]] = _CHOICES[answer_logits.argmax().item()]
164+
165+
examples_done = min(batch_start + batch_size, len(sorted_encoded))
166+
pbar.set_postfix(examples=f"{examples_done}/{len(sorted_encoded)}")
167+
168+
# Compute per-subject accuracy and overall average.
169+
subject_correct: dict[str, list[bool]] = {}
170+
for pred, label, subj in zip(predictions, all_labels, all_subjects_seen):
171+
subject_correct.setdefault(subj, []).append(pred == label)
172+
173+
all_correct = [pred == label for pred, label in zip(predictions, all_labels)]
174+
n_total = len(all_correct)
175+
avg = sum(all_correct) / n_total
176+
177+
if dist.is_master():
178+
print(f"\nMMLU ({fraction * 100}%, {few_shots}-shot) evaluation started...\n", flush=True)
109179
print("{:48} | (ACC) | Count/Total".format("Subject"), flush=True)
110180
print("{:48} | {:5} | {:11}".format("-" * 48, "-" * 5, "-" * 11), flush=True)
111-
112-
for subject in all_subjects:
113-
test_data = load_dataset("cais/mmlu", subject, split="test")
114-
dev_data = load_dataset("cais/mmlu", subject, split="dev")
115-
116-
correct = []
117-
for idx, test_example in enumerate(test_data):
118-
if idx > percentage * len(test_data):
119-
break
120-
prompt = _generate_prompt(test_example, dev_data, few_shots=few_shots)
121-
label = ["A", "B", "C", "D"][test_example["answer"]]
122-
tokens = tokenizer(prompt, return_tensors="pt")
123-
generated_ids = megatron_generate(
124-
model,
125-
tokens.input_ids.cuda(),
126-
osl=2,
127-
disable_tqdm=True,
128-
enable_kv_cache=enable_kv_cache,
129-
)
130-
predict = tokenizer.batch_decode(generated_ids)[0].strip()
131-
correct += [True] if predict.startswith(label) else [False]
132-
all_correct[subject] = correct
133-
134-
if torch.distributed.get_rank() == 0:
135-
print(
136-
f"{subject:48} | {sum(correct) / len(correct):.3f} | {sum(correct):5}/{len(correct):5}",
137-
flush=True,
138-
)
139-
140-
avg_correct = []
141-
142-
for subject, correct in all_correct.items():
143-
avg_correct += correct
144-
145-
if torch.distributed.get_rank() == 0:
181+
for subj in sorted(subject_correct):
182+
correct = subject_correct[subj]
183+
n = len(correct)
184+
print(f"{subj:48} | {sum(correct) / n:.3f} | {sum(correct):5}/{n:5}", flush=True)
146185
print("{:48} | {:5} | {:11}".format("-" * 48, "-" * 5, "-" * 11), flush=True)
147186
print(
148-
"{:48} | {:.3f} | {:5}/{:5}".format(
149-
"average", sum(avg_correct) / len(avg_correct), sum(avg_correct), len(avg_correct)
150-
),
187+
"{:48} | {:.3f} | {:5}/{:5}".format("average", avg, sum(all_correct), n_total),
151188
flush=True,
152189
)
153190

154-
return sum(avg_correct) / len(avg_correct)
191+
return avg

tests/gpu_megatron/torch/utils/plugins/test_utils_megatron.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525

2626
def _test_megatron_generate_and_mmlu(rank, size):
2727
initialize_for_megatron(tensor_model_parallel_size=size, seed=SEED)
28-
2928
model = get_mcore_qwen3_600m(tensor_model_parallel_size=size).cuda().eval()
30-
3129
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
3230

3331
messages = [
@@ -42,9 +40,9 @@ def _test_megatron_generate_and_mmlu(rank, size):
4240
model_inputs = tokenizer([text], return_tensors="pt").to(device="cuda")
4341
output_ids = megatron_generate(model, model_inputs["input_ids"])
4442
output_text = tokenizer.batch_decode(output_ids)
45-
print(output_text)
43+
print(rank, output_text)
4644

47-
assert megatron_mmlu(model, tokenizer) > 0.24
45+
assert 0.37 < megatron_mmlu(model, tokenizer, fraction=0.1, batch_size=16) < 0.38
4846

4947

5048
def test_megatron_generate_and_mmlu(dist_workers):

0 commit comments

Comments
 (0)