Skip to content

Commit e09523d

Browse files
committed
deepseek sharding and mla attention plumbing
1 parent e3dbd54 commit e09523d

10 files changed

Lines changed: 258 additions & 41 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ logical_axis_rules: [
463463
['decode_length', ['sequence']],
464464
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
465465
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
466+
['moe_mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
466467
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
467468
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
468469
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
@@ -1110,6 +1111,8 @@ vllm_hf_config_path: ""
11101111
# A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter.
11111112
# This can be used to override specific settings without modifying the original config file.
11121113
vllm_hf_overrides: {}
1114+
# Path to yaml file for loading vLLM config
1115+
vllm_config_path: ""
11131116
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
11141117
vllm_additional_config: {}
11151118
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
base_config: "vllm.yml"
17+
18+
logical_axis_rules: [
19+
['activation_batch', []],
20+
['activation_batch_no_exp', []],
21+
['activation_embed_and_logits_batch', ['expert']],
22+
['activation_embed_and_logits_batch_sequence', ['expert']],
23+
['activation_heads', ['model']],
24+
['activation_kv_heads', ['model']],
25+
['activation_attn_length', ['expert']],
26+
['activation_attn_length_no_exp', []],
27+
['activation_length', ['data', 'expert']],
28+
['activation_length_no_exp', 'data'],
29+
['activation_q_length', ['expert']],
30+
['activation_attn_embed', 'model'],
31+
['activation_embed', ['model', 'attn_dp']],
32+
['activation_mlp', ['model', 'attn_dp', 'expert']],
33+
['activation_kv', ['model']],
34+
['activation_prefill_kv_batch', ['expert']],
35+
['activation_kv_batch', []],
36+
['activation_kv_batch_no_exp', []],
37+
['activation_kv_head_dim', ['model', 'attn_dp', 'expert']],
38+
['activation_vocab', ['model', 'attn_dp']],
39+
['activation_norm_length', []],
40+
['activation_exp', ['expert']],
41+
['decode_batch', ['expert']],
42+
['decode_length', []],
43+
['mlp_no_fsdp', ['model', 'attn_dp', 'expert']],
44+
['vocab', ['model', 'attn_dp', 'expert']],
45+
['heads', ['expert', 'attn_dp', 'model']],
46+
['q_heads', []],
47+
['kv_heads', []],
48+
['kv_head_dim', ['model', 'attn_dp', 'expert']],
49+
['kv', ['model', 'attn_dp', 'expert']],
50+
['kv', []],
51+
['embed', []],
52+
['mlp', ['model', 'attn_dp', 'expert']],
53+
['moe_mlp', []],
54+
['embed_tensor_transpose', ['attn_dp', 'model']],
55+
['embed_no_exp', []],
56+
['q_lora', []],
57+
['kv_lora', []],
58+
['norm', []],
59+
['cache_heads', ['model']],
60+
['exp', ['expert', 'attn_dp', 'model']],
61+
['paged_kv_heads', ['model']],
62+
['cache_batch_prefill', []],
63+
['cache_batch', []],
64+
['cache_sequence', []],
65+
['cache_heads_none', []],
66+
['cache_kv', []],
67+
['kv_lora_up_proj',['expert', 'attn_dp', 'model']],
68+
['q_lora_up_proj',['expert', 'attn_dp', 'model']],
69+
]

src/maxtext/configs/post_train/rl.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ max_num_seqs: null
154154
async_scheduling: True
155155
# stop generation when any of these strings is generated
156156
stop_strings: null
157+
# path to initialize vllm config
158+
vllm_config_path: 'src/maxtext/configs/inference/vllm.yml'
157159

158160
# ====== Checkpoint Configuration ======
159161
enable_checkpointing: True

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,6 +1610,7 @@ class VLLM(BaseModel):
16101610
description="Overrides for HuggingFace model config for MaxText model.",
16111611
)
16121612
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
1613+
vllm_config_path: str = Field("src/maxtext/configs/inference/vllm.yml", description="path to yaml file for loading vLLM config.")
16131614

16141615

16151616
class RL(BaseModel):

src/maxtext/inference/vllm_decode.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
use_chat_template=True
3030
"""
3131

32+
import copy
3233
import os
3334
from typing import Any, Sequence
3435

@@ -67,6 +68,21 @@ def decode_with_vllm(config: Config) -> None:
6768
config: MaxText config.
6869
"""
6970
# Prepare vLLM Arguments
71+
# Use user-provided vllm_additional_config as base (includes model-specific
72+
# overrides like base_num_decoder_layers, override_model_config, etc.), then
73+
# fill in defaults and runtime-derived values on top.
74+
additional_config = copy.deepcopy(config.vllm_additional_config) if config.vllm_additional_config else {}
75+
additional_config.setdefault("maxtext_config", {})
76+
additional_config["maxtext_config"].setdefault("model_name", config.model_name)
77+
additional_config["maxtext_config"].setdefault("weight_dtype", "bfloat16")
78+
additional_config["maxtext_config"].setdefault("allow_split_physical_axes", True)
79+
additional_config["maxtext_config"]["debug_sharding"] = config.debug_sharding
80+
additional_config.setdefault("sharding", {})
81+
additional_config["sharding"].setdefault("sharding_strategy", {})
82+
additional_config["sharding"]["sharding_strategy"].setdefault("enable_dp_attention", config.enable_dp_attention)
83+
# Pass vllm_config_path so the adapter can use it as the MaxText base config.
84+
additional_config.setdefault("vllm_config_path", str(config.vllm_config_path))
85+
7086
vllm_args = {
7187
"model": config.tokenizer_path,
7288
"max_model_len": config.max_target_length,
@@ -76,19 +92,7 @@ def decode_with_vllm(config: Config) -> None:
7692
"hf_overrides": config.vllm_hf_overrides,
7793
"gpu_memory_utilization": config.hbm_utilization_vllm,
7894
"async_scheduling": config.async_scheduling,
79-
"additional_config": {
80-
"maxtext_config": {
81-
"model_name": config.model_name,
82-
"weight_dtype": "bfloat16",
83-
"allow_split_physical_axes": True,
84-
"debug_sharding": config.debug_sharding,
85-
},
86-
"sharding": {
87-
"sharding_strategy": {
88-
"enable_dp_attention": config.enable_dp_attention,
89-
},
90-
},
91-
},
95+
"additional_config": additional_config,
9296
}
9397

9498
if config.load_parameters_path:
@@ -106,8 +110,7 @@ def decode_with_vllm(config: Config) -> None:
106110
f"and EP={config.ici_expert_parallelism if enable_expert_parallel else 1}..."
107111
)
108112

