Skip to content

fix(inference): size DynamicInferenceContext KV layer_map for non-uniform PP#4775

Draft
athitten wants to merge 2 commits into
NVIDIA:mainfrom
athitten:fix/dynamic-inference-context-pp-layer-count
Draft

fix(inference): size DynamicInferenceContext KV layer_map for non-uniform PP#4775
athitten wants to merge 2 commits into
NVIDIA:mainfrom
athitten:fix/dynamic-inference-context-pp-layer-count

Conversation

@athitten
Copy link
Copy Markdown

@athitten athitten commented May 13, 2026

What does this PR do ?

Fixes incorrect sizing of the dynamic inference KV cache layer map for pure Transformer models when pipeline parallelism does not give every rank the same number of decoder layers (e.g. embedding/loss accounted in the PP split, uneven first/last stages, or explicit PP layouts).

Problem

DynamicInferenceContext set:

num_attention_layers = model_config.num_layers // pipeline_model_parallel_size
and built layer_map as {0:0, 1:1, …, num_attention_layers-1}. That matches uniform PP only.

When account_for_embedding_in_pipeline_split / account_for_loss_in_pipeline_split, num_layers_in_first_pipeline_stage / num_layers_in_last_pipeline_stage, or pipeline_model_parallel_layout is used, the actual number of transformer layers on a rank is given by get_num_layers_to_build() (same as TransformerBlock). It can be larger than num_layers // pp_size on some ranks.

The model then runs attention for global layer indices that exceed len(layer_map) - 1, and append_key_value_cache does self.layer_map[layer_number - 1] leading to **KeyError** (e.g. missing key 5), often surfaced during warmup or larger batch scheduling.

Hit this error while running evaluation on Qwen3-235B-A22B

Solution

For the non-hybrid (pure Transformer) branch, compute:

num_attention_layers = get_num_layers_to_build(model_config, vp_stage=..., pp_rank=None)
with vp_stage derived from virtual PP world size / rank when VPP > 1, matching TransformerBlock.

Import get_num_layers_to_build inside the branch to limit import-time coupling with transformer_block (which pulls in inference contexts).

Before fix:

^[[36m(ModelWorker pid=3343587)^[[0m   File "/opt/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/inference/contexts/dynamic_context.py", line 971, in append_key_value_cache
^[[36m(ModelWorker pid=3343587)^[[0m     attention_layer_number = self.layer_map[layer_number - 1]
^[[36m(ModelWorker pid=3343587)^[[0m                              ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
^[[36m(ModelWorker pid=3343587)^[[0m KeyError: 5

After fix in this PR:

The error does not occur anymore and evaluation runs fine on Qwen3-235B-A22B

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Issue tracking

For PRs from open-source community contributors:

  • New features: a linked issue is required. Please open a feature request and reference it here before submitting the PR.
  • Small updates (bug fixes, minor improvements): a linked issue is recommended and will accelerate the PR review process.

Linked issue:

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

…form PP

Replace num_layers // pp_size with get_num_layers_to_build so the
attention layer count matches TransformerBlock on each PP rank.
Embedding/loss pipeline splits and uneven first/last stages otherwise
left layer_map too short, causing KeyError in append_key_value_cache.
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 13, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@athitten
Copy link
Copy Markdown
Author

/ok to test 0825f05

Add TestDynamicContext.test_uneven_decoder_pp_layer_map_matches_get_num_layers_to_build
for num_layers_in_first/last_pipeline_stage with PP=2. Asserts DynamicInferenceContext
num_attention_layers and identity layer_map length match get_num_layers_to_build on
each rank, unlike uniform num_layers // pp_size.
@athitten
Copy link
Copy Markdown
Author

athitten commented May 14, 2026

@NVIDIA/mcore-oncall can you help start CI on this PR ? Thanks. Looks like I dont have permission

Comment on lines +371 to +373
vp_sz = parallel_state.get_virtual_pipeline_model_parallel_world_size()
if vp_sz is not None and vp_sz > 1:
vp_stage = parallel_state.get_virtual_pipeline_model_parallel_rank()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we derive this from pg_collection.pp instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants