Skip to content

[GDN] Tricked kernels: ungated KKT + fused inference via similarity transform#797

Open
hypnopump wants to merge 3 commits into
fla-org:mainfrom
hypnopump:gdn-tricked-kernels
Open

[GDN] Tricked kernels: ungated KKT + fused inference via similarity transform#797
hypnopump wants to merge 3 commits into
fla-org:mainfrom
hypnopump:gdn-tricked-kernels

Conversation

@hypnopump

@hypnopump hypnopump commented Mar 28, 2026

Copy link
Copy Markdown
Contributor

Summary

Integrates the similarity-transform optimization (N = G·M·G⁻¹) into FLA's GDN forward and backward paths, following Simon Veitner's observation (also independently noted in Comba).

Instead of computing the gated coupling matrix N directly (requiring BT² exp ops), we:

  1. Compute ungated M (skip BT² exp ops)
  2. Solve M⁻¹ (same cost)
  3. Apply G/G⁻¹ diagonal scaling in the WY step (only 2·BT exp ops)

Net savings: 3968 fewer exp ops per chunk (BT=64).

Changes

  • wy_fast.py (new): fusemaxxed kernel (inference, A in registers), tricked WY fwd/bwd kernels, ungated kkt+solve wrapper
  • chunk.py: Forward uses ungated kkt+solve + tricked WY; backward uses tricked WY bwd with precomputed w,u for cheap dG; inference auto-dispatches fusemaxxed (T≤8K) vs fused (T>8K); w/u caching for T≥2048

Benchmark results

All benchmarks on NVIDIA H100, B=1, H·Dh = 2048, bf16. Baseline: FLA 0.5.0 (commit f52529e). Averaged over 3 runs.

Forward (no_grad) — auto-dispatch inference

Auto-dispatch picks fusemaxxed (kkt+solve+WY fused, A in registers) for T≤8K, and fused (kkt+solve)+WY for T>8K.

SeqLen Dh=128 FLA (ms) Dh=128 Tricked (ms) Speedup Dh=256 FLA (ms) Dh=256 Tricked (ms) Speedup
1K 0.452 0.324 1.39x 0.453 0.323 1.40x
2K 0.454 0.324 1.40x 0.460 0.332 1.39x
4K 0.459 0.347 1.32x 0.538 0.428 1.26x
8K 0.572 0.529 1.08x 0.717 0.641 1.12x
16K 0.899 0.790 1.14x 1.092 1.084 1.01x
32K 1.611 1.416 1.14x 1.974 1.971 1.00x
64K 3.014 2.648 1.14x 3.726 3.722 1.00x
128K 5.844 5.128 1.14x 7.201 7.177 1.00x

Forward + Backward

SeqLen Dh=128 FLA (ms) Dh=128 Tricked (ms) Speedup Dh=256 FLA (ms) Dh=256 Tricked (ms) Speedup
1K 1.487 1.741 0.85x 1.611 1.521 1.06x
2K 1.807 1.696 1.07x 1.582 1.449 1.09x
4K 1.835 1.698 1.08x 1.911 1.886 1.01x
8K 2.121 2.049 1.04x 2.858 2.820 1.01x
16K 3.272 3.157 1.04x 4.773 4.727 1.01x
32K 5.840 5.629 1.04x 8.783 8.762 1.00x
64K 10.992 10.524 1.04x 16.881 16.789 1.01x
128K 21.373 20.393 1.05x 33.133 32.878 1.01x

Test plan

  • Correctness: verify tricked forward matches FLA baseline (max_diff < 0.1, cosine_sim > 0.999)
  • Correctness: verify tricked backward gradients match FLA baseline
  • Benchmark: reproduce forward speedups on H100
  • Test variable-length sequences (cu_seqlens path)
  • Test with l2norm and initial_state

Summary by CodeRabbit

  • New Features

    • Added memory efficiency option for improved performance on sequences with varying lengths
    • Introduced optimized inference path for faster inference without backpropagation
  • Refactor

    • Streamlined internal computation flow for gated delta rule operations to reduce memory usage and improve backward pass efficiency

@coderabbitai

coderabbitai Bot commented Mar 28, 2026

Copy link
Copy Markdown
Contributor

Walkthrough

Forward now separates KKT solve from WY computation (new "tricked" recompute/fused paths), adds a mem_efficient toggle and an inference-only fast path that avoids allocating A, and backward accepts optional cached (w,u) and uses new tricked WY prepare/backward kernels.

Changes

Cohort / File(s) Summary
Core interface & autograd
fla/ops/gated_delta_rule/chunk.py
Updated forward/backward/autograd signatures to add mem_efficient and transpose_state_layout; threaded conditional caching of (w,u) through saved tensors; added chunk_gated_delta_rule_fwd_inference and inference fast-path when grads are disabled.
KKT-solve & fused inference kernels
fla/ops/gated_delta_rule/chunk_fwd.py
Added chunk_gated_delta_rule_fwd_kkt_solve which returns solved (I+A)^{-1} blocks A, new Triton fused inference kernel tricked_fusemaxxed_kernel, and tricked_fused_infer_fwd that produces (w,u) directly without storing A.
Tricked WY recompute & backward prep
fla/ops/gated_delta_rule/wy_fast.py
Added tricked_recompute_w_u_fwd and tricked_prepare_wy_repr_bwd plus Triton kernels implementing per-tile g normalization (G/G_inv/G_full), updated forward recompute algebra and backward gradient accumulation/propagation; backward host wrapper prepares dk/dv/db/dg and launches kernels.

Sequence Diagram(s)

