Skip to content

Commit 0b2aabc

Browse files
Fix PP inference correctness in megatron_generate and megatron_importer
Two bugs caused pipeline-parallel inference to produce garbage output: 1. megatron_generate/megatron_prefill used get_forward_backward_func() (the training pipeline scheduler), which is not designed for inference. Rewrote both functions to use explicit P2P communication via recv_from_prev_pipeline_rank_ / send_to_next_pipeline_rank, matching the pattern from run_mcore_inference. 2. import_mcore_gpt_from_hf loads HF weights into stage 0's embedding but never updates the output_layer on the last PP stage when share_embeddings_and_output_weights=True. After import, call model.setup_embeddings_and_output_layer() to re-run the all-reduce that syncs the output layer from stage 0 to the last stage. Also parametrize the megatron_generate test to cover both TP and PP. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 67b8313 commit 0b2aabc

File tree

6 files changed

+208
-177
lines changed

6 files changed

+208
-177
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Changelog
2222

2323
**Bug Fixes**
2424

25+
- Fix Megatron utility functions for generation (with pipeline parallelism) and MMLU score evaluation (10x speedup).
2526
- Fix Minitron pruning (``mcore_minitron``) for MoE models. Importance estimation hooks were incorrectly registered for MoE modules and NAS step was hanging before this.
2627
- Fix TRT support for remote autotuning in ONNX Autotune from 10.16+ to 10.15+ and fix TRT versioning check to the ``trtexec`` version instead of the TRT Python API when using ``trtexec`` backend.
2728

modelopt/torch/export/plugins/megatron_importer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,17 @@ def _import_state_dict(self):
747747
if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights:
748748
self.rules["output_layer"](model.output_layer)
749749

750+
# For PP with shared embedding/output weights, re-sync the output layer on the last
751+
# pipeline stage from stage 0's (now HF-loaded) embedding. At model init,
752+
# setup_embeddings_and_output_layer() zeros out the last stage's weight and all-reduces
753+
# from stage 0. After importing HF weights into stage 0's embedding, that sync is stale,
754+
# so we re-run it here.
755+
if (
756+
model.share_embeddings_and_output_weights
757+
and model.config.pipeline_model_parallel_size > 1
758+
):
759+
model.setup_embeddings_and_output_layer()
760+
750761
# MTP
751762
if hasattr(model, "mtp"):
752763
layer_pbar.set_description("Importing MTP")

modelopt/torch/utils/logging.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ def no_stdout():
105105

106106
def print_rank_0(*args, **kwargs):
107107
"""Prints only on the master process."""
108+
kwargs.setdefault("flush", True)
108109
if dist.is_master():
109-
print(*args, **kwargs, flush=True)
110+
print(*args, **kwargs)
110111

111112

112113
def warn_rank_0(message, *args, **kwargs):

0 commit comments

Comments
 (0)