Skip to content

Commit 4db09ea

Browse files
committed
deepseek sharding and mla attention plumbing
1 parent 3816629 commit 4db09ea

11 files changed

Lines changed: 255 additions & 32 deletions

File tree

src/MaxText/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from maxtext.trainers.post_train.dpo import dpo_utils
3535
from maxtext.utils import maxtext_utils
3636
from maxtext.utils import model_creation_utils
37-
from maxtext.utils.model_creation_utils import from_config
3837

3938
Transformer = models.Transformer
4039
transformer_as_linen = models.transformer_as_linen

src/MaxText/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ logical_axis_rules: [
430430
['decode_length', ['sequence']],
431431
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
432432
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
433+
['moe_mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
433434
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
434435
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
435436
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
@@ -1049,6 +1050,8 @@ use_jax_splash: false
10491050
# vLLM Adapter Configurations
10501051
# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter)
10511052
vllm_hf_config_path: ""
1053+
# Path to yaml file for loading vLLM config
1054+
vllm_config_path: ""
10521055
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
10531056
vllm_additional_config: {}
10541057
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]

src/MaxText/configs/rl.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ enable_dp_attention: False
149149
# Performance tuning for samplers
150150
max_num_batched_tokens: null
151151
max_num_seqs: null
152+
# path to initialize vllm config
153+
vllm_config_path: 'src/MaxText/configs/vllm.yml'
152154

153155
# ====== Checkpoint Configuration ======
154156
enable_checkpointing: True

src/MaxText/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1491,6 +1491,7 @@ class VLLM(BaseModel):
14911491
max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.")
14921492
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
14931493
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
1494+
vllm_config_path: str = Field("src/MaxText/configs/vllm.yml", description="path to yaml file for loading vLLM config.")
14941495

14951496

14961497
class RL(BaseModel):

src/MaxText/configs/vllm.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ logical_axis_rules: [
5353
['decode_length', []],
5454
['mlp', ['model', 'attn_dp']],
5555
['mlp_no_fsdp', ['model', 'attn_dp']],
56+
['moe_mlp', ['model', 'attn_dp']],
5657
['vocab', ['model', 'attn_dp']],
5758
['heads', ['model']],
5859
['q_heads', ['model']],
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
base_config: "vllm.yml"
17+
18+
logical_axis_rules: [
19+
['activation_batch', ['']],
20+
['activation_batch_no_exp', []],
21+
['activation_embed_and_logits_batch', ['expert']],
22+
['activation_embed_and_logits_batch_sequence', ['expert']],
23+
['activation_heads', ['model']],
24+
['activation_kv_heads', ['model']],
25+
['activation_attn_length', ['expert']],
26+
['activation_attn_length_no_exp', []],
27+
['activation_length', ['data', 'expert']],
28+
['activation_length_no_exp', 'data'],
29+
['activation_q_length', ['expert']],
30+
['activation_attn_embed', 'model'],
31+
['activation_embed', ['model', 'attn_dp']],
32+
['activation_mlp', ['model', 'attn_dp', 'expert']],
33+
['activation_kv', ['model']],
34+
['activation_prefill_kv_batch', ['expert']],
35+
['activation_kv_batch', ['']],
36+
['activation_kv_batch_no_exp', []],
37+
['activation_kv_head_dim', ['model', 'attn_dp', 'expert']],
38+
['activation_vocab', ['model', 'attn_dp']],
39+
['activation_norm_length', []],
40+
['activation_exp', ['expert']],
41+
['decode_batch', ['expert']],
42+
['decode_length', []],
43+
['mlp_no_fsdp', ['model', 'attn_dp', 'expert']],
44+
['vocab', ['model', 'attn_dp', 'expert']],
45+
['heads', ['expert', 'attn_dp', 'model']],
46+
['q_heads', []],
47+
['kv_heads', []],
48+
['kv_head_dim', ['model', 'attn_dp', 'expert']],
49+
['kv', ['model', 'attn_dp', 'expert']],
50+
['kv', []],
51+
['embed', []],
52+
['mlp', ['model', 'attn_dp', 'expert']],
53+
['moe_mlp', []],
54+
['embed_tensor_transpose', ['attn_dp', 'model']],
55+
['embed_no_exp', []],
56+
['q_lora', []],
57+
['kv_lora', []],
58+
['norm', []],
59+
['cache_heads', ['model']],
60+
['exp', ['expert', 'attn_dp', 'model']],
61+
['paged_kv_heads', ['model']],
62+
['cache_batch_prefill', []],
63+
['cache_batch', []],
64+
['cache_sequence', []],
65+
['cache_heads_none', []],
66+
['cache_kv', []],
67+
['kv_lora_up_proj',['expert', 'attn_dp', 'model']],
68+
['q_lora_up_proj',['expert', 'attn_dp', 'model']],
69+
]

src/MaxText/layers/attention_mla.py

Lines changed: 128 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import jax
2222
from jax.ad_checkpoint import checkpoint_name
2323
from jax.experimental import layout
24+
from jax.sharding import PartitionSpec as P
25+
from jax.experimental import shard_map
2426
import jax.numpy as jnp
2527
from jax.sharding import Mesh, NamedSharding
2628

@@ -619,7 +621,11 @@ def __init__(
619621
)
620622

621623
# Module attribute names must match names previously passed to Linen for checkpointing
622-
self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None
624+
self.MlaKVCache_0 = (
625+
self.init_mla_kv_caches(inputs_kv_shape)
626+
if model_mode != MODEL_MODE_TRAIN and config.attention != "vllm_rpa"
627+
else None
628+
)
623629

624630
def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
625631
"""Initializes the MLA-specific projections."""
@@ -937,15 +943,118 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
937943

938944
key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode)
939945
cached_values = [None, None]
940-
if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN:
946+
if self.config.attention != "paged" and self.config.attention != "vllm_rpa" and model_mode != MODEL_MODE_TRAIN:
941947
if self.config.mla_naive_kvcache:
942948
cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)
943949
else:
944950
cached_values = self.update_mla_kv_caches(
945951
low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk
946952
)
947953

