Skip to content

Gemma4: PLE (Per-Layer Embeddings) implementation is underdocumented and config is misleading #45206

@w4nderlust

Description

@w4nderlust

Description

I was implementing Gemma4 inference from scratch (in Rust) and the Per-Layer Embeddings (PLE) system was by far the hardest part to get right. The config fields are misleading, the embedding type is non-obvious, and the full pipeline involves several undocumented steps. Sharing this in case it helps others and in case you want to improve the docs.

Problem 1: hidden_size_per_layer_input is ambiguous

The config says hidden_size_per_layer_input: 256, which sounds like it's the embedding dimension. But embed_tokens_per_layer.weight has shape [262144, 8960] where 8960 = 35 layers * 256. The actual embedding dimension is num_hidden_layers * hidden_size_per_layer_input, not hidden_size_per_layer_input alone.

This confused me because the __init__ in Gemma4TextModel seems like it should create nn.Embedding(vocab, 256) but then loading the pretrained weight of shape [vocab, 8960] would fail. (It doesn't fail because from_pretrained handles the resize, but it's not obvious from reading the code.)

Problem 2: embed_tokens_per_layer is secretly a Gemma4TextScaledWordEmbedding

The PLE embedding isn't a plain nn.Embedding. It's a Gemma4TextScaledWordEmbedding that multiplies the lookup result by sqrt(hidden_size_per_layer_input) = sqrt(256) = 16.0.

This isn't mentioned anywhere in the config, the docstrings, or the model card. I only found it by inspecting type(lm.embed_tokens_per_layer).__name__ after my outputs were 16x too small.

Problem 3: The full PLE pipeline has undocumented steps

The actual PLE computation involves:

  1. Token-identity: embed_tokens_per_layer(input_ids) (scaled by sqrt(256)) -> reshape to [B, S, num_layers, ple_dim]
  2. Context-aware projection: per_layer_model_projection(inputs_embeds) (a Linear) -> scale by 1/sqrt(hidden_size) -> reshape to [B, S, num_layers, ple_dim] -> RMSNorm (per_layer_projection_norm)
  3. Combine: (context_projection + token_identity) * (1/sqrt(2))
  4. Each layer i gets per_layer_inputs[:, :, i, :]

This involves weights that aren't mentioned in the config at all:

  • per_layer_model_projection (Linear, hidden_size -> num_layers * ple_dim)
  • per_layer_projection_norm (RMSNorm, dim=ple_dim)
  • Two hardcoded scale factors: 1/sqrt(hidden_size) and 1/sqrt(2)

The get_per_layer_inputs() and project_per_layer_inputs() methods implement this, but there are no docstrings explaining the overall pipeline or the scale factors.

Suggestion

Adding a docstring to Gemma4TextModel (or the config class) explaining:

  1. That hidden_size_per_layer_input is the per-layer dimension, and the total embedding dim is num_hidden_layers * hidden_size_per_layer_input
  2. That the PLE embedding is scaled by sqrt(hidden_size_per_layer_input)
  3. A brief description of the full PLE pipeline (token lookup + context projection + norm + combine with scale factors)

This would save a lot of pain for anyone implementing Gemma4 outside of HuggingFace transformers (e.g. llama.cpp, candle, mlx, etc.).

Environment

  • transformers 5.5.0
  • Model: google/gemma-4-E2B-it

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions