[None][feat] Checkpointing variant of replay for MTP for mamba models#14203
Conversation
c1be203 to
32ebf3f
Compare
|
/bot run |
|
PR_Github #48673 [ run ] triggered by Bot. Commit: |
|
PR_Github #48673 [ run ] completed with state
|
32ebf3f to
1eae71e
Compare
a1f2532 to
bab26f2
Compare
|
/bot run |
|
PR_Github #48787 [ run ] triggered by Bot. Commit: |
|
PR_Github #48787 [ run ] completed with state |
8f1d0aa to
24f3dda
Compare
|
/bot run |
|
PR_Github #49541 [ run ] triggered by Bot. Commit: |
|
PR_Github #49541 [ run ] completed with state
|
724fdf8 to
5c44835
Compare
|
/bot run |
|
PR_Github #49775 [ run ] triggered by Bot. Commit: |
|
PR_Github #49775 [ run ] completed with state
|
5c44835 to
5394fc9
Compare
|
/bot run |
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Treat replay work-item and n_writes resources as manager-backed when speculative replay is enabled, so AD disaggregated cache validation does not reject them as unmanaged persistent state. Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
The benchmark still passed precompute_num_stages, maxnreg, and num_ctas after those kernel kwargs were removed, causing every bench invocation to fail with TypeError. Remove the stale plumbing and fix _gen_from_cell_list yield to match the reduced unpack count. Also import the B200 precompute retunes for the int8/SR batch 256 and 512 default tuning cells. Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…r broke some paths. Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
|
/bot run |
|
PR_Github #54874 [ run ] triggered by Bot. Commit: |
|
PR_Github #54874 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #54880 [ run ] triggered by Bot. Commit: |
|
PR_Github #54880 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #54943 [ run ] triggered by Bot. Commit: |
|
PR_Github #54943 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #54949 [ run ] triggered by Bot. Commit: |
|
PR_Github #54949 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #54960 [ run ] triggered by Bot. Commit: |
|
PR_Github #54960 [ run ] completed with state |
Description
Summary of Changes
Quick summary, see ideas below for more detail.
We replace old per-step replay with a "checkpointing" replay, where we don't write out mamba state every step. Main changes:
While we were here:
Timing results
These are the time from the start of conv1d to the end of the last replay kernel. Replay state update is broken up into 2 or 3 kernels, and conv1d PDLs into the first one. It's a kernel microbenchmark that tries to simulate hot and cold inputs based on what is coming out of in_proj in the mamba2_mixer and what is not. he runtime depends on the mix of write and no-write requests, so to get a single score, we took the distribution of acceptance lengths of Super on a dataset, run with draft length 5, then treat those acceptance lengths as transitions for a Markov process and find the station distribution, then draw dummy requests with prev_num_accepted_token from that distribution.
Batch size is w.r.t. Nemotron v3 Super with TP=8. Across Nemotron families, all mamba heads are the same so all that really matters as you vary Super vs Ultra or TEP or ADP is # of mamba heads. I did not have time for e2e numbers, but the metadata computation is minimal up front, and unlike the prior replay vs "copy every draft state out", this should really only be changing kernel time.
Listed below are the timings vs old replay, and against the flashinfer checkpointing kernels from flashinfer-ai/flashinfer#3324. These kernels mostly win despite being Triton, due to the aggressie auto-tuning and the specialized kernel pair breakdown for larger batch (see below). Igor is actively working on FI / CUDA equivalents, which will surely be faster. These numbers used to be faster still, but our tunings ran into a triton compiler bug and we didn't have time to fully retune the workaround. Maybe later. See the bullet under "Other Tweaks" below.
We are faster than old replay everywhere, and faster than the new checkpointing flashinfer at 39/44 batch/dtype combos. More tuning time would probably help, especially anywhere that uses persistent_main vs persistent dynamic.
Ideas
Building on #13453, here we take the replay idea a step further, to a "checkpointing" replay. This replaces the old replay version in our code.
This section walks through the ideas, skip below to the tl;dr summary.
To summarize "old" replay, we would save a mamba state every step. It just lagged what would be the true state by one, because we don't know how many to accept. Compared to the prior method of saving every possible state at each step, we saved substantially on memory traffic in the kernel and at step-end when we'd otherwise copy the "winning" state over. And needing to only materialize one state let us save further by skipping a lot of computation, and accelerating what remained with tensor cores.
But there's no reason we have to write even one state every step. We can have an ssm state checkpoint arbitrarily far back, store all the needed inputs since, and reconstruct the state. When we don't accept tokens, we just overwrite the unaccepted ones with the next step's new ones. Of course, in the limit, that just becomes linear-ish attention where you kept the whole KV cache.
How often should we checkpoint? One could try to balance the memory traffic of the replay window with the state. But we're actually not just memory bound. The easiest choice is that our tensor core usage requires us to pad all our dimension T (draft length + 1) operations to at least 16, so that's a natural history size. In the prior PR you can see that the old kernels' runtimes were quite flat with respect to T up to 16, and that had larger T for both state replay and output generation. Here we replay up to 16 from history but only output T.
The algorithm, then is every step, if the number of previously accepted tokens (PNAT)+ T > window size, we can't store the T new tokens so we need to write a checkpoint and restart the history buffer. Otherwise we have room to append.
Already this approach saves time. We don't issue instructions for saving the state, and for quantized, stochastically rounded, or block scaled SSM dstates, we also save on that pre-writing computation. Also, by quantizing less often, we accumulate less quantization error, which means we could quantize more aggressively. For example, perhaps int8 w/ stochastic rounding is viable if it only happens every 8 tokens, but not every token. Speculative decoding already gives some of that benefit, but it depends on acceptance rate. Here, we guarantee 16 - T tokens between checkpoints, and it can be as high as 16.
This approach also unlocks a number of additional speedups:
The Rectangle
Ignoring dA/dt scaling factors, the original replay algorithm is:
last_step_state = very_old_state + (old_X^t @ old_B)
outputs = new_C @ one_back_state^t + (new_C @ new_B^t) @ new_X. # bypass all the draft states, straight to output tokens. That C @ B^t we call the CB matrix, it's TxT square lower triangular, padded into 16x16.
Here old_* are last step's new_*, except we only take the first k of them, the acceptance rate.
All the matrix multiplies have some dimension at T, padded to 16.
But if we aren't writing, then just as we go straight from one_back_state to output tokens, we can actually go straight from very_old_state. If you just plug last_step_state in and expand, you get:
outputs = new_C @ very_old_state^t + ( new_C @ (old_B cat new_B)^t) @ (old_X + new_X)
Now the CB matrix is a rectangular lower triangular (PNAT + T) x T, and this fits in our window size (16) exactly when PNAT + T <= window_size. Oh look, our non-checkpointing condition!
The old state computation is one of the largest matmuls, so this is a substantial savings. But due to implementation details it is not always faster at small batch sizes.
constexpr-specialized Persistent Kernels
For a given request, sometimes we now write a checkpoint and sometimes don't. And when we don't we may even use a different algorithm. We can specialize nowrite algorithm choice as a constexpr we tune based on batch/dtype/rounding mode, but not write vs nowrite. Or can we? If we do specialize them, the compiler can take a lot of advantage. Nowrite should need less live state, even more so when it uses rectangle or writing uses stochastic rounding or quantizing. The compiler generates very different code. We also build on our prior PR where we tune parameters like M (slice of a head_dim done by a block), number of warps, etc. Separate kernels lets us tune these knobs for each case.
Ideally we'd launch one correctly sized grid for all the write requests, and one for the nowrite. But we're using cuda graphs in pytorch, so we can't vary that. Instead, we need to launch enough CTAs to cover all requests as write, and all as nowrite. Although the kernel can early out on requests that aren't its type, that overhead made specialized kernels uncompetative.
Except, we can use persistent kernels. As long as we have enough work to fill all the SMs a few times over, then there's no real overhead. And if we add some secondary data structures to present the requests partitioned by write/nowrite, the kernels can just loop over their portion. If we pack other needed data into that structure, it's not even added latency.
This makes specialized kernels a huge win at medium and large batches. At smaller batches a "dynamic"(i.e. unspecialized, non-constexpr) kernel still wins, probably because we can't easily optimally PDL in PyTorch between the precomputation kernel and the two main kernels. And we didn't have time to explore emulating this with atomics. But the dynamic kernel also wins by being persistent. Similarly constrained by time, the precompute kernel is non-persistent dynamic, whereas at least one of persistent or specialized would surely help.
Other tweaks
Tunings
We implemented a huge number of tunable knobs, and can pick which knobs to use based on batch size, dtype, and whether we are doing stochastic rounding. But how to pick the best? We have a mode choice of a single dynamic or a pair of specialized main kernels. Then each of these offers many tuning knobs. When we pick the specialized kernels, ther are a total of of 17 active knobs (more are in the code, but a few either don't compile or weren't tuned as they seemed dead), a mixture of boolean and integer. This resulted in over 61 billion possible combinations. While we'd hoped to be able to drop a lot of them, exhaustive searching of select subspaces showed none of the 17 were always dead. Obviously we can't exhaustively search the whole space for even one input cofiguration, let alone for all supported batch/dtype/rounding combinations. So we implemented a multi-process block coordinate descent optimization, with some tweaks to deal with cross-GPU performance drift. We also had to invest considerable effort into the actual benchmark runner: in-process CUPTI timing vs cuda events, multi-process for compilation and cupti analysis, optimization of cuda graph instantiation vs lauch overhead due to GIL contention.
These are honestly giant vibe coded messes, so we're removing the benchmarker from the repo and not attempting to submit the optimizer. No other human should have to review such things. But there's no arguing with the actual timing results.
Test Coverage
I've updated the replay unit tests to the new paradigm and added some more coverage scenarios, including better coverage of the rounding. Also added unit tests to cache manager for modified PNAT/double buffer behavior, and mamba metadata to test its new functionality of generating the replay work items, needed for efficient persistent specialized kernel pairs.
Also, due to other Ultra instability I tested this quite a bit on Super and Ultra, which lead to discovery of the Triton compiler bug. After that I ran Super many times and results seem reasonble, no 0/ corruption. Plus given my efforts to find a self-bug in what turned out to be a compiler bug, this may be the most agentically inspected code of all time. (famous last words :) )
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit
New Features
Tests
Chores