@@ -203,26 +203,41 @@ def calculate_ffn_mamtul_tflops_per_device(config, mlp_dim):
203203 return ffn1_flops + ffn2_flops
204204
205205
206- def calculate_deepseek_ffn_tflops_per_device (config ):
206+ def calculate_routed_and_shared_ffn_tflops_per_device (config ):
207207 """Helper function to calculate DeepSeek-style ffn TFLOP"""
208208 gate_flops = 2 * config .per_device_batch_size * config .max_target_length * config .emb_dim * config .num_experts
209209 # Due to the mixed decoder layers, the flops is multiplied by num of layers for both dense and moe
210- dense_ffn_flops = calculate_ffn_mamtul_tflops_per_device (config , config .mlp_dim ) * config .first_num_dense_layers
210+ num_dense_layers , num_moe_layers = get_dense_moe_layers (config )
211+ dense_ffn_flops = calculate_ffn_mamtul_tflops_per_device (config , config .mlp_dim ) * num_dense_layers
211212 shared_experts_flops = calculate_ffn_mamtul_tflops_per_device (config , config .moe_mlp_dim ) * config .shared_experts
212213 routed_experts_flops = calculate_ffn_mamtul_tflops_per_device (config , config .moe_mlp_dim ) * config .num_experts_per_tok
213- moe_layers = config .num_decoder_layers - config .first_num_dense_layers
214- moe_ffn_flops = (gate_flops + shared_experts_flops + routed_experts_flops ) * moe_layers
214+ moe_ffn_flops = (gate_flops + shared_experts_flops + routed_experts_flops ) * num_moe_layers
215215 total_ffn_flops = dense_ffn_flops + moe_ffn_flops
216216 return total_ffn_flops
217217
218218
219+ def get_dense_moe_layers (config ):
220+ """Helper function to calculate number of dense and moe layers"""
221+ if config .decoder_block == DecoderBlockType .DEEPSEEK :
222+ num_dense_layers = config .first_num_dense_layers
223+ num_moe_layers = config .num_decoder_layers - config .first_num_dense_layers
224+ return num_dense_layers , num_moe_layers
225+ elif config .decoder_block == DecoderBlockType .LLAMA4 :
226+ num_moe_layers = config .num_decoder_layers // config .interleave_moe_layer_step
227+ num_dense_layers = config .num_decoder_layers - num_moe_layers
228+ else :
229+ raise ValueError ("Currently we only support DeepSeek and Llama4 calculation." )
230+
231+ return num_dense_layers , num_moe_layers
232+
233+
219234def calculate_tflops_training_per_device (config , log = True ):
220235 """Calculate training TFLOP"""
221236 # MLP flops
222237 if config .num_experts > 1 :
223238 # calculation based on dropless implementation
224- if config .decoder_block == DecoderBlockType .DEEPSEEK :
225- total_ffn_flops = calculate_deepseek_ffn_tflops_per_device (config )
239+ if config .decoder_block == DecoderBlockType .DEEPSEEK or config . decoder_block == DecoderBlockType . LLAMA4 :
240+ total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device (config )
226241 else :
227242 gate_flops = 2 * config .per_device_batch_size * config .max_target_length * config .emb_dim * config .num_experts
228243 total_ffn_flops = (
@@ -263,7 +278,7 @@ def calculate_tflops_training_per_device(config, log=True):
263278 attention_tflops , learnable_weight_tflops = calculate_gemma2_tflops_training_per_device (
264279 config , total_ffn_flops , qkv_flops , projection_flops , embedding_flops
265280 )
266- elif config .decoder_block == DecoderBlockType .DEEPSEEK :
281+ elif config .decoder_block == DecoderBlockType .DEEPSEEK or config . decoder_block == DecoderBlockType . LLAMA4 :
267282 learnable_weight_tflops = (
268283 (total_ffn_flops + (qkv_flops + projection_flops ) * config .num_decoder_layers + embedding_flops ) * 3 / 10 ** 12
269284 )
0 commit comments