Skip to content

Commit 0b434cf

Browse files
committed
pass kv_cache through scanned decoder layers
1 parent 47c6d0c commit 0b434cf

11 files changed

Lines changed: 100 additions & 58 deletions

File tree

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ class MaxTextForCausalLM(nnx.Module):
8686
of the decoding step.
8787
"""
8888

89+
# Signal to tpu-inference model_loader that this class manages its own
90+
# JIT-sharded initialization (via create_nnx_model with out_shardings).
91+
# When True, model_loader skips wrapping __init__ in an outer bare @jax.jit,
92+
_self_manages_sharding: bool = True
93+
8994
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
9095
"""Initializes the MaxTextForCausalLM model.
9196
@@ -232,7 +237,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
232237
if self.model is not None:
233238
return
234239

235-
with self.mesh, nn.logical_axis_rules(""):
240+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
236241
model, _ = model_creation_utils.create_nnx_model(
237242
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
238243
)

src/maxtext/layers/attentions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -956,13 +956,14 @@ def forward_serve_vllm(
956956
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
957957
) from e
958958

959-
if rpa_kv_cache is None or rpa_metadata is None:
960-
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
961-
962959
query = query.reshape(-1, query.shape[2], query.shape[3])
963960
key = key.reshape(-1, key.shape[2], key.shape[3])
964961
value = value.reshape(-1, value.shape[2], value.shape[3])
965962

963+
if rpa_kv_cache is None or rpa_metadata is None:
964+
# Return dummy values for dry runs (e.g. during model initialization or JIT tracing)
965+
return [], query
966+
966967
if self.config.sliding_window_size > 0:
967968
attention_chunk_size = self.config.sliding_window_size
968969
else:

