Skip to content

Commit 142bcf1

Browse files
Merge pull request #3759 from AI-Hypercomputer:gagik-qwen-flops
PiperOrigin-RevId: 907105602
2 parents c1d057b + c964555 commit 142bcf1

2 files changed

Lines changed: 58 additions & 9 deletions

File tree

src/maxtext/utils/maxtext_utils.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -573,18 +573,22 @@ def calculate_mla_tflops_per_device(config):
573573
return qkv_flops, attention_flops, projection_flops
574574

575575

576-
def calculate_ffn_mamtul_tflops_per_device(config, mlp_dim):
576+
def calculate_ffn_mamtul_tflops_per_device(config, mlp_dim, in_dim=None):
577577
"""Helper function to calculate matmul TFLOP in ffn based on MLP dimension.
578578
579579
Applies to:
580580
- Dense FFN layers (mlp_dim = config.mlp_dim).
581581
- MoE FFN layers (mlp_dim = config.moe_mlp_dim),
582582
need to scale by shared_experts or num_experts_per_tok.
583+
- Architectures that compress to a latent before the FFN (e.g. qwen3_custom_moe)
584+
pass ``in_dim=config.moe_expert_input_dim``; defaults to ``config.emb_dim``.
583585
"""
586+
if in_dim is None:
587+
in_dim = config.emb_dim
584588
ffn1_flops = (
585-
2 * config.per_device_batch_size * config.max_target_length * mlp_dim * config.emb_dim * len(config.mlp_activations)
589+
2 * config.per_device_batch_size * config.max_target_length * mlp_dim * in_dim * len(config.mlp_activations)
586590
)
587-
ffn2_flops = 2 * config.per_device_batch_size * config.max_target_length * mlp_dim * config.emb_dim
591+
ffn2_flops = 2 * config.per_device_batch_size * config.max_target_length * mlp_dim * in_dim
588592
return ffn1_flops + ffn2_flops
589593

590594

@@ -861,6 +865,14 @@ def calculate_tflops_training_per_device(config, log=True):
861865
):
862866
total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device(config)
863867
is_ffn_flops_already_total = True
868+
elif config.decoder_block == DecoderBlockType.QWEN3_CUSTOM_MOE:
869+
# MoE operates at moe_expert_input_dim (compressed latent), not emb_dim.
870+
in_dim = config.moe_expert_input_dim
871+
gate_flops = 2 * config.per_device_batch_size * config.max_target_length * in_dim * config.num_experts
872+
total_ffn_flops = (
873+
gate_flops
874+
+ calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim, in_dim=in_dim) * config.num_experts_per_tok
875+
)
864876
else:
865877
gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts
866878
total_ffn_flops = (
@@ -941,6 +953,24 @@ def calculate_tflops_training_per_device(config, log=True):
941953
/ 10**12
942954
)
943955
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12
956+
elif config.decoder_block == DecoderBlockType.QWEN3_CUSTOM_MOE:
957+
# Attention output projects (num_query_heads * head_dim) -> attention_output_dim, not -> emb_dim.
958+
qwen3_custom_proj_flops = (
959+
2
960+
* config.per_device_batch_size
961+
* config.max_target_length
962+
* config.attention_output_dim
963+
* config.num_query_heads
964+
* config.head_dim
965+
)
966+
# Each layer has a final up-projection: attention_output_dim -> emb_dim.
967+
layer_up_proj_flops = (
968+
2 * config.per_device_batch_size * config.max_target_length * config.attention_output_dim * config.emb_dim
969+
)
970+
per_layer_flops = qkv_flops + qwen3_custom_proj_flops + layer_up_proj_flops
971+
total_weight_flops = total_ffn_flops_all_layers + per_layer_flops * config.num_decoder_layers + embedding_flops
972+
learnable_weight_tflops = total_weight_flops * 3 / 10**12
973+
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12
944974
elif config.decoder_block == DecoderBlockType.QWEN3_NEXT:
945975
gdn_weight_flops_per_layer, gdn_attn_flops_per_layer = calculate_gated_delta_net_flops_per_device(config)
946976
cycle_interval = config.inhomogeneous_layer_cycle_interval
@@ -1386,18 +1416,14 @@ def setup_initial_state(
13861416
out_shardings=state_mesh_shardings,
13871417
)()
13881418
sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m
1389-
if (
1390-
sparsity_enabled and raw_params
1391-
): # If we loaded a partial state, we need to merge it.
1419+
if sparsity_enabled and raw_params: # If we loaded a partial state, we need to merge it.
13921420

13931421
def _merge_params(p_raw, p_init):
13941422
if isinstance(p_raw, jax.ShapeDtypeStruct):
13951423
return p_init
13961424
return p_raw
13971425

1398-
merged_params = jax.tree_util.tree_map(
1399-
_merge_params, raw_params, state.params
1400-
)
1426+
merged_params = jax.tree_util.tree_map(_merge_params, raw_params, state.params)
14011427
state = state.replace(params=merged_params)
14021428
elif raw_params:
14031429
state = state.replace(params=raw_params)

tests/unit/flop_calculation_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,29 @@ def test_gpt_oss_20b_flops(self):
335335
calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg)
336336
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)
337337

338+
def test_qwen3_custom_30b_a3b_flops(self):
339+
"""Test Qwen3 Custom 30B-A3B (compressed-latent MoE) FLOPs calculation.
340+
341+
The custom variant compresses attention output and MoE input to
342+
attention_output_dim = moe_expert_input_dim = 768 (vs emb_dim = 2048),
343+
then up-projects 768 -> 2048 once per layer. ~2.86B active parameters
344+
per token (48 layers x (Q/K/V/O + gate + 8 routed experts + up_proj)
345+
+ unembedding).
346+
"""
347+
cfg = self._initialize_model_config(
348+
"qwen3-custom-30b-a3b",
349+
max_target_length=2048,
350+
per_device_batch_size=4,
351+
)
352+
kwargs = cfg.get_keys()
353+
B = cfg.per_device_batch_size
354+
S = cfg.max_target_length
355+
attention_flops = self.compute_regular_attention_flops_per_device(kwargs)
356+
golden_param_size = 2.86e9 # active params per token
357+
golden_tflops = 6 * B * S * golden_param_size / 1e12 + attention_flops
358+
calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg)
359+
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)
360+
338361
def test_deepseek32_671b_flops(self):
339362
"""Test DeepSeek3.2-671b FLops calculation"""
340363
cfg = self._initialize_model_config(

0 commit comments

Comments
 (0)