-
Notifications
You must be signed in to change notification settings - Fork 434
Expand file tree
/
Copy pathlaunch_train.sh
More file actions
executable file
·83 lines (73 loc) · 3.12 KB
/
launch_train.sh
File metadata and controls
executable file
·83 lines (73 loc) · 3.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Usage:
# Single GPU: ./launch_train.sh --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml model.model_name_or_path=xxx
# Multi-node: ./launch_train.sh --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml --num_nodes 2 --head_node_ip <IP>
# With overrides: ./launch_train.sh --config my.yaml model.model_name_or_path=xxx training.output_dir=yyy
#
# Extra key=value args are forwarded as OmegaConf dotlist overrides to main.py; all
# training config lives in the YAML. mixed_precision is fixed to bf16.
set -eo pipefail
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
CONFIG_FILE=""
NUM_NODES=1
HEAD_NODE_IP=""
MACHINE_RANK=""
EXTRA_ARGS=()
while [ $# -gt 0 ]; do
case "$1" in
--config*) if [[ "$1" != *=* ]]; then shift; fi; CONFIG_FILE="${1#*=}" ;;
--num_nodes*) if [[ "$1" != *=* ]]; then shift; fi; NUM_NODES="${1#*=}" ;;
--head_node_ip*) if [[ "$1" != *=* ]]; then shift; fi; HEAD_NODE_IP="${1#*=}" ;;
--machine_rank*) if [[ "$1" != *=* ]]; then shift; fi; MACHINE_RANK="${1#*=}" ;;
*) EXTRA_ARGS+=("$1") ;;
esac
shift
done
if [ -z "$CONFIG_FILE" ]; then
>&2 echo "Usage: ./launch_train.sh --config <yaml_file> [--num_nodes N] [--head_node_ip IP] [key=value ...]"
exit 1
fi
if [[ "$NUM_NODES" != "1" ]]; then
GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)"
else
TOTAL_GPU=$(python3 -c "import torch; print(torch.cuda.device_count())")
echo "Total GPUs: $TOTAL_GPU (single node)"
fi
MULTI_NODE_ARGS=()
if [[ "$NUM_NODES" != "1" ]]; then
# --multi_gpu is required even at 1 GPU/node, else accelerate won't form the DDP group.
# machine_rank defaults to $SLURM_PROCID; override --machine_rank if node 0 isn't a trainer.
MULTI_NODE_ARGS=(
--multi_gpu
--num_processes "$TOTAL_GPU"
--num_machines "$NUM_NODES"
--machine_rank "${MACHINE_RANK:-$SLURM_PROCID}"
--main_process_ip "$HEAD_NODE_IP"
--main_process_port 29500
)
fi
export TOKENIZERS_PARALLELISM=False
# argv array, not `sh -c` (which would word-split overrides and run embedded substitutions).
CMD=(accelerate launch --mixed_precision bf16
"${MULTI_NODE_ARGS[@]}"
"${SCRIPT_DIR}/main.py" --config "$CONFIG_FILE" "${EXTRA_ARGS[@]}")
set -x
start_time=$(date +%s)
"${CMD[@]}"
echo "Total time: $(( $(date +%s) - $start_time )) seconds"