Skip to content

[None][feat] Checkpointing variant of replay for MTP for mamba models#14203

Merged
hnover-nv merged 102 commits into
NVIDIA:mainfrom
hnover-nv:mamba_checkpointing_submit
Jun 19, 2026
Merged

[None][feat] Checkpointing variant of replay for MTP for mamba models#14203
hnover-nv merged 102 commits into
NVIDIA:mainfrom
hnover-nv:mamba_checkpointing_submit

Conversation

@hnover-nv

@hnover-nv hnover-nv commented May 16, 2026

Copy link
Copy Markdown
Collaborator

Description

Summary of Changes
Quick summary, see ideas below for more detail.

We replace old per-step replay with a "checkpointing" replay, where we don't write out mamba state every step. Main changes:

  • Families of tunable kernels for the actual state update, tuned for Nemo v3 Super on B200.
  • Framework adjustments to support the new conditions on mamba cache management reflecting when we write state. Unfortunate switch to double-buffered old_x as Triton can't tell two derived memory addresses are actually the same, leading to a read hazard.
  • New fields on MambaMetadata that track the number of requests in a batch that will want to write checkpoints this step, plus a tensor that stores per-request data for the batch, sorted into writes and nowrites.
  • Zero out mamba state in the cache manager at init. Prevoiusly we did not zero the actual mamba head state, or the convolution state. This matters with dummy requests, either cuda graph padding or ADP, which don't go through prefill, and so would read that uninitialized state, which could contain inf or NaN. Probably they don't have a way to "get out" to live requests without other bugs, but as nothing prevents cross-request aggregation at other layers, this does seem necessary.
  • In the same spirit, prior code changes have the cuda dummy padding requests share a cache slot. This results in races in those cache slots between different copies of the request in a batch. That's fine for most parameters, but some, like cum_dA and the relation to num_previously_accepted_tokens, have invariants that when violated can lead to inf/NaN. So we force dummy request slots to stay at PNAT = 0 and not flip their cache double buffer bit, so the races can't introduce problems.
  • We no longer support per-cache slot Philox seeds. Including that code caused big slowdowns (>20% in some cases). Almost certainly this is just that we are so auto-tuned,that small code changes perturb somethign important. But we did the tuning before that change went in, and do not have time to retune now. If it's a necessary capability, we can retune after this PR. But this brings us in line with both current "coopy every state" flashinfer, and the upcoming checkpointing version of flashifner.
  • Removal of benchmark script. It go insanely complicated to support efficient tuning runs. A version may come back some day.

While we were here:

  • Fixed/clarified get_max_resource_count, which was being treated for mamba as sometimes # of slots, and sometimes # of slots without dummies. Now it is unambiguously the number of slots for life requests.
  • TRTLLM_USE_PY_MAMBA=1 silently fell back to CppMambaHybridCacheManage if kv cache reuse was enabled. Given the debugging nature of the flag, we changed this to instead we raise an error, so you don't think you're testing something you're not.

Timing results
These are the time from the start of conv1d to the end of the last replay kernel. Replay state update is broken up into 2 or 3 kernels, and conv1d PDLs into the first one. It's a kernel microbenchmark that tries to simulate hot and cold inputs based on what is coming out of in_proj in the mamba2_mixer and what is not. he runtime depends on the mix of write and no-write requests, so to get a single score, we took the distribution of acceptance lengths of Super on a dataset, run with draft length 5, then treat those acceptance lengths as transitions for a Markov process and find the station distribution, then draw dummy requests with prev_num_accepted_token from that distribution.

Batch size is w.r.t. Nemotron v3 Super with TP=8. Across Nemotron families, all mamba heads are the same so all that really matters as you vary Super vs Ultra or TEP or ADP is # of mamba heads. I did not have time for e2e numbers, but the metadata computation is minimal up front, and unlike the prior replay vs "copy every draft state out", this should really only be changing kernel time.

Listed below are the timings vs old replay, and against the flashinfer checkpointing kernels from flashinfer-ai/flashinfer#3324. These kernels mostly win despite being Triton, due to the aggressie auto-tuning and the specialized kernel pair breakdown for larger batch (see below). Igor is actively working on FI / CUDA equivalents, which will surely be faster. These numbers used to be faster still, but our tunings ran into a triton compiler bug and we didn't have time to fully retune the workaround. Maybe later. See the bullet under "Other Tweaks" below.

