Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,14 @@ class InferenceGeneral(BaseModel):
max_target_length: int = Field(2048, description="Maximum sequence length for the model.")
max_prefill_predict_length: int = Field(64, description="Maximum length for the prefill stage in decoding.")
prompt: str = Field("I love to", description="The default prompt for sampling.")
system_prompt: str = Field(
"",
description=(
"Optional system prompt prepended to the chat message list when "
"use_chat_template=True. Required for the gemma4-e2b / gemma4-e4b -it "
"checkpoints which need a system role to produce coherent output."
),
)
load_from_prefill_dir: bool = Field(False, description="Reads prefill cache from directory instead of computing it.")
prefill_cache_dir: PathStr = Field("", description="Directory for the prefill cache.")
autoregressive_decode_assert: str = Field(
Expand Down
14 changes: 8 additions & 6 deletions src/maxtext/inference/vllm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ def decode_with_vllm(config: Config) -> None:
prompts = [config.prompt]
if config.use_chat_template:
# Format the prompt using chat template if specified
messages = [
{"role": "user", "content": config.prompt},
]
messages = []
if config.system_prompt:
messages.append({"role": "system", "content": config.system_prompt})
messages.append({"role": "user", "content": config.prompt})
input_with_chat_template = tokenizer.apply_chat_template(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 The message creation logic is duplicated in both `decode_with_vllm` and `decode_with_tunix`. Consider refactoring this into a helper function to improve maintainability.
Suggested change
input_with_chat_template = tokenizer.apply_chat_template(
messages = []
if config.system_prompt:
messages.append({"role": "system", "content": config.system_prompt})
messages.append({"role": "user", "content": config.prompt})

messages,
tokenize=False, # Set to False to get the string
Expand Down Expand Up @@ -191,9 +192,10 @@ def decode_with_tunix(
prompts = [config.prompt]
if config.use_chat_template:
# Format the prompt using chat template if specified
messages = [
{"role": "user", "content": config.prompt},
]
messages = []
if config.system_prompt:
messages.append({"role": "system", "content": config.system_prompt})
messages.append({"role": "user", "content": config.prompt})
input_with_chat_template = tokenizer.apply_chat_template(
messages,
tokenize=False, # Set to False to get the string
Expand Down
9 changes: 7 additions & 2 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,10 +1026,14 @@ def forward_serve_vllm(
# Return dummy values for dry runs (e.g. during model initialization or JIT tracing)
return query, []

if self.config.sliding_window_size > 0:
# Sliding window applies only to LOCAL_SLIDING layers; global layers must run
# full attention.
if (
self.attention_type == AttentionType.LOCAL_SLIDING
and self.config.sliding_window_size > 0
):
attention_chunk_size = self.config.sliding_window_size
else:
# Chunked attention currently not used in vLLM RPA.
attention_chunk_size = None

q_scale, k_scale, v_scale = None, None, None
Expand All @@ -1052,6 +1056,7 @@ def forward_serve_vllm(
q_scale,
k_scale,
v_scale,
update_kv_cache=not self.share_kv_layer,
)
return output, kv_cache

Expand Down
29 changes: 24 additions & 5 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ def __call__(
kv_caches[index] = kv_cache
global_layer_idx_offset += num_layers
elif cfg.decoder_block == DecoderBlockType.GEMMA4_SMALL:
y = self._apply_gemma4_small_layers(
y, kv_caches = self._apply_gemma4_small_layers(
y,
decoder_input_tokens,
decoder_segment_ids,
Expand Down Expand Up @@ -1378,8 +1378,12 @@ def _apply_gemma4_small_layers(
* ``per_layer_inputs`` from PLE, sliced per layer.
* ``shared_kv_states``: donor-layer-index → (key, value) for
downstream KV-shared layers to consume.
* ``kv_caches``: when running via the vLLM RPA path, the per-layer
cache buffer threaded back from the kernel. KV-shared layers
redirect to the donor's cache slot via ``cache_index_of``.
Scan-over-layers and pipeline parallelism are not supported.
Returns ``(y, kv_caches)``. Scan-over-layers and pipeline
parallelism are not supported.
"""
cfg = self.config
mesh = self.mesh
Expand All @@ -1397,6 +1401,18 @@ def _apply_gemma4_small_layers(
num_kv_shared = cfg.num_kv_shared_layers
shared_kv_states: dict[int, tuple[jax.Array, jax.Array]] = {}

# tpu-inference allocates one `kv_caches` slot per non-shared layer;
# KV-shared layers reuse the donor's slot. Map decoder layer index -> slot.
cache_index_of: dict[int, int] = {}
next_slot = 0
for lyr in range(cfg.num_decoder_layers):
donor_idx = gemma4_small.kv_donor_layer_idx(lyr, layer_types, num_kv_shared)
if donor_idx is not None:
cache_index_of[lyr] = cache_index_of[donor_idx]
else:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Rebuilding the `cache_index_of` dictionary on every call to `_apply_gemma4_small_layers` is redundant since the layer types and sharing pattern are static for a given model configuration. While the overhead is likely negligible for tracing, precomputing this mapping during initialization would be cleaner.
Suggested change
else:
# tpu-inference allocates one `kv_caches` slot per non-shared layer;
# KV-shared layers reuse the donor's slot. Map decoder layer index -> slot.
cache_index_of: dict[int, int] = {}
next_slot = 0
for lyr in range(cfg.num_decoder_layers):
donor_idx = gemma4_small.kv_donor_layer_idx(lyr, layer_types, num_kv_shared)
if donor_idx is not None:
cache_index_of[lyr] = cache_index_of[donor_idx]
else:
cache_index_of[lyr] = next_slot
next_slot += 1

cache_index_of[lyr] = next_slot
next_slot += 1

for lyr in range(cfg.num_decoder_layers):
attention_type = layer_types[lyr]
donor_idx = gemma4_small.kv_donor_layer_idx(lyr, layer_types, num_kv_shared)
Expand Down Expand Up @@ -1433,8 +1449,9 @@ def _apply_gemma4_small_layers(

ple_slice = per_layer_inputs[..., lyr, :] if per_layer_inputs is not None else None

kv_cache = kv_caches[lyr] if kv_caches is not None else None
y = layer(
cache_idx = cache_index_of[lyr]
kv_cache = kv_caches[cache_idx] if kv_caches is not None else None
y, kv_cache = layer(
y,
decoder_segment_ids,
decoder_positions,
Expand All @@ -1449,8 +1466,10 @@ def _apply_gemma4_small_layers(
shared_key=shared_key,
shared_value=shared_value,
)
if kv_caches is not None and kv_cache is not None:
kv_caches[cache_idx] = kv_cache

return y
return y, kv_caches

# TODO(b/490118813): Relocate the following functions to their designated directories
# once the plug-in strategy is implemented: _find_next_boundary(), _apply_single_engram_layer()
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/gemma4_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def __call__(
h = h * jnp.asarray(self.layer_scalar.value, cfg.dtype)
h = nn.with_logical_constraint(h, self.activation_axis_names)

return h
return h, kv_cache


Gemma4SmallDecoderLayerToLinen = nnx_wrappers.to_linen_class(
Expand Down
65 changes: 64 additions & 1 deletion tests/end_to_end/tpu/gemma4/Run_Gemma4.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,67 @@ Set `model_name`/`tokenizer_path` to your variant (`gemma4-26b`, `gemma4-31b`) a
`ici_tensor_parallelism` to the number of chips — pass an explicit count (e.g. `4` on a v5p-8), not
`-1`, since `vllm_decode` forwards this value directly to vLLM's `tensor_parallel_size`.

> **Note:** `gemma4-e2b` / `gemma4-e4b` are not yet supported. They use cross-layer KV sharing, and will be supported soon.
#### E2B / E4B

`gemma4-e2b` and `gemma4-e4b` run through the same `vllm_decode` entry point as the larger variants, but the `-it` fine-tunes need **three things** the larger models tolerate without:

1. **A system prompt** ([per the HF model card](https://huggingface.co/google/gemma-4-E2B-it)) — without it the `-it` checkpoints drift off-topic at any temperature.
2. **Stochastic sampling** `temperature=1.0, top_p=0.95, top_k=64` (the model card's recommended settings). Greedy decoding tends to loop on these small checkpoints, independent of the MaxText path.
3. **The full stop-token set.** The upstream `google/gemma-4-*-it` repos declare `eos_token_id: [1, 106, 50]` (`<eos>`, `<turn|>`, `<|tool_response>`). If a converted checkpoint only carries `eos_token_id: 1`, end-of-turn `<turn|>` is no longer registered as a stop and generation runs to `max_tokens`. Using the upstream repo id for `tokenizer_path` keeps the full stop list automatically. A local checkpoint dir works equally well — just verify its `generation_config.json` carries the full list.

The CLI form, using the `system_prompt=` flag and the model card's sampling params:

```sh
python3 -m maxtext.inference.vllm_decode src/maxtext/configs/base.yml \
model_name=gemma4-e2b \
tokenizer_path=google/gemma-4-e2b-it \
load_parameters_path=${CONVERTED_CHECKPOINT} \
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
ici_tensor_parallelism=1 scan_layers=False \
system_prompt="You are a helpful assistant." \
prompt="Who was Albert Einstein?" use_chat_template=True \
decode_sampling_temperature=1.0 \
decode_sampling_nucleus_p=0.95 \
decode_sampling_top_k=64
```

Or via the Python API, useful for fixing a seed or stitching multiple requests:

```python
import maxtext.integration.vllm.maxtext_vllm_adapter as adapter
adapter.register()
from vllm import LLM
from vllm.sampling_params import SamplingParams
import transformers

llm = LLM(
model="google/gemma-4-e2b-it", # tokenizer + HF config dir
hf_overrides={"architectures": ["MaxTextForCausalLM"]},
additional_config={
"maxtext_config": {
"model_name": "gemma4-e2b", # or gemma4-e4b
"scan_layers": False,
"load_parameters_path": "${CONVERTED_CHECKPOINT}",
}
},
tensor_parallel_size=1, # set to chip count (e.g. 4 on v5p-8)
max_model_len=1024,
)

tok = transformers.AutoTokenizer.from_pretrained("google/gemma-4-e2b-it")
prompt = tok.apply_chat_template(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who was Albert Einstein?"},
],
tokenize=False,
add_generation_prompt=True,
)

out = llm.generate(
[prompt],
SamplingParams(temperature=1.0, top_p=0.95, top_k=64,
seed=42, max_tokens=300),
)
print(out[0].outputs[0].text)
```
Loading