Skip to content

Commit 2afd94e

Browse files
authored
fix(gemma4): Gemma4 packing, attention mask, and fixesMoE routing (#2116)
* propagate image_position_ids through VLM neat packing Signed-off-by: shruthan <shrutha.radhakrishna@servicenow.com> * propagate mm_token_type_ids through VLM neat packing Signed-off-by: shruthan <shrutha.radhakrishna@servicenow.com> * fix(models): convert 4D bool attention mask to additive format for eager attention The packed collater emits a 4D block-causal bool mask. Eager attention adds this directly to attn_weights (0/1 instead of 0/-inf), so no positions are masked — the model sees across sequence boundaries and future tokens. Also fixes _derive_padding_mask, which was applying logical_not to all mask shapes; for 4D masks the pad positions come from the diagonal. Signed-off-by: shruthan <shrutha.radhakrishna@servicenow.com> * fix(gemma4): select top-k experts from router_probs not expert_scores Consistent with HF Gemma4Router which applies top-k on softmax probabilities, not raw logits. Signed-off-by: shruthan <shrutha.radhakrishna@servicenow.com> * linter Signed-off-by: shruthan <shrutha.radhakrishna@servicenow.com> * add tests Signed-off-by: shruthan <shrutha.radhakrishna@servicenow.com> * cleanup Signed-off-by: shruthan <shrutha.radhakrishna@servicenow.com> * fix(vlm): add configurable Gemma4 thinking-prefix injection and packed example config Signed-off-by: shruthan <shrutha.radhakrishna@servicenow.com> * fix(gemma4): handle EP_SHARD mesh in state dict adapter checkpoint load Signed-off-by: shruthan <shrutha.radhakrishna@servicenow.com> * update model in example Signed-off-by: shruthan <shrutha.radhakrishna@servicenow.com> * linter Signed-off-by: shruthan <shrutha.radhakrishna@servicenow.com> --------- Signed-off-by: shruthan <shrutha.radhakrishna@servicenow.com>
1 parent 4767694 commit 2afd94e

11 files changed

Lines changed: 445 additions & 34 deletions

File tree

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Configuration for fine-tuning Gemma 4 26B-A4B MoE (128 experts) with MedPix dataset and sequence packing
2+
# Requires 8 GPUs (FSDP2 + EP=8, 16 experts per GPU)
3+
# torchrun --nproc-per-node=8 examples/vlm_finetune/finetune.py -c examples/vlm_finetune/gemma4/gemma4_26b_a4b_moe_packing.yaml
4+
5+
recipe: FinetuneRecipeForVLM
6+
7+
step_scheduler:
8+
global_batch_size: 8
9+
local_batch_size: 1
10+
ckpt_every_steps: 500
11+
val_every_steps: 500
12+
num_epochs: 2
13+
14+
dist_env:
15+
backend: nccl
16+
timeout_minutes: 60
17+
18+
rng:
19+
_target_: nemo_automodel.components.training.rng.StatefulRNG
20+
seed: 42
21+
ranked: true
22+
23+
model:
24+
_target_: nemo_automodel.NeMoAutoModelForImageTextToText.from_pretrained
25+
pretrained_model_name_or_path: google/gemma-4-26B-A4B-it
26+
torch_dtype: torch.bfloat16
27+
trust_remote_code: true
28+
attn_implementation: eager
29+
backend:
30+
_target_: nemo_automodel.components.models.common.BackendConfig
31+
attn: te
32+
linear: te
33+
rms_norm: te
34+
rope_fusion: true
35+
dispatcher: deepep
36+
fake_balanced_gate: false
37+
enable_hf_state_dict_adapter: true
38+
enable_fsdp_optimizations: true
39+
text_config:
40+
# 26B-A4B does not use kv_shared layers (only used in 2B, 4B), hence use_cache: false.
41+
use_cache: false
42+
43+
processor:
44+
padding_side: right
45+
46+
checkpoint:
47+
enabled: true
48+
checkpoint_dir: vlm_checkpoints/gemma4_26b_a4b_moe_packing/
49+
model_save_format: torch_save
50+
save_consolidated: false
51+
52+
distributed:
53+
strategy: fsdp2
54+
dp_size: none
55+
tp_size: 1
56+
cp_size: 1
57+
ep_size: 8
58+
sequence_parallel: false
59+
60+
loss_fn:
61+
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
62+
63+
dataset:
64+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
65+
path_or_dataset: mmoukouba/MedPix-VQA
66+
split: train[:1000]
67+
68+
packed_sequence:
69+
pretokenize: true
70+
max_length: 3072
71+
pack_size: 3072
72+
packing_ratio: 0.9
73+
drop_long_samples: true
74+
post_tokenize_hook_fn: nemo_automodel.components.datasets.vlm.collate_fns.gemma4_inject_thinking_prefix
75+
76+
dataloader:
77+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
78+
num_workers: 4
79+
persistent_workers: true
80+
pin_memory: true
81+
82+
validation_dataset:
83+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
84+
path_or_dataset: mmoukouba/MedPix-VQA
85+
split: validation[:500]
86+
87+
validation_dataloader:
88+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
89+
collate_fn:
90+
_target_: nemo_automodel.components.datasets.vlm.collate_fns.gemma4_prefix_collate_fn
91+
92+
optimizer:
93+
_target_: torch.optim.AdamW
94+
lr: 2e-5
95+
weight_decay: 0.01
96+
betas: [0.9, 0.95]
97+
98+
freeze_config:
99+
freeze_embeddings: true
100+
freeze_vision_tower: true
101+
freeze_audio_tower: true
102+
freeze_language_model: false
103+
104+
# wandb:
105+
# project: <your-project>
106+
# entity: <your-entity>
107+
# name: <your-run-name>

nemo_automodel/components/datasets/vlm/collate_fns.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,6 +1496,12 @@ def _pad_1d(tensor, pad_value, target_len):
14961496
labels = torch.stack([_pad_1d(x["labels"], LABEL_PAD, max_len) for x in batch])
14971497
attention_mask = torch.stack([_pad_1d(x["attention_mask"], 0, max_len) for x in batch])
14981498

1499+
def _get_mm_token_type_ids(item):
1500+
v = item.get("mm_token_type_ids")
1501+
return v if v is not None else torch.zeros(0, dtype=torch.long)
1502+
1503+
mm_token_type_ids = torch.stack([_pad_1d(_get_mm_token_type_ids(x), 0, max_len) for x in batch])
1504+
14991505
if use_flash:
15001506
# Keep indexed [B, S] mask for flash_attn_varlen_func.
15011507
# The patched _get_unpad_data will extract per-document cu_seqlens.
@@ -1526,6 +1532,7 @@ def _pad_mrope(pos, target_len):
15261532
"labels": labels,
15271533
"position_ids": position_ids,
15281534
"attention_mask": attention_mask_out,
1535+
"mm_token_type_ids": mm_token_type_ids,
15291536
}
15301537

