diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index f522e40164..b26e30ac76 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1569,6 +1569,14 @@ class InferenceGeneral(BaseModel): max_target_length: int = Field(2048, description="Maximum sequence length for the model.") max_prefill_predict_length: int = Field(64, description="Maximum length for the prefill stage in decoding.") prompt: str = Field("I love to", description="The default prompt for sampling.") + system_prompt: str = Field( + "", + description=( + "Optional system prompt prepended to the chat message list when " + "use_chat_template=True. Required for the gemma4-e2b / gemma4-e4b -it " + "checkpoints which need a system role to produce coherent output." + ), + ) load_from_prefill_dir: bool = Field(False, description="Reads prefill cache from directory instead of computing it.") prefill_cache_dir: PathStr = Field("", description="Directory for the prefill cache.") autoregressive_decode_assert: str = Field( diff --git a/src/maxtext/inference/vllm_decode.py b/src/maxtext/inference/vllm_decode.py index c2b1e5e5d2..0962a33b5f 100644 --- a/src/maxtext/inference/vllm_decode.py +++ b/src/maxtext/inference/vllm_decode.py @@ -50,11 +50,6 @@ from maxtext.configs import pyconfig import maxtext.integration.vllm.maxtext_vllm_adapter as adapter -adapter.register() - -os.environ["SKIP_JAX_PRECOMPILE"] = "1" -os.environ["NEW_MODEL_DESIGN"] = "1" - # --- DEFINE FLAGS GLOBALLY --- FLAGS = flags.FLAGS @@ -63,6 +58,15 @@ flags.DEFINE_integer("seed", 42, "Random seed for sampling.") +def build_chat_messages(config: Config) -> list[dict[str, str]]: + """Builds the chat message list, prepending a system prompt when set.""" + messages = [] + if config.system_prompt: + messages.append({"role": "system", "content": config.system_prompt}) + messages.append({"role": "user", "content": config.prompt}) + return messages + + def decode_with_vllm(config: Config) -> None: """Decode using vLLM with a MaxText model implementation. @@ -127,11 +131,8 @@ def decode_with_vllm(config: Config) -> None: prompts = [config.prompt] if config.use_chat_template: # Format the prompt using chat template if specified - messages = [ - {"role": "user", "content": config.prompt}, - ] input_with_chat_template = tokenizer.apply_chat_template( - messages, + build_chat_messages(config), tokenize=False, # Set to False to get the string add_generation_prompt=True, add_special_tokens=False, # Prevent adding special tokens @@ -191,11 +192,8 @@ def decode_with_tunix( prompts = [config.prompt] if config.use_chat_template: # Format the prompt using chat template if specified - messages = [ - {"role": "user", "content": config.prompt}, - ] input_with_chat_template = tokenizer.apply_chat_template( - messages, + build_chat_messages(config), tokenize=False, # Set to False to get the string add_generation_prompt=True, add_special_tokens=False, # Prevent adding special tokens @@ -240,6 +238,12 @@ def decode_with_tunix( def main(argv: Sequence[str]) -> None: + # Keep these in main(): registering the adapter and setting engine env flags + # at import time would leak into any process that merely imports this module. + adapter.register() + os.environ["SKIP_JAX_PRECOMPILE"] = "1" + os.environ["NEW_MODEL_DESIGN"] = "1" + jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 635506291d..93c54e25a6 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -1026,16 +1026,23 @@ def forward_serve_vllm( # Return dummy values for dry runs (e.g. during model initialization or JIT tracing) return query, [] - if self.config.sliding_window_size > 0: + # Sliding window applies only to LOCAL_SLIDING layers; global layers must run + # full attention. + if self.attention_type == AttentionType.LOCAL_SLIDING and self.config.sliding_window_size > 0: attention_chunk_size = self.config.sliding_window_size else: - # Chunked attention currently not used in vLLM RPA. attention_chunk_size = None q_scale, k_scale, v_scale = None, None, None md = rpa_metadata + # With cross-layer KV sharing (Gemma 4 E2B / E4B), a KV-shared layer has no + # cache of its own: `rpa_kv_cache` here is the donor layer's cache, and + # attention must run against the K/V the donor already wrote for this + # position. Only the donor writes the cache; shared layers read it as-is. + update_kv_cache = not self.share_kv_layer + output, kv_cache = rpa_ops( self.mesh, query, @@ -1052,6 +1059,7 @@ def forward_serve_vllm( q_scale, k_scale, v_scale, + update_kv_cache=update_kv_cache, ) return output, kv_cache diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index d2deb7192a..7259ea051a 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -1100,7 +1100,7 @@ def __call__( kv_caches[index] = kv_cache global_layer_idx_offset += num_layers elif cfg.decoder_block == DecoderBlockType.GEMMA4_SMALL: - y = self._apply_gemma4_small_layers( + y, kv_caches = self._apply_gemma4_small_layers( y, decoder_input_tokens, decoder_segment_ids, @@ -1379,8 +1379,12 @@ def _apply_gemma4_small_layers( * ``per_layer_inputs`` from PLE, sliced per layer. * ``shared_kv_states``: donor-layer-index → (key, value) for downstream KV-shared layers to consume. + * ``kv_caches``: when running via the vLLM RPA path, the per-layer + cache buffer threaded back from the kernel. KV-shared layers + redirect to the donor's cache slot via ``cache_index_of``. - Scan-over-layers and pipeline parallelism are not supported. + Returns ``(y, kv_caches)``. Scan-over-layers and pipeline + parallelism are not supported. """ cfg = self.config mesh = self.mesh @@ -1398,6 +1402,10 @@ def _apply_gemma4_small_layers( num_kv_shared = cfg.num_kv_shared_layers shared_kv_states: dict[int, tuple[jax.Array, jax.Array]] = {} + # tpu-inference allocates one `kv_caches` slot per non-shared layer; + # KV-shared layers reuse the donor's slot. + cache_index_of = gemma4_small.kv_cache_slot_map(layer_types, num_kv_shared) + for lyr in range(cfg.num_decoder_layers): attention_type = layer_types[lyr] donor_idx = gemma4_small.kv_donor_layer_idx(lyr, layer_types, num_kv_shared) @@ -1434,8 +1442,9 @@ def _apply_gemma4_small_layers( ple_slice = per_layer_inputs[..., lyr, :] if per_layer_inputs is not None else None - kv_cache = kv_caches[lyr] if kv_caches is not None else None - y = layer( + cache_idx = cache_index_of[lyr] + kv_cache = kv_caches[cache_idx] if kv_caches is not None else None + y, kv_cache = layer( y, decoder_segment_ids, decoder_positions, @@ -1450,8 +1459,10 @@ def _apply_gemma4_small_layers( shared_key=shared_key, shared_value=shared_value, ) + if kv_caches is not None and kv_cache is not None: + kv_caches[cache_idx] = kv_cache - return y + return y, kv_caches # TODO(b/490118813): Relocate the following functions to their designated directories # once the plug-in strategy is implemented: _find_next_boundary(), _apply_single_engram_layer() diff --git a/src/maxtext/models/gemma4_small.py b/src/maxtext/models/gemma4_small.py index 242a66469e..ca33470bf2 100644 --- a/src/maxtext/models/gemma4_small.py +++ b/src/maxtext/models/gemma4_small.py @@ -106,6 +106,27 @@ def kv_donor_layer_idx( return None +def kv_cache_slot_map( + layer_types: tuple[AttentionType, ...], + num_kv_shared_layers: int, +) -> dict[int, int]: + """Maps decoder layer index -> KV-cache slot. + + tpu-inference allocates one KV-cache slot per non-shared layer; KV-shared + layers reuse their donor's slot. + """ + slot_of: dict[int, int] = {} + next_slot = 0 + for lyr in range(len(layer_types)): + donor_idx = kv_donor_layer_idx(lyr, layer_types, num_kv_shared_layers) + if donor_idx is not None: + slot_of[lyr] = slot_of[donor_idx] + else: + slot_of[lyr] = next_slot + next_slot += 1 + return slot_of + + def is_kv_donor_layer( layer_idx: int, layer_types: tuple[AttentionType, ...], @@ -454,7 +475,7 @@ def __call__( h = h * jnp.asarray(self.layer_scalar.value, cfg.dtype) h = nn.with_logical_constraint(h, self.activation_axis_names) - return h + return h, kv_cache Gemma4SmallDecoderLayerToLinen = nnx_wrappers.to_linen_class( diff --git a/tests/end_to_end/tpu/gemma4/Run_Gemma4.md b/tests/end_to_end/tpu/gemma4/Run_Gemma4.md index 1df948b34f..d1cc91de17 100644 --- a/tests/end_to_end/tpu/gemma4/Run_Gemma4.md +++ b/tests/end_to_end/tpu/gemma4/Run_Gemma4.md @@ -141,4 +141,67 @@ Set `model_name`/`tokenizer_path` to your variant (`gemma4-26b`, `gemma4-31b`) a `ici_tensor_parallelism` to the number of chips — pass an explicit count (e.g. `4` on a v5p-8), not `-1`, since `vllm_decode` forwards this value directly to vLLM's `tensor_parallel_size`. -> **Note:** `gemma4-e2b` / `gemma4-e4b` are not yet supported. They use cross-layer KV sharing, and will be supported soon. +#### E2B / E4B + +`gemma4-e2b` and `gemma4-e4b` run through the same `vllm_decode` entry point as the larger variants, but the `-it` fine-tunes need **three things** the larger models tolerate without: + +1. **A system prompt** ([per the HF model card](https://huggingface.co/google/gemma-4-E2B-it)) — without it the `-it` checkpoints drift off-topic at any temperature. +2. **Stochastic sampling** `temperature=1.0, top_p=0.95, top_k=64` (the model card's recommended settings). Greedy decoding tends to loop on these small checkpoints, independent of the MaxText path. +3. **The full stop-token set.** The upstream `google/gemma-4-*-it` repos declare `eos_token_id: [1, 106, 50]` (``, ``, `<|tool_response>`). If a converted checkpoint only carries `eos_token_id: 1`, end-of-turn `` is no longer registered as a stop and generation runs to `max_tokens`. Using the upstream repo id for `tokenizer_path` keeps the full stop list automatically. A local checkpoint dir works equally well — just verify its `generation_config.json` carries the full list. + +The CLI form, using the `system_prompt=` flag and the model card's sampling params: + +```sh +python3 -m maxtext.inference.vllm_decode src/maxtext/configs/base.yml \ + model_name=gemma4-e2b \ + tokenizer_path=google/gemma-4-e2b-it \ + load_parameters_path=${CONVERTED_CHECKPOINT} \ + vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ + ici_tensor_parallelism=1 scan_layers=False \ + system_prompt="You are a helpful assistant." \ + prompt="Who was Albert Einstein?" use_chat_template=True \ + decode_sampling_temperature=1.0 \ + decode_sampling_nucleus_p=0.95 \ + decode_sampling_top_k=64 +``` + +Or via the Python API, useful for fixing a seed or stitching multiple requests: + +```python +import maxtext.integration.vllm.maxtext_vllm_adapter as adapter +adapter.register() +from vllm import LLM +from vllm.sampling_params import SamplingParams +import transformers + +llm = LLM( + model="google/gemma-4-e2b-it", # tokenizer + HF config dir + hf_overrides={"architectures": ["MaxTextForCausalLM"]}, + additional_config={ + "maxtext_config": { + "model_name": "gemma4-e2b", # or gemma4-e4b + "scan_layers": False, + "load_parameters_path": "${CONVERTED_CHECKPOINT}", + } + }, + tensor_parallel_size=1, # set to chip count (e.g. 4 on v5p-8) + max_model_len=1024, +) + +tok = transformers.AutoTokenizer.from_pretrained("google/gemma-4-e2b-it") +prompt = tok.apply_chat_template( + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who was Albert Einstein?"}, + ], + tokenize=False, + add_generation_prompt=True, +) + +out = llm.generate( + [prompt], + SamplingParams(temperature=1.0, top_p=0.95, top_k=64, + seed=42, max_tokens=300), +) +print(out[0].outputs[0].text) +``` diff --git a/tests/post_training/unit/vllm_decode_test.py b/tests/post_training/unit/vllm_decode_test.py new file mode 100644 index 0000000000..a312efacc7 --- /dev/null +++ b/tests/post_training/unit/vllm_decode_test.py @@ -0,0 +1,51 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for vllm_decode helpers.""" + +import types +import unittest + +import pytest + +pytest.importorskip("vllm") +pytest.importorskip("tunix") + +from maxtext.inference.vllm_decode import build_chat_messages + + +def _config(prompt: str, system_prompt: str): + return types.SimpleNamespace(prompt=prompt, system_prompt=system_prompt) + + +class BuildChatMessagesTest(unittest.TestCase): + """Chat-message construction for the vllm_decode CLI.""" + + def test_user_only_when_no_system_prompt(self): + messages = build_chat_messages(_config("What is 2+2?", "")) + self.assertEqual(messages, [{"role": "user", "content": "What is 2+2?"}]) + + def test_system_prompt_prepended(self): + messages = build_chat_messages(_config("Who was Albert Einstein?", "You are a helpful assistant.")) + self.assertEqual( + messages, + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who was Albert Einstein?"}, + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/gemma4_small_test.py b/tests/unit/gemma4_small_test.py index 4bd528d414..880be61487 100644 --- a/tests/unit/gemma4_small_test.py +++ b/tests/unit/gemma4_small_test.py @@ -102,5 +102,35 @@ def test_no_kv_sharing_when_num_kv_shared_zero(self): self.assertFalse(gemma4_small.is_kv_donor_layer(i, layer_types, 0)) +class Gemma4SmallKvCacheSlotMapTest(unittest.TestCase): + """Layer -> KV-cache slot mapping used by the vLLM RPA path.""" + + def _check_slot_map(self, model_name, num_layers, num_kv_shared): + """Asserts slot-map invariants for the given model layout.""" + layer_types = gemma4_small.build_layer_types(num_layers, model_name) + slot_map = gemma4_small.kv_cache_slot_map(layer_types, num_kv_shared) + + num_slots = num_layers - num_kv_shared + self.assertEqual(len(slot_map), num_layers) + # Non-shared layers get consecutive slots 0..num_slots-1. + self.assertEqual([slot_map[i] for i in range(num_slots)], list(range(num_slots))) + # Shared layers reuse the slot of a donor with the same attention type. + for lyr in range(num_slots, num_layers): + donor = gemma4_small.kv_donor_layer_idx(lyr, layer_types, num_kv_shared) + self.assertEqual(slot_map[lyr], slot_map[donor], f"layer {lyr}") + self.assertEqual(layer_types[lyr], layer_types[donor], f"layer {lyr}") + + def test_e2b_slot_map(self): + self._check_slot_map("gemma4-e2b", 35, 20) + + def test_e4b_slot_map(self): + self._check_slot_map("gemma4-e4b", 42, 18) + + def test_slot_map_without_sharing_is_identity(self): + layer_types = gemma4_small.build_layer_types(10, None) + slot_map = gemma4_small.kv_cache_slot_map(layer_types, 0) + self.assertEqual(slot_map, {i: i for i in range(10)}) + + if __name__ == "__main__": unittest.main()