We are faster than old replay everywhere, and faster than the new checkpointing flashinfer at 39/44 batch/dtype combos. More tuning time would probably help, especially anywhere that uses persistent_main vs persistent dynamic.

new_replay_timing_small_batches_b200 new_replay_timing_full_range_b200 speedup_vs_flashinfer_b200 speedup_vs_old_replay_b200

Ideas

Building on #13453, here we take the replay idea a step further, to a "checkpointing" replay. This replaces the old replay version in our code.

This section walks through the ideas, skip below to the tl;dr summary.

To summarize "old" replay, we would save a mamba state every step. It just lagged what would be the true state by one, because we don't know how many to accept. Compared to the prior method of saving every possible state at each step, we saved substantially on memory traffic in the kernel and at step-end when we'd otherwise copy the "winning" state over. And needing to only materialize one state let us save further by skipping a lot of computation, and accelerating what remained with tensor cores.

But there's no reason we have to write even one state every step. We can have an ssm state checkpoint arbitrarily far back, store all the needed inputs since, and reconstruct the state. When we don't accept tokens, we just overwrite the unaccepted ones with the next step's new ones. Of course, in the limit, that just becomes linear-ish attention where you kept the whole KV cache.

How often should we checkpoint? One could try to balance the memory traffic of the replay window with the state. But we're actually not just memory bound. The easiest choice is that our tensor core usage requires us to pad all our dimension T (draft length + 1) operations to at least 16, so that's a natural history size. In the prior PR you can see that the old kernels' runtimes were quite flat with respect to T up to 16, and that had larger T for both state replay and output generation. Here we replay up to 16 from history but only output T.

The algorithm, then is every step, if the number of previously accepted tokens (PNAT)+ T > window size, we can't store the T new tokens so we need to write a checkpoint and restart the history buffer. Otherwise we have room to append.

Already this approach saves time. We don't issue instructions for saving the state, and for quantized, stochastically rounded, or block scaled SSM dstates, we also save on that pre-writing computation. Also, by quantizing less often, we accumulate less quantization error, which means we could quantize more aggressively. For example, perhaps int8 w/ stochastic rounding is viable if it only happens every 8 tokens, but not every token. Speculative decoding already gives some of that benefit, but it depends on acceptance rate. Here, we guarantee 16 - T tokens between checkpoints, and it can be as high as 16.

This approach also unlocks a number of additional speedups:

The Rectangle
Ignoring dA/dt scaling factors, the original replay algorithm is:
last_step_state = very_old_state + (old_X^t @ old_B)
outputs = new_C @ one_back_state^t + (new_C @ new_B^t) @ new_X. # bypass all the draft states, straight to output tokens. That C @ B^t we call the CB matrix, it's TxT square lower triangular, padded into 16x16.
Here old_* are last step's new_*, except we only take the first k of them, the acceptance rate.

All the matrix multiplies have some dimension at T, padded to 16.
But if we aren't writing, then just as we go straight from one_back_state to output tokens, we can actually go straight from very_old_state. If you just plug last_step_state in and expand, you get:
outputs = new_C @ very_old_state^t + ( new_C @ (old_B cat new_B)^t) @ (old_X + new_X)
Now the CB matrix is a rectangular lower triangular (PNAT + T) x T, and this fits in our window size (16) exactly when PNAT + T <= window_size. Oh look, our non-checkpointing condition!

The old state computation is one of the largest matmuls, so this is a substantial savings. But due to implementation details it is not always faster at small batch sizes.

constexpr-specialized Persistent Kernels
For a given request, sometimes we now write a checkpoint and sometimes don't. And when we don't we may even use a different algorithm. We can specialize nowrite algorithm choice as a constexpr we tune based on batch/dtype/rounding mode, but not write vs nowrite. Or can we? If we do specialize them, the compiler can take a lot of advantage. Nowrite should need less live state, even more so when it uses rectangle or writing uses stochastic rounding or quantizing. The compiler generates very different code. We also build on our prior PR where we tune parameters like M (slice of a head_dim done by a block), number of warps, etc. Separate kernels lets us tune these knobs for each case.

Ideally we'd launch one correctly sized grid for all the write requests, and one for the nowrite. But we're using cuda graphs in pytorch, so we can't vary that. Instead, we need to launch enough CTAs to cover all requests as write, and all as nowrite. Although the kernel can early out on requests that aren't its type, that overhead made specialized kernels uncompetative.

