@@ -111,63 +111,21 @@ def _init_mem_manager(self):
111111 def _init_att_backend (self ):
112112 # Gemma-4 has per-layer heterogeneous attention: sliding layers use
113113 # (head_dim=256, kv_heads=16); full-attn layers use (head_dim=512,
114- # kv_heads=4, k_eq_v). No single generic backend setup covers both:
115- # - FA3 caps head_dim at 256 -> can't run full-attn layers.
116- # - Flashinfer plans once per infer_state on a single shape -> can't
117- # accommodate heterogeneous layout at all.
118- # Strategy:
119- # - Prefill: sliding layers go through the gemma4_mm Triton kernel
120- # directly (handles SWA + image bidi); full-attn layers use the
121- # primary triton backend below. No FA3 in prefill — its
122- # image_token_end build asserts incompatible with SWA. Revisit
123- # when fa3 supports both simultaneously.
124- # - Decode: full-attn layers on triton (primary); sliding layers on
125- # fa3 (with SWA) when available — secondary backend set in
126- # _init_att_backend1.
127- fa3_loadable = self ._gemma4_fa3_loadable ()
128-
129- # Full-attn layers always go through triton.
114+ # kv_heads=4, k_eq_v). FA3 caps head_dim at 256 and flashinfer plans
115+ # once per infer_state on a single shape — both unworkable for the
116+ # heterogeneous layout. Both layer kinds go through triton.
117+ #
118+ # Primary backend = sliding layers. Sliding prefill bypasses the
119+ # backend and calls gemma4_mm directly (SWA + image bidi in one
120+ # pass); the prefill_att_state created here is unused but the
121+ # framework requires prefill_att_backend to be non-None.
130122 self .prefill_att_backend = TritonAttBackend (model = self )
131123 self .decode_att_backend = TritonAttBackend (model = self )
132124
133- self ._gemma4_sliding_decode_backend_kind = self ._resolve_gemma4_sliding_backend (
134- self .args .llm_decode_att_backend [0 ], fa3_loadable
135- )
136-
137125 def _init_att_backend1 (self ):
138- # Only decode needs the sliding-layer backend; prefill sliding goes
139- # through gemma4_mm Triton directly in the layer.
140- self .prefill_att_backend1 = None
141- self .decode_att_backend1 = self ._build_gemma4_sliding_backend (self ._gemma4_sliding_decode_backend_kind )
142-
143- @staticmethod
144- def _gemma4_fa3_loadable ():
145- from lightllm .utils .sgl_utils import flash_attn_with_kvcache
146-
147- return flash_attn_with_kvcache is not None
148-
149- @staticmethod
150- def _resolve_gemma4_sliding_backend (backend_name , fa3_loadable ):
151- assert backend_name in ("auto" , "triton" , "fa3" ), (
152- "Gemma-4 requires triton or fa3 for sliding layers; flashinfer is "
153- f"not wired for the heterogeneous layout. Got backend={ backend_name !r} ."
154- )
155- if backend_name == "auto" :
156- return "fa3" if fa3_loadable else "triton"
157- if backend_name == "fa3" :
158- assert fa3_loadable , (
159- "Requested --llm_*_att_backend=fa3 but flash_attn_with_kvcache "
160- "did not import (sgl_kernel missing or wrong arch)."
161- )
162- return backend_name
163-
164- def _build_gemma4_sliding_backend (self , backend_kind ):
165- if backend_kind == "fa3" :
166- from lightllm .common .basemodel .attention .fa3 .fp import Fa3AttBackend
167-
168- return Fa3AttBackend (model = self )
169- assert backend_kind == "triton"
170- return TritonAttBackend (model = self )
126+ # Secondary backend = full-attn layers (head_dim=512, plain causal).
127+ self .prefill_att_backend1 = TritonAttBackend (model = self )
128+ self .decode_att_backend1 = TritonAttBackend (model = self )
171129
172130 def _init_custom (self ):
173131 self ._init_to_get_rotary_gemma4 ()
0 commit comments