sequenceDiagram
  participant Caller
  participant Autograd
  participant KKT as KKT_Solve (triton)
  participant TrickedInfer as Tricked_Infer (triton)
  participant WYRecomp as Tricked_Recomp (triton)
  participant BwdPrep as Tricked_Bwd (triton)

  Caller->>Autograd: call chunk_gated_delta_rule(..., mem_efficient)
  alt inference (no grads)
    Autograd->>TrickedInfer: tricked_fused_infer_fwd(k,v,beta,g_cum)
    TrickedInfer-->>Autograd: w,u
    Autograd-->>Caller: o, final_state
  else training
    Autograd->>KKT: chunk_gated_delta_rule_fwd_kkt_solve(k,v,beta,g)
    KKT-->>Autograd: A
    Autograd->>WYRecomp: tricked_recompute_w_u_fwd(k,v,beta,A,g)
    WYRecomp-->>Autograd: w,u
    Autograd-->>Caller: o, final_state (saves tensors for bwd)
  end
  Caller->>Autograd: backward(dw)
  Autograd->>BwdPrep: tricked_prepare_wy_repr_bwd(..., w,u, dw, du)
  BwdPrep-->>Autograd: dk,dv,db,dg
  Autograd-->>Caller: gradients
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • yzhangcs
  • zhiyuan1i

Poem

🐰 I hop through kernels, swift and spry,
I split A, then trick the WY,
For inference I skip the heavy haul,
For training I recompute it all,
A nibble, a tweak — I bounce and sprawl.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.71% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main focus of the PR—introducing 'tricked kernels' with ungated KKT solving and fused inference via similarity transform optimization. It directly matches the core changes in chunk.py (ungated KKT + tricked WY + inference dispatch) and the new supporting functions in wy_fast.py and chunk_fwd.py.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a 'tricked' path for the gated delta rule, optimizing the forward and backward passes by using G/G_inv scaling instead of gating in the coupling matrix. It adds a memory-efficient mode that allows trading compute for memory by toggling the caching of w and u tensors. Key additions include new Triton kernels for fused KKT, solve, and WY operations, as well as a dedicated inference path. Feedback highlights an inconsistency in the sequence length threshold for caching, unused parameters in the KKT solve function, and the presence of a debug barrier in the production code.

Comment thread fla/ops/gated_delta_rule/chunk.py
Comment thread fla/ops/gated_delta_rule/chunk.py
Comment thread fla/ops/gated_delta_rule/chunk.py
Comment thread fla/ops/gated_delta_rule/chunk_fwd.py
Comment thread fla/ops/gated_delta_rule/wy_fast.py Outdated
@hypnopump hypnopump force-pushed the gdn-tricked-kernels branch 2 times, most recently from fbf02fb to 8a50237 Compare March 28, 2026 20:23

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
fla/ops/gated_delta_rule/chunk_fwd.py (1)

768-789: Thread chunk_indices into this wrapper.

fla/ops/gated_delta_rule/chunk.py already computes chunk_indices on Lines 513-515 before entering the no-grad helper, but this function rebuilds them again on Lines 781-782. That duplicates varlen setup work on the inference hot path this PR is trying to speed up.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 768 - 789, The wrapper
tricked_fused_infer_fwd currently recomputes chunk_indices using
prepare_chunk_indices (lines building chunk_indices and NT) which duplicates
work already done upstream; modify tricked_fused_infer_fwd to accept
chunk_indices (and corresponding NT when needed) as an optional parameter
instead of recomputing it, remove the prepare_chunk_indices/len(...) recompute
logic, and thread the incoming chunk_indices through to the
tricked_fusemaxxed_kernel call (keep cu_seqlens handling as before). Update the
function signature (tricked_fused_infer_fwd) and its callers so the precomputed
chunk_indices from fla/ops/gated_delta_rule/chunk.py are passed in, ensuring the
kernel invocation still receives chunk_indices and NT correctly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 129-170: The no-grad fast path
(chunk_gated_delta_rule_fwd_inference) currently skips CP preprocessing
(chunk_gated_delta_rule_fwd_h_pre_process) when cp_context is set, which can
return local-only outputs; fix by adding a cp_context parameter to
chunk_gated_delta_rule_fwd_inference, thread that parameter through all callers,
and inside the function either call
chunk_gated_delta_rule_fwd_h_pre_process(...) before
chunk_gated_delta_rule_fwd_h(...) when cp_context is non-null or explicitly
raise an error if cp_context is provided while torch.is_grad_enabled() is False;
update all call sites that invoke chunk_gated_delta_rule_fwd_inference (and any
other inference helper copies at lines ~512-533) to pass the cp_context.
- Around line 122-125: Unify the predicate controlling cached w/u by introducing
a single constant (e.g., CACHE_WU_THRESHOLD = 2048) and use it everywhere
instead of duplicated literals: compute cache_wu = not mem_efficient and T >=
CACHE_WU_THRESHOLD in the chunk logic and change the conditional in
ChunkGatedDeltaRuleFunction.forward (and the unpack path that expects 5 vs 7
values) to use the same CACHE_WU_THRESHOLD and same comparison (use >= to match
the early return), so both branches agree on when a 7-tuple is returned; add a
regression test that runs the forward path with T == CACHE_WU_THRESHOLD to
ensure callers handle the 7-value return correctly.
- Around line 17-27: Remove the stale imports causing Ruff errors: delete
chunk_gated_delta_rule_fwd_intra from the import list in the chunk_fwd import
block and remove prepare_wy_repr_bwd and recompute_w_u_fwd from the wy_fast
import block in chunk.py; keep only the actually used symbols
(chunk_gated_delta_rule_fwd_kkt_solve, tricked_fused_infer_fwd,
tricked_prepare_wy_repr_bwd, tricked_recompute_w_u_fwd) so the file imports
match usage and the linter stops flagging unused names.

