Summary
For Gemma4 checkpoints whose text config has use_bidirectional_attention="vision" and enable_moe_block=True (e.g. gemma-4-26B-A4B-it), Gemma4MoETextModelBackend.forward() always builds plain causal + sliding masks, ignoring the use_bidirectional_attention flag. This diverges from HF's Gemma4Model.forward, which calls create_causal_mask_mapping to make tokens inside the same vision group bidirectionally visible.
As a result, the MoE backend forward pass numerically diverges from HF on multimodal inputs.
Affected code
nemo_automodel/components/models/gemma4_moe/model.py, Gemma4MoETextModelBackend.forward:
# current behavior: always plain causal masks, mm_token_type_ids is never plumbed in
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
mask_kwargs = {
"config": self.config,
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
Compare to HF transformers/models/gemma4/modeling_gemma4.py (Gemma4Model.forward), which uses:
if self.config.get_text_config().use_bidirectional_attention == "vision":
causal_mask_mapping = create_causal_mask_mapping(
self.config, inputs_embeds, attention_mask, past_key_values,
position_ids, mm_token_type_ids, pixel_values,
is_training=self.training,
)
create_causal_mask_mapping reads mm_token_type_ids, groups contiguous vision tokens, and adds an or_mask_function that gives them bidirectional visibility within each group.
Reproduction
Teacher-forcing forward pass on gemma-4-26B-A4B-it with a single multimodal prompt (341 tokens, 64 generation tokens), bf16, SDPA attention, FSDP2 with EP=8. Measuring gen_kl_error = KL(P_HF || P_Automodel) on generation tokens:
|
gen_kl_error |
| Before fix |
0.034023 |
| After fix |
0.006520 |
The residual ~0.006 is consistent with FSDP mixed-precision numerical noise observed on the dense 31B model (which always goes through HF forward).
Proposed fix
Reference commit on a fork: jQizhang@ce4b225
Summary
For Gemma4 checkpoints whose text config has
use_bidirectional_attention="vision"andenable_moe_block=True(e.g.gemma-4-26B-A4B-it),Gemma4MoETextModelBackend.forward()always builds plain causal + sliding masks, ignoring theuse_bidirectional_attentionflag. This diverges from HF'sGemma4Model.forward, which callscreate_causal_mask_mappingto make tokens inside the same vision group bidirectionally visible.As a result, the MoE backend forward pass numerically diverges from HF on multimodal inputs.
Affected code
nemo_automodel/components/models/gemma4_moe/model.py,Gemma4MoETextModelBackend.forward:Compare to HF
transformers/models/gemma4/modeling_gemma4.py(Gemma4Model.forward), which uses:create_causal_mask_mappingreadsmm_token_type_ids, groups contiguous vision tokens, and adds anor_mask_functionthat gives them bidirectional visibility within each group.Reproduction
Teacher-forcing forward pass on
gemma-4-26B-A4B-itwith a single multimodal prompt (341 tokens, 64 generation tokens), bf16, SDPA attention, FSDP2 with EP=8. Measuringgen_kl_error=KL(P_HF || P_Automodel)on generation tokens:The residual ~0.006 is consistent with FSDP mixed-precision numerical noise observed on the dense 31B model (which always goes through HF forward).
Proposed fix
Reference commit on a fork: jQizhang@ce4b225