|
18 | 18 | while skipping pruning of num_attention_heads using following defaults: |
19 | 19 | 1024 samples from nemotron-post-training-dataset-v2 for calibration, |
20 | 20 | at-most 20% depth (num_layers) and 40% width is pruned per prunable hparam (hidden_size, ffn_hidden_size, ...), |
21 | | - top-10 candidates are evaluated for MMLU score (10% sampled data) to select the best model. |
| 21 | + top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model. |
22 | 22 |
|
23 | 23 | torchrun --nproc_per_node 2 prune_minitron.py \ |
24 | 24 | --hf_model_name_or_path Qwen/Qwen3-8B \ |
@@ -140,11 +140,11 @@ def get_args() -> argparse.Namespace: |
140 | 140 | parser.add_argument( |
141 | 141 | "--prune_score_func", |
142 | 142 | type=str, |
143 | | - default="mmlu_10pct", |
| 143 | + default="mmlu_5pct", |
144 | 144 | help=( |
145 | 145 | "Score function to use for NAS-based pruning (--prune_target_params). Only supports MMLU at the moment. " |
146 | 146 | "Format: mmlu_<N>pct where <N> is the percentage of MMLU data to sample per subject " |
147 | | - "(e.g. mmlu_10pct for 10%, mmlu_100pct for full eval)." |
| 147 | + "(e.g. mmlu_5pct for 5%, mmlu_100pct for full eval)." |
148 | 148 | ), |
149 | 149 | ) |
150 | 150 | parser.add_argument( |
@@ -299,17 +299,13 @@ def main(args: argparse.Namespace): |
299 | 299 | match = re.fullmatch(r"mmlu_(\d+)pct", args.prune_score_func) |
300 | 300 | if not match: |
301 | 301 | raise ValueError( |
302 | | - f"Invalid score function: {args.prune_score_func}. " |
303 | | - "Expected format: mmlu_<N>pct (e.g. mmlu_10pct)" |
| 302 | + f"Invalid score function: {args.prune_score_func}. Expected format: mmlu_<N>pct (e.g. mmlu_5pct)" |
304 | 303 | ) |
305 | | - mmlu_pct = int(match.group(1)) |
306 | | - if not 0 < mmlu_pct <= 100: |
307 | | - raise ValueError("--prune_score_func percentage must be in the range [1, 100].") |
308 | | - _mmlu_frac = mmlu_pct / 100.0 |
| 304 | + mmlu_frac = float(match.group(1)) / 100.0 |
309 | 305 |
|
310 | 306 | def score_func(m): |
311 | 307 | return megatron_mmlu( |
312 | | - m, tokenizer, few_shots=0, fraction=_mmlu_frac, batch_size=args.calib_mbs |
| 308 | + m, tokenizer, few_shots=0, fraction=mmlu_frac, batch_size=args.calib_mbs |
313 | 309 | ) |
314 | 310 |
|
315 | 311 | pruning_config["score_func"] = score_func |
|
0 commit comments