@@ -626,8 +626,8 @@ def test_moe_deepseek_pipeline_subset(self):
626626 "" ,
627627 get_test_config_path (),
628628 f"compiled_trainstep_file={ compiled_trainstep_file } " ,
629- "compile_topology=v5p-64 " ,
630- "compile_topology_num_slices=8 " ,
629+ "compile_topology=v5p-8 " ,
630+ "compile_topology_num_slices=2 " ,
631631 "use_iota_embed=true" ,
632632 "model_name=deepseek3-test" ,
633633 "megablox=True" ,
@@ -636,8 +636,8 @@ def test_moe_deepseek_pipeline_subset(self):
636636 "per_device_batch_size=1" ,
637637 "max_target_length=1024" ,
638638 "pipeline_parallel_layers=56" ,
639- "ici_expert_parallelism=16 " ,
640- "dcn_pipeline_parallelism=8 " ,
639+ "ici_expert_parallelism=8 " ,
640+ "dcn_pipeline_parallelism=2 " ,
641641 )
642642 )
643643
@@ -669,22 +669,22 @@ def test_moe_llama4_17b_16e(self):
669669 "" ,
670670 get_test_config_path (),
671671 f"compiled_trainstep_file={ compiled_trainstep_file } " ,
672- "compile_topology=v5p-128 " ,
672+ "compile_topology=v5p-16 " ,
673673 "compile_topology_num_slices=1" ,
674674 "model_name=llama4-17b-16e" ,
675675 "per_device_batch_size=1" ,
676676 "max_target_length=1024" ,
677677 "dtype=bfloat16" ,
678678 "weight_dtype=bfloat16" ,
679679 "scan_layers=True" ,
680- "ici_fsdp_parallelism=16" ,
681- "ici_tensor_parallelism=4" ,
680+ "ici_fsdp_parallelism=4" ,
681+ "ici_tensor_parallelism=2" ,
682+ "ici_expert_parallelism=2" ,
682683 )
683684 )
684685
685- @pytest .mark .cpu_only
686- def test_moe_gpt_oss_20b_sparse_matmul (self ):
687- compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_sparse_matmul.pickle"
686+ def _run_moe_gpt_oss_20b (self , suffix , matmul_args ):
687+ compiled_trainstep_file = f"/tmp/test_moe_gpt_oss_20b_{ suffix } .pickle"
688688 train_compile_main (
689689 (
690690 "" ,
@@ -698,33 +698,18 @@ def test_moe_gpt_oss_20b_sparse_matmul(self):
698698 "dtype=bfloat16" ,
699699 "weight_dtype=bfloat16" ,
700700 "scan_layers=True" ,
701- "sparse_matmul=True" ,
702- "megablox=True" ,
703701 "attention=flash" ,
702+ * matmul_args ,
704703 )
705704 )
706705
706+ @pytest .mark .cpu_only
707+ def test_moe_gpt_oss_20b_sparse_matmul (self ):
708+ self ._run_moe_gpt_oss_20b ("sparse_matmul" , ["sparse_matmul=True" , "megablox=True" ])
709+
707710 @pytest .mark .cpu_only
708711 def test_moe_gpt_oss_20b_dense_matmul (self ):
709- compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_dense_matmul.pickle"
710- train_compile_main (
711- (
712- "" ,
713- get_test_config_path (),
714- f"compiled_trainstep_file={ compiled_trainstep_file } " ,
715- "compile_topology=v5p-16" ,
716- "compile_topology_num_slices=1" ,
717- "model_name=gpt-oss-20b" ,
718- "per_device_batch_size=1" ,
719- "max_target_length=1024" ,
720- "dtype=bfloat16" ,
721- "weight_dtype=bfloat16" ,
722- "scan_layers=True" ,
723- "sparse_matmul=False" ,
724- "capacity_factor=-1" ,
725- "attention=flash" ,
726- )
727- )
712+ self ._run_moe_gpt_oss_20b ("dense_matmul" , ["sparse_matmul=False" , "capacity_factor=-1" ])
728713
729714 @pytest .mark .cpu_only
730715 def test_gpt3_6b (self ):
@@ -766,7 +751,7 @@ def test_qwen3_next(self):
766751 "" ,
767752 get_test_config_path (),
768753 f"compiled_trainstep_file={ compiled_trainstep_file } " ,
769- "compile_topology=v5p-256 " ,
754+ "compile_topology=v5p-8 " ,
770755 "compile_topology_num_slices=1" ,
771756 "model_name=qwen3-next-80b-a3b" ,
772757 "per_device_batch_size=1" ,
0 commit comments