Skip to content

Commit 6288c23

Browse files
committed
Add Gemma 4 FLOPs & fix sliding window flops computations
1 parent ce8a7de commit 6288c23

5 files changed

Lines changed: 755 additions & 315 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ use_2d_fsdp_sharding: False
253253
# deepseek moe
254254
base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim.
255255
first_num_dense_layers: 0 # number of initial dense layers in the model
256-
shared_experts: 1
256+
shared_experts: 0
257257
routed_scaling_factor: 1.0 # scaling factor for routing scores
258258
routed_score_func: "" # scoring function for routing
259259
routed_bias: False # a flag if a learnable bias is added for routing

src/maxtext/configs/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ class DeepSeekMoE(BaseModel):
752752

753753
base_moe_mlp_dim: int = Field(7168, description="Intermediate dimension at MoE layer (DeepSeek style).")
754754
first_num_dense_layers: NonNegativeInt = Field(0, description="Number of initial dense layers in the model.")
755-
shared_experts: PositiveInt = Field(1, description="Number of shared experts.")
755+
shared_experts: NonNegativeInt = Field(0, description="Number of shared experts.")
756756
routed_scaling_factor: float = Field(1.0, description="Scaling factor for routing scores.")
757757
routed_score_func: str = Field("", description="Scoring function for routing (e.g., 'softmax', 'sigmoid').")
758758
routed_bias: bool = Field(False, description="Whether to add a bias term for routing.")

src/maxtext/utils/maxtext_utils.py

