Skip to content

Commit caa6034

Browse files
committed
Tune Gemma4 decoder FlexAttention execution
1 parent 0cce2e5 commit caa6034

1 file changed

Lines changed: 18 additions & 1 deletion

File tree

  • nemo_automodel/components/models/gemma4_moe

nemo_automodel/components/models/gemma4_moe/model.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,21 @@ def forward(
246246
# --- Attention ---
247247
residual = x
248248
x = self.input_layernorm(x)
249+
attn_kwargs = kwargs
250+
if getattr(self.config, "_attn_implementation", None) == "flex_attention" and "kernel_options" not in kwargs:
251+
attn_kwargs = {
252+
**kwargs,
253+
"kernel_options": {
254+
"BLOCK_M": 32,
255+
"BLOCK_N": 32,
256+
"BLOCK_M1": 32,
257+
"BLOCK_N1": 32,
258+
"BLOCK_M2": 32,
259+
"BLOCK_N2": 32,
260+
"num_stages": 1,
261+
"num_warps": 4,
262+
},
263+
}
249264
x, _ = self.self_attn(
250265
hidden_states=x,
251266
position_embeddings=position_embeddings,
@@ -255,8 +270,10 @@ def forward(
255270
use_cache=use_cache,
256271
cache_position=cache_position,
257272
mm_token_type_ids=mm_token_type_ids,
258-
**kwargs,
273+
**attn_kwargs,
259274
)
275+
if getattr(self.config, "_attn_implementation", None) == "flex_attention" and padding_mask is not None:
276+
x = x.masked_fill(padding_mask[..., None], 0)
260277
x = self.post_attention_layernorm(x)
261278
x = residual + x
262279

0 commit comments

Comments
 (0)