Skip to content

Commit 0f0cff6

Browse files
committed
reorder logical rule and add embed_vocab
1 parent 6d86e74 commit 0f0cff6

15 files changed

Lines changed: 131 additions & 97 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 81 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -454,92 +454,118 @@ shard_mode: "auto" # can be either auto or explicit
454454
custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/.
455455
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
456456
logical_axis_rules: [
457-
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
458-
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
457+
# ==========================================
458+
# Vocabulary Embedding
459+
# ==========================================
460+
# Vocab Activations
459461
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
460462
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
461-
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
462-
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
463-
['activation_length', ['sequence', 'context']],
464-
['activation_length', ['context']],
463+
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
464+
['activation_vocab', ['tensor', 'tensor_transpose']],
465+
['activation_vocab', 'tensor_sequence'],
466+
['activation_vocab', ['sequence', 'context']],
467+
# Vocab Weights
468+
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
469+
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
470+
# ==========================================
471+
# Attention
472+
# ==========================================
473+
# Attention Activations
474+
['activation_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive']],
475+
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
465476
['activation_attn_length', ['sequence', 'context']],
466477
['activation_attn_length', ['context']],
467-
['activation_length_moe', ['sequence', 'context']],
468-
['activation_length_moe', ['context']],
469-
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
470-
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
471478
['activation_q_length', ['context']],
472-
['prefill_activation_length', ['sequence', 'context']],
473-
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
474479
['activation_kv_length', []],
475480
['activation_attn_embed', ['tensor', 'tensor_transpose']],
476-
['activation_embed', ['tensor', 'tensor_transpose']],
477-
['activation_embed_moe', ['tensor', 'tensor_transpose']],
478-
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
479-
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
480481
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
481-
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
482482
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
483483
['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
484-
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
485-
['activation_vocab', ['tensor', 'tensor_transpose']],
486-
['activation_vocab', 'tensor_sequence'],
487-
['activation_vocab', ['sequence','context']],
488-
['activation_stage', 'stage'],
489-
['activation_exp', ['expert']],
490-
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
491-
['decode_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
492-
['decode_length', ['sequence']],
493-
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
494-
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
495-
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
496-
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
484+
# Attention Weights
497485
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
498486
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
499487
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
500-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
501-
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
502-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
503-
['embed', ['fsdp', 'sequence', 'context', 'expert']],
504-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
505-
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
506-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
507-
['embed_moe', ['fsdp', 'sequence', 'context']],
508-
['embed_tensor_transpose', ['tensor_transpose']],
488+
['qkv', []],
489+
['kv', []],
490+
['kv_head_dim', []],
509491
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
510492
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
511493
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
512494
['q_lora', ['fsdp', 'sequence', 'context', 'expert']],
513-
["q_lora_up_proj",[]],
495+
["q_lora_up_proj", []],
514496
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
515497
['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
516498
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
517499
['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
518-
["kv_lora_up_proj",[]],
500+
["kv_lora_up_proj", []],
501+
# ==========================================
502+
# Mixture of Experts (MoE)
503+
# ==========================================
504+
# MoE Activations
505+
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
506+
['activation_length_moe', ['sequence', 'context']],
507+
['activation_length_moe', ['context']],
508+
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
509+
['activation_embed_moe', ['tensor', 'tensor_transpose']],
510+
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
511+
['activation_exp', ['expert']],
512+
# MoE Weights
513+
['exp', 'expert'],
514+
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
515+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
516+
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
517+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
518+
['embed_moe', ['fsdp', 'sequence', 'context']],
519+
# ==========================================
520+
# Standard MLP / Dense Layers / Model Structure
521+
# ==========================================
522+
# Dense Activations
523+
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
524+
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
525+
['activation_length', ['sequence', 'context']],
526+
['activation_length', ['context']],
527+
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
528+
['activation_embed', ['tensor', 'tensor_transpose']],
529+
['activation_stage', 'stage'],
530+
# General Weights
531+
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
532+
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
533+
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context', 'expert']],
534+
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
535+
['embed', ['fsdp', 'sequence', 'context', 'expert']],
519536
['norm', ['tensor', 'tensor_transpose']],
520537
['layers', 'stage'],
521-
['qkv', []],
522-
['kv', []],
523-
['kv_head_dim', []],
538+
['diloco', 'diloco'],
539+
['engram_dim', ['tensor']],
540+
['dense_layers', []],
541+
['moe_layers', []],
542+
['mhc', []],
543+
# ==========================================
544+
# Inference(Prefill, Decode, Cache)
545+
# ==========================================
546+
['prefill_activation_length', ['sequence', 'context']],
547+
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
548+
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
549+
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
550+
['decode_length', ['sequence']],
551+
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
552+
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
553+
['paged_kv_heads', ['tensor']],
524554
['cache_batch_prefill', []],
525555
['cache_batch', []],
526556
['cache_heads_none', []],
527-
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
528-
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
529557
['cache_kv', []],
530558
['cache_sequence', []],
531-
['exp', 'expert'],
532-
['exp_with_fsdp', 'fsdp'],
533-
['paged_kv_heads', ['tensor']],
534559
['num_pages', []],
535560
['tokens_per_page', []],
536561
['paged_kv_head_dim_size', []],
537-
['dense_layers', []],
538-
['moe_layers', []],
539-
['engram_dim', ['tensor']],
540-
['mhc', []],
541-
['diloco', 'diloco'],
542-
]
562+
# ==========================================
563+
# Deprecated / Scheduled for Removal
564+
# ==========================================
565+
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
566+
['embed_tensor_transpose', ['tensor_transpose']],
567+
['exp_with_fsdp', 'fsdp'],
568+
]
543569
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
544570
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
545571
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ logical_axis_rules: [
5959
['q_heads', ['tensor']],
6060
['kv_heads', ['tensor']],
6161
['embed', ['fsdp', 'expert']], # remove context from embed sharding
62+
['embed_vocab', ['fsdp', 'expert']],
6263
['embed_moe', ['fsdp']],
6364
['q_lora', ['fsdp']],
6465
['kv_lora', ['fsdp']],

src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ logical_axis_rules: [
2525
['activation_kv_batch', ['fsdp']],
2626
['decode_batch', ['fsdp']],
2727
['embed', ['fsdp']],
28+
['embed_vocab', ['fsdp']],
2829
['embed_moe', ['fsdp']],
2930
['q_lora', ['fsdp']],
3031
['kv_lora', ['fsdp']],

src/maxtext/configs/inference/inference.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ logical_axis_rules: [
3636
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
3737
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
3838
['embed', ['fsdp', 'sequence', 'expert']],
39+
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
40+
['embed_vocab', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
41+
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
42+
['embed_vocab', ['fsdp', 'sequence', 'expert']],
3943
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
4044
['embed_moe', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
4145
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],

src/maxtext/configs/inference/vllm.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ logical_axis_rules: [
6767
['kv', []],
6868
['embed', ['expert', 'attn_dp_expert']],
6969
['embed', ['attn_dp_expert']],
70+
['embed_vocab', ['expert', 'attn_dp_expert']],
71+
['embed_vocab', ['attn_dp_expert']],
7072
['embed_moe', []],
7173
['embed_moe', []],
7274
['embed_tensor_transpose', ['attn_dp', 'model']],

src/maxtext/layers/decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi
736736
out_features_shape=cfg.vocab_size,
737737
weight_dtype=cfg.weight_dtype,
738738
dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability
739-
kernel_axes=("embed", "vocab"),
739+
kernel_axes=("embed_vocab", "vocab"),
740740
shard_mode=cfg.shard_mode,
741741
name="logits_dense",
742742
matmul_precision=self.config.matmul_precision,

src/maxtext/layers/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(
132132
(self.num_embeddings, self.num_features),
133133
self.config.weight_dtype,
134134
),
135-
sharding=("vocab", "embed"),
135+
sharding=("vocab", "embed_vocab"),
136136
)
137137

138138
def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array:

src/maxtext/layers/nnx_decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def __init__(
287287
out_features_shape=config.vocab_size,
288288
weight_dtype=config.weight_dtype,
289289
dtype=jnp.float32 if config.logits_dot_in_fp32 else config.dtype,
290-
kernel_axes=("embed", "vocab"),
290+
kernel_axes=("embed_vocab", "vocab"),
291291
shard_mode=config.shard_mode,
292292
matmul_precision=self.config.matmul_precision,
293293
parameter_memory_host_offload=config.parameter_memory_host_offload,

tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/rule_default/logical_shardings.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@
133133
},
134134
".params/['params']/['decoder']/['logits_dense']/['kernel']": {
135135
"partition_spec": [
136-
"embed",
136+
"embed_vocab",
137137
"vocab"
138138
],
139139
"shape": [
@@ -318,7 +318,7 @@
318318
".params/['params']/['token_embedder']/['embedding']": {
319319
"partition_spec": [
320320
"vocab",
321-
"embed"
321+
"embed_vocab"
322322
],
323323
"shape": [
324324
102400,
@@ -459,7 +459,7 @@
459459
},
460460
".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": {
461461
"partition_spec": [
462-
"embed",
462+
"embed_vocab",
463463
"vocab"
464464
],
465465
"shape": [
@@ -644,7 +644,7 @@
644644
".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": {
645645
"partition_spec": [
646646
"vocab",
647-
"embed"
647+
"embed_vocab"
648648
],
649649
"shape": [
650650
102400,
@@ -781,7 +781,7 @@
781781
},
782782
".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": {
783783
"partition_spec": [
784-
"embed",
784+
"embed_vocab",
785785
"vocab"
786786
],
787787
"shape": [
@@ -966,7 +966,7 @@
966966
".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": {
967967
"partition_spec": [
968968
"vocab",
969-
"embed"
969+
"embed_vocab"
970970
],
971971
"shape": [
972972
102400,

tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/rule_pure-fsdp/logical_shardings.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@
133133
},
134134
".params/['params']/['decoder']/['logits_dense']/['kernel']": {
135135
"partition_spec": [
136-
"embed",
136+
"embed_vocab",
137137
"vocab"
138138
],
139139
"shape": [
@@ -318,7 +318,7 @@
318318
".params/['params']/['token_embedder']/['embedding']": {
319319
"partition_spec": [
320320
"vocab",
321-
"embed"
321+
"embed_vocab"
322322
],
323323
"shape": [
324324
102400,
@@ -459,7 +459,7 @@
459459
},
460460
".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": {
461461
"partition_spec": [
462-
"embed",
462+
"embed_vocab",
463463
"vocab"
464464
],
465465
"shape": [
@@ -644,7 +644,7 @@
644644
".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": {
645645
"partition_spec": [
646646
"vocab",
647-
"embed"
647+
"embed_vocab"
648648
],
649649
"shape": [
650650
102400,
@@ -781,7 +781,7 @@
781781
},
782782
".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": {
783783
"partition_spec": [
784-
"embed",
784+
"embed_vocab",
785785
"vocab"
786786
],
787787
"shape": [
@@ -966,7 +966,7 @@
966966
".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": {
967967
"partition_spec": [
968968
"vocab",
969-
"embed"
969+
"embed_vocab"
970970
],
971971
"shape": [
972972
102400,

0 commit comments

Comments
 (0)