Lines changed: 156 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -217,19 +217,18 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo
217217
Calculate training TFLOP for Gemma2 as in Gemma2 we combine [local_attention, global_attention] into one decoder
218218
layer and we use sliding window attention in local_attention
219219
"""
220-
noncausal_attention_flops = (
221-
# global attention
222-
4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
223-
+
224-
# local attention
220+
window = min(config.sliding_window_size, config.max_target_length)
221+
global_causal_flops = (
222+
2 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
223+
)
224+
local_causal_flops = (
225225
4
226226
* config.per_device_batch_size
227-
* config.max_target_length
228-
* min(config.sliding_window_size, config.max_target_length)
227+
* (config.max_target_length * window - 0.5 * window**2)
229228
* config.num_query_heads
230229
* config.head_dim
231230
)
232-
causal_attention_flops = noncausal_attention_flops / 2
231+
causal_attention_flops = global_causal_flops + local_causal_flops
233232
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12
234233

235234
# multiply num_decoder_layers by 2 because we combine [local_attention, global_attention] into one decoder layer
@@ -241,7 +240,7 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo
241240

242241

243242
def calculate_mixed_attention_model_tflops_training_per_device(
244-
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length
243+
config, total_ffn_flops_all_layers, qkv_flops, projection_flops, embedding_flops, attention_pattern_length
245244
):
246245
"""
247246
Calculate training TFLOPs for models with a mixed attention pattern of local
@@ -252,34 +251,125 @@ def calculate_mixed_attention_model_tflops_training_per_device(
252251
num_global_layers = num_layers // attention_pattern_length
253252
num_local_layers = num_layers - num_global_layers
254253

255-
# FLOPs for a single global attention layer (full attention)
256-
# Formula: 4 * batch_size * seq_len^2 * num_heads * head_dim
257-
global_attention_flops_per_layer = (
258-
4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
254+
# Global causal attention uses a multiplier of 2 (instead of 4 for non-causal)
255+
# since we only compute the lower triangular half of the attention matrix.
256+
global_causal_flops_per_layer = (
257+
2 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
259258
)
260259

261-
# FLOPs for a single local attention layer (sliding window)
262-
# Formula: 4 * batch_size * seq_len * window_size * num_heads * head_dim
263-
local_attention_flops_per_layer = (
260+
# Local sliding window attention directly computes the exact causal interactions
261+
# via the formula `(T * W - 0.5 * W^2)`. Therefore, we use the base multiplier of 4.
262+
window = min(config.sliding_window_size, config.max_target_length)
263+
local_causal_flops_per_layer = (
264264
4
265265
* config.per_device_batch_size
266-
* config.max_target_length
267-
* min(config.sliding_window_size, config.max_target_length)
266+
* (config.max_target_length * window - 0.5 * window**2)
268267
* config.num_query_heads
269268
* config.head_dim
270269
)
271270

272-
# Total attention FLOPs = (num_global_layers * FLOPs_per_global) + (num_local_layers * FLOPs_per_local)
273-
noncausal_attention_flops = (
274-
num_global_layers * global_attention_flops_per_layer + num_local_layers * local_attention_flops_per_layer
271+
causal_attention_flops = (
272+
num_global_layers * global_causal_flops_per_layer + num_local_layers * local_causal_flops_per_layer
273+
)
274+
275+
# Convert to TFLOPs and multiply by 3 for fwd/bwd pass
276+
attention_tflops = causal_attention_flops * 3 / 10**12
277+
278+
total_learnable_flops = total_ffn_flops_all_layers
279+
280+
total_learnable_flops += (qkv_flops + projection_flops) * num_layers + embedding_flops
281+
282+
learnable_weight_tflops = total_learnable_flops * 3 / 10**12
283+
284+
return attention_tflops, learnable_weight_tflops
285+
286+
287+
def calculate_gemma4_tflops_training_per_device(
288+
config, total_ffn_flops_all_layers, embedding_flops, attention_pattern_length
289+
):
290+
"""
291+
Calculate training TFLOPs for Gemma 4.
292+
Gemma 4 has specific quirks:
293+
- Different QKV projection sizes for local vs. global layers.
294+
- Global-only KV sharing and varying global head dimensions.
295+
"""
296+
num_layers = config.num_decoder_layers
297+
298+
num_global_layers = num_layers // attention_pattern_length
299+
num_local_layers = num_layers - num_global_layers
300+
301+
kv_multiplier = 1 if config.share_kv_projections else 2
302+
global_head_dim = config.global_head_dim or config.head_dim
303+
global_num_kv_heads = config.global_num_kv_heads or config.num_kv_heads
304+
305+
# Global causal attention uses a multiplier of 2 (instead of 4 for non-causal)
306+
# since we only compute the lower triangular half of the attention matrix.
307+
global_causal_flops_per_layer = (
308+
2 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * global_head_dim
309+
)
310+
311+
# Local sliding window attention directly computes the exact causal interactions
312+
# via the formula `(T * W - 0.5 * W^2)`. Therefore, we use the base multiplier of 4.
313+
window = min(config.sliding_window_size, config.max_target_length)
314+
local_causal_flops_per_layer = (
315+
4
316+
* config.per_device_batch_size
317+
* (config.max_target_length * window - 0.5 * window**2)
318+
* config.num_query_heads
319+
* config.head_dim
320+
)
321+
322+
causal_attention_flops = (
323+
num_global_layers * global_causal_flops_per_layer + num_local_layers * local_causal_flops_per_layer
275324
)
276-
causal_attention_flops = noncausal_attention_flops / 2
277325

278326
# Convert to TFLOPs and multiply by 3 for fwd/bwd pass
279327
attention_tflops = causal_attention_flops * 3 / 10**12
280328

281-
# Learnable weights (FFN, QKV, Projections) are present in every layer.
282-
learnable_weight_tflops = ((total_ffn_flops + qkv_flops + projection_flops) * num_layers + embedding_flops) * 3 / 10**12
329+
global_qkv_flops_per_layer = (
330+
2
331+
* config.per_device_batch_size
332+
* config.max_target_length
333+
* config.emb_dim
334+
* (config.num_query_heads + kv_multiplier * global_num_kv_heads)
335+
* global_head_dim
336+
)
337+
global_projection_flops_per_layer = (
338+
2
339+
* config.per_device_batch_size
340+
* config.max_target_length
341+
* config.emb_dim
342+
* config.num_query_heads
343+
* global_head_dim
344+
)
345+
346+
# Local layers never share KV projections (kv_multiplier is always 2).
347+
local_qkv_flops_per_layer = (
348+
2
349+
* config.per_device_batch_size
350+
* config.max_target_length
351+
* config.emb_dim
352+
* (config.num_query_heads + 2 * config.num_kv_heads)
353+
* config.head_dim
354+
)
355+
local_projection_flops_per_layer = (
356+
2
357+
* config.per_device_batch_size
358+
* config.max_target_length
359+
* config.emb_dim
360+
* config.num_query_heads
361+
* config.head_dim
362+
)
363+
364+
total_learnable_flops = total_ffn_flops_all_layers
365+
366+
total_learnable_flops += (
367+
(local_qkv_flops_per_layer + local_projection_flops_per_layer) * num_local_layers
368+
+ (global_qkv_flops_per_layer + global_projection_flops_per_layer) * num_global_layers
369+
+ embedding_flops
370+
)
371+
372+
learnable_weight_tflops = total_learnable_flops * 3 / 10**12
283373

284374
return attention_tflops, learnable_weight_tflops
285375

@@ -496,11 +586,19 @@ def get_dense_moe_layers(config):
496586
elif config.decoder_block == DecoderBlockType.LLAMA4:
497587
num_moe_layers = config.num_decoder_layers // config.interleave_moe_layer_step
498588
num_dense_layers = config.num_decoder_layers - num_moe_layers
589+
return num_dense_layers, num_moe_layers
499590
elif config.decoder_block == DecoderBlockType.QWEN3_NEXT:
591+
return 0, config.num_decoder_layers
592+
elif config.decoder_block == DecoderBlockType.DEFAULT:
593+
raise ValueError("Unsupported decoder block for dense/MoE layer calculation")
594+
595+
num_experts = getattr(config, "num_experts", 0)
596+
if num_experts > 1:
500597
num_moe_layers = config.num_decoder_layers
501598
num_dense_layers = 0
502599
else:
503-
raise ValueError("Currently we only support DeepSeek, Llama4, and Qwen3-Next calculation.")
600+
num_moe_layers = 0
601+
num_dense_layers = config.num_decoder_layers
504602

505603
return num_dense_layers, num_moe_layers
506604

@@ -601,6 +699,7 @@ def calculate_gemma3_vision_layers_tflops_per_device(config):
601699
learnable_weight_flops += 2 * vision_embedder_flops # only projector is learnable, add fwd+optimizer
602700
else:
603701
learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer
702+
total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass
604703

605704
# Convert to TFLOPs
606705
learnable_weight_tflops = learnable_weight_flops / 1e12
@@ -663,6 +762,7 @@ def calculate_llama4_vision_layers_tflops_per_device(config):
663762
learnable_weight_flops += 2 * projector_flops # only projector is learnable, add fwd+optimizer
664763
else:
665764
learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer
765+
total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass
666766

667767
# Convert to TFLOPs
668768
learnable_weight_tflops = learnable_weight_flops / 1e12
@@ -726,28 +826,40 @@ def calculate_vision_encoder_tflops(config):
726826
def calculate_tflops_training_per_device(config, log=True):
727827
"""Calculate training TFLOP"""
728828
# MLP flops
829+
is_ffn_flops_already_total = False
729830
if config.num_experts > 1:
730831
# calculation based on dropless implementation
731-
if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4, DecoderBlockType.QWEN3_NEXT):
832+
if config.decoder_block in (
833+
DecoderBlockType.DEEPSEEK,
834+
DecoderBlockType.LLAMA4,
835+
DecoderBlockType.QWEN3_NEXT,
836+
DecoderBlockType.GEMMA4,
837+
):
732838
total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device(config)
839+
is_ffn_flops_already_total = True
733840
else:
734841
gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts
735842
total_ffn_flops = (
736-
gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) * config.num_experts_per_tok
843+
gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok
737844
)
738845
else:
739846
total_ffn_flops = calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim)
740847

848+
total_ffn_flops_all_layers = (
849+
total_ffn_flops if is_ffn_flops_already_total else total_ffn_flops * config.num_decoder_layers
850+
)
851+
741852
# Attention flops
742853
if config.attention_type == "mla":
743854
qkv_flops, causal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config)
744855
else:
856+
kv_multiplier = 1 if config.share_kv_projections else 2
745857
qkv_flops = (
746858
2
747859
* config.per_device_batch_size
748860
* config.max_target_length
749861
* config.emb_dim
750-
* (config.num_query_heads + 2 * config.num_kv_heads)
862+
* (config.num_query_heads + kv_multiplier * config.num_kv_heads)
751863
* config.head_dim
752864
)
753865
noncausal_attention_flops = (
@@ -768,7 +880,8 @@ def calculate_tflops_training_per_device(config, log=True):
768880
# NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
769881
causal_attention_flops = noncausal_attention_flops / 2
770882

771-
# Embedding flops
883+
# Embedding flops (counts only the unembedding projection; the embedding lookup is a gather operation
884+
# that performs no dense math, matching standard MFU hardware calculations)
772885
embedding_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.vocab_size
773886

774887
# Combine flops with number of decoder layers
@@ -778,26 +891,30 @@ def calculate_tflops_training_per_device(config, log=True):
778891
)
779892
elif config.decoder_block == DecoderBlockType.GEMMA3:
780893
attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device(
781-
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=6
894+
config, total_ffn_flops_all_layers, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=6
782895
)
783896
elif config.decoder_block == DecoderBlockType.GPT_OSS:
784897
attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device(
785-
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=2
898+
config, total_ffn_flops_all_layers, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=2
786899
)
787900
elif config.decoder_block == DecoderBlockType.LLAMA4:
788901
# Use the new helper to calculate attention TFLOPs correctly.
789902
attention_tflops = calculate_llama4_attention_tflops(config)
790903
# The learnable weight calculation remains the same as it correctly handles Llama4's MoE structure.
791904
learnable_weight_tflops = (
792-
(total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
905+
(total_ffn_flops_all_layers + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops)
906+
* 3
907+
/ 10**12
793908
)
794909
elif config.decoder_block == DecoderBlockType.GEMMA4:
795-
attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device(
796-
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=6
910+
attention_tflops, learnable_weight_tflops = calculate_gemma4_tflops_training_per_device(
911+
config, total_ffn_flops_all_layers, embedding_flops, attention_pattern_length=6
797912
)
798913
elif config.decoder_block == DecoderBlockType.DEEPSEEK:
799914
learnable_weight_tflops = (
800-
(total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
915+
(total_ffn_flops_all_layers + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops)
916+
* 3
917+
/ 10**12
801918
)
802919
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12
803920
elif config.decoder_block == DecoderBlockType.QWEN3_NEXT:
@@ -808,7 +925,7 @@ def calculate_tflops_training_per_device(config, log=True):
808925

809926
# Weights TFLOPs:
810927
total_weights = (
811-
total_ffn_flops
928+
total_ffn_flops_all_layers
812929
+ embedding_flops
813930
+ (qkv_flops + projection_flops) * num_full_attn_layers
814931
+ gdn_weight_flops_per_layer * num_linear_attn_layers
@@ -821,7 +938,9 @@ def calculate_tflops_training_per_device(config, log=True):
821938
else:
822939
# multiply by 3 for both feed forward and back propagation flops
823940
learnable_weight_tflops = (
824-
((total_ffn_flops + qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
941+
(total_ffn_flops_all_layers + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops)
942+
* 3
943+
/ 10**12
825944
)
826945
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12
827946

tests/integration/smoke/train_using_ragged_dot_smoke_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_tiny_config(self, quantization: str):
5555
"decoder_block=deepseek",
5656
"attention_type=mla",
5757
"num_experts=2",
58+
"shared_experts=1",
5859
# Enable sparse_matmul.
5960
"sparse_matmul=True",
6061
# Enable ragged_dot.

0 commit comments

Comments
 (0)