Skip to content

Gramnarayan/pr15063 mtp adp nan debug#256

Draft
govind-ramnarayan wants to merge 9 commits into
eg/superv3-on-upstreamfrom
gramnarayan/pr15063-mtp-adp-nan-debug
Draft

Gramnarayan/pr15063 mtp adp nan debug#256
govind-ramnarayan wants to merge 9 commits into
eg/superv3-on-upstreamfrom
gramnarayan/pr15063-mtp-adp-nan-debug

Conversation

@govind-ramnarayan

Copy link
Copy Markdown

@coderabbitai summary

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

MrGeva and others added 9 commits June 12, 2026 11:10
…mization

Rebases the SuperV3-MTP attention-DP optimization onto current upstream/main
(which now carries gk's MoE all-to-all stateful cache NVIDIA#13718/NVIDIA#13723 and gagam's
SSM-replay PR).

MoE all-to-all per-rank token budget (runtime_max_tokens_per_rank):
- Replace the per-iteration cross-rank-max read (an int(batch_info_host[14]
  .item()) on a pinned-host tensor, fed by a per-forward tp_allgather) with a
  sync-free shape-based budget via _hybrid_runtime_max_tokens_per_rank: under
  cuda-graph capture/warm-up the budget is x.shape[0] (uniform across DP ranks
  because maybe_pad_for_cuda_graph pads every rank to a common cg_batch_size and
  MTP tokens-per-seq is uniform), gated so the tight budget is only taken while
  the MoE-GEMM row count stays in the fast small-M trtllm-gen tactic region
  (https://nvbugspro.nvidia.com/bug/6247543); in eager (prefill or bypass) it
  falls back to the static max_num_tokens every rank computes identically. No
  per-layer host read.
- Drop the now-dead batch_info_host plumbing for the DP-max slot: the slot-14
  (max_dp_num_tokens) storage and update/get accessors in BatchInfo
  (_NUM_ELEMENTS 15->14), the pre-forward tp_allgather + update in the AD shim,
  and the batch_info_host injection into the MoE op in both the dict-based
  (sharding.py) and IR-based (sharding_ir.py) sharding paths, plus the op
  signatures in trtllm_moe.py / torch_moe.py.

MTP + attention-DP correctness:
- Keep the draft-EP revert under attention-DP in sharding (replicate the draft
  model's MoE rather than EP-sharding it) to avoid the shared-workspace
  corruption that hangs the all-to-all at concurrency.
- Forward max_draft_len / max_total_draft_tokens to PyExecutor so the
  attention-DP dummy request is seeded with py_draft_tokens=[1]*max_total_draft_tokens
  instead of [], preventing it from classifying as decode and tripping the eagle
  wrapper's `assert num_decode == 0` under MTP + attention_dp.

Mixed-mode cuda-graph bypass uses upstream's process-wide BypassCapturedGraphs()
context manager (cuda_graph_state.in_bypass()), keeping all ranks consistent when
one is in prefill.

Note: the attention-DP request balancer is intentionally NOT included here; it
showed no throughput gain in benchmarking and will be evaluated in a follow-up.

Tests:
- Add NVFP4 SuperV3-MTP attn-DP perf-sanity config + post-merge enrollment.
- test_mtp / test_accuracy NVFP4 coverage resolved to upstream's parametrization.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
…rV3-MTP config

Add attention_dp_config (enable_balance=true, batching_wait_iters=10,
timeout_iters=16) to super_v3_mtp.yaml. enable_attention_dp is already set
in this config, so the balancer (PyExecutor._balance_adp_requests) is active:
it balances per-rank decode load so the busiest rank does not gate every MoE
all-to-all collective step, which otherwise inflates inter-token latency.

Measured +64-111% output-token throughput at concurrency 128/256 on
SuperV3-MTP NVFP4 attention-DP serving (SPEED-Bench). Trade-off: defers
prefill, raising TTFT at high concurrency.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
…er config

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
…on-MTP) config

Enable attention-DP (enable_attention_dp) and the request balancer
(attention_dp_config: enable_balance) in the non-MTP super_v3.yaml,
matching super_v3_mtp.yaml.

Measured on NVFP4 120B, 4xGB200, SPEED-Bench (attention-DP+balancer vs
no-ADP tensor-parallel baseline): output-token throughput +14/+39/+76/+80%
at concurrency 1/32/128/256, with inter-token latency roughly halved at
high concurrency. Trade-off: higher TTFT at high concurrency (balancer
prefill deferral).

Also fixes a multi-GPU deadlock: the prior config (EP MoE without
attention-DP) hangs in decode at world_size>1; enabling attention-DP
routes the MoE via all-to-all and runs cleanly. No-op for single-GPU
(balancer requires tp_size>1).

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
…perV3 (non-MTP) config"

This reverts commit ec165dd.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
PythonMambaCacheManager allocated conv_states/ssm_states with torch.empty
(uninitialized HBM) instead of torch.zeros. Under attention-DP an idle/short
rank decodes an extra generation slot without a prefill, and the causal-conv
op reads that slot's uninitialized conv state. Whether the garbage reads as
NaN vs finite depends on the allocation layout, which ssm_replay=true shifts
(replay buffers vs intermediate_ssm_state_cache). The resulting NaN is
propagated by the conv into the residual stream, latched by the SSM replay
path into the persistent recurrent state, and spread across ranks by the MoE
all-to-all, collapsing GSM8K accuracy on SuperV3 MTP + attention-DP
(test_mtp[nvfp4_ws4_80gb-trtllm]: 42.95 -> 92.04 PASS; eager 44.92 -> 92.23).
Zeros are a valid empty Mamba state, so zero-init is correct and robust
regardless of allocation layout or which slots are decoded un-prefilled.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
multi_stream_moe was disabled because it produced NaNs in the target's sampled
logits. Root cause was the uninitialized Mamba conv/ssm cache pools (torch.empty)
read by an un-prefilled attention-DP generation slot, fixed by zero-init in
PythonMambaCacheManager (commit 323e09c). With that fix ms_moe is accuracy-
valid: test_mtp[nvfp4_ws4_80gb-trtllm] GSM8K 91.89 (PASS). It adds ~+18% (c128)
and ~+56% (c256) serving throughput on the SPEED-Bench sweep.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Disable piecewise CUDA graph capture in the SuperV3 MTP registry config because MTP does not support piecewise capture yet.

For the diagnostic branch, initialize the persistent Mamba conv and SSM state pools to NaNs instead of zeros so uninitialized state reads would surface during the MTP + ADP path.

Observed result from the in-repo accuracy target tests/integration/defs/accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[fp8_ws4_80gb-trtllm]: passed on 4 GPUs with torch-cudagraph, ADP enabled, and piecewise disabled. GSM8K average accuracy was 92.27, evaluated accuracy was 92.267 against a 89.407 threshold, and spec-dec acceptance was 53.70% (16360/30468 tokens across 1001 speculative iterations).

Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
@govind-ramnarayan

Copy link
Copy Markdown
Author

Here is the simple local diff for the current branch HEAD (8ef9470ad852) against its parent (HEAD^). The GitHub compare is noisy because of the rebase.

Validation from the rerun of tests/integration/defs/accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[fp8_ws4_80gb-trtllm]:

  • Result: PASSED in 467.00s
  • ADP enabled: enable_attention_dp=True
  • Main graph sharding: EP=40
  • Draft graph sharding: EP=0, with EP reverted/replicated under attention-DP
  • Torch cudagraph captured batch sizes: 8, 4, 2, 1
  • GSM8K average accuracy: 92.68
  • Evaluated accuracy: 92.684, threshold 89.407
  • Spec-dec acceptance: 53.28% (16744/31428 tokens, 1001 speculative iterations)
diff --git a/examples/auto_deploy/model_registry/configs/super_v3_mtp.yaml b/examples/auto_deploy/model_registry/configs/super_v3_mtp.yaml
index b46ff60e6f..04ec1fc316 100644
--- a/examples/auto_deploy/model_registry/configs/super_v3_mtp.yaml
+++ b/examples/auto_deploy/model_registry/configs/super_v3_mtp.yaml
@@ -21,6 +21,9 @@ attention_dp_config:
   batching_wait_iters: 10
   timeout_iters: 16
 transforms:
+  compile_model:
+    # Piecewise CUDA graph capture is disabled for MTP until MTP supports it.
+    piecewise_enabled: false
   detect_sharding:
     allreduce_strategy: NCCL
     enable_attention_dp: true
diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py
index 950b908160..55be655918 100644
--- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py
+++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py
@@ -459,23 +459,16 @@ class PythonMambaCacheManager(BaseResourceManager):
         ssm_state_shape = (nheads, head_dim, d_state)
 
         # create mamba conv and ssm states
-        # Zero-initialize (not torch.empty): an un-prefilled cache slot (e.g. an
-        # attention-DP idle/dummy generation request that is decoded without a
-        # prefill) is read by the causal-conv / SSM ops before anything writes it.
-        # Uninitialized HBM can be NaN/Inf, which the causal-conv then propagates
-        # into the residual stream (and the SSM replay path latches into the
-        # persistent recurrent state, spreading it across ranks via the MoE
-        # all-to-all). Whether the garbage is NaN vs finite depends on the
-        # allocation layout, which is exactly why the failure only surfaced under
-        # {ssm_replay=true × attention-DP × MTP}. Zeros are a valid empty state.
-        conv_states = torch.zeros(
+        conv_states = torch.full(
             size=(num_local_layers, max_batch_size) + conv_state_shape,
+            fill_value=float("nan"),
             dtype=dtype,
             device=device,
         )
 
-        ssm_states = torch.zeros(
+        ssm_states = torch.full(
             size=(num_local_layers, max_batch_size) + ssm_state_shape,
+            fill_value=float("nan"),
             dtype=self.mamba_ssm_cache_dtype,
             device=device,
         )

@MrGeva MrGeva force-pushed the eg/superv3-on-upstream branch from c76eedc to 033e80d Compare June 17, 2026 09:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants