File tree Expand file tree Collapse file tree
nemo_automodel/components/models/gemma4_moe Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -246,6 +246,21 @@ def forward(
246246 # --- Attention ---
247247 residual = x
248248 x = self .input_layernorm (x )
249+ attn_kwargs = kwargs
250+ if getattr (self .config , "_attn_implementation" , None ) == "flex_attention" and "kernel_options" not in kwargs :
251+ attn_kwargs = {
252+ ** kwargs ,
253+ "kernel_options" : {
254+ "BLOCK_M" : 32 ,
255+ "BLOCK_N" : 32 ,
256+ "BLOCK_M1" : 32 ,
257+ "BLOCK_N1" : 32 ,
258+ "BLOCK_M2" : 32 ,
259+ "BLOCK_N2" : 32 ,
260+ "num_stages" : 1 ,
261+ "num_warps" : 4 ,
262+ },
263+ }
249264 x , _ = self .self_attn (
250265 hidden_states = x ,
251266 position_embeddings = position_embeddings ,
@@ -255,8 +270,10 @@ def forward(
255270 use_cache = use_cache ,
256271 cache_position = cache_position ,
257272 mm_token_type_ids = mm_token_type_ids ,
258- ** kwargs ,
273+ ** attn_kwargs ,
259274 )
275+ if getattr (self .config , "_attn_implementation" , None ) == "flex_attention" and padding_mask is not None :
276+ x = x .masked_fill (padding_mask [..., None ], 0 )
260277 x = self .post_attention_layernorm (x )
261278 x = residual + x
262279
You can’t perform that action at this time.
0 commit comments