Skip to content

Commit 96215fd

Browse files
Merge pull request #3372 from AI-Hypercomputer:mohit/scan_vllm_serve
PiperOrigin-RevId: 907233362
2 parents 142bcf1 + 0ed706e commit 96215fd

17 files changed

Lines changed: 416 additions & 95 deletions

File tree

src/maxtext/inference/vllm_decode.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
from vllm import LLM
4949
from vllm.sampling_params import SamplingParams
5050
from maxtext.configs import pyconfig
51+
import maxtext.integration.vllm.maxtext_vllm_adapter as adapter
52+
53+
adapter.register()
5154

5255
os.environ["SKIP_JAX_PRECOMPILE"] = "1"
5356
os.environ["NEW_MODEL_DESIGN"] = "1"
@@ -83,6 +86,7 @@ def decode_with_vllm(config: Config) -> None:
8386
"allow_split_physical_axes": True,
8487
"debug_sharding": config.debug_sharding,
8588
"prefuse_moe_weights": config.prefuse_moe_weights,
89+
"scan_layers": config.scan_layers,
8690
},
8791
"sharding": {
8892
"sharding_strategy": {
@@ -141,11 +145,15 @@ def decode_with_vllm(config: Config) -> None:
141145
f"max_target_length ({config.max_target_length}) must be greater than max_prompt_length ({max_prompt_length})"
142146
)
143147

148+
# MaxText uses -1 to mean "disabled"; vLLM requires top_p in (0, 1].
149+
top_p = config.decode_sampling_nucleus_p if config.decode_sampling_nucleus_p > 0 else 1.0
150+
top_k = config.decode_sampling_top_k if config.decode_sampling_top_k > 0 else -1
151+
144152
sampling_params = SamplingParams(
145153
temperature=config.decode_sampling_temperature,
146154
max_tokens=max_tokens_to_generate,
147-
top_k=config.decode_sampling_top_k,
148-
top_p=config.decode_sampling_nucleus_p,
155+
top_k=top_k,
156+
top_p=top_p,
149157
)
150158

151159
outputs = llm.generate(prompts, sampling_params)

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ class MaxTextForCausalLM(nnx.Module):
104104
of the decoding step.
105105
"""
106106

107+
# Signal to tpu-inference model_loader that this class manages its own
108+
# JIT-sharded initialization (via create_nnx_model with out_shardings).
109+
# When True, model_loader skips wrapping __init__ in an outer bare @jax.jit,
110+
_self_manages_sharding: bool = True
111+
107112
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
108113
"""Initializes the MaxTextForCausalLM model.
109114

src/maxtext/layers/attentions.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,7 @@ def forward_serve_vllm(
981981
value: Array,
982982
rpa_kv_cache: list[Array] | None = None,
983983
rpa_metadata: dict[str, Any] | None = None,
984-
) -> tuple[list[Array], Array]:
984+
) -> tuple[Array, list[Array]]:
985985
"""Forward function for vLLM serving with RPA attention."""
986986
try:
987987
# pylint: disable=import-outside-toplevel
@@ -992,13 +992,14 @@ def forward_serve_vllm(
992992
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
993993
) from e
994994

995-
if rpa_kv_cache is None or rpa_metadata is None:
996-
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
997-
998995
query = query.reshape(-1, query.shape[2], query.shape[3])
999996
key = key.reshape(-1, key.shape[2], key.shape[3])
1000997
value = value.reshape(-1, value.shape[2], value.shape[3])
1001998

999+
if rpa_kv_cache is None or rpa_metadata is None:
1000+
# Return dummy values for dry runs (e.g. during model initialization or JIT tracing)
1001+
return query, []
1002+
10021003
if self.config.sliding_window_size > 0:
10031004
attention_chunk_size = self.config.sliding_window_size
10041005
else:
@@ -1026,7 +1027,7 @@ def forward_serve_vllm(
10261027
k_scale,
10271028
v_scale,
10281029
)
1029-
return kv_cache, output
1030+
return output, kv_cache
10301031

