@@ -1447,6 +1447,44 @@ def test_run_moe_ft_and_inference_ep1_kernels(dataset_path, ep_degree):
14471447 )
14481448
14491449
1450+ @pytest .mark .skipif (
1451+ not is_fms_accelerate_available (plugins = "moe" ),
1452+ reason = "Only runs if fms-accelerate is installed along with accelerated-moe plugin" ,
1453+ )
1454+ @pytest .mark .parametrize (
1455+ "dataset_path" ,
1456+ [
1457+ TWITTER_COMPLAINTS_DATA_JSONL ,
1458+ ],
1459+ )
1460+ def test_run_moe_lora_and_inference (dataset_path ):
1461+ """Check if we can finetune a moe model and check if hf checkpoint is created"""
1462+ with tempfile .TemporaryDirectory () as tempdir :
1463+ data_args = copy .deepcopy (DATA_ARGS )
1464+ data_args .training_data_path = dataset_path
1465+ model_args = copy .deepcopy (MODEL_ARGS )
1466+ model_args .model_name_or_path = "ibm-granite/granite-3.1-1b-a400m-base"
1467+ train_args = copy .deepcopy (TRAIN_ARGS )
1468+ train_args .output_dir = tempdir
1469+ lora_args = copy .deepcopy (PEFT_LORA_ARGS )
1470+ lora_args .r = 16
1471+ lora_args .target_modules = ["q_proj" , "v_proj" , "o_proj" , "k_proj" ] # Router doesn't work with LoRA test inference
1472+ fast_moe_config = FastMoeConfig (fast_moe = FastMoe (ep_degree = False ))
1473+ sft_trainer .train (
1474+ model_args ,
1475+ data_args ,
1476+ train_args ,
1477+ lora_args ,
1478+ fast_moe_config = fast_moe_config ,
1479+ )
1480+ _test_run_inference (
1481+ checkpoint_path = os .path .join (
1482+ _get_checkpoint_path (tempdir ), "hf_converted_checkpoint"
1483+ ),
1484+ base_model_name_or_path = "ibm-granite/granite-3.1-1b-a400m-base"
1485+ )
1486+
1487+
14501488@pytest .mark .skipif (
14511489 not is_fms_accelerate_available (plugins = "moe" ),
14521490 reason = "Only runs if fms-accelerate is installed along with accelerated-moe plugin" ,
@@ -1485,9 +1523,9 @@ def _test_run_causallm_ft(training_args, model_args, data_args, tempdir):
14851523 _validate_training (tempdir )
14861524
14871525
1488- def _test_run_inference (checkpoint_path ):
1526+ def _test_run_inference (checkpoint_path , base_model_name_or_path = None ):
14891527 # Load the model
1490- loaded_model = TunedCausalLM .load (checkpoint_path )
1528+ loaded_model = TunedCausalLM .load (checkpoint_path , base_model_name_or_path )
14911529
14921530 # Run inference on the text
14931531 output_inference = loaded_model .run (
0 commit comments