Skip to content

Commit 1b53c79

Browse files
committed
Support scanned layers KV cache carry and fix vLLM integration
1 parent 21c4433 commit 1b53c79

18 files changed

Lines changed: 407 additions & 90 deletions

File tree

src/maxtext/configs/inference/vllm.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ logical_axis_rules: [
7878
['cache_heads', ['model']],
7979
['exp', ['expert', 'attn_dp_expert']],
8080
['paged_kv_heads', ['model']],
81+
['layers', []],
8182
]
8283
data_sharding: [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']]
8384
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']

src/maxtext/inference/vllm_decode.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
from vllm import LLM
4949
from vllm.sampling_params import SamplingParams
5050
from maxtext.configs import pyconfig
51+
import maxtext.integration.vllm.maxtext_vllm_adapter as adapter
52+
adapter.register()
5153

5254
os.environ["SKIP_JAX_PRECOMPILE"] = "1"
5355
os.environ["NEW_MODEL_DESIGN"] = "1"
@@ -82,6 +84,7 @@ def decode_with_vllm(config: Config) -> None:
8284
"weight_dtype": "bfloat16",
8385
"allow_split_physical_axes": True,
8486
"debug_sharding": config.debug_sharding,
87+
"scan_layers": config.scan_layers,
8588
},
8689
"sharding": {
8790
"sharding_strategy": {
@@ -140,11 +143,15 @@ def decode_with_vllm(config: Config) -> None:
140143
f"max_target_length ({config.max_target_length}) must be greater than max_prompt_length ({max_prompt_length})"
141144
)
142145

146+
# MaxText uses -1 to mean "disabled"; vLLM requires top_p in (0, 1].
147+
top_p = config.decode_sampling_nucleus_p if config.decode_sampling_nucleus_p > 0 else 1.0
148+
top_k = config.decode_sampling_top_k if config.decode_sampling_top_k > 0 else -1
149+
143150
sampling_params = SamplingParams(
144151
temperature=config.decode_sampling_temperature,
145152
max_tokens=max_tokens_to_generate,
146-
top_k=config.decode_sampling_top_k,
147-
top_p=config.decode_sampling_nucleus_p,
153+
top_k=top_k,
154+
top_p=top_p,
148155
)
149156

150157
outputs = llm.generate(prompts, sampling_params)

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ class MaxTextForCausalLM(nnx.Module):
104104
of the decoding step.
105105
"""
106106

107+
# Signal to tpu-inference model_loader that this class manages its own
108+
# JIT-sharded initialization (via create_nnx_model with out_shardings).
109+
# When True, model_loader skips wrapping __init__ in an outer bare @jax.jit,
110+
_self_manages_sharding: bool = True
111+
107112
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
108113
"""Initializes the MaxTextForCausalLM model.
109114
@@ -250,7 +255,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
250255
if self.model is not None:
251256
return
252257

253-
with self.mesh, nn.logical_axis_rules(""):
258+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
254259
model, _ = model_creation_utils.create_nnx_model(
255260
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
256261
)

src/maxtext/layers/attentions.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,7 @@ def forward_serve_vllm(
981981
value: Array,
982982
rpa_kv_cache: list[Array] | None = None,
983983
rpa_metadata: dict[str, Any] | None = None,
984-
) -> tuple[list[Array], Array]:
984+
) -> tuple[Array, list[Array]]:
985985
"""Forward function for vLLM serving with RPA attention."""
986986
try:
987987
# pylint: disable=import-outside-toplevel
@@ -992,13 +992,14 @@ def forward_serve_vllm(
992992
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
993993
) from e
994994

995-
if rpa_kv_cache is None or rpa_metadata is None:
996-
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
997-
998995
query = query.reshape(-1, query.shape[2], query.shape[3])
999996
key = key.reshape(-1, key.shape[2], key.shape[3])
1000997
value = value.reshape(-1, value.shape[2], value.shape[3])
1001998

999+
if rpa_kv_cache is None or rpa_metadata is None:
1000+
# Return dummy values for dry runs (e.g. during model initialization or JIT tracing)
1001+
return query, []
1002+
10021003
if self.config.sliding_window_size > 0:
10031004
attention_chunk_size = self.config.sliding_window_size
10041005
else:
@@ -1026,7 +1027,7 @@ def forward_serve_vllm(
10261027
k_scale,
10271028
v_scale,
10281029
)
1029-
return kv_cache, output
1030+
return output, kv_cache
10301031

