-
Notifications
You must be signed in to change notification settings - Fork 32.7k
Gemma4: PLE (Per-Layer Embeddings) implementation is underdocumented and config is misleading #45206
Description
Description
I was implementing Gemma4 inference from scratch (in Rust) and the Per-Layer Embeddings (PLE) system was by far the hardest part to get right. The config fields are misleading, the embedding type is non-obvious, and the full pipeline involves several undocumented steps. Sharing this in case it helps others and in case you want to improve the docs.
Problem 1: hidden_size_per_layer_input is ambiguous
The config says hidden_size_per_layer_input: 256, which sounds like it's the embedding dimension. But embed_tokens_per_layer.weight has shape [262144, 8960] where 8960 = 35 layers * 256. The actual embedding dimension is num_hidden_layers * hidden_size_per_layer_input, not hidden_size_per_layer_input alone.
This confused me because the __init__ in Gemma4TextModel seems like it should create nn.Embedding(vocab, 256) but then loading the pretrained weight of shape [vocab, 8960] would fail. (It doesn't fail because from_pretrained handles the resize, but it's not obvious from reading the code.)
Problem 2: embed_tokens_per_layer is secretly a Gemma4TextScaledWordEmbedding
The PLE embedding isn't a plain nn.Embedding. It's a Gemma4TextScaledWordEmbedding that multiplies the lookup result by sqrt(hidden_size_per_layer_input) = sqrt(256) = 16.0.
This isn't mentioned anywhere in the config, the docstrings, or the model card. I only found it by inspecting type(lm.embed_tokens_per_layer).__name__ after my outputs were 16x too small.
Problem 3: The full PLE pipeline has undocumented steps
The actual PLE computation involves:
- Token-identity:
embed_tokens_per_layer(input_ids)(scaled by sqrt(256)) -> reshape to[B, S, num_layers, ple_dim] - Context-aware projection:
per_layer_model_projection(inputs_embeds)(a Linear) -> scale by1/sqrt(hidden_size)-> reshape to[B, S, num_layers, ple_dim]-> RMSNorm (per_layer_projection_norm) - Combine:
(context_projection + token_identity) * (1/sqrt(2)) - Each layer
igetsper_layer_inputs[:, :, i, :]
This involves weights that aren't mentioned in the config at all:
per_layer_model_projection(Linear, hidden_size -> num_layers * ple_dim)per_layer_projection_norm(RMSNorm, dim=ple_dim)- Two hardcoded scale factors:
1/sqrt(hidden_size)and1/sqrt(2)
The get_per_layer_inputs() and project_per_layer_inputs() methods implement this, but there are no docstrings explaining the overall pipeline or the scale factors.
Suggestion
Adding a docstring to Gemma4TextModel (or the config class) explaining:
- That
hidden_size_per_layer_inputis the per-layer dimension, and the total embedding dim isnum_hidden_layers * hidden_size_per_layer_input - That the PLE embedding is scaled by
sqrt(hidden_size_per_layer_input) - A brief description of the full PLE pipeline (token lookup + context projection + norm + combine with scale factors)
This would save a lot of pain for anyone implementing Gemma4 outside of HuggingFace transformers (e.g. llama.cpp, candle, mlx, etc.).
Environment
- transformers 5.5.0
- Model: google/gemma-4-E2B-it