109-
vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
110-
argv_list = ["", str(vllm_config_path), "log_config=False"]
113+
argv_list = ["", str(config.vllm_config_path), "log_config=False"]
111114
vllm_config = pyconfig.initialize(argv_list)
112115

113116
with nn_partitioning.axis_rules(vllm_config.logical_axis_rules):
@@ -145,7 +148,7 @@ def decode_with_vllm(config: Config) -> None:
145148
max_tokens=max_tokens_to_generate,
146149
top_k=config.decode_sampling_top_k,
147150
top_p=config.decode_sampling_nucleus_p,
148-
seed=FLAGS.seed,
151+
# seed=FLAGS.seed,
149152
)
150153

151154
outputs = llm.generate(prompts, sampling_params)

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,12 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
6969
)
7070
overrides["load_parameters_path"] = None
7171

72-
# Add base config path to positional args
73-
base_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
72+
# Add base config path to positional args — prefer the caller-supplied
73+
# vllm_config_path from additional_config, fall back to vllm.yml default.
74+
base_config_path = vllm_config.additional_config.get(
75+
"vllm_config_path",
76+
os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml"),
77+
)
7478
argv_list = ["", str(base_config_path)]
7579

7680
maxtext_config = pyconfig.initialize(argv_list, **overrides)

