Skip to content

Commit ba92ee5

Browse files
committed
fix(vllm): Fix PagedAttention memory aliasing and unrolled loop compilation for scan_layers=True
1 parent 21c4433 commit ba92ee5

12 files changed

Lines changed: 206 additions & 58 deletions

File tree

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ class MaxTextForCausalLM(nnx.Module):
104104
of the decoding step.
105105
"""
106106

107+
# Signal to tpu-inference model_loader that this class manages its own
108+
# JIT-sharded initialization (via create_nnx_model with out_shardings).
109+
# When True, model_loader skips wrapping __init__ in an outer bare @jax.jit,
110+
_self_manages_sharding: bool = True
111+
107112
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
108113
"""Initializes the MaxTextForCausalLM model.
109114
@@ -250,7 +255,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
250255
if self.model is not None:
251256
return
252257

253-
with self.mesh, nn.logical_axis_rules(""):
258+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
254259
model, _ = model_creation_utils.create_nnx_model(
255260
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
256261
)

src/maxtext/layers/attentions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -992,13 +992,14 @@ def forward_serve_vllm(
992992
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
993993
) from e
994994

995-
if rpa_kv_cache is None or rpa_metadata is None:
996-
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
997-
998995
query = query.reshape(-1, query.shape[2], query.shape[3])
999996
key = key.reshape(-1, key.shape[2], key.shape[3])
1000997
value = value.reshape(-1, value.shape[2], value.shape[3])
1001998

999+
if rpa_kv_cache is None or rpa_metadata is None:
1000+
# Return dummy values for dry runs (e.g. during model initialization or JIT tracing)
1001+
return [], query
1002+
10021003
if self.config.sliding_window_size > 0:
10031004
attention_chunk_size = self.config.sliding_window_size
10041005
else:

src/maxtext/layers/nnx_decoders.py

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,23 @@ def pure_layer_fn(state_in, y_in):
428428

429429
return out
430430

431-
def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs):
432-
"""Runs the layer stack using nnx.scan."""
431+
def _apply_layers_sequentially(self, layers, x_in, *args, length: int, kv_caches_stacked=None, **kwargs):
432+
"""Runs the layer stack using nnx.scan.
433+
434+
Args:
435+
layers: The stacked NNX module whose params are scanned over.
436+
x_in: The carry (hidden state) fed into the first layer.
437+
*args: Positional args broadcast to every layer call.
438+
length: Number of scan iterations (= number of layers).
439+
kv_caches_stacked: Optional pytree whose leaves have shape [num_layers, ...].
440+
When provided, the i-th slice is passed as `kv_cache=` to layer i and the
441+
updated caches are returned as a third element of the tuple.
442+
**kwargs: Keyword args forwarded to the layer (filtered by the layer signature).
443+
444+
Returns:
445+
(final_carry, updated_layers) when kv_caches_stacked is None.
446+
(final_carry, updated_layers, returned_kv_stacked) otherwise.
447+
"""
433448
policy = self.get_remat_policy()
434449
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config)
435450
graphdef, params, state = nnx.split(
@@ -450,35 +465,83 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
450465
# Filter kwargs to only include keys that exist in the layer's signature
451466
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
452467

468+
use_kv = kv_caches_stacked is not None
469+
453470
def layer_fn(carry, scanned_vars):
454471
# Unpack the sliced variables for THIS layer
455-
current_params, current_state = scanned_vars
472+
if use_kv:
473+
current_params, current_state, kv_cache_layer = scanned_vars
474+
else:
475+
current_params, current_state = scanned_vars
476+
kv_cache_layer = None
456477

457478
if self.config.parameter_memory_host_offload:
458479
current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
459480

460481
# Merge using the SLICED state
461482
layer = nnx.merge(graphdef, current_params, current_state)
462483

463-
# Run the layer (Filter kwargs if using the solution from previous turn)
464-
layer_out = layer(carry, *args, **valid_kwargs)
484+
# Build call kwargs, injecting per-layer kv_cache when available
485+
call_kwargs = dict(valid_kwargs)
486+
if kv_cache_layer is not None:
487+
call_kwargs["kv_cache"] = kv_cache_layer
465488

466-
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
489+
layer_out = layer(carry, *args, **call_kwargs)
490+
491+
if isinstance(layer_out, tuple):
492+
new_carry = layer_out[0]
493+
updated_kv = layer_out[1] if len(layer_out) > 1 else None
494+
else:
495+
new_carry = layer_out
496+
updated_kv = None
467497

468498
# Extract the updated state to return it
469-
# _, new_current_state = nnx.split(layer, nnx.Param, ...)
470499
new_current_state = nnx.state(layer)
500+
501+
if use_kv:
502+
return new_carry, (new_current_state, updated_kv)
471503
return new_carry, new_current_state
472504

473505
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
474506

475-
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
507+
if use_kv:
508+
# If kv_caches is provided (e.g., from vLLM), we CANNOT use jax.lax.scan
509+
# because scanning requires stacking the kv_caches list, which creates a copy
510+
# and breaks the in-place memory updates required by vLLM's PagedAttention.
511+
# Therefore, we must unroll the loop statically when kv_caches is provided.
512+
513+
# kv_caches_stacked is actually the original kv_caches list in this new flow
514+
kv_caches_list = kv_caches_stacked
515+
516+
current_carry = x_in
517+
518+
for i in range(length):
519+
# Statically slice the parameters and state for this layer
520+
current_params = jax.tree.map(lambda x: x[i], params)
521+
current_state = jax.tree.map(lambda x: x[i], state)
522+
523+
# Call the layer
524+
current_carry, (new_current_state, updated_kv) = layer_fn(
525+
current_carry, (current_params, current_state, kv_caches_list[i])
526+
)
527+
528+
# Update the list in-place (mutates the list passed by reference)
529+
kv_caches_list[i] = updated_kv
530+
531+
# We don't need to rebuild scanned_state or return it because during
532+
# inference with vLLM, parameters do not change and we don't need intermediates.
533+
return current_carry, layers, None
534+
else:
535+
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
536+
returned_kv_stacked = None
476537

477538
if scan_axis != 0:
478539
scanned_params, scanned_other = scanned_state.split(nnx.Param, ...)
479540
scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params)
480541
scanned_state = nnx.State.merge(scanned_params, scanned_other)
481542

