Skip to content

[Gemma4] Add docstrings for Per-Layer Embeddings (PLE) pipeline#45207

Open
w4nderlust wants to merge 1 commit intohuggingface:mainfrom
w4nderlust:gemma4-ple-docs
Open

[Gemma4] Add docstrings for Per-Layer Embeddings (PLE) pipeline#45207
w4nderlust wants to merge 1 commit intohuggingface:mainfrom
w4nderlust:gemma4-ple-docs

Conversation

@w4nderlust
Copy link
Copy Markdown
Contributor

Fixes #45206

What does this PR do?

Adds documentation for the Gemma4 Per-Layer Embeddings (PLE) system, which is currently pretty hard to reverse-engineer from the code alone.

I ran into this while implementing Gemma4 inference from scratch in Rust. The PLE system has several non-obvious aspects that aren't documented anywhere:

  1. hidden_size_per_layer_input (256) is the per-layer dimension, but the actual embedding weight is [vocab, num_layers * 256] = [262144, 8960] because all layers are packed
  2. The embedding is a Gemma4TextScaledWordEmbedding that silently multiplies by sqrt(256) = 16 - this took me a while to track down
  3. The full pipeline has a context-aware projection step (per_layer_model_projection + scale + RMSNorm) that combines with the token lookup before being passed to layers, with specific scale factors (1/sqrt(hidden_size) and 1/sqrt(2))

This PR adds:

  • Expanded config docstring for hidden_size_per_layer_input explaining the packed layout, scaling, and full pipeline
  • Docstrings for get_per_layer_inputs() and project_per_layer_inputs()
  • A comment on the PLE init block pointing to the pipeline methods

Hopefully this saves some pain for others implementing Gemma4 outside of transformers.

The PLE system is complex and underdocumented, which makes it hard
for third-party implementations (llama.cpp, candle, mlx, etc.) to
get right. This adds:

- Config docstring for hidden_size_per_layer_input explaining that
  the actual embedding dim is num_hidden_layers * hidden_size_per_layer_input,
  the embedding is scaled by sqrt(hidden_size_per_layer_input), and
  describing the full two-component pipeline

- Docstring for get_per_layer_inputs() explaining the token-identity
  component and the packed-to-4D reshape

- Docstring for project_per_layer_inputs() explaining the context-aware
  projection, normalization, and combination with scale factors

- Comment on the PLE init block pointing to the pipeline methods

Fixes huggingface#45206
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 3, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gemma4: PLE (Per-Layer Embeddings) implementation is underdocumented and config is misleading

1 participant