---

Nitpick comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 768-789: The wrapper tricked_fused_infer_fwd currently recomputes
chunk_indices using prepare_chunk_indices (lines building chunk_indices and NT)
which duplicates work already done upstream; modify tricked_fused_infer_fwd to
accept chunk_indices (and corresponding NT when needed) as an optional parameter
instead of recomputing it, remove the prepare_chunk_indices/len(...) recompute
logic, and thread the incoming chunk_indices through to the
tricked_fusemaxxed_kernel call (keep cu_seqlens handling as before). Update the
function signature (tricked_fused_infer_fwd) and its callers so the precomputed
chunk_indices from fla/ops/gated_delta_rule/chunk.py are passed in, ensuring the
kernel invocation still receives chunk_indices and NT correctly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: fab757ed-0cec-4b51-87b5-b3256af52161

📥 Commits

Reviewing files that changed from the base of the PR and between f52529e and df90cac.

📒 Files selected for processing (3)
  • fla/ops/gated_delta_rule/chunk.py
  • fla/ops/gated_delta_rule/chunk_fwd.py
  • fla/ops/gated_delta_rule/wy_fast.py

Comment thread fla/ops/gated_delta_rule/chunk.py
Comment thread fla/ops/gated_delta_rule/chunk.py
Comment thread fla/ops/gated_delta_rule/chunk.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
fla/ops/gated_delta_rule/chunk_fwd.py (1)

768-789: Reuse the caller’s chunk_indices in fused inference.

tricked_fused_infer_fwd() rebuilds the chunk map even though the no-grad path already has it. On varlen inference that adds avoidable preprocessing and drops the existing precomputed path.

♻️ Proposed change
 def tricked_fused_infer_fwd(
     k: torch.Tensor,
     v: torch.Tensor,
     beta: torch.Tensor,
     g_cum: torch.Tensor,
     cu_seqlens: torch.LongTensor | None = None,
+    chunk_indices: torch.LongTensor | None = None,
     use_exp2: bool = False,
 ) -> tuple[torch.Tensor, torch.Tensor]:
@@
-    chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
+    if chunk_indices is None and cu_seqlens is not None:
+        chunk_indices = prepare_chunk_indices(cu_seqlens, BT)

Pass the existing chunk_indices through from chunk_gated_delta_rule_fwd_inference().

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 768 - 789,
tricked_fused_infer_fwd currently recomputes chunk_indices via
prepare_chunk_indices even when a precomputed map exists; change its API to
accept an optional chunk_indices parameter and use the provided value when not
None (i.e., remove the prepare_chunk_indices(...) call and set chunk_indices =
chunk_indices_arg if given, else prepare_chunk_indices(cu_seqlens,...)). Update
any caller (notably chunk_gated_delta_rule_fwd_inference) to pass through its
precomputed chunk_indices into tricked_fused_infer_fwd so the no‑grad inference
path reuses the existing map instead of rebuilding it.
fla/ops/gated_delta_rule/wy_fast.py (1)

424-425: Make backward derive BT from A too.

Line 425 already keys the forward wrapper off A.shape[-1], but Line 588 resets BT to 64. If the forward chunk size ever changes, backward will launch against the wrong tile shape.

