File tree Expand file tree Collapse file tree
examples/single-node-training/trl Expand file tree Collapse file tree Original file line number Diff line number Diff line change 2121 - WANDB_API_KEY
2222 - HUB_MODEL_ID
2323commands :
24- - pip install "transformers>=4.43.2"
25- - pip install bitsandbytes
26- - pip install flash-attn --no-build-isolation
27- - pip install peft
28- - pip install wandb
24+ # Pin torch==2.6.0 to avoid building Flash Attention from source.
25+ # Prebuilt Flash Attention wheels are not available for the latest torch==2.7.0.
26+ - uv pip install torch==2.6.0
27+ - uv pip install transformers bitsandbytes peft wandb
28+ - uv pip install flash_attn --no-build-isolation
2929 - git clone https://github.com/huggingface/trl
3030 - cd trl
31- - pip install .
32- - |
31+ - uv pip install .
32+ - |
3333 accelerate launch \
3434 --config_file=examples/accelerate_configs/multi_gpu.yaml \
3535 --num_processes $DSTACK_GPUS_PER_NODE \
36- examples /scripts/sft.py \
36+ trl /scripts/sft.py \
3737 --model_name meta-llama/Meta-Llama-3.1-8B \
3838 --dataset_name OpenAssistant/oasst_top1_2023-08-25 \
3939 --dataset_text_field="text" \
@@ -44,14 +44,15 @@ commands:
4444 --report_to wandb \
4545 --bf16 \
4646 --max_seq_length 1024 \
47- --lora_r 16 --lora_alpha 32 \
47+ --lora_r 16 \
48+ --lora_alpha 32 \
4849 --lora_target_modules q_proj k_proj v_proj o_proj \
4950 --load_in_4bit \
5051 --use_peft \
5152 --attn_implementation "flash_attention_2" \
5253 --logging_steps=10 \
5354 --output_dir models/llama31 \
54- --hub_model_id $HUB_MODEL_ID
55+ --hub_model_id peterschmidt85/FineLlama-3.1-8B
5556
5657resources :
5758 gpu :
Original file line number Diff line number Diff line change 1212 - ACCELERATE_LOG_LEVEL=info
1313# Commands of the task
1414commands :
15- - conda install cuda
16- - pip install git+https://github.com/huggingface/transformers.git
17- - pip install bitsandbytes
18- - pip install flash-attn --no-build-isolation
19- - pip install peft
20- - pip install wandb
15+ # Pin torch==2.6.0 to avoid building Flash Attention from source.
16+ # Prebuilt Flash Attention wheels are not available for the latest torch==2.7.0.
17+ - uv pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0
18+ - uv pip install transformers bitsandbytes peft wandb
19+ - uv pip install flash_attn --no-build-isolation
2120 - git clone https://github.com/huggingface/trl
2221 - cd trl
23- - pip install .
22+ - uv pip install .
2423 - |
2524 accelerate launch \
2625 --config_file=examples/accelerate_configs/multi_gpu.yaml \
2726 --num_processes $DSTACK_GPUS_PER_NODE \
28- examples /scripts/sft.py \
27+ trl /scripts/sft.py \
2928 --model_name meta-llama/Meta-Llama-3.1-8B \
3029 --dataset_name OpenAssistant/oasst_top1_2023-08-25 \
3130 --dataset_text_field="text" \
@@ -36,15 +35,15 @@ commands:
3635 --report_to wandb \
3736 --bf16 \
3837 --max_seq_length 1024 \
39- --lora_r 16 --lora_alpha 32 \
38+ --lora_r 16 \
39+ --lora_alpha 32 \
4040 --lora_target_modules q_proj k_proj v_proj o_proj \
4141 --load_in_4bit \
4242 --use_peft \
4343 --attn_implementation "flash_attention_2" \
4444 --logging_steps=10 \
4545 --output_dir models/llama31 \
46- --hub_model_id $HUB_MODEL_ID
47-
46+ --hub_model_id peterschmidt85/FineLlama-3.1-8B
4847resources :
4948 gpu :
5049 # 24GB or more VRAM
You can’t perform that action at this time.
0 commit comments