Skip to content

Commit 042ae7d

Browse files
committed
pass kv_cache through scanned decoder layers
1 parent 21c4433 commit 042ae7d

11 files changed

Lines changed: 100 additions & 58 deletions

File tree

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ class MaxTextForCausalLM(nnx.Module):
104104
of the decoding step.
105105
"""
106106

107+
# Signal to tpu-inference model_loader that this class manages its own
108+
# JIT-sharded initialization (via create_nnx_model with out_shardings).
109+
# When True, model_loader skips wrapping __init__ in an outer bare @jax.jit,
110+
_self_manages_sharding: bool = True
111+
107112
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
108113
"""Initializes the MaxTextForCausalLM model.
109114
@@ -250,7 +255,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
250255
if self.model is not None:
251256
return
252257

253-
with self.mesh, nn.logical_axis_rules(""):
258+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
254259
model, _ = model_creation_utils.create_nnx_model(
255260
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
256261
)

src/maxtext/layers/attentions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -992,13 +992,14 @@ def forward_serve_vllm(
992992
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
993993
) from e
994994

995-
if rpa_kv_cache is None or rpa_metadata is None:
996-
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
997-
998995
query = query.reshape(-1, query.shape[2], query.shape[3])
999996
key = key.reshape(-1, key.shape[2], key.shape[3])
1000997
value = value.reshape(-1, value.shape[2], value.shape[3])
1001998

999+
if rpa_kv_cache is None or rpa_metadata is None:
1000+
# Return dummy values for dry runs (e.g. during model initialization or JIT tracing)
1001+
return [], query
1002+
10021003
if self.config.sliding_window_size > 0:
10031004
attention_chunk_size = self.config.sliding_window_size
10041005
else:

