Skip to content

Commit 5c4341b

Browse files
committed
trial
1 parent 41267f2 commit 5c4341b

5 files changed

Lines changed: 70 additions & 13 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,7 @@ use_tokamax_splash: false
11731173
use_jax_splash: false
11741174

11751175
# vLLM Adapter Configurations
1176+
hbm_utilization_vllm: 0.5
11761177
# Path to the HuggingFace-style config directory for the adapter (e.g. src/maxtext/integration/vllm/maxtext_vllm_adapter)
11771178
vllm_hf_config_path: ""
11781179
# A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter.

src/maxtext/configs/inference/vllm.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ logical_axis_rules: [
7878
['cache_heads', ['model']],
7979
['exp', ['expert', 'attn_dp_expert']],
8080
['paged_kv_heads', ['model']],
81+
['layers', []],
8182
]
8283
data_sharding: [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']]
8384
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']

src/maxtext/inference/vllm_decode.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def decode_with_vllm(config: Config) -> None:
8282
"weight_dtype": "bfloat16",
8383
"allow_split_physical_axes": True,
8484
"debug_sharding": config.debug_sharding,
85+
"scan_layers": config.scan_layers,
8586
},
8687
"sharding": {
8788
"sharding_strategy": {
@@ -140,11 +141,15 @@ def decode_with_vllm(config: Config) -> None:
140141
f"max_target_length ({config.max_target_length}) must be greater than max_prompt_length ({max_prompt_length})"
141142
)
142143

144+
# MaxText uses -1 to mean "disabled"; vLLM requires top_p in (0, 1].
145+
top_p = config.decode_sampling_nucleus_p if config.decode_sampling_nucleus_p > 0 else 1.0
146+
top_k = config.decode_sampling_top_k if config.decode_sampling_top_k > 0 else -1
147+
143148
sampling_params = SamplingParams(
144149
temperature=config.decode_sampling_temperature,
145150
max_tokens=max_tokens_to_generate,
146-
top_k=config.decode_sampling_top_k,
147-
top_p=config.decode_sampling_nucleus_p,
151+
top_k=top_k,
152+
top_p=top_p,
148153
)
149154

150155
outputs = llm.generate(prompts, sampling_params)

src/maxtext/layers/decoders.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -981,16 +981,54 @@ def __call__(
981981
"nope_layer_interval": self.config.nope_layer_interval,
982982
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
983983
}
984-
y, _ = self.scan_decoder_layers(
985-
cfg,
986-
RemattedBlockLayer,
987-
scan_length,
988-
"layers",
989-
mesh,
990-
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
991-
model_mode=model_mode,
992-
**layer_kwargs,
993-
)(y, *broadcast_args)
984+
# Update broadcast_args and in_axes_tuple for vLLM RPA
985+
in_axes_tuple = (nn.broadcast,) * len(broadcast_args)
986+
current_broadcast_args = list(broadcast_args)
987+
current_in_axes_tuple = list(in_axes_tuple)
988+
989+
current_broadcast_args.append(attention_metadata)
990+
current_in_axes_tuple.append(nn.broadcast)
991+
992+
if kv_caches is not None:
993+
# Stack kv_caches for scan: [num_layers, ...]
994+
stacked_kv_cache = jnp.stack(kv_caches, axis=0)
995+
996+
# We pass (y, stacked_kv_cache, 0) as the carry
997+
carry = (y, stacked_kv_cache, 0)
998+
999+
# We don't pass kv_cache as a scanned argument anymore
1000+
1001+
final_carry, _ = self.scan_decoder_layers(
1002+
cfg,
1003+
RemattedBlockLayer,
1004+
scan_length,
1005+
"layers",
1006+
mesh,
1007+
in_axes_tuple=tuple(current_in_axes_tuple),
1008+
model_mode=model_mode,
1009+
**layer_kwargs,
1010+
)(carry, *current_broadcast_args)
1011+
1012+
y, returned_kv_cache, _ = final_carry
1013+
1014+
# Update the list of KV caches from the scanned results
1015+
for i in range(cfg.num_decoder_layers):
1016+
kv_caches[i] = returned_kv_cache[i]
1017+
else:
1018+
# Fallback to old behavior if kv_caches is None (not vLLM RPA)
1019+
current_broadcast_args.append(None)
1020+
current_in_axes_tuple.append(nn.broadcast)
1021+
1022+
y, _ = self.scan_decoder_layers(
1023+
cfg,
1024+
RemattedBlockLayer,
1025+
scan_length,
1026+
"layers",
1027+
mesh,
1028+
in_axes_tuple=tuple(current_in_axes_tuple),
1029+
model_mode=model_mode,
1030+
**layer_kwargs,
1031+
)(y, *current_broadcast_args)
9941032
else:
9951033
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
9961034
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."

src/maxtext/models/qwen3.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1285,6 +1285,14 @@ def __call__(
12851285
attention_metadata: None | dict[str, Any] = None,
12861286
):
12871287
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
1288+
is_scan_carry = False
1289+
if isinstance(inputs, tuple) and len(inputs) == 3:
1290+
hidden_states, stacked_kv_cache, layer_idx = inputs
1291+
kv_cache = stacked_kv_cache[layer_idx]
1292+
inputs = hidden_states
1293+
is_scan_carry = True
1294+
elif isinstance(inputs, tuple):
1295+
inputs = inputs[0]
12881296
if isinstance(inputs, tuple):
12891297
inputs = inputs[0]
12901298
hidden_states, intermediate_inputs, kv_cache = self.apply_attention_with_norm(
@@ -1305,7 +1313,11 @@ def __call__(
13051313
layer_output = intermediate_inputs + mlp_lnx
13061314
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
13071315

1308-
return layer_output, kv_cache
1316+
if is_scan_carry:
1317+
stacked_kv_cache = stacked_kv_cache.at[layer_idx].set(kv_cache)
1318+
return (layer_output, stacked_kv_cache, layer_idx + 1), None
1319+
else:
1320+
return layer_output, kv_cache
13091321

13101322

13111323
class Qwen3OmniMoeVisionPatchMerger(nnx.Module):

0 commit comments

Comments
 (0)