Skip to content

Commit c11b8bf

Browse files
Fix PP inference correctness in megatron_generate and megatron_importer
Two bugs caused pipeline-parallel inference to produce garbage output: 1. megatron_generate/megatron_prefill used get_forward_backward_func() (the training pipeline scheduler), which is not designed for inference. Rewrote both functions to use explicit P2P communication via recv_from_prev_pipeline_rank_ / send_to_next_pipeline_rank, matching the pattern from run_mcore_inference. 2. import_mcore_gpt_from_hf loads HF weights into stage 0's embedding but never updates the output_layer on the last PP stage when share_embeddings_and_output_weights=True. After import, call model.setup_embeddings_and_output_layer() to re-run the all-reduce that syncs the output layer from stage 0 to the last stage. Also parametrize the megatron_generate test to cover both TP and PP. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 303e429 commit c11b8bf

File tree

5 files changed

+211
-176
lines changed

5 files changed

+211
-176
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Changelog
2222

2323
**Bug Fixes**
2424

25+
- Fix Megatron utility functions for generation and MMLU score evaluation.
2526
- Fix Minitron pruning (``mcore_minitron``) for MoE models. Importance estimation hooks were incorrectly registered for MoE modules and NAS step was hanging before this.
2627
- Fix TRT support for remote autotuning in ONNX Autotune from 10.16+ to 10.15+ and fix TRT versioning check to the ``trtexec`` version instead of the TRT Python API when using ``trtexec`` backend.
2728

modelopt/torch/export/plugins/megatron_importer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,17 @@ def _import_state_dict(self):
747747
if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights:
748748
self.rules["output_layer"](model.output_layer)
749749

750+
# For PP with shared embedding/output weights, re-sync the output layer on the last
751+
# pipeline stage from stage 0's (now HF-loaded) embedding. At model init,
752+
# setup_embeddings_and_output_layer() zeros out the last stage's weight and all-reduces
753+
# from stage 0. After importing HF weights into stage 0's embedding, that sync is stale,
754+
# so we re-run it here.
755+
if (
756+
model.share_embeddings_and_output_weights
757+
and model.config.pipeline_model_parallel_size > 1
758+
):
759+
model.setup_embeddings_and_output_layer()
760+
750761
# MTP
751762
if hasattr(model, "mtp"):
752763
layer_pbar.set_description("Importing MTP")

0 commit comments

Comments
 (0)