Skip to content

Commit 2e6cd11

Browse files
Merge pull request #4053 from AI-Hypercomputer:agagik-gemma4e-vllm
PiperOrigin-RevId: 927575165
2 parents 16cc4f4 + 97c6f9a commit 2e6cd11

8 files changed

Lines changed: 218 additions & 22 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,6 +1569,14 @@ class InferenceGeneral(BaseModel):
15691569
max_target_length: int = Field(2048, description="Maximum sequence length for the model.")
15701570
max_prefill_predict_length: int = Field(64, description="Maximum length for the prefill stage in decoding.")
15711571
prompt: str = Field("I love to", description="The default prompt for sampling.")
1572+
system_prompt: str = Field(
1573+
"",
1574+
description=(
1575+
"Optional system prompt prepended to the chat message list when "
1576+
"use_chat_template=True. Required for the gemma4-e2b / gemma4-e4b -it "
1577+
"checkpoints which need a system role to produce coherent output."
1578+
),
1579+
)
15721580
load_from_prefill_dir: bool = Field(False, description="Reads prefill cache from directory instead of computing it.")
15731581
prefill_cache_dir: PathStr = Field("", description="Directory for the prefill cache.")
15741582
autoregressive_decode_assert: str = Field(

src/maxtext/inference/vllm_decode.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,6 @@
5050
from maxtext.configs import pyconfig
5151
import maxtext.integration.vllm.maxtext_vllm_adapter as adapter
5252

53-
adapter.register()
54-
55-
os.environ["SKIP_JAX_PRECOMPILE"] = "1"
56-
os.environ["NEW_MODEL_DESIGN"] = "1"
57-
5853

5954
# --- DEFINE FLAGS GLOBALLY ---
6055
FLAGS = flags.FLAGS
@@ -63,6 +58,15 @@
6358
flags.DEFINE_integer("seed", 42, "Random seed for sampling.")
6459

6560

61+
def build_chat_messages(config: Config) -> list[dict[str, str]]:
62+
"""Builds the chat message list, prepending a system prompt when set."""
63+
messages = []
64+
if config.system_prompt:
65+
messages.append({"role": "system", "content": config.system_prompt})
66+
messages.append({"role": "user", "content": config.prompt})
67+
return messages
68+
69+
6670
def decode_with_vllm(config: Config) -> None:
6771
"""Decode using vLLM with a MaxText model implementation.
6872
@@ -127,11 +131,8 @@ def decode_with_vllm(config: Config) -> None:
127131
prompts = [config.prompt]
128132
if config.use_chat_template:
129133
# Format the prompt using chat template if specified
130-
messages = [
131-
{"role": "user", "content": config.prompt},
132-
]
133134
input_with_chat_template = tokenizer.apply_chat_template(
134-
messages,
135+
build_chat_messages(config),
135136
tokenize=False, # Set to False to get the string
136137
add_generation_prompt=True,
137138
add_special_tokens=False, # Prevent adding special tokens
@@ -191,11 +192,8 @@ def decode_with_tunix(
191192
prompts = [config.prompt]
192193
if config.use_chat_template:
193194
# Format the prompt using chat template if specified
194-
messages = [
195-
{"role": "user", "content": config.prompt},
196-
]
197195
input_with_chat_template = tokenizer.apply_chat_template(
198-
messages,
196+
build_chat_messages(config),
199197
tokenize=False, # Set to False to get the string
200198
add_generation_prompt=True,
201199
add_special_tokens=False, # Prevent adding special tokens
@@ -240,6 +238,12 @@ def decode_with_tunix(
240238

241239

242240
def main(argv: Sequence[str]) -> None:
241+
# Keep these in main(): registering the adapter and setting engine env flags
242+
# at import time would leak into any process that merely imports this module.
243+
adapter.register()
244+
os.environ["SKIP_JAX_PRECOMPILE"] = "1"
245+
os.environ["NEW_MODEL_DESIGN"] = "1"
246+
243247
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
244248
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
245249
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):

src/maxtext/layers/attentions.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,16 +1026,23 @@ def forward_serve_vllm(
10261026
# Return dummy values for dry runs (e.g. during model initialization or JIT tracing)
10271027
return query, []
10281028

1029-
if self.config.sliding_window_size > 0:
1029+
# Sliding window applies only to LOCAL_SLIDING layers; global layers must run
1030+
# full attention.
1031+
if self.attention_type == AttentionType.LOCAL_SLIDING and self.config.sliding_window_size > 0:
10301032
attention_chunk_size = self.config.sliding_window_size
10311033
else:
1032-
# Chunked attention currently not used in vLLM RPA.
10331034
attention_chunk_size = None
10341035

10351036
q_scale, k_scale, v_scale = None, None, None
10361037

10371038
md = rpa_metadata
10381039

1040+
# With cross-layer KV sharing (Gemma 4 E2B / E4B), a KV-shared layer has no
1041+
# cache of its own: `rpa_kv_cache` here is the donor layer's cache, and
1042+
# attention must run against the K/V the donor already wrote for this
1043+
# position. Only the donor writes the cache; shared layers read it as-is.
1044+
update_kv_cache = not self.share_kv_layer
1045+
10391046
output, kv_cache = rpa_ops(
10401047
self.mesh,
10411048
query,
@@ -1052,6 +1059,7 @@ def forward_serve_vllm(
10521059
q_scale,
10531060
k_scale,
10541061
v_scale,
1062+
update_kv_cache=update_kv_cache,
10551063
)
10561064
return output, kv_cache
10571065

src/maxtext/layers/decoders.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@ def __call__(
11001100
kv_caches[index] = kv_cache
11011101
global_layer_idx_offset += num_layers
11021102
elif cfg.decoder_block == DecoderBlockType.GEMMA4_SMALL:
1103-
y = self._apply_gemma4_small_layers(
1103+
y, kv_caches = self._apply_gemma4_small_layers(
11041104
y,
11051105
decoder_input_tokens,
11061106
decoder_segment_ids,
@@ -1379,8 +1379,12 @@ def _apply_gemma4_small_layers(
13791379
* ``per_layer_inputs`` from PLE, sliced per layer.
13801380
* ``shared_kv_states``: donor-layer-index → (key, value) for
13811381
downstream KV-shared layers to consume.
1382+
* ``kv_caches``: when running via the vLLM RPA path, the per-layer
1383+
cache buffer threaded back from the kernel. KV-shared layers
1384+
redirect to the donor's cache slot via ``cache_index_of``.
13821385
1383-
Scan-over-layers and pipeline parallelism are not supported.
1386+
Returns ``(y, kv_caches)``. Scan-over-layers and pipeline
1387+
parallelism are not supported.
13841388
"""
13851389
cfg = self.config
13861390
mesh = self.mesh
@@ -1398,6 +1402,10 @@ def _apply_gemma4_small_layers(
13981402
num_kv_shared = cfg.num_kv_shared_layers
13991403
shared_kv_states: dict[int, tuple[jax.Array, jax.Array]] = {}
14001404

1405+
# tpu-inference allocates one `kv_caches` slot per non-shared layer;
1406+
# KV-shared layers reuse the donor's slot.
1407+
cache_index_of = gemma4_small.kv_cache_slot_map(layer_types, num_kv_shared)
1408+
14011409
for lyr in range(cfg.num_decoder_layers):
14021410
attention_type = layer_types[lyr]
14031411
donor_idx = gemma4_small.kv_donor_layer_idx(lyr, layer_types, num_kv_shared)
@@ -1434,8 +1442,9 @@ def _apply_gemma4_small_layers(
14341442

14351443
ple_slice = per_layer_inputs[..., lyr, :] if per_layer_inputs is not None else None
14361444

1437-
kv_cache = kv_caches[lyr] if kv_caches is not None else None
1438-
y = layer(
1445+
cache_idx = cache_index_of[lyr]
1446+
kv_cache = kv_caches[cache_idx] if kv_caches is not None else None
1447+
y, kv_cache = layer(
14391448
y,
14401449
decoder_segment_ids,
14411450
decoder_positions,
@@ -1450,8 +1459,10 @@ def _apply_gemma4_small_layers(
14501459
shared_key=shared_key,
14511460
shared_value=shared_value,
14521461
)
1462+
if kv_caches is not None and kv_cache is not None:
1463+
kv_caches[cache_idx] = kv_cache
14531464

1454-
return y
1465+
return y, kv_caches
14551466

14561467
# TODO(b/490118813): Relocate the following functions to their designated directories
14571468
# once the plug-in strategy is implemented: _find_next_boundary(), _apply_single_engram_layer()

src/maxtext/models/gemma4_small.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,27 @@ def kv_donor_layer_idx(
106106
return None
107107

108108

109+
def kv_cache_slot_map(
110+
layer_types: tuple[AttentionType, ...],
111+
num_kv_shared_layers: int,
112+
) -> dict[int, int]:
113+
"""Maps decoder layer index -> KV-cache slot.
114+
115+
tpu-inference allocates one KV-cache slot per non-shared layer; KV-shared
116+
layers reuse their donor's slot.
117+
"""
118+
slot_of: dict[int, int] = {}
119+
next_slot = 0
120+
for lyr in range(len(layer_types)):
121+
donor_idx = kv_donor_layer_idx(lyr, layer_types, num_kv_shared_layers)
122+
if donor_idx is not None:
123+
slot_of[lyr] = slot_of[donor_idx]
124+
else:
125+
slot_of[lyr] = next_slot
126+
next_slot += 1
127+
return slot_of
128+
129+
109130
def is_kv_donor_layer(
110131
layer_idx: int,
111132
layer_types: tuple[AttentionType, ...],
@@ -454,7 +475,7 @@ def __call__(
454475
h = h * jnp.asarray(self.layer_scalar.value, cfg.dtype)
455476
h = nn.with_logical_constraint(h, self.activation_axis_names)
456477

457-
return h
478+
return h, kv_cache
458479

459480

460481
Gemma4SmallDecoderLayerToLinen = nnx_wrappers.to_linen_class(

tests/end_to_end/tpu/gemma4/Run_Gemma4.md

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,67 @@ Set `model_name`/`tokenizer_path` to your variant (`gemma4-26b`, `gemma4-31b`) a
141141
`ici_tensor_parallelism` to the number of chips — pass an explicit count (e.g. `4` on a v5p-8), not
142142
`-1`, since `vllm_decode` forwards this value directly to vLLM's `tensor_parallel_size`.
143143

144-
> **Note:** `gemma4-e2b` / `gemma4-e4b` are not yet supported. They use cross-layer KV sharing, and will be supported soon.
144+
#### E2B / E4B
145+
146+
`gemma4-e2b` and `gemma4-e4b` run through the same `vllm_decode` entry point as the larger variants, but the `-it` fine-tunes need **three things** the larger models tolerate without:
147+
148+
1. **A system prompt** ([per the HF model card](https://huggingface.co/google/gemma-4-E2B-it)) — without it the `-it` checkpoints drift off-topic at any temperature.
149+
2. **Stochastic sampling** `temperature=1.0, top_p=0.95, top_k=64` (the model card's recommended settings). Greedy decoding tends to loop on these small checkpoints, independent of the MaxText path.
150+
3. **The full stop-token set.** The upstream `google/gemma-4-*-it` repos declare `eos_token_id: [1, 106, 50]` (`<eos>`, `<turn|>`, `<|tool_response>`). If a converted checkpoint only carries `eos_token_id: 1`, end-of-turn `<turn|>` is no longer registered as a stop and generation runs to `max_tokens`. Using the upstream repo id for `tokenizer_path` keeps the full stop list automatically. A local checkpoint dir works equally well — just verify its `generation_config.json` carries the full list.
151+
152+
The CLI form, using the `system_prompt=` flag and the model card's sampling params:
153+
154+
```sh
155+
python3 -m maxtext.inference.vllm_decode src/maxtext/configs/base.yml \
156+
model_name=gemma4-e2b \
157+
tokenizer_path=google/gemma-4-e2b-it \
158+
load_parameters_path=${CONVERTED_CHECKPOINT} \
159+
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
160+
ici_tensor_parallelism=1 scan_layers=False \
161+
system_prompt="You are a helpful assistant." \
162+
prompt="Who was Albert Einstein?" use_chat_template=True \
163+
decode_sampling_temperature=1.0 \
164+
decode_sampling_nucleus_p=0.95 \
165+
decode_sampling_top_k=64
166+
```
167+
168+
Or via the Python API, useful for fixing a seed or stitching multiple requests:
169+
170+
```python
171+
import maxtext.integration.vllm.maxtext_vllm_adapter as adapter
172+
adapter.register()
173+
from vllm import LLM
174+
from vllm.sampling_params import SamplingParams
175+
import transformers
176+
177+
llm = LLM(
178+
model="google/gemma-4-e2b-it", # tokenizer + HF config dir
179+
hf_overrides={"architectures": ["MaxTextForCausalLM"]},
180+
additional_config={
181+
"maxtext_config": {
182+
"model_name": "gemma4-e2b", # or gemma4-e4b
183+
"scan_layers": False,
184+
"load_parameters_path": "${CONVERTED_CHECKPOINT}",
185+
}
186+
},
187+
tensor_parallel_size=1, # set to chip count (e.g. 4 on v5p-8)
188+
max_model_len=1024,
189+
)
190+
191+
tok = transformers.AutoTokenizer.from_pretrained("google/gemma-4-e2b-it")
192+
prompt = tok.apply_chat_template(
193+
[
194+
{"role": "system", "content": "You are a helpful assistant."},
195+
{"role": "user", "content": "Who was Albert Einstein?"},
196+
],
197+
tokenize=False,
198+
add_generation_prompt=True,
199+
)
200+
201+
out = llm.generate(
202+
[prompt],
203+
SamplingParams(temperature=1.0, top_p=0.95, top_k=64,
204+
seed=42, max_tokens=300),
205+
)
206+
print(out[0].outputs[0].text)
207+
```
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for vllm_decode helpers."""
16+
17+
import types
18+
import unittest
19+
20+
import pytest
21+
22+
pytest.importorskip("vllm")
23+
pytest.importorskip("tunix")
24+
25+
from maxtext.inference.vllm_decode import build_chat_messages
26+
27+
28+
def _config(prompt: str, system_prompt: str):
29+
return types.SimpleNamespace(prompt=prompt, system_prompt=system_prompt)
30+
31+
32+
class BuildChatMessagesTest(unittest.TestCase):
33+
"""Chat-message construction for the vllm_decode CLI."""
34+
35+
def test_user_only_when_no_system_prompt(self):
36+
messages = build_chat_messages(_config("What is 2+2?", ""))
37+
self.assertEqual(messages, [{"role": "user", "content": "What is 2+2?"}])
38+
39+
def test_system_prompt_prepended(self):
40+
messages = build_chat_messages(_config("Who was Albert Einstein?", "You are a helpful assistant."))
41+
self.assertEqual(
42+
messages,
43+
[
44+
{"role": "system", "content": "You are a helpful assistant."},
45+
{"role": "user", "content": "Who was Albert Einstein?"},
46+
],
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
unittest.main()

tests/unit/gemma4_small_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,5 +102,35 @@ def test_no_kv_sharing_when_num_kv_shared_zero(self):
102102
self.assertFalse(gemma4_small.is_kv_donor_layer(i, layer_types, 0))
103103

104104

105+
class Gemma4SmallKvCacheSlotMapTest(unittest.TestCase):
106+
"""Layer -> KV-cache slot mapping used by the vLLM RPA path."""
107+
108+
def _check_slot_map(self, model_name, num_layers, num_kv_shared):
109+
"""Asserts slot-map invariants for the given model layout."""
110+
layer_types = gemma4_small.build_layer_types(num_layers, model_name)
111+
slot_map = gemma4_small.kv_cache_slot_map(layer_types, num_kv_shared)
112+
113+
num_slots = num_layers - num_kv_shared
114+
self.assertEqual(len(slot_map), num_layers)
115+
# Non-shared layers get consecutive slots 0..num_slots-1.
116+
self.assertEqual([slot_map[i] for i in range(num_slots)], list(range(num_slots)))
117+
# Shared layers reuse the slot of a donor with the same attention type.
118+
for lyr in range(num_slots, num_layers):
119+
donor = gemma4_small.kv_donor_layer_idx(lyr, layer_types, num_kv_shared)
120+
self.assertEqual(slot_map[lyr], slot_map[donor], f"layer {lyr}")
121+
self.assertEqual(layer_types[lyr], layer_types[donor], f"layer {lyr}")
122+
123+
def test_e2b_slot_map(self):
124+
self._check_slot_map("gemma4-e2b", 35, 20)
125+
126+
def test_e4b_slot_map(self):
127+
self._check_slot_map("gemma4-e4b", 42, 18)
128+
129+
def test_slot_map_without_sharing_is_identity(self):
130+
layer_types = gemma4_small.build_layer_types(10, None)
131+
slot_map = gemma4_small.kv_cache_slot_map(layer_types, 0)
132+
self.assertEqual(slot_map, {i: i for i in range(10)})
133+
134+
105135
if __name__ == "__main__":
106136
unittest.main()

0 commit comments

Comments
 (0)