Skip to content

Commit efc7560

Browse files
authored
[executorch][cuda] fuse gate/up MLP projections (#20482)
Summary: Fuse each gemma4_31b MLP's gate_proj|up_proj into a single [2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA export. This issues one activation-quant + one W4A8 matvec per layer instead of two, cutting per-token launch + activation-quant overhead in the launch-bound decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct). | decode length | main branch | current branch | |---|---|---| | 512 | 42.2 | 44.80 | | 2K | 40.8 | 43.20 | | 8K | 40.0 | 42.23 | | 32K | 39.4 | 41.64 | | 127K | 35.5 | 38.41 | Next Step: we will upsteam this kind of operator fusion into gemma4-31b model level when loading gguf. #20481 is the draft PR
1 parent 47af14f commit efc7560

2 files changed

Lines changed: 111 additions & 5 deletions

File tree

examples/models/gemma4_31b/cuda_source_transformations.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import torch
3232
import torch.nn as nn
33+
import torch.nn.functional as F
3334

3435
from executorch.examples.models.gemma4.text_decoder import apply_rotary_emb
3536
from executorch.extension.llm.modules.turboquant import TurboQuantKVCache
@@ -110,13 +111,117 @@ def _turboquant_attention_forward(
110111
return self.o_proj(y)
111112

112113

114+
def _fused_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
115+
"""Drop-in ``Gemma4MLP.forward`` over a fused gate|up projection.
116+
117+
Identical math to ``down(gelu(gate(x)) * up(x))``: the single
118+
``gate_up_proj`` emits ``[gate | up]`` concatenated on the last dim,
119+
which is then split. One W4A8 matmul (and one activation-quant of ``x``)
120+
instead of two.
121+
"""
122+
h = self.gate_up_proj(x)
123+
gate = h[..., : self.intermediate_size]
124+
up = h[..., self.intermediate_size :]
125+
return self.down_proj(F.gelu(gate, approximate="tanh") * up)
126+
127+
128+
def _concat_coalesced_int4_along_n(a, b):
129+
"""Concatenate two ``CudaCoalescedInt4Tensor`` along the output (N) dim.
130+
131+
qdata is ``[N, K/2]`` and scale/zero_point are ``[N, n_groups]`` in the
132+
coalesced layout, so a per-output-row concat on dim 0 is exact: the W4A8
133+
dp4a matvec reads each output row's qdata/scale/zero independently, so
134+
out[:N_a] reproduces ``a`` and out[N_a:] reproduces ``b`` bit-for-bit.
135+
"""
136+
from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor
137+
138+
return CudaCoalescedInt4Tensor(
139+
torch.cat([a.qdata, b.qdata], dim=0),
140+
torch.cat([a.scale, b.scale], dim=0),
141+
torch.cat([a.zero_point, b.zero_point], dim=0),
142+
a.block_size,
143+
torch.Size([a.shape[0] + b.shape[0], a.shape[1]]),
144+
None,
145+
a.activation_dtype,
146+
)
147+
148+
149+
def _is_fuseable_int4_pair(gate_w, up_w) -> bool:
150+
"""True iff gate/up are both coalesced-int4 with matching K + block_size.
151+
152+
Q4_K MLP weights become ``CudaCoalescedInt4Tensor`` (fuseable); a Q6_K
153+
weight becomes ``CudaDp4aPlanarInt6Tensor`` (left alone). ``act_pre_scale``
154+
is unused on this path but we require it absent so the concat stays exact.
155+
"""
156+
from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor
157+
158+
return (
159+
isinstance(gate_w, CudaCoalescedInt4Tensor)
160+
and isinstance(up_w, CudaCoalescedInt4Tensor)
161+
and list(gate_w.block_size) == list(up_w.block_size)
162+
and gate_w.shape[1] == up_w.shape[1]
163+
and gate_w.act_pre_scale is None
164+
and up_w.act_pre_scale is None
165+
)
166+
167+
168+
def _fuse_gate_up_proj(model: nn.Module) -> None:
169+
"""Fuse each MLP's ``gate_proj | up_proj`` into one ``gate_up_proj``.
170+
171+
gate and up share the same input, so the unfused path quantizes ``x`` to
172+
int8 twice and launches two W4A8 matvecs per layer. Fusing the weights
173+
into one ``[2*inter, hidden]`` tensor halves both. Weight bytes read are
174+
unchanged, so the win is launch + activation-quant overhead (decode is
175+
launch-bound). Only Q4_K (coalesced-int4) layers are fused; any layer
176+
with a non-int4 weight is left as two matmuls (still correct).
177+
178+
Must run AFTER weights are packed to ``CudaCoalescedInt4Tensor`` (i.e.
179+
inside ``_export_cuda``), and is independent of TurboQuant.
180+
"""
181+
n_fused = 0
182+
n_skipped = 0
183+
for layer in model.layers:
184+
mlp = getattr(layer, "mlp", None)
185+
if mlp is None or not (hasattr(mlp, "gate_proj") and hasattr(mlp, "up_proj")):
186+
continue
187+
gate_w = mlp.gate_proj.weight
188+
up_w = mlp.up_proj.weight
189+
if not _is_fuseable_int4_pair(gate_w, up_w):
190+
n_skipped += 1
191+
continue
192+
inter = up_w.shape[0]
193+
hidden = up_w.shape[1]
194+
fused_w = _concat_coalesced_int4_along_n(gate_w, up_w)
195+
196+
# Container built on meta to avoid materializing a dense
197+
# [2*inter, hidden] weight before we overwrite it with fused_w.
198+
gate_up = nn.Linear(hidden, 2 * inter, bias=False, device="meta")
199+
gate_up.weight = nn.Parameter(fused_w, requires_grad=False)
200+
mlp.gate_up_proj = gate_up
201+
mlp.intermediate_size = inter
202+
del mlp.gate_proj
203+
del mlp.up_proj
204+
mlp.forward = types.MethodType(_fused_mlp_forward, mlp)
205+
n_fused += 1
206+
207+
msg = f"[gemma4_31b cuda] Fused gate+up on {n_fused} MLP layers"
208+
if n_skipped:
209+
msg += f" ({n_skipped} skipped: non-int4 weights)"
210+
print(msg)
211+
212+
113213
def cuda_source_transformations(
114214
model: nn.Module,
115215
*,
116216
use_turboquant: bool = False,
117217
) -> None:
118218
"""Apply CUDA source transformations to a Gemma 4 31B model in place.
119219
220+
Always fuses each MLP's ``gate_proj|up_proj`` into a single matmul (one
221+
activation-quant + one W4A8 matvec per layer instead of two; Q4_K
222+
coalesced-int4 layers only — other quant types are left untouched).
223+
Optionally also swaps full-attention KV caches for TurboQuant TQ4.
224+
120225
Args:
121226
model: ``Gemma4_31B`` instance to transform.
122227
use_turboquant: When True, swap full-attention layers' KV caches
@@ -125,6 +230,8 @@ def cuda_source_transformations(
125230
``torch.ops.triton.tq4_sdpa``. Sliding-window layers are
126231
unaffected.
127232
"""
233+
_fuse_gate_up_proj(model)
234+
128235
if not use_turboquant:
129236
return
130237

examples/models/gemma4_31b/export.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,11 @@ def _export_cuda(
189189
materialize_runtime_buffers(model, dtype=torch.bfloat16)
190190
mutable_buffer_metadata = _mutable_buffer_metadata(model)
191191

192-
if use_turboquant:
193-
from executorch.examples.models.gemma4_31b.cuda_source_transformations import (
194-
cuda_source_transformations,
195-
)
192+
from executorch.examples.models.gemma4_31b.cuda_source_transformations import (
193+
cuda_source_transformations,
194+
)
196195

197-
cuda_source_transformations(model, use_turboquant=True)
196+
cuda_source_transformations(model, use_turboquant=use_turboquant)
198197

199198
# Int4Tensor weights are used directly — no format conversion.
200199
# F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim).

0 commit comments

Comments
 (0)