Except, we can use persistent kernels. As long as we have enough work to fill all the SMs a few times over, then there's no real overhead. And if we add some secondary data structures to present the requests partitioned by write/nowrite, the kernels can just loop over their portion. If we pack other needed data into that structure, it's not even added latency.

This makes specialized kernels a huge win at medium and large batches. At smaller batches a "dynamic"(i.e. unspecialized, non-constexpr) kernel still wins, probably because we can't easily optimally PDL in PyTorch between the precomputation kernel and the two main kernels. And we didn't have time to explore emulating this with atomics. But the dynamic kernel also wins by being persistent. Similarly constrained by time, the precompute kernel is non-persistent dynamic, whereas at least one of persistent or specialized would surely help.

Other tweaks

  • Optionally support TMA loads and stores of mamba state, as a specializable knob per kernel and scenario.
  • Use of tensor-ization in triton to replace some for loops, notably in the heads per block part of precompute kernel. This both removed some undefined behavior and improves performance.
  • Sadly, we had to double-buffer old_x. In the main kernel, we read old x from the history buffer and new x from inputs from conv1d. Then on nowrite requests we append new x to old x, but on write requests it puts new x at the start of the history buffer, on top of old x. This is the same block and should be even threads doing the work, so should be safe, but triton does not realize they're the same, so we have to double buffer. In replay, where old_x and new_x are the same size, it was presumably more clear to the compiler. Or a latent bug, hard to know! If space is an issue, we could replace this with a circular buffer of size window + T, instead of today's 2 * window.
  • More sadly, our amazing auto-tuning discovered a bug in the Triton compiler where using num_stages on tl.range causes it to moves loads before the gdc_wait needed for PDL, which corrupts the answers. We tried a few fixes, and settled on moving the wait outside the loop when loop num_stages > 1. This lets tuning decide if it would rather have more PDL work without the pipelining, or PDL with it. We directly tested varying num_loops_stages while leaving the other tunings fixed, which lead to regressions at many batch size/dytpe combinations. We had time to partly re-auto-tune a few of the worst results (see next section for tuning details), but not all of them and many of the cells were clearly not converged. So further re-tuning would surely speed things up. int8+SR was especially hit by regressions, for whatever reason, although it was everywhere.

Tunings

We implemented a huge number of tunable knobs, and can pick which knobs to use based on batch size, dtype, and whether we are doing stochastic rounding. But how to pick the best? We have a mode choice of a single dynamic or a pair of specialized main kernels. Then each of these offers many tuning knobs. When we pick the specialized kernels, ther are a total of of 17 active knobs (more are in the code, but a few either don't compile or weren't tuned as they seemed dead), a mixture of boolean and integer. This resulted in over 61 billion possible combinations. While we'd hoped to be able to drop a lot of them, exhaustive searching of select subspaces showed none of the 17 were always dead. Obviously we can't exhaustively search the whole space for even one input cofiguration, let alone for all supported batch/dtype/rounding combinations. So we implemented a multi-process block coordinate descent optimization, with some tweaks to deal with cross-GPU performance drift. We also had to invest considerable effort into the actual benchmark runner: in-process CUPTI timing vs cuda events, multi-process for compilation and cupti analysis, optimization of cuda graph instantiation vs lauch overhead due to GIL contention.

These are honestly giant vibe coded messes, so we're removing the benchmarker from the repo and not attempting to submit the optimizer. No other human should have to review such things. But there's no arguing with the actual timing results.

Test Coverage

I've updated the replay unit tests to the new paradigm and added some more coverage scenarios, including better coverage of the rounding. Also added unit tests to cache manager for modified PNAT/double buffer behavior, and mamba metadata to test its new functionality of generating the replay work items, needed for efficient persistent specialized kernel pairs.

Also, due to other Ultra instability I tested this quite a bit on Super and Ultra, which lead to discovery of the Triton compiler bug. After that I ran Super many times and results seem reasonble, no 0/ corruption. Plus given my efforts to find a self-bug in what turned out to be a compiler bug, this may be the most agentically inspected code of all time. (famous last words :) )

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Summary by CodeRabbit

  • New Features

    • Added end-to-end replay-state-update support with preallocated replay buffers and runtime population for speculative decoding.
    • Runtime wiring to pass per-decode replay work items and write counts into replay kernels.
  • Tests

    • Added and expanded CUDA unit tests and correctness benchmarks covering replay update behavior, quantization-aware checks, stochastic rounding, and multi-head scenarios.
  • Chores

    • Removed an obsolete replay benchmark script.

