Skip to content

Commit 49ea7f7

Browse files
committed
deepseek sharding and mla attention plumbing
1 parent 17d805e commit 49ea7f7

9 files changed

Lines changed: 234 additions & 23 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ logical_axis_rules: [
451451
['decode_length', ['sequence']],
452452
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
453453
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
454+
['moe_mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
454455
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
455456
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
456457
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
@@ -1096,6 +1097,8 @@ vllm_hf_config_path: ""
10961097
# A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter.
10971098
# This can be used to override specific settings without modifying the original config file.
10981099
vllm_hf_overrides: {}
1100+
# Path to yaml file for loading vLLM config
1101+
vllm_config_path: ""
10991102
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
11001103
vllm_additional_config: {}
11011104
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]

src/maxtext/configs/inference/vllm.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ logical_axis_rules: [
5353
['decode_length', []],
5454
['mlp', ['model', 'attn_dp']],
5555
['mlp_no_fsdp', ['model', 'attn_dp']],
56+
['moe_mlp', ['model', 'attn_dp']],
5657
['vocab', ['model', 'attn_dp']],
5758
['heads', ['model']],
5859
['q_heads', ['model']],
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
base_config: "vllm.yml"
17+
18+
logical_axis_rules: [
19+
['activation_batch', ['']],
20+
['activation_batch_no_exp', []],
21+
['activation_embed_and_logits_batch', ['expert']],
22+
['activation_embed_and_logits_batch_sequence', ['expert']],
23+
['activation_heads', ['model']],
24+
['activation_kv_heads', ['model']],
25+
['activation_attn_length', ['expert']],
26+
['activation_attn_length_no_exp', []],
27+
['activation_length', ['data', 'expert']],
28+
['activation_length_no_exp', 'data'],
29+
['activation_q_length', ['expert']],
30+
['activation_attn_embed', 'model'],
31+
['activation_embed', ['model', 'attn_dp']],
32+
['activation_mlp', ['model', 'attn_dp', 'expert']],
33+
['activation_kv', ['model']],
34+
['activation_prefill_kv_batch', ['expert']],
35+
['activation_kv_batch', ['']],
36+
['activation_kv_batch_no_exp', []],
37+
['activation_kv_head_dim', ['model', 'attn_dp', 'expert']],
38+
['activation_vocab', ['model', 'attn_dp']],
39+
['activation_norm_length', []],
40+
['activation_exp', ['expert']],
41+
['decode_batch', ['expert']],
42+
['decode_length', []],
43+
['mlp_no_fsdp', ['model', 'attn_dp', 'expert']],
44+
['vocab', ['model', 'attn_dp', 'expert']],
45+
['heads', ['expert', 'attn_dp', 'model']],
46+
['q_heads', []],
47+
['kv_heads', []],
48+
['kv_head_dim', ['model', 'attn_dp', 'expert']],
49+
['kv', ['model', 'attn_dp', 'expert']],
50+
['kv', []],
51+
['embed', []],
52+
['mlp', ['model', 'attn_dp', 'expert']],
53+
['moe_mlp', []],
54+
['embed_tensor_transpose', ['attn_dp', 'model']],
55+
['embed_no_exp', []],
56+
['q_lora', []],
57+
['kv_lora', []],
58+
['norm', []],
59+
['cache_heads', ['model']],
60+
['exp', ['expert', 'attn_dp', 'model']],
61+
['paged_kv_heads', ['model']],
62+
['cache_batch_prefill', []],
63+
['cache_batch', []],
64+
['cache_sequence', []],
65+
['cache_heads_none', []],
66+
['cache_kv', []],
67+
['kv_lora_up_proj',['expert', 'attn_dp', 'model']],
68+
['q_lora_up_proj',['expert', 'attn_dp', 'model']],
69+
]

src/maxtext/configs/post_train/rl.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ enable_dp_attention: False
149149
# Performance tuning for samplers
150150
max_num_batched_tokens: null
151151
max_num_seqs: null
152+
# path to initialize vllm config
153+
vllm_config_path: 'src/maxtext/configs/inference/vllm.yml'
152154

153155
# ====== Checkpoint Configuration ======
154156
enable_checkpointing: True

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,6 +1570,7 @@ class VLLM(BaseModel):
15701570
default_factory=dict, description="Overrides for HuggingFace model config for MaxText model."
15711571
)
15721572
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
1573+
vllm_config_path: str = Field("src/maxtext/configs/inference/vllm.yml", description="path to yaml file for loading vLLM config.")
15731574

15741575

15751576
class RL(BaseModel):

src/maxtext/layers/attention_mla.py

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import jax
2222
from jax.ad_checkpoint import checkpoint_name
2323
from jax.experimental import layout
24+
from jax.sharding import PartitionSpec as P
25+
from jax.experimental import shard_map
2426
import jax.numpy as jnp
2527
from jax.sharding import Mesh, NamedSharding
2628

@@ -623,7 +625,11 @@ def __init__(
623625
)
624626

625627
# Module attribute names must match names previously passed to Linen for checkpointing
626-
self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None
628+
self.MlaKVCache_0 = (
629+
self.init_mla_kv_caches(inputs_kv_shape)
630+
if model_mode != MODEL_MODE_TRAIN and config.attention != "vllm_rpa"
631+
else None
632+
)
627633

628634
def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
629635
"""Initializes the MLA-specific projections."""
@@ -941,15 +947,118 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
941947

942948
key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode)
943949
cached_values = [None, None]
944-
if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN:
950+
if self.config.attention != "paged" and self.config.attention != "vllm_rpa" and model_mode != MODEL_MODE_TRAIN:
945951
if self.config.mla_naive_kvcache:
946952
cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)
947953
else:
948954
cached_values = self.update_mla_kv_caches(
949955
low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk
950956
)
951957

