Skip to content

Commit 0cce2e5

Browse files
committed
Build Gemma4 packed FlexAttention block masks
1 parent 42e05cd commit 0cce2e5

1 file changed

Lines changed: 59 additions & 3 deletions

File tree

  • nemo_automodel/components/models/gemma4_moe

nemo_automodel/components/models/gemma4_moe/model.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ def _build_packed_gemma4_causal_mask_mapping(
320320
*,
321321
dtype: torch.dtype,
322322
sliding_window: int | None,
323+
as_additive: bool = False,
324+
as_block_mask: bool = False,
323325
) -> dict[str, torch.Tensor]:
324326
"""Build Gemma4 full/sliding masks for packed VLM sequences.
325327
@@ -335,12 +337,58 @@ def _build_packed_gemma4_causal_mask_mapping(
335337
f"got {tuple(mm_token_type_ids.shape)} vs {tuple(packed_seq_ids.shape)}"
336338
)
337339

340+
if as_additive and as_block_mask:
341+
raise ValueError("Only one of as_additive and as_block_mask may be set.")
342+
338343
batch_size, seq_len = packed_seq_ids.shape
339344
device = packed_seq_ids.device
340345
positions = torch.arange(seq_len, device=device)
341346
q_positions = positions.view(1, seq_len, 1)
342347
kv_positions = positions.view(1, 1, seq_len)
343348

349+
vision_group_ids = _vision_group_ids(mm_token_type_ids)
350+
351+
if as_block_mask:
352+
from torch.nn.attention.flex_attention import create_block_mask
353+
354+
def _full_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
355+
q_pack_id = packed_seq_ids[batch_idx, q_idx]
356+
kv_pack_id = packed_seq_ids[batch_idx, kv_idx]
357+
allowed = (q_pack_id == kv_pack_id) & (q_pack_id > 0) & (kv_idx <= q_idx)
358+
return torch.where(q_pack_id <= 0, kv_idx == 0, allowed)
359+
360+
def _sliding_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
361+
q_pack_id = packed_seq_ids[batch_idx, q_idx]
362+
kv_pack_id = packed_seq_ids[batch_idx, kv_idx]
363+
same_doc = (q_pack_id == kv_pack_id) & (q_pack_id > 0)
364+
allowed = same_doc & (kv_idx <= q_idx)
365+
if sliding_window is not None:
366+
allowed = allowed & ((q_idx - kv_idx) < sliding_window)
367+
q_group = vision_group_ids[batch_idx, q_idx]
368+
kv_group = vision_group_ids[batch_idx, kv_idx]
369+
same_vision_group = (q_group == kv_group) & (q_group >= 0)
370+
allowed = (allowed | same_vision_group) & same_doc
371+
return torch.where(q_pack_id <= 0, kv_idx == 0, allowed)
372+
373+
return {
374+
"full_attention": create_block_mask(
375+
_full_mask_mod,
376+
B=batch_size,
377+
H=None,
378+
Q_LEN=seq_len,
379+
KV_LEN=seq_len,
380+
device=device,
381+
),
382+
"sliding_attention": create_block_mask(
383+
_sliding_mask_mod,
384+
B=batch_size,
385+
H=None,
386+
Q_LEN=seq_len,
387+
KV_LEN=seq_len,
388+
device=device,
389+
),
390+
}
391+
344392
valid_q = packed_seq_ids[:, :, None] > 0
345393
valid_kv = packed_seq_ids[:, None, :] > 0
346394
same_doc = (packed_seq_ids[:, :, None] == packed_seq_ids[:, None, :]) & valid_q & valid_kv
@@ -351,15 +399,22 @@ def _build_packed_gemma4_causal_mask_mapping(
351399
if sliding_window is not None:
352400
sliding_mask = sliding_mask & ((q_positions - kv_positions) < sliding_window)
353401

354-
vision_group_ids = _vision_group_ids(mm_token_type_ids)
355402
same_vision_group = (vision_group_ids[:, :, None] == vision_group_ids[:, None, :]) & (
356403
vision_group_ids[:, :, None] >= 0
357404
)
358405
sliding_mask = (sliding_mask | same_vision_group) & same_doc
359406

407+
full_mask = full_mask.view(batch_size, 1, seq_len, seq_len)
408+
sliding_mask = sliding_mask.view(batch_size, 1, seq_len, seq_len)
409+
410+
if as_additive:
411+
min_dtype = torch.finfo(dtype).min
412+
full_mask = torch.where(full_mask, torch.zeros((), dtype=dtype, device=device), min_dtype)
413+
sliding_mask = torch.where(sliding_mask, torch.zeros((), dtype=dtype, device=device), min_dtype)
414+
360415
return {
361-
"full_attention": full_mask.view(batch_size, 1, seq_len, seq_len),
362-
"sliding_attention": sliding_mask.view(batch_size, 1, seq_len, seq_len),
416+
"full_attention": full_mask,
417+
"sliding_attention": sliding_mask,
363418
}
364419

365420

@@ -488,6 +543,7 @@ def forward(
488543
mm_token_type_ids.to(device=inputs_embeds.device),
489544
dtype=inputs_embeds.dtype,
490545
sliding_window=getattr(self.config, "sliding_window", None),
546+
as_block_mask=getattr(self.config, "_attn_implementation", None) == "flex_attention",
491547
)
492548
elif use_vision_bidirectional_mask:
493549
from transformers.models.gemma4.modeling_gemma4 import create_causal_mask_mapping

0 commit comments

Comments
 (0)