@@ -29,56 +29,77 @@ weight_dtype: bfloat16
2929# -------------- Logical Axis Rules --------------
3030mesh_axes : ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
3131logical_axis_rules : [
32- ['activation_batch', ['data']],
33- ['activation_batch_moe', ['data']],
34- ['activation_batch_attn', ['data']],
35- ['activation_embed_and_logits_batch', ['data', 'expert']],
36- ['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
37- ['activation_heads', ['model', 'expert']],
38- ['activation_kv_heads', ['model', 'expert']],
39- ['activation_length_attn', []],
40- ['activation_length', []],
41- ['activation_length_moe', []],
42- ['activation_q_length', ['expert', 'attn_dp_expert']],
43- ['activation_embed_attn', 'model'],
44- # Expert is missing explicitly from activation_embed despite using TP.
45- # We are going for a replicate-AR style of TP as opposed to our typical AG-RS style of TP
46- # due to the output sharding of the fused_moe_gmm kernel in tpu-inference.
47- ['activation_embed', ['model', 'attn_dp']],
48- ['activation_embed_moe', ['model', 'attn_dp']],
49- ['activation_mlp', ['model', 'attn_dp']],
50- ['activation_mlp_moe', ['model', 'attn_dp']],
51- ['activation_kv', ['model']],
52- ['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
53- ['activation_kv_batch', ['data']],
54- ['activation_kv_head_dim', ['model']],
55- ['activation_vocab', ['model', 'attn_dp']],
56- ['activation_norm_length', []],
57- ['activation_norm_length_moe', []],
58- ['activation_exp', ['expert', 'attn_dp_expert']],
59- ['decode_batch', ['data']],
60- ['decode_batch_moe', ['data']],
61- ['decode_length', []],
62- ['mlp', ['model', 'attn_dp']],
63- ['mlp_moe', ['model', 'attn_dp']],
64- ['mlp_no_fsdp', ['model', 'attn_dp']],
65- ['vocab', ['model', 'attn_dp']],
66- # Expert is intended to act like TP for attention.
67- # We target two all-reduces, one at the end of attention out projection and one at the end of the feedforward.
68- ['heads', ['model', 'expert']],
69- ['q_heads', ['model', 'expert']],
70- ['kv_heads', ['model', 'expert']],
71- ['kv_head_dim', []],
32+ # ==========================================
33+ # Vocabulary Embedding
34+ # ==========================================
35+ # Vocab Activations
36+ ['activation_embed_and_logits_batch', ['data', 'attn_dp', 'attn_dp_expert']],
37+ ['activation_embed_and_logits_batch_sequence', ['data', 'attn_dp', 'attn_dp_expert']],
38+ ['activation_vocab', ['expert', 'model']],
39+ # Vocab Weights
40+ ['vocab', []],
41+ ['embed_vocab', []],
42+ # ==========================================
43+ # Attention
44+ # ==========================================
45+ # Attention Activations
46+ ['activation_batch_attn', ['data', 'attn_dp', 'attn_dp_expert']],
47+ ['activation_heads', ['expert', 'model']],
48+ ['activation_kv_heads', ['expert', 'model']],
49+ ['activation_embed_attn', []],
50+ ['activation_kv', []],
51+ ['activation_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
52+ ['activation_kv_head_dim', []],
53+ # Attention Weights
54+ ['heads', ['expert', 'model']],
55+ ['q_heads', ['expert', 'model']],
56+ ['kv_heads', ['expert', 'model']],
57+ ['qkv', []],
7258 ['kv', []],
73- ['embed', []],
59+ ['kv_head_dim', []],
60+ ['q_lora', []],
61+ ["q_lora_up_proj", []],
62+ ['kv_lora', []],
63+ ["kv_lora_up_proj", []],
64+ # ==========================================
65+ # Mixture of Experts (MoE)
66+ # ==========================================
67+ # MoE Activations
68+ ['activation_batch_moe', ['data', 'attn_dp', 'attn_dp_expert']],
69+ ['activation_embed_moe', ['attn_dp', 'model']],
70+ ['activation_mlp_moe', ['attn_dp', 'model']],
71+ ['activation_exp', ['attn_dp_expert', 'expert']],
72+ # MoE Weights
73+ ['exp', ['attn_dp_expert', 'expert']],
74+ ['mlp_moe', ['attn_dp', 'model']],
7475 ['embed_moe', []],
75- ['embed_tensor_transpose', ['attn_dp', 'model']],
76- ['q_lora', ['expert', 'attn_dp_expert']],
77- ['kv_lora', ['expert', 'attn_dp_expert']],
76+ # ==========================================
77+ # Standard MLP / Dense Layers / Model Structure
78+ # ==========================================
79+ # Dense Activations
80+ ['activation_mlp', ['attn_dp', 'model']],
81+ # Note activation batch and length also get used in attention and vocab
82+ ['activation_batch', ['data', 'attn_dp', 'attn_dp_expert']],
83+ ['activation_embed', []],
84+ # General Weights
85+ ['mlp', ['attn_dp', 'model']],
86+ ['embed', []],
7887 ['norm', []],
79- ['cache_heads', ['model']],
80- ['exp', ['expert', 'attn_dp_expert']],
81- ['paged_kv_heads', ['model']],
82- ]
88+ # ==========================================
89+ # Inference(Prefill, Decode, Cache)
90+ # ==========================================
91+ ['activation_prefill_kv_batch', ['data', 'attn_dp', 'attn_dp_expert']],
92+ ['decode_batch', ['data', 'attn_dp', 'attn_dp_expert']],
93+ ['cache_heads', ['expert', 'model']],
94+ ['paged_kv_heads', ['expert', 'model']],
95+ ['cache_batch_prefill', []],
96+ ['cache_batch', []],
97+ ['cache_heads_none', []],
98+ ['cache_kv', []],
99+ ['cache_sequence', []],
100+ ['num_pages', []],
101+ ['tokens_per_page', []],
102+ ['paged_kv_head_dim_size', []],
103+ ]
83104data_sharding : [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']]
84105input_data_sharding_logical_axes : ['activation_embed_and_logits_batch']
0 commit comments