Skip to content

Commit 91db1a9

Browse files
committed
feat: add support for mamba cp
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent 4cf9d7e commit 91db1a9

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

tests/test_sft_trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ def test_parse_arguments(job_config):
410410
_,
411411
_,
412412
_,
413+
_,
413414
) = sft_trainer.parse_arguments(parser, job_config_copy)
414415
assert str(model_args.torch_dtype) == "torch.bfloat16"
415416
assert data_args.dataset_text_field == "output"
@@ -436,6 +437,7 @@ def test_parse_arguments_defaults(job_config):
436437
_,
437438
_,
438439
_,
440+
_,
439441
) = sft_trainer.parse_arguments(parser, job_config_defaults)
440442
assert str(model_args.torch_dtype) == "torch.bfloat16"
441443
assert model_args.use_flash_attn is False
@@ -459,6 +461,7 @@ def test_parse_arguments_peft_method(job_config):
459461
_,
460462
_,
461463
_,
464+
_,
462465
) = sft_trainer.parse_arguments(parser, job_config_pt)
463466

464467
assert isinstance(tune_config, peft_config.PromptTuningConfig)
@@ -478,6 +481,7 @@ def test_parse_arguments_peft_method(job_config):
478481
_,
479482
_,
480483
_,
484+
_,
481485
) = sft_trainer.parse_arguments(parser, job_config_lora)
482486
assert isinstance(tune_config, peft_config.LoraConfig)
483487
assert not tune_config.target_modules

0 commit comments

Comments
 (0)