543+
if use_kv:
544+
return final_carry, nnx.merge(graphdef, scanned_state), returned_kv_stacked
482545
return final_carry, nnx.merge(graphdef, scanned_state)
483546

484547
def get_decoder_layers(self):
@@ -1001,7 +1064,19 @@ def __call__(
10011064
)
10021065
else:
10031066
scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval)
1004-
y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs)
1067+
if kv_caches is not None:
1068+
# Pass the kv_caches list directly to avoid copying in jnp.stack,
1069+
# which breaks vLLM PagedAttention in-place memory updates.
1070+
# The _apply_layers_sequentially function will handle it by statically unrolling.
1071+
y, self.layers, returned_kv = self._apply_layers_sequentially(
1072+
self.layers, y, *layer_args, length=scan_length,
1073+
kv_caches_stacked=kv_caches, **layer_kwargs
1074+
)
1075+
# kv_caches list is updated in-place inside _apply_layers_sequentially
1076+
else:
1077+
y, self.layers = self._apply_layers_sequentially(
1078+
self.layers, y, *layer_args, length=scan_length, **layer_kwargs
1079+
)
10051080
else:
10061081
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg)
10071082

src/maxtext/models/gemma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from maxtext.layers.linears import Dropout, MlpBlock
3131
from maxtext.layers.normalizations import RMSNorm
3232
from maxtext.layers.quantizations import AqtQuantization as Quant
33+
from maxtext.inference import page_manager
3334
from maxtext.utils import max_utils
3435

3536

@@ -126,8 +127,7 @@ def __call__(
126127
deterministic,
127128
model_mode,
128129
previous_chunk=None,
129-
page_manager=None,
130-
page_state=None,
130+
page_state: None | page_manager.PageState = None,
131131
slot=None,
132132
kv_cache=None,
133133
attention_metadata=None,

src/maxtext/models/gpt_oss.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from maxtext.layers.attentions import Attention
3535
from maxtext.layers.normalizations import RMSNorm
3636
from maxtext.layers.quantizations import AqtQuantization as Quant
37+
from maxtext.inference import page_manager
3738
from maxtext.utils import max_utils
3839

3940
# -----------------------------------------
@@ -138,7 +139,7 @@ def __call__(
138139
deterministic,
139140
model_mode,
140141
previous_chunk=None,
141-
page_state=None,
142+
page_state: None | page_manager.PageState = None,
142143
slot=None,
143144
kv_cache=None,
144145
attention_metadata=None,
@@ -258,6 +259,11 @@ def __call__(
258259
decoder_positions,
259260
deterministic,
260261
model_mode,
262+
previous_chunk=None,
263+
page_state: None | page_manager.PageState = None,
264+
slot=None,
265+
kv_cache=None,
266+
attention_metadata=None,
261267
):
262268
cfg = self.config
263269

@@ -267,19 +273,19 @@ def __call__(
267273
for layer_id in range(cfg.inhomogeneous_layer_cycle_interval):
268274
layer_name = f"layers_{layer_id}"
269275
layer = getattr(self, layer_name)
270-
y = layer(
276+
y, kv_cache = layer(
271277
y,
272278
decoder_segment_ids,
273279
decoder_positions,
274280
deterministic,
275281
model_mode,
282+
previous_chunk=previous_chunk,
283+
page_state=page_state,
284+
slot=slot,
285+
kv_cache=kv_cache,
286+
attention_metadata=attention_metadata,
276287
)
277-
if cfg.scan_layers:
278-
y = y[0]
279-
if cfg.scan_layers:
280-
return y, None
281-
else:
282-
return y
288+
return y, kv_cache
283289

284290

285291
GptOssScannableBlockToLinen = nnx_wrappers.to_linen_class(

src/maxtext/models/llama2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ def __call__(
143143
decoder_positions,
144144
deterministic,
145145
model_mode,
146+
previous_chunk=None,
146147
slot: None | int = None,
147148
page_state: None | page_manager.PageState = None,
148-
previous_chunk=None,
149149
kv_cache=None,
150150
attention_metadata=None,
151151
):

src/maxtext/models/llama4.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,9 @@ def __call__(
442442
decoder_positions,
443443
deterministic,
444444
model_mode,
445+
previous_chunk=None,
445446
slot: None | int = None,
446447
page_state: None | page_manager.PageState = None,
447-
previous_chunk=None,
448448
kv_cache=None,
449449
attention_metadata=None,
450450
):
@@ -570,9 +570,11 @@ def __call__(
570570
decoder_positions,
571571
deterministic,
572572
model_mode,
573+
previous_chunk=None,
573574
slot: None | int = None,
574575
page_state: None | page_manager.PageState = None,
575-
previous_chunk=None,
576+
kv_cache=None,
577+
attention_metadata=None,
576578
):
577579

578580
cfg = self.config
@@ -590,6 +592,8 @@ def __call__(
590592
previous_chunk=previous_chunk,
591593
page_state=page_state,
592594
slot=slot,
595+
kv_cache=kv_cache,
596+
attention_metadata=attention_metadata,
593597
)
594598
if cfg.scan_layers:
595599
y = y[0]

src/maxtext/models/mistral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax.numpy as jnp
2323
from jax.sharding import Mesh
2424
from maxtext.common.common_types import Config
25+
from maxtext.inference import page_manager
2526
from maxtext.layers import initializers, nnx_wrappers
2627
from maxtext.layers import quantizations
2728
from maxtext.layers.attentions import Attention
@@ -126,9 +127,9 @@ def __call__(
126127
decoder_positions,
127128
deterministic,
128129
model_mode,
129-
page_state: None | int = None,
130-
slot: None | int = None,
131130
previous_chunk=None,
131+
slot: None | int = None,
132+
page_state: None | page_manager.PageState = None,
132133
kv_cache=None,
133134
attention_metadata=None,
134135
):

src/maxtext/models/olmo3.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,11 @@ def __call__(
267267
decoder_positions,
268268
deterministic,
269269
model_mode,
270+
previous_chunk=None,
271+
page_state=None,
272+
slot=None,
273+
kv_cache=None,
274+
attention_metadata=None,
270275
):
271276
cfg = self.config
272277

@@ -282,6 +287,11 @@ def __call__(
282287
decoder_positions,
283288
deterministic,
284289
model_mode,
290+
previous_chunk=previous_chunk,
291+
page_state=page_state,
292+
slot=slot,
293+
kv_cache=kv_cache,
294+
attention_metadata=attention_metadata,
285295
)
286296
if cfg.scan_layers:
287297
y = y[0]

src/maxtext/models/qwen3.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,8 @@ def __call__(
896896
previous_chunk=None,
897897
page_state: None | page_manager.PageState = None,
898898
slot: None | int = None,
899+
kv_cache=None,
900+
attention_metadata=None,
899901
) -> tuple[Array, None]:
900902
"""Applies the block of decoder layers to the input carry.
901903
@@ -924,6 +926,8 @@ def __call__(
924926
previous_chunk,
925927
page_state,
926928
slot,
929+
kv_cache=kv_cache,
930+
attention_metadata=attention_metadata,
927931
)
928932

929933
# The output of the block is the carry for the next scan iteration.
@@ -1235,10 +1239,7 @@ def __call__(
12351239
layer_output = intermediate_inputs + mlp_lnx
12361240
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
12371241

1238-
if self.config.scan_layers:
1239-
return layer_output, None
1240-
else:
1241-
return layer_output, kv_cache
1242+
return layer_output, kv_cache
12421243

12431244

12441245
# -----------------------------------------
@@ -1304,10 +1305,7 @@ def __call__(
13041305
layer_output = intermediate_inputs + mlp_lnx
13051306
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
13061307

1307-
if self.config.scan_layers:
1308-
return layer_output, None
1309-
else:
1310-
return layer_output, kv_cache
1308+
return layer_output, kv_cache
13111309

13121310

13131311
class Qwen3OmniMoeVisionPatchMerger(nnx.Module):

0 commit comments

Comments
 (0)