15311538
# Store indexed attention mask for loss functions that need per-sample
@@ -1541,7 +1548,7 @@ def _pad_mrope(pos, target_len):
15411548
if tensors:
15421549
result[key] = torch.cat(tensors, dim=0).to(torch.bfloat16)
15431550

1544-
for key in ("image_grid_thw", "video_grid_thw", "second_per_grid_ts"):
1551+
for key in ("image_grid_thw", "image_position_ids", "video_grid_thw", "second_per_grid_ts"):
15451552
tensors = [x[key] for x in batch if key in x and x[key] is not None]
15461553
if tensors:
15471554
result[key] = torch.cat(tensors, dim=0)
@@ -1804,13 +1811,6 @@ def _inject_thinking_prefix_tokens(
18041811
) -> Dict[str, torch.Tensor]:
18051812
"""Insert ``<|channel>thought\\n<channel|>`` tokens after every ``<|turn>model\\n`` marker.
18061813
1807-
Gemma4 31B / 26B-A4B MoE instruction-tuned models always emit a thinking-
1808-
channel prefix before the actual response. When this prefix is absent from
1809-
training sequences the model predicts ``<|channel>`` but the label says
1810-
answer text, inflating initial loss to ~9. Injecting the prefix (masked
1811-
as -100 in labels) lets the model see its expected pattern and brings
1812-
initial loss down to ~3.
1813-
18141814
Modifies ``input_ids``, ``attention_mask``, and ``mm_token_type_ids``
18151815
(if present). Additionally, any other 2-D integer tensor whose second
18161816
dimension matches ``input_ids`` is extended with zeros so that sequence
@@ -1885,6 +1885,25 @@ def _inject_thinking_prefix_tokens(
18851885
return batch
18861886

18871887

1888+
def gemma4_inject_thinking_prefix(
1889+
batch: Dict[str, torch.Tensor],
1890+
processor,
1891+
) -> Dict[str, torch.Tensor]:
1892+
"""Inject Gemma4's thinking-channel prefix after every assistant turn marker.
1893+
1894+
Gemma4 31B / 26B-A4B MoE instruction-tuned models always emit a thinking-
1895+
channel prefix before the actual response. When this prefix is absent from
1896+
training sequences the model predicts ``<|channel>`` but the label says
1897+
answer text, inflating initial loss to ~9. Injecting the prefix (masked
1898+
as -100 in labels) lets the model see its expected pattern and brings
1899+
initial loss down to ~3.
1900+
1901+
Safe no-op for non-Gemma4 tokenizers.
1902+
"""
1903+
tokenizer = getattr(processor, "tokenizer", processor)
1904+
return _inject_thinking_prefix_tokens(batch, tokenizer)
1905+
1906+
18881907
def gemma4_prefix_collate_fn(
18891908
examples: Sequence[Dict[str, Any]],
18901909
processor,
@@ -1900,8 +1919,7 @@ def gemma4_prefix_collate_fn(
19001919
"""
19011920

