Gramnarayan/pr15063 mtp adp nan debug#256
Draft
govind-ramnarayan wants to merge 9 commits into
Draft
Conversation
…mization Rebases the SuperV3-MTP attention-DP optimization onto current upstream/main (which now carries gk's MoE all-to-all stateful cache NVIDIA#13718/NVIDIA#13723 and gagam's SSM-replay PR). MoE all-to-all per-rank token budget (runtime_max_tokens_per_rank): - Replace the per-iteration cross-rank-max read (an int(batch_info_host[14] .item()) on a pinned-host tensor, fed by a per-forward tp_allgather) with a sync-free shape-based budget via _hybrid_runtime_max_tokens_per_rank: under cuda-graph capture/warm-up the budget is x.shape[0] (uniform across DP ranks because maybe_pad_for_cuda_graph pads every rank to a common cg_batch_size and MTP tokens-per-seq is uniform), gated so the tight budget is only taken while the MoE-GEMM row count stays in the fast small-M trtllm-gen tactic region (https://nvbugspro.nvidia.com/bug/6247543); in eager (prefill or bypass) it falls back to the static max_num_tokens every rank computes identically. No per-layer host read. - Drop the now-dead batch_info_host plumbing for the DP-max slot: the slot-14 (max_dp_num_tokens) storage and update/get accessors in BatchInfo (_NUM_ELEMENTS 15->14), the pre-forward tp_allgather + update in the AD shim, and the batch_info_host injection into the MoE op in both the dict-based (sharding.py) and IR-based (sharding_ir.py) sharding paths, plus the op signatures in trtllm_moe.py / torch_moe.py. MTP + attention-DP correctness: - Keep the draft-EP revert under attention-DP in sharding (replicate the draft model's MoE rather than EP-sharding it) to avoid the shared-workspace corruption that hangs the all-to-all at concurrency. - Forward max_draft_len / max_total_draft_tokens to PyExecutor so the attention-DP dummy request is seeded with py_draft_tokens=[1]*max_total_draft_tokens instead of [], preventing it from classifying as decode and tripping the eagle wrapper's `assert num_decode == 0` under MTP + attention_dp. Mixed-mode cuda-graph bypass uses upstream's process-wide BypassCapturedGraphs() context manager (cuda_graph_state.in_bypass()), keeping all ranks consistent when one is in prefill. Note: the attention-DP request balancer is intentionally NOT included here; it showed no throughput gain in benchmarking and will be evaluated in a follow-up. Tests: - Add NVFP4 SuperV3-MTP attn-DP perf-sanity config + post-merge enrollment. - test_mtp / test_accuracy NVFP4 coverage resolved to upstream's parametrization. Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
…rV3-MTP config Add attention_dp_config (enable_balance=true, batching_wait_iters=10, timeout_iters=16) to super_v3_mtp.yaml. enable_attention_dp is already set in this config, so the balancer (PyExecutor._balance_adp_requests) is active: it balances per-rank decode load so the busiest rank does not gate every MoE all-to-all collective step, which otherwise inflates inter-token latency. Measured +64-111% output-token throughput at concurrency 128/256 on SuperV3-MTP NVFP4 attention-DP serving (SPEED-Bench). Trade-off: defers prefill, raising TTFT at high concurrency. Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
…er config Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
…on-MTP) config Enable attention-DP (enable_attention_dp) and the request balancer (attention_dp_config: enable_balance) in the non-MTP super_v3.yaml, matching super_v3_mtp.yaml. Measured on NVFP4 120B, 4xGB200, SPEED-Bench (attention-DP+balancer vs no-ADP tensor-parallel baseline): output-token throughput +14/+39/+76/+80% at concurrency 1/32/128/256, with inter-token latency roughly halved at high concurrency. Trade-off: higher TTFT at high concurrency (balancer prefill deferral). Also fixes a multi-GPU deadlock: the prior config (EP MoE without attention-DP) hangs in decode at world_size>1; enabling attention-DP routes the MoE via all-to-all and runs cleanly. No-op for single-GPU (balancer requires tp_size>1). Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
…perV3 (non-MTP) config" This reverts commit ec165dd. Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
PythonMambaCacheManager allocated conv_states/ssm_states with torch.empty (uninitialized HBM) instead of torch.zeros. Under attention-DP an idle/short rank decodes an extra generation slot without a prefill, and the causal-conv op reads that slot's uninitialized conv state. Whether the garbage reads as NaN vs finite depends on the allocation layout, which ssm_replay=true shifts (replay buffers vs intermediate_ssm_state_cache). The resulting NaN is propagated by the conv into the residual stream, latched by the SSM replay path into the persistent recurrent state, and spread across ranks by the MoE all-to-all, collapsing GSM8K accuracy on SuperV3 MTP + attention-DP (test_mtp[nvfp4_ws4_80gb-trtllm]: 42.95 -> 92.04 PASS; eager 44.92 -> 92.23). Zeros are a valid empty Mamba state, so zero-init is correct and robust regardless of allocation layout or which slots are decoded un-prefilled. Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
multi_stream_moe was disabled because it produced NaNs in the target's sampled logits. Root cause was the uninitialized Mamba conv/ssm cache pools (torch.empty) read by an un-prefilled attention-DP generation slot, fixed by zero-init in PythonMambaCacheManager (commit 323e09c). With that fix ms_moe is accuracy- valid: test_mtp[nvfp4_ws4_80gb-trtllm] GSM8K 91.89 (PASS). It adds ~+18% (c128) and ~+56% (c256) serving throughput on the SPEED-Bench sweep. Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Disable piecewise CUDA graph capture in the SuperV3 MTP registry config because MTP does not support piecewise capture yet. For the diagnostic branch, initialize the persistent Mamba conv and SSM state pools to NaNs instead of zeros so uninitialized state reads would surface during the MTP + ADP path. Observed result from the in-repo accuracy target tests/integration/defs/accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[fp8_ws4_80gb-trtllm]: passed on 4 GPUs with torch-cudagraph, ADP enabled, and piecewise disabled. GSM8K average accuracy was 92.27, evaluated accuracy was 92.267 against a 89.407 threshold, and spec-dec acceptance was 53.70% (16360/30468 tokens across 1001 speculative iterations). Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
Author
|
Here is the simple local diff for the current branch HEAD ( Validation from the rerun of
diff --git a/examples/auto_deploy/model_registry/configs/super_v3_mtp.yaml b/examples/auto_deploy/model_registry/configs/super_v3_mtp.yaml
index b46ff60e6f..04ec1fc316 100644
--- a/examples/auto_deploy/model_registry/configs/super_v3_mtp.yaml
+++ b/examples/auto_deploy/model_registry/configs/super_v3_mtp.yaml
@@ -21,6 +21,9 @@ attention_dp_config:
batching_wait_iters: 10
timeout_iters: 16
transforms:
+ compile_model:
+ # Piecewise CUDA graph capture is disabled for MTP until MTP supports it.
+ piecewise_enabled: false
detect_sharding:
allreduce_strategy: NCCL
enable_attention_dp: true
diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py
index 950b908160..55be655918 100644
--- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py
+++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py
@@ -459,23 +459,16 @@ class PythonMambaCacheManager(BaseResourceManager):
ssm_state_shape = (nheads, head_dim, d_state)
# create mamba conv and ssm states
- # Zero-initialize (not torch.empty): an un-prefilled cache slot (e.g. an
- # attention-DP idle/dummy generation request that is decoded without a
- # prefill) is read by the causal-conv / SSM ops before anything writes it.
- # Uninitialized HBM can be NaN/Inf, which the causal-conv then propagates
- # into the residual stream (and the SSM replay path latches into the
- # persistent recurrent state, spreading it across ranks via the MoE
- # all-to-all). Whether the garbage is NaN vs finite depends on the
- # allocation layout, which is exactly why the failure only surfaced under
- # {ssm_replay=true × attention-DP × MTP}. Zeros are a valid empty state.
- conv_states = torch.zeros(
+ conv_states = torch.full(
size=(num_local_layers, max_batch_size) + conv_state_shape,
+ fill_value=float("nan"),
dtype=dtype,
device=device,
)
- ssm_states = torch.zeros(
+ ssm_states = torch.full(
size=(num_local_layers, max_batch_size) + ssm_state_shape,
+ fill_value=float("nan"),
dtype=self.mamba_ssm_cache_dtype,
device=device,
) |
c76eedc to
033e80d
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
@coderabbitai summary
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.