10311032
def __call__(
10321033
self,
@@ -1169,7 +1170,7 @@ def __call__(
11691170

11701171
elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN:
11711172
batch, seq_len, num_heads, head_dim = query.shape
1172-
updated_kv, attn_out = self.forward_serve_vllm(
1173+
attn_out, updated_kv = self.forward_serve_vllm(
11731174
query, key, value, rpa_kv_cache=kv_cache, rpa_metadata=attention_metadata
11741175
)
11751176
out = attn_out.reshape(batch, seq_len, num_heads, head_dim)

src/maxtext/layers/decoders.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -981,16 +981,58 @@ def __call__(
981981
"nope_layer_interval": self.config.nope_layer_interval,
982982
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
983983
}
984-
y, _ = self.scan_decoder_layers(
985-
cfg,
986-
RemattedBlockLayer,
987-
scan_length,
988-
"layers",
989-
mesh,
990-
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
991-
model_mode=model_mode,
992-
**layer_kwargs,
993-
)(y, *broadcast_args)
984+
# Update broadcast_args and in_axes_tuple for vLLM RPA
985+
in_axes_tuple = (nn.broadcast,) * len(broadcast_args)
986+
current_broadcast_args = list(broadcast_args)
987+
current_in_axes_tuple = list(in_axes_tuple)
988+
989+
if kv_caches is not None:
990+
# Stack kv_caches for scan: [num_layers, ...]
991+
stacked_kv_cache = jnp.stack(kv_caches, axis=0)
992+
993+
# We pass (y, stacked_kv_cache, 0) as the carry
994+
carry = (y, stacked_kv_cache, 0)
995+
996+
# We don't pass kv_cache as a scanned argument anymore
997+
998+
# Pass None for previous_chunk, slot, page_state, kv_cache to align with __call__ signature
999+
current_broadcast_args.extend([None, None, None, None, attention_metadata])
1000+
current_in_axes_tuple.extend([nn.broadcast] * 5)
1001+
1002+
max_logging.info(f"DEBUG: len(current_broadcast_args)={len(current_broadcast_args)}")
1003+
max_logging.info(f"DEBUG: current_broadcast_args={[type(a) for a in current_broadcast_args]}")
1004+
1005+
final_carry, _ = self.scan_decoder_layers(
1006+
cfg,
1007+
RemattedBlockLayer,
1008+
scan_length,
1009+
"layers",
1010+
mesh,
1011+
in_axes_tuple=tuple(current_in_axes_tuple),
1012+
model_mode=model_mode,
1013+
**layer_kwargs,
1014+
)(carry, *current_broadcast_args)
1015+
1016+
y, returned_kv_cache, _ = final_carry
1017+
1018+
# Update the list of KV caches from the scanned results
1019+
for i in range(cfg.num_decoder_layers):
1020+
kv_caches[i] = returned_kv_cache[i]
1021+
else:
1022+
# Fallback to old behavior if kv_caches is None (not vLLM RPA)
1023+
current_broadcast_args.append(None)
1024+
current_in_axes_tuple.append(nn.broadcast)
1025+
1026+
y, _ = self.scan_decoder_layers(
1027+
cfg,
1028+
RemattedBlockLayer,
1029+
scan_length,
1030+
"layers",
1031+
mesh,
1032+
in_axes_tuple=tuple(current_in_axes_tuple),
1033+
model_mode=model_mode,
1034+
**layer_kwargs,
1035+
)(y, *current_broadcast_args)
9941036
else:
9951037
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
9961038
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."

src/maxtext/layers/nnx_decoders.py

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,23 @@ def pure_layer_fn(state_in, y_in):
428428

429429
return out
430430

431-
def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs):
432-
"""Runs the layer stack using nnx.scan."""
431+
def _apply_layers_sequentially(self, layers, x_in, *args, length: int, kv_caches_stacked=None, **kwargs):
432+
"""Runs the layer stack using nnx.scan.
433+
434+
Args:
435+
layers: The stacked NNX module whose params are scanned over.
436+
x_in: The carry (hidden state) fed into the first layer.
437+
*args: Positional args broadcast to every layer call.
438+
length: Number of scan iterations (= number of layers).
439+
kv_caches_stacked: Optional pytree whose leaves have shape [num_layers, ...].
440+
When provided, the i-th slice is passed as `kv_cache=` to layer i and the
441+
updated caches are returned as a third element of the tuple.
442+
**kwargs: Keyword args forwarded to the layer (filtered by the layer signature).
443+
444+
Returns:
445+
(final_carry, updated_layers) when kv_caches_stacked is None.
446+
(final_carry, updated_layers, returned_kv_stacked) otherwise.
447+
"""
433448
policy = self.get_remat_policy()
434449
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config)
435450
graphdef, params, state = nnx.split(
@@ -450,35 +465,83 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
450465
# Filter kwargs to only include keys that exist in the layer's signature
451466
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
452467

468+
use_kv = kv_caches_stacked is not None
469+
453470
def layer_fn(carry, scanned_vars):
454471
# Unpack the sliced variables for THIS layer
455-
current_params, current_state = scanned_vars
472+
if use_kv:
473+
current_params, current_state, kv_cache_layer = scanned_vars
474+
else:
475+
current_params, current_state = scanned_vars
476+
kv_cache_layer = None
456477

457478
if self.config.parameter_memory_host_offload:
458479
current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
459480

460481
# Merge using the SLICED state
461482
layer = nnx.merge(graphdef, current_params, current_state)
462483

463-
# Run the layer (Filter kwargs if using the solution from previous turn)
464-
layer_out = layer(carry, *args, **valid_kwargs)
484+
# Build call kwargs, injecting per-layer kv_cache when available
485+
call_kwargs = dict(valid_kwargs)
486+
if kv_cache_layer is not None:
487+
call_kwargs["kv_cache"] = kv_cache_layer
465488

466-
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
489+
layer_out = layer(carry, *args, **call_kwargs)
490+
491+
if isinstance(layer_out, tuple):
492+
new_carry = layer_out[0]
493+
updated_kv = layer_out[1] if len(layer_out) > 1 else None
494+
else:
495+
new_carry = layer_out
496+
updated_kv = None
467497

468498
# Extract the updated state to return it
469-
# _, new_current_state = nnx.split(layer, nnx.Param, ...)
470499
new_current_state = nnx.state(layer)
500+
501+
if use_kv:
502+
return new_carry, (new_current_state, updated_kv)
471503
return new_carry, new_current_state
472504

473505
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
474506

475-
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
507+
if use_kv:
508+
# If kv_caches is provided (e.g., from vLLM), we CANNOT use jax.lax.scan
509+
# because scanning requires stacking the kv_caches list, which creates a copy
510+
# and breaks the in-place memory updates required by vLLM's PagedAttention.
511+
# Therefore, we must unroll the loop statically when kv_caches is provided.
512+
513+
# kv_caches_stacked is actually the original kv_caches list in this new flow
514+
kv_caches_list = kv_caches_stacked
515+
516+
current_carry = x_in
517+
518+
for i in range(length):
519+
# Statically slice the parameters and state for this layer
520+
current_params = jax.tree.map(lambda x: x[i], params)
521+
current_state = jax.tree.map(lambda x: x[i], state)
522+
523+
# Call the layer
524+
current_carry, (new_current_state, updated_kv) = layer_fn(
525+
current_carry, (current_params, current_state, kv_caches_list[i])
526+
)
527+
528+
# Update the list in-place (mutates the list passed by reference)
529+
kv_caches_list[i] = updated_kv
530+
531+
# We don't need to rebuild scanned_state or return it because during
532+
# inference with vLLM, parameters do not change and we don't need intermediates.
533+
return current_carry, layers, None
534+
else:
535+
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
536+
returned_kv_stacked = None
476537

477538
if scan_axis != 0:
478539
scanned_params, scanned_other = scanned_state.split(nnx.Param, ...)
479540
scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params)
480541
scanned_state = nnx.State.merge(scanned_params, scanned_other)
481542

543+
if use_kv:
544+
return final_carry, nnx.merge(graphdef, scanned_state), returned_kv_stacked
482545
return final_carry, nnx.merge(graphdef, scanned_state)
483546

484547
def get_decoder_layers(self):
@@ -1001,7 +1064,19 @@ def __call__(
10011064
)
10021065
else:
10031066
scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval)
1004-
y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs)
1067+
if kv_caches is not None:
1068+
# Pass the kv_caches list directly to avoid copying in jnp.stack,
1069+
# which breaks vLLM PagedAttention in-place memory updates.
1070+
# The _apply_layers_sequentially function will handle it by statically unrolling.
1071+
y, self.layers, returned_kv = self._apply_layers_sequentially(
1072+
self.layers, y, *layer_args, length=scan_length,
1073+
kv_caches_stacked=kv_caches, **layer_kwargs
1074+
)
1075+
# kv_caches list is updated in-place inside _apply_layers_sequentially
1076+
else:
1077+
y, self.layers = self._apply_layers_sequentially(
1078+
self.layers, y, *layer_args, length=scan_length, **layer_kwargs
1079+
)
10051080
else:
10061081
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg)
10071082

