Skip to content

Commit f67d8b1

Browse files
Merge pull request #3714 from AI-Hypercomputer:chengnuojin-attn-batch
PiperOrigin-RevId: 904043717
2 parents 62674fd + f6a0110 commit f67d8b1

19 files changed

Lines changed: 78 additions & 68 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,17 @@
3232
AxisIdxes = tuple[int, ...]
3333

3434
BATCH = "activation_batch"
35+
BATCH_ATTN = "activation_batch_attn"
3536

36-
ATTN_LENGTH = "activation_attn_length"
37+
ATTN_LENGTH = "activation_length_attn"
3738

3839
LENGTH = "activation_length"
3940
PREFILL_LENGTH = "prefill_activation_length"
4041
Q_LENGTH = "activation_q_length"
4142
Q_LORA_UP_PROJ = "q_lora_up_proj"
4243
KV_LENGTH = "activation_kv_length"
4344
KV_LORA_UP_PROJ = "kv_lora_up_proj"
44-
ATTN_EMBED = "activation_attn_embed"
45+
ATTN_EMBED = "activation_embed_attn"
4546
EMBED = "activation_embed"
4647
HEAD = "activation_heads"
4748
PREFILL_KV_BATCH = "activation_prefill_kv_batch"

src/maxtext/configs/base.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,13 +464,14 @@ logical_axis_rules: [
464464
# Attention
465465
# ==========================================
466466
# Attention Activations
467+
['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
467468
['activation_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive']],
468469
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
469-
['activation_attn_length', ['sequence', 'context']],
470-
['activation_attn_length', ['context']],
470+
['activation_length_attn', ['sequence', 'context']],
471+
['activation_length_attn', ['context']],
471472
['activation_q_length', ['context']],
472473
['activation_kv_length', []],
473-
['activation_attn_embed', ['tensor', 'tensor_transpose']],
474+
['activation_embed_attn', ['tensor', 'tensor_transpose']],
474475
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
475476
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
476477
['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
@@ -514,7 +515,7 @@ logical_axis_rules: [
514515
# ==========================================
515516
# Dense Activations
516517
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
517-
# Note activation batch and length also get used in attention and vocab
518+
# Note activation batch and length also get used in vocab
518519
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
519520
['activation_length', ['sequence', 'context']],
520521
['activation_length', ['context']],

src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ logical_axis_rules: [
3131
# Attention
3232
# ==========================================
3333
# Attention Activations
34+
['activation_batch_attn', ['data', 'fsdp']],
3435
['activation_heads', []],
3536
['activation_kv_heads', []],
36-
['activation_attn_length', ['expert']],
37+
['activation_length_attn', ['expert']],
3738
['activation_q_length', ['expert']],
3839
['activation_kv_length', []],
39-
['activation_attn_embed', []],
40+
['activation_embed_attn', []],
4041
['activation_kv', []],
4142
['activation_kv_batch', ['data', 'fsdp']],
4243
['activation_kv_head_dim', []],

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,16 @@ mesh_axes: ['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']
3030
data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']]
3131
logical_axis_rules: [
3232
['activation_batch', ['data', 'fsdp', 'expert']],
33+
['activation_batch_attn', ['data', 'fsdp', 'expert']],
3334
['activation_batch_moe', ['data', 'fsdp']],
3435
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
3536
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']],
3637
['activation_heads', ['tensor']],
3738
['activation_kv_heads', ['tensor']],
3839
['activation_length', ['context']],
39-
['activation_attn_length', ['context']],
40+
['activation_length_attn', ['context']],
4041
['activation_q_length', ['context']],
41-
['activation_attn_embed', ['tensor']],
42+
['activation_embed_attn', ['tensor']],
4243
['activation_norm_length', ['context']],
4344
['activation_norm_length_moe', ['context']],
4445
['activation_embed', ['tensor']],

src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ mesh_axes: ['fsdp']
1818
data_sharding: [['fsdp']]
1919
logical_axis_rules: [
2020
['activation_batch', ['fsdp']],
21+
['activation_batch_attn', ['fsdp']],
2122
['activation_batch_moe', ['fsdp']],
2223
['activation_embed_and_logits_batch', ['fsdp']],
2324
['activation_embed_and_logits_batch_sequence', ['fsdp']],

src/maxtext/configs/inference/vllm.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,16 @@ mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
3131
logical_axis_rules: [
3232
['activation_batch', ['data']],
3333
['activation_batch_moe', ['data']],
34+
['activation_batch_attn', ['data']],
3435
['activation_embed_and_logits_batch', ['data', 'expert']],
3536
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
3637
['activation_heads', ['model', 'expert']],
3738
['activation_kv_heads', ['model', 'expert']],
38-
['activation_attn_length', []],
39+
['activation_length_attn', []],
3940
['activation_length', []],
4041
['activation_length_moe', []],
4142
['activation_q_length', ['expert', 'attn_dp_expert']],
42-
['activation_attn_embed', 'model'],
43+
['activation_embed_attn', 'model'],
4344
# Expert is missing explicitly from activation_embed despite using TP.
4445
# We are going for a replicate-AR style of TP as opposed to our typical AG-RS style of TP
4546
# due to the output sharding of the fused_moe_gmm kernel in tpu-inference.

src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'cont
6060
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
6161
logical_axis_rules: [
6262
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63+
['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6364
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6465
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6566
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],

src/maxtext/configs/models/deepseek3-671b-batchsplit.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ mesh_axes: ['data', 'fsdp', 'expert', 'context']
6363
data_sharding: [['data', 'fsdp', 'expert', 'context']]
6464
logical_axis_rules: [
6565
['activation_batch', ['data', 'fsdp', 'expert', 'context']],
66+
['activation_batch_attn', ['data', 'fsdp', 'expert', 'context']],
6667
['activation_batch_moe', ['data', 'fsdp', 'expert', 'context']],
6768
['activation_embed_and_logits_batch', ['data', 'fsdp', 'expert', 'context']],
6869
['activation_kv_batch', ['data', 'fsdp', 'expert', 'context']],

src/maxtext/layers/attention_mla.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
Array,
3737
AxisIdxes,
3838
AxisNames,
39-
BATCH,
39+
BATCH_ATTN,
4040
CACHE_BATCH,
4141
CACHE_BATCH_PREFILL,
4242
CACHE_SEQUENCE,
@@ -424,8 +424,8 @@ def mla_as_linen(
424424
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
425425
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
426426
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
427-
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
428-
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
427+
input_axis_names: AxisNames = (BATCH_ATTN, LENGTH, EMBED),
428+
out_axis_names: AxisNames = (BATCH_ATTN, LENGTH, HEAD, D_KV),
429429
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
430430
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
431431
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
@@ -562,8 +562,8 @@ def __init__(
562562
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
563563
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
564564
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
565-
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
566-
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
565+
input_axis_names: AxisNames = (BATCH_ATTN, LENGTH, EMBED),
566+
out_axis_names: AxisNames = (BATCH_ATTN, LENGTH, HEAD, D_KV),
567567
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
568568
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
569569
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
@@ -1153,7 +1153,7 @@ def __call__(
11531153
else:
11541154
inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names)
11551155
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
1156-
out_logical_name = (BATCH, LENGTH, HEAD, D_KV)
1156+
out_logical_name = (BATCH_ATTN, LENGTH, HEAD, D_KV)
11571157

11581158
if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None:
11591159
decoder_segment_ids = jnp.ones(inputs_q.shape[:2], dtype=jnp.int32)

src/maxtext/layers/attention_op.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
AttentionType,
3939
AxisIdxes,
4040
AxisNames,
41-
BATCH,
41+
BATCH_ATTN,
4242
CACHE_BATCH,
4343
CACHE_BATCH_PREFILL,
4444
CACHE_HEADS,
@@ -297,8 +297,8 @@ def attention_op_as_linen(
297297
float32_qk_product: bool = False,
298298
max_prefill_predict_length: int = -1,
299299
float32_logits: bool = False,
300-
flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
301-
flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
300+
flash_axis_names_q: AxisNames = (BATCH_ATTN, HEAD, LENGTH, D_KV),
301+
flash_axis_names_kv: AxisNames = (BATCH_ATTN, HEAD, KV_LENGTH, D_KV),
302302
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH),
303303
prefill_cache_logical_axis_names: AxisNames = (
304304
CACHE_BATCH_PREFILL,
@@ -394,8 +394,8 @@ def __init__(
394394
float32_qk_product: bool = False,
395395
max_prefill_predict_length: int = -1,
396396
float32_logits: bool = False,
397-
flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
398-
flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
397+
flash_axis_names_q: AxisNames = (BATCH_ATTN, HEAD, LENGTH, D_KV),
398+
flash_axis_names_kv: AxisNames = (BATCH_ATTN, HEAD, KV_LENGTH, D_KV),
399399
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH),
400400
prefill_cache_logical_axis_names: AxisNames = (
401401
CACHE_BATCH_PREFILL,
@@ -1144,13 +1144,13 @@ def tpu_flash_attention(
11441144
segment_axis_names_kv = None
11451145
sink_axis_names = self._logical_to_mesh_axes((HEAD,))
11461146
if decoder_segment_ids is not None:
1147-
segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH))
1148-
segment_axis_names_kv = self._logical_to_mesh_axes((BATCH, KV_LENGTH))
1147+
segment_axis_names_q = self._logical_to_mesh_axes((BATCH_ATTN, Q_LENGTH))
1148+
segment_axis_names_kv = self._logical_to_mesh_axes((BATCH_ATTN, KV_LENGTH))
11491149

11501150
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel)
11511151
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q)
11521152
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv)
1153-
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH))
1153+
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH_ATTN, Q_LENGTH, KV_LENGTH))
11541154

11551155
global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv
11561156
global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel
@@ -1730,7 +1730,7 @@ def compute_local_attention(
17301730
if model_mode == MODEL_MODE_AUTOREGRESSIVE and self.is_partition_in_decode(q_seq_len):
17311731
local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV))
17321732
elif model_mode == MODEL_MODE_PREFILL:
1733-
local_out = partitioning.with_sharding_constraint(local_out, (BATCH, KV_LENGTH, HEAD, D_KV))
1733+
local_out = partitioning.with_sharding_constraint(local_out, (BATCH_ATTN, KV_LENGTH, HEAD, D_KV))
17341734

17351735
if self.reshape_q and q_seq_len == 1:
17361736
local_max = local_max[:, 0:1, :, :]
@@ -1774,7 +1774,7 @@ def apply_attention_dot(
17741774

17751775
# special sharding for decode
17761776
q_seq_len = query.shape[1]
1777-
prefill_qkv_sharding = (BATCH, PREFILL_LENGTH, HEAD, D_KV)
1777+
prefill_qkv_sharding = (BATCH_ATTN, PREFILL_LENGTH, HEAD, D_KV)
17781778
decode_qkv_sharding = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)
17791779
if self.is_partition_in_decode(q_seq_len):
17801780
query = partitioning.with_sharding_constraint(query, decode_qkv_sharding)
@@ -1799,7 +1799,9 @@ def apply_attention_dot(
17991799
if self.is_partition_in_decode(q_seq_len):
18001800
attn_weights = partitioning.with_sharding_constraint(attn_weights, (KV_LENGTH, HEAD, None, None, None))
18011801
elif model_mode == MODEL_MODE_PREFILL:
1802-
attn_weights = partitioning.with_sharding_constraint(attn_weights, (BATCH, HEAD, None, PREFILL_LENGTH, KV_LENGTH))
1802+
attn_weights = partitioning.with_sharding_constraint(
1803+
attn_weights, (BATCH_ATTN, HEAD, None, PREFILL_LENGTH, KV_LENGTH)
1804+
)
18031805

18041806
if self.attn_logits_soft_cap:
18051807
attn_weights = jnp.tanh(attn_weights / self.attn_logits_soft_cap)
@@ -1846,7 +1848,7 @@ def apply_attention_dot(
18461848
if self.is_partition_in_decode(q_seq_len):
18471849
attn_mask = partitioning.with_sharding_constraint(attn_mask, (KV_LENGTH, HEAD, None, None, None))
18481850
elif model_mode == MODEL_MODE_PREFILL:
1849-
attn_mask = partitioning.with_sharding_constraint(attn_mask, (BATCH, HEAD, None, PREFILL_LENGTH, KV_LENGTH))
1851+
attn_mask = partitioning.with_sharding_constraint(attn_mask, (BATCH_ATTN, HEAD, None, PREFILL_LENGTH, KV_LENGTH))
18501852
if attn_mask is not None:
18511853
attn_weights = apply_mask_to_logits(attn_weights, attn_mask)
18521854

0 commit comments

Comments
 (0)