Skip to content

Replace chunked FLA with recurrent gated delta rule for T=1 decode#18667

Draft
Gasoonjia wants to merge 14 commits intomainfrom
recurrent-fla
Draft

Replace chunked FLA with recurrent gated delta rule for T=1 decode#18667
Gasoonjia wants to merge 14 commits intomainfrom
recurrent-fla

Conversation

@Gasoonjia
Copy link
Copy Markdown
Contributor

@Gasoonjia Gasoonjia commented Apr 2, 2026

The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode while is a good fit for prefilling stage.
This PR decomposes the single forward method with separate prefill and decode methods that use different FLA implementation (chunked for prefilll while recurrent for decode) while share KV cache, conv_state, and recurrent_state to boost the decode performance from 77.7 token/s to xxx token/s, while maintain the prefill performance.

The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode.
Replace with plain PyTorch einsum ops that Inductor can fuse:
- FLA GPU time: 1.085ms → 0.344ms/step (-68%)
- Total GPU time: 12.0ms → 9.0ms/step (-25%)
- Export changed to static T=1 with enable_dynamic_shape=False
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 2, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18667

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 21 Pending

As of commit 2b36797 with merge base 3466332 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 2, 2026
@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 2, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

q, k, v, g, beta, self.recurrent_state[:B]
# Recurrent gated delta rule — single-step update.
# The model is exported with static T=1 and the C++ runner does
# token-by-token prefill (enable_dynamic_shape=False), so T is
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any impact on prefill performance?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will impact prefill performance to only half; the ongoing fix will make sure that prefill will use chunked implementation while decode uses recurrent one.

Gasoonjia and others added 9 commits April 2, 2026 20:49
Move decode/prefill dispatch inside the chunk_gated_delta_rule triton_op
instead of using torch.cond at model level. This follows the same pattern
as the SDPA triton_op (pow2/non-pow2 dispatch) and avoids torch.cond
incompatibility with AOTI's FunctionalTensor pipeline.

Changes:
- chunk_gated_delta_rule.py: Add fused recurrent Triton kernel for T=1,
  refactor chunked pipeline into _launch_chunked(), dispatch via Python
  if inside the @triton_op wrapper
- model.py: Remove torch.cond from GatedDeltaNet.forward(), call
  triton_op directly (dispatch is internal)
- export.py: Single-method export with dynamic seq_len dim
- main.cpp: Fix create_text_llm_runner API signature
Only chunk_gated_delta_rule.py needs modification — dispatch logic
is internal to the triton_op, no model/export/runner changes needed.
- test_recurrent_t1: verify T=1 recurrent kernel against FLA naive
  reference across all FLA test configs
- test_dispatch_multiple_seq_lengths: verify correctness for
  T in {1, 2, 32, 63, 64, 65, 128, 256}, covering both dispatch
  paths and chunk boundary edge cases
- Grid changed from (B*H,) to (V//BV, B*H) — 4x more blocks, better SM
  occupancy (128 blocks vs 32 on A100)
- BV reduced from 128 to 32 — lower register pressure, no spilling
- Removed unnecessary .contiguous() copies on squeezed inputs
- Removed debug print from triton_op dispatch
- GPU kernel time: 6us (3.47x faster than Inductor-fused native ops)
- Split model into prefill (chunked FLA triton_op) and decode (native PyTorch
  recurrent delta rule) methods with explicit state passing
- Add runtime_specs processing in CudaBackend::init() so LoadBackendOptionsMap
  options (skip_copy_output_to_cpu, use_shared_cuda_stream) take effect
- Keep state tensors GPU-resident across method calls; only copy logits to CPU
  for sampling via cudaMemcpy
- Achieves 77.4 tok/s decode (3.75x over naive dual-method baseline)

Modified files:
- cuda_backend.cpp: read runtime_specs in init() for skip_copy + shared stream
- main.cpp: dual-method runner with GPU-resident state, logits CPU copy helper
- CMakeLists.txt: link CUDA::cudart for cudaMemcpy
- model.py: dual-method model definition (prefill + decode)
- export.py: export script for dual-method PTE
Revert from explicit state passing back to registered buffers with
in-place updates (KVCache, conv_state, recurrent_state). Export with
share_mutable_buffers=True so both prefill and forward methods share
mutable state via mem_id=2. C++ runner uses share_memory_arenas=true
and only passes (tokens, input_pos) — no CUDA runtime dependency.

Results: 84.5 tok/s (up from 77.4), 0 select_scatter ops in profile,
65 D2H memcpy (logits only).
@Gasoonjia Gasoonjia changed the title [WIP] Replace chunked FLA with recurrent gated delta rule for T=1 decode Replace chunked FLA with recurrent gated delta rule for T=1 decode Apr 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants