diff --git a/tests/artifacts/language_models/maykeye-tinyllama-v0/config.json b/tests/artifacts/language_models/maykeye-tinyllama-v0/config.json index 11735ecd10..b53c03fe40 100644 --- a/tests/artifacts/language_models/maykeye-tinyllama-v0/config.json +++ b/tests/artifacts/language_models/maykeye-tinyllama-v0/config.json @@ -14,7 +14,7 @@ "num_hidden_layers": 8, "pad_token_id": 0, "rms_norm_eps": 1e-06, - "tie_word_embeddings": false, + "tie_word_embeddings": true, "torch_dtype": "bfloat16", "transformers_version": "4.30.2", "use_cache": true, diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 9cf8191aa2..5981523175 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -27,8 +27,10 @@ # Third Party from datasets.exceptions import DatasetGenerationError, DatasetNotFoundError +from packaging import version from peft import LoraConfig as HFLoraConfig from transformers.trainer_callback import TrainerCallback +import peft import pytest import torch import transformers @@ -633,7 +635,9 @@ def test_run_causallm_lora_invalid_train_params(param_name, param_val, exc_msg): setattr(invalid_params, param_name, param_val) with pytest.raises(ValueError, match=exc_msg): - sft_trainer.train(MODEL_ARGS, DATA_ARGS, invalid_params, PEFT_LORA_ARGS) + sft_trainer.train( + MODEL_ARGS, DATA_ARGS, invalid_params, copy.deepcopy(PEFT_LORA_ARGS) + ) @pytest.mark.parametrize( @@ -649,7 +653,9 @@ def test_run_causallm_lora_with_validation(dataset_path): data_args = copy.deepcopy(DATA_ARGS) data_args.validation_data_path = dataset_path - sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS) + sft_trainer.train( + MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS) + ) _validate_training(tempdir, check_eval=True) @@ -670,7 +676,9 @@ def test_run_causallm_lora_with_validation_data_formatting(dataset_path): "### Text: {{element['Tweet text']}} \n\n### Label: {{text_label}}" ) - sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS) + sft_trainer.train( + MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS) + ) _validate_training(tempdir, check_eval=True) @@ -829,6 +837,154 @@ def test_successful_lora_target_modules_default_from_main(): "v_proj", }, "target_modules are not set to the default values." + os.environ.pop("SFT_TRAINER_CONFIG_JSON_ENV_VAR", None) + + +def test_run_causallm_lora_add_special_tokens(): + """Check if embed layer is added as modules_to_save when special tokens are added""" + with tempfile.TemporaryDirectory() as tempdir: + # with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + base_lora_args = copy.deepcopy(PEFT_LORA_ARGS) + base_lora_args.target_modules = ["q_proj"] + + # sample hugging face dataset id + data_args = copy.deepcopy(DATA_ARGS) + data_args.add_special_tokens = [ + "<|test_token_1|>", + "<|test_token_2|>", + "<|test_token_3|>", + ] + + sft_trainer.train(MODEL_ARGS, data_args, train_args, base_lora_args) + + # validate lora tuning configs + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + adapter_config = _get_adapter_config(checkpoint_path) + _validate_adapter_config(adapter_config, "LORA") + tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_path) + + assert adapter_config.get("modules_to_save") is not None + assert "embed_tokens" in adapter_config.get("modules_to_save") + + # Check if all special tokens passed are in tokenizer + for tok in data_args.add_special_tokens: + assert tok in tokenizer.vocab + + +@pytest.mark.parametrize( + "modules_to_save, expected", + [ + (None, []), + (["embed_tokens"], ["embed_tokens"]), + pytest.param( + ["lm_head"], + ["embed_tokens"], + marks=pytest.mark.skipif( + version.parse(peft.__version__) <= version.parse("0.18.0"), + reason="Not released in PEFT <= 0.18.0", + ), + ), + pytest.param( + ["embed_tokens", "lm_head"], + ["embed_tokens"], + marks=pytest.mark.skipif( + version.parse(peft.__version__) <= version.parse("0.18.0"), + reason="Not released in PEFT <= 0.18.0", + ), + ), + ], +) +def test_run_causallm_lora_tied_weights_in_modules_to_save(modules_to_save, expected): + """Check if a model with tied weights in modules to save is correctly trained""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + base_lora_args = copy.deepcopy(PEFT_LORA_ARGS) + base_lora_args.target_modules = ["q_proj"] + base_lora_args.modules_to_save = modules_to_save + + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, base_lora_args) + + # validate lora tuning configs + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + adapter_config = _get_adapter_config(checkpoint_path) + _validate_adapter_config(adapter_config, "LORA") + + for module in expected: + assert module in adapter_config.get("modules_to_save") + + # Load the model and merge it + loaded_model = TunedCausalLM.load(checkpoint_path, MAYKEYE_TINY_LLAMA_CACHED) + merged_model = loaded_model.peft_model.merge_and_unload() + + # In all the cases Embedding and the LM layer should not have diverged + embed_layer = merged_model.get_input_embeddings() + lm_layer = merged_model.get_output_embeddings() + + assert torch.allclose(embed_layer.weight, lm_layer.weight) + assert embed_layer.weight.data_ptr() == lm_layer.weight.data_ptr() + + +@pytest.mark.parametrize( + "target_modules, expected", + [ + (["embed_tokens"], ["embed_tokens"]), + (["lm_head"], ["embed_tokens"]), + (["embed_tokens", "lm_head"], ["embed_tokens"]), + ], +) +@pytest.mark.skipif( + version.parse(peft.__version__) <= version.parse("0.18.0"), + reason="Not released in PEFT <= 0.18.0", +) +def test_run_causallm_lora_tied_weights_in_target_modules(target_modules, expected): + """Check if a model with tied weights in target_modules is correctly trained""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + base_lora_args = copy.deepcopy(PEFT_LORA_ARGS) + base_lora_args.target_modules = target_modules + + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, base_lora_args) + + # validate lora tuning configs + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + adapter_config = _get_adapter_config(checkpoint_path) + _validate_adapter_config(adapter_config, "LORA") + + tm = adapter_config.get("target_modules") + for module in expected: + flag = False + + for t in tm: + if module in t: + flag = True + break + + assert flag, f"Expected {module} not found in target_modules config: {tm}" + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path, MAYKEYE_TINY_LLAMA_CACHED) + + # In all the cases Embedding and the LM layer should not have diverged + embed_layer = loaded_model.peft_model.get_input_embeddings() + lm_layer = loaded_model.peft_model.get_output_embeddings() + d_embed = embed_layer.get_delta_weight("default") + d_lm = lm_layer.get_delta_weight("default") + + assert embed_layer.weight.data_ptr() == lm_layer.weight.data_ptr() + assert torch.allclose( + d_embed, d_lm, atol=1e-6 + ), f"Max diff between deltas: {(d_embed - d_lm).abs().max()}" + ############################# Finetuning Tests ############################# @pytest.mark.parametrize( @@ -1816,7 +1972,9 @@ def test_tokenizer_has_no_eos_token(): # If we handled this badly, we would probably get something like a # TypeError: can only concatenate str (not "NoneType") to str error # when we go to apply the data formatter. - sft_trainer.train(model_args, DATA_ARGS, train_args, PEFT_LORA_ARGS) + sft_trainer.train( + model_args, DATA_ARGS, train_args, copy.deepcopy(PEFT_LORA_ARGS) + ) _validate_training(tempdir) @@ -1828,7 +1986,9 @@ def test_invalid_dataset_text_field(): data_args.dataset_text_field = "not found" with pytest.raises(KeyError): - sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_LORA_ARGS) + sft_trainer.train( + MODEL_ARGS, data_args, TRAIN_ARGS, copy.deepcopy(PEFT_LORA_ARGS) + ) ### Tests that giving dataset_text_field as well as formatter template gives error @@ -1840,7 +2000,9 @@ def test_invalid_dataset_text_field_and_formatter_template(): ) with pytest.raises(ValueError): - sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_LORA_ARGS) + sft_trainer.train( + MODEL_ARGS, data_args, TRAIN_ARGS, copy.deepcopy(PEFT_LORA_ARGS) + ) ### Tests passing formatter with invalid keys gives error @@ -1852,7 +2014,9 @@ def test_invalid_formatter_template(): ) with pytest.raises(KeyError): - sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_LORA_ARGS) + sft_trainer.train( + MODEL_ARGS, data_args, TRAIN_ARGS, copy.deepcopy(PEFT_LORA_ARGS) + ) ### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing) @@ -1862,7 +2026,9 @@ def test_malformatted_data(): data_args.training_data_path = MALFORMATTED_DATA with pytest.raises((DatasetGenerationError, ValueError)): - sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_LORA_ARGS) + sft_trainer.train( + MODEL_ARGS, data_args, TRAIN_ARGS, copy.deepcopy(PEFT_LORA_ARGS) + ) def test_empty_data(): @@ -1871,7 +2037,9 @@ def test_empty_data(): data_args.training_data_path = EMPTY_DATA with pytest.raises((DatasetGenerationError, ValueError)): - sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_LORA_ARGS) + sft_trainer.train( + MODEL_ARGS, data_args, TRAIN_ARGS, copy.deepcopy(PEFT_LORA_ARGS) + ) ### Tests for bad tuning module configurations @@ -1900,7 +2068,9 @@ def test_no_packing_needs_dataset_text_field_or_data_formatter_template(): data_args.data_formatter_template = None with pytest.raises(ValueError): - sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS) + sft_trainer.train( + MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS) + ) # TODO: Fix this case @@ -1914,7 +2084,9 @@ def test_no_packing_needs_reponse_template(): data_args.response_template = None with pytest.raises(ValueError): - sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS) + sft_trainer.train( + MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS) + ) ### Tests for model dtype edge cases @@ -1931,7 +2103,9 @@ def test_bf16_still_tunes_if_unsupported(): model_args = copy.deepcopy(MODEL_ARGS) model_args.torch_dtype = "bfloat16" - sft_trainer.train(model_args, DATA_ARGS, train_args, PEFT_LORA_ARGS) + sft_trainer.train( + model_args, DATA_ARGS, train_args, copy.deepcopy(PEFT_LORA_ARGS) + ) _validate_training(tempdir) @@ -1944,7 +2118,9 @@ def test_bad_torch_dtype(): model_args.torch_dtype = "not a type" with pytest.raises(ValueError): - sft_trainer.train(model_args, DATA_ARGS, train_args, PEFT_LORA_ARGS) + sft_trainer.train( + model_args, DATA_ARGS, train_args, copy.deepcopy(PEFT_LORA_ARGS) + ) def test_run_with_additional_callbacks(): @@ -1958,7 +2134,7 @@ def test_run_with_additional_callbacks(): MODEL_ARGS, DATA_ARGS, train_args, - PEFT_LORA_ARGS, + copy.deepcopy(PEFT_LORA_ARGS), additional_callbacks=[TrainerCallback()], ) @@ -1977,7 +2153,7 @@ def test_run_with_bad_additional_callbacks(): MODEL_ARGS, DATA_ARGS, train_args, - PEFT_LORA_ARGS, + copy.deepcopy(PEFT_LORA_ARGS), additional_callbacks=["NotSupposedToBeHere"], ) @@ -1998,7 +2174,7 @@ def test_run_with_bad_experimental_metadata(): MODEL_ARGS, DATA_ARGS, train_args, - PEFT_LORA_ARGS, + copy.deepcopy(PEFT_LORA_ARGS), additional_callbacks=[TrainerCallback()], exp_metadata=metadata, ) @@ -2017,7 +2193,7 @@ def test_run_with_good_experimental_metadata(): MODEL_ARGS, DATA_ARGS, train_args, - PEFT_LORA_ARGS, + copy.deepcopy(PEFT_LORA_ARGS), additional_callbacks=[TrainerCallback()], exp_metadata=metadata, ) @@ -2040,7 +2216,9 @@ def test_pretokenized_dataset(dataset_path): data_args.dataset_text_field = None data_args.response_template = None data_args.training_data_path = dataset_path - sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS) + sft_trainer.train( + MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS) + ) _validate_training(tempdir) @@ -2064,7 +2242,9 @@ def test_pretokenized_dataset_bad_args(dataset_text_field, response_template): # We should raise an error since we should not have a dataset text # field or a response template if we have pretokenized data with pytest.raises(ValueError): - sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS) + sft_trainer.train( + MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS) + ) def test_pretokenized_dataset_wrong_format(): @@ -2082,7 +2262,9 @@ def test_pretokenized_dataset_wrong_format(): # need to directly add validation prior to the dataset generation since datasets # is essentially swallowing a KeyError here. with pytest.raises(ValueError): - sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS) + sft_trainer.train( + MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS) + ) ########################################################################### @@ -2115,7 +2297,7 @@ def test_run_with_bad_additional_data_handlers(additional_handlers): MODEL_ARGS, DATA_ARGS, train_args, - PEFT_LORA_ARGS, + copy.deepcopy(PEFT_LORA_ARGS), additional_data_handlers=additional_handlers, ) @@ -2130,7 +2312,7 @@ def test_run_with_additional_data_handlers_as_none(): MODEL_ARGS, DATA_ARGS, train_args, - PEFT_LORA_ARGS, + copy.deepcopy(PEFT_LORA_ARGS), additional_data_handlers=None, ) _validate_training(tempdir) @@ -2177,7 +2359,7 @@ def test_handler(element, **kwargs): MODEL_ARGS, DATA_ARGS, train_args, - PEFT_LORA_ARGS, + copy.deepcopy(PEFT_LORA_ARGS), additional_data_handlers={ TEST_HANDLER: DataHandler( op=test_handler, diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index fd759229c7..dc6a74174b 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -379,6 +379,33 @@ def train( added_tokens_dict = setup_tokenizer(tokenizer, data_args, model_args, model) + # If additional tokens are added, and we are doing LoRA + # we need to set the embedding layer as trainable + # and ensure that the weights are tied + if added_tokens_dict and isinstance(peft_config, LoraConfig): + if added_tokens_dict.get("num_new_tokens", 0) > 0: + modules_to_save = getattr(peft_config, "modules_to_save", []) or [] + target_modules = getattr(peft_config, "target_modules", []) or [] + + # If the initial model's weights are not tied, + # then we need to add both the embedding layer and the output layer + # If embedding layer or lm head is already targetted via `target_modules` + # then we skip adding it `modules_to_save` since it is already adapted + # for changes + if not any(m in target_modules for m in ("embed_tokens", "lm_head")): + # TODO: @romit Enable adding both embed tokens and lm head to modules to save + # modules_to_save.extend(["embed_tokens", "lm_head"]) + modules_to_save.extend(["embed_tokens"]) + setattr(peft_config, "modules_to_save", modules_to_save) + + # This is safe to do for both tied and non-tied models + # `ensure_weight_tying` will be ignored if weights are not tied + # https://github.com/huggingface/peft/blob/v0.18.0.rc0/src/peft/tuners/tuners_utils.py#L1230 + setattr(peft_config, "ensure_weight_tying", True) + logger.info( + "Adding embed_tokens and lm_head as trainable modules due to vocab expansion" + ) + # Configure the collator and validate args related to packing prior to formatting the dataset data_collator = None logger.info("Packing is set to %s ", train_args.packing) diff --git a/tuning/utils/config_utils.py b/tuning/utils/config_utils.py index 806e593dcb..3e8f5bc18a 100644 --- a/tuning/utils/config_utils.py +++ b/tuning/utils/config_utils.py @@ -97,6 +97,21 @@ def get_hf_peft_config(task_type, tuning_config, tokenizer_name_or_path): if hasattr(tuning_config, "alora_invocation_string"): delattr(tuning_config, "alora_invocation_string") + + # Make sure that weight tying is not broken in case + # the embedding layer is added as trainable under LoRA + if any( + m in (getattr(tuning_config, "modules_to_save", []) or []) + for m in ("embed_tokens", "lm_head") + ): + setattr(tuning_config, "ensure_weight_tying", True) + + if any( + m in (getattr(tuning_config, "target_modules", []) or []) + for m in ("embed_tokens", "lm_head") + ): + setattr(tuning_config, "ensure_weight_tying", True) + return tuning_config if isinstance(tuning_config, peft_config.PromptTuningConfig):