Skip to content

Commit fb75045

Browse files
committed
fix
1 parent 91051f0 commit fb75045

2 files changed

Lines changed: 14 additions & 56 deletions

File tree

lightllm/models/gemma4/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ def _context_attention_kernel(
245245
return o_tensor.view(q.shape)
246246

247247
# Full-attn layers: head_dim=512, no SWA, no image bidi — standard
248-
# triton via the primary backend.
249-
o_tensor = infer_state.prefill_att_state.prefill_att(
248+
# triton via backend1.
249+
o_tensor = infer_state.prefill_att_state1.prefill_att(
250250
q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor
251251
)
252252
return o_tensor.view(q.shape)
@@ -260,7 +260,7 @@ def _token_attention_kernel(
260260
) -> torch.Tensor:
261261
_k, _v = self._get_layer_kv(infer_state)
262262
_q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
263-
att_state = infer_state.decode_att_state1 if self.is_sliding else infer_state.decode_att_state
263+
att_state = infer_state.decode_att_state if self.is_sliding else infer_state.decode_att_state1
264264
o_tensor = att_state.decode_att(q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor)
265265
return o_tensor.view(q.shape)
266266

lightllm/models/gemma4/model.py

Lines changed: 11 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -111,63 +111,21 @@ def _init_mem_manager(self):
111111
def _init_att_backend(self):
112112
# Gemma-4 has per-layer heterogeneous attention: sliding layers use
113113
# (head_dim=256, kv_heads=16); full-attn layers use (head_dim=512,
114-
# kv_heads=4, k_eq_v). No single generic backend setup covers both:
115-
# - FA3 caps head_dim at 256 -> can't run full-attn layers.
116-
# - Flashinfer plans once per infer_state on a single shape -> can't
117-
# accommodate heterogeneous layout at all.
118-
# Strategy:
119-
# - Prefill: sliding layers go through the gemma4_mm Triton kernel
120-
# directly (handles SWA + image bidi); full-attn layers use the
121-
# primary triton backend below. No FA3 in prefill — its
122-
# image_token_end build asserts incompatible with SWA. Revisit
123-
# when fa3 supports both simultaneously.
124-
# - Decode: full-attn layers on triton (primary); sliding layers on
125-
# fa3 (with SWA) when available — secondary backend set in
126-
# _init_att_backend1.
127-
fa3_loadable = self._gemma4_fa3_loadable()
128-
129-
# Full-attn layers always go through triton.
114+
# kv_heads=4, k_eq_v). FA3 caps head_dim at 256 and flashinfer plans
115+
# once per infer_state on a single shape — both unworkable for the
116+
# heterogeneous layout. Both layer kinds go through triton.
117+
#
118+
# Primary backend = sliding layers. Sliding prefill bypasses the
119+
# backend and calls gemma4_mm directly (SWA + image bidi in one
120+
# pass); the prefill_att_state created here is unused but the
121+
# framework requires prefill_att_backend to be non-None.
130122
self.prefill_att_backend = TritonAttBackend(model=self)
131123
self.decode_att_backend = TritonAttBackend(model=self)
132124

133-
self._gemma4_sliding_decode_backend_kind = self._resolve_gemma4_sliding_backend(
134-
self.args.llm_decode_att_backend[0], fa3_loadable
135-
)
136-
137125
def _init_att_backend1(self):
138-
# Only decode needs the sliding-layer backend; prefill sliding goes
139-
# through gemma4_mm Triton directly in the layer.
140-
self.prefill_att_backend1 = None
141-
self.decode_att_backend1 = self._build_gemma4_sliding_backend(self._gemma4_sliding_decode_backend_kind)
142-
143-
@staticmethod
144-
def _gemma4_fa3_loadable():
145-
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
146-
147-
return flash_attn_with_kvcache is not None
148-
149-
@staticmethod
150-
def _resolve_gemma4_sliding_backend(backend_name, fa3_loadable):
151-
assert backend_name in ("auto", "triton", "fa3"), (
152-
"Gemma-4 requires triton or fa3 for sliding layers; flashinfer is "
153-
f"not wired for the heterogeneous layout. Got backend={backend_name!r}."
154-
)
155-
if backend_name == "auto":
156-
return "fa3" if fa3_loadable else "triton"
157-
if backend_name == "fa3":
158-
assert fa3_loadable, (
159-
"Requested --llm_*_att_backend=fa3 but flash_attn_with_kvcache "
160-
"did not import (sgl_kernel missing or wrong arch)."
161-
)
162-
return backend_name
163-
164-
def _build_gemma4_sliding_backend(self, backend_kind):
165-
if backend_kind == "fa3":
166-
from lightllm.common.basemodel.attention.fa3.fp import Fa3AttBackend
167-
168-
return Fa3AttBackend(model=self)
169-
assert backend_kind == "triton"
170-
return TritonAttBackend(model=self)
126+
# Secondary backend = full-attn layers (head_dim=512, plain causal).
127+
self.prefill_att_backend1 = TritonAttBackend(model=self)
128+
self.decode_att_backend1 = TritonAttBackend(model=self)
171129

172130
def _init_custom(self):
173131
self._init_to_get_rotary_gemma4()

0 commit comments

Comments
 (0)