10311032
def __call__(
10321033
self,
@@ -1169,7 +1170,7 @@ def __call__(
11691170

11701171
elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN:
11711172
batch, seq_len, num_heads, head_dim = query.shape
1172-
updated_kv, attn_out = self.forward_serve_vllm(
1173+
attn_out, updated_kv = self.forward_serve_vllm(
11731174
query, key, value, rpa_kv_cache=kv_cache, rpa_metadata=attention_metadata
11741175
)
11751176
out = attn_out.reshape(batch, seq_len, num_heads, head_dim)

src/maxtext/layers/decoders.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -989,16 +989,58 @@ def __call__(
989989
"nope_layer_interval": self.config.nope_layer_interval,
990990
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
991991
}
992-
y, _ = self.scan_decoder_layers(
993-
cfg,
994-
RemattedBlockLayer,
995-
scan_length,
996-
"layers",
997-
mesh,
998-
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
999-
model_mode=model_mode,
1000-
**layer_kwargs,
1001-
)(y, *broadcast_args)
992+
# Update broadcast_args and in_axes_tuple for vLLM RPA
993+
in_axes_tuple = (nn.broadcast,) * len(broadcast_args)
994+
current_broadcast_args = list(broadcast_args)
995+
current_in_axes_tuple = list(in_axes_tuple)
996+
997+
if kv_caches is not None:
998+
# Stack kv_caches for scan: [num_layers, ...]
999+
stacked_kv_cache = jnp.stack(kv_caches, axis=0)
1000+
1001+
# We pass (y, stacked_kv_cache, 0) as the carry
1002+
carry = (y, stacked_kv_cache, 0)
1003+
1004+
# We don't pass kv_cache as a scanned argument anymore
1005+
1006+
# Pass None for previous_chunk, slot, page_state, kv_cache to align with __call__ signature
1007+
current_broadcast_args.extend([None, None, None, None, attention_metadata])
1008+
current_in_axes_tuple.extend([nn.broadcast] * 5)
1009+
1010+
max_logging.info(f"DEBUG: len(current_broadcast_args)={len(current_broadcast_args)}")
1011+
max_logging.info(f"DEBUG: current_broadcast_args={[type(a) for a in current_broadcast_args]}")
1012+
1013+
final_carry, _ = self.scan_decoder_layers(
1014+
cfg,
1015+
RemattedBlockLayer,
1016+
scan_length,
1017+
"layers",
1018+
mesh,
1019+
in_axes_tuple=tuple(current_in_axes_tuple),
1020+
model_mode=model_mode,
1021+
**layer_kwargs,
1022+
)(carry, *current_broadcast_args)
1023+
1024+
y, returned_kv_cache, _ = final_carry
1025+
1026+
# Update the list of KV caches from the scanned results
1027+
for i in range(cfg.num_decoder_layers):
1028+
kv_caches[i] = returned_kv_cache[i]
1029+
else:
1030+
# Fallback to old behavior if kv_caches is None (not vLLM RPA)
1031+
current_broadcast_args.append(None)
1032+
current_in_axes_tuple.append(nn.broadcast)
1033+
1034+
y, _ = self.scan_decoder_layers(
1035+
cfg,
1036+
RemattedBlockLayer,
1037+
scan_length,
1038+
"layers",
1039+
mesh,
1040+
in_axes_tuple=tuple(current_in_axes_tuple),
1041+
model_mode=model_mode,
1042+
**layer_kwargs,
1043+
)(y, *current_broadcast_args)
10021044
else:
10031045
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
10041046
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."

src/maxtext/layers/nnx_decoders.py

Lines changed: 84 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,23 @@ def pure_layer_fn(state_in, y_in):
428428

429429
return out
430430

431-
def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs):
432-
"""Runs the layer stack using nnx.scan."""
431+
def _apply_layers_sequentially(self, layers, x_in, *args, length: int, kv_caches_stacked=None, **kwargs):
432+
"""Runs the layer stack using nnx.scan.
433+
434+
Args:
435+
layers: The stacked NNX module whose params are scanned over.
436+
x_in: The carry (hidden state) fed into the first layer.
437+
*args: Positional args broadcast to every layer call.
438+
length: Number of scan iterations (= number of layers).
439+
kv_caches_stacked: Optional pytree whose leaves have shape [num_layers, ...].
440+
When provided, the i-th slice is passed as `kv_cache=` to layer i and the
441+
updated caches are returned as a third element of the tuple.
442+
**kwargs: Keyword args forwarded to the layer (filtered by the layer signature).
443+
444+
Returns:
445+
(final_carry, updated_layers) when kv_caches_stacked is None.
446+
(final_carry, updated_layers, returned_kv_stacked) otherwise.
447+
"""
433448
policy = self.get_remat_policy()
434449
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config)
435450
graphdef, params, state = nnx.split(
@@ -450,36 +465,80 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
450465
# Filter kwargs to only include keys that exist in the layer's signature
451466
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
452467

468+
use_kv = kv_caches_stacked is not None
469+
453470
def layer_fn(carry, scanned_vars):
454471
# Unpack the sliced variables for THIS layer
455-
current_params, current_state = scanned_vars
472+
if use_kv:
473+
current_params, current_state, kv_cache_layer = scanned_vars
474+
else:
475+
current_params, current_state = scanned_vars
476+
kv_cache_layer = None
456477

457478
if self.config.parameter_memory_host_offload:
458479
current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
459480

460481
# Merge using the SLICED state
461482
layer = nnx.merge(graphdef, current_params, current_state)
462483

463-
# Run the layer (Filter kwargs if using the solution from previous turn)
464-
layer_out = layer(carry, *args, **valid_kwargs)
484+
# Build call kwargs, injecting per-layer kv_cache when available
485+
call_kwargs = dict(valid_kwargs)
486+
if kv_cache_layer is not None:
487+
call_kwargs["kv_cache"] = kv_cache_layer
488+
489+
layer_out = layer(carry, *args, **call_kwargs)
465490

466-
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
491+
if isinstance(layer_out, tuple):
492+
new_carry = layer_out[0]
493+
updated_kv = layer_out[1] if len(layer_out) > 1 else None
494+
else:
495+
new_carry = layer_out
496+
updated_kv = None
467497

468498
# Extract the updated state to return it
469-
# _, new_current_state = nnx.split(layer, nnx.Param, ...)
470499
new_current_state = nnx.state(layer)
500+
501+
if use_kv:
502+
return new_carry, (new_current_state, updated_kv)
471503
return new_carry, new_current_state
472504

473505
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
474506

475-
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
507+
if use_kv:
508+
# If kv_caches is provided (e.g., from vLLM), we CANNOT use jax.lax.scan
509+
# because scanning requires stacking the kv_caches list, which creates a copy
510+
# and breaks the in-place memory updates required by vLLM's PagedAttention.
511+
# Therefore, we must unroll the loop statically when kv_caches is provided.
512+
513+
# kv_caches_stacked is actually the original kv_caches list in this new flow
514+
kv_caches_list = kv_caches_stacked
515+
516+
current_carry = x_in
517+
518+
for i in range(length):
519+
# Statically slice the parameters and state for this layer
520+
current_params = jax.tree.map(lambda x, i=i: x[i], params)
521+
current_state = jax.tree.map(lambda x, i=i: x[i], state)
522+
523+
# Call the layer
524+
current_carry, (_, updated_kv) = layer_fn(current_carry, (current_params, current_state, kv_caches_list[i]))
525+
526+
# Update the list in-place (mutates the list passed by reference)
527+
kv_caches_list[i] = updated_kv
528+
529+
# We don't need to rebuild scanned_state or return it because during
530+
# inference with vLLM, parameters do not change and we don't need intermediates.
531+
return current_carry, layers, None
532+
else:
533+
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
534+
returned_kv_stacked = None
476535

477536
if scan_axis != 0:
478537
scanned_params, scanned_other = scanned_state.split(nnx.Param, ...)
479538
scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params)
480539
scanned_state = nnx.State.merge(scanned_params, scanned_other)
481540

482-
return final_carry, nnx.merge(graphdef, scanned_state)
541+
return final_carry, nnx.merge(graphdef, scanned_state), returned_kv_stacked if use_kv else None
483542

484543
def get_decoder_layers(self):
485544
"""Retrieves decoder layer classes based on config using a dictionary lookup."""
@@ -859,7 +918,7 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args
859918
chunk_stack = nnx.merge(graphdef, chunk_state)
860919

861920
# Apply sequentially
862-
y, chunk_stack = self._apply_layers_sequentially(
921+
y, chunk_stack, _ = self._apply_layers_sequentially(
863922
chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {})
864923
)
865924

@@ -966,7 +1025,7 @@ def __call__(
9661025
**common_kwargs,
9671026
)
9681027
else:
969-
y, self.dense_layers = self._apply_layers_sequentially(
1028+
y, self.dense_layers, _ = self._apply_layers_sequentially(
9701029
self.dense_layers, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs
9711030
)
9721031

@@ -984,7 +1043,7 @@ def __call__(
9841043
num_layers=num_moe,
9851044
)
9861045
else:
987-
y, self.moe_layer = self._apply_layers_sequentially(
1046+
y, self.moe_layer, _ = self._apply_layers_sequentially(
9881047
self.moe_layer, y, *layer_args, length=num_moe, **layer_kwargs
9891048
)
9901049
elif self.is_gemma3:
@@ -1001,7 +1060,18 @@ def __call__(
10011060
)
10021061
else:
10031062
scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval)
1004-
y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs)
1063+
if kv_caches is not None:
1064+
# Pass the kv_caches list directly to avoid copying in jnp.stack,
1065+
# which breaks vLLM PagedAttention in-place memory updates.
1066+
# The _apply_layers_sequentially function will handle it by statically unrolling.
1067+
y, self.layers, _ = self._apply_layers_sequentially(
1068+
self.layers, y, *layer_args, length=scan_length, kv_caches_stacked=kv_caches, **layer_kwargs
1069+
)
1070+
# kv_caches list is updated in-place inside _apply_layers_sequentially
1071+
else:
1072+
y, self.layers, _ = self._apply_layers_sequentially(
1073+
self.layers, y, *layer_args, length=scan_length, **layer_kwargs
1074+
)
10051075
else:
10061076
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg)
10071077