952-
return key, value, cached_values
958+
return key, value, cached_values, low_rank_main, key_rope
959+
960+
def mla_rpa_vllm(self, q_nope, q_rope, k_latent, k_rope, mla_kv_cache, mla_metadata):
961+
"""Forward function for vLLM serving with MLA attention.
962+
963+
Args:
964+
q_nope: Query nope part [T, N, qk_nope_head_dim]
965+
q_rope: Query rope part [T, N, qk_rope_head_dim]
966+
k_latent: Latent KV representation [S, kv_lora_rank] (NOT expanded k_nope)
967+
k_rope: Key rope part [S, qk_rope_head_dim] (NO head dimension)
968+
mla_kv_cache: The KV cache
969+
mla_metadata: Attention metadata
970+
"""
971+
md = mla_metadata
972+
try:
973+
# pylint: disable=import-outside-toplevel
974+
# pytype: disable=import-error
975+
from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention
976+
from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import get_tuned_block_sizes
977+
except ImportError as e:
978+
raise ImportError(
979+
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
980+
) from e
981+
982+
if mla_kv_cache is None or mla_metadata is None:
983+
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
984+
985+
wkv_b_kernel = self.wkv_b.kernel.value
986+
wk_b_kernel = wkv_b_kernel[..., : self.qk_nope_head_dim]
987+
wv_b_kernel = wkv_b_kernel[..., self.qk_nope_head_dim :]
988+
q_absorbed = jnp.einsum("TNH,ANH->TNA", q_nope, wk_b_kernel)
989+
990+
def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args):
991+
def _initialize_block_sizes():
992+
# Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes)
993+
max_num_tokens = q_absorbed.shape[0]
994+
max_num_seqs = md.seq_lens.shape[0]
995+
num_page_indices = md.block_tables.shape[0]
996+
assert num_page_indices % max_num_seqs == 0
997+
pages_per_seq = num_page_indices // max_num_seqs
998+
# num_kv_pages_per_block = min(pages_per_seq, 16)
999+
bkv_p, bq_sz = get_tuned_block_sizes(
1000+
q_nope.dtype,
1001+
q_nope.dtype, # changed to q_nope dtype from mla_kv_cache.dtype
1002+
self.num_query_heads,
1003+
1, # num_kv_heads for MLA kernel
1004+
self.qk_nope_head_dim,
1005+
q_nope.shape[1], # page size ?? kv_cache.shape[1]
1006+
max_num_tokens,
1007+
pages_per_seq,
1008+
)
1009+
num_kv_pages_per_block = min(pages_per_seq, bkv_p, 4)
1010+
num_queries_per_block = min(max_num_tokens, bq_sz, 4) # OOMS at 8
1011+
return num_kv_pages_per_block, num_queries_per_block
1012+
1013+
num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes()
1014+
output, kv_cache = mla_ragged_paged_attention(
1015+
q,
1016+
q_rope,
1017+
k,
1018+
k_rope,
1019+
kv_cache,
1020+
*args,
1021+
sm_scale=1.0,
1022+
num_kv_pages_per_block=num_kv_pages_per_block,
1023+
num_queries_per_block=num_queries_per_block,
1024+
)
1025+
return kv_cache, output
1026+
1027+
in_specs = (
1028+
P(("attn_dp", "model", "expert"), None, None), # q
1029+
P(("attn_dp", "model", "expert"), None, None), # q_rope
1030+
P(("attn_dp", "model", "expert"), None), # k
1031+
P(("attn_dp", "model", "expert"), None), # k_rope
1032+
P(("attn_dp", "model", "expert")), # kv_cache
1033+
P(("data", "attn_dp")), # md.seq_lens: Replicated
1034+
P(("data", "attn_dp")), # page_indices_flat: Replicated
1035+
P(("data", "attn_dp")), # query_start_loc: Replicated
1036+
P(("data", "attn_dp")), # distribution: Replicated
1037+
)
1038+
1039+
out_specs = (P(("attn_dp", "model", "expert"), None, None), P(("attn_dp", "model", "expert")))
1040+
1041+
kv_cache, output = jax.jit(
1042+
shard_map.shard_map(
1043+
_mla_ragged_paged_attention,
1044+
mesh=self.mesh,
1045+
in_specs=in_specs,
1046+
out_specs=out_specs,
1047+
check_rep=False,
1048+
),
1049+
)(
1050+
q_absorbed,
1051+
q_rope,
1052+
k_latent,
1053+
k_rope,
1054+
mla_kv_cache,
1055+
md.seq_lens,
1056+
md.block_tables,
1057+
md.query_start_loc,
1058+
md.request_distribution,
1059+
)
1060+
output = jnp.einsum("TNA,ANH->TNH", output, wv_b_kernel)
1061+
return kv_cache, output
9531062

9541063
def __call__(
9551064
self,
@@ -1005,7 +1114,7 @@ def __call__(
10051114
query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
10061115
if self.config.force_q_layout:
10071116
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
1008-
key, value, cached_values = self.mla_kv_projection(
1117+
key, value, cached_values, low_rank_main, key_rope = self.mla_kv_projection(
10091118
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
10101119
)
10111120
query = checkpoint_name(query, "query_proj")
@@ -1039,7 +1148,22 @@ def __call__(
10391148
)
10401149
unnormalized_out = unnormalized_out[..., : self.v_head_dim]
10411150
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
1151+
elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN and kv_cache is not None:
1152+
batch, seq_len, num_heads, _ = query.shape
1153+
query = query.reshape(-1, query.shape[2], query.shape[3])
1154+
q_nope, q_rope = jnp.split(query, [self.qk_nope_head_dim], axis=-1)
1155+
1156+
k_latent = low_rank_main.reshape(-1, self.kv_lora_rank)
1157+
k_rope_squeezed = key_rope.reshape(-1, self.qk_rope_head_dim)
1158+
1159+
updated_kv, attn_out = self.mla_rpa_vllm(
1160+
q_nope, q_rope, k_latent, k_rope_squeezed, mla_kv_cache=kv_cache, mla_metadata=attention_metadata
1161+
)
1162+
out = attn_out.reshape(batch, seq_len, num_heads, self.v_head_dim)
1163+
kv_cache = updated_kv
10421164
else:
1165+
if self.config.attention == "vllm_rpa" and kv_cache is None and model_mode != MODEL_MODE_TRAIN:
1166+
model_mode = MODEL_MODE_TRAIN
10431167
out = self.attention_op(
10441168
query,
10451169
key,

src/maxtext/layers/moe.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -351,16 +351,16 @@ def __init__(
351351

352352
if self.config.shard_exp_on_fsdp:
353353
# special sharding for dsv3
354-
self.wi_kernel_axes = ("embed_no_exp", None, "mlp")
355-
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
354+
self.wi_kernel_axes = ("embed_no_exp", None, "moe_mlp")
355+
self.wo_kernel_axes = ("embed_no_exp", "moe_mlp", None)
356356
elif self.config.use_2d_fsdp_sharding:
357357
self.wi_kernel_axes = ("embed_no_exp", "mlp", None)
358358
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
359359
elif self.config.use_batch_split_schedule:
360360
self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes()
361361
else:
362-
self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
363-
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")
362+
self.wi_kernel_axes = ("exp", "embed_no_exp", "moe_mlp")
363+
self.wo_kernel_axes = ("exp", "moe_mlp", "embed_no_exp")
364364

365365
if self.config.attention == "vllm_rpa":
366366
# vLLM uses 'model' as the tensor parallelism axis name
@@ -1378,11 +1378,11 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13781378

13791379
if self.config.moe_fsdp_use_two_stage_all_gather:
13801380
# Unshard on fsdp axis
1381-
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp"))
1382-
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp"))
1381+
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "moe_mlp"))
1382+
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "moe_mlp"))
13831383

13841384
# Unshard on fsdp_transpose axis
1385-
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp", "embed_tensor_transpose"))
1385+
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "moe_mlp", "embed_tensor_transpose"))
13861386

13871387
# Make sure XLA does not optimize by combining above All-Gather to unshard
13881388
# on FSDP axis and the subsequent unshard on fsdp_transpose axis
@@ -1830,7 +1830,7 @@ def dense_matmul(
18301830
dispatch_axis,
18311831
)
18321832
with jax.named_scope("wi_0"):
1833-
w0_kernel_axes = ("exp", None, "mlp")
1833+
w0_kernel_axes = ("exp", None, "moe_mlp")
18341834
w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes)
18351835
layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)(
18361836
mlp_up_einsum, dispatch, w0_kernel, precision=matmul_precision
@@ -1847,7 +1847,7 @@ def dense_matmul(
18471847
)
18481848
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
18491849
with jax.named_scope("wi_1"):
1850-
w1_kernel_axes = ("exp", None, "mlp")
1850+
w1_kernel_axes = ("exp", None, "moe_mlp")
18511851
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes)
18521852
layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)(
18531853
mlp_up_einsum, dispatch, w1_kernel, precision=matmul_precision
@@ -1864,7 +1864,7 @@ def dense_matmul(
18641864
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
18651865
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
18661866
with jax.named_scope("wo"):
1867-
wo_kernel_axes = ("exp", "mlp", None)
1867+
wo_kernel_axes = ("exp", "moe_mlp", None)
18681868
wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes)
18691869
intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)(
18701870
mlp_down_einsum,

0 commit comments

Comments
 (0)