Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 0 additions & 51 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
- [Tuning Techniques](#tuning-techniques)
- [LoRA Tuning Example](#lora-tuning-example)
- [GPTQ-LoRA with AutoGPTQ Tuning Example](#gptq-lora-with-autogptq-tuning-example)
- [Prompt Tuning](#prompt-tuning)
- [Fine Tuning](#fine-tuning)
- [FMS Acceleration](#fms-acceleration)
- [Extended Pre-Training](#extended-pre-training)
Expand Down Expand Up @@ -754,54 +753,6 @@ Note that with LoRA tuning technique, setting `all-linear` on `target_modules` r

_________________________

### Prompt Tuning:

Specify `peft_method` to `'pt'` . You can additionally pass any arguments from [PromptTuningConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L63).
```py
# prompt_tuning_init can be either "TEXT" or "RANDOM"
prompt_tuning_init: str = "TEXT"
num_virtual_tokens: int = 8
# prompt_tuning_init_text only applicable if prompt_tuning_init= "TEXT"
prompt_tuning_init_text: str = "Classify if the tweet is a complaint or not:"
tokenizer_name_or_path: str = "llama-7b-hf"
```

Example command you can run:

```bash
python tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--training_data_path $TRAIN_DATA_PATH \
--output_dir $OUTPUT_PATH \
--num_train_epochs 5 \
--per_device_train_batch_size 1 \
--learning_rate 0.03 \
--response_template "\n### Label:" \
--dataset_text_field "output" \
--peft_method pt \
--tokenizer_name_or_path $MODEL_PATH \ # This field is optional and if not specified, tokenizer from model_name_or_path will be used
--prompt_tuning_init "RANDOM" \
--prompt_tuning_init_text "From the following input, identify target sentiment of following types: neutral, negative, positive"
```

Equally you can pass in a JSON configuration for running tuning. See [build doc](./build/README.md) for more details. The above can also be passed in as JSON:
```json
{
"model_name_or_path": $MODEL_PATH,
"training_data_path": $TRAIN_DATA_PATH,
"output_dir": $OUTPUT_PATH,
"num_train_epochs": 5.0,
"per_device_train_batch_size": 1,
"learning_rate": 0.03,
"response_template": "\n### Label:",
"dataset_text_field": "output",
"peft_method": "pt",
"tokenizer_name_or_path": $MODEL_PATH,
"prompt_tuning_init": "RANDOM",
"prompt_tuning_init_text": "From the following input, identify target sentiment of following types: neutral, negative, positive"
}
```

### Fine Tuning:

Set `peft_method` to `'None'` or do not provide `peft_method` flag.
Expand Down Expand Up @@ -1041,6 +992,4 @@ Further details on enabling and using the trackers mentioned above can be found

## More Examples

[Prompt Tuning on Twitter Complaints](examples/prompt_tuning_twitter_complaints/README.md)

A good simple example can be found [here](examples/kfto-kueue-sft-trainer.yaml) which launches a Kubernetes-native `PyTorchJob` using the [Kubeflow Training Operator](https://github.com/kubeflow/training-operator/) with [Kueue](https://github.com/kubernetes-sigs/kueue) for the queue management of tuning jobs.
65 changes: 0 additions & 65 deletions examples/prompt_tuning_twitter_complaints/README.md

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"sentencepiece>=0.1.99,<0.3",
"tokenizers>=0.13.3,<1.0",
"tqdm>=4.66.2,<5.0",
"trl>=0.13,<0.15",
"trl>=0.13,<0.17",
"peft>=0.8.0,<0.14",
"protobuf>=5.28.0,<6.0.0",
"datasets>=2.15.0,<4.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ datasets:
data_handlers:
- name: duplicate_columns
arguments:
remove_columns: all
batched: false
fn_kwargs:
old_column: "input_ids"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ datasets:
dataset_text_field: "output"
- name: duplicate_columns
arguments:
remove_columns: all
batched: true
fn_kwargs:
old_column: "input_ids"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ datasets:
max_length: 1024
- name: duplicate_columns
arguments:
remove_columns: all
batched: true
fn_kwargs:
old_column: "input_ids"
Expand Down
Binary file not shown.
Loading