-
-
Notifications
You must be signed in to change notification settings - Fork 227
Expand file tree
/
Copy pathfsdp.dstack.yml
More file actions
52 lines (47 loc) · 1.34 KB
/
fsdp.dstack.yml
File metadata and controls
52 lines (47 loc) · 1.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
type: task
name: trl-train-fsdp-distrib
# Size of the cluster
nodes: 2
image: nvcr.io/nvidia/pytorch:25.01-py3
# Required environment variables
env:
- HF_TOKEN
- WANDB_API_KEY
- HUB_MODEL_ID
- MODEL_ID=meta-llama/Llama-3.1-8B
- ACCELERATE_LOG_LEVEL=info
# Commands of the task
commands:
- pip install transformers bitsandbytes peft wandb
- git clone https://github.com/huggingface/trl
- cd trl
- pip install .
- |
accelerate launch \
--config_file=examples/accelerate_configs/fsdp1.yaml \
--main_process_ip=$DSTACK_MASTER_NODE_IP \
--main_process_port=8008 \
--machine_rank=$DSTACK_NODE_RANK \
--num_processes=$DSTACK_GPUS_NUM \
--num_machines=$DSTACK_NODES_NUM \
trl/scripts/sft.py \
--model_name $MODEL_ID \
--dataset_name OpenAssistant/oasst_top1_2023-08-25 \
--dataset_text_field="text" \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 4 \
--learning_rate 2e-4 \
--report_to wandb \
--bf16 \
--max_seq_length 1024 \
--attn_implementation flash_attention_2 \
--logging_steps=10 \
--output_dir /checkpoints/llama31-ft \
--hub_model_id $HUB_MODEL_ID \
--torch_dtype bfloat16
resources:
gpu: 80GB:8
shm_size: 128GB
volumes:
- /checkpoints:/checkpoints