@@ -320,6 +320,8 @@ def _build_packed_gemma4_causal_mask_mapping(
320320 * ,
321321 dtype : torch .dtype ,
322322 sliding_window : int | None ,
323+ as_additive : bool = False ,
324+ as_block_mask : bool = False ,
323325) -> dict [str , torch .Tensor ]:
324326 """Build Gemma4 full/sliding masks for packed VLM sequences.
325327
@@ -335,12 +337,58 @@ def _build_packed_gemma4_causal_mask_mapping(
335337 f"got { tuple (mm_token_type_ids .shape )} vs { tuple (packed_seq_ids .shape )} "
336338 )
337339
340+ if as_additive and as_block_mask :
341+ raise ValueError ("Only one of as_additive and as_block_mask may be set." )
342+
338343 batch_size , seq_len = packed_seq_ids .shape
339344 device = packed_seq_ids .device
340345 positions = torch .arange (seq_len , device = device )
341346 q_positions = positions .view (1 , seq_len , 1 )
342347 kv_positions = positions .view (1 , 1 , seq_len )
343348
349+ vision_group_ids = _vision_group_ids (mm_token_type_ids )
350+
351+ if as_block_mask :
352+ from torch .nn .attention .flex_attention import create_block_mask
353+
354+ def _full_mask_mod (batch_idx , head_idx , q_idx , kv_idx ):
355+ q_pack_id = packed_seq_ids [batch_idx , q_idx ]
356+ kv_pack_id = packed_seq_ids [batch_idx , kv_idx ]
357+ allowed = (q_pack_id == kv_pack_id ) & (q_pack_id > 0 ) & (kv_idx <= q_idx )
358+ return torch .where (q_pack_id <= 0 , kv_idx == 0 , allowed )
359+
360+ def _sliding_mask_mod (batch_idx , head_idx , q_idx , kv_idx ):
361+ q_pack_id = packed_seq_ids [batch_idx , q_idx ]
362+ kv_pack_id = packed_seq_ids [batch_idx , kv_idx ]
363+ same_doc = (q_pack_id == kv_pack_id ) & (q_pack_id > 0 )
364+ allowed = same_doc & (kv_idx <= q_idx )
365+ if sliding_window is not None :
366+ allowed = allowed & ((q_idx - kv_idx ) < sliding_window )
367+ q_group = vision_group_ids [batch_idx , q_idx ]
368+ kv_group = vision_group_ids [batch_idx , kv_idx ]
369+ same_vision_group = (q_group == kv_group ) & (q_group >= 0 )
370+ allowed = (allowed | same_vision_group ) & same_doc
371+ return torch .where (q_pack_id <= 0 , kv_idx == 0 , allowed )
372+
373+ return {
374+ "full_attention" : create_block_mask (
375+ _full_mask_mod ,
376+ B = batch_size ,
377+ H = None ,
378+ Q_LEN = seq_len ,
379+ KV_LEN = seq_len ,
380+ device = device ,
381+ ),
382+ "sliding_attention" : create_block_mask (
383+ _sliding_mask_mod ,
384+ B = batch_size ,
385+ H = None ,
386+ Q_LEN = seq_len ,
387+ KV_LEN = seq_len ,
388+ device = device ,
389+ ),
390+ }
391+
344392 valid_q = packed_seq_ids [:, :, None ] > 0
345393 valid_kv = packed_seq_ids [:, None , :] > 0
346394 same_doc = (packed_seq_ids [:, :, None ] == packed_seq_ids [:, None , :]) & valid_q & valid_kv
@@ -351,15 +399,22 @@ def _build_packed_gemma4_causal_mask_mapping(
351399 if sliding_window is not None :
352400 sliding_mask = sliding_mask & ((q_positions - kv_positions ) < sliding_window )
353401
354- vision_group_ids = _vision_group_ids (mm_token_type_ids )
355402 same_vision_group = (vision_group_ids [:, :, None ] == vision_group_ids [:, None , :]) & (
356403 vision_group_ids [:, :, None ] >= 0
357404 )
358405 sliding_mask = (sliding_mask | same_vision_group ) & same_doc
359406
407+ full_mask = full_mask .view (batch_size , 1 , seq_len , seq_len )
408+ sliding_mask = sliding_mask .view (batch_size , 1 , seq_len , seq_len )
409+
410+ if as_additive :
411+ min_dtype = torch .finfo (dtype ).min
412+ full_mask = torch .where (full_mask , torch .zeros ((), dtype = dtype , device = device ), min_dtype )
413+ sliding_mask = torch .where (sliding_mask , torch .zeros ((), dtype = dtype , device = device ), min_dtype )
414+
360415 return {
361- "full_attention" : full_mask . view ( batch_size , 1 , seq_len , seq_len ) ,
362- "sliding_attention" : sliding_mask . view ( batch_size , 1 , seq_len , seq_len ) ,
416+ "full_attention" : full_mask ,
417+ "sliding_attention" : sliding_mask ,
363418 }
364419
365420
@@ -488,6 +543,7 @@ def forward(
488543 mm_token_type_ids .to (device = inputs_embeds .device ),
489544 dtype = inputs_embeds .dtype ,
490545 sliding_window = getattr (self .config , "sliding_window" , None ),
546+ as_block_mask = getattr (self .config , "_attn_implementation" , None ) == "flex_attention" ,
491547 )
492548 elif use_vision_bidirectional_mask :
493549 from transformers .models .gemma4 .modeling_gemma4 import create_causal_mask_mapping
0 commit comments