@@ -29,59 +29,78 @@ weight_dtype: bfloat16
2929# -------------- Logical Axis Rules --------------
3030mesh_axes : ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
3131logical_axis_rules : [
32- ['activation_batch', ['data']],
33- ['activation_batch_moe', ['data']],
34- ['activation_embed_and_logits_batch', ['data', 'expert']],
35- ['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
32+ # ==========================================
33+ # Vocabulary Embedding
34+ # ==========================================
35+ # Vocab Activations
36+ ['activation_embed_and_logits_batch', ['data']],
37+ ['activation_embed_and_logits_batch_sequence', ['data']],
38+ ['activation_vocab', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
39+ # Vocab Weights
40+ ['vocab', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
41+ ['embed_vocab', []],
42+ # ==========================================
43+ # Attention
44+ # ==========================================
45+ # Attention Activations
46+ ['activation_batch_attn', ['data', 'attn_dp', 'attn_dp_expert']],
3647 ['activation_heads', ['model', 'expert']],
3748 ['activation_kv_heads', ['model', 'expert']],
38- ['activation_attn_length', []],
39- ['activation_length', []],
40- ['activation_length_moe', []],
41- ['activation_q_length', ['expert', 'attn_dp_expert']],
42- ['activation_attn_embed', 'model'],
43- # Expert is missing explicitly from activation_embed despite using TP.
44- # We are going for a replicate-AR style of TP as opposed to our typical AG-RS style of TP
45- # due to the output sharding of the fused_moe_gmm kernel in tpu-inference.
46- ['activation_embed', ['model', 'attn_dp']],
47- ['activation_embed_moe', ['model', 'attn_dp']],
48- ['activation_mlp', ['model', 'attn_dp']],
49- ['activation_mlp_moe', ['model', 'attn_dp']],
50- ['activation_kv', ['model']],
51- ['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
52- ['activation_kv_batch', ['data']],
53- ['activation_kv_head_dim', ['model']],
54- ['activation_vocab', ['model', 'attn_dp']],
55- ['activation_norm_length', []],
56- ['activation_norm_length_moe', []],
57- ['activation_exp', ['expert', 'attn_dp_expert']],
58- ['decode_batch', ['data']],
59- ['decode_batch_moe', ['data']],
60- ['decode_length', []],
61- ['mlp', ['model', 'attn_dp']],
62- ['mlp_moe', ['model', 'attn_dp']],
63- ['mlp_no_fsdp', ['model', 'attn_dp']],
64- ['vocab', ['model', 'attn_dp']],
65- # Expert is intended to act like TP for attention.
66- # We target two all-reduces, one at the end of attention out projection and one at the end of the feedforward.
49+ ['activation_attn_embed', []],
50+ ['activation_kv', ['model', 'expert']],
51+ ['activation_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
52+ ['activation_kv_head_dim', []],
53+ # Attention Weights
6754 ['heads', ['model', 'expert']],
6855 ['q_heads', ['model', 'expert']],
6956 ['kv_heads', ['model', 'expert']],
70- ['kv_head_dim ', []],
57+ ['qkv ', []],
7158 ['kv', []],
72- ['embed', ['expert', 'attn_dp_expert']],
73- ['embed', ['attn_dp_expert']],
74- ['embed_vocab', ['expert', 'attn_dp_expert']],
75- ['embed_vocab', ['attn_dp_expert']],
76- ['embed_moe', []],
59+ ['kv_head_dim', []],
60+ ['q_lora', []],
61+ ["q_lora_up_proj", []],
62+ ['kv_lora', []],
63+ ["kv_lora_up_proj", []],
64+ # ==========================================
65+ # Mixture of Experts (MoE)
66+ # ==========================================
67+ # MoE Activations
68+ ['activation_batch_moe', ['data']],
69+ ['activation_embed_moe', ['model']],
70+ ['activation_mlp_moe', []],
71+ ['activation_exp', ['expert', 'attn_dp', 'attn_dp_expert']],
72+ # MoE Weights
73+ ['exp', ['expert', 'attn_dp', 'attn_dp_expert']],
74+ ['mlp_moe', []],
7775 ['embed_moe', []],
78- ['embed_tensor_transpose', ['attn_dp', 'model']],
79- ['q_lora', ['expert', 'attn_dp_expert']],
80- ['kv_lora', ['expert', 'attn_dp_expert']],
76+ # ==========================================
77+ # Standard MLP / Dense Layers / Model Structure
78+ # ==========================================
79+ # Dense Activations
80+ ['activation_mlp', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
81+ # Note activation batch and length also get used in attention and vocab
82+ ['activation_batch', ['data']],
83+ ['activation_embed', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
84+ # General Weights
85+ ['mlp', ['model', 'expert', 'attn_dp', 'attn_dp_expert']],
86+ ['embed', []],
8187 ['norm', []],
88+ # ==========================================
89+ # Inference(Prefill, Decode, Cache)
90+ # ==========================================
91+ ['activation_prefill_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
92+ ['decode_batch', ['data', 'attn_dp', 'attn_dp_expert']],
93+ ['cache_heads', ['model', 'expert']],
8294 ['cache_heads', ['model']],
83- ['exp', ['expert', 'attn_dp_expert']],
84- ['paged_kv_heads', ['model']],
85- ]
95+ ['paged_kv_heads', ['model', 'expert']],
96+ ['cache_batch_prefill', []],
97+ ['cache_batch', []],
98+ ['cache_heads_none', []],
99+ ['cache_kv', []],
100+ ['cache_sequence', []],
101+ ['num_pages', []],
102+ ['tokens_per_page', []],
103+ ['paged_kv_head_dim_size', []],
104+ ]
86105data_sharding : [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']]
87106input_data_sharding_logical_axes : ['activation_embed_and_logits_batch']
0 commit comments