Skip to content

Commit f3099c3

Browse files
committed
update test
1 parent 7f78228 commit f3099c3

2 files changed

Lines changed: 50 additions & 19 deletions

File tree

README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,25 +117,25 @@ MaxText aims to provide you with the best OSS models, whether as a reference imp
117117
* Gemma 2 (2B, 9B, 27B)
118118
* Gemma 1 (2B, 7B)
119119
* Alibaba
120-
* Qwen 2.5 (1.5B, 7B, 14B)
121-
* Qwen 3 MoE 2507 (235B, 480B)
122-
* Qwen 3 MoE (30B, 235B)
120+
* Qwen 3 Next (80B)
121+
* Qwen 3 MoE (30B, 235B), Qwen 3 MoE 2507 (235B, 480B)
123122
* Qwen 3 Dense (0.6B, 1.7B, 4B, 8B, 14B, 32B)
124-
* DeepSeek
123+
* Qwen 2.5 (1.5B, 7B, 14B)
124+
* DeepSeek AI
125125
* DeepSeek V3.2 (671B)
126126
* DeepSeek V3.1 (671B)
127-
* DeepSeek V3 0324 (671B) & DeepSeek R1 0528 (671B)
127+
* DeepSeek V3 0324 (671B), DeepSeek R1 0528 (671B)
128128
* DeepSeek V2 (16B, 236B)
129-
* Kimi
130-
* Kimi K2
129+
* Moonshot AI
130+
* Kimi K2 (1T)
131131
* Meta
132132
* Llama 4 Scout (109B) & Maverick (400B)
133-
* Llama 3.3 70B, 3.1 (8B, 70B, 405B), 3.0 (8B, 70B, 405B)
133+
* Llama 3.3 (70B), 3.1 (8B, 70B, 405B), 3.0 (8B, 70B, 405B)
134134
* Llama 2 (7B, 13B, 70B)
135-
* Open AI
135+
* OpenAI
136136
* GPT-OSS (20B, 120B)
137137
* GPT3 (52K, 6B, 22B, 175B)
138-
* Mistral
138+
* Mistral AI
139139
* Mixtral (8x7B, 8x22B)
140140
* Mistral (7B)
141141
* Diffusion Models

tests/unit/train_compile_test.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ def test_gpt3_6b(self):
734734
"",
735735
get_test_config_path(),
736736
f"compiled_trainstep_file={compiled_trainstep_file}",
737-
"compile_topology=v5p-256",
737+
"compile_topology=v5p-8",
738738
"compile_topology_num_slices=1",
739739
"model_name=gpt3-6b",
740740
"per_device_batch_size=1",
@@ -766,7 +766,7 @@ def test_qwen3_next(self):
766766
"",
767767
get_test_config_path(),
768768
f"compiled_trainstep_file={compiled_trainstep_file}",
769-
"compile_topology=v5p-256",
769+
"compile_topology=v5p-64",
770770
"compile_topology_num_slices=1",
771771
"model_name=qwen3-next-80b-a3b",
772772
"per_device_batch_size=1",
@@ -796,9 +796,6 @@ def test_deepseek32(self):
796796
"use_tokamax_splash=True",
797797
"dtype=bfloat16",
798798
"weight_dtype=bfloat16",
799-
# without_device_limit
800-
"n_routing_groups=-1",
801-
"topk_routing_group=-1",
802799
)
803800
)
804801

@@ -948,9 +945,9 @@ def test_circular_pipeline_ag_per_repeat_ep_ds(self):
948945
)
949946

950947
@pytest.mark.cpu_only
951-
def test_qk_clip(self):
952-
"""AOT test for qk-clip with DeepSeek3 Tiny model"""
953-
compiled_trainstep_file = "/tmp/test_qk_clip.pickle"
948+
def test_qk_clip_with_dot_product(self):
949+
"""AOT test for AdamW optimizer with QK clip on dot product attention for DeepSeek3 Tiny model"""
950+
compiled_trainstep_file = "/tmp/test_qk_clip_with_dot_product.pickle"
954951
train_compile_main(
955952
(
956953
"",
@@ -963,13 +960,47 @@ def test_qk_clip(self):
963960
"sparse_matmul=True",
964961
"megablox=True",
965962
"use_tokamax_gmm=False",
966-
# TODO(agagik): update to flash after support
963+
"max_target_length=128",
964+
"per_device_batch_size=1",
965+
"dtype=bfloat16",
966+
"weight_dtype=float32",
967+
# dot product
967968
"attention=dot_product",
968969
"use_tokamax_splash=True",
970+
# qk
971+
"use_qk_clip=true",
972+
"qk_clip_threshold=100",
973+
)
974+
)
975+
976+
@pytest.mark.cpu_only
977+
def test_muon_clip_with_tokamax_splash(self):
978+
"""AOT test for Muon optimizer with QK clip on tokamax splash attention for DeepSeek3 Tiny model"""
979+
compiled_trainstep_file = "/tmp/test_muon_clip_with_tokamax_splash.pickle"
980+
train_compile_main(
981+
(
982+
"",
983+
get_test_config_path(),
984+
f"compiled_trainstep_file={compiled_trainstep_file}",
985+
"compile_topology=v5p-8",
986+
"compile_topology_num_slices=1",
987+
"model_name=deepseek3-tiny",
988+
"scan_layers=True",
989+
"sparse_matmul=True",
990+
"megablox=True",
991+
"use_tokamax_gmm=False",
969992
"max_target_length=128",
970993
"per_device_batch_size=1",
971994
"dtype=bfloat16",
972995
"weight_dtype=float32",
996+
# tokamax splash
997+
"attention=flash",
998+
"use_tokamax_splash=True",
999+
# muon
1000+
"opt_type=muon",
1001+
"muon_consistent_rms=0.2",
1002+
"muon_weight_decay=0.1",
1003+
# qk
9731004
"use_qk_clip=true",
9741005
"qk_clip_threshold=100",
9751006
)

0 commit comments

Comments
 (0)