Skip to content

Replace chunked FLA with recurrent gated delta rule for T=1 decode#18667

Merged
Gasoonjia merged 24 commits into
mainfrom
recurrent-fla
Apr 10, 2026
Merged

Replace chunked FLA with recurrent gated delta rule for T=1 decode#18667
Gasoonjia merged 24 commits into
mainfrom
recurrent-fla

Conversation

@Gasoonjia
Copy link
Copy Markdown
Contributor

@Gasoonjia Gasoonjia commented Apr 2, 2026

The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode while is a good fit for prefilling stage.
This PR decomposes the single forward method with separate prefill and decode methods that use different FLA implementation (chunked for prefilll while recurrent for decode) while share KV cache, conv_state, and recurrent_state to boost the decode performance from 77.7 token/s to 88.3 token/s, while maintain the prefill performance.

The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode.
Replace with plain PyTorch einsum ops that Inductor can fuse:
- FLA GPU time: 1.085ms → 0.344ms/step (-68%)
- Total GPU time: 12.0ms → 9.0ms/step (-25%)
- Export changed to static T=1 with enable_dynamic_shape=False
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 2, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18667

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures, 25 Cancelled Jobs, 10 Pending, 14 Unrelated Failures

As of commit 89d3fe6 with merge base 5e8a0df (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 2, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 2, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Comment thread examples/models/qwen3_5_moe/model.py Outdated
Gasoonjia and others added 9 commits April 2, 2026 20:49
Move decode/prefill dispatch inside the chunk_gated_delta_rule triton_op
instead of using torch.cond at model level. This follows the same pattern
as the SDPA triton_op (pow2/non-pow2 dispatch) and avoids torch.cond
incompatibility with AOTI's FunctionalTensor pipeline.

Changes:
- chunk_gated_delta_rule.py: Add fused recurrent Triton kernel for T=1,
  refactor chunked pipeline into _launch_chunked(), dispatch via Python
  if inside the @triton_op wrapper
- model.py: Remove torch.cond from GatedDeltaNet.forward(), call
  triton_op directly (dispatch is internal)
- export.py: Single-method export with dynamic seq_len dim
- main.cpp: Fix create_text_llm_runner API signature
Only chunk_gated_delta_rule.py needs modification — dispatch logic
is internal to the triton_op, no model/export/runner changes needed.
- test_recurrent_t1: verify T=1 recurrent kernel against FLA naive
  reference across all FLA test configs
- test_dispatch_multiple_seq_lengths: verify correctness for
  T in {1, 2, 32, 63, 64, 65, 128, 256}, covering both dispatch
  paths and chunk boundary edge cases
- Grid changed from (B*H,) to (V//BV, B*H) — 4x more blocks, better SM
  occupancy (128 blocks vs 32 on A100)
- BV reduced from 128 to 32 — lower register pressure, no spilling
- Removed unnecessary .contiguous() copies on squeezed inputs
- Removed debug print from triton_op dispatch
- GPU kernel time: 6us (3.47x faster than Inductor-fused native ops)
- Split model into prefill (chunked FLA triton_op) and decode (native PyTorch
  recurrent delta rule) methods with explicit state passing
- Add runtime_specs processing in CudaBackend::init() so LoadBackendOptionsMap
  options (skip_copy_output_to_cpu, use_shared_cuda_stream) take effect
- Keep state tensors GPU-resident across method calls; only copy logits to CPU
  for sampling via cudaMemcpy
- Achieves 77.4 tok/s decode (3.75x over naive dual-method baseline)

Modified files:
- cuda_backend.cpp: read runtime_specs in init() for skip_copy + shared stream
- main.cpp: dual-method runner with GPU-resident state, logits CPU copy helper
- CMakeLists.txt: link CUDA::cudart for cudaMemcpy
- model.py: dual-method model definition (prefill + decode)
- export.py: export script for dual-method PTE
Revert from explicit state passing back to registered buffers with
in-place updates (KVCache, conv_state, recurrent_state). Export with
share_mutable_buffers=True so both prefill and forward methods share
mutable state via mem_id=2. C++ runner uses share_memory_arenas=true
and only passes (tokens, input_pos) — no CUDA runtime dependency.

Results: 84.5 tok/s (up from 77.4), 0 select_scatter ops in profile,
65 D2H memcpy (logits only).
@Gasoonjia Gasoonjia changed the title [WIP] Replace chunked FLA with recurrent gated delta rule for T=1 decode Replace chunked FLA with recurrent gated delta rule for T=1 decode Apr 6, 2026
output, state = torch.ops.triton.chunk_gated_delta_rule(
q, k, v, g, beta, self.recurrent_state[:B]
)
if T == 1:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess since AOTI traces through there is no way to support this (i.e. prefill vs. decode paths) at runtime, right? I am asking because this ultimately has implications on differentiating kernels based on shape at runtime for instance.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we put it in a same method then we can't support it. Here we do twice export with differnt T and make each exported result a single method.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes torch.cond works. I mean if we follow the current if-else pattern it won't work.

gasoonjia and others added 4 commits April 6, 2026 18:38
Add runtime buffer sharing between AOTI containers so that prefill and
decode methods operate on the same GPU tensors (KV cache, conv_state,
etc.) without unnecessary H2D/D2H copies or getter/setter overhead.

The first container to initialize extracts its constants (keyed by
original FQN). Subsequent containers with matching FQNs are updated via
AOTInductorModelContainerUpdateUserManagedConstantBufferPairs to point
to the same GPU memory (user_managed = true, no copy).

Also switch main.cpp prefill to token-by-token decode path while the
chunked FLA triton_op numerical issue is being resolved.

Tested E2E: "What is the capital of France?" → "Paris" with 966
constants shared between prefill and decode containers on A100.
- cuda_backend.cpp: Use codegen name (from GetConstantName) instead of
  original FQN when calling UpdateUserManagedConstantBufferPairs. The AOTI
  API matches against internal codegen names, not FQNs — using FQNs caused
  silent no-op sharing, breaking KV cache flow between prefill and decode.

- main.cpp: Add chunked prefill path using the "prefill" method (T>=2) with
  cudaDeviceSynchronize between prefill and decode for cross-stream safety.
  Add --decode_only flag to fall back to token-by-token decode for all tokens.

- inference.py: Update docstring to reflect that chunked FLA is used in PTE
  mode (not eager).

Verified E2E: "What is the capital of France?" → "The capital of France is Paris."
Prefill: 105 tok/s (chunked FLA), Decode: 87 tok/s (recurrent delta rule).
- cuda_backend.cpp: Replace debug printf with ET_LOG for errors/info only
- main.cpp: Remove --decode_only flag, keep only chunked prefill path
- cuda_backend.cpp: Replace ET_CHECK_OK_OR_RETURN_ERROR with explicit error
  handling + cudaDeviceSynchronize after weight transfer, add logging for
  missing weights_blob
- main.cpp: Support single "forward" method fallback when prefill/decode
  not available, use prefill_method variable, remove debug printf
@Gasoonjia Gasoonjia marked this pull request as ready for review April 9, 2026 20:40
@Gasoonjia Gasoonjia requested a review from lucylq as a code owner April 9, 2026 20:40
"""Export model to .pte via torch.export + CUDA backend."""
"""Export model to .pte via torch.export + CUDA backend.

Exports two methods:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the selection of FLA is the only thing, then why do we need to have two different methods? Can't we just do torch.cond on the Lq?

Copy link
Copy Markdown
Contributor Author

@Gasoonjia Gasoonjia Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i've tried the single method with torch.cond; it can generate correct output after some fix but perf is not as good as our expectation due to:

  1. torch.cond operator blocks operator fusion between recurrent FLA and other pytorch ops surround it, blocking the perf gain
  2. after introduce torch.cond we increase about 20% kernel invoke -- increasing cpu time
  3. we can't solve 2 by using cuda graph cuz cuda graph need a static graph and can't support torch.cond

you can find my experiment here: https://github.com/pytorch/executorch/tree/torch-cond-single-method

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah perf may be bad because it may require sync.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm tbh just from my profiling result no extra sync is observed

if (it != shared_constant_tensors_.end()) {
// UpdateUserManagedConstantBufferPairs matches against the
// codegen constant name (underscored), not the original FQN.
pairs.push_back({internal_name.c_str(), it->second});
Copy link
Copy Markdown
Contributor

@digantdesai digantdesai Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we put a check to make sure the shape or size is also same besides fqn?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm we can add it but for our case both containers are exported from the same model, so matching FQNs guarantee matching shapes by construction.

}
}

std::lock_guard<std::mutex> guard(shared_constants_mutex_);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect multiple init() calls from different threads? If yes, do we have any tests for it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no in aoti-cuda we only expect one thread.

}

if (!pairs.empty()) {
auto update_err = handle->update_user_managed_constant_buffer_pairs(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is KV-Cache buffer constant or just in name?
If its modified during the run then, perhaps a naive question, who manages mutually exclusive access to the buffer itself? I.e. for some reason someone runs Preill() followed by decode twice on two threads.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KV cache is mutable -- AOTI's "constants" API is a misnomer that covers both immutable weights and mutable state buffers.

For aoti backend now we don't support mutli-thread scenerio so that we don't manager the mutually exclusive.

I think we may need to add some regulation on that. Let me try to make it in another PR.

if (!pairs.empty()) {
auto update_err = handle->update_user_managed_constant_buffer_pairs(
handle->container_handle,
pairs.data(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are using share_mutable_buffers for memory planning and ET owns the memory shouldn't these two pointers already be same? sorry if this is also a naive question

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. share_mutable_buffers operates at the ExecuTorch memory planning level while AOTI containers are opaque to ET's memory planner; their internal GPU allocations (ConstantMap) are managed by the AOTI runtime, not by ET.
This mechanism bridges the gap by sharing AOTI-internal buffers across containers at runtime. Both are needed for the full pipeline.

Copy link
Copy Markdown
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments, but LGTM. Stamping to unblock you (and myself :p)

@Gasoonjia Gasoonjia merged commit 266ff2d into main Apr 10, 2026
171 of 243 checks passed
@Gasoonjia Gasoonjia deleted the recurrent-fla branch April 10, 2026 20:44
jpiat pushed a commit to jpiat/executorch that referenced this pull request Apr 14, 2026
…ytorch#18667)

The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode
while is a good fit for prefilling stage.
This PR decomposes the single `forward` method with separate `prefill`
and `decode` methods that use different FLA implementation (chunked for
`prefilll` while recurrent for `decode`) while share KV cache,
conv_state, and recurrent_state to boost the decode performance from
77.7 token/s to 88.3 token/s, while maintain the prefill performance.

---------

Co-authored-by: gasoonjia <gasoonjia@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants