Skip to content

Add fused GatedDeltaNet decode Triton kernel#18865

Merged
Gasoonjia merged 61 commits intomainfrom
fused-deltanet-decode
Apr 27, 2026
Merged

Add fused GatedDeltaNet decode Triton kernel#18865
Gasoonjia merged 61 commits intomainfrom
fused-deltanet-decode

Conversation

@Gasoonjia
Copy link
Copy Markdown
Contributor

@Gasoonjia Gasoonjia commented Apr 14, 2026

Fuse Q/K/V split, L2 normalization, head repeat, gating computation, and delta-rule recurrent state update into a single Triton kernel for decode (T=1). Replaces ~6 small AOTI-generated kernels with one, reducing GatedDeltaNet kernel time by ~62%.

Config performance (compare with last optimization)
p=128 d=128 156.7 (+6.8)
p=128 d=512 160.8 (+7.1)
p=256 d=128 156.1 (+6.5)
p=256 d=512 160.8 (+7.3)
p=512 d=128 156.0 (+6.9)
p=512 d=512 160.9 (+8.0)
p=1024 d=128 156.3 (+7.7)
p=1024 d=512 160.0 (+6.6)
p=2048 d=128 154.8 (+6.4)
p=2048 d=512 160.6 (+7.5)
Average 158.3 (+7.1)

Gasoonjia and others added 24 commits April 1, 2026 23:06
The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode.
Replace with plain PyTorch einsum ops that Inductor can fuse:
- FLA GPU time: 1.085ms → 0.344ms/step (-68%)
- Total GPU time: 12.0ms → 9.0ms/step (-25%)
- Export changed to static T=1 with enable_dynamic_shape=False
Move decode/prefill dispatch inside the chunk_gated_delta_rule triton_op
instead of using torch.cond at model level. This follows the same pattern
as the SDPA triton_op (pow2/non-pow2 dispatch) and avoids torch.cond
incompatibility with AOTI's FunctionalTensor pipeline.

Changes:
- chunk_gated_delta_rule.py: Add fused recurrent Triton kernel for T=1,
  refactor chunked pipeline into _launch_chunked(), dispatch via Python
  if inside the @triton_op wrapper
- model.py: Remove torch.cond from GatedDeltaNet.forward(), call
  triton_op directly (dispatch is internal)
- export.py: Single-method export with dynamic seq_len dim
- main.cpp: Fix create_text_llm_runner API signature
Only chunk_gated_delta_rule.py needs modification — dispatch logic
is internal to the triton_op, no model/export/runner changes needed.
- test_recurrent_t1: verify T=1 recurrent kernel against FLA naive
  reference across all FLA test configs
- test_dispatch_multiple_seq_lengths: verify correctness for
  T in {1, 2, 32, 63, 64, 65, 128, 256}, covering both dispatch
  paths and chunk boundary edge cases
