-
-
Notifications
You must be signed in to change notification settings - Fork 227
Expand file tree
/
Copy pathtrain.dstack.yml
More file actions
54 lines (52 loc) · 1.61 KB
/
train.dstack.yml
File metadata and controls
54 lines (52 loc) · 1.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
type: task
# The name is optional, if not specified, generated randomly
name: trl-train
python: 3.12
# Required environment variables
env:
- HF_TOKEN
- WANDB_API_KEY
- HUB_MODEL_ID
- ACCELERATE_LOG_LEVEL=info
# Commands of the task
commands:
# Pin torch==2.6.0 to avoid building Flash Attention from source.
# Prebuilt Flash Attention wheels are not available for the latest torch==2.7.0.
- uv pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0
- uv pip install transformers bitsandbytes peft wandb
- uv pip install flash_attn --no-build-isolation
- git clone https://github.com/huggingface/trl
- cd trl
- uv pip install .
- |
accelerate launch \
--config_file=examples/accelerate_configs/multi_gpu.yaml \
--num_processes $DSTACK_GPUS_PER_NODE \
trl/scripts/sft.py \
--model_name meta-llama/Meta-Llama-3.1-8B \
--dataset_name OpenAssistant/oasst_top1_2023-08-25 \
--dataset_text_field="text" \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 4 \
--learning_rate 2e-4 \
--report_to wandb \
--bf16 \
--max_seq_length 1024 \
--lora_r 16 \
--lora_alpha 32 \
--lora_target_modules q_proj k_proj v_proj o_proj \
--load_in_4bit \
--use_peft \
--attn_implementation "flash_attention_2" \
--logging_steps=10 \
--output_dir models/llama31 \
--hub_model_id peterschmidt85/FineLlama-3.1-8B
resources:
gpu:
# 24GB or more VRAM
memory: 24GB..
# One or more GPU
count: 1..
# Shared memory (for multi-gpu)
shm_size: 24GB