@@ -29,59 +29,77 @@ weight_dtype: bfloat16
2929# -------------- Logical Axis Rules --------------
3030mesh_axes : ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
3131logical_axis_rules : [
32- ['activation_batch', ['data']],
33- ['activation_batch_moe', ['data']],
34- ['activation_embed_and_logits_batch', ['data', 'expert']],
35- ['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
32+ # ==========================================
33+ # Vocabulary Embedding
34+ # ==========================================
35+ # Vocab Activations
36+ ['activation_embed_and_logits_batch', ['data', 'attn_dp_expert']],
37+ ['activation_embed_and_logits_batch_sequence', ['data', 'attn_dp_expert']],
38+ ['activation_vocab', ['model', 'expert']],
39+ # Vocab Weights
40+ ['vocab', ['model', 'expert']],
41+ ['embed_vocab', []],
42+ # ==========================================
43+ # Attention
44+ # ==========================================
45+ # Attention Activations
3646 ['activation_heads', ['model', 'expert']],
3747 ['activation_kv_heads', ['model', 'expert']],
38- ['activation_attn_length', []],
39- ['activation_length', []],
40- ['activation_length_moe', []],
41- ['activation_q_length', ['expert', 'attn_dp_expert']],
42- ['activation_attn_embed', 'model'],
43- # Expert is missing explicitly from activation_embed despite using TP.
44- # We are going for a replicate-AR style of TP as opposed to our typical AG-RS style of TP
45- # due to the output sharding of the fused_moe_gmm kernel in tpu-inference.
46- ['activation_embed', ['model', 'attn_dp']],
47- ['activation_embed_moe', ['model', 'attn_dp']],
48- ['activation_mlp', ['model', 'attn_dp']],
49- ['activation_mlp_moe', ['model', 'attn_dp']],
50- ['activation_kv', ['model']],
51- ['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
52- ['activation_kv_batch', ['data']],
53- ['activation_kv_head_dim', ['model']],
54- ['activation_vocab', ['model', 'attn_dp']],
55- ['activation_norm_length', []],
56- ['activation_norm_length_moe', []],
57- ['activation_exp', ['expert', 'attn_dp_expert']],
58- ['decode_batch', ['data']],
59- ['decode_batch_moe', ['data']],
60- ['decode_length', []],
61- ['mlp', ['model', 'attn_dp']],
62- ['mlp_moe', ['model', 'attn_dp']],
63- ['mlp_no_fsdp', ['model', 'attn_dp']],
64- ['vocab', ['model', 'attn_dp']],
65- # Expert is intended to act like TP for attention.
66- # We target two all-reduces, one at the end of attention out projection and one at the end of the feedforward.
48+ ['activation_attn_embed', ['model', 'expert']],
49+ ['activation_kv', ['model', 'expert']],
50+ ['activation_kv_batch', ['data', 'attn_dp_expert']],
51+ ['activation_kv_head_dim', []],
52+ # Attention Weights
6753 ['heads', ['model', 'expert']],
6854 ['q_heads', ['model', 'expert']],
6955 ['kv_heads', ['model', 'expert']],
70- ['kv_head_dim ', []],
56+ ['qkv ', []],
7157 ['kv', []],
72- ['embed', ['expert', 'attn_dp_expert']],
73- ['embed', ['attn_dp_expert']],
74- ['embed_vocab', ['expert', 'attn_dp_expert']],
75- ['embed_vocab', ['attn_dp_expert']],
76- ['embed_moe', []],
77- ['embed_moe', []],
78- ['embed_tensor_transpose', ['attn_dp', 'model']],
79- ['q_lora', ['expert', 'attn_dp_expert']],
80- ['kv_lora', ['expert', 'attn_dp_expert']],
81- ['norm', []],
82- ['cache_heads', ['model']],
58+ ['kv_head_dim', []],
59+ ['q_lora', []],
60+ ["q_lora_up_proj", []],
61+ ['kv_lora', []],
62+ ["kv_lora_up_proj", []],
63+ # ==========================================
64+ # Mixture of Experts (MoE)
65+ # ==========================================
66+ # MoE Activations
67+ ['activation_batch_moe', ['data']],
68+ ['activation_embed_moe', ['model']],
69+ ['activation_mlp_moe', []],
70+ ['activation_exp', ['expert', 'attn_dp_expert']],
71+ # MoE Weights
8372 ['exp', ['expert', 'attn_dp_expert']],
84- ['paged_kv_heads', ['model']],
85- ]
73+ ['mlp_moe', []],
74+ ['embed_moe', ['data']],
75+ # ==========================================
76+ # Standard MLP / Dense Layers / Model Structure
77+ # ==========================================
78+ # Dense Activations
79+ ['activation_mlp', ['model', 'expert']],
80+ # Note activation batch and length also get used in attention and vocab
81+ ['activation_batch', ['data', 'attn_dp_expert']],
82+ ['activation_embed', ['model', 'expert']],
83+ # General Weights
84+ ['mlp', ['model', 'expert']],
85+ ['embed', []],
86+ ['norm', ['model', 'expert']],
87+ # ==========================================
88+ # Inference(Prefill, Decode, Cache)
89+ # ==========================================
90+ ['activation_prefill_kv_batch', ['data', 'attn_dp_expert']],
91+ ['decode_batch', ['data', 'attn_dp_expert']],
92+ ['cache_heads', ['model', 'expert']],
93+ ['cache_heads', ['model']],
94+ ['paged_kv_heads', ['model', 'expert']],
95+ ['cache_batch_prefill', []],
96+ ['cache_batch', []],
97+ ['cache_heads_none', []],
98+ ['cache_kv', []],
99+ ['cache_sequence', []],
100+ ['num_pages', []],
101+ ['tokens_per_page', []],
102+ ['paged_kv_head_dim_size', []],
103+ ]
86104data_sharding : [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']]
87105input_data_sharding_logical_axes : ['activation_embed_and_logits_batch']
0 commit comments