Skip to content

Enable Gemma 4 E2B / E4B inference via vLLM RPA#4053

Draft
gagika wants to merge 1 commit into
mainfrom
agagik-gemma4e-vllm
Draft

Enable Gemma 4 E2B / E4B inference via vLLM RPA#4053
gagika wants to merge 1 commit into
mainfrom
agagik-gemma4e-vllm

Conversation

@gagika
Copy link
Copy Markdown
Collaborator

@gagika gagika commented Jun 3, 2026

Description

Add KV-shared layers wiring adds a system_prompt flag to vllm_decode (required by the E-family -it checkpoints), and documents a verified inference recipe in Run_Gemma4.md.

Key changes

  • gemma4_small.py — decoder layer returns the kernel-updated kv_cache (was dropped).
  • decoders.py — KV-shared layers redirect to the donor's kv_caches slot via a layer→slot map; cache is written back per layer.
  • attentions.py — sliding-window only on LOCAL_SLIDING layers; KV-shared layers no longer overwrite the donor's cache (update_kv_cache=False).
  • vllm_decode.py / types.py — new system_prompt config knob, prepended as a system message when use_chat_template=True.
  • Run_Gemma4.md — E2B / E4B recipe: system prompt, model-card sampling, full eos_token_id [1, 106, 50] stop-token set.

Tests

On v5p-8, e2b + e4b (vLLM TP=1): cross-checked top-1 logits at greedy vs the native checkpoint — bit-identical. With the documented recipe, models generate coherent output and stop cleanly. CLI and Python API verified.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jun 3, 2026

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request successfully enables inference for Gemma 4 E2B/E4B models by implementing cross-layer KV sharing within the vLLM RPA (Ragged Paged Attention) path. The implementation is technically sound, handling the complex layer-to-slot mapping required for shared KV caches while maintaining compatibility with existing inference workflows.

🔍 General Feedback

  • Robust KV-Sharing Implementation: The mapping logic in decoders.py correctly handles the redirection of shared layers to donor slots, ensuring efficient memory usage during TPU inference.
  • Improved Attention Logic: The fix in attentions.py to restrict sliding window attention to LOCAL_SLIDING layers is a necessary correction for hybrid attention models like Gemma 4.
  • Clear Documentation: The added recipes in Run_Gemma4.md provide essential guidance on system prompts and sampling parameters required for coherent output from these smaller checkpoints.

donor_idx = gemma4_small.kv_donor_layer_idx(lyr, layer_types, num_kv_shared)
if donor_idx is not None:
cache_index_of[lyr] = cache_index_of[donor_idx]
else:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Rebuilding the `cache_index_of` dictionary on every call to `_apply_gemma4_small_layers` is redundant since the layer types and sharing pattern are static for a given model configuration. While the overhead is likely negligible for tracing, precomputing this mapping during initialization would be cleaner.
Suggested change
else:
# tpu-inference allocates one `kv_caches` slot per non-shared layer;
# KV-shared layers reuse the donor's slot. Map decoder layer index -> slot.
cache_index_of: dict[int, int] = {}
next_slot = 0
for lyr in range(cfg.num_decoder_layers):
donor_idx = gemma4_small.kv_donor_layer_idx(lyr, layer_types, num_kv_shared)
if donor_idx is not None:
cache_index_of[lyr] = cache_index_of[donor_idx]
else:
cache_index_of[lyr] = next_slot
next_slot += 1

if config.system_prompt:
messages.append({"role": "system", "content": config.system_prompt})
messages.append({"role": "user", "content": config.prompt})
input_with_chat_template = tokenizer.apply_chat_template(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 The message creation logic is duplicated in both `decode_with_vllm` and `decode_with_tunix`. Consider refactoring this into a helper function to improve maintainability.
Suggested change
input_with_chat_template = tokenizer.apply_chat_template(
messages = []
if config.system_prompt:
messages.append({"role": "system", "content": config.system_prompt})
messages.append({"role": "user", "content": config.prompt})

@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 3, 2026

Codecov Report

❌ Patch coverage is 0% with 17 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/decoders.py 0.00% 15 Missing ⚠️
src/maxtext/layers/attentions.py 0.00% 1 Missing ⚠️
src/maxtext/models/gemma4_small.py 0.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant