diff --git a/examples/single-node-training/trl/README.md b/examples/single-node-training/trl/README.md index 300c9255e..3374ec51c 100644 --- a/examples/single-node-training/trl/README.md +++ b/examples/single-node-training/trl/README.md @@ -21,19 +21,19 @@ env: - WANDB_API_KEY - HUB_MODEL_ID commands: - - pip install "transformers>=4.43.2" - - pip install bitsandbytes - - pip install flash-attn --no-build-isolation - - pip install peft - - pip install wandb + # Pin torch==2.6.0 to avoid building Flash Attention from source. + # Prebuilt Flash Attention wheels are not available for the latest torch==2.7.0. + - uv pip install torch==2.6.0 + - uv pip install transformers bitsandbytes peft wandb + - uv pip install flash_attn --no-build-isolation - git clone https://github.com/huggingface/trl - cd trl - - pip install . - - | + - uv pip install . + - | accelerate launch \ --config_file=examples/accelerate_configs/multi_gpu.yaml \ --num_processes $DSTACK_GPUS_PER_NODE \ - examples/scripts/sft.py \ + trl/scripts/sft.py \ --model_name meta-llama/Meta-Llama-3.1-8B \ --dataset_name OpenAssistant/oasst_top1_2023-08-25 \ --dataset_text_field="text" \ @@ -44,14 +44,15 @@ commands: --report_to wandb \ --bf16 \ --max_seq_length 1024 \ - --lora_r 16 --lora_alpha 32 \ + --lora_r 16 \ + --lora_alpha 32 \ --lora_target_modules q_proj k_proj v_proj o_proj \ --load_in_4bit \ --use_peft \ --attn_implementation "flash_attention_2" \ --logging_steps=10 \ --output_dir models/llama31 \ - --hub_model_id $HUB_MODEL_ID + --hub_model_id peterschmidt85/FineLlama-3.1-8B resources: gpu: diff --git a/examples/single-node-training/trl/train.dstack.yml b/examples/single-node-training/trl/train.dstack.yml index de57c3c11..77a23f3b0 100644 --- a/examples/single-node-training/trl/train.dstack.yml +++ b/examples/single-node-training/trl/train.dstack.yml @@ -12,20 +12,19 @@ env: - ACCELERATE_LOG_LEVEL=info # Commands of the task commands: - - conda install cuda - - pip install git+https://github.com/huggingface/transformers.git - - pip install bitsandbytes - - pip install flash-attn --no-build-isolation - - pip install peft - - pip install wandb + # Pin torch==2.6.0 to avoid building Flash Attention from source. + # Prebuilt Flash Attention wheels are not available for the latest torch==2.7.0. + - uv pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 + - uv pip install transformers bitsandbytes peft wandb + - uv pip install flash_attn --no-build-isolation - git clone https://github.com/huggingface/trl - cd trl - - pip install . + - uv pip install . - | accelerate launch \ --config_file=examples/accelerate_configs/multi_gpu.yaml \ --num_processes $DSTACK_GPUS_PER_NODE \ - examples/scripts/sft.py \ + trl/scripts/sft.py \ --model_name meta-llama/Meta-Llama-3.1-8B \ --dataset_name OpenAssistant/oasst_top1_2023-08-25 \ --dataset_text_field="text" \ @@ -36,15 +35,15 @@ commands: --report_to wandb \ --bf16 \ --max_seq_length 1024 \ - --lora_r 16 --lora_alpha 32 \ + --lora_r 16 \ + --lora_alpha 32 \ --lora_target_modules q_proj k_proj v_proj o_proj \ --load_in_4bit \ --use_peft \ --attn_implementation "flash_attention_2" \ --logging_steps=10 \ --output_dir models/llama31 \ - --hub_model_id $HUB_MODEL_ID - + --hub_model_id peterschmidt85/FineLlama-3.1-8B resources: gpu: # 24GB or more VRAM