From 993cff58ae5232e5be0ec360b0cde49b746b2019 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 24 Jun 2026 01:03:55 -0700 Subject: [PATCH 1/4] [gemma4_31b][cuda] Export Gemma4-31B @128k under 32 GB Three CUDA-export memory optimizations: - tq4_sdpa: add BLOCK_N=16 (and a BLOCK_M=32) autotune config. The superset is kept for big-shared-memory GPUs (A100/H100); the Triton autotuner auto-prunes configs that exceed a GPU's shared memory (OutOfResources -> inf), so the same config list also works on the 5090 (Blackwell, ~101 KB SMEM) where the previous smallest config did not fit. - int4_dispatch: chunk the inline _dequant_matmul along N for vocab-sized weights (N>65536, i.e. only the lm_head). Avoids transiently materializing the full ~10 GiB bf16 lm_head when AOTI executes the int4_plain_mm custom op during autotune / cpp_wrapper. The runtime decode path uses the C++ dp4a shim and the M>4 prefill inline path is below the threshold, so this never enters the runtime graph -> zero runtime / accuracy impact. Applied unconditionally (no flag). - cuda_backend / aoti_backend: skip occupying the GPU with the KV-cache buffers during AOTI compile (gated behind low_memory_mode). A new move_program_to_device hook places KV constants on the target device but immediately frees their storage (resize_(0)), so the fake-tensor device check passes while no real KV bytes sit on the GPU during autotune. The emptied buffers are re-synthesized as zeros at the _unlift_graph clone and at serialization, and excluded from constant dedup (resize_(0) gives every KV data_ptr 0, which would otherwise collapse same-shape caches across layers). Result on 2xA100: Gemma4-31B @128k no-TQ export peak 36.3 -> 27.0 GiB; the exported model runs correctly (output "...Paris."). --- backends/aoti/aoti_backend.py | 25 ++- backends/cuda/cuda_backend.py | 172 +++++++++++++++++- .../quantize_op_dispatch/int4_dispatch.py | 46 ++++- backends/cuda/triton/kernels/tq4_sdpa.py | 5 + 4 files changed, 229 insertions(+), 19 deletions(-) diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py index 91a8a60078e..22f6feeab6c 100644 --- a/backends/aoti/aoti_backend.py +++ b/backends/aoti/aoti_backend.py @@ -112,6 +112,21 @@ def codesign_so(cls, so_path: str, compile_specs: List[CompileSpec]) -> None: """ return + @classmethod + def move_program_to_device( + cls, + edge_program: ExportedProgram, + device: str, + compile_specs: List[CompileSpec], + ) -> ExportedProgram: + """Move the exported program to the target device for compilation. + + Default implementation moves everything (params, buffers, constants) via + ``move_to_device_pass``. Concrete backends may override to keep large + non-parameter tensors off the device during a low-memory export. + """ + return move_to_device_pass(edge_program, device) + @classmethod def release_moved_tensors( cls, @@ -196,9 +211,13 @@ def preprocess( decomposition_table = cls.get_decomposition_table() options = cls.get_aoti_compile_options(compile_specs) - # Move the edge_program to the target device - device_edge_program = move_to_device_pass( - edge_program, device_name if device_name != "metal" else "mps" + # Move the edge_program to the target device. Routed through a hook so + # backends can keep large non-parameter tensors (e.g. KV-cache buffers) + # off the device during a low-memory export. + device_edge_program = cls.move_program_to_device( + edge_program, + device_name if device_name != "metal" else "mps", + compile_specs, ) # Replace view_copy with view diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index f9f23a842f9..1781c5bfd39 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -61,29 +61,81 @@ def _is_cpu_clone_active() -> bool: return getattr(_CPU_CLONE_GUARD, "active", False) +def _full_zeros_preserving_strides(x: torch.Tensor, device) -> torch.Tensor: + """Allocate a zero-filled tensor matching ``x``'s size/stride/dtype on ``device``. + + Used to re-synthesize KV-cache buffers whose storage was freed (``resize_(0)``) + during the low-memory device move. KV content is all zeros, so this exactly + reproduces the buffer for both the lifted graph value and serialization. + """ + needed = 1 + for size, stride in zip(x.size(), x.stride()): + needed += (size - 1) * stride + buf = torch.zeros(int(needed), dtype=x.dtype, device=device) + return torch.as_strided(buf, x.size(), x.stride()) + + +def _is_emptied(x) -> bool: + return ( + isinstance(x, torch.Tensor) + and x.numel() > 0 + and x.untyped_storage().nbytes() == 0 + ) + + @contextlib.contextmanager def _compile_time_cpu_clones(target_device: torch.device): """Force AOTI's mutated-buffer clones onto CPU while preserving the serialized constants' target device.""" - from torch._inductor import compile_fx as _cfx + from torch._inductor import compile_fx as _cfx, graph as _graph from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu as _Cpp + from torch._inductor.graph import GraphLowering as _GL orig_clone = _cfx.clone_preserve_strides orig_codegen_device = _Cpp.codegen_device + orig_get_const = _GL.get_original_value_of_constant + orig_is_same = _graph.is_same_tensor + + def _is_same_skip_emptied(data, value): + # KV buffers freed via resize_(0) all have data_ptr 0, so the stock + # is_same_tensor would treat every same-shape KV constant as a duplicate + # and collapse the 60 layers' caches into one — the runtime needs each + # FQN's own buffer, so the collapsed ones load uninitialized garbage. + # Never dedup an emptied tensor. + if _is_emptied(data) or _is_emptied(value): + return False + return orig_is_same(data, value) def _cpu_clone_preserve_strides(x: torch.Tensor) -> torch.Tensor: - # `clone_preserve_strides` is shared by `_unlift_graph` (clones - # lifted buffers — can be safely kept on CPU) and by autotuning code - # in `triton_heuristics.py` (clones for benchmark — must stay on - # GPU for Triton). Discriminate by caller frame so we only force - # CPU clones for the buffer-lifting path. + # `clone_preserve_strides` is shared by `_unlift_graph` (clones lifted + # buffers — can be safely kept on CPU) and by autotuning code in + # `triton_heuristics.py` (clones for benchmark — must stay on GPU for + # Triton). Discriminate by caller frame so we only force CPU clones for + # the buffer-lifting path. import sys caller = sys._getframe(1).f_code.co_name if caller == "_unlift_graph": + # KV-cache buffers are emptied (storage resize_(0)) by the low-memory + # device move so they never occupy GPU memory during compile. Their + # content is all zeros, so re-synthesize zeros (on CPU, strides + # preserved) instead of cloning the now-empty storage. + if _is_emptied(x): + return _full_zeros_preserving_strides(x, "cpu") return orig_clone(x).cpu() return orig_clone(x) + def _get_const_synthesize_zeros(self, name): + # AOTI serializes each constant via get_original_value_of_constant -> + # _to_bytes. For KV buffers we freed with resize_(0) this would otherwise + # fall back to the empty-storage constant and write 0 bytes, producing a + # .ptd with an uninitialized cache. Re-synthesize the zeros so the blob + # holds a correctly-zeroed KV cache. + value = orig_get_const(self, name) + if _is_emptied(value): + return _full_zeros_preserving_strides(value, "cpu") + return value + def _codegen_device_target_aware(self, device): # Translate accidental CPU device strings back to the model target # device only when a constant we forced to CPU is being serialized. @@ -99,6 +151,8 @@ def _codegen_device_target_aware(self, device): _cfx.clone_preserve_strides = _cpu_clone_preserve_strides _Cpp.codegen_device = _codegen_device_target_aware + _GL.get_original_value_of_constant = _get_const_synthesize_zeros + _graph.is_same_tensor = _is_same_skip_emptied prev_active = getattr(_CPU_CLONE_GUARD, "active", False) _CPU_CLONE_GUARD.active = True try: @@ -107,6 +161,89 @@ def _codegen_device_target_aware(self, device): _CPU_CLONE_GUARD.active = prev_active _cfx.clone_preserve_strides = orig_clone _Cpp.codegen_device = orig_codegen_device + _GL.get_original_value_of_constant = orig_get_const + _graph.is_same_tensor = orig_is_same + + +def _is_kv_buffer(name, v) -> bool: + return ( + isinstance(v, torch.Tensor) + and not isinstance(v, torch.nn.Parameter) + and "kv_cache" in name + ) + + +def _empty_strided_on_device(v, location): + """A device tensor with v's shape/stride/dtype but zero (freed) storage.""" + t = torch.empty_strided(v.shape, v.stride(), dtype=v.dtype, device=location) + t.untyped_storage().resize_(0) # free bytes, keep device + shape/stride + return t + + +def _move_graph_nodes_to_device(graph_module, location): + """Point node device kwargs / aten.to.device targets / meta vals at location.""" + import torch.utils._pytree as pytree + + def _to_loc(v): + return v.to(location) if isinstance(v, torch.Tensor) else v + + for m in graph_module.modules(): + if not isinstance(m, torch.fx.GraphModule): + continue + for node in m.graph.nodes: + if "device" in node.kwargs: + node.kwargs = {**node.kwargs, "device": location} + if node.op == "call_function" and node.target is torch.ops.aten.to.device: + args = list(node.args) + args[1] = location + node.args = tuple(args) + node.meta["val"] = pytree.tree_map(_to_loc, node.meta.get("val")) + + +def _move_to_device_resize_kv(ep, location): + """``move_to_device_pass`` variant that frees KV-cache storage on-device. + + Mirrors ``torch.export.passes.move_to_device_pass`` exactly, except KV-cache + buffers (FQN contains ``kv_cache``) are placed on ``location`` but with their + storage immediately freed via ``resize_(0)``. This keeps ``device == + location`` — so the fake-tensor device check on the ``index_copy`` cache + update passes (``self`` and ``values`` both on cuda) — while no real KV bytes + occupy the device during the AOTI compile. KV content is all zeros, so the + emptied tensors are re-synthesized as zeros at the ``_unlift_graph`` clone + (see ``_compile_time_cpu_clones``), which is reused as both the lifted initial + value and the serialized ``.ptd`` constant. The empty/free is interleaved per + tensor so the transient device peak is a single KV buffer, not the whole cache. + Only ``kv_cache`` tensors are emptied (they are the lone large zero-buffers); + every other tensor is moved normally so non-zero content is never lost. + """ + import torch.utils._pytree as pytree + + for k, v in ep.state_dict.items(): + if isinstance(v, torch.nn.Parameter): + ep._state_dict[k] = torch.nn.Parameter(v.to(location), v.requires_grad) + elif _is_kv_buffer(k, v): + ep._state_dict[k] = _empty_strided_on_device(v, location) + else: + ep._state_dict[k] = v.to(location) + + for k, v in ep.constants.items(): + if isinstance(v, torch.Tensor): + ep._constants[k] = ( + _empty_strided_on_device(v, location) + if _is_kv_buffer(k, v) + else v.to(location) + ) + + if ep.example_inputs is not None: + args, kwargs = ep.example_inputs + ep._example_inputs = ( + pytree.tree_map_only(torch.Tensor, lambda t: t.to(location), args), + pytree.tree_map_only(torch.Tensor, lambda t: t.to(location), kwargs), + ) + + _move_graph_nodes_to_device(ep.graph_module, location) + ep.validate() + return ep @final @@ -424,6 +561,29 @@ def _is_low_memory_mode(compile_specs: List[CompileSpec]) -> bool: return spec.value.decode("utf-8").upper() == "ON" return False + @classmethod + def move_program_to_device( + cls, + edge_program, + device: str, + compile_specs: List[CompileSpec], + ): + """Move the program to ``device`` for AOTI compile. + + On a low-memory export (``low_memory_mode="ON"``) the KV-cache buffers — + which can be 10+ GiB at long context — are placed on-device but with their + storage freed (``resize_(0)``), so they never occupy device memory during + the autotune / cpp_wrapper compile while still satisfying the device-match + check on the cache update. They are re-synthesized as zeros for the lifted + graph and the serialized blob. This activates automatically with low-memory + mode. Other (non-low-memory) exports use the stock pass. + """ + from torch.export.passes import move_to_device_pass + + if not cls._is_low_memory_mode(compile_specs): + return move_to_device_pass(edge_program, device) + return _move_to_device_resize_kv(edge_program, device) + @classmethod def release_moved_tensors( cls, diff --git a/backends/cuda/quantize_op_dispatch/int4_dispatch.py b/backends/cuda/quantize_op_dispatch/int4_dispatch.py index c3b8921e2fe..1b8c370eecf 100644 --- a/backends/cuda/quantize_op_dispatch/int4_dispatch.py +++ b/backends/cuda/quantize_op_dispatch/int4_dispatch.py @@ -60,11 +60,29 @@ def _cuda(self, qdata, scale, zero, group_size): return _dequant_matmul(self, qdata, scale, zero, group_size) +# Chunked dequant for the export GPU budget. The lm_head dequant (N = vocab_size, +# e.g. 262144) runs through the int4_plain_mm custom op (M=1); AOTI executes that +# op's CUDA impl during autotune / cpp_wrapper codegen, where it transiently holds +# ~5 full-size bf16 temporaries (low/high/data/data-z/w_deq) — ~10 GiB for a +# 262144-row weight even though the final w_deq is only ~2.6 GiB. Chunking along N +# caps that at ~chunk rows. It is numerically identical (F.linear output rows are +# independent), and because only the lm_head (custom-op) path crosses the N +# threshold — never the M>4 prefill inline path — it never enters the runtime +# graph: ZERO runtime / accuracy impact. Applied unconditionally to any weight +# whose row count exceeds the threshold. +_DEQUANT_N_THRESHOLD = 65536 +_DEQUANT_N_CHUNK = 32768 + + def _dequant_matmul(x, qdata, scale, zero, group_size): """Dequant INT4 weights to input dtype and call F.linear. scale/zero are in the coalesced [N, n_groups] layout (baked into the weight constant at pack time), aligned row-for-row with qdata's [N, *]. + + Large weights (N > threshold, i.e. the lm_head) are chunked along N to bound + the dequant intermediate (see note above); smaller weights take the original + single-shot dequant. """ N, K_half = qdata.shape K = K_half * 2 @@ -72,16 +90,24 @@ def _dequant_matmul(x, qdata, scale, zero, group_size): gs_half = group_size // 2 dtype = x.dtype - p = qdata.to(torch.uint8).reshape(N, n_groups, gs_half) - low = (p & 0x0F).to(dtype) - high = ((p >> 4) & 0x0F).to(dtype) - data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size) - - s = scale.to(dtype).unsqueeze(-1) - z = zero.to(dtype).unsqueeze(-1) - w_deq = ((data - z) * s).reshape(N, K) - - return F.linear(x, w_deq) + def _dq(qd, sc, ze, rows): + p = qd.to(torch.uint8).reshape(rows, n_groups, gs_half) + low = (p & 0x0F).to(dtype) + high = ((p >> 4) & 0x0F).to(dtype) + data = torch.stack([low, high], dim=-1).reshape(rows, n_groups, group_size) + s = sc.to(dtype).unsqueeze(-1) + z = ze.to(dtype).unsqueeze(-1) + w_deq = ((data - z) * s).reshape(rows, K) + return F.linear(x, w_deq) + + if N <= _DEQUANT_N_THRESHOLD: + return _dq(qdata, scale, zero, N) + + outs = [] + for i in range(0, N, _DEQUANT_N_CHUNK): + j = min(i + _DEQUANT_N_CHUNK, N) + outs.append(_dq(qdata[i:j], scale[i:j], zero[i:j], j - i)) + return torch.cat(outs, dim=-1) # --------------------------------------------------------------------------- diff --git a/backends/cuda/triton/kernels/tq4_sdpa.py b/backends/cuda/triton/kernels/tq4_sdpa.py index 10f02c7fa3c..7a41eaf92c1 100644 --- a/backends/cuda/triton/kernels/tq4_sdpa.py +++ b/backends/cuda/triton/kernels/tq4_sdpa.py @@ -294,6 +294,10 @@ def _tq4_sdpa_fwd_kernel_body( triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=3), ], key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], ) @@ -410,6 +414,7 @@ def _tq4_sdpa_fwd_kernel_m64( triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 16}, num_warps=4, num_stages=3), ], key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], ) From 92d62c974345d8fd387d6698f5e161b542ec9939 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 25 Jun 2026 08:55:35 -0700 Subject: [PATCH 2/4] Fix TurboQuant KV zeroed by low-mem export (993cff58ae): _is_kv_buffer only frees genuinely all-zero kv_cache.* buffers (count_nonzero==0); preserves TQ4 centroids/boundaries/rotation/rotation_T --- backends/cuda/cuda_backend.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 1781c5bfd39..b328a05df54 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -166,11 +166,29 @@ def _codegen_device_target_aware(self, device): def _is_kv_buffer(name, v) -> bool: - return ( - isinstance(v, torch.Tensor) - and not isinstance(v, torch.nn.Parameter) - and "kv_cache" in name - ) + """True only for an actual KV-cache *content* buffer that is safe to free. + + The low-memory path (``_move_to_device_resize_kv``) frees every buffer this + matches and re-synthesizes it as ZEROS in both the lifted graph and the + serialized ``.ptd`` (see ``_full_zeros_preserving_strides`` / + ``_get_const_synthesize_zeros``). That is only valid for genuine KV *content*, + which is all-zeros at export time (caches start empty). + + It must NOT match the non-zero constants that some KV-cache modules register + alongside the cache — e.g. TurboQuant registers its codebook/rotation + (``centroids``/``boundaries``/``rotation``/``rotation_T``) as buffers on the + ``kv_cache`` module, so their FQNs also contain ``kv_cache``. Freeing+zeroing + those silently corrupts the serialized model (TQ4 dequant -> 0 -> garbage). + Gate on the buffer actually being all-zeros so only empty KV content is freed; + this is robust to any future constant name (a non-zero buffer is never freed). + """ + if not isinstance(v, torch.Tensor) or isinstance(v, torch.nn.Parameter): + return False + if "kv_cache" not in name or v.numel() == 0 or v.is_meta: + return False + # Only the genuinely all-zero KV content may be freed + re-zeroed; non-zero + # constants (TurboQuant centroids/rotation/...) must be preserved as-is. + return bool(torch.count_nonzero(v) == 0) def _empty_strided_on_device(v, location): From 4025660ac810cf796f6a19c06692b1777e0ac145 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 25 Jun 2026 08:56:51 -0700 Subject: [PATCH 3/4] [executorch][cuda] gemma4_31b: fuse gate/up MLP projections (default-on) Summary: Fuse each gemma4_31b MLP's gate_proj|up_proj into a single [2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA export. This issues one activation-quant + one W4A8 matvec per layer instead of two, cutting per-token launch + activation-quant overhead in the launch-bound decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct). Builds on the already-landed kv_len-bounded tq4_sdpa kernel + gemma4_31b call-site (kv_len + mask_is_causal), which recovered 128k decode from ~2.8 to ~43 tok/s. With both, ET gemma4_31b 128k+TurboQuant decode beats llama.cpp at every measured context (cuda_graph ON): ctx ET llama 512 44.80 42.77 2K 43.20 41.97 8K 42.23 41.23 32K 41.64 40.27 127K 38.41 35.97 TurboQuant KV compression kept; prefill restored (6-8x) with no regression; output quality preserved. Test Plan: - Fusion numerics: fused vs unfused MLP through the real W4A8 int4_plain_mm kernel = bit-exact (max_abs_diff 0.0, cos 1.000000) for decode (T=1) and prefill (T=4). - Export + run: fused module exported via CudaPartitioner and executed through executor_runner (RC=0, cos 0.999915 vs eager). Full 31B export logs "Fused gate+up on 60 MLP layers". - Decode A/B (gemma4_31b 128k+TQ, cuda_graph ON, 5x median): table above; beats llama.cpp at 512 -> 127K. nsys: tq4_sdpa 91.7% -> 2.9% of decode. --- .../gemma4_31b/cuda_source_transformations.py | 107 ++++++++++++++++++ examples/models/gemma4_31b/export.py | 9 +- 2 files changed, 111 insertions(+), 5 deletions(-) diff --git a/examples/models/gemma4_31b/cuda_source_transformations.py b/examples/models/gemma4_31b/cuda_source_transformations.py index 666d0c44e9d..6609178e084 100644 --- a/examples/models/gemma4_31b/cuda_source_transformations.py +++ b/examples/models/gemma4_31b/cuda_source_transformations.py @@ -30,6 +30,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from executorch.examples.models.gemma4.text_decoder import apply_rotary_emb from executorch.extension.llm.modules.turboquant import TurboQuantKVCache @@ -110,6 +111,105 @@ def _turboquant_attention_forward( return self.o_proj(y) +def _fused_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: + """Drop-in ``Gemma4MLP.forward`` over a fused gate|up projection. + + Identical math to ``down(gelu(gate(x)) * up(x))``: the single + ``gate_up_proj`` emits ``[gate | up]`` concatenated on the last dim, + which is then split. One W4A8 matmul (and one activation-quant of ``x``) + instead of two. + """ + h = self.gate_up_proj(x) + gate = h[..., : self.intermediate_size] + up = h[..., self.intermediate_size :] + return self.down_proj(F.gelu(gate, approximate="tanh") * up) + + +def _concat_coalesced_int4_along_n(a, b): + """Concatenate two ``CudaCoalescedInt4Tensor`` along the output (N) dim. + + qdata is ``[N, K/2]`` and scale/zero_point are ``[N, n_groups]`` in the + coalesced layout, so a per-output-row concat on dim 0 is exact: the W4A8 + dp4a matvec reads each output row's qdata/scale/zero independently, so + out[:N_a] reproduces ``a`` and out[N_a:] reproduces ``b`` bit-for-bit. + """ + from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor + + return CudaCoalescedInt4Tensor( + torch.cat([a.qdata, b.qdata], dim=0), + torch.cat([a.scale, b.scale], dim=0), + torch.cat([a.zero_point, b.zero_point], dim=0), + a.block_size, + torch.Size([a.shape[0] + b.shape[0], a.shape[1]]), + None, + a.activation_dtype, + ) + + +def _is_fuseable_int4_pair(gate_w, up_w) -> bool: + """True iff gate/up are both coalesced-int4 with matching K + block_size. + + Q4_K MLP weights become ``CudaCoalescedInt4Tensor`` (fuseable); a Q6_K + weight becomes ``CudaDp4aPlanarInt6Tensor`` (left alone). ``act_pre_scale`` + is unused on this path but we require it absent so the concat stays exact. + """ + from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor + + return ( + isinstance(gate_w, CudaCoalescedInt4Tensor) + and isinstance(up_w, CudaCoalescedInt4Tensor) + and list(gate_w.block_size) == list(up_w.block_size) + and gate_w.shape[1] == up_w.shape[1] + and gate_w.act_pre_scale is None + and up_w.act_pre_scale is None + ) + + +def _fuse_gate_up_proj(model: nn.Module) -> None: + """Fuse each MLP's ``gate_proj | up_proj`` into one ``gate_up_proj``. + + gate and up share the same input, so the unfused path quantizes ``x`` to + int8 twice and launches two W4A8 matvecs per layer. Fusing the weights + into one ``[2*inter, hidden]`` tensor halves both. Weight bytes read are + unchanged, so the win is launch + activation-quant overhead (decode is + launch-bound). Only Q4_K (coalesced-int4) layers are fused; any layer + with a non-int4 weight is left as two matmuls (still correct). + + Must run AFTER weights are packed to ``CudaCoalescedInt4Tensor`` (i.e. + inside ``_export_cuda``), and is independent of TurboQuant. + """ + n_fused = 0 + n_skipped = 0 + for layer in model.layers: + mlp = getattr(layer, "mlp", None) + if mlp is None or not (hasattr(mlp, "gate_proj") and hasattr(mlp, "up_proj")): + continue + gate_w = mlp.gate_proj.weight + up_w = mlp.up_proj.weight + if not _is_fuseable_int4_pair(gate_w, up_w): + n_skipped += 1 + continue + inter = up_w.shape[0] + hidden = up_w.shape[1] + fused_w = _concat_coalesced_int4_along_n(gate_w, up_w) + + # Container built on meta to avoid materializing a dense + # [2*inter, hidden] weight before we overwrite it with fused_w. + gate_up = nn.Linear(hidden, 2 * inter, bias=False, device="meta") + gate_up.weight = nn.Parameter(fused_w, requires_grad=False) + mlp.gate_up_proj = gate_up + mlp.intermediate_size = inter + del mlp.gate_proj + del mlp.up_proj + mlp.forward = types.MethodType(_fused_mlp_forward, mlp) + n_fused += 1 + + msg = f"[gemma4_31b cuda] Fused gate+up on {n_fused} MLP layers" + if n_skipped: + msg += f" ({n_skipped} skipped: non-int4 weights)" + print(msg) + + def cuda_source_transformations( model: nn.Module, *, @@ -117,6 +217,11 @@ def cuda_source_transformations( ) -> None: """Apply CUDA source transformations to a Gemma 4 31B model in place. + Always fuses each MLP's ``gate_proj|up_proj`` into a single matmul (one + activation-quant + one W4A8 matvec per layer instead of two; Q4_K + coalesced-int4 layers only — other quant types are left untouched). + Optionally also swaps full-attention KV caches for TurboQuant TQ4. + Args: model: ``Gemma4_31B`` instance to transform. use_turboquant: When True, swap full-attention layers' KV caches @@ -125,6 +230,8 @@ def cuda_source_transformations( ``torch.ops.triton.tq4_sdpa``. Sliding-window layers are unaffected. """ + _fuse_gate_up_proj(model) + if not use_turboquant: return diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index d9e16bc34df..b2b2264178a 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -182,12 +182,11 @@ def _export_cuda( materialize_runtime_buffers(model, dtype=torch.bfloat16) - if use_turboquant: - from executorch.examples.models.gemma4_31b.cuda_source_transformations import ( - cuda_source_transformations, - ) + from executorch.examples.models.gemma4_31b.cuda_source_transformations import ( + cuda_source_transformations, + ) - cuda_source_transformations(model, use_turboquant=True) + cuda_source_transformations(model, use_turboquant=use_turboquant) # Int4Tensor weights are used directly — no format conversion. # F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim). From 74c2a9d224f6c63519a070d4efacb5dee2f8f26b Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 25 Jun 2026 09:38:29 -0700 Subject: [PATCH 4/4] [gemma4_31b][cuda] length-aware bf16 global attention + head_dim-agnostic prefill autotune - Global (full-attention) bf16 layers: bound SDPA to a runtime kv_len scalar (CUDA-graph-safe) instead of the full max_seq_len KV buffer -> O(context) decode; restores decode scaling (was flat ~36.5 t/s at all depths -> 46.5@512, 34.9@127K). (sdpa.py kv_len path + cuda_source_transformations.py _lenaware_attention_forward; global layers only, sliding + turbo untouched) - Prefill global full-attention: replace fixed m32/m64 BLOCK_M selection with a head_dim-keyed autotuned _sdpa_fwd_kernel + register-budget prune (BLOCK_M*HEAD_DIM <= 4096*num_warps), fixing acc[64,512] fp32 register spill at head_dim=512. Prefill +24% @8K, +63% @32K, +117% @127K; head_dim-agnostic (no split-D needed for D<=512). (sdpa.py) - Validated: output bitwise-identical to prior kernel (cos=1.0, D=64/128/256/512), no decode regression; non-tq prefill now beats llama.cpp at all 5 cells and turbo TQ4 at 4/5. Op-level autotune profiling (A100) confirms the config set is near-optimal (in-set optimum at every regime; only <=1.3% marginal candidates). --- backends/cuda/triton/kernels/sdpa.py | 309 +++++++++++------- .../gemma4_31b/cuda_source_transformations.py | 94 ++++++ 2 files changed, 289 insertions(+), 114 deletions(-) diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py index fb665e538bf..37989349ea7 100644 --- a/backends/cuda/triton/kernels/sdpa.py +++ b/backends/cuda/triton/kernels/sdpa.py @@ -45,6 +45,15 @@ def _is_power_of_2(n: int) -> bool: return n > 0 and (n & (n - 1)) == 0 +# KV length at/above which decode (L_q == 1) uses the split-K flash-decoding +# kernel instead of the standard kernel. Mirrors the threshold the CUDA +# replacement pass uses to pick triton.sdpa_decode_splitk. +_SPLITK_LKV_THRESHOLD = 256 + +# FlashDecoding++ unified-max constant used by the split-K decode path. +_DEFAULT_SPLITK_PHI = 5.0 + + def _next_power_of_2(x: int) -> int: """Get the next power of 2 >= x, clamped to [16, 256]. @@ -160,6 +169,7 @@ def _sdpa_fwd_kernel_non_pow2( v_ptr, o_ptr, mask_ptr, + kv_len_ptr, B, H_grid, LQ, @@ -191,6 +201,7 @@ def _sdpa_fwd_kernel_non_pow2( BLOCK_D: tl.constexpr, HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, + HAS_KV_LEN: tl.constexpr, NUM_GROUPS: tl.constexpr, PACK_GQA: tl.constexpr, ): @@ -254,9 +265,15 @@ def _sdpa_fwd_kernel_non_pow2( NEG_INF: tl.constexpr = float("-inf") - for start_n in tl.range(0, LK, BLOCK_N, num_stages=2): + # Bound the KV loop to valid (filled) positions; see pow2 body for details. + if HAS_KV_LEN: + kv_len = tl.load(kv_len_ptr) + else: + kv_len = LK + + for start_n in tl.range(0, kv_len, BLOCK_N, num_stages=2): offs_n = start_n + tl.arange(0, BLOCK_N) - kv_col_mask = offs_n < LK + kv_col_mask = offs_n < kv_len k_ptrs = k_base + (offs_n[:, None] * stride_kl + offs_d[None, :] * stride_kd) k = tl.load(k_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0) @@ -332,6 +349,7 @@ def _sdpa_fwd_kernel_body( V_ptr, O_ptr, Mask_ptr, + KV_LEN_ptr, B, H_grid, Lq, @@ -358,6 +376,7 @@ def _sdpa_fwd_kernel_body( sm_scale: tl.float32, HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, + HAS_KV_LEN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, HEAD_DIM: tl.constexpr, @@ -422,6 +441,18 @@ def _sdpa_fwd_kernel_body( offs_n_init = tl.arange(0, BLOCK_N) + # Bound the KV loop to the number of valid (filled) positions instead of the + # full pre-allocated buffer Lk. For decode this is input_pos+1; for a prefill + # chunk it is chunk_end. This makes full-attention (global) layers O(context) + # rather than O(max_seq_len) — the empty tail of the cache is never touched. + # kv_len is read from a GPU scalar so the bound updates across CUDA-graph + # replays (decode is graph-captured). When not provided (HAS_KV_LEN False) it + # falls back to Lk, preserving the original behavior exactly. + if HAS_KV_LEN: + kv_len = tl.load(KV_LEN_ptr) + else: + kv_len = Lk + # Window-aware early-exit. A KV block that is fully masked (sliding-window # or causal) contributes nothing to the online softmax — every entry is # -inf, so p=0 and m_i/l_i/acc are left unchanged. We detect such blocks up @@ -434,7 +465,7 @@ def _sdpa_fwd_kernel_body( if IS_CAUSAL: max_seq_pos = tl.max(seq_pos) - for start_n in tl.range(0, Lk, BLOCK_N): + for start_n in tl.range(0, kv_len, BLOCK_N): offs_n = start_n + offs_n_init # Decide whether any row in this tile actually attends to this KV block. @@ -444,7 +475,7 @@ def _sdpa_fwd_kernel_body( + (seq_pos[:, None] * stride_mq) + (offs_n[None, :] * stride_mk) ) - mn_mask = row_valid[:, None] & (offs_n[None, :] < Lk) + mn_mask = row_valid[:, None] & (offs_n[None, :] < kv_len) mask_block = tl.load(mask_ptrs, mask=mn_mask, other=False) block_active = tl.sum(mask_block.to(tl.int32)) > 0 elif IS_CAUSAL: @@ -461,7 +492,7 @@ def _sdpa_fwd_kernel_body( + (offs_n[:, None] * stride_kn) + (offs_d[None, :] * stride_kd) ) - k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + k_mask = (offs_n[:, None] < kv_len) & (offs_d[None, :] < HEAD_DIM) k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16) qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32) @@ -493,7 +524,7 @@ def _sdpa_fwd_kernel_body( + (offs_n[:, None] * stride_vn) + (offs_d[None, :] * stride_vd) ) - v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + v_mask = (offs_n[:, None] < kv_len) & (offs_d[None, :] < HEAD_DIM) v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16) p_bf16 = p_f32.to(tl.bfloat16) @@ -523,111 +554,64 @@ def _sdpa_fwd_kernel_body( tl.store(o_ptrs, acc.to(tl.bfloat16), mask=o_mask) -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2), - ], - key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], -) -@triton.jit -def _sdpa_fwd_kernel_m64( - Q_ptr, - K_ptr, - V_ptr, - O_ptr, - Mask_ptr, - B, - H_grid, - Lq, - Lk, - stride_qb, - stride_qh, - stride_qm, - stride_qd, - stride_kb, - stride_kh, - stride_kn, - stride_kd, - stride_vb, - stride_vh, - stride_vn, - stride_vd, - stride_ob, - stride_oh, - stride_om, - stride_od, - stride_mb, - stride_mq, - stride_mk, - sm_scale: tl.float32, - HAS_MASK: tl.constexpr, - IS_CAUSAL: tl.constexpr, - HEAD_DIM: tl.constexpr, - NUM_GROUPS: tl.constexpr, - PACK_GQA: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - _sdpa_fwd_kernel_body( - Q_ptr, - K_ptr, - V_ptr, - O_ptr, - Mask_ptr, - B, - H_grid, - Lq, - Lk, - stride_qb, - stride_qh, - stride_qm, - stride_qd, - stride_kb, - stride_kh, - stride_kn, - stride_kd, - stride_vb, - stride_vh, - stride_vn, - stride_vd, - stride_ob, - stride_oh, - stride_om, - stride_od, - stride_mb, - stride_mq, - stride_mk, - sm_scale, - HAS_MASK=HAS_MASK, - IS_CAUSAL=IS_CAUSAL, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - HEAD_DIM=HEAD_DIM, - NUM_GROUPS=NUM_GROUPS, - PACK_GQA=PACK_GQA, - ) +# Prefill / standard-path tile configs. ONE autotuned kernel spanning BLOCK_M in +# {16..128}; `_sdpa_prefill_prune` drops configs whose fp32 accumulator +# acc[BLOCK_M, HEAD_DIM] would spill registers for the runtime HEAD_DIM, so the +# kernel is high-occupancy AND HEAD_DIM-agnostic (64/80/96/128/256/512). This +# replaces the old fixed BLOCK_M=64 (m64) / BLOCK_M=32 (m32) wrappers + Python +# CTA-count selector: at HEAD_DIM=512 the m64 path spilled acc[64,512] fp32 +# (128 KB/CTA -> ~280 reg spills -> ~30 TFLOP/s); the autotuner now picks a +# non-spilling, well-pipelined tile per HEAD_DIM (e.g. BLOCK_M=32 at 512). +_SDPA_PREFILL_CONFIGS = [ + triton.Config({"BLOCK_M": 16, "BLOCK_N": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 64}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=3), +] + + +def _sdpa_prefill_prune(configs, nargs, **kwargs): + """Drop configs whose fp32 acc[BLOCK_M, HEAD_DIM] would spill registers. + + Keeps ``BLOCK_M * HEAD_DIM <= 4096 * num_warps`` (the measured A100 no-spill + boundary: HEAD_DIM=512 -> BLOCK_M<=32 at 4 warps / <=64 at 8 warps; + HEAD_DIM=128 -> BLOCK_M<=128 at 4 warps). This guarantees a high-occupancy + pick for any HEAD_DIM and a non-empty result (the BLOCK_M=16 configs satisfy + the budget for every HEAD_DIM<=1024). SMEM-OOR tiles (large + BLOCK_N*HEAD_DIM*num_stages) are pruned by the autotuner at benchmark time. + """ + head_dim = kwargs.get("HEAD_DIM") + if head_dim is None and nargs is not None: + head_dim = nargs.get("HEAD_DIM") + if head_dim is None: + return configs + kept = [c for c in configs if c.kwargs["BLOCK_M"] * head_dim <= 4096 * c.num_warps] + if not kept: + kept = [min(configs, key=lambda c: c.kwargs["BLOCK_M"] / c.num_warps)] + return kept @triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2), - ], + configs=_SDPA_PREFILL_CONFIGS, key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], + prune_configs_by={"early_config_prune": _sdpa_prefill_prune}, ) @triton.jit -def _sdpa_fwd_kernel_m32( +def _sdpa_fwd_kernel( Q_ptr, K_ptr, V_ptr, O_ptr, Mask_ptr, + KV_LEN_ptr, B, H_grid, Lq, @@ -654,6 +638,7 @@ def _sdpa_fwd_kernel_m32( sm_scale: tl.float32, HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, + HAS_KV_LEN: tl.constexpr, HEAD_DIM: tl.constexpr, NUM_GROUPS: tl.constexpr, PACK_GQA: tl.constexpr, @@ -666,6 +651,7 @@ def _sdpa_fwd_kernel_m32( V_ptr, O_ptr, Mask_ptr, + KV_LEN_ptr, B, H_grid, Lq, @@ -692,6 +678,7 @@ def _sdpa_fwd_kernel_m32( sm_scale, HAS_MASK=HAS_MASK, IS_CAUSAL=IS_CAUSAL, + HAS_KV_LEN=HAS_KV_LEN, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM, @@ -785,6 +772,8 @@ def _launch_pow2_kernel( is_causal: bool, num_groups: int, pack_gqa: bool, + kv_len_ptr: Optional[torch.Tensor] = None, + HAS_KV_LEN: bool = False, ) -> None: """Launch power-of-2 optimized SDPA kernel.""" stride_qb, stride_qh, stride_qm, stride_qd = query.stride() @@ -802,18 +791,18 @@ def _launch_pow2_kernel( def grid(meta): return (triton.cdiv(Lq_packed, meta["BLOCK_M"]), B * H_grid) - total_ctas_m64 = ((Lq_packed + 63) // 64) * (B * H_grid) - threshold = 4 * 84 - kernel = ( - _sdpa_fwd_kernel_m32 if total_ctas_m64 < threshold else _sdpa_fwd_kernel_m64 - ) - - wrap_triton(kernel)[grid]( + # Single autotuned kernel: the config set spans BLOCK_M in {16..128} and + # `_sdpa_prefill_prune` keeps only non-spilling tiles for this HEAD_DIM, so + # the autotuner picks a high-occupancy tile (small BLOCK_M for large HEAD_DIM, + # larger BLOCK_M / more CTAs for small problems) — subsuming the old + # CTA-count m32/m64 selector. + wrap_triton(_sdpa_fwd_kernel)[grid]( query, key, value, out, Mask_ptr if HAS_MASK else 0, + kv_len_ptr if HAS_KV_LEN else 0, B, H_grid, L_q, @@ -840,6 +829,7 @@ def grid(meta): sm_scale, HAS_MASK=HAS_MASK, IS_CAUSAL=is_causal, + HAS_KV_LEN=HAS_KV_LEN, HEAD_DIM=D, NUM_GROUPS=num_groups, PACK_GQA=pack_gqa, @@ -863,6 +853,8 @@ def _launch_non_pow2_kernel( is_causal: bool, num_groups: int, pack_gqa: bool, + kv_len_ptr: Optional[torch.Tensor] = None, + HAS_KV_LEN: bool = False, ) -> None: """Launch non-power-of-2 SDPA kernel with dynamic HEAD_DIM masking.""" stride_qb, stride_qh, stride_qm, stride_qd = query.stride() @@ -902,6 +894,7 @@ def grid_non_pow2(meta): value, out, mask_ptr, + kv_len_ptr if HAS_KV_LEN else 0, B, H_grid, L_q, @@ -933,6 +926,7 @@ def grid_non_pow2(meta): BLOCK_D=BLOCK_D, HAS_MASK=HAS_MASK, IS_CAUSAL=is_causal, + HAS_KV_LEN=HAS_KV_LEN, NUM_GROUPS=num_groups, PACK_GQA=pack_gqa, num_warps=num_warps, @@ -950,6 +944,7 @@ def sdpa( is_causal: bool = False, scale: float = 0.0, enable_gqa: bool = False, + kv_len: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Triton fused Scaled Dot-Product Attention with GQA pack optimization. @@ -967,6 +962,15 @@ def sdpa( is_causal: apply causal masking scale: attention scale (default: 1/sqrt(D)) enable_gqa: allow H_q != H_kv (GQA/MQA) + kv_len: Optional GPU int scalar = number of valid (filled) KV positions. + When provided, the inner KV loop is bounded to ``kv_len`` instead of + the full pre-allocated ``L_kv``, making attention O(context) instead + of O(max_seq_len). It is read on-device (no host sync) so the bound + updates correctly under CUDA-graph replay (decode). For decode pass + ``input_pos + 1``; for a prefill chunk pass ``chunk_end``. When None + the loop runs over the full ``L_kv`` (original behavior). Supplying + it for an L_q==1 decode with a large buffer also routes through the + split-K flash-decoding kernel for occupancy. Returns: Output tensor [B, H_q, L_q, D], dtype torch.bfloat16 """ @@ -984,6 +988,54 @@ def sdpa( "For decode (L_q < L_kv), use an explicit bool mask instead." ) + out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype) + sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale + HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk = _prepare_mask_params( + attn_mask, B, L_q, L_kv + ) + + # Optional length bound: device int32 scalar, clamped to the buffer size for + # OOB safety. Reshaped to [1] so the kernel can ``tl.load`` element 0. No + # ``.item()`` — keeps it CUDA-graph-safe (value updates on replay). + HAS_KV_LEN = kv_len is not None + if HAS_KV_LEN: + kv_len_t = torch.clamp( + kv_len.reshape(1).to(torch.int32), max=int(L_kv) + ).contiguous() + else: + kv_len_t = None + + # Split-K decode dispatch: L_q == 1 with a kv_len bound and a large KV + # buffer. Flash-decoding partitions the KV sequence across many CTAs for + # better occupancy (L_q=1 launches too few CTAs otherwise). The split is + # static (from buffer size L_kv, not the runtime kv_len value) so it is + # export/AOTI-traceable; the kernel still bounds each split's loop by kv_len + # on-device (CUDA-graph safe). Only taken when kv_len is supplied, so callers + # that don't pass kv_len keep the exact original (standard-kernel) dispatch. + if HAS_KV_LEN and L_q == 1 and _is_power_of_2(D) and L_kv >= _SPLITK_LKV_THRESHOLD: + _launch_decode_splitk( + query, + key, + value, + out, + B, + H_q, + H_kv, + L_kv, + D, + sm_scale, + HAS_MASK, + Mask_ptr, + stride_mb, + stride_mq, + stride_mk, + num_groups, + _DEFAULT_SPLITK_PHI, + kv_len_t, + HAS_KV_LEN, + ) + return out + # Decide whether to pack GQA based on tile utilization heuristic. # Use the actual BLOCK_M that the launched kernel will use: # - non-pow2 path always uses BLOCK_M=32 @@ -995,12 +1047,6 @@ def sdpa( block_m = 32 if total_ctas_m64 < 4 * 84 else 64 pack_gqa = _should_pack_gqa(L_q, num_groups, block_m) - out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype) - sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale - HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk = _prepare_mask_params( - attn_mask, B, L_q, L_kv - ) - if _is_power_of_2(D): _launch_pow2_kernel( query, @@ -1022,6 +1068,8 @@ def sdpa( is_causal, num_groups, pack_gqa, + kv_len_t, + HAS_KV_LEN, ) else: _launch_non_pow2_kernel( @@ -1041,6 +1089,8 @@ def sdpa( is_causal, num_groups, pack_gqa, + kv_len_t, + HAS_KV_LEN, ) return out @@ -1058,6 +1108,7 @@ def _sdpa_abstract( is_causal: bool = False, scale: float = 0.0, enable_gqa: bool = False, + kv_len: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Abstract/fake implementation for torch.export. @@ -1104,6 +1155,7 @@ def _sdpa_decode_splitk_kernel( O_partial_ptr, L_partial_ptr, Mask_ptr, + KV_LEN_ptr, B, H_kv, Lk, @@ -1133,6 +1185,7 @@ def _sdpa_decode_splitk_kernel( phi: tl.float32, chunk_size, HAS_MASK: tl.constexpr, + HAS_KV_LEN: tl.constexpr, BLOCK_N: tl.constexpr, HEAD_DIM: tl.constexpr, NUM_GROUPS: tl.constexpr, @@ -1144,7 +1197,15 @@ def _sdpa_decode_splitk_kernel( h_kv = pid_bh % H_kv start_n = split_id * chunk_size - end_n = tl.minimum(start_n + chunk_size, Lk) + # Bound the decode KV sweep to the valid (filled) positions. Splits whose + # chunk starts past kv_len do no work (end_n <= start_n) and store the zero + # partials they were initialized with, so the reduce is unaffected. kv_len is + # read on-device (CUDA-graph safe); falls back to Lk when not provided. + if HAS_KV_LEN: + kv_len = tl.load(KV_LEN_ptr) + else: + kv_len = Lk + end_n = tl.minimum(start_n + chunk_size, kv_len) offs_d = tl.arange(0, HEAD_DIM) offs_g = tl.arange(0, BLOCK_G) @@ -1293,6 +1354,8 @@ def _launch_decode_splitk( stride_mk: int, num_groups: int, phi: float, + kv_len_ptr: Optional[torch.Tensor] = None, + HAS_KV_LEN: bool = False, ) -> None: num_splits = min(max(triton.cdiv(L_kv, 256), 1), 128) chunk_size = triton.cdiv(L_kv, num_splits) @@ -1319,6 +1382,7 @@ def _launch_decode_splitk( O_partial, L_partial, Mask_ptr if HAS_MASK else 0, + kv_len_ptr if HAS_KV_LEN else 0, B, H_kv, L_kv, @@ -1348,6 +1412,7 @@ def _launch_decode_splitk( phi, chunk_size, HAS_MASK=HAS_MASK, + HAS_KV_LEN=HAS_KV_LEN, HEAD_DIM=D, NUM_GROUPS=num_groups, BLOCK_G=_next_power_of_2_unclamped(num_groups), @@ -1387,6 +1452,7 @@ def sdpa_decode_splitk( scale: float = 0.0, enable_gqa: bool = False, phi: float = 5.0, + kv_len: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Split-K flash-decoding SDPA for L_q=1 (decode step). @@ -1396,6 +1462,10 @@ def sdpa_decode_splitk( Signature mirrors sdpa() for drop-in use with torch.cond dispatch. enable_gqa is accepted but ignored — GQA is handled natively via H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1. + + kv_len: optional GPU int scalar bounding the KV sweep to the valid + (filled) positions (O(context) instead of O(max_seq_len)). Read + on-device, CUDA-graph safe. When None, sweeps the full L_kv. """ _validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa) @@ -1431,6 +1501,14 @@ def sdpa_decode_splitk( attn_mask, B, L_q, L_kv ) + HAS_KV_LEN = kv_len is not None + if HAS_KV_LEN: + kv_len_t = torch.clamp( + kv_len.reshape(1).to(torch.int32), max=int(L_kv) + ).contiguous() + else: + kv_len_t = None + _launch_decode_splitk( query, key, @@ -1449,6 +1527,8 @@ def sdpa_decode_splitk( stride_mk, num_groups, phi, + kv_len_t, + HAS_KV_LEN, ) return out @@ -1464,6 +1544,7 @@ def _sdpa_decode_splitk_abstract( scale: float = 0.0, enable_gqa: bool = False, phi: float = 5.0, + kv_len: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype" B, H_q, L_q, D = query.shape diff --git a/examples/models/gemma4_31b/cuda_source_transformations.py b/examples/models/gemma4_31b/cuda_source_transformations.py index 6609178e084..635161390d7 100644 --- a/examples/models/gemma4_31b/cuda_source_transformations.py +++ b/examples/models/gemma4_31b/cuda_source_transformations.py @@ -25,6 +25,11 @@ import types +# Importing this module registers ``torch.ops.triton.sdpa`` / +# ``torch.ops.triton.sdpa_decode_splitk`` (the length-aware bf16 attention ops +# used by the non-TurboQuant full-attention path below). +import executorch.backends.cuda.triton.kernels.sdpa # noqa: F401 + # Importing this module registers ``torch.ops.triton.tq4_sdpa``. import executorch.backends.cuda.triton.kernels.tq4_sdpa # noqa: F401 @@ -111,6 +116,79 @@ def _turboquant_attention_forward( return self.o_proj(y) +def _lenaware_attention_forward( + self, + x: torch.Tensor, + input_pos: torch.Tensor, + attn_mask: torch.Tensor, +) -> torch.Tensor: + """Drop-in ``Gemma4Attention.forward`` for full-attention layers on the + non-TurboQuant CUDA path that bounds SDPA to the valid context length. + + Identical to the default forward (plain bf16 KV cache) except the final + ``F.scaled_dot_product_attention`` is replaced with + ``torch.ops.triton.sdpa(..., kv_len=...)``. Passing ``kv_len`` bounds the + attention KV loop to the actual filled context instead of the full + pre-allocated buffer (``max_seq_len`` for global layers), making decode + O(context) instead of O(max_seq_len) — and routes L_q==1 decode through the + length-aware split-K flash-decoding kernel. Sliding-window layers are not + patched (they already use a bounded ring buffer). + """ + B, T, _ = x.shape + + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim) + raw_k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + if self.k_eq_v: + raw_v = raw_k + else: + raw_v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(raw_k) + v = self.v_norm(raw_v) + + # (B, H, T, D) for SDPA / KV cache. + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # RoPE: same code path as default forward. + freqs = torch.outer(input_pos.float(), self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos = torch.cos(emb) + sin = torch.sin(emb) + q, k = apply_rotary_emb(q, k, cos, sin) + + # Update cache and read back the full (pre-allocated) K/V buffers. + k, v = self.kv_cache.update(input_pos, k, v) + + # Number of valid (filled) KV positions = input_pos[0] + T. Passing this to + # sdpa bounds its KV loop to the actual context instead of the full + # pre-allocated buffer (max_seq_len for global layers), making attention + # O(context) instead of O(max_seq_len). Kept as a GPU scalar (no ``.item()``) + # so the bound is captured correctly by the decode CUDA graph. Decode: T=1 -> + # input_pos+1; prefill chunk: T -> chunk_end. + kv_len = input_pos[0] + input_pos.shape[0] + + # ``scale=self.scaling`` (= 1.0 for Gemma 4) — Gemma's QK-norm has absorbed + # the 1/sqrt(d) factor into trained weights. ``enable_gqa=True`` lets the + # kernel handle the head ratio without materializing expanded K/V. + y = torch.ops.triton.sdpa( + q, + k, + v, + attn_mask, + 0.0, # dropout_p + False, # is_causal: attn_mask already encodes causal masking + self.scaling, + True, # enable_gqa + kv_len, + ) + + y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) + return self.o_proj(y) + + def _fused_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: """Drop-in ``Gemma4MLP.forward`` over a fused gate|up projection. @@ -233,6 +311,22 @@ def cuda_source_transformations( _fuse_gate_up_proj(model) if not use_turboquant: + # Non-TurboQuant path: keep the bf16 KV cache but bound full-attention + # SDPA to the valid context length via a runtime kv_len scalar (routes + # through torch.ops.triton.sdpa, which dispatches L_q==1 decode to the + # length-aware split-K flash-decoding kernel). Sliding-window layers + # already use a bounded ring buffer, so they are left untouched. + n_bounded = 0 + for layer in model.layers: + attn = layer.self_attn + if attn.is_sliding: + continue + attn.forward = types.MethodType(_lenaware_attention_forward, attn) + n_bounded += 1 + print( + f"[gemma4_31b cuda] length-aware SDPA: bounded {n_bounded} " + f"full-attention layers to runtime kv_len (O(context) attention)" + ) return config = model.config