3030
3131import torch
3232import torch .nn as nn
33+ import torch .nn .functional as F
3334
3435from executorch .examples .models .gemma4 .text_decoder import apply_rotary_emb
3536from executorch .extension .llm .modules .turboquant import TurboQuantKVCache
@@ -110,13 +111,117 @@ def _turboquant_attention_forward(
110111 return self .o_proj (y )
111112
112113
114+ def _fused_mlp_forward (self , x : torch .Tensor ) -> torch .Tensor :
115+ """Drop-in ``Gemma4MLP.forward`` over a fused gate|up projection.
116+
117+ Identical math to ``down(gelu(gate(x)) * up(x))``: the single
118+ ``gate_up_proj`` emits ``[gate | up]`` concatenated on the last dim,
119+ which is then split. One W4A8 matmul (and one activation-quant of ``x``)
120+ instead of two.
121+ """
122+ h = self .gate_up_proj (x )
123+ gate = h [..., : self .intermediate_size ]
124+ up = h [..., self .intermediate_size :]
125+ return self .down_proj (F .gelu (gate , approximate = "tanh" ) * up )
126+
127+
128+ def _concat_coalesced_int4_along_n (a , b ):
129+ """Concatenate two ``CudaCoalescedInt4Tensor`` along the output (N) dim.
130+
131+ qdata is ``[N, K/2]`` and scale/zero_point are ``[N, n_groups]`` in the
132+ coalesced layout, so a per-output-row concat on dim 0 is exact: the W4A8
133+ dp4a matvec reads each output row's qdata/scale/zero independently, so
134+ out[:N_a] reproduces ``a`` and out[N_a:] reproduces ``b`` bit-for-bit.
135+ """
136+ from executorch .backends .cuda .coalesced_int4_tensor import CudaCoalescedInt4Tensor
137+
138+ return CudaCoalescedInt4Tensor (
139+ torch .cat ([a .qdata , b .qdata ], dim = 0 ),
140+ torch .cat ([a .scale , b .scale ], dim = 0 ),
141+ torch .cat ([a .zero_point , b .zero_point ], dim = 0 ),
142+ a .block_size ,
143+ torch .Size ([a .shape [0 ] + b .shape [0 ], a .shape [1 ]]),
144+ None ,
145+ a .activation_dtype ,
146+ )
147+
148+
149+ def _is_fuseable_int4_pair (gate_w , up_w ) -> bool :
150+ """True iff gate/up are both coalesced-int4 with matching K + block_size.
151+
152+ Q4_K MLP weights become ``CudaCoalescedInt4Tensor`` (fuseable); a Q6_K
153+ weight becomes ``CudaDp4aPlanarInt6Tensor`` (left alone). ``act_pre_scale``
154+ is unused on this path but we require it absent so the concat stays exact.
155+ """
156+ from executorch .backends .cuda .coalesced_int4_tensor import CudaCoalescedInt4Tensor
157+
158+ return (
159+ isinstance (gate_w , CudaCoalescedInt4Tensor )
160+ and isinstance (up_w , CudaCoalescedInt4Tensor )
161+ and list (gate_w .block_size ) == list (up_w .block_size )
162+ and gate_w .shape [1 ] == up_w .shape [1 ]
163+ and gate_w .act_pre_scale is None
164+ and up_w .act_pre_scale is None
165+ )
166+
167+
168+ def _fuse_gate_up_proj (model : nn .Module ) -> None :
169+ """Fuse each MLP's ``gate_proj | up_proj`` into one ``gate_up_proj``.
170+
171+ gate and up share the same input, so the unfused path quantizes ``x`` to
172+ int8 twice and launches two W4A8 matvecs per layer. Fusing the weights
173+ into one ``[2*inter, hidden]`` tensor halves both. Weight bytes read are
174+ unchanged, so the win is launch + activation-quant overhead (decode is
175+ launch-bound). Only Q4_K (coalesced-int4) layers are fused; any layer
176+ with a non-int4 weight is left as two matmuls (still correct).
177+
178+ Must run AFTER weights are packed to ``CudaCoalescedInt4Tensor`` (i.e.
179+ inside ``_export_cuda``), and is independent of TurboQuant.
180+ """
181+ n_fused = 0
182+ n_skipped = 0
183+ for layer in model .layers :
184+ mlp = getattr (layer , "mlp" , None )
185+ if mlp is None or not (hasattr (mlp , "gate_proj" ) and hasattr (mlp , "up_proj" )):
186+ continue
187+ gate_w = mlp .gate_proj .weight
188+ up_w = mlp .up_proj .weight
189+ if not _is_fuseable_int4_pair (gate_w , up_w ):
190+ n_skipped += 1
191+ continue
192+ inter = up_w .shape [0 ]
193+ hidden = up_w .shape [1 ]
194+ fused_w = _concat_coalesced_int4_along_n (gate_w , up_w )
195+
196+ # Container built on meta to avoid materializing a dense
197+ # [2*inter, hidden] weight before we overwrite it with fused_w.
198+ gate_up = nn .Linear (hidden , 2 * inter , bias = False , device = "meta" )
199+ gate_up .weight = nn .Parameter (fused_w , requires_grad = False )
200+ mlp .gate_up_proj = gate_up
201+ mlp .intermediate_size = inter
202+ del mlp .gate_proj
203+ del mlp .up_proj
204+ mlp .forward = types .MethodType (_fused_mlp_forward , mlp )
205+ n_fused += 1
206+
207+ msg = f"[gemma4_31b cuda] Fused gate+up on { n_fused } MLP layers"
208+ if n_skipped :
209+ msg += f" ({ n_skipped } skipped: non-int4 weights)"
210+ print (msg )
211+
212+
113213def cuda_source_transformations (
114214 model : nn .Module ,
115215 * ,
116216 use_turboquant : bool = False ,
117217) -> None :
118218 """Apply CUDA source transformations to a Gemma 4 31B model in place.
119219
220+ Always fuses each MLP's ``gate_proj|up_proj`` into a single matmul (one
221+ activation-quant + one W4A8 matvec per layer instead of two; Q4_K
222+ coalesced-int4 layers only — other quant types are left untouched).
223+ Optionally also swaps full-attention KV caches for TurboQuant TQ4.
224+
120225 Args:
121226 model: ``Gemma4_31B`` instance to transform.
122227 use_turboquant: When True, swap full-attention layers' KV caches
@@ -125,6 +230,8 @@ def cuda_source_transformations(
125230 ``torch.ops.triton.tq4_sdpa``. Sliding-window layers are
126231 unaffected.
127232 """
233+ _fuse_gate_up_proj (model )
234+
128235 if not use_turboquant :
129236 return
130237
0 commit comments