- Grid changed from (B*H,) to (V//BV, B*H) — 4x more blocks, better SM
  occupancy (128 blocks vs 32 on A100)
- BV reduced from 128 to 32 — lower register pressure, no spilling
- Removed unnecessary .contiguous() copies on squeezed inputs
- Removed debug print from triton_op dispatch
- GPU kernel time: 6us (3.47x faster than Inductor-fused native ops)
- Split model into prefill (chunked FLA triton_op) and decode (native PyTorch
  recurrent delta rule) methods with explicit state passing
- Add runtime_specs processing in CudaBackend::init() so LoadBackendOptionsMap
  options (skip_copy_output_to_cpu, use_shared_cuda_stream) take effect
- Keep state tensors GPU-resident across method calls; only copy logits to CPU
  for sampling via cudaMemcpy
- Achieves 77.4 tok/s decode (3.75x over naive dual-method baseline)

Modified files:
- cuda_backend.cpp: read runtime_specs in init() for skip_copy + shared stream
- main.cpp: dual-method runner with GPU-resident state, logits CPU copy helper
- CMakeLists.txt: link CUDA::cudart for cudaMemcpy
- model.py: dual-method model definition (prefill + decode)
- export.py: export script for dual-method PTE
Revert from explicit state passing back to registered buffers with
in-place updates (KVCache, conv_state, recurrent_state). Export with
share_mutable_buffers=True so both prefill and forward methods share
mutable state via mem_id=2. C++ runner uses share_memory_arenas=true
and only passes (tokens, input_pos) — no CUDA runtime dependency.

Results: 84.5 tok/s (up from 77.4), 0 select_scatter ops in profile,
65 D2H memcpy (logits only).
Add runtime buffer sharing between AOTI containers so that prefill and
decode methods operate on the same GPU tensors (KV cache, conv_state,
etc.) without unnecessary H2D/D2H copies or getter/setter overhead.

The first container to initialize extracts its constants (keyed by
original FQN). Subsequent containers with matching FQNs are updated via
AOTInductorModelContainerUpdateUserManagedConstantBufferPairs to point
to the same GPU memory (user_managed = true, no copy).

Also switch main.cpp prefill to token-by-token decode path while the
chunked FLA triton_op numerical issue is being resolved.

Tested E2E: "What is the capital of France?" → "Paris" with 966
constants shared between prefill and decode containers on A100.
- cuda_backend.cpp: Use codegen name (from GetConstantName) instead of
  original FQN when calling UpdateUserManagedConstantBufferPairs. The AOTI
  API matches against internal codegen names, not FQNs — using FQNs caused
  silent no-op sharing, breaking KV cache flow between prefill and decode.

- main.cpp: Add chunked prefill path using the "prefill" method (T>=2) with
  cudaDeviceSynchronize between prefill and decode for cross-stream safety.
  Add --decode_only flag to fall back to token-by-token decode for all tokens.

- inference.py: Update docstring to reflect that chunked FLA is used in PTE
  mode (not eager).

Verified E2E: "What is the capital of France?" → "The capital of France is Paris."
Prefill: 105 tok/s (chunked FLA), Decode: 87 tok/s (recurrent delta rule).
- cuda_backend.cpp: Replace debug printf with ET_LOG for errors/info only
- main.cpp: Remove --decode_only flag, keep only chunked prefill path
- cuda_backend.cpp: Replace ET_CHECK_OK_OR_RETURN_ERROR with explicit error
  handling + cudaDeviceSynchronize after weight transfer, add logging for
  missing weights_blob
- main.cpp: Support single "forward" method fallback when prefill/decode
  not available, use prefill_method variable, remove debug printf
Implements CUDA graph support in the CUDA backend to reduce CPU kernel
launch overhead during autoregressive decoding:

- cuda_backend.cpp: 3-phase execution (warmup → capture → replay) with
  static input/output GPU buffers, cudaMemcpyAsync for I/O, and
  cudaGraphInstantiateFlagAutoFreeOnLaunch for cudaMallocAsync compat
- cuda_delegate_handle.h: CUDA graph state (phase, graph objects, static
  buffer metadata) with RAII cleanup in destructor
- main.cpp: --cuda_graph flag that sets BackendOptions before load_method
- test_model_e2e.sh: Enable --cuda_graph for Qwen3.5 MoE CI, set
  PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync

Benchmark (A100, Qwen3.5-35B-A3B HQQ-INT4): 82→98 tok/s (1.20x)
Fuse Q/K/V split, L2 normalization, head repeat, gating computation,
and delta-rule recurrent state update into a single Triton kernel for
decode (T=1). Replaces ~6 small AOTI-generated kernels with one,
reducing GatedDeltaNet kernel time by ~62% and improving end-to-end
decode throughput by ~2% (106 -> 108.5 tok/s on A100).
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 14, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18865

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 14, 2026
@Gasoonjia Gasoonjia temporarily deployed to upload-benchmark-results April 14, 2026 07:39 — with GitHub Actions Inactive
Gasoonjia and others added 11 commits April 22, 2026 22:14
Replace online softmax (per-tile max tracking + cross-split rescaling)
with a unified maximum value (phi=5.0) approach from FlashDecoding++.

Key changes:
- Split kernel: subtract fixed phi instead of tracking running max m_i,
  eliminating alpha rescaling between tiles
- Reduce kernel: simple summation of partial outputs instead of
  max-aware weighted combination; removes M_partial buffer
- ~12.9% average kernel-level speedup (6.8%-20.1% range) by saving
  HBM bandwidth (no M_partial reads/writes) and reducing ALU ops

The unified phi works because exp(qk - phi) is numerically stable
for typical attention score ranges, and the fixed constant allows
all splits to compute independently without synchronization.
Keep only sdpa.py changes on this branch; revert all other files
(aoti_delegate_handle.h, benchmark_sdpa.py, cuda_backend.cpp,
main.cpp, model.py) to their main branch state.
@Gasoonjia Gasoonjia changed the base branch from cuda-graph to gasoonjia/flashdecoding-pp-async-softmax April 23, 2026 05:41
@Gasoonjia
Copy link
Copy Markdown
Contributor Author

Updated summary with latest perf

Base automatically changed from gasoonjia/flashdecoding-pp-async-softmax to main April 27, 2026 07:15
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@Gasoonjia Gasoonjia merged commit 6968475 into main Apr 27, 2026
116 of 125 checks passed
@Gasoonjia Gasoonjia deleted the fused-deltanet-decode branch April 27, 2026 07:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants