Skip to content

Commit 42c420c

Browse files
committed
test: lora for scattermoe
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
1 parent a848a9b commit 42c420c

1 file changed

Lines changed: 40 additions & 2 deletions

File tree

tests/test_sft_trainer.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,44 @@ def test_run_moe_ft_and_inference_ep1_kernels(dataset_path, ep_degree):
14471447
)
14481448

14491449

1450+
@pytest.mark.skipif(
1451+
not is_fms_accelerate_available(plugins="moe"),
1452+
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
1453+
)
1454+
@pytest.mark.parametrize(
1455+
"dataset_path",
1456+
[
1457+
TWITTER_COMPLAINTS_DATA_JSONL,
1458+
],
1459+
)
1460+
def test_run_moe_lora_and_inference(dataset_path):
1461+
"""Check if we can finetune a moe model and check if hf checkpoint is created"""
1462+
with tempfile.TemporaryDirectory() as tempdir:
1463+
data_args = copy.deepcopy(DATA_ARGS)
1464+
data_args.training_data_path = dataset_path
1465+
model_args = copy.deepcopy(MODEL_ARGS)
1466+
model_args.model_name_or_path = "ibm-granite/granite-3.1-1b-a400m-base"
1467+
train_args = copy.deepcopy(TRAIN_ARGS)
1468+
train_args.output_dir = tempdir
1469+
lora_args = copy.deepcopy(PEFT_LORA_ARGS)
1470+
lora_args.r = 16
1471+
lora_args.target_modules = ["q_proj", "v_proj", "o_proj", "k_proj"] # Router doesn't work with LoRA test inference
1472+
fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=False))
1473+
sft_trainer.train(
1474+
model_args,
1475+
data_args,
1476+
train_args,
1477+
lora_args,
1478+
fast_moe_config=fast_moe_config,
1479+
)
1480+
_test_run_inference(
1481+
checkpoint_path=os.path.join(
1482+
_get_checkpoint_path(tempdir), "hf_converted_checkpoint"
1483+
),
1484+
base_model_name_or_path="ibm-granite/granite-3.1-1b-a400m-base"
1485+
)
1486+
1487+
14501488
@pytest.mark.skipif(
14511489
not is_fms_accelerate_available(plugins="moe"),
14521490
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
@@ -1485,9 +1523,9 @@ def _test_run_causallm_ft(training_args, model_args, data_args, tempdir):
14851523
_validate_training(tempdir)
14861524

14871525

1488-
def _test_run_inference(checkpoint_path):
1526+
def _test_run_inference(checkpoint_path, base_model_name_or_path=None):
14891527
# Load the model
1490-
loaded_model = TunedCausalLM.load(checkpoint_path)
1528+
loaded_model = TunedCausalLM.load(checkpoint_path, base_model_name_or_path)
14911529

14921530
# Run inference on the text
14931531
output_inference = loaded_model.run(

0 commit comments

Comments
 (0)