♻️ Proposed change
 def tricked_prepare_wy_repr_bwd(
     k: torch.Tensor,
     v: torch.Tensor,
     beta: torch.Tensor,
@@
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
     B, T, H, K, V = *k.shape, v.shape[-1]
-    BT = 64
+    BT = A.shape[-1]

Also applies to: 587-594

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/wy_fast.py` around lines 424 - 425, The backward
pass is using a hardcoded tile size (BT = 64) which can mismatch the forward
wrapper that sets BT from A.shape[-1]; update the backward logic to derive BT
from A (e.g., BT = A.shape[-1]) instead of 64 so the backward tiling matches
forward. Locate occurrences where BT is reset to 64 (around the backward/chunked
loop that consumes A, k, v) and replace them with deriving BT from A.shape[-1];
apply the same change to the other backward block referenced (the 587-594 area)
so all backward branches use BT = A.shape[-1].
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 119-123: The boundary check for caching w/u is inconsistent:
change the token-threshold predicate so both sides use the same comparison (make
the check consistent between the cache_wu assignment that uses q.shape[1] and
the check inside
chunk_gated_delta_rule_fwd()/ChunkGatedDeltaRuleFunction.forward()); locate the
cache predicate (cache_wu = not mem_efficient and T >= 2048) and the
corresponding conditional in
chunk_gated_delta_rule_fwd()/ChunkGatedDeltaRuleFunction.forward() and make them
identical (e.g., use >= 2048 in both) so the forward and fwd wrapper
return/unpack the same tuple shape at T == 2048.
- Around line 509-530: The no-grad branch currently bypasses the CP
(context-parallel) preprocessing, causing incorrect state when cp_context is
set; modify the condition so that we only take the fused no-grad inference path
when cp_context is False (i.e. require both torch.is_grad_enabled() == False AND
cp_context is falsy), otherwise run the same CP-prepared path used by the
grad-enabled flow (reuse chunk_gated_delta_rule_fwd_h preprocessing steps and
chunk_indices preparation) and then call chunk_gated_delta_rule_fwd_h (or the
CP-aware forward) instead of chunk_gated_delta_rule_fwd_inference so CP
evaluation gets the CP-prepared state.

---

Nitpick comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 768-789: tricked_fused_infer_fwd currently recomputes
chunk_indices via prepare_chunk_indices even when a precomputed map exists;
change its API to accept an optional chunk_indices parameter and use the
provided value when not None (i.e., remove the prepare_chunk_indices(...) call
and set chunk_indices = chunk_indices_arg if given, else
prepare_chunk_indices(cu_seqlens,...)). Update any caller (notably
chunk_gated_delta_rule_fwd_inference) to pass through its precomputed
chunk_indices into tricked_fused_infer_fwd so the no‑grad inference path reuses
the existing map instead of rebuilding it.

In `@fla/ops/gated_delta_rule/wy_fast.py`:
- Around line 424-425: The backward pass is using a hardcoded tile size (BT =
64) which can mismatch the forward wrapper that sets BT from A.shape[-1]; update
the backward logic to derive BT from A (e.g., BT = A.shape[-1]) instead of 64 so
the backward tiling matches forward. Locate occurrences where BT is reset to 64
(around the backward/chunked loop that consumes A, k, v) and replace them with
deriving BT from A.shape[-1]; apply the same change to the other backward block
referenced (the 587-594 area) so all backward branches use BT = A.shape[-1].
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: a48dd2f6-8381-482e-8209-bc0beaafa33d

📥 Commits

Reviewing files that changed from the base of the PR and between f52529e and fbf02fb.

📒 Files selected for processing (3)
  • fla/ops/gated_delta_rule/chunk.py
  • fla/ops/gated_delta_rule/chunk_fwd.py
  • fla/ops/gated_delta_rule/wy_fast.py

Comment thread fla/ops/gated_delta_rule/chunk.py
Comment thread fla/ops/gated_delta_rule/chunk.py Outdated

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
fla/ops/gated_delta_rule/wy_fast.py (2)

406-413: Inconsistent allow_tf32 usage between u and w computations.

The u computation at line 406 uses allow_tf32=False but the w computation at line 413 does not specify this parameter (defaulting to True on supported hardware). This asymmetry may cause precision differences between the two outputs. Consider applying consistent precision settings.

♻️ Proposed fix for consistency
         b_w = tl.dot(b_A, (b_k * b_b[:, None]).to(b_k.dtype)) * b_G[:, None]
+        # Note: If allow_tf32=False is intentional for u, consider adding it here too
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/wy_fast.py` around lines 406 - 413, The tl.dot call
used to compute b_w (the expression "b_w = tl.dot(b_A, (b_k * b_b[:,
None]).to(b_k.dtype)) * b_G[:, None]") is missing the allow_tf32 flag, causing
inconsistent precision versus the b_u computation (which sets allow_tf32=False);
update the tl.dot for b_w to include allow_tf32=False so both dot products use
the same precision semantics.

550-562: Unused variable b_M computed but never used.

b_M is initialized at line 550 and accumulated at line 562, but it's never read or stored. This appears to be dead code that could be removed.

🧹 Proposed fix to remove dead code
-    b_M = tl.zeros([BT, BT], dtype=tl.float32)
     b_dA = tl.where(m_A, -b_dA, 0).to(k.dtype.element_ty)

     tl.debug_barrier()
     for i_k in range(tl.cdiv(K, BK)):
         off = i_k * BK
         p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, off), (BT, BK), (1, 0))
         p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, off), (BT, BK), (1, 0))
         b_k = tl.load(p_k, boundary_check=(0, 1))
         b_kt = tl.trans(b_k)
         b_ktb = b_kt * b_b[None, :]

