[GDN] Tricked kernels: ungated KKT + fused inference via similarity transform#797
[GDN] Tricked kernels: ungated KKT + fused inference via similarity transform#797hypnopump wants to merge 3 commits into
Conversation
WalkthroughForward now separates KKT solve from WY computation (new "tricked" recompute/fused paths), adds a mem_efficient toggle and an inference-only fast path that avoids allocating A, and backward accepts optional cached (w,u) and uses new tricked WY prepare/backward kernels. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant Autograd
participant KKT as KKT_Solve (triton)
participant TrickedInfer as Tricked_Infer (triton)
participant WYRecomp as Tricked_Recomp (triton)
participant BwdPrep as Tricked_Bwd (triton)
Caller->>Autograd: call chunk_gated_delta_rule(..., mem_efficient)
alt inference (no grads)
Autograd->>TrickedInfer: tricked_fused_infer_fwd(k,v,beta,g_cum)
TrickedInfer-->>Autograd: w,u
Autograd-->>Caller: o, final_state
else training
Autograd->>KKT: chunk_gated_delta_rule_fwd_kkt_solve(k,v,beta,g)
KKT-->>Autograd: A
Autograd->>WYRecomp: tricked_recompute_w_u_fwd(k,v,beta,A,g)
WYRecomp-->>Autograd: w,u
Autograd-->>Caller: o, final_state (saves tensors for bwd)
end
Caller->>Autograd: backward(dw)
Autograd->>BwdPrep: tricked_prepare_wy_repr_bwd(..., w,u, dw, du)
BwdPrep-->>Autograd: dk,dv,db,dg
Autograd-->>Caller: gradients
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a 'tricked' path for the gated delta rule, optimizing the forward and backward passes by using G/G_inv scaling instead of gating in the coupling matrix. It adds a memory-efficient mode that allows trading compute for memory by toggling the caching of w and u tensors. Key additions include new Triton kernels for fused KKT, solve, and WY operations, as well as a dedicated inference path. Feedback highlights an inconsistency in the sequence length threshold for caching, unused parameters in the KKT solve function, and the presence of a debug barrier in the production code.
fbf02fb to
8a50237
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
fla/ops/gated_delta_rule/chunk_fwd.py (1)
768-789: Threadchunk_indicesinto this wrapper.
fla/ops/gated_delta_rule/chunk.pyalready computeschunk_indiceson Lines 513-515 before entering the no-grad helper, but this function rebuilds them again on Lines 781-782. That duplicates varlen setup work on the inference hot path this PR is trying to speed up.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 768 - 789, The wrapper tricked_fused_infer_fwd currently recomputes chunk_indices using prepare_chunk_indices (lines building chunk_indices and NT) which duplicates work already done upstream; modify tricked_fused_infer_fwd to accept chunk_indices (and corresponding NT when needed) as an optional parameter instead of recomputing it, remove the prepare_chunk_indices/len(...) recompute logic, and thread the incoming chunk_indices through to the tricked_fusemaxxed_kernel call (keep cu_seqlens handling as before). Update the function signature (tricked_fused_infer_fwd) and its callers so the precomputed chunk_indices from fla/ops/gated_delta_rule/chunk.py are passed in, ensuring the kernel invocation still receives chunk_indices and NT correctly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 129-170: The no-grad fast path
(chunk_gated_delta_rule_fwd_inference) currently skips CP preprocessing
(chunk_gated_delta_rule_fwd_h_pre_process) when cp_context is set, which can
return local-only outputs; fix by adding a cp_context parameter to
chunk_gated_delta_rule_fwd_inference, thread that parameter through all callers,
and inside the function either call
chunk_gated_delta_rule_fwd_h_pre_process(...) before
chunk_gated_delta_rule_fwd_h(...) when cp_context is non-null or explicitly
raise an error if cp_context is provided while torch.is_grad_enabled() is False;
update all call sites that invoke chunk_gated_delta_rule_fwd_inference (and any
other inference helper copies at lines ~512-533) to pass the cp_context.
- Around line 122-125: Unify the predicate controlling cached w/u by introducing
a single constant (e.g., CACHE_WU_THRESHOLD = 2048) and use it everywhere
instead of duplicated literals: compute cache_wu = not mem_efficient and T >=
CACHE_WU_THRESHOLD in the chunk logic and change the conditional in
ChunkGatedDeltaRuleFunction.forward (and the unpack path that expects 5 vs 7
values) to use the same CACHE_WU_THRESHOLD and same comparison (use >= to match
the early return), so both branches agree on when a 7-tuple is returned; add a
regression test that runs the forward path with T == CACHE_WU_THRESHOLD to
ensure callers handle the 7-value return correctly.
- Around line 17-27: Remove the stale imports causing Ruff errors: delete
chunk_gated_delta_rule_fwd_intra from the import list in the chunk_fwd import
block and remove prepare_wy_repr_bwd and recompute_w_u_fwd from the wy_fast
import block in chunk.py; keep only the actually used symbols
(chunk_gated_delta_rule_fwd_kkt_solve, tricked_fused_infer_fwd,
tricked_prepare_wy_repr_bwd, tricked_recompute_w_u_fwd) so the file imports
match usage and the linter stops flagging unused names.
---
Nitpick comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 768-789: The wrapper tricked_fused_infer_fwd currently recomputes
chunk_indices using prepare_chunk_indices (lines building chunk_indices and NT)
which duplicates work already done upstream; modify tricked_fused_infer_fwd to
accept chunk_indices (and corresponding NT when needed) as an optional parameter
instead of recomputing it, remove the prepare_chunk_indices/len(...) recompute
logic, and thread the incoming chunk_indices through to the
tricked_fusemaxxed_kernel call (keep cu_seqlens handling as before). Update the
function signature (tricked_fused_infer_fwd) and its callers so the precomputed
chunk_indices from fla/ops/gated_delta_rule/chunk.py are passed in, ensuring the
kernel invocation still receives chunk_indices and NT correctly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: fab757ed-0cec-4b51-87b5-b3256af52161
📒 Files selected for processing (3)
fla/ops/gated_delta_rule/chunk.pyfla/ops/gated_delta_rule/chunk_fwd.pyfla/ops/gated_delta_rule/wy_fast.py
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
fla/ops/gated_delta_rule/chunk_fwd.py (1)
768-789: Reuse the caller’schunk_indicesin fused inference.
tricked_fused_infer_fwd()rebuilds the chunk map even though the no-grad path already has it. On varlen inference that adds avoidable preprocessing and drops the existing precomputed path.♻️ Proposed change
def tricked_fused_infer_fwd( k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, g_cum: torch.Tensor, cu_seqlens: torch.LongTensor | None = None, + chunk_indices: torch.LongTensor | None = None, use_exp2: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: @@ - chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT)Pass the existing
chunk_indicesthrough fromchunk_gated_delta_rule_fwd_inference().🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 768 - 789, tricked_fused_infer_fwd currently recomputes chunk_indices via prepare_chunk_indices even when a precomputed map exists; change its API to accept an optional chunk_indices parameter and use the provided value when not None (i.e., remove the prepare_chunk_indices(...) call and set chunk_indices = chunk_indices_arg if given, else prepare_chunk_indices(cu_seqlens,...)). Update any caller (notably chunk_gated_delta_rule_fwd_inference) to pass through its precomputed chunk_indices into tricked_fused_infer_fwd so the no‑grad inference path reuses the existing map instead of rebuilding it.fla/ops/gated_delta_rule/wy_fast.py (1)
424-425: Make backward deriveBTfromAtoo.Line 425 already keys the forward wrapper off
A.shape[-1], but Line 588 resetsBTto64. If the forward chunk size ever changes, backward will launch against the wrong tile shape.♻️ Proposed change
def tricked_prepare_wy_repr_bwd( k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, @@ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] - BT = 64 + BT = A.shape[-1]Also applies to: 587-594
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/wy_fast.py` around lines 424 - 425, The backward pass is using a hardcoded tile size (BT = 64) which can mismatch the forward wrapper that sets BT from A.shape[-1]; update the backward logic to derive BT from A (e.g., BT = A.shape[-1]) instead of 64 so the backward tiling matches forward. Locate occurrences where BT is reset to 64 (around the backward/chunked loop that consumes A, k, v) and replace them with deriving BT from A.shape[-1]; apply the same change to the other backward block referenced (the 587-594 area) so all backward branches use BT = A.shape[-1].
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 119-123: The boundary check for caching w/u is inconsistent:
change the token-threshold predicate so both sides use the same comparison (make
the check consistent between the cache_wu assignment that uses q.shape[1] and
the check inside
chunk_gated_delta_rule_fwd()/ChunkGatedDeltaRuleFunction.forward()); locate the
cache predicate (cache_wu = not mem_efficient and T >= 2048) and the
corresponding conditional in
chunk_gated_delta_rule_fwd()/ChunkGatedDeltaRuleFunction.forward() and make them
identical (e.g., use >= 2048 in both) so the forward and fwd wrapper
return/unpack the same tuple shape at T == 2048.
- Around line 509-530: The no-grad branch currently bypasses the CP
(context-parallel) preprocessing, causing incorrect state when cp_context is
set; modify the condition so that we only take the fused no-grad inference path
when cp_context is False (i.e. require both torch.is_grad_enabled() == False AND
cp_context is falsy), otherwise run the same CP-prepared path used by the
grad-enabled flow (reuse chunk_gated_delta_rule_fwd_h preprocessing steps and
chunk_indices preparation) and then call chunk_gated_delta_rule_fwd_h (or the
CP-aware forward) instead of chunk_gated_delta_rule_fwd_inference so CP
evaluation gets the CP-prepared state.
---
Nitpick comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 768-789: tricked_fused_infer_fwd currently recomputes
chunk_indices via prepare_chunk_indices even when a precomputed map exists;
change its API to accept an optional chunk_indices parameter and use the
provided value when not None (i.e., remove the prepare_chunk_indices(...) call
and set chunk_indices = chunk_indices_arg if given, else
prepare_chunk_indices(cu_seqlens,...)). Update any caller (notably
chunk_gated_delta_rule_fwd_inference) to pass through its precomputed
chunk_indices into tricked_fused_infer_fwd so the no‑grad inference path reuses
the existing map instead of rebuilding it.
In `@fla/ops/gated_delta_rule/wy_fast.py`:
- Around line 424-425: The backward pass is using a hardcoded tile size (BT =
64) which can mismatch the forward wrapper that sets BT from A.shape[-1]; update
the backward logic to derive BT from A (e.g., BT = A.shape[-1]) instead of 64 so
the backward tiling matches forward. Locate occurrences where BT is reset to 64
(around the backward/chunked loop that consumes A, k, v) and replace them with
deriving BT from A.shape[-1]; apply the same change to the other backward block
referenced (the 587-594 area) so all backward branches use BT = A.shape[-1].
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: a48dd2f6-8381-482e-8209-bc0beaafa33d
📒 Files selected for processing (3)
fla/ops/gated_delta_rule/chunk.pyfla/ops/gated_delta_rule/chunk_fwd.pyfla/ops/gated_delta_rule/wy_fast.py
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (4)
fla/ops/gated_delta_rule/wy_fast.py (2)
406-413: Inconsistentallow_tf32usage betweenuandwcomputations.The
ucomputation at line 406 usesallow_tf32=Falsebut thewcomputation at line 413 does not specify this parameter (defaulting toTrueon supported hardware). This asymmetry may cause precision differences between the two outputs. Consider applying consistent precision settings.♻️ Proposed fix for consistency
b_w = tl.dot(b_A, (b_k * b_b[:, None]).to(b_k.dtype)) * b_G[:, None] + # Note: If allow_tf32=False is intentional for u, consider adding it here too🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/wy_fast.py` around lines 406 - 413, The tl.dot call used to compute b_w (the expression "b_w = tl.dot(b_A, (b_k * b_b[:, None]).to(b_k.dtype)) * b_G[:, None]") is missing the allow_tf32 flag, causing inconsistent precision versus the b_u computation (which sets allow_tf32=False); update the tl.dot for b_w to include allow_tf32=False so both dot products use the same precision semantics.
550-562: Unused variableb_Mcomputed but never used.
b_Mis initialized at line 550 and accumulated at line 562, but it's never read or stored. This appears to be dead code that could be removed.🧹 Proposed fix to remove dead code
- b_M = tl.zeros([BT, BT], dtype=tl.float32) b_dA = tl.where(m_A, -b_dA, 0).to(k.dtype.element_ty) tl.debug_barrier() for i_k in range(tl.cdiv(K, BK)): off = i_k * BK p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, off), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, off), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kt = tl.trans(b_k) b_ktb = b_kt * b_b[None, :] - b_M += tl.dot(b_k, b_kt) b_dkb = tl.dot(b_dA, b_k)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/wy_fast.py` around lines 550 - 562, The variable b_M is allocated (b_M = tl.zeros(...)) and accumulated (b_M += tl.dot(...)) but never used; remove the dead code by deleting the b_M allocation and the accumulation expression inside the loop (references to b_M), leaving the rest of the loop (p_k/p_dk, b_k, b_kt, b_ktb, etc.) unchanged; verify no other code expects b_M or its value after the loop and remove any now-unused imports or variables created solely for b_M if they become dead.fla/ops/gated_delta_rule/chunk_fwd.py (2)
405-414: Unused parametervin function signature.The
vparameter is declared but never used inchunk_gated_delta_rule_fwd_kkt_solve. The underlying kernel only operates onk,g,beta, andA. Consider removing it or documenting why it's kept for API consistency.♻️ Option 1: Remove unused parameter
def chunk_gated_delta_rule_fwd_kkt_solve( k: torch.Tensor, - v: torch.Tensor, beta: torch.Tensor, g: torch.Tensor | None = None,Note: If kept for API consistency with
chunk_gated_delta_rule_fwd_intra, consider adding a comment explaining this.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 405 - 414, The function chunk_gated_delta_rule_fwd_kkt_solve declares an unused parameter v; remove v from the signature and all internal/no-op mentions to avoid confusion OR if you must preserve API parity with chunk_gated_delta_rule_fwd_intra keep the parameter but mark it explicitly unused (rename to _v or add a clear comment) and update any callers/tests accordingly; locate and edit the function named chunk_gated_delta_rule_fwd_kkt_solve and any call sites to either drop the extra argument or pass a dummy value, and if opting to keep the parameter add a one-line comment inside the function explaining it is kept for API compatibility with chunk_gated_delta_rule_fwd_intra.
697-737: Consistentallow_tf32=Falseusage forubut not forw.Similar to
tricked_wy_fwd_kernel, the fused kernel usesallow_tf32=Falsefor theucomputation (lines 737, 746-747, 756-758, 767-770) but uses default precision forwcomputation (lines 697, 706, 715, 724-725). This appears intentional but should be documented if the precision difference is deliberate.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 697 - 737, The w-path uses tl.dot without allow_tf32 while the u-path explicitly sets allow_tf32=False (see tl.dot calls producing b_w0/b_w1/b_w2/b_w3 vs b_u0 and subsequent b_u*), causing inconsistent precision; either make the w tl.dot calls consistent by adding allow_tf32=False to the tl.dot invocations that compute b_w0/b_w1/b_w2/b_w3, or if the mixed precision is intentional, add a short explanatory comment near the w and u tl.dot calls documenting why w can use TF32 but u must disable it; update all relevant tl.dot calls for w or add the single comment so the intent is explicit and consistent across the kernel.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 344-346: The forward/backward caching threshold is inconsistent:
update the condition in ChunkGatedDeltaRuleFunction.forward to match
chunk_gated_delta_rule_fwd (use T >= 2048) so the same tuple shape is returned
for T=2048; locate the caching check in ChunkGatedDeltaRuleFunction.forward (the
variable cache_wu) and change the comparator from ">" to ">=" (or conversely
change chunk_gated_delta_rule_fwd to use ">" if you prefer that convention) so
both functions use the identical threshold and the returned tuple arity matches
what the autograd function expects.
---
Nitpick comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 405-414: The function chunk_gated_delta_rule_fwd_kkt_solve
declares an unused parameter v; remove v from the signature and all
internal/no-op mentions to avoid confusion OR if you must preserve API parity
with chunk_gated_delta_rule_fwd_intra keep the parameter but mark it explicitly
unused (rename to _v or add a clear comment) and update any callers/tests
accordingly; locate and edit the function named
chunk_gated_delta_rule_fwd_kkt_solve and any call sites to either drop the extra
argument or pass a dummy value, and if opting to keep the parameter add a
one-line comment inside the function explaining it is kept for API compatibility
with chunk_gated_delta_rule_fwd_intra.
- Around line 697-737: The w-path uses tl.dot without allow_tf32 while the
u-path explicitly sets allow_tf32=False (see tl.dot calls producing
b_w0/b_w1/b_w2/b_w3 vs b_u0 and subsequent b_u*), causing inconsistent
precision; either make the w tl.dot calls consistent by adding allow_tf32=False
to the tl.dot invocations that compute b_w0/b_w1/b_w2/b_w3, or if the mixed
precision is intentional, add a short explanatory comment near the w and u
tl.dot calls documenting why w can use TF32 but u must disable it; update all
relevant tl.dot calls for w or add the single comment so the intent is explicit
and consistent across the kernel.
In `@fla/ops/gated_delta_rule/wy_fast.py`:
- Around line 406-413: The tl.dot call used to compute b_w (the expression "b_w
= tl.dot(b_A, (b_k * b_b[:, None]).to(b_k.dtype)) * b_G[:, None]") is missing
the allow_tf32 flag, causing inconsistent precision versus the b_u computation
(which sets allow_tf32=False); update the tl.dot for b_w to include
allow_tf32=False so both dot products use the same precision semantics.
- Around line 550-562: The variable b_M is allocated (b_M = tl.zeros(...)) and
accumulated (b_M += tl.dot(...)) but never used; remove the dead code by
deleting the b_M allocation and the accumulation expression inside the loop
(references to b_M), leaving the rest of the loop (p_k/p_dk, b_k, b_kt, b_ktb,
etc.) unchanged; verify no other code expects b_M or its value after the loop
and remove any now-unused imports or variables created solely for b_M if they
become dead.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: ef56151b-219b-4212-b5fa-bf2c9ec2b651
📒 Files selected for processing (3)
fla/ops/gated_delta_rule/chunk.pyfla/ops/gated_delta_rule/chunk_fwd.pyfla/ops/gated_delta_rule/wy_fast.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/gated_delta_rule/wy_fast.py`:
- Around line 390-414: The w-path currently uses the normalized b_G (exp(g -
b_g_max)) which loses the chunk max shift; restore the original G by multiplying
b_w by exp(b_g_max) (or exp2(b_g_max) when USE_EXP2) before storing. Concretely,
after computing b_w in the block that builds b_w (using b_k, b_b and b_G),
compute a scale = exp2(b_g_max) if USE_EXP2 else exp(b_g_max) and apply b_w =
b_w * scale[:, None] (use the same b_g_max used to normalize b_g); also apply
the identical change in the fused inference kernel mentioned (chunk_fwd.py
region 642-691) so both paths use the same max convention.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: efcd6100-a713-4f58-9a00-ee17fe1e955b
📒 Files selected for processing (2)
fla/ops/gated_delta_rule/chunk.pyfla/ops/gated_delta_rule/wy_fast.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 405-444: The function chunk_gated_delta_rule_fwd_kkt_solve
currently declares an unused parameter v; remove v from the function signature
(and its type hints) and from the internal parameter list passed to the triton
kernel call, then update every call site that passes v (notably the call in
fla/ops/gated_delta_rule/chunk.py where v=v is passed) to stop providing that
argument; ensure any import/type usages referencing the old signature are
updated to the new signature so there are no mismatched-argument errors.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: c09a3901-af8e-43fe-8f3d-6630c1cb3704
📒 Files selected for processing (2)
fla/ops/gated_delta_rule/chunk_fwd.pyfla/ops/gated_delta_rule/wy_fast.py
🚧 Files skipped from review as they are similar to previous changes (1)
- fla/ops/gated_delta_rule/wy_fast.py
- Replace gated coupling matrix with ungated KKT + G/G_inv post-scaling in WY step, saving ~BT² exp operations in both forward and backward - Add fusemaxxed kernel for inference (no_grad): fuses KKT+solve+WY into a single kernel where A never touches HBM - Backward uses precomputed w,u for cheap dG gradient (no extra matmuls) - Cache w,u from forward pass (T>2048) to skip recomputation in backward (mem_efficient=False by default)
…eanup - Align cache_wu threshold to T >= 2048 in both forward and autograd wrapper - Guard no-grad inference path to fall back when cp_context is active - Remove unused b_M variable and tl.debug_barrier() from backward kernel
…utation The g_max subtraction for numerical stability only cancels for u (where G * G_inv = 1), but w = G * (A @ (k*beta)) has no inverse term to cancel the shift. Use G_full = G * exp(g_max) for w so it matches exp(g) exactly.
19145b9 to
7004d19
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
fla/ops/gated_delta_rule/chunk.py (1)
55-75:⚠️ Potential issue | 🔴 CriticalDon’t route GQA through the tricked WY helpers yet.
This forward/backward path is now unconditional, but
tricked_recompute_w_u_fwdinfla/ops/gated_delta_rule/wy_fast.py:429-451andtricked_prepare_wy_repr_bwdinfla/ops/gated_delta_rule/wy_fast.py:589-625both derive their launch head count fromk.shape[2]. The public API here still acceptsH % Hq == 0, so grouped-head runs will only materialize the tricked WY state forHqheads and backward will mirror the same mismatch. Please fall back to the old WY path, or explicitly reject GQA until those kernels take bothHandHq.Also applies to: 189-201, 277-291
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/chunk.py` around lines 55 - 75, The new forward unconditionally calls the tricked WY helpers (tricked_recompute_w_u_fwd) which assume the kernel launch head-count equals k.shape[2], causing mismatch for grouped-head (GQA) runs; update the chunked forward/backward callers (the spots invoking chunk_gated_delta_rule_fwd_kkt_solve and tricked_recompute_w_u_fwd) to detect grouped-head mode (H != Hq or Hq < H) and either (a) fall back to the legacy WY code path used previously for those ranges, or (b) explicitly raise/return an error rejecting GQA until the tricked WY kernels (tricked_recompute_w_u_fwd and tricked_prepare_wy_repr_bwd) are updated to accept both H and Hq; ensure the check is applied at the other occurrences mentioned (the other two forward blocks corresponding to the review: the blocks around the other two ranges) so grouped-head runs never call the tricked WY helpers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 467-515: The fast-path in tricked_fusemaxxed_kernel (and the
corresponding lines in tricked_fused_infer_fwd) assumes a single head-count and
therefore uses k/v/w/u/beta/g strides derived from k.shape[2], which breaks when
model heads H != query-heads Hq; replicate the H↔Hq mapping used in
chunk_gated_delta_rule_fwd_kkt_solve_kernel so program IDs and pointer
arithmetic use the query-head index (map i_h -> i_hq or compute i_h_from_i_bh
the same way) and load/store v, beta, g, u using the correct per-head strides,
or alternatively gate this fast path to only run when H == Hq; update indexing
of k/v/w/u/beta/g and the early-return/program-launch condition accordingly.
- Around line 424-445: The Triton kernel launch for
chunk_gated_delta_rule_fwd_kkt_solve_kernel is missing the constexpr Hq
parameter causing incorrect pointer math for grouped-query attention; compute Hq
(e.g., Hq = g.shape[2] or otherwise obtain the grouped-query head count
available in this scope) and include it in the launch configuration alongside H
(i.e., change the launch tuple from (NT, B * H) to include Hq as a
template/constexpr parameter), and if the kernel signature expects Hq as an
explicit arg ensure you pass it as well so the kernel's key indexing (used in k
pointer arithmetic) is correct.
- Around line 645-687: The max-reduction uses zero-filled tail lanes from the
boundary_check loads (b_g0..b_g3), so mask out invalid lanes using the
per-subchunk masks (m_tc0, m_tc1, m_tc2, m_tc3) by replacing invalid entries
with -inf (or a sufficiently large negative float) before computing b_g_max;
update the code that computes b_g_max (and then b_g0..b_g3 adjustments) to use
the masked b_g* values so the normalization and subsequent b_Ginv* / b_Gw*
calculations remain numerically stable on partial chunks.
---
Outside diff comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 55-75: The new forward unconditionally calls the tricked WY
helpers (tricked_recompute_w_u_fwd) which assume the kernel launch head-count
equals k.shape[2], causing mismatch for grouped-head (GQA) runs; update the
chunked forward/backward callers (the spots invoking
chunk_gated_delta_rule_fwd_kkt_solve and tricked_recompute_w_u_fwd) to detect
grouped-head mode (H != Hq or Hq < H) and either (a) fall back to the legacy WY
code path used previously for those ranges, or (b) explicitly raise/return an
error rejecting GQA until the tricked WY kernels (tricked_recompute_w_u_fwd and
tricked_prepare_wy_repr_bwd) are updated to accept both H and Hq; ensure the
check is applied at the other occurrences mentioned (the other two forward
blocks corresponding to the review: the blocks around the other two ranges) so
grouped-head runs never call the tricked WY helpers.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 7f95c04d-57e1-4cd1-953e-5e7051d8b95d
📒 Files selected for processing (3)
fla/ops/gated_delta_rule/chunk.pyfla/ops/gated_delta_rule/chunk_fwd.pyfla/ops/gated_delta_rule/wy_fast.py
✅ Files skipped from review due to trivial changes (1)
- fla/ops/gated_delta_rule/wy_fast.py
| B, T, H, K = k.shape | ||
| BT = chunk_size | ||
| BC = 16 | ||
|
|
||
| if chunk_indices is None and cu_seqlens is not None: | ||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
|
|
||
| A = torch.zeros(B, T, H, BT, device=k.device, dtype=k.dtype) | ||
| chunk_gated_delta_rule_fwd_kkt_solve_kernel[(NT, B * H)]( | ||
| k=k, | ||
| g=g, | ||
| beta=beta, | ||
| A=A, | ||
| cu_seqlens=cu_seqlens, | ||
| chunk_indices=chunk_indices, | ||
| T=T, | ||
| H=H, | ||
| K=K, | ||
| BT=BT, | ||
| BC=BC, | ||
| USE_EXP2=use_exp2, |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the kernel definition and see what parameters it expects
head -100 fla/ops/gated_delta_rule/chunk_fwd.py | cat -nRepository: fla-org/flash-linear-attention
Length of output: 3880
🏁 Script executed:
# Get more context around the kernel definition (around lines 40-41)
sed -n '1,80p' fla/ops/gated_delta_rule/chunk_fwd.py | cat -nRepository: fla-org/flash-linear-attention
Length of output: 2870
🏁 Script executed:
# Look at the chunk_gated_delta_rule_fwd_kkt_solve_kernel definition
rg "chunk_gated_delta_rule_fwd_kkt_solve_kernel" fla/ops/gated_delta_rule/chunk_fwd.py -A 40 -B 5Repository: fla-org/flash-linear-attention
Length of output: 3756
🏁 Script executed:
# Now let's look at both launchers - the old one around 378-388
sed -n '370,395p' fla/ops/gated_delta_rule/chunk_fwd.py | cat -nRepository: fla-org/flash-linear-attention
Length of output: 928
🏁 Script executed:
# And the new one around 424-445
sed -n '420,450p' fla/ops/gated_delta_rule/chunk_fwd.py | cat -nRepository: fla-org/flash-linear-attention
Length of output: 1144
🏁 Script executed:
# Let's get the full function signatures to understand parameter structure better
sed -n '1,500p' fla/ops/gated_delta_rule/chunk_fwd.py | grep -A 50 "def chunk_gated_delta_rule_fwd_kkt_solve"Repository: fla-org/flash-linear-attention
Length of output: 3067
Pass Hq into this Triton launch.
chunk_gated_delta_rule_fwd_kkt_solve_kernel requires both H and Hq as constexpr parameters, but this wrapper derives H from k.shape[2] and never supplies Hq. The kernel uses Hq for key pointer indexing (line 81: k += (bos * Hq + i_h // (H // Hq)) * K) and will fail or produce incorrect results for grouped-query attention where Hq ≠ H.
Minimal fix
- B, T, H, K = k.shape
+ B, T, Hq, K = k.shape
+ H = beta.shape[2]
BT = chunk_size
BC = 16
@@
chunk_gated_delta_rule_fwd_kkt_solve_kernel[(NT, B * H)](
k=k,
g=g,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
+ Hq=Hq,
K=K,
BT=BT,
BC=BC,
USE_EXP2=use_exp2,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 424 - 445, The Triton
kernel launch for chunk_gated_delta_rule_fwd_kkt_solve_kernel is missing the
constexpr Hq parameter causing incorrect pointer math for grouped-query
attention; compute Hq (e.g., Hq = g.shape[2] or otherwise obtain the
grouped-query head count available in this scope) and include it in the launch
configuration alongside H (i.e., change the launch tuple from (NT, B * H) to
include Hq as a template/constexpr parameter), and if the kernel signature
expects Hq as an explicit arg ensure you pass it as well so the kernel's key
indexing (used in k pointer arithmetic) is correct.
| def tricked_fusemaxxed_kernel( | ||
| k, | ||
| v, | ||
| beta, | ||
| g, | ||
| w, | ||
| u, | ||
| cu_seqlens, | ||
| chunk_indices, | ||
| T, | ||
| H: tl.constexpr, | ||
| K: tl.constexpr, | ||
| V: tl.constexpr, | ||
| BT: tl.constexpr, | ||
| BC: tl.constexpr, | ||
| BK: tl.constexpr, | ||
| BV: tl.constexpr, | ||
| USE_EXP2: tl.constexpr, | ||
| IS_VARLEN: tl.constexpr, | ||
| ): | ||
| """ | ||
| Fused ungated kkt + solve_tril + tricked WY. | ||
|
|
||
| Computes (I+A)^{-1} entirely in registers (A never hits HBM), | ||
| then immediately uses it to compute w and u with G/G_inv scaling. | ||
| """ | ||
| i_t, i_bh = tl.program_id(0), tl.program_id(1) | ||
| i_b, i_h = i_bh // H, i_bh % H | ||
|
|
||
| if IS_VARLEN: | ||
| i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) | ||
| bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) | ||
| T = eos - bos | ||
| else: | ||
| bos, eos = i_b * T, i_b * T + T | ||
|
|
||
| if i_t * BT >= T: | ||
| return | ||
|
|
||
| i_tc0 = i_t * BT | ||
| i_tc1 = i_t * BT + BC | ||
| i_tc2 = i_t * BT + 2 * BC | ||
| i_tc3 = i_t * BT + 3 * BC | ||
|
|
||
| k += (bos * H + i_h) * K | ||
| v += (bos * H + i_h) * V | ||
| w += (bos * H + i_h) * K | ||
| u += (bos * H + i_h) * V | ||
|
|
There was a problem hiding this comment.
The fused inference fast path is not GQA-safe yet.
tricked_fusemaxxed_kernel indexes k, v, beta, and g with one shared head count, and tricked_fused_infer_fwd sets that count from k.shape[2]. For H != Hq, this launches only B*Hq programs, uses the wrong stride for v/beta/g, and leaves the extra value heads in u untouched. Please mirror the H/Hq mapping already used by chunk_gated_delta_rule_fwd_kkt_solve_kernel, or gate this fast path to H == Hq.
Also applies to: 786-807
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 467 - 515, The fast-path
in tricked_fusemaxxed_kernel (and the corresponding lines in
tricked_fused_infer_fwd) assumes a single head-count and therefore uses
k/v/w/u/beta/g strides derived from k.shape[2], which breaks when model heads H
!= query-heads Hq; replicate the H↔Hq mapping used in
chunk_gated_delta_rule_fwd_kkt_solve_kernel so program IDs and pointer
arithmetic use the query-head index (map i_h -> i_hq or compute i_h_from_i_bh
the same way) and load/store v, beta, g, u using the correct per-head strides,
or alternatively gate this fast path to only run when H == Hq; update indexing
of k/v/w/u/beta/g and the early-return/program-launch condition accordingly.
| # Phase 5: Load gating, cast Ai blocks for WY matmuls | ||
| p_g0 = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_tc0,), (BC,), (0,)) | ||
| p_g1 = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_tc1,), (BC,), (0,)) | ||
| p_g2 = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_tc2,), (BC,), (0,)) | ||
| p_g3 = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_tc3,), (BC,), (0,)) | ||
| b_g0 = tl.load(p_g0, boundary_check=(0,)).to(tl.float32) | ||
| b_g1 = tl.load(p_g1, boundary_check=(0,)).to(tl.float32) | ||
| b_g2 = tl.load(p_g2, boundary_check=(0,)).to(tl.float32) | ||
| b_g3 = tl.load(p_g3, boundary_check=(0,)).to(tl.float32) | ||
|
|
||
| # Normalize g by subtracting max across all sub-chunks to prevent exp overflow | ||
| b_g_max = tl.max(tl.maximum(tl.maximum(b_g0, b_g1), tl.maximum(b_g2, b_g3)), 0) | ||
| b_g0 = b_g0 - b_g_max | ||
| b_g1 = b_g1 - b_g_max | ||
| b_g2 = b_g2 - b_g_max | ||
| b_g3 = b_g3 - b_g_max | ||
|
|
||
| if USE_EXP2: | ||
| b_G0 = exp2(b_g0) | ||
| b_Ginv0 = exp2(-b_g0) | ||
| b_G1 = exp2(b_g1) | ||
| b_Ginv1 = exp2(-b_g1) | ||
| b_G2 = exp2(b_g2) | ||
| b_Ginv2 = exp2(-b_g2) | ||
| b_G3 = exp2(b_g3) | ||
| b_Ginv3 = exp2(-b_g3) | ||
| b_Gw0 = b_G0 * exp2(b_g_max) | ||
| b_Gw1 = b_G1 * exp2(b_g_max) | ||
| b_Gw2 = b_G2 * exp2(b_g_max) | ||
| b_Gw3 = b_G3 * exp2(b_g_max) | ||
| else: | ||
| b_G0 = exp(b_g0) | ||
| b_Ginv0 = exp(-b_g0) | ||
| b_G1 = exp(b_g1) | ||
| b_Ginv1 = exp(-b_g1) | ||
| b_G2 = exp(b_g2) | ||
| b_Ginv2 = exp(-b_g2) | ||
| b_G3 = exp(b_g3) | ||
| b_Ginv3 = exp(-b_g3) | ||
| b_Gw0 = b_G0 * exp(b_g_max) | ||
| b_Gw1 = b_G1 * exp(b_g_max) | ||
| b_Gw2 = b_G2 * exp(b_g_max) | ||
| b_Gw3 = b_G3 * exp(b_g_max) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
wc -l fla/ops/gated_delta_rule/chunk_fwd.pyRepository: fla-org/flash-linear-attention
Length of output: 115
🏁 Script executed:
cat -n fla/ops/gated_delta_rule/chunk_fwd.py | sed -n '630,700p'Repository: fla-org/flash-linear-attention
Length of output: 3462
🏁 Script executed:
cat -n fla/ops/gated_delta_rule/chunk_fwd.py | sed -n '400,450p'Repository: fla-org/flash-linear-attention
Length of output: 1823
🏁 Script executed:
# Search for mask definitions and boundary handling
rg -n "mask|boundary" fla/ops/gated_delta_rule/chunk_fwd.py | head -30Repository: fla-org/flash-linear-attention
Length of output: 2186
🏁 Script executed:
# Search for i_tc0, i_tc1, etc definitions
rg -n "i_tc" fla/ops/gated_delta_rule/chunk_fwd.pyRepository: fla-org/flash-linear-attention
Length of output: 6009
🏁 Script executed:
cat -n fla/ops/gated_delta_rule/chunk_fwd.py | sed -n '500,660p'Repository: fla-org/flash-linear-attention
Length of output: 8237
🏁 Script executed:
cat -n fla/ops/gated_delta_rule/chunk_fwd.py | sed -n '76,120p'Repository: fla-org/flash-linear-attention
Length of output: 2446
🏁 Script executed:
# Look for how g_max normalization is handled in the other kernel
rg -n "g_max|b_Gw" fla/ops/gated_delta_rule/chunk_fwd.py | head -20Repository: fla-org/flash-linear-attention
Length of output: 963
🏁 Script executed:
# Check the full first kernel to see if there's any masking pattern for g_max
cat -n fla/ops/gated_delta_rule/chunk_fwd.py | sed -n '50,250p'Repository: fla-org/flash-linear-attention
Length of output: 11235
Mask invalid lanes before reducing g_max.
The boundary_check loads zero-fill tail lanes. On any partial chunk where all valid gates are below zero, Line 656 picks the zero-filled out-of-bounds element as the max. This defeats the normalization for G_inv and can cause overflow instead of stabilizing it. Mask b_g* values (for example, using -inf on invalid lanes via m_tc*) before computing b_g_max in the reduction.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 645 - 687, The
max-reduction uses zero-filled tail lanes from the boundary_check loads
(b_g0..b_g3), so mask out invalid lanes using the per-subchunk masks (m_tc0,
m_tc1, m_tc2, m_tc3) by replacing invalid entries with -inf (or a sufficiently
large negative float) before computing b_g_max; update the code that computes
b_g_max (and then b_g0..b_g3 adjustments) to use the masked b_g* values so the
normalization and subsequent b_Ginv* / b_Gw* calculations remain numerically
stable on partial chunks.
Summary
Integrates the similarity-transform optimization (
N = G·M·G⁻¹) into FLA's GDN forward and backward paths, following Simon Veitner's observation (also independently noted in Comba).Instead of computing the gated coupling matrix
Ndirectly (requiringBT²exp ops), we:M(skipBT²exp ops)M⁻¹(same cost)G/G⁻¹diagonal scaling in the WY step (only2·BTexp ops)Net savings: 3968 fewer exp ops per chunk (BT=64).
Changes
wy_fast.py(new): fusemaxxed kernel (inference, A in registers), tricked WY fwd/bwd kernels, ungated kkt+solve wrapperchunk.py: Forward uses ungated kkt+solve + tricked WY; backward uses tricked WY bwd with precomputed w,u for cheap dG; inference auto-dispatches fusemaxxed (T≤8K) vs fused (T>8K); w/u caching for T≥2048Benchmark results
All benchmarks on NVIDIA H100, B=1,
H·Dh = 2048, bf16. Baseline: FLA 0.5.0 (commitf52529e). Averaged over 3 runs.Forward (no_grad) — auto-dispatch inference
Auto-dispatch picks fusemaxxed (kkt+solve+WY fused, A in registers) for T≤8K, and fused (kkt+solve)+WY for T>8K.
Forward + Backward
Test plan
Summary by CodeRabbit
New Features
Refactor