src/maxtext/layers/decoders.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,13 @@ def __call__(
795795
decoder_positions,
796796
deterministic,
797797
model_mode,
798+
previous_chunk,
799+
page_state,
800+
slot,
798801
)
802+
in_axes_tuple = (nn.broadcast,) * len(broadcast_args)
803+
# Pipeline module only accepts (segment_ids, positions, deterministic, model_mode)
804+
pipeline_broadcast_args = broadcast_args[:4]
799805
if cfg.using_pipeline_parallelism:
800806
logical_partition_spec = (
801807
self.pipeline_module.get_weight_sharding(y, decoder_segment_ids, decoder_positions, deterministic, model_mode)
@@ -830,9 +836,9 @@ def __call__(
830836
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
831837
model_mode=model_mode,
832838
)(y, *broadcast_args)
833-
y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec)
839+
y = self.pipeline_module(y, *pipeline_broadcast_args, logical_partition_spec=logical_partition_spec)
834840
else: # Not DeepSeek
835-
y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec)
841+
y = self.pipeline_module(y, *pipeline_broadcast_args, logical_partition_spec=logical_partition_spec)
836842
remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers
837843
if remaining_layers > 0:
838844
logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
@@ -849,26 +855,12 @@ def __call__(
849855
else:
850856
if cfg.scan_layers:
851857
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
852-
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
853-
layer_call_kwargs = {
854-
"page_state": page_state,
855-
"previous_chunk": previous_chunk,
856-
"slot": slot,
857-
}
858858
dense_layer = RemattedBlockLayers[0]
859859
moe_layer = RemattedBlockLayers[1]
860860
if cfg.engram_layers:
861-
original_dense_call = dense_layer.__call__
862-
original_moe_call = moe_layer.__call__
863-
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
864-
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
865-
866861
common_kwargs = {
867862
"dense_layer": dense_layer,
868863
"moe_layer": moe_layer,
869-
"original_dense_call": original_dense_call,
870-
"original_moe_call": original_moe_call,
871-
"layer_call_kwargs": layer_call_kwargs,
872864
"decoder_segment_ids": decoder_segment_ids,
873865
"decoder_positions": decoder_positions,
874866
"deterministic": deterministic,
@@ -897,7 +889,6 @@ def __call__(
897889
**common_kwargs,
898890
)
899891
else:
900-
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
901892
y, _ = self.scan_decoder_layers(
902893
cfg,
903894
dense_layer,
@@ -907,7 +898,6 @@ def __call__(
907898
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
908899
model_mode=model_mode,
909900
)(y, *broadcast_args)
910-
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
911901
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
912902

913903
# If batch-split schedule is used and initialization is complete,
@@ -981,16 +971,38 @@ def __call__(
981971
"nope_layer_interval": self.config.nope_layer_interval,
982972
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
983973
}
984-
y, _ = self.scan_decoder_layers(
974+
975+
# Update broadcast_args and in_axes_tuple for vLLM RPA
976+
current_broadcast_args = list(broadcast_args)
977+
current_in_axes_tuple = list(in_axes_tuple)
978+
979+
if kv_caches is not None:
980+
# Stack kv_caches for scan: [num_layers, ...]
981+
stacked_kv_cache = jnp.stack(kv_caches, axis=0)
982+
current_broadcast_args.append(stacked_kv_cache)
983+
current_in_axes_tuple.append(0) # Scan over the layer dimension
984+
else:
985+
current_broadcast_args.append(None)
986+
current_in_axes_tuple.append(nn.broadcast)
987+
988+
current_broadcast_args.append(attention_metadata)
989+
current_in_axes_tuple.append(nn.broadcast)
990+
991+
y, returned_kv_cache = self.scan_decoder_layers(
985992
cfg,
986993
RemattedBlockLayer,
987994
scan_length,
988995
"layers",
989996
mesh,
990-
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
997+
in_axes_tuple=tuple(current_in_axes_tuple),
991998
model_mode=model_mode,
992999
**layer_kwargs,
993-
)(y, *broadcast_args)
1000+
)(y, *current_broadcast_args)
1001+
1002+
if kv_caches is not None and returned_kv_cache is not None:
1003+
# Update the list of KV caches from the scanned results
1004+
for i, cache in enumerate(returned_kv_cache):
1005+
kv_caches[i] = cache
9941006
else:
9951007
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
9961008
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."
@@ -1295,10 +1307,8 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
12951307
"""Applies a single, unscanned Engram layer."""
12961308
layer = kwargs["dense_layer"] if layer_type == "dense" else kwargs["moe_layer"]
12971309
layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
1298-
original_call = kwargs["original_dense_call"] if layer_type == "dense" else kwargs["original_moe_call"]
1299-
layer_call_kwargs = kwargs["layer_call_kwargs"]
1310+
broadcast_args = kwargs["broadcast_args"]
13001311

1301-
layer.__call__ = original_call
13021312
y, _ = layer(
13031313
config=self.config,
13041314
mesh=self.mesh,
@@ -1308,14 +1318,9 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
13081318
layer_idx=current_idx,
13091319
)(
13101320
y,
1311-
kwargs["decoder_segment_ids"],
1312-
kwargs["decoder_positions"],
1313-
kwargs["deterministic"],
1314-
kwargs["model_mode"],
1321+
*broadcast_args,
13151322
decoder_input_tokens=kwargs["decoder_input_tokens"],
1316-
**layer_call_kwargs,
13171323
)
1318-
layer.__call__ = functools.partial(original_call, **layer_call_kwargs)
13191324
return y
13201325

13211326
def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_type, **kwargs):

src/maxtext/models/gemma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from maxtext.layers.linears import Dropout, MlpBlock
3131
from maxtext.layers.normalizations import RMSNorm
3232
from maxtext.layers.quantizations import AqtQuantization as Quant
33+
from maxtext.inference import page_manager
3334
from maxtext.utils import max_utils
3435

3536

@@ -126,8 +127,7 @@ def __call__(
126127
deterministic,
127128
model_mode,
128129
previous_chunk=None,
129-
page_manager=None,
130-
page_state=None,
130+
page_state: None | page_manager.PageState = None,
131131
slot=None,
132132
kv_cache=None,
133133
attention_metadata=None,

src/maxtext/models/gpt_oss.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from maxtext.layers.attentions import Attention
3535
from maxtext.layers.normalizations import RMSNorm
3636
from maxtext.layers.quantizations import AqtQuantization as Quant
37+
from maxtext.inference import page_manager
3738
from maxtext.utils import max_utils
3839

3940
# -----------------------------------------
@@ -138,7 +139,7 @@ def __call__(
138139
deterministic,
139140
model_mode,
140141
previous_chunk=None,
141-
page_state=None,
142+
page_state: None | page_manager.PageState = None,
142143
slot=None,
143144
kv_cache=None,
144145
attention_metadata=None,
@@ -258,6 +259,11 @@ def __call__(
258259
decoder_positions,
259260
deterministic,
260261
model_mode,
262+
previous_chunk=None,
263+
page_state: None | page_manager.PageState = None,
264+
slot=None,
265+
kv_cache=None,
266+
attention_metadata=None,
261267
):
262268
cfg = self.config
263269

@@ -267,19 +273,19 @@ def __call__(
267273
for layer_id in range(cfg.inhomogeneous_layer_cycle_interval):
268274
layer_name = f"layers_{layer_id}"
269275
layer = getattr(self, layer_name)
270-
y = layer(
276+
y, kv_cache = layer(
271277
y,
272278
decoder_segment_ids,
273279
decoder_positions,
274280
deterministic,
275281
model_mode,
282+
previous_chunk=previous_chunk,
283+
page_state=page_state,
284+
slot=slot,
285+
kv_cache=kv_cache,
286+
attention_metadata=attention_metadata,
276287
)
277-
if cfg.scan_layers:
278-
y = y[0]
279-
if cfg.scan_layers:
280-
return y, None
281-
else:
282-
return y
288+
return y, kv_cache
283289

284290

285291
GptOssScannableBlockToLinen = nnx_wrappers.to_linen_class(

src/maxtext/models/llama2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ def __call__(
143143
decoder_positions,
144144
deterministic,
145145
model_mode,
146+
previous_chunk=None,
146147
slot: None | int = None,
147148
page_state: None | page_manager.PageState = None,
148-
previous_chunk=None,
149149
kv_cache=None,
150150
attention_metadata=None,
151151
):

src/maxtext/models/llama4.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,9 @@ def __call__(
442442
decoder_positions,
443443
deterministic,
444444
model_mode,
445+
previous_chunk=None,
445446
slot: None | int = None,
446447
page_state: None | page_manager.PageState = None,
447-
previous_chunk=None,
448448
kv_cache=None,
449449
attention_metadata=None,
450450
):
@@ -570,9 +570,11 @@ def __call__(
570570
decoder_positions,
571571
deterministic,
572572
model_mode,
573+
previous_chunk=None,
573574
slot: None | int = None,
574575
page_state: None | page_manager.PageState = None,
575-
previous_chunk=None,
576+
kv_cache=None,
577+
attention_metadata=None,
576578
):
577579

578580
cfg = self.config
@@ -590,6 +592,8 @@ def __call__(
590592
previous_chunk=previous_chunk,
591593
page_state=page_state,
592594
slot=slot,
595+
kv_cache=kv_cache,
596+
attention_metadata=attention_metadata,
593597
)
594598
if cfg.scan_layers:
595599
y = y[0]

src/maxtext/models/mistral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax.numpy as jnp
2323
from jax.sharding import Mesh
2424
from maxtext.common.common_types import Config
25+
from maxtext.inference import page_manager
2526
from maxtext.layers import initializers, nnx_wrappers
2627
from maxtext.layers import quantizations
2728
from maxtext.layers.attentions import Attention
@@ -126,9 +127,9 @@ def __call__(
126127
decoder_positions,
127128
deterministic,
128129
model_mode,
129-
page_state: None | int = None,
130-
slot: None | int = None,
131130
previous_chunk=None,
131+
slot: None | int = None,
132+
page_state: None | page_manager.PageState = None,
132133
kv_cache=None,
133134
attention_metadata=None,
134135
):

src/maxtext/models/olmo3.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,11 @@ def __call__(
267267
decoder_positions,
268268
deterministic,
269269
model_mode,
270+
previous_chunk=None,
271+
page_state=None,
272+
slot=None,
273+
kv_cache=None,
274+
attention_metadata=None,
270275
):
271276
cfg = self.config
272277

@@ -282,6 +287,11 @@ def __call__(
282287
decoder_positions,
283288
deterministic,
284289
model_mode,
290+
previous_chunk=previous_chunk,
291+
page_state=page_state,
292+
slot=slot,
293+
kv_cache=kv_cache,
294+
attention_metadata=attention_metadata,
285295
)
286296
if cfg.scan_layers:
287297
y = y[0]

src/maxtext/models/qwen3.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,8 @@ def __call__(
896896
previous_chunk=None,
897897
page_state: None | page_manager.PageState = None,
898898
slot: None | int = None,
899+
kv_cache=None,
900+
attention_metadata=None,
899901
) -> tuple[Array, None]:
900902
"""Applies the block of decoder layers to the input carry.
901903
@@ -924,6 +926,8 @@ def __call__(
924926
previous_chunk,
925927
page_state,
926928
slot,
929+
kv_cache=kv_cache,
930+
attention_metadata=attention_metadata,
927931
)
928932

929933
# The output of the block is the carry for the next scan iteration.
@@ -1235,10 +1239,7 @@ def __call__(
12351239
layer_output = intermediate_inputs + mlp_lnx
12361240
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
12371241

1238-
if self.config.scan_layers:
1239-
return layer_output, None
1240-
else:
1241-
return layer_output, kv_cache
1242+
return layer_output, kv_cache
12421243

12431244

12441245
# -----------------------------------------
@@ -1304,10 +1305,7 @@ def __call__(
13041305
layer_output = intermediate_inputs + mlp_lnx
13051306
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
13061307

1307-
if self.config.scan_layers:
1308-
return layer_output, None
1309-
else:
1310-
return layer_output, kv_cache
1308+
return layer_output, kv_cache
13111309

13121310

13131311
class Qwen3OmniMoeVisionPatchMerger(nnx.Module):

0 commit comments

Comments
 (0)