Skip to content

Commit 7e3f19f

Browse files
Merge pull request #3624 from AI-Hypercomputer:chengnuojin-no-exp2
PiperOrigin-RevId: 897341308
2 parents 6f3c5a4 + 1d60d2d commit 7e3f19f

19 files changed

Lines changed: 96 additions & 124 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,10 @@
3434
BATCH = "activation_batch"
3535

3636
ATTN_LENGTH = "activation_attn_length"
37-
ATTN_LENGTH_NO_EXP = "activation_attn_length_no_exp"
3837

3938
LENGTH = "activation_length"
4039
PREFILL_LENGTH = "prefill_activation_length"
4140
Q_LENGTH = "activation_q_length"
42-
Q_LENGTH_NO_EXP = "activation_q_length_no_exp"
4341
Q_LORA_UP_PROJ = "q_lora_up_proj"
4442
KV_LENGTH = "activation_kv_length"
4543
KV_LORA_UP_PROJ = "kv_lora_up_proj"
@@ -48,7 +46,6 @@
4846
HEAD = "activation_heads"
4947
PREFILL_KV_BATCH = "activation_prefill_kv_batch"
5048
KV_BATCH = "activation_kv_batch"
51-
KV_BATCH_NO_EXP = "activation_kv_batch_no_exp"
5249
KV_HEAD = "activation_kv_heads"
5350
KV_HEAD_DIM = "activation_kv_head_dim"
5451
D_KV = "activation_kv"

src/maxtext/configs/base.yml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,7 @@ logical_axis_rules: [
469469
['activation_length_moe', ['context']],
470470
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
471471
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
472-
['activation_q_length', ['context', 'expert']],
473-
['activation_q_length_no_exp', ['context']],
472+
['activation_q_length', ['context']],
474473
['prefill_activation_length', ['sequence', 'context']],
475474
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
476475
['activation_kv_length', []],
@@ -480,8 +479,7 @@ logical_axis_rules: [
480479
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
481480
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
482481
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
483-
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
484-
['activation_kv_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
482+
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose']],
485483
['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
486484
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
487485
['activation_vocab', ['tensor', 'tensor_transpose']],
@@ -978,7 +976,7 @@ xprof_e2e_enable_fw_power_level_event: False
978976
xprof_e2e_enable_fw_thermal_event: False
979977
profile_power_events: False # Set to True to enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.
980978

981-
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
979+
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
982980
debug_sharding: False # Prints model weights sharding info
983981

984982
# Checkpoint Structured logging

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,17 @@ logical_axis_rules: [
3636
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']],
3737
['activation_heads', ['tensor']],
3838
['activation_kv_heads', ['tensor']],
39-
['activation_length', ['context', 'expert']],
40-
['activation_attn_length', ['context', 'expert']],
41-
['activation_q_length', ['context', 'expert']],
39+
['activation_length', ['context']],
40+
['activation_attn_length', ['context']],
41+
['activation_q_length', ['context']],
4242
['activation_attn_embed', ['tensor']],
4343
['activation_norm_length', ['context']],
4444
['activation_norm_length_moe', ['context']],
4545
['activation_embed', ['tensor']],
4646
['activation_embed_moe', ['tensor']],
4747
['activation_mlp', ['tensor']],
4848
['activation_kv', ['tensor']],
49-
['activation_kv_batch', ['data', 'fsdp', 'expert']],
50-
['activation_kv_batch_no_exp', ['data', 'fsdp']],
49+
['activation_kv_batch', ['data', 'fsdp']],
5150
['activation_kv_head_dim', ['tensor']],
5251
['activation_vocab', ['tensor']],
5352
['activation_stage', 'stage'],

src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ logical_axis_rules: [
2424
['activation_embed_and_logits_batch_sequence', ['fsdp']],
2525
['activation_prefill_kv_batch', ['fsdp']],
2626
['activation_kv_batch', ['fsdp']],
27-
['activation_kv_batch_no_exp', ['fsdp']],
2827
['decode_batch', ['fsdp']],
2928
['embed', ['fsdp']],
3029
['embed_no_exp', ['fsdp']],

src/maxtext/configs/inference/vllm.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ logical_axis_rules: [
3636
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
3737
['activation_heads', ['model', 'expert']],
3838
['activation_kv_heads', ['model', 'expert']],
39-
['activation_attn_length', ['expert']],
40-
['activation_attn_length_no_exp', []],
39+
['activation_attn_length', []],
4140
['activation_length', ['data']],
4241
['activation_length_moe', ['data', 'expert']],
4342
['activation_length_moe', 'data'],
@@ -48,8 +47,7 @@ logical_axis_rules: [
4847
['activation_mlp', ['model', 'attn_dp']],
4948
['activation_kv', ['model']],
5049
['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
51-
['activation_kv_batch', ['data', 'expert', 'attn_dp_expert']],
52-
['activation_kv_batch_no_exp', ['data']],
50+
['activation_kv_batch', ['data']],
5351
['activation_kv_head_dim', ['model']],
5452
['activation_vocab', ['model', 'attn_dp']],
5553
['activation_norm_length', []],

src/maxtext/layers/attention_op.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
MODEL_MODE_PREFILL,
6464
MODEL_MODE_TRAIN,
6565
PREFILL_LENGTH,
66-
Q_LENGTH_NO_EXP,
66+
Q_LENGTH,
6767
)
6868
from maxtext.inference import page_manager
6969
from maxtext.inference.kvcache import KVQuant, KVTensor
@@ -1134,13 +1134,13 @@ def tpu_flash_attention(
11341134
segment_axis_names_kv = None
11351135
sink_axis_names = self._logical_to_mesh_axes((HEAD,))
11361136
if decoder_segment_ids is not None:
1137-
segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP))
1137+
segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH))
11381138
segment_axis_names_kv = self._logical_to_mesh_axes((BATCH, KV_LENGTH))
11391139

11401140
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel)
11411141
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q)
11421142
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv)
1143-
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP, KV_LENGTH))
1143+
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH))
11441144

11451145
global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv
11461146
global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel
@@ -1269,11 +1269,11 @@ def wrap_splash_kernel(single_head_mask):
12691269
return splash_kernel
12701270

12711271
splash_kernel = wrap_splash_kernel(single_head_mask)
1272-
segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH_NO_EXP,))
1272+
segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,))
12731273
elif self.config.use_jax_splash and self.config.expert_shard_attention_option == EP_AS_FSDP:
12741274
if self.config.use_max_logit_estimate > 0:
12751275
sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate)
1276-
segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH_NO_EXP,))
1276+
segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH,))
12771277
else:
12781278
# Create multi-head mask
12791279
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])

src/maxtext/layers/attentions.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
AxisNames,
3535
AxisIdxes,
3636
ATTN_LENGTH,
37-
ATTN_LENGTH_NO_EXP,
3837
DType,
3938
Config,
4039
Array,
@@ -44,12 +43,10 @@
4443
KV_HEAD,
4544
KV_HEAD_DIM,
4645
KV_BATCH,
47-
KV_BATCH_NO_EXP,
4846
ATTN_EMBED,
4947
MODEL_MODE_AUTOREGRESSIVE,
5048
MODEL_MODE_TRAIN,
5149
MODEL_MODE_PREFILL,
52-
EP_AS_CONTEXT,
5350
AttentionType,
5451
)
5552
from maxtext.layers import nnx_wrappers
@@ -141,14 +138,11 @@ def attention_as_linen(
141138
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
142139
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
143140
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
144-
query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
145-
key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
146-
value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
147-
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
148-
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
149-
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
150-
input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED),
151-
out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV),
141+
query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
142+
key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
143+
value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
144+
input_axis_names: AxisNames = (BATCH, ATTN_LENGTH, ATTN_EMBED),
145+
out_axis_names: AxisNames = (BATCH, ATTN_LENGTH, HEAD, D_KV),
152146
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED),
153147
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED),
154148
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
@@ -208,9 +202,6 @@ def attention_as_linen(
208202
query_axis_names=query_axis_names,
209203
key_axis_names=key_axis_names,
210204
value_axis_names=value_axis_names,
211-
ep_query_axis_names=ep_query_axis_names,
212-
ep_key_axis_names=ep_key_axis_names,
213-
ep_value_axis_names=ep_value_axis_names,
214205
input_axis_names=input_axis_names,
215206
out_axis_names=out_axis_names,
216207
prefill_input_axis_names=prefill_input_axis_names,
@@ -304,14 +295,11 @@ def __init__(
304295
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
305296
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
306297
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
307-
query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
308-
key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
309-
value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
310-
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
311-
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
312-
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
313-
input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED),
314-
out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV),
298+
query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
299+
key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
300+
value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM),
301+
input_axis_names: AxisNames = (BATCH, ATTN_LENGTH, ATTN_EMBED),
302+
out_axis_names: AxisNames = (BATCH, ATTN_LENGTH, HEAD, D_KV),
315303
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED),
316304
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED),
317305
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
@@ -413,9 +401,6 @@ def __init__(
413401
self.query_axis_names = query_axis_names
414402
self.key_axis_names = key_axis_names
415403
self.value_axis_names = value_axis_names
416-
self.ep_query_axis_names = ep_query_axis_names
417-
self.ep_key_axis_names = ep_key_axis_names
418-
self.ep_value_axis_names = ep_value_axis_names
419404
self.input_axis_names = input_axis_names
420405
self.out_axis_names = out_axis_names
421406
self.prefill_input_axis_names = prefill_input_axis_names
@@ -1161,10 +1146,6 @@ def __call__(
11611146
query = self._maybe_shard_with_logical(query, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV))
11621147
key = self._maybe_shard_with_logical(key, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV))
11631148
value = self._maybe_shard_with_logical(value, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV))
1164-
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1165-
query = self._maybe_shard_with_logical(query, self.ep_query_axis_names)
1166-
key = self._maybe_shard_with_logical(key, self.ep_key_axis_names)
1167-
value = self._maybe_shard_with_logical(value, self.ep_value_axis_names)
11681149
else:
11691150
query = self._maybe_shard_with_logical(query, self.query_axis_names)
11701151
key = self._maybe_shard_with_logical(key, self.key_axis_names)

tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,31 @@
22
"Activation Sharding Dump": [
33
{
44
"attentions/inputs_q: bfloat16[192,2048,2880]": {
5-
"logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')",
5+
"logic_axes": "('activation_batch', 'activation_attn_length', 'activation_attn_embed')",
66
"PartitionSpec": "P('fsdp', None, None)"
77
}
88
},
99
{
1010
"attentions/inputs_kv: bfloat16[192,2048,2880]": {
11-
"logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')",
11+
"logic_axes": "('activation_batch', 'activation_attn_length', 'activation_attn_embed')",
1212
"PartitionSpec": "P('fsdp', None, None)"
1313
}
1414
},
1515
{
1616
"attentions/query: bfloat16[192,2048,64,64]": {
17-
"logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
17+
"logic_axes": "('activation_kv_batch', 'activation_attn_length', 'activation_kv_heads', 'activation_kv_head_dim')",
1818
"PartitionSpec": "P('fsdp', None, None, None)"
1919
}
2020
},
2121
{
2222
"attentions/key: bfloat16[192,2048,8,64]": {
23-
"logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
23+
"logic_axes": "('activation_kv_batch', 'activation_attn_length', 'activation_kv_heads', 'activation_kv_head_dim')",
2424
"PartitionSpec": "P('fsdp', None, None, None)"
2525
}
2626
},
2727
{
2828
"attentions/value: bfloat16[192,2048,8,64]": {
29-
"logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
29+
"logic_axes": "('activation_kv_batch', 'activation_attn_length', 'activation_kv_heads', 'activation_kv_head_dim')",
3030
"PartitionSpec": "P('fsdp', None, None, None)"
3131
}
3232
},
@@ -50,7 +50,7 @@
5050
},
5151
{
5252
"attentions/out: bfloat16[192,2048,64,64]": {
53-
"logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')",
53+
"logic_axes": "('activation_batch', 'activation_attn_length', 'activation_heads', 'activation_kv')",
5454
"PartitionSpec": "P('fsdp', None, None, None)"
5555
}
5656
},

tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,31 @@
22
"Activation Sharding Dump": [
33
{
44
"attentions/inputs_q: bfloat16[768,2048,2880]": {
5-
"logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')",
5+
"logic_axes": "('activation_batch', 'activation_attn_length', 'activation_attn_embed')",
66
"PartitionSpec": "P(('data', 'fsdp'), None, None)"
77
}
88
},
99
{
1010
"attentions/inputs_kv: bfloat16[768,2048,2880]": {
11-
"logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_attn_embed')",
11+
"logic_axes": "('activation_batch', 'activation_attn_length', 'activation_attn_embed')",
1212
"PartitionSpec": "P(('data', 'fsdp'), None, None)"
1313
}
1414
},
1515
{
1616
"attentions/query: bfloat16[768,2048,64,64]": {
17-
"logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
17+
"logic_axes": "('activation_kv_batch', 'activation_attn_length', 'activation_kv_heads', 'activation_kv_head_dim')",
1818
"PartitionSpec": "P(('data', 'fsdp'), None, None, None)"
1919
}
2020
},
2121
{
2222
"attentions/key: bfloat16[768,2048,8,64]": {
23-
"logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
23+
"logic_axes": "('activation_kv_batch', 'activation_attn_length', 'activation_kv_heads', 'activation_kv_head_dim')",
2424
"PartitionSpec": "P(('data', 'fsdp'), None, None, None)"
2525
}
2626
},
2727
{
2828
"attentions/value: bfloat16[768,2048,8,64]": {
29-
"logic_axes": "('activation_kv_batch', 'activation_attn_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')",
29+
"logic_axes": "('activation_kv_batch', 'activation_attn_length', 'activation_kv_heads', 'activation_kv_head_dim')",
3030
"PartitionSpec": "P(('data', 'fsdp'), None, None, None)"
3131
}
3232
},
@@ -50,7 +50,7 @@
5050
},
5151
{
5252
"attentions/out: bfloat16[768,2048,64,64]": {
53-
"logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')",
53+
"logic_axes": "('activation_batch', 'activation_attn_length', 'activation_heads', 'activation_kv')",
5454
"PartitionSpec": "P(('data', 'fsdp'), None, None, None)"
5555
}
5656
},

0 commit comments

Comments
 (0)