948-
return key, value, cached_values
954+
return key, value, cached_values, low_rank_main, key_rope
955+
956+
def mla_rpa_vllm(self, q_nope, q_rope, k_latent, k_rope, mla_kv_cache, mla_metadata):
957+
"""Forward function for vLLM serving with MLA attention.
958+
959+
Args:
960+
q_nope: Query nope part [T, N, qk_nope_head_dim]
961+
q_rope: Query rope part [T, N, qk_rope_head_dim]
962+
k_latent: Latent KV representation [S, kv_lora_rank] (NOT expanded k_nope)
963+
k_rope: Key rope part [S, qk_rope_head_dim] (NO head dimension)
964+
mla_kv_cache: The KV cache
965+
mla_metadata: Attention metadata
966+
"""
967+
md = mla_metadata
968+
try:
969+
# pylint: disable=import-outside-toplevel
970+
# pytype: disable=import-error
971+
from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention
972+
from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import get_tuned_block_sizes
973+
except ImportError as e:
974+
raise ImportError(
975+
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
976+
) from e
977+
978+
if mla_kv_cache is None or mla_metadata is None:
979+
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
980+
981+
wkv_b_kernel = self.wkv_b.kernel.value
982+
wk_b_kernel = wkv_b_kernel[..., : self.qk_nope_head_dim]
983+
wv_b_kernel = wkv_b_kernel[..., self.qk_nope_head_dim :]
984+
q_absorbed = jnp.einsum("TNH,ANH->TNA", q_nope, wk_b_kernel)
985+
986+
def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args):
987+
def _initialize_block_sizes():
988+
# Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes)
989+
max_num_tokens = q_absorbed.shape[0]
990+
max_num_seqs = md.seq_lens.shape[0]
991+
num_page_indices = md.block_tables.shape[0]
992+
assert num_page_indices % max_num_seqs == 0
993+
pages_per_seq = num_page_indices // max_num_seqs
994+
# num_kv_pages_per_block = min(pages_per_seq, 16)
995+
bkv_p, bq_sz = get_tuned_block_sizes(
996+
q_nope.dtype,
997+
q_nope.dtype, # changed to q_nope dtype from mla_kv_cache.dtype
998+
self.num_query_heads,
999+
1, # num_kv_heads for MLA kernel
1000+
self.qk_nope_head_dim,
1001+
q_nope.shape[1], # page size ?? kv_cache.shape[1]
1002+
max_num_tokens,
1003+
pages_per_seq,
1004+
)
1005+
num_kv_pages_per_block = min(pages_per_seq, bkv_p, 4)
1006+
num_queries_per_block = min(max_num_tokens, bq_sz, 4) # OOMS at 8
1007+
return num_kv_pages_per_block, num_queries_per_block
1008+
1009+
num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes()
1010+
output, kv_cache = mla_ragged_paged_attention(
1011+
q,
1012+
q_rope,
1013+
k,
1014+
k_rope,
1015+
kv_cache,
1016+
*args,
1017+
sm_scale=1.0,
1018+
num_kv_pages_per_block=num_kv_pages_per_block,
1019+
num_queries_per_block=num_queries_per_block,
1020+
)
1021+
return kv_cache, output
1022+
1023+
in_specs = (
1024+
P(("attn_dp", "model", "expert"), None, None), # q
1025+
P(("attn_dp", "model", "expert"), None, None), # q_rope
1026+
P(("attn_dp", "model", "expert"), None), # k
1027+
P(("attn_dp", "model", "expert"), None), # k_rope
1028+
P(("attn_dp", "model", "expert")), # kv_cache
1029+
P(("data", "attn_dp")), # md.seq_lens: Replicated
1030+
P(("data", "attn_dp")), # page_indices_flat: Replicated
1031+
P(("data", "attn_dp")), # query_start_loc: Replicated
1032+
P(("data", "attn_dp")), # distribution: Replicated
1033+
)
1034+
1035+
out_specs = (P(("attn_dp", "model", "expert"), None, None), P(("attn_dp", "model", "expert")))
1036+
1037+
kv_cache, output = jax.jit(
1038+
shard_map.shard_map(
1039+
_mla_ragged_paged_attention,
1040+
mesh=self.mesh,
1041+
in_specs=in_specs,
1042+
out_specs=out_specs,
1043+
check_rep=False,
1044+
),
1045+
)(
1046+
q_absorbed,
1047+
q_rope,
1048+
k_latent,
1049+
k_rope,
1050+
mla_kv_cache,
1051+
md.seq_lens,
1052+
md.block_tables,
1053+
md.query_start_loc,
1054+
md.request_distribution,
1055+
)
1056+
output = jnp.einsum("TNA,ANH->TNH", output, wv_b_kernel)
1057+
return kv_cache, output
9491058

9501059
def __call__(
9511060
self,
@@ -1001,7 +1110,7 @@ def __call__(
10011110
query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
10021111
if self.config.force_q_layout:
10031112
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
1004-
key, value, cached_values = self.mla_kv_projection(
1113+
key, value, cached_values, low_rank_main, key_rope = self.mla_kv_projection(
10051114
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
10061115
)
10071116
query = checkpoint_name(query, "query_proj")
@@ -1034,8 +1143,22 @@ def __call__(
10341143
)
10351144
unnormalized_out = unnormalized_out[..., : self.v_head_dim]
10361145
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
1146+
elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN and kv_cache is not None:
1147+
batch, seq_len, num_heads, _ = query.shape
1148+
query = query.reshape(-1, query.shape[2], query.shape[3])
1149+
q_nope, q_rope = jnp.split(query, [self.qk_nope_head_dim], axis=-1)
1150+
1151+
k_latent = low_rank_main.reshape(-1, self.kv_lora_rank)
1152+
k_rope_squeezed = key_rope.reshape(-1, self.qk_rope_head_dim)
1153+
1154+
updated_kv, attn_out = self.mla_rpa_vllm(
1155+
q_nope, q_rope, k_latent, k_rope_squeezed, mla_kv_cache=kv_cache, mla_metadata=attention_metadata
1156+
)
1157+
out = attn_out.reshape(batch, seq_len, num_heads, self.v_head_dim)
1158+
kv_cache = updated_kv
10371159
else:
1038-
# Pass the index_mask to the Attention Op
1160+
if self.config.attention == "vllm_rpa" and kv_cache is None and model_mode != MODEL_MODE_TRAIN:
1161+
model_mode = MODEL_MODE_TRAIN
10391162
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values, index_mask=index_mask)
10401163

10411164
if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:

src/MaxText/layers/deepseek.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# pylint: disable=arguments-differ
1717
# pylint: disable=no-name-in-module
1818

19-
from typing import Optional
19+
from typing import Optional, Any
2020

2121
from flax import nnx
2222
from jax.ad_checkpoint import checkpoint_name
@@ -154,9 +154,11 @@ def attention_op(
154154
previous_chunk=None,
155155
page_state: None | page_manager.PageState = None,
156156
slot: None | int = None,
157+
kv_cache: None | jnp.ndarray = None,
158+
attention_metadata: None | dict[str, Any] = None,
157159
):
158160
"""Executes the attention layer."""
159-
attention_result, _ = self.self_attention(
161+
attention_result, kv_cache = self.self_attention(
160162
x,
161163
x,
162164
decoder_positions,
@@ -167,8 +169,10 @@ def attention_op(
167169
previous_chunk=previous_chunk,
168170
page_state=page_state,
169171
slot=slot,
172+
kv_cache=kv_cache,
173+
attention_metadata=attention_metadata,
170174
)
171-
return self.with_logical_constraint(attention_result)
175+
return self.with_logical_constraint(attention_result), kv_cache
172176

173177
@property
174178
def logical_axis_names(self):
@@ -229,23 +233,27 @@ def self_attention_with_norm_op(
229233
previous_chunk=None,
230234
page_state: None | page_manager.PageState = None,
231235
slot: None | int = None,
236+
kv_cache: None | jnp.ndarray = None,
237+
attention_metadata: None | dict[str, Any] = None,
232238
):
233239
"""self-attention with normalization"""
234240
lnx = self.pre_attention_norm_op(inputs)
235241

236-
attention_lnx = self.attention_op(
242+
attention_lnx, kv_cache = self.attention_op(
237243
lnx,
238244
decoder_segment_ids,
239245
decoder_positions,
240246
deterministic,
241247
previous_chunk,
242248
page_state,
243249
slot,
250+
kv_cache,
251+
attention_metadata,
244252
)
245253
intermediate_inputs = inputs + attention_lnx
246254
# Normalization
247255
hidden_states = self.post_attention_norm_op(intermediate_inputs)
248-
return hidden_states, intermediate_inputs
256+
return hidden_states, intermediate_inputs, kv_cache
249257

250258

251259
class DeepSeekDenseLayer(DeepSeekGenericLayer):
@@ -298,14 +306,16 @@ def __call__(
298306
x = self.with_logical_constraint(inputs)
299307
x = checkpoint_name(x, "decoder_layer_input")
300308

301-
hidden_states, intermediate_inputs = self.self_attention_with_norm_op(
309+
hidden_states, intermediate_inputs, kv_cache = self.self_attention_with_norm_op(
302310
x,
303311
decoder_segment_ids,
304312
decoder_positions,
305313
deterministic,
306314
previous_chunk,
307315
page_state,
308316
slot,
317+
kv_cache,
318+
attention_metadata,
309319
)
310320

311321
mlp_lnx = self.mlp_op(hidden_states, deterministic)
@@ -384,14 +394,16 @@ def __call__(
384394
x = self.with_logical_constraint(inputs)
385395
x = checkpoint_name(x, "decoder_layer_input")
386396

387-
hidden_states, intermediate_inputs = self.self_attention_with_norm_op(
397+
hidden_states, intermediate_inputs, kv_cache = self.self_attention_with_norm_op(
388398
x,
389399
decoder_segment_ids,
390400
decoder_positions,
391401
deterministic,
392402
previous_chunk,
393403
page_state,
394404
slot,
405+
kv_cache,
406+
attention_metadata,
395407
)
396408

397409
mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(hidden_states, deterministic)

0 commit comments

Comments
 (0)