Skip to content

Commit 3e4e02b

Browse files
committed
Add Llama4 tflops calculation
1 parent 1efc736 commit 3e4e02b

2 files changed

Lines changed: 26 additions & 7 deletions

File tree

MaxText/maxtext_utils.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,26 +203,41 @@ def calculate_ffn_mamtul_tflops_per_device(config, mlp_dim):
203203
return ffn1_flops + ffn2_flops
204204

205205

206-
def calculate_deepseek_ffn_tflops_per_device(config):
206+
def calculate_routed_and_shared_ffn_tflops_per_device(config):
207207
"""Helper function to calculate DeepSeek-style ffn TFLOP"""
208208
gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts
209209
# Due to the mixed decoder layers, the flops is multiplied by num of layers for both dense and moe
210-
dense_ffn_flops = calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) * config.first_num_dense_layers
210+
num_dense_layers, num_moe_layers = get_dense_moe_layers(config)
211+
dense_ffn_flops = calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) * num_dense_layers
211212
shared_experts_flops = calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.shared_experts
212213
routed_experts_flops = calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok
213-
moe_layers = config.num_decoder_layers - config.first_num_dense_layers
214-
moe_ffn_flops = (gate_flops + shared_experts_flops + routed_experts_flops) * moe_layers
214+
moe_ffn_flops = (gate_flops + shared_experts_flops + routed_experts_flops) * num_moe_layers
215215
total_ffn_flops = dense_ffn_flops + moe_ffn_flops
216216
return total_ffn_flops
217217

218218

219+
def get_dense_moe_layers(config):
220+
"""Helper function to calculate number of dense and moe layers"""
221+
if config.decoder_block == DecoderBlockType.DEEPSEEK:
222+
num_dense_layers = config.first_num_dense_layers
223+
num_moe_layers = config.num_decoder_layers - config.first_num_dense_layers
224+
return num_dense_layers, num_moe_layers
225+
elif config.decoder_block == DecoderBlockType.LLAMA4:
226+
num_moe_layers = config.num_decoder_layers // config.interleave_moe_layer_step
227+
num_dense_layers = config.num_decoder_layers - num_moe_layers
228+
else:
229+
raise ValueError("Currently we only support DeepSeek and Llama4 calculation.")
230+
231+
return num_dense_layers, num_moe_layers
232+
233+
219234
def calculate_tflops_training_per_device(config, log=True):
220235
"""Calculate training TFLOP"""
221236
# MLP flops
222237
if config.num_experts > 1:
223238
# calculation based on dropless implementation
224-
if config.decoder_block == DecoderBlockType.DEEPSEEK:
225-
total_ffn_flops = calculate_deepseek_ffn_tflops_per_device(config)
239+
if config.decoder_block == DecoderBlockType.DEEPSEEK or config.decoder_block == DecoderBlockType.LLAMA4:
240+
total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device(config)
226241
else:
227242
gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts
228243
total_ffn_flops = (
@@ -263,7 +278,7 @@ def calculate_tflops_training_per_device(config, log=True):
263278
attention_tflops, learnable_weight_tflops = calculate_gemma2_tflops_training_per_device(
264279
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops
265280
)
266-
elif config.decoder_block == DecoderBlockType.DEEPSEEK:
281+
elif config.decoder_block == DecoderBlockType.DEEPSEEK or config.decoder_block == DecoderBlockType.LLAMA4:
267282
learnable_weight_tflops = (
268283
(total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
269284
)

MaxText/pyconfig.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ def validate_llama4_config(keys: dict):
262262
raise ValueError("Llama4 decoder has not been tested with capacity_factor >= 0 -- please set that value to -1 for now!")
263263
if keys["num_experts_per_tok"] > 1:
264264
raise ValueError("Only top-1 routing is supported for Llama4 for now!")
265+
if keys["base_num_decoder_layers"] % keys["interleave_moe_layer_step"] != 0:
266+
raise ValueError(
267+
f"The number of decoder layers ({keys['base_num_decoder_layers']}) must be divisible by interleave moe layer step ({keys['interleave_moe_layer_step']})"
268+
)
265269

266270

267271
def validate_model_name(s: str) -> bool:

0 commit comments

Comments
 (0)