src/maxtext/layers/attention_mla.py

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import jax
2222
from jax.ad_checkpoint import checkpoint_name
2323
from jax.experimental import layout
24+
from jax.sharding import PartitionSpec as P
25+
from jax.experimental import shard_map
2426
import jax.numpy as jnp
2527
from jax.sharding import Mesh, NamedSharding
2628

@@ -623,7 +625,11 @@ def __init__(
623625
)
624626

625627
# Module attribute names must match names previously passed to Linen for checkpointing
626-
self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None
628+
self.MlaKVCache_0 = (
629+
self.init_mla_kv_caches(inputs_kv_shape)
630+
if model_mode != MODEL_MODE_TRAIN and config.attention != "vllm_rpa"
631+
else None
632+
)
627633

628634
def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
629635
"""Initializes the MLA-specific projections."""
@@ -941,15 +947,118 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
941947

942948
key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode)
943949
cached_values = [None, None]
944-
if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN:
950+
if self.config.attention != "paged" and self.config.attention != "vllm_rpa" and model_mode != MODEL_MODE_TRAIN:
945951
if self.config.mla_naive_kvcache:
946952
cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)
947953
else:
948954
cached_values = self.update_mla_kv_caches(
949955
low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk
950956
)
951957

952-
return key, value, cached_values
958+
return key, value, cached_values, low_rank_main, key_rope
959+
960+
def mla_rpa_vllm(self, q_nope, q_rope, k_latent, k_rope, mla_kv_cache, mla_metadata):
961+
"""Forward function for vLLM serving with MLA attention.
962+
963+
Args:
964+
q_nope: Query nope part [T, N, qk_nope_head_dim]
965+
q_rope: Query rope part [T, N, qk_rope_head_dim]
966+
k_latent: Latent KV representation [S, kv_lora_rank] (NOT expanded k_nope)
967+
k_rope: Key rope part [S, qk_rope_head_dim] (NO head dimension)
968+
mla_kv_cache: The KV cache
969+
mla_metadata: Attention metadata
970+
"""
971+
md = mla_metadata
972+
try:
973+
# pylint: disable=import-outside-toplevel
974+
# pytype: disable=import-error
975+
from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention
976+
from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import get_tuned_block_sizes
977+
except ImportError as e:
978+
raise ImportError(
979+
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
980+
) from e
981+
982+
if mla_kv_cache is None or mla_metadata is None:
983+
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
984+
985+
wkv_b_kernel = self.wkv_b.kernel.value
986+
wk_b_kernel = wkv_b_kernel[..., : self.qk_nope_head_dim]
987+
wv_b_kernel = wkv_b_kernel[..., self.qk_nope_head_dim :]
988+
q_absorbed = jnp.einsum("TNH,ANH->TNA", q_nope, wk_b_kernel)
989+
990+
def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args):
991+
def _initialize_block_sizes():
992+
# Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes)
993+
max_num_tokens = q_absorbed.shape[0]
994+
max_num_seqs = md.seq_lens.shape[0]
995+
num_page_indices = md.block_tables.shape[0]
996+
assert num_page_indices % max_num_seqs == 0
997+
pages_per_seq = num_page_indices // max_num_seqs
998+
# num_kv_pages_per_block = min(pages_per_seq, 16)
999+
bkv_p, bq_sz = get_tuned_block_sizes(
1000+
q_nope.dtype,
1001+
q_nope.dtype, # changed to q_nope dtype from mla_kv_cache.dtype
1002+
self.num_query_heads,
1003+
1, # num_kv_heads for MLA kernel
1004+
self.qk_nope_head_dim,
1005+
q_nope.shape[1], # page size ?? kv_cache.shape[1]
1006+
max_num_tokens,
1007+
pages_per_seq,
1008+
)
1009+
num_kv_pages_per_block = min(pages_per_seq, bkv_p, 4)
1010+
num_queries_per_block = min(max_num_tokens, bq_sz, 4) # OOMS at 8
1011+
return num_kv_pages_per_block, num_queries_per_block
1012+
1013+
num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes()
1014+
output, kv_cache = mla_ragged_paged_attention(
1015+
q,
1016+
q_rope,
1017+
k,
1018+
k_rope,
1019+
kv_cache,
1020+
*args,
1021+
sm_scale=1.0,
1022+
num_kv_pages_per_block=num_kv_pages_per_block,
1023+
num_queries_per_block=num_queries_per_block,
1024+
)
1025+
return kv_cache, output
1026+
1027+
in_specs = (
1028+
P(("attn_dp", "model", "expert"), None, None), # q
1029+
P(("attn_dp", "model", "expert"), None, None), # q_rope
1030+
P(("attn_dp", "model", "expert"), None), # k
1031+
P(("attn_dp", "model", "expert"), None), # k_rope
1032+
P(("attn_dp", "model", "expert")), # kv_cache
1033+
P(("data", "attn_dp")), # md.seq_lens: Replicated
1034+
P(("data", "attn_dp")), # page_indices_flat: Replicated
1035+
P(("data", "attn_dp")), # query_start_loc: Replicated
1036+
P(("data", "attn_dp")), # distribution: Replicated
1037+
)
1038+
1039+
out_specs = (P(("attn_dp", "model", "expert"), None, None), P(("attn_dp", "model", "expert")))
1040+
1041+
kv_cache, output = jax.jit(
1042+
shard_map.shard_map(
1043+
_mla_ragged_paged_attention,
1044+
mesh=self.mesh,
1045+
in_specs=in_specs,
1046+
out_specs=out_specs,
1047+
check_rep=False,
1048+
),
1049+
)(
1050+
q_absorbed,
1051+
q_rope,
1052+
k_latent,
1053+
k_rope,
1054+
mla_kv_cache,
1055+
md.seq_lens,
1056+
md.block_tables,
1057+
md.query_start_loc,
1058+
md.request_distribution,
1059+
)
1060+
output = jnp.einsum("TNA,ANH->TNH", output, wv_b_kernel)
1061+
return kv_cache, output
9531062

