4747 AttentionAndDistributedPackingConfig ,
4848 FastMoeConfig ,
4949 FusedOpsAndKernelsConfig ,
50+ MCPConfig ,
5051 ODMConfig ,
5152 QuantizedLoraConfig ,
5253 get_additional_accel_framework_callbacks ,
@@ -87,6 +88,7 @@ def train(
8788 AttentionAndDistributedPackingConfig
8889 ] = None ,
8990 fast_moe_config : Optional [FastMoeConfig ] = None ,
91+ mcp_config : Optional [MCPConfig ] = None ,
9092 additional_data_handlers : Optional [Dict [str , DataHandler ]] = None ,
9193) -> tuple [SFTTrainer , dict ]:
9294 """Call the SFTTrainer
@@ -198,6 +200,8 @@ def train(
198200 )
199201 if fast_moe_config is not None and fast_moe_config .fast_moe is None :
200202 fast_moe_config = None
203+ if mcp_config is not None and mcp_config .cp is None :
204+ mcp_config = None
201205 if fast_moe_config is not None :
202206 # If LoRA with ScatterMoE detected, raise warning
203207 accepted_layers = ["all-linear" ]
@@ -261,6 +265,7 @@ def train(
261265 quantized_lora_config ,
262266 fusedops_kernels_config ,
263267 odm_config ,
268+ mcp_config ,
264269 ).get_framework ()
265270
266271 # option to set multimodal var here
@@ -567,6 +572,7 @@ def get_parser():
567572 FusedOpsAndKernelsConfig ,
568573 AttentionAndDistributedPackingConfig ,
569574 FastMoeConfig ,
575+ MCPConfig ,
570576 TrackerConfigs ,
571577 )
572578 )
@@ -648,6 +654,7 @@ def parse_arguments(parser, json_config=None):
648654 fusedops_kernels_config ,
649655 attention_and_distributed_packing_config ,
650656 fast_moe_config ,
657+ mcp_config ,
651658 tracker_configs ,
652659 ) = parser .parse_dict (json_config , allow_extra_keys = True )
653660 peft_method = json_config .get ("peft_method" )
@@ -667,6 +674,7 @@ def parse_arguments(parser, json_config=None):
667674 fusedops_kernels_config ,
668675 attention_and_distributed_packing_config ,
669676 fast_moe_config ,
677+ mcp_config ,
670678 tracker_configs ,
671679 additional ,
672680 _ ,
@@ -703,6 +711,7 @@ def parse_arguments(parser, json_config=None):
703711 fusedops_kernels_config ,
704712 attention_and_distributed_packing_config ,
705713 fast_moe_config ,
714+ mcp_config ,
706715 tracker_configs ,
707716 exp_metadata ,
708717 )
@@ -725,6 +734,7 @@ def main():
725734 fusedops_kernels_config ,
726735 attention_and_distributed_packing_config ,
727736 fast_moe_config ,
737+ mcp_config ,
728738 tracker_configs ,
729739 exp_metadata ,
730740 ) = parse_arguments (parser , job_config )
@@ -746,6 +756,7 @@ def main():
746756 "AADP (fms-acceleration) Config" : attention_and_distributed_packing_config ,
747757 "Fused Ops Kernels Config" : fusedops_kernels_config ,
748758 "Fast MoE Config" : fast_moe_config ,
759+ "MCP Config" : mcp_config ,
749760 "Tracker Config" : tracker_configs ,
750761 "Extra Metadata" : exp_metadata ,
751762 "Trainer Controller Config" : trainer_controller_args ,
@@ -789,6 +800,7 @@ def main():
789800 quantized_lora_config = quantized_lora_config ,
790801 fusedops_kernels_config = fusedops_kernels_config ,
791802 attention_and_distributed_packing_config = attention_and_distributed_packing_config ,
803+ mcp_config = mcp_config ,
792804 fast_moe_config = fast_moe_config ,
793805 )
794806 except (MemoryError , OutOfMemoryError ) as e :
0 commit comments