Skip to content

Commit 5bb5489

Browse files
build: Upgrade TRL version from 0.14 to 0.16 (#527)
* Upgrade TRL version Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * Add attention mask in dataset Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * Version increase to 0.16.1 Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * Prompt tuning arg assign num_virtual_tokens=0 Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * Prompt tuning arg assign num_virtual_tokens=0 Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * undocumented Prompt tuning and commented its unit tests Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * Remove DataCollatorForSeq2Seq for tokenized dataset with packing Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * Remove DataCollatorForSeq2Seq for tokenized dataset with packing Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * Skipped PT tests Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * Fix lint Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * Remove enable_reduce_loss_sum and _is_peft_model check Signed-off-by: Abhishek <maurya.abhishek@ibm.com> --------- Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
1 parent bb92699 commit 5bb5489

16 files changed

Lines changed: 2267 additions & 208 deletions

README.md

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
- [Tuning Techniques](#tuning-techniques)
1111
- [LoRA Tuning Example](#lora-tuning-example)
1212
- [GPTQ-LoRA with AutoGPTQ Tuning Example](#gptq-lora-with-autogptq-tuning-example)
13-
- [Prompt Tuning](#prompt-tuning)
1413
- [Fine Tuning](#fine-tuning)
1514
- [FMS Acceleration](#fms-acceleration)
1615
- [Extended Pre-Training](#extended-pre-training)
@@ -754,54 +753,6 @@ Note that with LoRA tuning technique, setting `all-linear` on `target_modules` r
754753

755754
_________________________
756755

757-
### Prompt Tuning:
758-
759-
Specify `peft_method` to `'pt'` . You can additionally pass any arguments from [PromptTuningConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L63).
760-
```py
761-
# prompt_tuning_init can be either "TEXT" or "RANDOM"
762-
prompt_tuning_init: str = "TEXT"
763-
num_virtual_tokens: int = 8
764-
# prompt_tuning_init_text only applicable if prompt_tuning_init= "TEXT"
765-
prompt_tuning_init_text: str = "Classify if the tweet is a complaint or not:"
766-
tokenizer_name_or_path: str = "llama-7b-hf"
767-
```
768-
769-
Example command you can run:
770-
771-
```bash
772-
python tuning/sft_trainer.py \
773-
--model_name_or_path $MODEL_PATH \
774-
--training_data_path $TRAIN_DATA_PATH \
775-
--output_dir $OUTPUT_PATH \
776-
--num_train_epochs 5 \
777-
--per_device_train_batch_size 1 \
778-
--learning_rate 0.03 \
779-
--response_template "\n### Label:" \
780-
--dataset_text_field "output" \
781-
--peft_method pt \
782-
--tokenizer_name_or_path $MODEL_PATH \ # This field is optional and if not specified, tokenizer from model_name_or_path will be used
783-
--prompt_tuning_init "RANDOM" \
784-
--prompt_tuning_init_text "From the following input, identify target sentiment of following types: neutral, negative, positive"
785-
```
786-
787-
Equally you can pass in a JSON configuration for running tuning. See [build doc](./build/README.md) for more details. The above can also be passed in as JSON:
788-
```json
789-
{
790-
"model_name_or_path": $MODEL_PATH,
791-
"training_data_path": $TRAIN_DATA_PATH,
792-
"output_dir": $OUTPUT_PATH,
793-
"num_train_epochs": 5.0,
794-
"per_device_train_batch_size": 1,
795-
"learning_rate": 0.03,
796-
"response_template": "\n### Label:",
797-
"dataset_text_field": "output",
798-
"peft_method": "pt",
799-
"tokenizer_name_or_path": $MODEL_PATH,
800-
"prompt_tuning_init": "RANDOM",
801-
"prompt_tuning_init_text": "From the following input, identify target sentiment of following types: neutral, negative, positive"
802-
}
803-
```
804-
805756
### Fine Tuning:
806757

807758
Set `peft_method` to `'None'` or do not provide `peft_method` flag.
@@ -1041,6 +992,4 @@ Further details on enabling and using the trackers mentioned above can be found
1041992

1042993
## More Examples
1043994

1044-
[Prompt Tuning on Twitter Complaints](examples/prompt_tuning_twitter_complaints/README.md)
1045-
1046995
A good simple example can be found [here](examples/kfto-kueue-sft-trainer.yaml) which launches a Kubernetes-native `PyTorchJob` using the [Kubeflow Training Operator](https://github.com/kubeflow/training-operator/) with [Kueue](https://github.com/kubernetes-sigs/kueue) for the queue management of tuning jobs.

examples/prompt_tuning_twitter_complaints/README.md

Lines changed: 0 additions & 65 deletions
This file was deleted.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies = [
3434
"sentencepiece>=0.1.99,<0.3",
3535
"tokenizers>=0.13.3,<1.0",
3636
"tqdm>=4.66.2,<5.0",
37-
"trl>=0.13,<0.15",
37+
"trl>=0.13,<0.17",
3838
"peft>=0.8.0,<0.14",
3939
"protobuf>=5.28.0,<6.0.0",
4040
"datasets>=2.15.0,<4.0",

tests/artifacts/predefined_data_configs/duplicate_columns.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ datasets:
77
data_handlers:
88
- name: duplicate_columns
99
arguments:
10-
remove_columns: all
1110
batched: false
1211
fn_kwargs:
1312
old_column: "input_ids"

tests/artifacts/predefined_data_configs/skip_large_text_data_handler_template.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ datasets:
1313
dataset_text_field: "output"
1414
- name: duplicate_columns
1515
arguments:
16-
remove_columns: all
1716
batched: true
1817
fn_kwargs:
1918
old_column: "input_ids"

tests/artifacts/predefined_data_configs/tokenize_using_handler_and_train.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ datasets:
1515
max_length: 1024
1616
- name: duplicate_columns
1717
arguments:
18-
remove_columns: all
1918
batched: true
2019
fn_kwargs:
2120
old_column: "input_ids"
Binary file not shown.

0 commit comments

Comments
 (0)