src/maxtext/layers/decoders.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,13 @@ def __call__(
792792
decoder_positions,
793793
deterministic,
794794
model_mode,
795+
previous_chunk,
796+
page_state,
797+
slot,
795798
)
799+
in_axes_tuple = (nn.broadcast,) * len(broadcast_args)
800+
# Pipeline module only accepts (segment_ids, positions, deterministic, model_mode)
801+
pipeline_broadcast_args = broadcast_args[:4]
796802
if cfg.using_pipeline_parallelism:
797803
if cfg.pipeline_fsdp_ag_once:
798804
logical_partition_spec = self.pipeline_module.get_weight_sharding(
@@ -828,9 +834,9 @@ def __call__(
828834
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
829835
model_mode=model_mode,
830836
)(y, *broadcast_args)
831-
y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec)
837+
y = self.pipeline_module(y, *pipeline_broadcast_args, logical_partition_spec=logical_partition_spec)
832838
else: # Not DeepSeek
833-
y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec)
839+
y = self.pipeline_module(y, *pipeline_broadcast_args, logical_partition_spec=logical_partition_spec)
834840
remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers
835841
if remaining_layers > 0:
836842
logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
@@ -847,26 +853,12 @@ def __call__(
847853
else:
848854
if cfg.scan_layers:
849855
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
850-
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
851-
layer_call_kwargs = {
852-
"page_state": page_state,
853-
"previous_chunk": previous_chunk,
854-
"slot": slot,
855-
}
856856
dense_layer = RemattedBlockLayers[0]
857857
moe_layer = RemattedBlockLayers[1]
858858
if cfg.engram_layers:
859-
original_dense_call = dense_layer.__call__
860-
original_moe_call = moe_layer.__call__
861-
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
862-
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
863-
864859
common_kwargs = {
865860
"dense_layer": dense_layer,
866861
"moe_layer": moe_layer,
867-
"original_dense_call": original_dense_call,
868-
"original_moe_call": original_moe_call,
869-
"layer_call_kwargs": layer_call_kwargs,
870862
"decoder_segment_ids": decoder_segment_ids,
871863
"decoder_positions": decoder_positions,
872864
"deterministic": deterministic,
@@ -895,7 +887,6 @@ def __call__(
895887
**common_kwargs,
896888
)
897889
else:
898-
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
899890
y, _ = self.scan_decoder_layers(
900891
cfg,
901892
dense_layer,
@@ -905,7 +896,6 @@ def __call__(
905896
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
906897
model_mode=model_mode,
907898
)(y, *broadcast_args)
908-
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
909899
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
910900

911901
# If batch-split schedule is used and initialization is complete,
@@ -954,16 +944,38 @@ def __call__(
954944
"nope_layer_interval": self.config.nope_layer_interval,
955945
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
956946
}
957-
y, _ = self.scan_decoder_layers(
947+
948+
# Update broadcast_args and in_axes_tuple for vLLM RPA
949+
current_broadcast_args = list(broadcast_args)
950+
current_in_axes_tuple = list(in_axes_tuple)
951+
952+
if kv_caches is not None:
953+
# Stack kv_caches for scan: [num_layers, ...]
954+
stacked_kv_cache = jnp.stack(kv_caches, axis=0)
955+
current_broadcast_args.append(stacked_kv_cache)
956+
current_in_axes_tuple.append(0) # Scan over the layer dimension
957+
else:
958+
current_broadcast_args.append(None)
959+
current_in_axes_tuple.append(nn.broadcast)
960+
961+
current_broadcast_args.append(attention_metadata)
962+
current_in_axes_tuple.append(nn.broadcast)
963+
964+
y, returned_kv_cache = self.scan_decoder_layers(
958965
cfg,
959966
RemattedBlockLayer,
960967
scan_length,
961968
"layers",
962969
mesh,
963-
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
970+
in_axes_tuple=tuple(current_in_axes_tuple),
964971
model_mode=model_mode,
965972
**layer_kwargs,
966-
)(y, *broadcast_args)
973+
)(y, *current_broadcast_args)
974+
975+
if kv_caches is not None and returned_kv_cache is not None:
976+
# Update the list of KV caches from the scanned results
977+
for i, cache in enumerate(returned_kv_cache):
978+
kv_caches[i] = cache
967979
else:
968980
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
969981
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."
@@ -1173,10 +1185,8 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
11731185
"""Applies a single, unscanned Engram layer."""
11741186
layer = kwargs["dense_layer"] if layer_type == "dense" else kwargs["moe_layer"]
11751187
layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
1176-
original_call = kwargs["original_dense_call"] if layer_type == "dense" else kwargs["original_moe_call"]
1177-
layer_call_kwargs = kwargs["layer_call_kwargs"]
1188+
broadcast_args = kwargs["broadcast_args"]
11781189

1179-
layer.__call__ = original_call
11801190
y, _ = layer(
11811191
config=self.config,
11821192
mesh=self.mesh,
@@ -1186,14 +1196,9 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
11861196
layer_idx=current_idx,
11871197
)(
11881198
y,
1189-
kwargs["decoder_segment_ids"],
1190-
kwargs["decoder_positions"],
1191-
kwargs["deterministic"],
1192-
kwargs["model_mode"],
1199+
*broadcast_args,
11931200
decoder_input_tokens=kwargs["decoder_input_tokens"],
1194-
**layer_call_kwargs,
11951201
)
1196-
layer.__call__ = functools.partial(original_call, **layer_call_kwargs)
11971202
return y
11981203

11991204
def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_type, **kwargs):

src/maxtext/models/gemma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from maxtext.layers.linears import Dropout, MlpBlock
3131
from maxtext.layers.normalizations import RMSNorm
3232
from maxtext.layers.quantizations import AqtQuantization as Quant
33+
from maxtext.inference import page_manager
3334
from maxtext.utils import max_utils
3435

3536

@@ -126,8 +127,7 @@ def __call__(
126127
deterministic,
127128
model_mode,
128129
previous_chunk=None,
129-
page_manager=None,
130-
page_state=None,
130+
page_state: None | page_manager.PageState = None,
131131
slot=None,
132132
kv_cache=None,
133133
attention_metadata=None,

src/maxtext/models/gpt_oss.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from maxtext.layers.attentions import Attention
3535
from maxtext.layers.normalizations import RMSNorm
3636
from maxtext.layers.quantizations import AqtQuantization as Quant
37+
from maxtext.inference import page_manager
3738
from maxtext.utils import max_utils
3839

3940
# -----------------------------------------
@@ -138,7 +139,7 @@ def __call__(
138139
deterministic,
139140
model_mode,
140141
previous_chunk=None,
141-
page_state=None,
142+
page_state: None | page_manager.PageState = None,
142143
slot=None,
143144
kv_cache=None,
144145
attention_metadata=None,
@@ -258,6 +259,11 @@ def __call__(
258259
decoder_positions,
259260
deterministic,
260261
model_mode,
262+
previous_chunk=None,
263+
page_state: None | page_manager.PageState = None,
264+
slot=None,
265+
kv_cache=None,
266+
attention_metadata=None,
261267
):
262268
cfg = self.config
263269

@@ -267,19 +273,19 @@ def __call__(
267273
for layer_id in range(cfg.inhomogeneous_layer_cycle_interval):
268274
layer_name = f"layers_{layer_id}"
269275
layer = getattr(self, layer_name)
270-
y = layer(
276+
y, kv_cache = layer(
271277
y,
272278
decoder_segment_ids,
273279
decoder_positions,
274280
deterministic,
275281
model_mode,
282+
previous_chunk=previous_chunk,
283+
page_state=page_state,
284+
slot=slot,
285+
kv_cache=kv_cache,
286+
attention_metadata=attention_metadata,
276287
)
277-
if cfg.scan_layers:
278-
y = y[0]
279-
if cfg.scan_layers:
280-
return y, None
281-
else:
282-
return y
288+
return y, kv_cache
283289

284290

285291
GptOssScannableBlockToLinen = nnx_wrappers.to_linen_class(

src/maxtext/models/llama2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ def __call__(
143143
decoder_positions,
144144
deterministic,
145145
model_mode,
146+
previous_chunk=None,
146147
slot: None | int = None,
147148
page_state: None | page_manager.PageState = None,
148-
previous_chunk=None,
149149
kv_cache=None,
150150
attention_metadata=None,
151151
):

src/maxtext/models/llama4.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,9 @@ def __call__(
442442
decoder_positions,
443443
deterministic,
444444
model_mode,
445+
previous_chunk=None,
445446
slot: None | int = None,
446447
page_state: None | page_manager.PageState = None,
447-
previous_chunk=None,
448448
kv_cache=None,
449449
attention_metadata=None,
450450
):
@@ -570,9 +570,11 @@ def __call__(
570570
decoder_positions,
571571
deterministic,
572572
model_mode,
573+
previous_chunk=None,
573574
slot: None | int = None,
574575
page_state: None | page_manager.PageState = None,
575-
previous_chunk=None,
576+
kv_cache=None,
577+
attention_metadata=None,
576578
):
577579

578580
cfg = self.config
@@ -590,6 +592,8 @@ def __call__(
590592
previous_chunk=previous_chunk,
591593
page_state=page_state,
592594
slot=slot,
595+
kv_cache=kv_cache,
596+
attention_metadata=attention_metadata,
593597
)
594598
if cfg.scan_layers:
595599
y = y[0]

src/maxtext/models/mistral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax.numpy as jnp
2323
from jax.sharding import Mesh
2424
from maxtext.common.common_types import Config
25+
from maxtext.inference import page_manager
2526
from maxtext.layers import initializers, nnx_wrappers
2627
from maxtext.layers import quantizations
2728
from maxtext.layers.attentions import Attention
@@ -126,9 +127,9 @@ def __call__(
126127
decoder_positions,
127128
deterministic,
128129
model_mode,
129-
page_state: None | int = None,
130-
slot: None | int = None,
131130
previous_chunk=None,
131+
slot: None | int = None,
132+
page_state: None | page_manager.PageState = None,
132133
kv_cache=None,
133134
attention_metadata=None,
134135
):

src/maxtext/models/olmo3.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,11 @@ def __call__(
267267
decoder_positions,
268268
deterministic,
269269
model_mode,
270+
previous_chunk=None,
271+
page_state=None,
272+
slot=None,
273+
kv_cache=None,
274+
attention_metadata=None,
270275
):
271276
cfg = self.config
272277

@@ -282,6 +287,11 @@ def __call__(
282287
decoder_positions,
283288
deterministic,
284289
model_mode,
290+
previous_chunk=previous_chunk,
291+
page_state=page_state,
292+
slot=slot,
293+
kv_cache=kv_cache,
294+
attention_metadata=attention_metadata,
285295
)
286296
if cfg.scan_layers:
287297
y = y[0]

src/maxtext/models/qwen3.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,8 @@ def __call__(
896896
previous_chunk=None,
897897
page_state: None | page_manager.PageState = None,
898898
slot: None | int = None,
899+
kv_cache=None,
900+
attention_metadata=None,
899901
) -> tuple[Array, None]:
900902
"""Applies the block of decoder layers to the input carry.
901903
@@ -924,6 +926,8 @@ def __call__(
924926
previous_chunk,
925927
page_state,
926928
slot,
929+
kv_cache=kv_cache,
930+
attention_metadata=attention_metadata,
927931
)
928932

929933
# The output of the block is the carry for the next scan iteration.
@@ -1235,10 +1239,7 @@ def __call__(
12351239
layer_output = intermediate_inputs + mlp_lnx
12361240
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
12371241

1238-
if self.config.scan_layers:
1239-
return layer_output, None
1240-
else:
1241-
return layer_output, kv_cache
1242+
return layer_output, kv_cache
12421243

12431244

12441245
# -----------------------------------------
@@ -1304,10 +1305,7 @@ def __call__(
13041305
layer_output = intermediate_inputs + mlp_lnx
13051306
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
13061307

1307-
if self.config.scan_layers:
1308-
return layer_output, None
1309-
else:
1310-
return layer_output, kv_cache
1308+
return layer_output, kv_cache
13111309

13121310

13131311
class Qwen3OmniMoeVisionPatchMerger(nnx.Module):

0 commit comments

Comments
 (0)