System Info
transformers version: 5.6.2
- Platform: Linux 5.15.0-41-generic (x86_64)
- Python: 3.12
- PyTorch: 2.10.0+cu128
- GPU: 8 × NVIDIA H800 (80GB)
- Model:
google/gemma-4-E4B-it (any Gemma 4 multimodal variant)
- Trainer: ZeRO-3 (DeepSpeed) full-parameter SFT,
torch_dtype: bfloat16
Who can help?
@yonigozlan @molbap (vision models) @zucchini-nlp (multimodal models)
Information
Tasks
Reproduction
The current implementation of Gemma4VisionPatchEmbedder._position_embeddings in
src/transformers/models/gemma4/modeling_gemma4.py (line 561 in 5.6.2) computes
2D patch position embeddings via one-hot encoding followed by a batched matmul:
def _position_embeddings(self, pixel_position_ids, padding_positions):
"""Prepare patch positions map for matmul with positon embedding table."""
clamped_positions = pixel_position_ids.clamp(min=0)
one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size)
one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table)
position_embeddings = one_hot @ self.position_embedding_table
position_embeddings = position_embeddings.sum(dim=1)
position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
return position_embeddings
For Gemma-4-E4B-it (position_embedding_size=10240), with batch_size=4 ×
10 images/sample × 2520 patches/image (the default max_soft_tokens=280 * pooling_kernel_size^2=9), this materializes:
one_hot: [40, 2520, 2, 10240] int64 → 15.38 GiB
one_hot.to(bf16) cast copy: same shape, bf16 → 3.85 GiB
Total ~19 GiB just for two embedding lookups. Captured via
torch.cuda.memory._record_memory_history peak breakdown:
[1] alloc peak 36.00 GiB
19.57 GiB (54.4%) x10 vision tower (vision_model)
14.09 GiB (39.1%) ZeRO-3
...
Drilling into the 10 vision_model allocations:
one_hot @ line 581: 15.38 GiB (single int64 tensor, [40, 2520, 2, 10240])
one_hot.to bf16 @ 582: 3.85 GiB (bf16 cast of the same shape)
Stack tops:
_position_embeddings @ modeling_gemma4.py:581
forward @ modeling_gemma4.py:596
The one_hot @ table pattern is mathematically a 2-row embedding lookup
(two table rows indexed by clamped[..., 0] and clamped[..., 1], summed)
that can be replaced with F.embedding without materializing any of the
[..., position_embedding_size] intermediates.
Expected behavior
_position_embeddings should produce the same result without the
intermediate one-hot tensor. The following replacement is numerically
equivalent (bit-exact outside autocast; matches the original under autocast
to within bf16-matmul roundoff ~1.5e-2 absolute) and eliminates both the
15.38 GiB int64 tensor and its 3.85 GiB bf16 cast:
def _position_embeddings(self, pixel_position_ids, padding_positions):
"""Prepare patch positions map via direct embedding lookup.
Mathematically equivalent to the original one_hot @ table + sum,
but avoids the [batch, num_patches, 2, position_embedding_size]
one-hot temporary (~15 GiB int64 + ~3.85 GiB bf16 cast at bs=4,
n_img=10, position_embedding_size=10240).
"""
# Original .clamp(min=0) only bounds the lower side; one_hot() did
# an implicit upper-bound check. F.embedding would read OOB on
# overflow, so clamp both sides.
clamped = pixel_position_ids.clamp(0, self.position_embedding_size - 1)
x_emb = F.embedding(clamped[..., 0], self.position_embedding_table[0])
y_emb = F.embedding(clamped[..., 1], self.position_embedding_table[1])
position_embeddings = x_emb + y_emb
# The original's trailing .sum(dim=1) promotes bf16 -> fp32 under
# autocast (reduction-dtype rule); plain `+` doesn't, so cast
# explicitly only inside autocast to keep the downstream
# `hidden_states + pos` dtype contract.
if torch.is_autocast_enabled() and position_embeddings.dtype != torch.float32:
position_embeddings = position_embeddings.to(torch.float32)
position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
return position_embeddings
System Info
transformersversion: 5.6.2google/gemma-4-E4B-it(any Gemma 4 multimodal variant)torch_dtype: bfloat16Who can help?
@yonigozlan @molbap (vision models) @zucchini-nlp (multimodal models)
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
The current implementation of
Gemma4VisionPatchEmbedder._position_embeddingsinsrc/transformers/models/gemma4/modeling_gemma4.py(line 561 in 5.6.2) computes2D patch position embeddings via one-hot encoding followed by a batched matmul:
For Gemma-4-E4B-it (
position_embedding_size=10240), with batch_size=4 ×10 images/sample × 2520 patches/image (the default
max_soft_tokens=280 * pooling_kernel_size^2=9), this materializes:one_hot:[40, 2520, 2, 10240]int64 → 15.38 GiBone_hot.to(bf16)cast copy: same shape, bf16 → 3.85 GiBTotal ~19 GiB just for two embedding lookups. Captured via
torch.cuda.memory._record_memory_historypeak breakdown:Drilling into the 10 vision_model allocations:
Stack tops:
The
one_hot @ tablepattern is mathematically a 2-row embedding lookup(two table rows indexed by
clamped[..., 0]andclamped[..., 1], summed)that can be replaced with
F.embeddingwithout materializing any of the[..., position_embedding_size]intermediates.Expected behavior
_position_embeddingsshould produce the same result without theintermediate one-hot tensor. The following replacement is numerically
equivalent (bit-exact outside autocast; matches the original under autocast
to within bf16-matmul roundoff ~1.5e-2 absolute) and eliminates both the
15.38 GiB int64 tensor and its 3.85 GiB bf16 cast: