@@ -155,6 +155,12 @@ base_num_kv_heads: 16
155155base_mlp_dim : 7168
156156base_num_decoder_layers : 16
157157head_dim : 128
158+ attention_output_dim : -1
159+ local_num_query_heads : -1
160+ local_num_kv_heads : -1
161+ global_num_query_heads : -1
162+ global_num_kv_heads : -1
163+ attention_layer_hybrid_ratio : -1
158164mlp_activations : ["silu", "linear"]
159165mlp_activations_limit : -1.0
160166dropout_rate : 0.0
@@ -184,6 +190,11 @@ num_experts_per_tok: 1
184190megablox : true
185191sparse_matmul : true
186192capacity_factor : -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
193+ ragged_buffer_factor : -1.0 # a factor to determine the size of the ragged buffer for routed MoE activations.
194+ # By default (-1), this buffer will be worst case size to ensure no dropping.
195+ # When set to 1.0 this buffer if set to the size assuming perfectly balanced. If the routing dictates
196+ # a size larger than this then tokens will be dropped.
197+ # In general if ragged_buffer_factor>0, the ragged_buffer_size is is balanced_size * ragged_buffer_factor.
187198load_balance_loss_weight : 0.0 # weight for the load balance loss
188199use_random_routing : false # whether to use random routing for debug/test purpose
189200use_custom_sort_vjp : true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
@@ -240,6 +251,8 @@ use_2d_fsdp_sharding: False
240251
241252# deepseek moe
242253base_moe_mlp_dim : 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim.
254+ moe_model_dim : -1 # dimension of token entering moe layer.
255+ shared_expert_mlp_dim : -1 # intermediate dimension of the shared expert.
243256first_num_dense_layers : 0 # number of initial dense layers in the model
244257shared_experts : 1
245258routed_scaling_factor : 1.0 # scaling factor for routing scores
@@ -485,6 +498,7 @@ logical_axis_rules: [
485498 ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
486499 ['embed_no_exp', ['fsdp', 'sequence', 'context']],
487500 ['embed_tensor_transpose', ['tensor_transpose']],
501+ ['attention_out_proj', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
488502 ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
489503 ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
490504 ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
0 commit comments