Skip to content

Commit fee2744

Browse files
committed
pass kv_cache through scanned decoder layers
1 parent 47c6d0c commit fee2744

4 files changed

Lines changed: 43 additions & 40 deletions

File tree

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ class MaxTextForCausalLM(nnx.Module):
8686
of the decoding step.
8787
"""
8888

89+
# Signal to tpu-inference model_loader that this class manages its own
90+
# JIT-sharded initialization (via create_nnx_model with out_shardings).
91+
# When True, model_loader skips wrapping __init__ in an outer bare @jax.jit,
92+
_self_manages_sharding: bool = True
93+
8994
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
9095
"""Initializes the MaxTextForCausalLM model.
9196
@@ -232,7 +237,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
232237
if self.model is not None:
233238
return
234239

235-
with self.mesh, nn.logical_axis_rules(""):
240+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
236241
model, _ = model_creation_utils.create_nnx_model(
237242
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
238243
)

src/maxtext/layers/attentions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -956,13 +956,14 @@ def forward_serve_vllm(
956956
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
957957
) from e
958958

959-
if rpa_kv_cache is None or rpa_metadata is None:
960-
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
961-
962959
query = query.reshape(-1, query.shape[2], query.shape[3])
963960
key = key.reshape(-1, key.shape[2], key.shape[3])
964961
value = value.reshape(-1, value.shape[2], value.shape[3])
965962

963+
if rpa_kv_cache is None or rpa_metadata is None:
964+
# Return dummy values for dry runs (e.g. during model initialization or JIT tracing)
965+
return [], query
966+
966967
if self.config.sliding_window_size > 0:
967968
attention_chunk_size = self.config.sliding_window_size
968969
else:

src/maxtext/layers/decoders.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,11 @@ def __call__(
792792
decoder_positions,
793793
deterministic,
794794
model_mode,
795+
previous_chunk,
796+
page_state,
797+
slot,
795798
)
799+
in_axes_tuple = (nn.broadcast,) * len(broadcast_args)
796800
if cfg.using_pipeline_parallelism:
797801
if cfg.pipeline_fsdp_ag_once:
798802
logical_partition_spec = self.pipeline_module.get_weight_sharding(
@@ -847,26 +851,12 @@ def __call__(
847851
else:
848852
if cfg.scan_layers:
849853
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
850-
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
851-
layer_call_kwargs = {
852-
"page_state": page_state,
853-
"previous_chunk": previous_chunk,
854-
"slot": slot,
855-
}
856854
dense_layer = RemattedBlockLayers[0]
857855
moe_layer = RemattedBlockLayers[1]
858856
if cfg.engram_layers:
859-
original_dense_call = dense_layer.__call__
860-
original_moe_call = moe_layer.__call__
861-
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
862-
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
863-
864857
common_kwargs = {
865858
"dense_layer": dense_layer,
866859
"moe_layer": moe_layer,
867-
"original_dense_call": original_dense_call,
868-
"original_moe_call": original_moe_call,
869-
"layer_call_kwargs": layer_call_kwargs,
870860
"decoder_segment_ids": decoder_segment_ids,
871861
"decoder_positions": decoder_positions,
872862
"deterministic": deterministic,
@@ -895,7 +885,6 @@ def __call__(
895885
**common_kwargs,
896886
)
897887
else:
898-
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
899888
y, _ = self.scan_decoder_layers(
900889
cfg,
901890
dense_layer,
@@ -905,7 +894,6 @@ def __call__(
905894
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
906895
model_mode=model_mode,
907896
)(y, *broadcast_args)
908-
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
909897
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
910898

911899
# If batch-split schedule is used and initialization is complete,
@@ -954,16 +942,38 @@ def __call__(
954942
"nope_layer_interval": self.config.nope_layer_interval,
955943
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
956944
}
957-
y, _ = self.scan_decoder_layers(
945+
946+
# Update broadcast_args and in_axes_tuple for vLLM RPA
947+
current_broadcast_args = list(broadcast_args)
948+
current_in_axes_tuple = list(in_axes_tuple)
949+
950+
if kv_caches is not None:
951+
# Stack kv_caches for scan: [num_layers, ...]
952+
stacked_kv_cache = jnp.stack(kv_caches, axis=0)
953+
current_broadcast_args.append(stacked_kv_cache)
954+
current_in_axes_tuple.append(0) # Scan over the layer dimension
955+
else:
956+
current_broadcast_args.append(None)
957+
current_in_axes_tuple.append(nn.broadcast)
958+
959+
current_broadcast_args.append(attention_metadata)
960+
current_in_axes_tuple.append(nn.broadcast)
961+
962+
y, returned_kv_cache = self.scan_decoder_layers(
958963
cfg,
959964
RemattedBlockLayer,
960965
scan_length,
961966
"layers",
962967
mesh,
963-
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
968+
in_axes_tuple=tuple(current_in_axes_tuple),
964969
model_mode=model_mode,
965970
**layer_kwargs,
966-
)(y, *broadcast_args)
971+
)(y, *current_broadcast_args)
972+
973+
if kv_caches is not None and returned_kv_cache is not None:
974+
# Update the list of KV caches from the scanned results
975+
for i, cache in enumerate(returned_kv_cache):
976+
kv_caches[i] = cache
967977
else:
968978
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
969979
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."
@@ -1173,10 +1183,8 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
11731183
"""Applies a single, unscanned Engram layer."""
11741184
layer = kwargs["dense_layer"] if layer_type == "dense" else kwargs["moe_layer"]
11751185
layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
1176-
original_call = kwargs["original_dense_call"] if layer_type == "dense" else kwargs["original_moe_call"]
1177-
layer_call_kwargs = kwargs["layer_call_kwargs"]
1186+
broadcast_args = kwargs["broadcast_args"]
11781187

1179-
layer.__call__ = original_call
11801188
y, _ = layer(
11811189
config=self.config,
11821190
mesh=self.mesh,
@@ -1186,14 +1194,9 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
11861194
layer_idx=current_idx,
11871195
)(
11881196
y,
1189-
kwargs["decoder_segment_ids"],
1190-
kwargs["decoder_positions"],
1191-
kwargs["deterministic"],
1192-
kwargs["model_mode"],
1197+
*broadcast_args,
11931198
decoder_input_tokens=kwargs["decoder_input_tokens"],
1194-
**layer_call_kwargs,
11951199
)
1196-
layer.__call__ = functools.partial(original_call, **layer_call_kwargs)
11971200
return y
11981201

11991202
def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_type, **kwargs):

src/maxtext/models/qwen3.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,10 +1235,7 @@ def __call__(
12351235
layer_output = intermediate_inputs + mlp_lnx
12361236
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
12371237

1238-
if self.config.scan_layers:
1239-
return layer_output, None
1240-
else:
1241-
return layer_output, kv_cache
1238+
return layer_output, kv_cache
12421239

12431240

12441241
# -----------------------------------------
@@ -1304,10 +1301,7 @@ def __call__(
13041301
layer_output = intermediate_inputs + mlp_lnx
13051302
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
13061303

1307-
if self.config.scan_layers:
1308-
return layer_output, None
1309-
else:
1310-
return layer_output, kv_cache
1304+
return layer_output, kv_cache
13111305

13121306

13131307
class Qwen3OmniMoeVisionPatchMerger(nnx.Module):

0 commit comments

Comments
 (0)