Skip to content

[Gemma4] Gemma4VisionPatchEmbedder._position_embeddings materializes a ~19 GiB one-hot tensor that's mathematically a 2-row embedding lookup #46175

@kuso2006

Description

@kuso2006

System Info

  • transformers version: 5.6.2
  • Platform: Linux 5.15.0-41-generic (x86_64)
  • Python: 3.12
  • PyTorch: 2.10.0+cu128
  • GPU: 8 × NVIDIA H800 (80GB)
  • Model: google/gemma-4-E4B-it (any Gemma 4 multimodal variant)
  • Trainer: ZeRO-3 (DeepSpeed) full-parameter SFT, torch_dtype: bfloat16

Who can help?

@yonigozlan @molbap (vision models) @zucchini-nlp (multimodal models)

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The current implementation of Gemma4VisionPatchEmbedder._position_embeddings in
src/transformers/models/gemma4/modeling_gemma4.py (line 561 in 5.6.2) computes
2D patch position embeddings via one-hot encoding followed by a batched matmul:

def _position_embeddings(self, pixel_position_ids, padding_positions):
    """Prepare patch positions map for matmul with positon embedding table."""
    clamped_positions = pixel_position_ids.clamp(min=0)
    one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size)
    one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table)
    position_embeddings = one_hot @ self.position_embedding_table
    position_embeddings = position_embeddings.sum(dim=1)
    position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
    return position_embeddings

For Gemma-4-E4B-it (position_embedding_size=10240), with batch_size=4 ×
10 images/sample × 2520 patches/image (the default max_soft_tokens=280 * pooling_kernel_size^2=9), this materializes:

  • one_hot: [40, 2520, 2, 10240] int64 → 15.38 GiB
  • one_hot.to(bf16) cast copy: same shape, bf16 → 3.85 GiB

Total ~19 GiB just for two embedding lookups. Captured via
torch.cuda.memory._record_memory_history peak breakdown:

[1] alloc peak 36.00 GiB 
      19.57 GiB  (54.4%)  x10    vision tower (vision_model)
      14.09 GiB  (39.1%)        ZeRO-3 
       ...

Drilling into the 10 vision_model allocations:

one_hot @ line 581:   15.38 GiB (single int64 tensor, [40, 2520, 2, 10240])
one_hot.to bf16 @ 582: 3.85 GiB (bf16 cast of the same shape)

Stack tops:

_position_embeddings @ modeling_gemma4.py:581
forward @ modeling_gemma4.py:596

The one_hot @ table pattern is mathematically a 2-row embedding lookup
(two table rows indexed by clamped[..., 0] and clamped[..., 1], summed)
that can be replaced with F.embedding without materializing any of the
[..., position_embedding_size] intermediates.

Expected behavior

_position_embeddings should produce the same result without the
intermediate one-hot tensor. The following replacement is numerically
equivalent (bit-exact outside autocast; matches the original under autocast
to within bf16-matmul roundoff ~1.5e-2 absolute) and eliminates both the
15.38 GiB int64 tensor and its 3.85 GiB bf16 cast:

def _position_embeddings(self, pixel_position_ids, padding_positions):
    """Prepare patch positions map via direct embedding lookup.

    Mathematically equivalent to the original one_hot @ table + sum,
    but avoids the [batch, num_patches, 2, position_embedding_size]
    one-hot temporary (~15 GiB int64 + ~3.85 GiB bf16 cast at bs=4,
    n_img=10, position_embedding_size=10240).
    """
    # Original .clamp(min=0) only bounds the lower side; one_hot() did
    # an implicit upper-bound check. F.embedding would read OOB on
    # overflow, so clamp both sides.
    clamped = pixel_position_ids.clamp(0, self.position_embedding_size - 1)
    x_emb = F.embedding(clamped[..., 0], self.position_embedding_table[0])
    y_emb = F.embedding(clamped[..., 1], self.position_embedding_table[1])
    position_embeddings = x_emb + y_emb
    # The original's trailing .sum(dim=1) promotes bf16 -> fp32 under
    # autocast (reduction-dtype rule); plain `+` doesn't, so cast
    # explicitly only inside autocast to keep the downstream
    # `hidden_states + pos` dtype contract.
    if torch.is_autocast_enabled() and position_embeddings.dtype != torch.float32:
        position_embeddings = position_embeddings.to(torch.float32)
    position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
    return position_embeddings

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions