1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- # This rule only uses FSDP. Pure FSDP is the go-to sharding strategy
15+ # This rule only uses FSDP. Pure FSDP is the go-to sharding strategy
1616# for small-scale training and this rule simplifies the overall configuration.
1717mesh_axes : ['fsdp']
1818data_sharding : [['fsdp']]
1919logical_axis_rules : [
20+ # Batch/data dimensions sharded on fsdp
2021 ['activation_batch', ['fsdp']],
2122 ['activation_batch_no_exp', ['fsdp']],
2223 ['activation_batch_moe', ['fsdp']],
@@ -27,11 +28,65 @@ logical_axis_rules: [
2728 ['activation_kv_batch', ['fsdp']],
2829 ['activation_kv_batch_no_exp', ['fsdp']],
2930 ['decode_batch', ['fsdp']],
31+ # Weight dimensions sharded on fsdp
3032 ['embed', ['fsdp']],
3133 ['embed_no_exp', ['fsdp']],
3234 ['embed_moe', ['fsdp']],
3335 ['embed_no_exp_moe', ['fsdp']],
3436 ['q_lora', ['fsdp']],
3537 ['kv_lora', ['fsdp']],
3638 ['exp_with_fsdp', 'fsdp'],
39+ # All other axes are unsharded (tensor/sequence/expert axes do not exist in pure-fsdp)
40+ ['activation_heads', []],
41+ ['activation_kv_heads', []],
42+ ['activation_length', []],
43+ ['activation_attn_length', []],
44+ ['activation_attn_length_no_exp', []],
45+ ['activation_length_no_exp', []],
46+ ['activation_norm_length', []],
47+ ['activation_q_length', []],
48+ ['activation_q_length_no_exp', []],
49+ ['prefill_activation_length', []],
50+ ['prefill_activation_norm_length', []],
51+ ['activation_kv_length', []],
52+ ['activation_attn_embed', []],
53+ ['activation_embed', []],
54+ ['activation_mlp', []],
55+ ['activation_kv', []],
56+ ['activation_kv_head_dim', []],
57+ ['activation_vocab', []],
58+ ['activation_stage', []],
59+ ['activation_exp', []],
60+ ['decode_length', []],
61+ ['mlp', []],
62+ ['mlp_no_fsdp', []],
63+ ['vocab', []],
64+ ['heads', []],
65+ ['q_heads', []],
66+ ['kv_heads', []],
67+ ['embed_tensor_transpose', []],
68+ ['q_lora_up_proj', []],
69+ ['kv_lora_up_proj', []],
70+ ['norm', []],
71+ ['layers', []],
72+ ['qkv', []],
73+ ['kv', []],
74+ ['kv_head_dim', []],
75+ ['cache_batch_prefill', []],
76+ ['cache_batch', []],
77+ ['cache_heads_none', []],
78+ ['cache_heads', []],
79+ ['cache_kv', []],
80+ ['cache_sequence', []],
81+ ['exp', []],
82+ ['paged_kv_heads', []],
83+ ['num_pages', []],
84+ ['tokens_per_page', []],
85+ ['paged_kv_head_dim_size', []],
86+ ['dense_layers', []],
87+ ['moe_layers', []],
88+ ['num_activations', []],
89+ ['engram_dim', []],
90+ ['mhc', []],
91+ ['diloco', []],
3792 ]
0 commit comments