Skip to content

Commit 52ff5ef

Browse files
committed
deprecate sequence axis
1 parent f67d8b1 commit 52ff5ef

9 files changed

Lines changed: 62 additions & 1273 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -445,29 +445,27 @@ compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_s
445445
# Parallelism
446446
shard_mode: "auto" # can be either auto or explicit
447447
custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/.
448-
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
448+
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
449449
logical_axis_rules: [
450450
# ==========================================
451451
# Vocabulary Embedding
452452
# ==========================================
453453
# Vocab Activations
454454
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
455-
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
455+
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'expert']],
456456
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
457457
['activation_vocab', ['tensor', 'tensor_transpose']],
458458
['activation_vocab', 'tensor_sequence'],
459-
['activation_vocab', ['sequence', 'context']],
460459
# Vocab Weights
461460
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
462-
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
461+
['embed_vocab', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
463462
# ==========================================
464463
# Attention
465464
# ==========================================
466465
# Attention Activations
467466
['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
468-
['activation_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive']],
469-
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
470-
['activation_length_attn', ['sequence', 'context']],
467+
['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
468+
['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']],
471469
['activation_length_attn', ['context']],
472470
['activation_q_length', ['context']],
473471
['activation_kv_length', []],
@@ -482,52 +480,50 @@ logical_axis_rules: [
482480
['qkv', []],
483481
['kv', []],
484482
['kv_head_dim', []],
485-
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
486-
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
487-
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
488-
['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
483+
['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']],
484+
['q_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']],
485+
['q_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
486+
['q_lora', ['fsdp', 'context', 'expert']],
489487
["q_lora_up_proj", []],
490-
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
491-
['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
492-
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
493-
['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
488+
['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'tensor_transpose', 'expert']],
489+
['kv_lora', ['fsdp', 'context', 'tensor_transpose', 'expert']],
490+
['kv_lora', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
491+
['kv_lora', ['fsdp', 'context', 'expert']],
494492
["kv_lora_up_proj", []],
495493
# ==========================================
496494
# Mixture of Experts (MoE)
497495
# ==========================================
498496
# MoE Activations
499497
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
500-
['activation_length_moe', ['sequence', 'context']],
501498
['activation_length_moe', ['context']],
502-
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
499+
['activation_norm_length_moe', ['tensor_sequence', 'context']],
503500
['activation_embed_moe', ['tensor', 'tensor_transpose']],
504501
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
505502
['activation_exp', ['expert']],
506503
# MoE Weights
507504
['exp', 'expert'],
508505
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
509-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
510-
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
511-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
512-
['embed_moe', ['fsdp', 'sequence', 'context']],
506+
['embed_moe', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context']],
507+
['embed_moe', ['fsdp', 'tensor_transpose', 'context']],
508+
['embed_moe', ['fsdp', 'fsdp_transpose', 'context']],
509+
['embed_moe', ['fsdp', 'context']],
513510
# ==========================================
514511
# Standard MLP / Dense Layers / Model Structure
515512
# ==========================================
516513
# Dense Activations
517514
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
518515
# Note activation batch and length also get used in vocab
519516
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
520-
['activation_length', ['sequence', 'context']],
521517
['activation_length', ['context']],
522-
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
518+
['activation_norm_length', ['tensor_sequence', 'context']],
523519
['activation_embed', ['tensor', 'tensor_transpose']],
524520
['activation_stage', 'stage'],
525521
# General Weights
526522
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
527-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
528-
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context', 'expert']],
529-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
530-
['embed', ['fsdp', 'sequence', 'context', 'expert']],
523+
['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'context', 'expert']],
524+
['embed', ['fsdp', 'tensor_transpose', 'context', 'expert']],
525+
['embed', ['fsdp', 'fsdp_transpose', 'context', 'expert']],
526+
['embed', ['fsdp', 'context', 'expert']],
531527
['norm', ['tensor', 'tensor_transpose']],
532528
['layers', 'stage'],
533529
['diloco', 'diloco'],
@@ -538,11 +534,11 @@ logical_axis_rules: [
538534
# ==========================================
539535
# Inference(Prefill, Decode, Cache)
540536
# ==========================================
541-
['prefill_activation_length', ['sequence', 'context']],
542-
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
537+
['prefill_activation_length', ['context']],
538+
['prefill_activation_norm_length', ['tensor_sequence', 'context']],
543539
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
544540
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
545-
['decode_length', ['sequence']],
541+
['decode_length', []],
546542
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
547543
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
548544
['paged_kv_heads', ['tensor']],
@@ -562,7 +558,7 @@ logical_axis_rules: [
562558
['exp_with_fsdp', 'fsdp'],
563559
]
564560
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
565-
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
561+
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
566562
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']
567563
# Determines which physical axis plays the role of context parallelism for input data processing and load balancing
568564
# only supports "context" or "expert" (when custom_mesh_and_rule=ep-as-cp)

src/maxtext/configs/inference/inference.yml

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,24 @@ base_config: "base.yml"
33
logical_axis_rules: [
44
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
55
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
6-
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
7-
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
8-
['activation_length', ['context_autoregressive', 'sequence']],
6+
['activation_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']],
7+
['activation_kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence']],
8+
['activation_length', ['context_autoregressive']],
99
['activation_length', ['context_autoregressive']],
1010
['activation_q_length', ['context_autoregressive']],
1111
['activation_kv_length', ['context_autoregressive']],
12-
['activation_norm_length', ['tensor_sequence', 'sequence']],
12+
['activation_norm_length', ['tensor_sequence']],
1313
['activation_embed', ['tensor_transpose']],
1414
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
1515
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
1616
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
1717
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
1818
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
1919
['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
20-
['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
20+
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
2121
['activation_vocab', ['tensor', 'tensor_transpose']],
2222
['activation_vocab', 'tensor_sequence'],
23-
['activation_vocab', ['sequence', 'context_autoregressive']],
23+
['activation_vocab', ['context_autoregressive']],
2424
['activation_stage', 'stage'],
2525
['activation_exp', ['expert', 'context_autoregressive']],
2626
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
@@ -32,18 +32,18 @@ logical_axis_rules: [
3232
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
3333
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
3434
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
35-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
36-
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
37-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
38-
['embed', ['fsdp', 'sequence', 'expert']],
39-
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
40-
['embed_vocab', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
41-
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
42-
['embed_vocab', ['fsdp', 'sequence', 'expert']],
43-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
44-
['embed_moe', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
45-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
46-
['embed_moe', ['fsdp', 'sequence', 'context_autoregressive']],
35+
['embed', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'expert']],
36+
['embed', ['fsdp', 'tensor_transpose', 'expert']],
37+
['embed', ['fsdp', 'fsdp_transpose', 'expert']],
38+
['embed', ['fsdp', 'expert']],
39+
['embed_vocab', ['fsdp', 'fsdp_transpose', 'tensor_transpose', 'expert']],
40+
['embed_vocab', ['fsdp', 'tensor_transpose', 'expert']],
41+
['embed_vocab', ['fsdp', 'fsdp_transpose', 'expert']],
42+
['embed_vocab', ['fsdp', 'expert']],
43+
['embed_moe', ['fsdp', 'fsdp_transpose', 'context_autoregressive', 'tensor_transpose']],
44+
['embed_moe', ['fsdp', 'context_autoregressive', 'tensor_transpose']],
45+
['embed_moe', ['fsdp', 'fsdp_transpose', 'context_autoregressive']],
46+
['embed_moe', ['fsdp', 'context_autoregressive']],
4747
['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']],
4848
['layers', 'stage'],
4949
['kv', []],
@@ -62,4 +62,4 @@ logical_axis_rules: [
6262
['paged_kv_head_dim_size', []],
6363
]
6464
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
65-
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
65+
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]

0 commit comments

Comments
 (0)