Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Changelog

**Bug Fixes**

- Fix Megatron utility functions for generation (with pipeline parallelism) and ~10x speedup in MMLU score evaluation (by batching prefill passes).
- Fix Minitron pruning (``mcore_minitron``) for MoE models. Importance estimation hooks were incorrectly registered for MoE modules and NAS step was hanging before this.
- Fix TRT support for remote autotuning in ONNX Autotune from 10.16+ to 10.15+ and fix TRT versioning check to the ``trtexec`` version instead of the TRT Python API when using ``trtexec`` backend.

Expand Down
18 changes: 8 additions & 10 deletions examples/megatron_bridge/prune_minitron.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
while skipping pruning of num_attention_heads using following defaults:
1024 samples from nemotron-post-training-dataset-v2 for calibration,
at-most 20% depth (num_layers) and 40% width is pruned per prunable hparam (hidden_size, ffn_hidden_size, ...),
top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model.
top-10 candidates are evaluated for MMLU score (10% sampled data) to select the best model.

torchrun --nproc_per_node 2 prune_minitron.py \
--hf_model_name_or_path Qwen/Qwen3-8B \
Expand Down Expand Up @@ -140,11 +140,11 @@ def get_args() -> argparse.Namespace:
parser.add_argument(
"--prune_score_func",
type=str,
default="mmlu_5pct",
default="mmlu_10pct",
help=(
"Score function to use for NAS-based pruning (--prune_target_params). Only supports MMLU at the moment. "
"Format: mmlu_<N>pct where <N> is the percentage of MMLU data to sample per subject "
"(e.g. mmlu_5pct for 5%, mmlu_100pct for full eval)."
"(e.g. mmlu_10pct for 10%, mmlu_100pct for full eval)."
),
)
parser.add_argument(
Expand Down Expand Up @@ -299,16 +299,14 @@ def main(args: argparse.Namespace):
match = re.fullmatch(r"mmlu_(\d+)pct", args.prune_score_func)
if not match:
raise ValueError(
f"Invalid score function: {args.prune_score_func}. "
"Expected format: mmlu_<N>pct (e.g. mmlu_5pct)"
f"Invalid score function: {args.prune_score_func}. Expected format: mmlu_<N>pct (e.g. mmlu_10pct)"
)
mmlu_pct = int(match.group(1))
if not 0 < mmlu_pct <= 100:
raise ValueError("--prune_score_func percentage must be in the range [1, 100].")
_mmlu_pct = mmlu_pct / 100.0
mmlu_frac = float(match.group(1)) / 100.0

def score_func(m):
return megatron_mmlu(m, tokenizer, percentage=_mmlu_pct)
return megatron_mmlu(
m, tokenizer, few_shots=0, fraction=mmlu_frac, batch_size=args.calib_mbs
)

pruning_config["score_func"] = score_func
pruning_config["max_width_pruning"] = args.max_width_pruning
Expand Down
4 changes: 2 additions & 2 deletions examples/pruning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ This mode can be useful when you don't know the exact dimensions you want to pru
from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu

def score_func(m):
return megatron_mmlu(m, tokenizer, percentage=0.05) # 5% sampled data for faster eval
return megatron_mmlu(m, tokenizer, fraction=0.1, batch_size=4) # 10% sampled data for faster eval

# Specify target parameter count and configure the auto pruning algorithm
# Save minitron scores at checkpoint so we can resume pruning without running the forward loop again
Expand All @@ -147,7 +147,7 @@ mtp.prune(...)

1. **Importance Scoring**: Same as manual pruning - computes activation magnitudes for all parameters (takes ~5 minutes for an 8B model)
2. **Search Space Construction**: Generates a search space of possible architectures based search space config and other configs (`max_width_pruning`, `max_depth_pruning`, `hparams_to_skip`)
3. **Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes ~10 mins per candidate for an 8B model pruning)
3. **Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes ~5 min per candidate for an 8B model MMLU score with 10% sampled data)
4. **Best Architecture Selection**: Returns the architecture (best `export_config`) with the highest actual score from the top-K evaluated architectures
5. **Weight Slicing**: Slices the model weights according to the best pruned architecture found

Expand Down
11 changes: 11 additions & 0 deletions modelopt/torch/export/plugins/megatron_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,17 @@ def _import_state_dict(self):
if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights:
self.rules["output_layer"](model.output_layer)

# For PP with shared embedding/output weights, re-sync the output layer on the last
# pipeline stage from stage 0's (now HF-loaded) embedding. At model init,
# setup_embeddings_and_output_layer() zeros out the last stage's weight and all-reduces
# from stage 0. After importing HF weights into stage 0's embedding, that sync is stale,
# so we re-run it here.
if (
model.share_embeddings_and_output_weights
and model.config.pipeline_model_parallel_size > 1
):
model.setup_embeddings_and_output_layer()

# MTP
if hasattr(model, "mtp"):
layer_pbar.set_description("Importing MTP")
Expand Down
3 changes: 3 additions & 0 deletions modelopt/torch/prune/plugins/mcore_minitron.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ class CandidateSubnet:
score: float | None


torch.serialization.add_safe_globals([CandidateSubnet])


class MCoreMinitronSearcher(BaseSearcher):
"""Searcher for Minitron pruning algorithm.

Expand Down
3 changes: 2 additions & 1 deletion modelopt/torch/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def no_stdout():

def print_rank_0(*args, **kwargs):
"""Prints only on the master process."""
kwargs.setdefault("flush", True)
if dist.is_master():
print(*args, **kwargs, flush=True)
print(*args, **kwargs)


def warn_rank_0(message, *args, **kwargs):
Expand Down
Loading
Loading