-        b_M += tl.dot(b_k, b_kt)
         b_dkb = tl.dot(b_dA, b_k)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/wy_fast.py` around lines 550 - 562, The variable b_M
is allocated (b_M = tl.zeros(...)) and accumulated (b_M += tl.dot(...)) but
never used; remove the dead code by deleting the b_M allocation and the
accumulation expression inside the loop (references to b_M), leaving the rest of
the loop (p_k/p_dk, b_k, b_kt, b_ktb, etc.) unchanged; verify no other code
expects b_M or its value after the loop and remove any now-unused imports or
variables created solely for b_M if they become dead.
fla/ops/gated_delta_rule/chunk_fwd.py (2)

405-414: Unused parameter v in function signature.

The v parameter is declared but never used in chunk_gated_delta_rule_fwd_kkt_solve. The underlying kernel only operates on k, g, beta, and A. Consider removing it or documenting why it's kept for API consistency.

♻️ Option 1: Remove unused parameter
 def chunk_gated_delta_rule_fwd_kkt_solve(
     k: torch.Tensor,
-    v: torch.Tensor,
     beta: torch.Tensor,
     g: torch.Tensor | None = None,

Note: If kept for API consistency with chunk_gated_delta_rule_fwd_intra, consider adding a comment explaining this.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 405 - 414, The function
chunk_gated_delta_rule_fwd_kkt_solve declares an unused parameter v; remove v
from the signature and all internal/no-op mentions to avoid confusion OR if you
must preserve API parity with chunk_gated_delta_rule_fwd_intra keep the
parameter but mark it explicitly unused (rename to _v or add a clear comment)
and update any callers/tests accordingly; locate and edit the function named
chunk_gated_delta_rule_fwd_kkt_solve and any call sites to either drop the extra
argument or pass a dummy value, and if opting to keep the parameter add a
one-line comment inside the function explaining it is kept for API compatibility
with chunk_gated_delta_rule_fwd_intra.

697-737: Consistent allow_tf32=False usage for u but not for w.

Similar to tricked_wy_fwd_kernel, the fused kernel uses allow_tf32=False for the u computation (lines 737, 746-747, 756-758, 767-770) but uses default precision for w computation (lines 697, 706, 715, 724-725). This appears intentional but should be documented if the precision difference is deliberate.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 697 - 737, The w-path
uses tl.dot without allow_tf32 while the u-path explicitly sets allow_tf32=False
(see tl.dot calls producing b_w0/b_w1/b_w2/b_w3 vs b_u0 and subsequent b_u*),
causing inconsistent precision; either make the w tl.dot calls consistent by
adding allow_tf32=False to the tl.dot invocations that compute
b_w0/b_w1/b_w2/b_w3, or if the mixed precision is intentional, add a short
explanatory comment near the w and u tl.dot calls documenting why w can use TF32
but u must disable it; update all relevant tl.dot calls for w or add the single
comment so the intent is explicit and consistent across the kernel.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 344-346: The forward/backward caching threshold is inconsistent:
update the condition in ChunkGatedDeltaRuleFunction.forward to match
chunk_gated_delta_rule_fwd (use T >= 2048) so the same tuple shape is returned
for T=2048; locate the caching check in ChunkGatedDeltaRuleFunction.forward (the
variable cache_wu) and change the comparator from ">" to ">=" (or conversely
change chunk_gated_delta_rule_fwd to use ">" if you prefer that convention) so
both functions use the identical threshold and the returned tuple arity matches
what the autograd function expects.

---

Nitpick comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 405-414: The function chunk_gated_delta_rule_fwd_kkt_solve
declares an unused parameter v; remove v from the signature and all
internal/no-op mentions to avoid confusion OR if you must preserve API parity
with chunk_gated_delta_rule_fwd_intra keep the parameter but mark it explicitly
unused (rename to _v or add a clear comment) and update any callers/tests
accordingly; locate and edit the function named
chunk_gated_delta_rule_fwd_kkt_solve and any call sites to either drop the extra
argument or pass a dummy value, and if opting to keep the parameter add a
one-line comment inside the function explaining it is kept for API compatibility
with chunk_gated_delta_rule_fwd_intra.
- Around line 697-737: The w-path uses tl.dot without allow_tf32 while the
u-path explicitly sets allow_tf32=False (see tl.dot calls producing
b_w0/b_w1/b_w2/b_w3 vs b_u0 and subsequent b_u*), causing inconsistent
precision; either make the w tl.dot calls consistent by adding allow_tf32=False
to the tl.dot invocations that compute b_w0/b_w1/b_w2/b_w3, or if the mixed
precision is intentional, add a short explanatory comment near the w and u
tl.dot calls documenting why w can use TF32 but u must disable it; update all
relevant tl.dot calls for w or add the single comment so the intent is explicit
and consistent across the kernel.

In `@fla/ops/gated_delta_rule/wy_fast.py`:
- Around line 406-413: The tl.dot call used to compute b_w (the expression "b_w
= tl.dot(b_A, (b_k * b_b[:, None]).to(b_k.dtype)) * b_G[:, None]") is missing
the allow_tf32 flag, causing inconsistent precision versus the b_u computation
(which sets allow_tf32=False); update the tl.dot for b_w to include
allow_tf32=False so both dot products use the same precision semantics.
- Around line 550-562: The variable b_M is allocated (b_M = tl.zeros(...)) and
accumulated (b_M += tl.dot(...)) but never used; remove the dead code by
deleting the b_M allocation and the accumulation expression inside the loop
(references to b_M), leaving the rest of the loop (p_k/p_dk, b_k, b_kt, b_ktb,
etc.) unchanged; verify no other code expects b_M or its value after the loop
and remove any now-unused imports or variables created solely for b_M if they
become dead.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: ef56151b-219b-4212-b5fa-bf2c9ec2b651

📥 Commits

Reviewing files that changed from the base of the PR and between f52529e and 8a50237.

📒 Files selected for processing (3)
  • fla/ops/gated_delta_rule/chunk.py
  • fla/ops/gated_delta_rule/chunk_fwd.py
  • fla/ops/gated_delta_rule/wy_fast.py

Comment thread fla/ops/gated_delta_rule/chunk.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/gated_delta_rule/wy_fast.py`:
- Around line 390-414: The w-path currently uses the normalized b_G (exp(g -
b_g_max)) which loses the chunk max shift; restore the original G by multiplying
b_w by exp(b_g_max) (or exp2(b_g_max) when USE_EXP2) before storing. Concretely,
after computing b_w in the block that builds b_w (using b_k, b_b and b_G),
compute a scale = exp2(b_g_max) if USE_EXP2 else exp(b_g_max) and apply b_w =
b_w * scale[:, None] (use the same b_g_max used to normalize b_g); also apply
the identical change in the fused inference kernel mentioned (chunk_fwd.py
region 642-691) so both paths use the same max convention.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: efcd6100-a713-4f58-9a00-ee17fe1e955b

📥 Commits

Reviewing files that changed from the base of the PR and between 8a50237 and 3f037e2.

📒 Files selected for processing (2)
  • fla/ops/gated_delta_rule/chunk.py
  • fla/ops/gated_delta_rule/wy_fast.py

Comment thread fla/ops/gated_delta_rule/wy_fast.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 405-444: The function chunk_gated_delta_rule_fwd_kkt_solve
currently declares an unused parameter v; remove v from the function signature
(and its type hints) and from the internal parameter list passed to the triton
kernel call, then update every call site that passes v (notably the call in
fla/ops/gated_delta_rule/chunk.py where v=v is passed) to stop providing that
argument; ensure any import/type usages referencing the old signature are
updated to the new signature so there are no mismatched-argument errors.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: c09a3901-af8e-43fe-8f3d-6630c1cb3704

📥 Commits

Reviewing files that changed from the base of the PR and between 3f037e2 and 19145b9.

📒 Files selected for processing (2)
  • fla/ops/gated_delta_rule/chunk_fwd.py
  • fla/ops/gated_delta_rule/wy_fast.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • fla/ops/gated_delta_rule/wy_fast.py