@hnover-nv hnover-nv force-pushed the mamba_checkpointing_submit branch from c1be203 to 32ebf3f Compare May 16, 2026 05:43
@hnover-nv hnover-nv changed the title Mamba checkpointing submit @coderabbit May 16, 2026
@hnover-nv hnover-nv changed the title @coderabbit @coderabbitai May 16, 2026
@hnover-nv

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #48673 [ run ] triggered by Bot. Commit: 32ebf3f Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #48673 [ run ] completed with state FAILURE. Commit: 32ebf3f
/LLM/main/L0_MergeRequest_PR pipeline #38453 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@hnover-nv hnover-nv force-pushed the mamba_checkpointing_submit branch from 32ebf3f to 1eae71e Compare May 17, 2026 01:44
@hnover-nv hnover-nv changed the title @coderabbitai [None][feat] Checkpointing variant of replay for MTP for mamba models May 17, 2026
@hnover-nv hnover-nv force-pushed the mamba_checkpointing_submit branch 2 times, most recently from a1f2532 to bab26f2 Compare May 17, 2026 21:09
@hnover-nv

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #48787 [ run ] triggered by Bot. Commit: bab26f2 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #48787 [ run ] completed with state SUCCESS. Commit: bab26f2
/LLM/main/L0_MergeRequest_PR pipeline #38551 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

CI Report

Link to invocation

@hnover-nv hnover-nv force-pushed the mamba_checkpointing_submit branch 4 times, most recently from 8f1d0aa to 24f3dda Compare May 21, 2026 02:06
@hnover-nv

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #49541 [ run ] triggered by Bot. Commit: 24f3dda Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #49541 [ run ] completed with state SUCCESS. Commit: 24f3dda
/LLM/main/L0_MergeRequest_PR pipeline #39171 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hnover-nv hnover-nv force-pushed the mamba_checkpointing_submit branch 4 times, most recently from 724fdf8 to 5c44835 Compare May 21, 2026 18:10
@hnover-nv

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #49775 [ run ] triggered by Bot. Commit: 5c44835 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #49775 [ run ] completed with state SUCCESS. Commit: 5c44835
/LLM/main/L0_MergeRequest_PR pipeline #39373 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@hnover-nv hnover-nv force-pushed the mamba_checkpointing_submit branch from 5c44835 to 5394fc9 Compare May 22, 2026 04:17
@hnover-nv

Copy link
Copy Markdown
Collaborator Author

/bot run

hnover-nv added 11 commits June 17, 2026 13:39
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Treat replay work-item and n_writes resources as manager-backed
when speculative replay is enabled, so AD disaggregated cache
validation does not reject them as unmanaged persistent state.

Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
The benchmark still passed precompute_num_stages, maxnreg, and num_ctas
after those kernel kwargs were removed, causing every bench invocation to fail
with TypeError. Remove the stale plumbing and fix _gen_from_cell_list yield to
match the reduced unpack count.

Also import the B200 precompute retunes for the int8/SR batch 256 and 512
default tuning cells.


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…r broke some paths.

Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
@hnover-nv

Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54874 [ run ] triggered by Bot. Commit: ca32494 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54874 [ run ] completed with state FAILURE. Commit: ca32494
/LLM/main/L0_MergeRequest_PR pipeline #43879 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@hnover-nv

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54880 [ run ] triggered by Bot. Commit: ca32494 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54880 [ run ] completed with state FAILURE. Commit: ca32494
/LLM/main/L0_MergeRequest_PR pipeline #43885 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@hnover-nv

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54943 [ run ] triggered by Bot. Commit: ca32494 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54943 [ run ] completed with state SUCCESS. Commit: ca32494
/LLM/main/L0_MergeRequest_PR pipeline #43943 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hnover-nv

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54949 [ run ] triggered by Bot. Commit: 4538e22 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54949 [ run ] completed with state FAILURE. Commit: 4538e22
/LLM/main/L0_MergeRequest_PR pipeline #43949 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@hnover-nv

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54960 [ run ] triggered by Bot. Commit: 4538e22 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54960 [ run ] completed with state SUCCESS. Commit: 4538e22
/LLM/main/L0_MergeRequest_PR pipeline #43960 completed with status: 'SUCCESS'

CI Report

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants