Skip to content

Commit a78bdab

Browse files
committed
deepseek sharding and mla attention plumbing
1 parent 496ed40 commit a78bdab

9 files changed

Lines changed: 253 additions & 23 deletions

File tree

src/MaxText/layers/attention_mla.py

Lines changed: 128 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import jax
2222
from jax.ad_checkpoint import checkpoint_name
2323
from jax.experimental import layout
24+
from jax.sharding import PartitionSpec as P
25+
from jax.experimental import shard_map
2426
import jax.numpy as jnp
2527
from jax.sharding import Mesh, NamedSharding
2628

@@ -619,7 +621,11 @@ def __init__(
619621
)
620622

621623
# Module attribute names must match names previously passed to Linen for checkpointing
622-
self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None
624+
self.MlaKVCache_0 = (
625+
self.init_mla_kv_caches(inputs_kv_shape)
626+
if model_mode != MODEL_MODE_TRAIN and config.attention != "vllm_rpa"
627+
else None
628+
)
623629

624630
def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
625631
"""Initializes the MLA-specific projections."""
@@ -937,15 +943,118 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
937943

938944
key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode)
939945
cached_values = [None, None]
940-
if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN:
946+
if self.config.attention != "paged" and self.config.attention != "vllm_rpa" and model_mode != MODEL_MODE_TRAIN:
941947
if self.config.mla_naive_kvcache:
942948
cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)
943949
else:
944950
cached_values = self.update_mla_kv_caches(
945951
low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk
946952
)
947953

948-
return key, value, cached_values
954+
return key, value, cached_values, low_rank_main, key_rope
955+
956+
def mla_rpa_vllm(self, q_nope, q_rope, k_latent, k_rope, mla_kv_cache, mla_metadata):
957+
"""Forward function for vLLM serving with MLA attention.
958+
959+
Args:
960+
q_nope: Query nope part [T, N, qk_nope_head_dim]
961+
q_rope: Query rope part [T, N, qk_rope_head_dim]
962+
k_latent: Latent KV representation [S, kv_lora_rank] (NOT expanded k_nope)
963+
k_rope: Key rope part [S, qk_rope_head_dim] (NO head dimension)
964+
mla_kv_cache: The KV cache
965+
mla_metadata: Attention metadata
966+
"""
967+
md = mla_metadata
968+
try:
969+
# pylint: disable=import-outside-toplevel
970+
# pytype: disable=import-error
971+
from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention
972+
from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import get_tuned_block_sizes
973+
except ImportError as e:
974+
raise ImportError(
975+
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
976+
) from e
977+
978+
if mla_kv_cache is None or mla_metadata is None:
979+
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
980+
981+
wkv_b_kernel = self.wkv_b.kernel.value
982+
wk_b_kernel = wkv_b_kernel[..., : self.qk_nope_head_dim]
983+
wv_b_kernel = wkv_b_kernel[..., self.qk_nope_head_dim :]
984+
q_absorbed = jnp.einsum("TNH,ANH->TNA", q_nope, wk_b_kernel)
985+
986+
def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args):
987+
def _initialize_block_sizes():
988+
# Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes)
989+
max_num_tokens = q_absorbed.shape[0]
990+
max_num_seqs = md.seq_lens.shape[0]
991+
num_page_indices = md.block_tables.shape[0]
992+
assert num_page_indices % max_num_seqs == 0
993+
pages_per_seq = num_page_indices // max_num_seqs
994+
# num_kv_pages_per_block = min(pages_per_seq, 16)
995+
bkv_p, bq_sz = get_tuned_block_sizes(
996+
q_nope.dtype,
997+
q_nope.dtype, # changed to q_nope dtype from mla_kv_cache.dtype
998+
self.num_query_heads,
999+
1, # num_kv_heads for MLA kernel
1000+
self.qk_nope_head_dim,
1001+
q_nope.shape[1], # page size ?? kv_cache.shape[1]
1002+
max_num_tokens,
1003+
pages_per_seq,
1004+
)
1005+
num_kv_pages_per_block = min(pages_per_seq, bkv_p, 4)
1006+
num_queries_per_block = min(max_num_tokens, bq_sz, 4) # OOMS at 8
1007+
return num_kv_pages_per_block, num_queries_per_block
1008+
1009+
num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes()
1010+
output, kv_cache = mla_ragged_paged_attention(
1011+
q,
1012+
q_rope,
1013+
k,
1014+
k_rope,
1015+
kv_cache,
1016+
*args,
1017+
sm_scale=1.0,
1018+
num_kv_pages_per_block=num_kv_pages_per_block,
1019+
num_queries_per_block=num_queries_per_block,
1020+
)
1021+
return kv_cache, output
1022+
1023+
in_specs = (
1024+
P(("attn_dp", "model", "expert"), None, None), # q
1025+
P(("attn_dp", "model", "expert"), None, None), # q_rope
1026+
P(("attn_dp", "model", "expert"), None), # k
1027+
P(("attn_dp", "model", "expert"), None), # k_rope
1028+
P(("attn_dp", "model", "expert")), # kv_cache
1029+
P(("data", "attn_dp")), # md.seq_lens: Replicated
1030+
P(("data", "attn_dp")), # page_indices_flat: Replicated
1031+
P(("data", "attn_dp")), # query_start_loc: Replicated
1032+
P(("data", "attn_dp")), # distribution: Replicated
1033+
)
1034+
1035+
out_specs = (P(("attn_dp", "model", "expert"), None, None), P(("attn_dp", "model", "expert")))
1036+
1037+
kv_cache, output = jax.jit(
1038+
shard_map.shard_map(
1039+
_mla_ragged_paged_attention,
1040+
mesh=self.mesh,
1041+
in_specs=in_specs,
1042+
out_specs=out_specs,
1043+
check_rep=False,
1044+
),
1045+
)(
1046+
q_absorbed,
1047+
q_rope,
1048+
k_latent,
1049+
k_rope,
1050+
mla_kv_cache,
1051+
md.seq_lens,
1052+
md.block_tables,
1053+
md.query_start_loc,
1054+
md.request_distribution,
1055+
)
1056+
output = jnp.einsum("TNA,ANH->TNH", output, wv_b_kernel)
1057+
return kv_cache, output
9491058

9501059
def __call__(
9511060
self,
@@ -1001,7 +1110,7 @@ def __call__(
10011110
query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
10021111
if self.config.force_q_layout:
10031112
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
1004-
key, value, cached_values = self.mla_kv_projection(
1113+
key, value, cached_values, low_rank_main, key_rope = self.mla_kv_projection(
10051114
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
10061115
)
10071116
query = checkpoint_name(query, "query_proj")
@@ -1032,8 +1141,22 @@ def __call__(
10321141
)
10331142
unnormalized_out = unnormalized_out[..., : self.v_head_dim]
10341143
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
1144+
elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN and kv_cache is not None:
1145+
batch, seq_len, num_heads, _ = query.shape
1146+
query = query.reshape(-1, query.shape[2], query.shape[3])
1147+
q_nope, q_rope = jnp.split(query, [self.qk_nope_head_dim], axis=-1)
1148+
1149+
k_latent = low_rank_main.reshape(-1, self.kv_lora_rank)
1150+
k_rope_squeezed = key_rope.reshape(-1, self.qk_rope_head_dim)
1151+
1152+
updated_kv, attn_out = self.mla_rpa_vllm(
1153+
q_nope, q_rope, k_latent, k_rope_squeezed, mla_kv_cache=kv_cache, mla_metadata=attention_metadata
1154+
)
1155+
out = attn_out.reshape(batch, seq_len, num_heads, self.v_head_dim)
1156+
kv_cache = updated_kv
10351157
else:
1036-
# Pass the index_mask to the Attention Op
1158+
if self.config.attention == "vllm_rpa" and kv_cache is None and model_mode != MODEL_MODE_TRAIN:
1159+
model_mode = MODEL_MODE_TRAIN
10371160
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values, index_mask=index_mask)
10381161

10391162
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")

src/MaxText/layers/deepseek.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# pylint: disable=arguments-differ
1717
# pylint: disable=no-name-in-module
1818

19-
from typing import Optional
19+
from typing import Optional, Any
2020

2121
from flax import nnx
2222
from jax.ad_checkpoint import checkpoint_name
@@ -154,9 +154,11 @@ def attention_op(
154154
previous_chunk=None,
155155
page_state: None | page_manager.PageState = None,
156156
slot: None | int = None,
157+
kv_cache: None | jnp.ndarray = None,
158+
attention_metadata: None | dict[str, Any] = None,
157159
):
158160
"""Executes the attention layer."""
159-
attention_result, _ = self.self_attention(
161+
attention_result, kv_cache = self.self_attention(
160162
x,
161163
x,
162164
decoder_positions,
@@ -167,8 +169,10 @@ def attention_op(
167169
previous_chunk=previous_chunk,
168170
page_state=page_state,
169171
slot=slot,
172+
kv_cache=kv_cache,
173+
attention_metadata=attention_metadata,
170174
)
171-
return self.with_logical_constraint(attention_result)
175+
return self.with_logical_constraint(attention_result), kv_cache
172176

173177
@property
174178
def logical_axis_names(self):
@@ -229,23 +233,27 @@ def self_attention_with_norm_op(
229233
previous_chunk=None,
230234
page_state: None | page_manager.PageState = None,
231235
slot: None | int = None,
236+
kv_cache: None | jnp.ndarray = None,
237+
attention_metadata: None | dict[str, Any] = None,
232238
):
233239
"""self-attention with normalization"""
234240
lnx = self.pre_attention_norm_op(inputs)
235241

236-
attention_lnx = self.attention_op(
242+
attention_lnx, kv_cache = self.attention_op(
237243
lnx,
238244
decoder_segment_ids,
239245
decoder_positions,
240246
deterministic,
241247
previous_chunk,
242248
page_state,
243249
slot,
250+
kv_cache,
251+
attention_metadata,
244252
)
245253
intermediate_inputs = inputs + attention_lnx
246254
# Normalization
247255
hidden_states = self.post_attention_norm_op(intermediate_inputs)
248-
return hidden_states, intermediate_inputs
256+
return hidden_states, intermediate_inputs, kv_cache
249257

250258

251259
class DeepSeekDenseLayer(DeepSeekGenericLayer):
@@ -298,14 +306,16 @@ def __call__(
298306
x = self.with_logical_constraint(inputs)
299307
x = checkpoint_name(x, "decoder_layer_input")
300308

301-
hidden_states, intermediate_inputs = self.self_attention_with_norm_op(
309+
hidden_states, intermediate_inputs, kv_cache = self.self_attention_with_norm_op(
302310
x,
303311
decoder_segment_ids,
304312
decoder_positions,
305313
deterministic,
306314
previous_chunk,
307315
page_state,
308316
slot,
317+
kv_cache,
318+
attention_metadata,
309319
)
310320

311321
mlp_lnx = self.mlp_op(hidden_states, deterministic)
@@ -384,14 +394,16 @@ def __call__(
384394
x = self.with_logical_constraint(inputs)
385395
x = checkpoint_name(x, "decoder_layer_input")
386396

387-
hidden_states, intermediate_inputs = self.self_attention_with_norm_op(
397+
hidden_states, intermediate_inputs, kv_cache = self.self_attention_with_norm_op(
388398
x,
389399
decoder_segment_ids,
390400
decoder_positions,
391401
deterministic,
392402
previous_chunk,
393403
page_state,
394404
slot,
405+
kv_cache,
406+
attention_metadata,
395407
)
396408

397409
mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(hidden_states, deterministic)

src/MaxText/layers/moe.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -351,16 +351,16 @@ def __init__(
351351

352352
if self.config.shard_exp_on_fsdp:
353353
# special sharding for dsv3
354-
self.wi_kernel_axes = ("embed_no_exp", None, "mlp")
355-
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
354+
self.wi_kernel_axes = ("embed_no_exp", None, "moe_mlp")
355+
self.wo_kernel_axes = ("embed_no_exp", "moe_mlp", None)
356356
elif self.config.use_2d_fsdp_sharding:
357357
self.wi_kernel_axes = ("embed_no_exp", "mlp", None)
358358
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
359359
elif self.config.use_batch_split_schedule:
360360
self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes()
361361
else:
362-
self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
363-
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")
362+
self.wi_kernel_axes = ("exp", "embed_no_exp", "moe_mlp")
363+
self.wo_kernel_axes = ("exp", "moe_mlp", "embed_no_exp")
364364

365365
if self.config.attention == "vllm_rpa":
366366
# vLLM uses 'model' as the tensor parallelism axis name
@@ -1377,11 +1377,11 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13771377

13781378
if self.config.moe_fsdp_use_two_stage_all_gather:
13791379
# Unshard on fsdp axis
1380-
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp"))
1381-
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp"))
1380+
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "moe_mlp"))
1381+
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "moe_mlp"))
13821382

13831383
# Unshard on fsdp_transpose axis
1384-
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp", "embed_tensor_transpose"))
1384+
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "moe_mlp", "embed_tensor_transpose"))
13851385

13861386
# Make sure XLA does not optimize by combining above All-Gather to unshard
13871387
# on FSDP axis and the subsequent unshard on fsdp_transpose axis
@@ -1829,7 +1829,7 @@ def dense_matmul(
18291829
dispatch_axis,
18301830
)
18311831
with jax.named_scope("wi_0"):
1832-
w0_kernel_axes = ("exp", None, "mlp")
1832+
w0_kernel_axes = ("exp", None, "moe_mlp")
18331833
w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes)
18341834
layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)(
18351835
mlp_up_einsum, dispatch, w0_kernel, precision=matmul_precision
@@ -1846,7 +1846,7 @@ def dense_matmul(
18461846
)
18471847
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
18481848
with jax.named_scope("wi_1"):
1849-
w1_kernel_axes = ("exp", None, "mlp")
1849+
w1_kernel_axes = ("exp", None, "moe_mlp")
18501850
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes)
18511851
layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)(
18521852
mlp_up_einsum, dispatch, w1_kernel, precision=matmul_precision
@@ -1863,7 +1863,7 @@ def dense_matmul(
18631863
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
18641864
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
18651865
with jax.named_scope("wo"):
1866-
wo_kernel_axes = ("exp", "mlp", None)
1866+
wo_kernel_axes = ("exp", "moe_mlp", None)
18671867
wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes)
18681868
intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)(
18691869
mlp_down_einsum,

src/maxtext/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ logical_axis_rules: [
443443
['decode_length', ['sequence']],
444444
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
445445
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
446+
['moe_mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
446447
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
447448
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
448449
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
@@ -1083,6 +1084,8 @@ use_jax_splash: false
10831084
# vLLM Adapter Configurations
10841085
# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter)
10851086
vllm_hf_config_path: ""
1087+
# Path to yaml file for loading vLLM config
1088+
vllm_config_path: ""
10861089
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
10871090
vllm_additional_config: {}
10881091
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]

src/maxtext/configs/inference/vllm.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ logical_axis_rules: [
5353
['decode_length', []],
5454
['mlp', ['model', 'attn_dp']],
5555
['mlp_no_fsdp', ['model', 'attn_dp']],
56+
['moe_mlp', ['model', 'attn_dp']],
5657
['vocab', ['model', 'attn_dp']],
5758
['heads', ['model']],
5859
['q_heads', ['model']],

src/maxtext/configs/post_train/rl.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ enable_dp_attention: False
149149
# Performance tuning for samplers
150150
max_num_batched_tokens: null
151151
max_num_seqs: null
152+
# path to initialize vllm config
153+
vllm_config_path: 'src/MaxText/configs/vllm.yml'
152154

153155
# ====== Checkpoint Configuration ======
154156
enable_checkpointing: True

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,6 +1557,7 @@ class VLLM(BaseModel):
15571557
max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.")
15581558
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
15591559
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
1560+
vllm_config_path: str = Field("src/MaxText/configs/vllm.yml", description="path to yaml file for loading vLLM config.")
15601561

15611562

15621563
class RL(BaseModel):

0 commit comments

Comments
 (0)