Skip to content

Commit bfbbc5e

Browse files
committed
pass kv_cache through scanned decoder layers
1 parent 47c6d0c commit bfbbc5e

4 files changed

Lines changed: 40 additions & 15 deletions

File tree

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ class MaxTextForCausalLM(nnx.Module):
8585
tasks. It handles configuration generation, model initialization, and execution
8686
of the decoding step.
8787
"""
88+
# Signal to tpu-inference model_loader that this class manages its own
89+
# JIT-sharded initialization (via create_nnx_model with out_shardings).
90+
# When True, model_loader skips wrapping __init__ in an outer bare @jax.jit,
91+
_self_manages_sharding: bool = True
8892

8993
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
9094
"""Initializes the MaxTextForCausalLM model.
@@ -232,7 +236,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
232236
if self.model is not None:
233237
return
234238

235-
with self.mesh, nn.logical_axis_rules(""):
239+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
236240
model, _ = model_creation_utils.create_nnx_model(
237241
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
238242
)

src/maxtext/layers/attentions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -956,13 +956,14 @@ def forward_serve_vllm(
956956
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
957957
) from e
958958

959-
if rpa_kv_cache is None or rpa_metadata is None:
960-
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
961-
962959
query = query.reshape(-1, query.shape[2], query.shape[3])
963960
key = key.reshape(-1, key.shape[2], key.shape[3])
964961
value = value.reshape(-1, value.shape[2], value.shape[3])
965962

963+
if rpa_kv_cache is None or rpa_metadata is None:
964+
# Return dummy values for dry runs (e.g. during model initialization or JIT tracing)
965+
return [], query
966+
966967
if self.config.sliding_window_size > 0:
967968
attention_chunk_size = self.config.sliding_window_size
968969
else:

src/maxtext/layers/decoders.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,11 @@ def __call__(
792792
decoder_positions,
793793
deterministic,
794794
model_mode,
795+
previous_chunk,
796+
page_state,
797+
slot,
795798
)
799+
in_axes_tuple = (nn.broadcast,) * len(broadcast_args)
796800
if cfg.using_pipeline_parallelism:
797801
if cfg.pipeline_fsdp_ag_once:
798802
logical_partition_spec = self.pipeline_module.get_weight_sharding(
@@ -954,16 +958,38 @@ def __call__(
954958
"nope_layer_interval": self.config.nope_layer_interval,
955959
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
956960
}
957-
y, _ = self.scan_decoder_layers(
961+
962+
# Update broadcast_args and in_axes_tuple for vLLM RPA
963+
current_broadcast_args = list(broadcast_args)
964+
current_in_axes_tuple = list(in_axes_tuple)
965+
966+
if kv_caches is not None:
967+
# Stack kv_caches for scan: [num_layers, ...]
968+
stacked_kv_cache = jnp.stack(kv_caches, axis=0)
969+
current_broadcast_args.append(stacked_kv_cache)
970+
current_in_axes_tuple.append(0) # Scan over the layer dimension
971+
else:
972+
current_broadcast_args.append(None)
973+
current_in_axes_tuple.append(nn.broadcast)
974+
975+
current_broadcast_args.append(attention_metadata)
976+
current_in_axes_tuple.append(nn.broadcast)
977+
978+
y, returned_kv_cache = self.scan_decoder_layers(
958979
cfg,
959980
RemattedBlockLayer,
960981
scan_length,
961982
"layers",
962983
mesh,
963-
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
984+
in_axes_tuple=tuple(current_in_axes_tuple),
964985
model_mode=model_mode,
965986
**layer_kwargs,
966-
)(y, *broadcast_args)
987+
)(y, *current_broadcast_args)
988+
989+
if kv_caches is not None and returned_kv_cache is not None:
990+
# Update the list of KV caches from the scanned results
991+
for i in range(len(kv_caches)):
992+
kv_caches[i] = returned_kv_cache[i]
967993
else:
968994
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
969995
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."

src/maxtext/models/qwen3.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,10 +1235,7 @@ def __call__(
12351235
layer_output = intermediate_inputs + mlp_lnx
12361236
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
12371237

1238-
if self.config.scan_layers:
1239-
return layer_output, None
1240-
else:
1241-
return layer_output, kv_cache
1238+
return layer_output, kv_cache
12421239

12431240

12441241
# -----------------------------------------
@@ -1304,10 +1301,7 @@ def __call__(
13041301
layer_output = intermediate_inputs + mlp_lnx
13051302
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
13061303

1307-
if self.config.scan_layers:
1308-
return layer_output, None
1309-
else:
1310-
return layer_output, kv_cache
1304+
return layer_output, kv_cache
13111305

13121306

13131307
class Qwen3OmniMoeVisionPatchMerger(nnx.Module):

0 commit comments

Comments
 (0)