@@ -373,6 +373,8 @@ def test_moe_dropping_bf16(self):
373373 "use_iota_embed=true" ,
374374 "compile_topology_num_slices=1" ,
375375 "model_name=mixtral-8x7b" ,
376+ "override_model_config=true" ,
377+ "base_num_decoder_layers=8" ,
376378 "sparse_matmul=False" ,
377379 "capacity_factor=1" ,
378380 "per_device_batch_size=4" ,
@@ -442,6 +444,8 @@ def test_moe_megablox_ring_ep_random(self):
442444 "use_iota_embed=true" ,
443445 "compile_topology_num_slices=1" ,
444446 "model_name=deepseek3-test" ,
447+ "override_model_config=true" ,
448+ "base_num_decoder_layers=8" ,
445449 "sparse_matmul=True" ,
446450 "megablox=True" ,
447451 "per_device_batch_size=4" ,
@@ -466,6 +470,8 @@ def test_moe_ragged_dot_bf16(self):
466470 "use_iota_embed=true" ,
467471 "compile_topology_num_slices=1" ,
468472 "model_name=mixtral-8x7b" ,
473+ "override_model_config=true" ,
474+ "base_num_decoder_layers=8" ,
469475 "sparse_matmul=True" ,
470476 "megablox=False" ,
471477 "per_device_batch_size=4" ,
@@ -488,6 +494,8 @@ def test_moe_dense_bf16(self):
488494 "use_iota_embed=true" ,
489495 "compile_topology_num_slices=1" ,
490496 "model_name=mixtral-8x7b" ,
497+ "override_model_config=true" ,
498+ "base_num_decoder_layers=8" ,
491499 "sparse_matmul=False" ,
492500 "capacity_factor=-1" ,
493501 "per_device_batch_size=4" ,
@@ -606,6 +614,8 @@ def test_moe_deepseek_with_device_limit(self):
606614 "use_iota_embed=true" ,
607615 "compile_topology_num_slices=1" ,
608616 "model_name=deepseek3-test" ,
617+ "override_model_config=true" ,
618+ "base_num_decoder_layers=8" ,
609619 "sparse_matmul=True" ,
610620 "megablox=False" ,
611621 "per_device_batch_size=1" ,
@@ -626,8 +636,8 @@ def test_moe_deepseek_pipeline_subset(self):
626636 "" ,
627637 get_test_config_path (),
628638 f"compiled_trainstep_file={ compiled_trainstep_file } " ,
629- "compile_topology=v5p-64 " ,
630- "compile_topology_num_slices=8 " ,
639+ "compile_topology=v5p-8 " ,
640+ "compile_topology_num_slices=2 " ,
631641 "use_iota_embed=true" ,
632642 "model_name=deepseek3-test" ,
633643 "megablox=True" ,
@@ -636,8 +646,9 @@ def test_moe_deepseek_pipeline_subset(self):
636646 "per_device_batch_size=1" ,
637647 "max_target_length=1024" ,
638648 "pipeline_parallel_layers=56" ,
639- "ici_expert_parallelism=16" ,
640- "dcn_pipeline_parallelism=8" ,
649+ "ici_expert_parallelism=4" ,
650+ "ici_fsdp_parallelism=1" ,
651+ "dcn_pipeline_parallelism=2" ,
641652 )
642653 )
643654
@@ -669,22 +680,23 @@ def test_moe_llama4_17b_16e(self):
669680 "" ,
670681 get_test_config_path (),
671682 f"compiled_trainstep_file={ compiled_trainstep_file } " ,
672- "compile_topology=v5p-128 " ,
683+ "compile_topology=v5p-16 " ,
673684 "compile_topology_num_slices=1" ,
674685 "model_name=llama4-17b-16e" ,
686+ "override_model_config=true" ,
687+ "base_num_decoder_layers=4" ,
675688 "per_device_batch_size=1" ,
676689 "max_target_length=1024" ,
677690 "dtype=bfloat16" ,
678691 "weight_dtype=bfloat16" ,
679692 "scan_layers=True" ,
680- "ici_fsdp_parallelism=16 " ,
681- "ici_tensor_parallelism=4 " ,
693+ "ici_fsdp_parallelism=4 " ,
694+ "ici_tensor_parallelism=2 " ,
682695 )
683696 )
684697
685- @pytest .mark .cpu_only
686- def test_moe_gpt_oss_20b_sparse_matmul (self ):
687- compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_sparse_matmul.pickle"
698+ def _run_moe_gpt_oss_20b (self , suffix , matmul_args ):
699+ compiled_trainstep_file = f"/tmp/test_moe_gpt_oss_20b_{ suffix } .pickle"
688700 train_compile_main (
689701 (
690702 "" ,
@@ -693,38 +705,25 @@ def test_moe_gpt_oss_20b_sparse_matmul(self):
693705 "compile_topology=v5p-16" ,
694706 "compile_topology_num_slices=1" ,
695707 "model_name=gpt-oss-20b" ,
708+ "override_model_config=true" ,
709+ "base_num_decoder_layers=8" ,
696710 "per_device_batch_size=1" ,
697711 "max_target_length=1024" ,
698712 "dtype=bfloat16" ,
699713 "weight_dtype=bfloat16" ,
700714 "scan_layers=True" ,
701- "sparse_matmul=True" ,
702- "megablox=True" ,
703715 "attention=flash" ,
716+ * matmul_args ,
704717 )
705718 )
706719
720+ @pytest .mark .cpu_only
721+ def test_moe_gpt_oss_20b_sparse_matmul (self ):
722+ self ._run_moe_gpt_oss_20b ("sparse_matmul" , ["sparse_matmul=True" , "megablox=True" ])
723+
707724 @pytest .mark .cpu_only
708725 def test_moe_gpt_oss_20b_dense_matmul (self ):
709- compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_dense_matmul.pickle"
710- train_compile_main (
711- (
712- "" ,
713- get_test_config_path (),
714- f"compiled_trainstep_file={ compiled_trainstep_file } " ,
715- "compile_topology=v5p-16" ,
716- "compile_topology_num_slices=1" ,
717- "model_name=gpt-oss-20b" ,
718- "per_device_batch_size=1" ,
719- "max_target_length=1024" ,
720- "dtype=bfloat16" ,
721- "weight_dtype=bfloat16" ,
722- "scan_layers=True" ,
723- "sparse_matmul=False" ,
724- "capacity_factor=-1" ,
725- "attention=flash" ,
726- )
727- )
726+ self ._run_moe_gpt_oss_20b ("dense_matmul" , ["sparse_matmul=False" , "capacity_factor=-1" ])
728727
729728 @pytest .mark .cpu_only
730729 def test_gpt3_6b (self ):
@@ -867,6 +866,8 @@ def test_olmo3_7b(self):
867866 "compile_topology=v5p-8" ,
868867 "compile_topology_num_slices=1" ,
869868 "model_name=olmo3-7b" ,
869+ "override_model_config=true" ,
870+ "base_num_decoder_layers=8" ,
870871 "per_device_batch_size=1" ,
871872 "scan_layers=True" ,
872873 "max_target_length=1024" ,
0 commit comments