19021921
def _inject(batch, proc):
1903-
tokenizer = getattr(proc, "tokenizer", proc)
1904-
batch = _inject_thinking_prefix_tokens(batch, tokenizer)
1922+
batch = gemma4_inject_thinking_prefix(batch, proc)
19051923
if max_length is not None and batch["input_ids"].size(1) > max_length:
19061924
for key in list(batch.keys()):
19071925
v = batch[key]

nemo_automodel/components/datasets/vlm/datasets.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -922,12 +922,21 @@ class PreTokenizedDatasetWrapper(torch.utils.data.Dataset):
922922
``pixel_values_videos``, ``video_grid_thw``).
923923
"""
924924

925-
def __init__(self, dataset, processor, max_length=None, max_retries=10, truncate=False):
925+
def __init__(
926+
self,
927+
dataset,
928+
processor,
929+
max_length=None,
930+
max_retries=10,
931+
truncate=False,
932+
post_tokenize_hook=None,
933+
):
926934
self.dataset = dataset
927935
self.processor = processor
928936
self.max_length = max_length
929937
self.truncate = truncate
930938
self.max_retries = max_retries
939+
self.post_tokenize_hook = post_tokenize_hook
931940
# Compatibility attributes expected by build_dataloader
932941
self.preload_media = False
933942

@@ -998,6 +1007,8 @@ def __getitem__(self, idx):
9981007
processor_kwargs["video_metadata"] = [video_metadata]
9991008

10001009
result = self.processor(**processor_kwargs)
1010+
if self.post_tokenize_hook is not None:
1011+
result = self.post_tokenize_hook(result, self.processor)
10011012

10021013
input_ids = result["input_ids"][0] # (seq_len,)
10031014
seq_len = input_ids.shape[0]

nemo_automodel/components/datasets/vlm/neat_packing_vlm.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,14 @@
4747

4848
logger = logging.getLogger(__name__)
4949

50-
MEDIA_KEYS = ("pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts")
50+
MEDIA_KEYS = (
51+
"pixel_values",
52+
"image_grid_thw",
53+
"image_position_ids",
54+
"pixel_values_videos",
55+
"video_grid_thw",
56+
"second_per_grid_ts",
57+
)
5158

5259

5360
# ---------------------------------------------------------------------------
@@ -302,6 +309,10 @@ def _shift_sample(sample: dict, has_mrope: bool = False) -> dict:
302309
out["labels"] = sample["labels"][1:]
303310
out["attention_mask"] = sample["attention_mask"][:-1]
304311

312+
if (mm_ttids := sample.get("mm_token_type_ids")) is not None:
313+
mm_ttids = torch.as_tensor(mm_ttids)
314+
out["mm_token_type_ids"] = mm_ttids[0, :-1] if mm_ttids.ndim == 2 else mm_ttids[:-1]
315+
305316
if has_mrope and "position_ids" in sample and sample["position_ids"] is not None:
306317
out["position_ids"] = sample["position_ids"][:, :-1]
307318

@@ -321,11 +332,13 @@ def _build_packed_vlm_sample(
321332
all_input_ids: list[int] = []
322333
all_labels: list[int] = []
323334
all_attention_mask: list[int] = []
335+
all_mm_token_type_ids: list[int] = []
324336
all_position_ids_1d: list[int] = []
325337
mrope_position_ids_list: list[torch.Tensor] = []
326338

327339
pixel_values_list: list[torch.Tensor] = []
328340
image_grid_thw_list: list[torch.Tensor] = []
341+
image_position_ids_list: list[torch.Tensor] = []
329342
pixel_values_videos_list: list[torch.Tensor] = []
330343
video_grid_thw_list: list[torch.Tensor] = []
331344
second_per_grid_ts_list: list[torch.Tensor] = []
@@ -345,6 +358,12 @@ def _build_packed_vlm_sample(
345358
all_labels.extend(labs)
346359
all_attention_mask.extend([seq_idx] * seq_len)
347360

361+
mm_ttids = sample.get("mm_token_type_ids")
362+
if mm_ttids is not None:
363+
all_mm_token_type_ids.extend(mm_ttids.tolist() if isinstance(mm_ttids, torch.Tensor) else mm_ttids)
364+
else:
365+
all_mm_token_type_ids.extend([0] * seq_len)
366+
348367
if has_mrope and "position_ids" in sample:
349368
mrope_position_ids_list.append(sample["position_ids"])
350369
else:
@@ -355,6 +374,8 @@ def _build_packed_vlm_sample(
355374
if "image_grid_thw" in sample and sample["image_grid_thw"] is not None:
356375
n_images += sample["image_grid_thw"].shape[0]
357376
image_grid_thw_list.append(sample["image_grid_thw"])
377+
if "image_position_ids" in sample and sample["image_position_ids"] is not None:
378+
image_position_ids_list.append(sample["image_position_ids"])
358379
if "pixel_values_videos" in sample and sample["pixel_values_videos"] is not None:
359380
pixel_values_videos_list.append(sample["pixel_values_videos"])
360381
if "video_grid_thw" in sample and sample["video_grid_thw"] is not None:
@@ -368,6 +389,7 @@ def _build_packed_vlm_sample(
368389
"input_ids": torch.tensor(all_input_ids, dtype=torch.long),
369390
"labels": torch.tensor(all_labels, dtype=torch.long),
370391
"attention_mask": torch.tensor(all_attention_mask, dtype=torch.long),
392+
"mm_token_type_ids": torch.tensor(all_mm_token_type_ids, dtype=torch.long),
371393
"n_images": n_images,
372394
"n_videos": n_videos,
373395
}
@@ -379,6 +401,7 @@ def _build_packed_vlm_sample(
379401

380402
packed["pixel_values"] = torch.cat(pixel_values_list, dim=0) if pixel_values_list else None
381403
packed["image_grid_thw"] = torch.cat(image_grid_thw_list, dim=0) if image_grid_thw_list else None
404+
packed["image_position_ids"] = torch.cat(image_position_ids_list, dim=0) if image_position_ids_list else None
382405
packed["pixel_values_videos"] = torch.cat(pixel_values_videos_list, dim=0) if pixel_values_videos_list else None
383406
packed["video_grid_thw"] = torch.cat(video_grid_thw_list, dim=0) if video_grid_thw_list else None
384407
packed["second_per_grid_ts"] = torch.cat(second_per_grid_ts_list, dim=0) if second_per_grid_ts_list else None

nemo_automodel/components/models/gemma4_moe/model.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,7 @@ def forward(self, x, token_mask=None, cp_mesh=None):
114114
expert_scores = self.proj(x_norm)
115115
router_probs = F.softmax(expert_scores, dim=-1)
116116

117-
# Top-k on raw scores (matching HF Gemma4Router behaviour)
118-
_, indices = torch.topk(expert_scores, k=self.topk, dim=-1)
119-
weights = router_probs.gather(-1, indices)
117+
weights, indices = torch.topk(router_probs, k=self.topk, dim=-1)
120118
weights = weights / weights.sum(dim=-1, keepdim=True).clamp(min=1e-20)
121119
return weights, indices, None
122120

@@ -264,6 +262,26 @@ def forward(
264262
return x
265263

266264

265+
def _convert_bool_4d_mask_to_additive(attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
266+
"""Convert a 4D bool allowed-mask to HF additive format (0.0 allowed, -inf masked)."""
267+
if attention_mask.ndim != 4 or attention_mask.dtype != torch.bool:
268+
return attention_mask
269+
additive = torch.zeros(attention_mask.shape, dtype=dtype, device=attention_mask.device)
270+
return additive.masked_fill(~attention_mask, torch.finfo(dtype).min)
271+
272+
273+
def _derive_padding_mask(attention_mask: torch.Tensor) -> torch.Tensor:
274+
"""Derive 2D padding mask (True = pad) from 1D, 2D, or 4D attention mask."""
275+
if attention_mask.ndim == 2:
276+
return attention_mask == 0
277+
if attention_mask.ndim == 4:
278+
diagonal = torch.diagonal(attention_mask[:, 0], dim1=-2, dim2=-1)
279+
if attention_mask.dtype == torch.bool:
280+
return diagonal.logical_not()
281+
return diagonal != 0
282+
return attention_mask.bool().logical_not()
283+
284+
267285
# ---------------------------------------------------------------------------
268286
# Text model backend
269287
# ---------------------------------------------------------------------------
@@ -356,7 +374,10 @@ def forward(
356374
position_ids = cache_position.unsqueeze(0)
357375

358376
if padding_mask is None and attention_mask is not None:
359-
padding_mask = attention_mask.bool().logical_not()
377+
padding_mask = _derive_padding_mask(attention_mask)
378+
379+
if attention_mask is not None:
380+
attention_mask = _convert_bool_4d_mask_to_additive(attention_mask, inputs_embeds.dtype)
360381

361382
hidden_states = inputs_embeds
362383

0 commit comments

Comments
 (0)