@@ -573,18 +573,22 @@ def calculate_mla_tflops_per_device(config):
573573 return qkv_flops , attention_flops , projection_flops
574574
575575
576- def calculate_ffn_mamtul_tflops_per_device (config , mlp_dim ):
576+ def calculate_ffn_mamtul_tflops_per_device (config , mlp_dim , in_dim = None ):
577577 """Helper function to calculate matmul TFLOP in ffn based on MLP dimension.
578578
579579 Applies to:
580580 - Dense FFN layers (mlp_dim = config.mlp_dim).
581581 - MoE FFN layers (mlp_dim = config.moe_mlp_dim),
582582 need to scale by shared_experts or num_experts_per_tok.
583+ - Architectures that compress to a latent before the FFN (e.g. qwen3_custom_moe)
584+ pass ``in_dim=config.moe_expert_input_dim``; defaults to ``config.emb_dim``.
583585 """
586+ if in_dim is None :
587+ in_dim = config .emb_dim
584588 ffn1_flops = (
585- 2 * config .per_device_batch_size * config .max_target_length * mlp_dim * config . emb_dim * len (config .mlp_activations )
589+ 2 * config .per_device_batch_size * config .max_target_length * mlp_dim * in_dim * len (config .mlp_activations )
586590 )
587- ffn2_flops = 2 * config .per_device_batch_size * config .max_target_length * mlp_dim * config . emb_dim
591+ ffn2_flops = 2 * config .per_device_batch_size * config .max_target_length * mlp_dim * in_dim
588592 return ffn1_flops + ffn2_flops
589593
590594
@@ -861,6 +865,14 @@ def calculate_tflops_training_per_device(config, log=True):
861865 ):
862866 total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device (config )
863867 is_ffn_flops_already_total = True
868+ elif config .decoder_block == DecoderBlockType .QWEN3_CUSTOM_MOE :
869+ # MoE operates at moe_expert_input_dim (compressed latent), not emb_dim.
870+ in_dim = config .moe_expert_input_dim
871+ gate_flops = 2 * config .per_device_batch_size * config .max_target_length * in_dim * config .num_experts
872+ total_ffn_flops = (
873+ gate_flops
874+ + calculate_ffn_mamtul_tflops_per_device (config , config .moe_mlp_dim , in_dim = in_dim ) * config .num_experts_per_tok
875+ )
864876 else :
865877 gate_flops = 2 * config .per_device_batch_size * config .max_target_length * config .emb_dim * config .num_experts
866878 total_ffn_flops = (
@@ -941,6 +953,24 @@ def calculate_tflops_training_per_device(config, log=True):
941953 / 10 ** 12
942954 )
943955 attention_tflops = causal_attention_flops * config .num_decoder_layers * 3 / 10 ** 12
956+ elif config .decoder_block == DecoderBlockType .QWEN3_CUSTOM_MOE :
957+ # Attention output projects (num_query_heads * head_dim) -> attention_output_dim, not -> emb_dim.
958+ qwen3_custom_proj_flops = (
959+ 2
960+ * config .per_device_batch_size
961+ * config .max_target_length
962+ * config .attention_output_dim
963+ * config .num_query_heads
964+ * config .head_dim
965+ )
966+ # Each layer has a final up-projection: attention_output_dim -> emb_dim.
967+ layer_up_proj_flops = (
968+ 2 * config .per_device_batch_size * config .max_target_length * config .attention_output_dim * config .emb_dim
969+ )
970+ per_layer_flops = qkv_flops + qwen3_custom_proj_flops + layer_up_proj_flops
971+ total_weight_flops = total_ffn_flops_all_layers + per_layer_flops * config .num_decoder_layers + embedding_flops
972+ learnable_weight_tflops = total_weight_flops * 3 / 10 ** 12
973+ attention_tflops = causal_attention_flops * config .num_decoder_layers * 3 / 10 ** 12
944974 elif config .decoder_block == DecoderBlockType .QWEN3_NEXT :
945975 gdn_weight_flops_per_layer , gdn_attn_flops_per_layer = calculate_gated_delta_net_flops_per_device (config )
946976 cycle_interval = config .inhomogeneous_layer_cycle_interval
@@ -1386,18 +1416,14 @@ def setup_initial_state(
13861416 out_shardings = state_mesh_shardings ,
13871417 )()
13881418 sparsity_enabled = config .weight_sparsity_n and config .weight_sparsity_m
1389- if (
1390- sparsity_enabled and raw_params
1391- ): # If we loaded a partial state, we need to merge it.
1419+ if sparsity_enabled and raw_params : # If we loaded a partial state, we need to merge it.
13921420
13931421 def _merge_params (p_raw , p_init ):
13941422 if isinstance (p_raw , jax .ShapeDtypeStruct ):
13951423 return p_init
13961424 return p_raw
13971425
1398- merged_params = jax .tree_util .tree_map (
1399- _merge_params , raw_params , state .params
1400- )
1426+ merged_params = jax .tree_util .tree_map (_merge_params , raw_params , state .params )
14011427 state = state .replace (params = merged_params )
14021428 elif raw_params :
14031429 state = state .replace (params = raw_params )
0 commit comments