@@ -1085,7 +1155,7 @@ def _apply_gemma3_scanned_blocks(
10851155

10861156
# Apply the main scan over the full blocks
10871157
if scan_length > 0:
1088-
y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs)
1158+
y, self.layers, _ = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs)
10891159

10901160
# Apply any remaining layers that did not fit into a full scanned block
10911161
num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length

src/maxtext/models/gemma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from maxtext.layers.linears import Dropout, MlpBlock
3131
from maxtext.layers.normalizations import RMSNorm
3232
from maxtext.layers.quantizations import AqtQuantization as Quant
33+
from maxtext.inference import page_manager
3334
from maxtext.utils import max_utils
3435

3536

@@ -126,8 +127,7 @@ def __call__(
126127
deterministic,
127128
model_mode,
128129
previous_chunk=None,
129-
page_manager=None,
130-
page_state=None,
130+
page_state: None | page_manager.PageState = None,
131131
slot=None,
132132
kv_cache=None,
133133
attention_metadata=None,

src/maxtext/models/gemma3.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,13 @@ def __call__(
194194
):
195195
cfg = self.config
196196
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
197-
if isinstance(inputs, tuple):
197+
is_scan_carry = False
198+
if isinstance(inputs, tuple) and len(inputs) == 3:
199+
hidden_states, stacked_kv_cache, layer_idx = inputs
200+
kv_cache = stacked_kv_cache[layer_idx]
201+
inputs = hidden_states
202+
is_scan_carry = True
203+
elif isinstance(inputs, tuple):
198204
inputs = inputs[0]
199205
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
200206
inputs = checkpoint_name(inputs, "decoder_layer_input")
@@ -244,7 +250,16 @@ def __call__(
244250
jnp.sum(layer_output == 0) / jnp.size(layer_output),
245251
)
246252

247-
if cfg.scan_layers:
253+
if is_scan_carry:
254+
255+
def update_cache(cache, val):
256+
if jnp.size(val) > 0:
257+
return cache.at[layer_idx].set(val)
258+
return cache
259+
260+
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
261+
return (layer_output, stacked_kv_cache, layer_idx + 1), None
262+
elif cfg.scan_layers:
248263
return layer_output, None
249264
else:
250265
return layer_output, kv_cache

0 commit comments

Comments
 (0)