9541063
def __call__(
9551064
self,
@@ -1005,7 +1114,7 @@ def __call__(
10051114
query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
10061115
if self.config.force_q_layout:
10071116
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
1008-
key, value, cached_values = self.mla_kv_projection(
1117+
key, value, cached_values, low_rank_main, key_rope = self.mla_kv_projection(
10091118
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
10101119
)
10111120
query = checkpoint_name(query, "query_proj")
@@ -1039,7 +1148,22 @@ def __call__(
10391148
)
10401149
unnormalized_out = unnormalized_out[..., : self.v_head_dim]
10411150
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
1151+
elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN and kv_cache is not None:
1152+
batch, seq_len, num_heads, _ = query.shape
1153+
query = query.reshape(-1, query.shape[2], query.shape[3])
1154+
q_nope, q_rope = jnp.split(query, [self.qk_nope_head_dim], axis=-1)
1155+
1156+
k_latent = low_rank_main.reshape(-1, self.kv_lora_rank)
1157+
k_rope_squeezed = key_rope.reshape(-1, self.qk_rope_head_dim)
1158+
1159+
updated_kv, attn_out = self.mla_rpa_vllm(
1160+
q_nope, q_rope, k_latent, k_rope_squeezed, mla_kv_cache=kv_cache, mla_metadata=attention_metadata
1161+
)
1162+
out = attn_out.reshape(batch, seq_len, num_heads, self.v_head_dim)
1163+
kv_cache = updated_kv
10421164
else:
1165+
if self.config.attention == "vllm_rpa" and kv_cache is None and model_mode != MODEL_MODE_TRAIN:
1166+
model_mode = MODEL_MODE_TRAIN
10431167
out = self.attention_op(
10441168
query,
10451169
key,

0 commit comments

Comments
 (0)