|
42 | 42 | from tests.artifacts.language_models import MAYKEYE_TINY_LLAMA_CACHED, TINYMIXTRAL_MOE |
43 | 43 | from tests.artifacts.predefined_data_configs import ( |
44 | 44 | CHAT_TEMPLATE_JINJA, |
| 45 | + DATA_CONFIG_CUSTOM_SPLIT_NAME, |
45 | 46 | DATA_CONFIG_DUPLICATE_COLUMNS, |
46 | 47 | DATA_CONFIG_INVALID_BASE64_CHAT_TEMPLATE, |
47 | 48 | DATA_CONFIG_MULTIPLE_DATASETS_ODM_YAML, |
|
61 | 62 | GRANITE_3_1_B_CHAT_TEMPLATE, |
62 | 63 | ) |
63 | 64 | from tests.artifacts.testdata import ( |
| 65 | + CHAT_DATA_HF_HOSTED_CUSTOM_SPLIT, |
64 | 66 | CHAT_DATA_MULTI_TURN, |
65 | 67 | CHAT_DATA_MULTI_TURN_CONVERSATIONS, |
66 | 68 | CHAT_DATA_MULTI_TURN_GRANITE_3_1B, |
@@ -949,16 +951,16 @@ def test_run_causallm_lora_add_special_tokens(): |
949 | 951 | ["lm_head"], |
950 | 952 | ["embed_tokens"], |
951 | 953 | marks=pytest.mark.skipif( |
952 | | - version.parse(peft.__version__) <= version.parse("0.18.0"), |
953 | | - reason="Not released in PEFT <= 0.18.0", |
| 954 | + version.parse(peft.__version__) <= version.parse("0.18.1"), |
| 955 | + reason="Not released in PEFT <= 0.18.1", |
954 | 956 | ), |
955 | 957 | ), |
956 | 958 | pytest.param( |
957 | 959 | ["embed_tokens", "lm_head"], |
958 | 960 | ["embed_tokens"], |
959 | 961 | marks=pytest.mark.skipif( |
960 | | - version.parse(peft.__version__) <= version.parse("0.18.0"), |
961 | | - reason="Not released in PEFT <= 0.18.0", |
| 962 | + version.parse(peft.__version__) <= version.parse("0.18.1"), |
| 963 | + reason="Not released in PEFT <= 0.18.1", |
962 | 964 | ), |
963 | 965 | ), |
964 | 966 | ], |
@@ -1010,8 +1012,8 @@ def test_run_causallm_lora_tied_weights_in_modules_to_save(modules_to_save, expe |
1010 | 1012 | ], |
1011 | 1013 | ) |
1012 | 1014 | @pytest.mark.skipif( |
1013 | | - version.parse(peft.__version__) <= version.parse("0.18.0"), |
1014 | | - reason="Not released in PEFT <= 0.18.0", |
| 1015 | + version.parse(peft.__version__) <= version.parse("0.18.1"), |
| 1016 | + reason="Not released in PEFT <= 0.18.1", |
1015 | 1017 | ) |
1016 | 1018 | def test_run_causallm_lora_tied_weights_in_target_modules(target_modules, expected): |
1017 | 1019 | """Check if a model with tied weights in target_modules is correctly trained""" |
@@ -1916,6 +1918,61 @@ def test_run_moe_ft_with_save_model_dir(dataset_path): |
1916 | 1918 | assert os.path.exists(os.path.join(save_model_dir)) |
1917 | 1919 |
|
1918 | 1920 |
|
| 1921 | +@pytest.mark.parametrize( |
| 1922 | + "datafiles, dataconfigfile", |
| 1923 | + [ |
| 1924 | + ( |
| 1925 | + [CHAT_DATA_HF_HOSTED_CUSTOM_SPLIT, CHAT_DATA_HF_HOSTED_CUSTOM_SPLIT], |
| 1926 | + DATA_CONFIG_CUSTOM_SPLIT_NAME, |
| 1927 | + ) |
| 1928 | + ], |
| 1929 | +) |
| 1930 | +def test_run_chat_style_ft_using_custom_split_name(datafiles, dataconfigfile): |
| 1931 | + """Check if we can select custom split for a dataset.""" |
| 1932 | + with tempfile.TemporaryDirectory() as tempdir: |
| 1933 | + data_args = copy.deepcopy(DATA_ARGS) |
| 1934 | + data_args.training_data_path = None |
| 1935 | + data_args.response_template = None |
| 1936 | + data_args.dataset_text_field = None |
| 1937 | + data_args.chat_template = CHAT_TEMPLATE_JINJA |
| 1938 | + |
| 1939 | + model_args = copy.deepcopy(MODEL_ARGS) |
| 1940 | + model_args.model_name_or_path = TINYMIXTRAL_MOE |
| 1941 | + model_args.tokenizer_name_or_path = TINYMIXTRAL_MOE |
| 1942 | + |
| 1943 | + train_args = copy.deepcopy(TRAIN_ARGS) |
| 1944 | + train_args.output_dir = tempdir |
| 1945 | + |
| 1946 | + with tempfile.NamedTemporaryFile( |
| 1947 | + "w", delete=False, suffix=".yaml" |
| 1948 | + ) as temp_yaml_file: |
| 1949 | + with open(dataconfigfile, "r", encoding="utf-8") as f: |
| 1950 | + data = yaml.safe_load(f) |
| 1951 | + datasets = data["datasets"] |
| 1952 | + for i, d in enumerate(datasets): |
| 1953 | + d["data_paths"] = [datafiles[i]] |
| 1954 | + yaml.dump(data, temp_yaml_file) |
| 1955 | + data_args.data_config_path = temp_yaml_file.name |
| 1956 | + |
| 1957 | + sft_trainer.train(model_args, data_args, train_args) |
| 1958 | + |
| 1959 | + # validate the configs |
| 1960 | + _validate_training(tempdir) |
| 1961 | + checkpoint_path = _get_checkpoint_path(tempdir) |
| 1962 | + |
| 1963 | + # Load the model |
| 1964 | + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) |
| 1965 | + |
| 1966 | + # Run inference on the text |
| 1967 | + output_inference = loaded_model.run( |
| 1968 | + '<|user|>\nProvide two rhyming words for the word "love"\n\ |
| 1969 | + <nopace></s><|assistant|>', |
| 1970 | + max_new_tokens=50, |
| 1971 | + ) |
| 1972 | + assert len(output_inference) > 0 |
| 1973 | + assert 'Provide two rhyming words for the word "love"' in output_inference |
| 1974 | + |
| 1975 | + |
1919 | 1976 | ############################# Helper functions ############################# |
1920 | 1977 | def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): |
1921 | 1978 | train_args = copy.deepcopy(training_args) |
|
0 commit comments