src/maxtext/models/gemma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from maxtext.layers.linears import Dropout, MlpBlock
3131
from maxtext.layers.normalizations import RMSNorm
3232
from maxtext.layers.quantizations import AqtQuantization as Quant
33+
from maxtext.inference import page_manager
3334
from maxtext.utils import max_utils
3435

3536

@@ -126,8 +127,7 @@ def __call__(
126127
deterministic,
127128
model_mode,
128129
previous_chunk=None,
129-
page_manager=None,
130-
page_state=None,
130+
page_state: None | page_manager.PageState = None,
131131
slot=None,
132132
kv_cache=None,
133133
attention_metadata=None,

src/maxtext/models/gemma3.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,13 @@ def __call__(
194194
):
195195
cfg = self.config
196196
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
197-
if isinstance(inputs, tuple):
197+
is_scan_carry = False
198+
if isinstance(inputs, tuple) and len(inputs) == 3:
199+
hidden_states, stacked_kv_cache, layer_idx = inputs
200+
kv_cache = stacked_kv_cache[layer_idx]
201+
inputs = hidden_states
202+
is_scan_carry = True
203+
elif isinstance(inputs, tuple):
198204
inputs = inputs[0]
199205
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
200206
inputs = checkpoint_name(inputs, "decoder_layer_input")
@@ -244,7 +250,15 @@ def __call__(
244250
jnp.sum(layer_output == 0) / jnp.size(layer_output),
245251
)
246252

247-
if cfg.scan_layers:
253+
if is_scan_carry:
254+
def update_cache(cache, val):
255+
if jnp.size(val) > 0:
256+
return cache.at[layer_idx].set(val)
257+
return cache
258+
259+
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
260+
return (layer_output, stacked_kv_cache, layer_idx + 1), None
261+
elif cfg.scan_layers:
248262
return layer_output, None
249263
else:
250264
return layer_output, kv_cache

src/maxtext/models/gemma4.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,13 @@ def __call__(
322322
):
323323
cfg = self.config
324324
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
325-
if isinstance(inputs, tuple):
325+
is_scan_carry = False
326+
if isinstance(inputs, tuple) and len(inputs) == 3:
327+
hidden_states, stacked_kv_cache, layer_idx = inputs
328+
kv_cache = stacked_kv_cache[layer_idx]
329+
inputs = hidden_states
330+
is_scan_carry = True
331+
elif isinstance(inputs, tuple):
326332
inputs = inputs[0]
327333
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
328334
inputs = checkpoint_name(inputs, "decoder_layer_input")
@@ -383,7 +389,15 @@ def __call__(
383389
jnp.sum(layer_output == 0) / jnp.size(layer_output),
384390
)
385391

386-
if cfg.scan_layers:
392+
if is_scan_carry:
393+
def update_cache(cache, val):
394+
if jnp.size(val) > 0:
395+
return cache.at[layer_idx].set(val)
396+
return cache
397+
398+
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
399+
return (layer_output, stacked_kv_cache, layer_idx + 1), None
400+
elif cfg.scan_layers:
387401
return layer_output, None
388402
else:
389403
return layer_output, kv_cache

0 commit comments

Comments
 (0)