Comment thread fla/ops/gated_delta_rule/chunk_fwd.py
@hypnopump hypnopump changed the title [GDN] Port math-tricked kernels: ungated KKT + fused inference path [GDN] Tricked kernels: 1.23x inference speedup via ungated coupling + G/G_inv WY scaling Mar 31, 2026
@hypnopump hypnopump changed the title [GDN] Tricked kernels: 1.23x inference speedup via ungated coupling + G/G_inv WY scaling [GDN] Tricked kernels: ungated KKT + fused inference via similarity transform Mar 31, 2026
- Replace gated coupling matrix with ungated KKT + G/G_inv post-scaling
  in WY step, saving ~BT² exp operations in both forward and backward
- Add fusemaxxed kernel for inference (no_grad): fuses KKT+solve+WY
  into a single kernel where A never touches HBM
- Backward uses precomputed w,u for cheap dG gradient (no extra matmuls)
- Cache w,u from forward pass (T>2048) to skip recomputation in backward
  (mem_efficient=False by default)
…eanup

- Align cache_wu threshold to T >= 2048 in both forward and autograd wrapper
- Guard no-grad inference path to fall back when cp_context is active
- Remove unused b_M variable and tl.debug_barrier() from backward kernel
…utation

The g_max subtraction for numerical stability only cancels for u (where
G * G_inv = 1), but w = G * (A @ (k*beta)) has no inverse term to cancel
the shift. Use G_full = G * exp(g_max) for w so it matches exp(g) exactly.
@hypnopump hypnopump force-pushed the gdn-tricked-kernels branch from 19145b9 to 7004d19 Compare March 31, 2026 16:53

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
fla/ops/gated_delta_rule/chunk.py (1)

55-75: ⚠️ Potential issue | 🔴 Critical

Don’t route GQA through the tricked WY helpers yet.

This forward/backward path is now unconditional, but tricked_recompute_w_u_fwd in fla/ops/gated_delta_rule/wy_fast.py:429-451 and tricked_prepare_wy_repr_bwd in fla/ops/gated_delta_rule/wy_fast.py:589-625 both derive their launch head count from k.shape[2]. The public API here still accepts H % Hq == 0, so grouped-head runs will only materialize the tricked WY state for Hq heads and backward will mirror the same mismatch. Please fall back to the old WY path, or explicitly reject GQA until those kernels take both H and Hq.

