Skip to content

Commit 62f0f17

Browse files
author
maxtext authors
committed
Merge pull request #1673 from AI-Hypercomputer:llama4_benchmark
PiperOrigin-RevId: 754175426
2 parents 18ccd37 + b923e59 commit 62f0f17

4 files changed

Lines changed: 6 additions & 7 deletions

File tree

MaxText/configs/models/llama4-17b-128e.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,12 @@ rope_max_timescale: 500000
3131
rope_type: "llama3.1"
3232
rope_use_scale: False
3333
num_experts: 128
34-
capacity_factor: -1.0 # TODO: this will be removed once we support dropless with megablox/ragged_dot
3534
shared_experts: 1
3635
num_experts_per_tok: 1
3736
use_qk_norm: False
3837
nope_layer_interval: 4 # Every fourth layer should NOT use RoPE
3938
interleave_moe_layer_step: 2 # Every 2nd layer is MoE layer, and 1st layer is dense layer
4039

41-
# TODO: delete the following variables once we add support for dropless with megablox/ragged_dot
42-
sparse_matmul: False
43-
megablox: False
44-
4540
temperature_tuning: True
4641
# Chunk attention is used on all RoPE layers
4742
# otherwise, on NoPE layers, use global attention

MaxText/configs/models/llama4-17b-16e.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ normalization_layer_epsilon: 1e-05
3131
rope_max_timescale: 500000
3232
rope_type: "llama3.1"
3333
num_experts: 16
34-
capacity_factor: -1.0 # TODO: this will be removed once we support dropless with megablox/ragged_dot
3534
shared_experts: 1
3635
num_experts_per_tok: 1
3736
use_qk_norm: True # Llama4 models apply an L2Norm to the Query and Keys after RoPE

MaxText/pyconfig.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@ def validate_llama4_config(keys: dict):
262262
raise ValueError("Llama4 decoder has not been tested with capacity_factor >= 0 -- please set that value to -1 for now!")
263263
if keys["num_experts_per_tok"] > 1:
264264
raise ValueError("Only top-1 routing is supported for Llama4 for now!")
265+
if keys["scan_layers"]:
266+
raise ValueError("Only unscanned layer is supported for Llama4, and please set scan_layers=False for now!")
265267
if keys["base_num_decoder_layers"] % keys["interleave_moe_layer_step"] != 0:
266268
raise ValueError(
267269
f"The number of decoder layers ({keys['base_num_decoder_layers']}) must be divisible by interleave moe layer step ({keys['interleave_moe_layer_step']})"

MaxText/tests/train_compile_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,12 +567,15 @@ def test_moe_llama4_17b_16e(self):
567567
None,
568568
os.path.join(PKG_DIR, "configs", "base.yml"),
569569
f"compiled_trainstep_file={compiled_trainstep_file}",
570-
"compile_topology=v6e-256",
570+
"compile_topology=v5p-256",
571571
"compile_topology_num_slices=1",
572572
"model_name=llama4-17b-16e",
573573
"per_device_batch_size=1",
574574
"max_target_length=1024",
575575
"dtype=bfloat16",
576576
"weight_dtype=bfloat16",
577+
"scan_layers=False",
578+
"ici_fsdp_parallelism=32",
579+
"ici_tensor_parallelism=4",
577580
)
578581
)

0 commit comments

Comments
 (0)