Skip to content

Commit b3090a2

Browse files
authored
Update TRL Single Node example to uv (#2715)
1 parent de0bf5f commit b3090a2

2 files changed

Lines changed: 21 additions & 21 deletions

File tree

examples/single-node-training/trl/README.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,19 @@ env:
2121
- WANDB_API_KEY
2222
- HUB_MODEL_ID
2323
commands:
24-
- pip install "transformers>=4.43.2"
25-
- pip install bitsandbytes
26-
- pip install flash-attn --no-build-isolation
27-
- pip install peft
28-
- pip install wandb
24+
# Pin torch==2.6.0 to avoid building Flash Attention from source.
25+
# Prebuilt Flash Attention wheels are not available for the latest torch==2.7.0.
26+
- uv pip install torch==2.6.0
27+
- uv pip install transformers bitsandbytes peft wandb
28+
- uv pip install flash_attn --no-build-isolation
2929
- git clone https://github.com/huggingface/trl
3030
- cd trl
31-
- pip install .
32-
- |
31+
- uv pip install .
32+
- |
3333
accelerate launch \
3434
--config_file=examples/accelerate_configs/multi_gpu.yaml \
3535
--num_processes $DSTACK_GPUS_PER_NODE \
36-
examples/scripts/sft.py \
36+
trl/scripts/sft.py \
3737
--model_name meta-llama/Meta-Llama-3.1-8B \
3838
--dataset_name OpenAssistant/oasst_top1_2023-08-25 \
3939
--dataset_text_field="text" \
@@ -44,14 +44,15 @@ commands:
4444
--report_to wandb \
4545
--bf16 \
4646
--max_seq_length 1024 \
47-
--lora_r 16 --lora_alpha 32 \
47+
--lora_r 16 \
48+
--lora_alpha 32 \
4849
--lora_target_modules q_proj k_proj v_proj o_proj \
4950
--load_in_4bit \
5051
--use_peft \
5152
--attn_implementation "flash_attention_2" \
5253
--logging_steps=10 \
5354
--output_dir models/llama31 \
54-
--hub_model_id $HUB_MODEL_ID
55+
--hub_model_id peterschmidt85/FineLlama-3.1-8B
5556
5657
resources:
5758
gpu:

examples/single-node-training/trl/train.dstack.yml

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,19 @@ env:
1212
- ACCELERATE_LOG_LEVEL=info
1313
# Commands of the task
1414
commands:
15-
- conda install cuda
16-
- pip install git+https://github.com/huggingface/transformers.git
17-
- pip install bitsandbytes
18-
- pip install flash-attn --no-build-isolation
19-
- pip install peft
20-
- pip install wandb
15+
# Pin torch==2.6.0 to avoid building Flash Attention from source.
16+
# Prebuilt Flash Attention wheels are not available for the latest torch==2.7.0.
17+
- uv pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0
18+
- uv pip install transformers bitsandbytes peft wandb
19+
- uv pip install flash_attn --no-build-isolation
2120
- git clone https://github.com/huggingface/trl
2221
- cd trl
23-
- pip install .
22+
- uv pip install .
2423
- |
2524
accelerate launch \
2625
--config_file=examples/accelerate_configs/multi_gpu.yaml \
2726
--num_processes $DSTACK_GPUS_PER_NODE \
28-
examples/scripts/sft.py \
27+
trl/scripts/sft.py \
2928
--model_name meta-llama/Meta-Llama-3.1-8B \
3029
--dataset_name OpenAssistant/oasst_top1_2023-08-25 \
3130
--dataset_text_field="text" \
@@ -36,15 +35,15 @@ commands:
3635
--report_to wandb \
3736
--bf16 \
3837
--max_seq_length 1024 \
39-
--lora_r 16 --lora_alpha 32 \
38+
--lora_r 16 \
39+
--lora_alpha 32 \
4040
--lora_target_modules q_proj k_proj v_proj o_proj \
4141
--load_in_4bit \
4242
--use_peft \
4343
--attn_implementation "flash_attention_2" \
4444
--logging_steps=10 \
4545
--output_dir models/llama31 \
46-
--hub_model_id $HUB_MODEL_ID
47-
46+
--hub_model_id peterschmidt85/FineLlama-3.1-8B
4847
resources:
4948
gpu:
5049
# 24GB or more VRAM

0 commit comments

Comments
 (0)