Also applies to: 189-201, 277-291

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk.py` around lines 55 - 75, The new forward
unconditionally calls the tricked WY helpers (tricked_recompute_w_u_fwd) which
assume the kernel launch head-count equals k.shape[2], causing mismatch for
grouped-head (GQA) runs; update the chunked forward/backward callers (the spots
invoking chunk_gated_delta_rule_fwd_kkt_solve and tricked_recompute_w_u_fwd) to
detect grouped-head mode (H != Hq or Hq < H) and either (a) fall back to the
legacy WY code path used previously for those ranges, or (b) explicitly
raise/return an error rejecting GQA until the tricked WY kernels
(tricked_recompute_w_u_fwd and tricked_prepare_wy_repr_bwd) are updated to
accept both H and Hq; ensure the check is applied at the other occurrences
mentioned (the other two forward blocks corresponding to the review: the blocks
around the other two ranges) so grouped-head runs never call the tricked WY
helpers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 467-515: The fast-path in tricked_fusemaxxed_kernel (and the
corresponding lines in tricked_fused_infer_fwd) assumes a single head-count and
therefore uses k/v/w/u/beta/g strides derived from k.shape[2], which breaks when
model heads H != query-heads Hq; replicate the H↔Hq mapping used in
chunk_gated_delta_rule_fwd_kkt_solve_kernel so program IDs and pointer
arithmetic use the query-head index (map i_h -> i_hq or compute i_h_from_i_bh
the same way) and load/store v, beta, g, u using the correct per-head strides,
or alternatively gate this fast path to only run when H == Hq; update indexing
of k/v/w/u/beta/g and the early-return/program-launch condition accordingly.
- Around line 424-445: The Triton kernel launch for
chunk_gated_delta_rule_fwd_kkt_solve_kernel is missing the constexpr Hq
parameter causing incorrect pointer math for grouped-query attention; compute Hq
(e.g., Hq = g.shape[2] or otherwise obtain the grouped-query head count
available in this scope) and include it in the launch configuration alongside H
(i.e., change the launch tuple from (NT, B * H) to include Hq as a
template/constexpr parameter), and if the kernel signature expects Hq as an
explicit arg ensure you pass it as well so the kernel's key indexing (used in k
pointer arithmetic) is correct.
- Around line 645-687: The max-reduction uses zero-filled tail lanes from the
boundary_check loads (b_g0..b_g3), so mask out invalid lanes using the
per-subchunk masks (m_tc0, m_tc1, m_tc2, m_tc3) by replacing invalid entries
with -inf (or a sufficiently large negative float) before computing b_g_max;
update the code that computes b_g_max (and then b_g0..b_g3 adjustments) to use
the masked b_g* values so the normalization and subsequent b_Ginv* / b_Gw*
calculations remain numerically stable on partial chunks.

---

Outside diff comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 55-75: The new forward unconditionally calls the tricked WY
helpers (tricked_recompute_w_u_fwd) which assume the kernel launch head-count
equals k.shape[2], causing mismatch for grouped-head (GQA) runs; update the
chunked forward/backward callers (the spots invoking
chunk_gated_delta_rule_fwd_kkt_solve and tricked_recompute_w_u_fwd) to detect
grouped-head mode (H != Hq or Hq < H) and either (a) fall back to the legacy WY
code path used previously for those ranges, or (b) explicitly raise/return an
error rejecting GQA until the tricked WY kernels (tricked_recompute_w_u_fwd and
tricked_prepare_wy_repr_bwd) are updated to accept both H and Hq; ensure the
check is applied at the other occurrences mentioned (the other two forward
blocks corresponding to the review: the blocks around the other two ranges) so
grouped-head runs never call the tricked WY helpers.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 7f95c04d-57e1-4cd1-953e-5e7051d8b95d

📥 Commits

Reviewing files that changed from the base of the PR and between 19145b9 and 7004d19.

📒 Files selected for processing (3)
  • fla/ops/gated_delta_rule/chunk.py
  • fla/ops/gated_delta_rule/chunk_fwd.py
  • fla/ops/gated_delta_rule/wy_fast.py
✅ Files skipped from review due to trivial changes (1)
  • fla/ops/gated_delta_rule/wy_fast.py

Comment on lines +424 to +445
B, T, H, K = k.shape
BT = chunk_size
BC = 16

if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

A = torch.zeros(B, T, H, BT, device=k.device, dtype=k.dtype)
chunk_gated_delta_rule_fwd_kkt_solve_kernel[(NT, B * H)](
k=k,
g=g,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
USE_EXP2=use_exp2,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the kernel definition and see what parameters it expects
head -100 fla/ops/gated_delta_rule/chunk_fwd.py | cat -n

Repository: fla-org/flash-linear-attention

Length of output: 3880


🏁 Script executed:

# Get more context around the kernel definition (around lines 40-41)
sed -n '1,80p' fla/ops/gated_delta_rule/chunk_fwd.py | cat -n

Repository: fla-org/flash-linear-attention

Length of output: 2870


🏁 Script executed:

# Look at the chunk_gated_delta_rule_fwd_kkt_solve_kernel definition
rg "chunk_gated_delta_rule_fwd_kkt_solve_kernel" fla/ops/gated_delta_rule/chunk_fwd.py -A 40 -B 5

Repository: fla-org/flash-linear-attention

Length of output: 3756


🏁 Script executed:

# Now let's look at both launchers - the old one around 378-388
sed -n '370,395p' fla/ops/gated_delta_rule/chunk_fwd.py | cat -n

Repository: fla-org/flash-linear-attention

Length of output: 928


🏁 Script executed:

# And the new one around 424-445
sed -n '420,450p' fla/ops/gated_delta_rule/chunk_fwd.py | cat -n

Repository: fla-org/flash-linear-attention

Length of output: 1144


🏁 Script executed:

# Let's get the full function signatures to understand parameter structure better
sed -n '1,500p' fla/ops/gated_delta_rule/chunk_fwd.py | grep -A 50 "def chunk_gated_delta_rule_fwd_kkt_solve"

Repository: fla-org/flash-linear-attention

Length of output: 3067


Pass Hq into this Triton launch.

chunk_gated_delta_rule_fwd_kkt_solve_kernel requires both H and Hq as constexpr parameters, but this wrapper derives H from k.shape[2] and never supplies Hq. The kernel uses Hq for key pointer indexing (line 81: k += (bos * Hq + i_h // (H // Hq)) * K) and will fail or produce incorrect results for grouped-query attention where Hq ≠ H.

Minimal fix
-    B, T, H, K = k.shape
+    B, T, Hq, K = k.shape
+    H = beta.shape[2]
     BT = chunk_size
     BC = 16
@@
     chunk_gated_delta_rule_fwd_kkt_solve_kernel[(NT, B * H)](
         k=k,
         g=g,
         beta=beta,
         A=A,
         cu_seqlens=cu_seqlens,
         chunk_indices=chunk_indices,
         T=T,
         H=H,
+        Hq=Hq,
         K=K,
         BT=BT,
         BC=BC,
         USE_EXP2=use_exp2,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 424 - 445, The Triton
kernel launch for chunk_gated_delta_rule_fwd_kkt_solve_kernel is missing the
constexpr Hq parameter causing incorrect pointer math for grouped-query
attention; compute Hq (e.g., Hq = g.shape[2] or otherwise obtain the
grouped-query head count available in this scope) and include it in the launch
configuration alongside H (i.e., change the launch tuple from (NT, B * H) to
include Hq as a template/constexpr parameter), and if the kernel signature
expects Hq as an explicit arg ensure you pass it as well so the kernel's key
indexing (used in k pointer arithmetic) is correct.

Comment on lines +467 to +515
def tricked_fusemaxxed_kernel(
k,
v,
beta,
g,
w,
u,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_EXP2: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
"""
Fused ungated kkt + solve_tril + tricked WY.

Computes (I+A)^{-1} entirely in registers (A never hits HBM),
then immediately uses it to compute w and u with G/G_inv scaling.
"""
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H

if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T

if i_t * BT >= T:
return

i_tc0 = i_t * BT
i_tc1 = i_t * BT + BC
i_tc2 = i_t * BT + 2 * BC
i_tc3 = i_t * BT + 3 * BC

k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
w += (bos * H + i_h) * K
u += (bos * H + i_h) * V

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

The fused inference fast path is not GQA-safe yet.

tricked_fusemaxxed_kernel indexes k, v, beta, and g with one shared head count, and tricked_fused_infer_fwd sets that count from k.shape[2]. For H != Hq, this launches only B*Hq programs, uses the wrong stride for v/beta/g, and leaves the extra value heads in u untouched. Please mirror the H/Hq mapping already used by chunk_gated_delta_rule_fwd_kkt_solve_kernel, or gate this fast path to H == Hq.

Also applies to: 786-807

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 467 - 515, The fast-path
in tricked_fusemaxxed_kernel (and the corresponding lines in
tricked_fused_infer_fwd) assumes a single head-count and therefore uses
k/v/w/u/beta/g strides derived from k.shape[2], which breaks when model heads H
!= query-heads Hq; replicate the H↔Hq mapping used in
chunk_gated_delta_rule_fwd_kkt_solve_kernel so program IDs and pointer
arithmetic use the query-head index (map i_h -> i_hq or compute i_h_from_i_bh
the same way) and load/store v, beta, g, u using the correct per-head strides,
or alternatively gate this fast path to only run when H == Hq; update indexing
of k/v/w/u/beta/g and the early-return/program-launch condition accordingly.

Comment on lines +645 to +687
# Phase 5: Load gating, cast Ai blocks for WY matmuls
p_g0 = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_tc0,), (BC,), (0,))
p_g1 = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_tc1,), (BC,), (0,))
p_g2 = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_tc2,), (BC,), (0,))
p_g3 = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_tc3,), (BC,), (0,))
b_g0 = tl.load(p_g0, boundary_check=(0,)).to(tl.float32)
b_g1 = tl.load(p_g1, boundary_check=(0,)).to(tl.float32)
b_g2 = tl.load(p_g2, boundary_check=(0,)).to(tl.float32)
b_g3 = tl.load(p_g3, boundary_check=(0,)).to(tl.float32)

# Normalize g by subtracting max across all sub-chunks to prevent exp overflow
b_g_max = tl.max(tl.maximum(tl.maximum(b_g0, b_g1), tl.maximum(b_g2, b_g3)), 0)
b_g0 = b_g0 - b_g_max
b_g1 = b_g1 - b_g_max
b_g2 = b_g2 - b_g_max
b_g3 = b_g3 - b_g_max

if USE_EXP2:
b_G0 = exp2(b_g0)
b_Ginv0 = exp2(-b_g0)
b_G1 = exp2(b_g1)
b_Ginv1 = exp2(-b_g1)
b_G2 = exp2(b_g2)
b_Ginv2 = exp2(-b_g2)
b_G3 = exp2(b_g3)
b_Ginv3 = exp2(-b_g3)
b_Gw0 = b_G0 * exp2(b_g_max)
b_Gw1 = b_G1 * exp2(b_g_max)
b_Gw2 = b_G2 * exp2(b_g_max)
b_Gw3 = b_G3 * exp2(b_g_max)
else:
b_G0 = exp(b_g0)
b_Ginv0 = exp(-b_g0)
b_G1 = exp(b_g1)
b_Ginv1 = exp(-b_g1)
b_G2 = exp(b_g2)
b_Ginv2 = exp(-b_g2)
b_G3 = exp(b_g3)
b_Ginv3 = exp(-b_g3)
b_Gw0 = b_G0 * exp(b_g_max)
b_Gw1 = b_G1 * exp(b_g_max)
b_Gw2 = b_G2 * exp(b_g_max)
b_Gw3 = b_G3 * exp(b_g_max)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

wc -l fla/ops/gated_delta_rule/chunk_fwd.py

Repository: fla-org/flash-linear-attention

Length of output: 115


🏁 Script executed:

cat -n fla/ops/gated_delta_rule/chunk_fwd.py | sed -n '630,700p'

Repository: fla-org/flash-linear-attention

Length of output: 3462


🏁 Script executed:

cat -n fla/ops/gated_delta_rule/chunk_fwd.py | sed -n '400,450p'

Repository: fla-org/flash-linear-attention

Length of output: 1823


🏁 Script executed:

# Search for mask definitions and boundary handling
rg -n "mask|boundary" fla/ops/gated_delta_rule/chunk_fwd.py | head -30

Repository: fla-org/flash-linear-attention

Length of output: 2186


🏁 Script executed:

# Search for i_tc0, i_tc1, etc definitions
rg -n "i_tc" fla/ops/gated_delta_rule/chunk_fwd.py

Repository: fla-org/flash-linear-attention

Length of output: 6009


🏁 Script executed:

cat -n fla/ops/gated_delta_rule/chunk_fwd.py | sed -n '500,660p'

Repository: fla-org/flash-linear-attention

Length of output: 8237


🏁 Script executed:

cat -n fla/ops/gated_delta_rule/chunk_fwd.py | sed -n '76,120p'

Repository: fla-org/flash-linear-attention

Length of output: 2446


🏁 Script executed:

# Look for how g_max normalization is handled in the other kernel
rg -n "g_max|b_Gw" fla/ops/gated_delta_rule/chunk_fwd.py | head -20

Repository: fla-org/flash-linear-attention

Length of output: 963


🏁 Script executed:

# Check the full first kernel to see if there's any masking pattern for g_max
cat -n fla/ops/gated_delta_rule/chunk_fwd.py | sed -n '50,250p'

Repository: fla-org/flash-linear-attention

Length of output: 11235


Mask invalid lanes before reducing g_max.

The boundary_check loads zero-fill tail lanes. On any partial chunk where all valid gates are below zero, Line 656 picks the zero-filled out-of-bounds element as the max. This defeats the normalization for G_inv and can cause overflow instead of stabilizing it. Mask b_g* values (for example, using -inf on invalid lanes via m_tc*) before computing b_g_max in the reduction.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 645 - 687, The
max-reduction uses zero-filled tail lanes from the boundary_check loads
(b_g0..b_g3), so mask out invalid lanes using the per-subchunk masks (m_tc0,
m_tc1, m_tc2, m_tc3) by replacing invalid entries with -inf (or a sufficiently
large negative float) before computing b_g_max; update the code that computes
b_g_max (and then b_g0..b_g3 adjustments) to use the masked b_g* values so the
normalization and subsequent b_Ginv* / b_Gw* calculations remain numerically
stable on partial chunks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant