@@ -217,19 +217,18 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo
217217 Calculate training TFLOP for Gemma2 as in Gemma2 we combine [local_attention, global_attention] into one decoder
218218 layer and we use sliding window attention in local_attention
219219 """
220- noncausal_attention_flops = (
221- # global attention
222- 4 * config .per_device_batch_size * config .max_target_length ** 2 * config .num_query_heads * config .head_dim
223- +
224- # local attention
220+ window = min ( config . sliding_window_size , config . max_target_length )
221+ global_causal_flops = (
222+ 2 * config .per_device_batch_size * config .max_target_length ** 2 * config .num_query_heads * config .head_dim
223+ )
224+ local_causal_flops = (
225225 4
226226 * config .per_device_batch_size
227- * config .max_target_length
228- * min (config .sliding_window_size , config .max_target_length )
227+ * (config .max_target_length * window - 0.5 * window ** 2 )
229228 * config .num_query_heads
230229 * config .head_dim
231230 )
232- causal_attention_flops = noncausal_attention_flops / 2
231+ causal_attention_flops = global_causal_flops + local_causal_flops
233232 attention_tflops = causal_attention_flops * config .num_decoder_layers * 3 / 10 ** 12
234233
235234 # multiply num_decoder_layers by 2 because we combine [local_attention, global_attention] into one decoder layer
@@ -241,7 +240,7 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo
241240
242241
243242def calculate_mixed_attention_model_tflops_training_per_device (
244- config , total_ffn_flops , qkv_flops , projection_flops , embedding_flops , attention_pattern_length
243+ config , total_ffn_flops_all_layers , qkv_flops , projection_flops , embedding_flops , attention_pattern_length
245244):
246245 """
247246 Calculate training TFLOPs for models with a mixed attention pattern of local
@@ -252,34 +251,125 @@ def calculate_mixed_attention_model_tflops_training_per_device(
252251 num_global_layers = num_layers // attention_pattern_length
253252 num_local_layers = num_layers - num_global_layers
254253
255- # FLOPs for a single global attention layer (full attention )
256- # Formula: 4 * batch_size * seq_len^2 * num_heads * head_dim
257- global_attention_flops_per_layer = (
258- 4 * config .per_device_batch_size * config .max_target_length ** 2 * config .num_query_heads * config .head_dim
254+ # Global causal attention uses a multiplier of 2 (instead of 4 for non-causal )
255+ # since we only compute the lower triangular half of the attention matrix.
256+ global_causal_flops_per_layer = (
257+ 2 * config .per_device_batch_size * config .max_target_length ** 2 * config .num_query_heads * config .head_dim
259258 )
260259
261- # FLOPs for a single local attention layer (sliding window)
262- # Formula: 4 * batch_size * seq_len * window_size * num_heads * head_dim
263- local_attention_flops_per_layer = (
260+ # Local sliding window attention directly computes the exact causal interactions
261+ # via the formula `(T * W - 0.5 * W^2)`. Therefore, we use the base multiplier of 4.
262+ window = min (config .sliding_window_size , config .max_target_length )
263+ local_causal_flops_per_layer = (
264264 4
265265 * config .per_device_batch_size
266- * config .max_target_length
267- * min (config .sliding_window_size , config .max_target_length )
266+ * (config .max_target_length * window - 0.5 * window ** 2 )
268267 * config .num_query_heads
269268 * config .head_dim
270269 )
271270
272- # Total attention FLOPs = (num_global_layers * FLOPs_per_global) + (num_local_layers * FLOPs_per_local)
273- noncausal_attention_flops = (
274- num_global_layers * global_attention_flops_per_layer + num_local_layers * local_attention_flops_per_layer
271+ causal_attention_flops = (
272+ num_global_layers * global_causal_flops_per_layer + num_local_layers * local_causal_flops_per_layer
273+ )
274+
275+ # Convert to TFLOPs and multiply by 3 for fwd/bwd pass
276+ attention_tflops = causal_attention_flops * 3 / 10 ** 12
277+
278+ total_learnable_flops = total_ffn_flops_all_layers
279+
280+ total_learnable_flops += (qkv_flops + projection_flops ) * num_layers + embedding_flops
281+
282+ learnable_weight_tflops = total_learnable_flops * 3 / 10 ** 12
283+
284+ return attention_tflops , learnable_weight_tflops
285+
286+
287+ def calculate_gemma4_tflops_training_per_device (
288+ config , total_ffn_flops_all_layers , embedding_flops , attention_pattern_length
289+ ):
290+ """
291+ Calculate training TFLOPs for Gemma 4.
292+ Gemma 4 has specific quirks:
293+ - Different QKV projection sizes for local vs. global layers.
294+ - Global-only KV sharing and varying global head dimensions.
295+ """
296+ num_layers = config .num_decoder_layers
297+
298+ num_global_layers = num_layers // attention_pattern_length
299+ num_local_layers = num_layers - num_global_layers
300+
301+ kv_multiplier = 1 if config .share_kv_projections else 2
302+ global_head_dim = config .global_head_dim or config .head_dim
303+ global_num_kv_heads = config .global_num_kv_heads or config .num_kv_heads
304+
305+ # Global causal attention uses a multiplier of 2 (instead of 4 for non-causal)
306+ # since we only compute the lower triangular half of the attention matrix.
307+ global_causal_flops_per_layer = (
308+ 2 * config .per_device_batch_size * config .max_target_length ** 2 * config .num_query_heads * global_head_dim
309+ )
310+
311+ # Local sliding window attention directly computes the exact causal interactions
312+ # via the formula `(T * W - 0.5 * W^2)`. Therefore, we use the base multiplier of 4.
313+ window = min (config .sliding_window_size , config .max_target_length )
314+ local_causal_flops_per_layer = (
315+ 4
316+ * config .per_device_batch_size
317+ * (config .max_target_length * window - 0.5 * window ** 2 )
318+ * config .num_query_heads
319+ * config .head_dim
320+ )
321+
322+ causal_attention_flops = (
323+ num_global_layers * global_causal_flops_per_layer + num_local_layers * local_causal_flops_per_layer
275324 )
276- causal_attention_flops = noncausal_attention_flops / 2
277325
278326 # Convert to TFLOPs and multiply by 3 for fwd/bwd pass
279327 attention_tflops = causal_attention_flops * 3 / 10 ** 12
280328
281- # Learnable weights (FFN, QKV, Projections) are present in every layer.
282- learnable_weight_tflops = ((total_ffn_flops + qkv_flops + projection_flops ) * num_layers + embedding_flops ) * 3 / 10 ** 12
329+ global_qkv_flops_per_layer = (
330+ 2
331+ * config .per_device_batch_size
332+ * config .max_target_length
333+ * config .emb_dim
334+ * (config .num_query_heads + kv_multiplier * global_num_kv_heads )
335+ * global_head_dim
336+ )
337+ global_projection_flops_per_layer = (
338+ 2
339+ * config .per_device_batch_size
340+ * config .max_target_length
341+ * config .emb_dim
342+ * config .num_query_heads
343+ * global_head_dim
344+ )
345+
346+ # Local layers never share KV projections (kv_multiplier is always 2).
347+ local_qkv_flops_per_layer = (
348+ 2
349+ * config .per_device_batch_size
350+ * config .max_target_length
351+ * config .emb_dim
352+ * (config .num_query_heads + 2 * config .num_kv_heads )
353+ * config .head_dim
354+ )
355+ local_projection_flops_per_layer = (
356+ 2
357+ * config .per_device_batch_size
358+ * config .max_target_length
359+ * config .emb_dim
360+ * config .num_query_heads
361+ * config .head_dim
362+ )
363+
364+ total_learnable_flops = total_ffn_flops_all_layers
365+
366+ total_learnable_flops += (
367+ (local_qkv_flops_per_layer + local_projection_flops_per_layer ) * num_local_layers
368+ + (global_qkv_flops_per_layer + global_projection_flops_per_layer ) * num_global_layers
369+ + embedding_flops
370+ )
371+
372+ learnable_weight_tflops = total_learnable_flops * 3 / 10 ** 12
283373
284374 return attention_tflops , learnable_weight_tflops
285375
@@ -496,11 +586,19 @@ def get_dense_moe_layers(config):
496586 elif config .decoder_block == DecoderBlockType .LLAMA4 :
497587 num_moe_layers = config .num_decoder_layers // config .interleave_moe_layer_step
498588 num_dense_layers = config .num_decoder_layers - num_moe_layers
589+ return num_dense_layers , num_moe_layers
499590 elif config .decoder_block == DecoderBlockType .QWEN3_NEXT :
591+ return 0 , config .num_decoder_layers
592+ elif config .decoder_block == DecoderBlockType .DEFAULT :
593+ raise ValueError ("Unsupported decoder block for dense/MoE layer calculation" )
594+
595+ num_experts = getattr (config , "num_experts" , 0 )
596+ if num_experts > 1 :
500597 num_moe_layers = config .num_decoder_layers
501598 num_dense_layers = 0
502599 else :
503- raise ValueError ("Currently we only support DeepSeek, Llama4, and Qwen3-Next calculation." )
600+ num_moe_layers = 0
601+ num_dense_layers = config .num_decoder_layers
504602
505603 return num_dense_layers , num_moe_layers
506604
@@ -601,6 +699,7 @@ def calculate_gemma3_vision_layers_tflops_per_device(config):
601699 learnable_weight_flops += 2 * vision_embedder_flops # only projector is learnable, add fwd+optimizer
602700 else :
603701 learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer
702+ total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass
604703
605704 # Convert to TFLOPs
606705 learnable_weight_tflops = learnable_weight_flops / 1e12
@@ -663,6 +762,7 @@ def calculate_llama4_vision_layers_tflops_per_device(config):
663762 learnable_weight_flops += 2 * projector_flops # only projector is learnable, add fwd+optimizer
664763 else :
665764 learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer
765+ total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass
666766
667767 # Convert to TFLOPs
668768 learnable_weight_tflops = learnable_weight_flops / 1e12
@@ -726,28 +826,40 @@ def calculate_vision_encoder_tflops(config):
726826def calculate_tflops_training_per_device (config , log = True ):
727827 """Calculate training TFLOP"""
728828 # MLP flops
829+ is_ffn_flops_already_total = False
729830 if config .num_experts > 1 :
730831 # calculation based on dropless implementation
731- if config .decoder_block in (DecoderBlockType .DEEPSEEK , DecoderBlockType .LLAMA4 , DecoderBlockType .QWEN3_NEXT ):
832+ if config .decoder_block in (
833+ DecoderBlockType .DEEPSEEK ,
834+ DecoderBlockType .LLAMA4 ,
835+ DecoderBlockType .QWEN3_NEXT ,
836+ DecoderBlockType .GEMMA4 ,
837+ ):
732838 total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device (config )
839+ is_ffn_flops_already_total = True
733840 else :
734841 gate_flops = 2 * config .per_device_batch_size * config .max_target_length * config .emb_dim * config .num_experts
735842 total_ffn_flops = (
736- gate_flops + calculate_ffn_mamtul_tflops_per_device (config , config .mlp_dim ) * config .num_experts_per_tok
843+ gate_flops + calculate_ffn_mamtul_tflops_per_device (config , config .moe_mlp_dim ) * config .num_experts_per_tok
737844 )
738845 else :
739846 total_ffn_flops = calculate_ffn_mamtul_tflops_per_device (config , config .mlp_dim )
740847
848+ total_ffn_flops_all_layers = (
849+ total_ffn_flops if is_ffn_flops_already_total else total_ffn_flops * config .num_decoder_layers
850+ )
851+
741852 # Attention flops
742853 if config .attention_type == "mla" :
743854 qkv_flops , causal_attention_flops , projection_flops = calculate_mla_tflops_per_device (config )
744855 else :
856+ kv_multiplier = 1 if config .share_kv_projections else 2
745857 qkv_flops = (
746858 2
747859 * config .per_device_batch_size
748860 * config .max_target_length
749861 * config .emb_dim
750- * (config .num_query_heads + 2 * config .num_kv_heads )
862+ * (config .num_query_heads + kv_multiplier * config .num_kv_heads )
751863 * config .head_dim
752864 )
753865 noncausal_attention_flops = (
@@ -768,7 +880,8 @@ def calculate_tflops_training_per_device(config, log=True):
768880 # NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
769881 causal_attention_flops = noncausal_attention_flops / 2
770882
771- # Embedding flops
883+ # Embedding flops (counts only the unembedding projection; the embedding lookup is a gather operation
884+ # that performs no dense math, matching standard MFU hardware calculations)
772885 embedding_flops = 2 * config .per_device_batch_size * config .max_target_length * config .emb_dim * config .vocab_size
773886
774887 # Combine flops with number of decoder layers
@@ -778,26 +891,30 @@ def calculate_tflops_training_per_device(config, log=True):
778891 )
779892 elif config .decoder_block == DecoderBlockType .GEMMA3 :
780893 attention_tflops , learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device (
781- config , total_ffn_flops , qkv_flops , projection_flops , embedding_flops , attention_pattern_length = 6
894+ config , total_ffn_flops_all_layers , qkv_flops , projection_flops , embedding_flops , attention_pattern_length = 6
782895 )
783896 elif config .decoder_block == DecoderBlockType .GPT_OSS :
784897 attention_tflops , learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device (
785- config , total_ffn_flops , qkv_flops , projection_flops , embedding_flops , attention_pattern_length = 2
898+ config , total_ffn_flops_all_layers , qkv_flops , projection_flops , embedding_flops , attention_pattern_length = 2
786899 )
787900 elif config .decoder_block == DecoderBlockType .LLAMA4 :
788901 # Use the new helper to calculate attention TFLOPs correctly.
789902 attention_tflops = calculate_llama4_attention_tflops (config )
790903 # The learnable weight calculation remains the same as it correctly handles Llama4's MoE structure.
791904 learnable_weight_tflops = (
792- (total_ffn_flops + (qkv_flops + projection_flops ) * config .num_decoder_layers + embedding_flops ) * 3 / 10 ** 12
905+ (total_ffn_flops_all_layers + (qkv_flops + projection_flops ) * config .num_decoder_layers + embedding_flops )
906+ * 3
907+ / 10 ** 12
793908 )
794909 elif config .decoder_block == DecoderBlockType .GEMMA4 :
795- attention_tflops , learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device (
796- config , total_ffn_flops , qkv_flops , projection_flops , embedding_flops , attention_pattern_length = 6
910+ attention_tflops , learnable_weight_tflops = calculate_gemma4_tflops_training_per_device (
911+ config , total_ffn_flops_all_layers , embedding_flops , attention_pattern_length = 6
797912 )
798913 elif config .decoder_block == DecoderBlockType .DEEPSEEK :
799914 learnable_weight_tflops = (
800- (total_ffn_flops + (qkv_flops + projection_flops ) * config .num_decoder_layers + embedding_flops ) * 3 / 10 ** 12
915+ (total_ffn_flops_all_layers + (qkv_flops + projection_flops ) * config .num_decoder_layers + embedding_flops )
916+ * 3
917+ / 10 ** 12
801918 )
802919 attention_tflops = causal_attention_flops * config .num_decoder_layers * 3 / 10 ** 12
803920 elif config .decoder_block == DecoderBlockType .QWEN3_NEXT :
@@ -808,7 +925,7 @@ def calculate_tflops_training_per_device(config, log=True):
808925
809926 # Weights TFLOPs:
810927 total_weights = (
811- total_ffn_flops
928+ total_ffn_flops_all_layers
812929 + embedding_flops
813930 + (qkv_flops + projection_flops ) * num_full_attn_layers
814931 + gdn_weight_flops_per_layer * num_linear_attn_layers
@@ -821,7 +938,9 @@ def calculate_tflops_training_per_device(config, log=True):
821938 else :
822939 # multiply by 3 for both feed forward and back propagation flops
823940 learnable_weight_tflops = (
824- ((total_ffn_flops + qkv_flops + projection_flops ) * config .num_decoder_layers + embedding_flops ) * 3 / 10 ** 12
941+ (total_ffn_flops_all_layers + (qkv_flops + projection_flops ) * config .num_decoder_layers + embedding_flops )
942+ * 3
943+ / 10 ** 12
825944 )
826945 attention_tflops = causal_attention_flops * config .num_decoder_layers * 3 / 10 ** 12
827946
0 commit comments