Skip to content

[Bug] Gemma4 MoE backend missing vision-aware attention mask for use_bidirectional_attention="vision" models #1891

@jQizhang

Description

@jQizhang

Summary

For Gemma4 checkpoints whose text config has use_bidirectional_attention="vision" and enable_moe_block=True (e.g. gemma-4-26B-A4B-it), Gemma4MoETextModelBackend.forward() always builds plain causal + sliding masks, ignoring the use_bidirectional_attention flag. This diverges from HF's Gemma4Model.forward, which calls create_causal_mask_mapping to make tokens inside the same vision group bidirectionally visible.

As a result, the MoE backend forward pass numerically diverges from HF on multimodal inputs.

Affected code

nemo_automodel/components/models/gemma4_moe/model.py, Gemma4MoETextModelBackend.forward:

# current behavior: always plain causal masks, mm_token_type_ids is never plumbed in
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask

mask_kwargs = {
    "config": self.config,
    "inputs_embeds": inputs_embeds,
    "attention_mask": attention_mask,
    "cache_position": cache_position,
    "past_key_values": past_key_values,
    "position_ids": position_ids,
}
causal_mask_mapping = {
    "full_attention": create_causal_mask(**mask_kwargs),
    "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}

Compare to HF transformers/models/gemma4/modeling_gemma4.py (Gemma4Model.forward), which uses:

if self.config.get_text_config().use_bidirectional_attention == "vision":
    causal_mask_mapping = create_causal_mask_mapping(
        self.config, inputs_embeds, attention_mask, past_key_values,
        position_ids, mm_token_type_ids, pixel_values,
        is_training=self.training,
    )

create_causal_mask_mapping reads mm_token_type_ids, groups contiguous vision tokens, and adds an or_mask_function that gives them bidirectional visibility within each group.

Reproduction

Teacher-forcing forward pass on gemma-4-26B-A4B-it with a single multimodal prompt (341 tokens, 64 generation tokens), bf16, SDPA attention, FSDP2 with EP=8. Measuring gen_kl_error = KL(P_HF || P_Automodel) on generation tokens:

gen_kl_error
Before fix 0.034023
After fix 0.006520

The residual ~0.006 is consistent with FSDP mixed-precision numerical noise observed on the dense 31B model (which always goes through HF forward).

Proposed fix

Reference commit on a fork: jQizhang@ce4b225

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions