Skip to content

Commit ca03ea7

Browse files
committed
pass kv_cache through scanned decoder layers
1 parent 47c6d0c commit ca03ea7

9 files changed

Lines changed: 65 additions & 55 deletions

File tree

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ class MaxTextForCausalLM(nnx.Module):
8686
of the decoding step.
8787
"""
8888

89+
# Signal to tpu-inference model_loader that this class manages its own
90+
# JIT-sharded initialization (via create_nnx_model with out_shardings).
91+
# When True, model_loader skips wrapping __init__ in an outer bare @jax.jit,
92+
_self_manages_sharding: bool = True
93+
8994
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
9095
"""Initializes the MaxTextForCausalLM model.
9196
@@ -232,7 +237,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
232237
if self.model is not None:
233238
return
234239

235-
with self.mesh, nn.logical_axis_rules(""):
240+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
236241
model, _ = model_creation_utils.create_nnx_model(
237242
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
238243
)

src/maxtext/layers/attentions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -956,13 +956,14 @@ def forward_serve_vllm(
956956
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
957957
) from e
958958

959-
if rpa_kv_cache is None or rpa_metadata is None:
960-
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
961-
962959
query = query.reshape(-1, query.shape[2], query.shape[3])
963960
key = key.reshape(-1, key.shape[2], key.shape[3])
964961
value = value.reshape(-1, value.shape[2], value.shape[3])
965962

963+
if rpa_kv_cache is None or rpa_metadata is None:
964+
# Return dummy values for dry runs (e.g. during model initialization or JIT tracing)
965+
return [], query
966+
966967
if self.config.sliding_window_size > 0:
967968
attention_chunk_size = self.config.sliding_window_size
968969
else:

src/maxtext/layers/decoders.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,11 @@ def __call__(
792792
decoder_positions,
793793
deterministic,
794794
model_mode,
795+
previous_chunk,
796+
page_state,
797+
slot,
795798
)
799+
in_axes_tuple = (nn.broadcast,) * len(broadcast_args)
796800
if cfg.using_pipeline_parallelism:
797801
if cfg.pipeline_fsdp_ag_once:
798802
logical_partition_spec = self.pipeline_module.get_weight_sharding(
@@ -847,26 +851,12 @@ def __call__(
847851
else:
848852
if cfg.scan_layers:
849853
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
850-
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
851-
layer_call_kwargs = {
852-
"page_state": page_state,
853-
"previous_chunk": previous_chunk,
854-
"slot": slot,
855-
}
856854
dense_layer = RemattedBlockLayers[0]
857855
moe_layer = RemattedBlockLayers[1]
858856
if cfg.engram_layers:
859-
original_dense_call = dense_layer.__call__
860-
original_moe_call = moe_layer.__call__
861-
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
862-
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
863-
864857
common_kwargs = {
865858
"dense_layer": dense_layer,
866859
"moe_layer": moe_layer,
867-
"original_dense_call": original_dense_call,
868-
"original_moe_call": original_moe_call,
869-
"layer_call_kwargs": layer_call_kwargs,
870860
"decoder_segment_ids": decoder_segment_ids,
871861
"decoder_positions": decoder_positions,
872862
"deterministic": deterministic,
@@ -895,7 +885,6 @@ def __call__(
895885
**common_kwargs,
896886
)
897887
else:
898-
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
899888
y, _ = self.scan_decoder_layers(
900889
cfg,
901890
dense_layer,
@@ -905,7 +894,6 @@ def __call__(
905894
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
906895
model_mode=model_mode,
907896
)(y, *broadcast_args)
908-
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
909897
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
910898

911899
# If batch-split schedule is used and initialization is complete,
@@ -954,16 +942,38 @@ def __call__(
954942
"nope_layer_interval": self.config.nope_layer_interval,
955943
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
956944
}
957-
y, _ = self.scan_decoder_layers(
945+
946+
# Update broadcast_args and in_axes_tuple for vLLM RPA
947+
current_broadcast_args = list(broadcast_args)
948+
current_in_axes_tuple = list(in_axes_tuple)
949+
950+
if kv_caches is not None:
951+
# Stack kv_caches for scan: [num_layers, ...]
952+
stacked_kv_cache = jnp.stack(kv_caches, axis=0)
953+
current_broadcast_args.append(stacked_kv_cache)
954+
current_in_axes_tuple.append(0) # Scan over the layer dimension
955+
else:
956+
current_broadcast_args.append(None)
957+
current_in_axes_tuple.append(nn.broadcast)
958+
959+
current_broadcast_args.append(attention_metadata)
960+
current_in_axes_tuple.append(nn.broadcast)
961+
962+
y, returned_kv_cache = self.scan_decoder_layers(
958963
cfg,
959964
RemattedBlockLayer,
960965
scan_length,
961966
"layers",
962967
mesh,
963-
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
968+
in_axes_tuple=tuple(current_in_axes_tuple),
964969
model_mode=model_mode,
965970
**layer_kwargs,
966-
)(y, *broadcast_args)
971+
)(y, *current_broadcast_args)
972+
973+
if kv_caches is not None and returned_kv_cache is not None:
974+
# Update the list of KV caches from the scanned results
975+
for i, cache in enumerate(returned_kv_cache):
976+
kv_caches[i] = cache
967977
else:
968978
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
969979
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."
@@ -1173,10 +1183,8 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
11731183
"""Applies a single, unscanned Engram layer."""
11741184
layer = kwargs["dense_layer"] if layer_type == "dense" else kwargs["moe_layer"]
11751185
layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
1176-
original_call = kwargs["original_dense_call"] if layer_type == "dense" else kwargs["original_moe_call"]
1177-
layer_call_kwargs = kwargs["layer_call_kwargs"]
1186+
broadcast_args = kwargs["broadcast_args"]
11781187

1179-
layer.__call__ = original_call
11801188
y, _ = layer(
11811189
config=self.config,
11821190
mesh=self.mesh,
@@ -1186,14 +1194,9 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
11861194
layer_idx=current_idx,
11871195
)(
11881196
y,
1189-
kwargs["decoder_segment_ids"],
1190-
kwargs["decoder_positions"],
1191-
kwargs["deterministic"],
1192-
kwargs["model_mode"],
1197+
*broadcast_args,
11931198
decoder_input_tokens=kwargs["decoder_input_tokens"],
1194-
**layer_call_kwargs,
11951199
)
1196-
layer.__call__ = functools.partial(original_call, **layer_call_kwargs)
11971200
return y
11981201

11991202
def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_type, **kwargs):

src/maxtext/models/gemma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from maxtext.layers.linears import Dropout, MlpBlock
3131
from maxtext.layers.normalizations import RMSNorm
3232
from maxtext.layers.quantizations import AqtQuantization as Quant
33+
from maxtext.inference import page_manager
3334
from maxtext.utils import max_utils
3435

3536

@@ -126,8 +127,7 @@ def __call__(
126127
deterministic,
127128
model_mode,
128129
previous_chunk=None,
129-
page_manager=None,
130-
page_state=None,
130+
page_state: None | page_manager.PageState = None,
131131
slot=None,
132132
kv_cache=None,
133133
attention_metadata=None,

src/maxtext/models/gpt_oss.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from maxtext.layers.attentions import Attention
3535
from maxtext.layers.normalizations import RMSNorm
3636
from maxtext.layers.quantizations import AqtQuantization as Quant
37+
from maxtext.inference import page_manager
3738
from maxtext.utils import max_utils
3839

3940
# -----------------------------------------
@@ -138,7 +139,7 @@ def __call__(
138139
deterministic,
139140
model_mode,
140141
previous_chunk=None,
141-
page_state=None,
142+
page_state: None | page_manager.PageState = None,
142143
slot=None,
143144
kv_cache=None,
144145
attention_metadata=None,
@@ -258,6 +259,11 @@ def __call__(
258259
decoder_positions,
259260
deterministic,
260261
model_mode,
262+
previous_chunk=None,
263+
page_state: None | page_manager.PageState = None,
264+
slot=None,
265+
kv_cache=None,
266+
attention_metadata=None,
261267
):
262268
cfg = self.config
263269

@@ -267,19 +273,19 @@ def __call__(
267273
for layer_id in range(cfg.inhomogeneous_layer_cycle_interval):
268274
layer_name = f"layers_{layer_id}"
269275
layer = getattr(self, layer_name)
270-
y = layer(
276+
y, kv_cache = layer(
271277
y,
272278
decoder_segment_ids,
273279
decoder_positions,
274280
deterministic,
275281
model_mode,
282+
previous_chunk=previous_chunk,
283+
page_state=page_state,
284+
slot=slot,
285+
kv_cache=kv_cache,
286+
attention_metadata=attention_metadata,
276287
)
277-
if cfg.scan_layers:
278-
y = y[0]
279-
if cfg.scan_layers:
280-
return y, None
281-
else:
282-
return y
288+
return y, kv_cache
283289

284290

285291
GptOssScannableBlockToLinen = nnx_wrappers.to_linen_class(

src/maxtext/models/llama2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ def __call__(
143143
decoder_positions,
144144
deterministic,
145145
model_mode,
146+
previous_chunk=None,
146147
slot: None | int = None,
147148
page_state: None | page_manager.PageState = None,
148-
previous_chunk=None,
149149
kv_cache=None,
150150
attention_metadata=None,
151151
):

src/maxtext/models/llama4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,9 @@ def __call__(
442442
decoder_positions,
443443
deterministic,
444444
model_mode,
445+
previous_chunk=None,
445446
slot: None | int = None,
446447
page_state: None | page_manager.PageState = None,
447-
previous_chunk=None,
448448
kv_cache=None,
449449
attention_metadata=None,
450450
):
@@ -570,9 +570,9 @@ def __call__(
570570
decoder_positions,
571571
deterministic,
572572
model_mode,
573+
previous_chunk=None,
573574
slot: None | int = None,
574575
page_state: None | page_manager.PageState = None,
575-
previous_chunk=None,
576576
):
577577

578578
cfg = self.config

src/maxtext/models/mistral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax.numpy as jnp
2323
from jax.sharding import Mesh
2424
from maxtext.common.common_types import Config
25+
from maxtext.inference import page_manager
2526
from maxtext.layers import initializers, nnx_wrappers
2627
from maxtext.layers import quantizations
2728
from maxtext.layers.attentions import Attention
@@ -126,9 +127,9 @@ def __call__(
126127
decoder_positions,
127128
deterministic,
128129
model_mode,
129-
page_state: None | int = None,
130-
slot: None | int = None,
131130
previous_chunk=None,
131+
slot: None | int = None,
132+
page_state: None | page_manager.PageState = None,
132133
kv_cache=None,
133134
attention_metadata=None,
134135
):

src/maxtext/models/qwen3.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,10 +1235,7 @@ def __call__(
12351235
layer_output = intermediate_inputs + mlp_lnx
12361236
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
12371237

1238-
if self.config.scan_layers:
1239-
return layer_output, None
1240-
else:
1241-
return layer_output, kv_cache
1238+
return layer_output, kv_cache
12421239

12431240

12441241
# -----------------------------------------
@@ -1304,10 +1301,7 @@ def __call__(
13041301
layer_output = intermediate_inputs + mlp_lnx
13051302
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
13061303

1307-
if self.config.scan_layers:
1308-
return layer_output, None
1309-
else:
1310-
return layer_output, kv_cache
1304+
return layer_output, kv_cache
13111305

13121306

13131307
class Qwen3OmniMoeVisionPatchMerger(nnx.Module):

0 commit comments

Comments
 (0)