Replace chunked FLA with recurrent gated delta rule for T=1 decode#18667
Conversation
The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode. Replace with plain PyTorch einsum ops that Inductor can fuse: - FLA GPU time: 1.085ms → 0.344ms/step (-68%) - Total GPU time: 12.0ms → 9.0ms/step (-25%) - Export changed to static T=1 with enable_dynamic_shape=False
This PR needs a
|
Move decode/prefill dispatch inside the chunk_gated_delta_rule triton_op instead of using torch.cond at model level. This follows the same pattern as the SDPA triton_op (pow2/non-pow2 dispatch) and avoids torch.cond incompatibility with AOTI's FunctionalTensor pipeline. Changes: - chunk_gated_delta_rule.py: Add fused recurrent Triton kernel for T=1, refactor chunked pipeline into _launch_chunked(), dispatch via Python if inside the @triton_op wrapper - model.py: Remove torch.cond from GatedDeltaNet.forward(), call triton_op directly (dispatch is internal) - export.py: Single-method export with dynamic seq_len dim - main.cpp: Fix create_text_llm_runner API signature
Only chunk_gated_delta_rule.py needs modification — dispatch logic is internal to the triton_op, no model/export/runner changes needed.
- test_recurrent_t1: verify T=1 recurrent kernel against FLA naive
reference across all FLA test configs
- test_dispatch_multiple_seq_lengths: verify correctness for
T in {1, 2, 32, 63, 64, 65, 128, 256}, covering both dispatch
paths and chunk boundary edge cases
- Grid changed from (B*H,) to (V//BV, B*H) — 4x more blocks, better SM occupancy (128 blocks vs 32 on A100) - BV reduced from 128 to 32 — lower register pressure, no spilling - Removed unnecessary .contiguous() copies on squeezed inputs - Removed debug print from triton_op dispatch - GPU kernel time: 6us (3.47x faster than Inductor-fused native ops)
- Split model into prefill (chunked FLA triton_op) and decode (native PyTorch recurrent delta rule) methods with explicit state passing - Add runtime_specs processing in CudaBackend::init() so LoadBackendOptionsMap options (skip_copy_output_to_cpu, use_shared_cuda_stream) take effect - Keep state tensors GPU-resident across method calls; only copy logits to CPU for sampling via cudaMemcpy - Achieves 77.4 tok/s decode (3.75x over naive dual-method baseline) Modified files: - cuda_backend.cpp: read runtime_specs in init() for skip_copy + shared stream - main.cpp: dual-method runner with GPU-resident state, logits CPU copy helper - CMakeLists.txt: link CUDA::cudart for cudaMemcpy - model.py: dual-method model definition (prefill + decode) - export.py: export script for dual-method PTE
Revert from explicit state passing back to registered buffers with in-place updates (KVCache, conv_state, recurrent_state). Export with share_mutable_buffers=True so both prefill and forward methods share mutable state via mem_id=2. C++ runner uses share_memory_arenas=true and only passes (tokens, input_pos) — no CUDA runtime dependency. Results: 84.5 tok/s (up from 77.4), 0 select_scatter ops in profile, 65 D2H memcpy (logits only).
| output, state = torch.ops.triton.chunk_gated_delta_rule( | ||
| q, k, v, g, beta, self.recurrent_state[:B] | ||
| ) | ||
| if T == 1: |
There was a problem hiding this comment.
I guess since AOTI traces through there is no way to support this (i.e. prefill vs. decode paths) at runtime, right? I am asking because this ultimately has implications on differentiating kernels based on shape at runtime for instance.
There was a problem hiding this comment.
if we put it in a same method then we can't support it. Here we do twice export with differnt T and make each exported result a single method.
There was a problem hiding this comment.
This seems to be working for me - https://github.com/pytorch/executorch/pull/18759/changes#diff-0c0f640c387c16ef6e725941e2281a0f971b769feaf9716760865ddd968943afR293
There was a problem hiding this comment.
Yes torch.cond works. I mean if we follow the current if-else pattern it won't work.
Add runtime buffer sharing between AOTI containers so that prefill and decode methods operate on the same GPU tensors (KV cache, conv_state, etc.) without unnecessary H2D/D2H copies or getter/setter overhead. The first container to initialize extracts its constants (keyed by original FQN). Subsequent containers with matching FQNs are updated via AOTInductorModelContainerUpdateUserManagedConstantBufferPairs to point to the same GPU memory (user_managed = true, no copy). Also switch main.cpp prefill to token-by-token decode path while the chunked FLA triton_op numerical issue is being resolved. Tested E2E: "What is the capital of France?" → "Paris" with 966 constants shared between prefill and decode containers on A100.
- cuda_backend.cpp: Use codegen name (from GetConstantName) instead of original FQN when calling UpdateUserManagedConstantBufferPairs. The AOTI API matches against internal codegen names, not FQNs — using FQNs caused silent no-op sharing, breaking KV cache flow between prefill and decode. - main.cpp: Add chunked prefill path using the "prefill" method (T>=2) with cudaDeviceSynchronize between prefill and decode for cross-stream safety. Add --decode_only flag to fall back to token-by-token decode for all tokens. - inference.py: Update docstring to reflect that chunked FLA is used in PTE mode (not eager). Verified E2E: "What is the capital of France?" → "The capital of France is Paris." Prefill: 105 tok/s (chunked FLA), Decode: 87 tok/s (recurrent delta rule).
- cuda_backend.cpp: Replace debug printf with ET_LOG for errors/info only - main.cpp: Remove --decode_only flag, keep only chunked prefill path
- cuda_backend.cpp: Replace ET_CHECK_OK_OR_RETURN_ERROR with explicit error handling + cudaDeviceSynchronize after weight transfer, add logging for missing weights_blob - main.cpp: Support single "forward" method fallback when prefill/decode not available, use prefill_method variable, remove debug printf
| """Export model to .pte via torch.export + CUDA backend.""" | ||
| """Export model to .pte via torch.export + CUDA backend. | ||
|
|
||
| Exports two methods: |
There was a problem hiding this comment.
if the selection of FLA is the only thing, then why do we need to have two different methods? Can't we just do torch.cond on the Lq?
There was a problem hiding this comment.
i've tried the single method with torch.cond; it can generate correct output after some fix but perf is not as good as our expectation due to:
torch.condoperator blocks operator fusion between recurrent FLA and other pytorch ops surround it, blocking the perf gain- after introduce
torch.condwe increase about 20% kernel invoke -- increasing cpu time - we can't solve 2 by using cuda graph cuz cuda graph need a static graph and can't support
torch.cond
you can find my experiment here: https://github.com/pytorch/executorch/tree/torch-cond-single-method
There was a problem hiding this comment.
yeah perf may be bad because it may require sync.
There was a problem hiding this comment.
hmm tbh just from my profiling result no extra sync is observed
| if (it != shared_constant_tensors_.end()) { | ||
| // UpdateUserManagedConstantBufferPairs matches against the | ||
| // codegen constant name (underscored), not the original FQN. | ||
| pairs.push_back({internal_name.c_str(), it->second}); |
There was a problem hiding this comment.
Should we put a check to make sure the shape or size is also same besides fqn?
There was a problem hiding this comment.
hmm we can add it but for our case both containers are exported from the same model, so matching FQNs guarantee matching shapes by construction.
| } | ||
| } | ||
|
|
||
| std::lock_guard<std::mutex> guard(shared_constants_mutex_); |
There was a problem hiding this comment.
Do we expect multiple init() calls from different threads? If yes, do we have any tests for it?
There was a problem hiding this comment.
no in aoti-cuda we only expect one thread.
| } | ||
|
|
||
| if (!pairs.empty()) { | ||
| auto update_err = handle->update_user_managed_constant_buffer_pairs( |
There was a problem hiding this comment.
Is KV-Cache buffer constant or just in name?
If its modified during the run then, perhaps a naive question, who manages mutually exclusive access to the buffer itself? I.e. for some reason someone runs Preill() followed by decode twice on two threads.
There was a problem hiding this comment.
KV cache is mutable -- AOTI's "constants" API is a misnomer that covers both immutable weights and mutable state buffers.
For aoti backend now we don't support mutli-thread scenerio so that we don't manager the mutually exclusive.
I think we may need to add some regulation on that. Let me try to make it in another PR.
| if (!pairs.empty()) { | ||
| auto update_err = handle->update_user_managed_constant_buffer_pairs( | ||
| handle->container_handle, | ||
| pairs.data(), |
There was a problem hiding this comment.
if we are using share_mutable_buffers for memory planning and ET owns the memory shouldn't these two pointers already be same? sorry if this is also a naive question
There was a problem hiding this comment.
Not really. share_mutable_buffers operates at the ExecuTorch memory planning level while AOTI containers are opaque to ET's memory planner; their internal GPU allocations (ConstantMap) are managed by the AOTI runtime, not by ET.
This mechanism bridges the gap by sharing AOTI-internal buffers across containers at runtime. Both are needed for the full pipeline.
digantdesai
left a comment
There was a problem hiding this comment.
Left some comments, but LGTM. Stamping to unblock you (and myself :p)
…ytorch#18667) The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode while is a good fit for prefilling stage. This PR decomposes the single `forward` method with separate `prefill` and `decode` methods that use different FLA implementation (chunked for `prefilll` while recurrent for `decode`) while share KV cache, conv_state, and recurrent_state to boost the decode performance from 77.7 token/s to 88.3 token/s, while maintain the prefill performance. --------- Co-authored-by: gasoonjia <gasoonjia@fb.com>
The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode while is a good fit for prefilling stage.
This PR decomposes the single
forwardmethod with separateprefillanddecodemethods that use different FLA implementation (chunked forprefilllwhile recurrent fordecode) while share KV cache, conv_state, and recurrent_state to boost the decode performance from 77.7 token/s to 88.3 token/s, while maintain the prefill performance.