Skip to content

Commit 4a11486

Browse files
authored
Enable multinode training for HF speculative decoding (#922)
## What does this PR do? **Type of change:** New example **Overview:** Modify launch_train.sh script to enable multi-node training. Provide a slurm template script. ## Usage Add required fields in slurm.sh. Then use the command below to submit multi-node job: ```bash bash slurm.sh ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added multi-node training support with new configuration options for distributed setups. * Introduced Slurm batch script for streamlined job submission to cluster environments. * **Improvements** * Enhanced GPU resource management for multi-GPU and distributed training configurations. * Updated speculative decoding model support and validation handling. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent ef5a2df commit 4a11486

4 files changed

Lines changed: 89 additions & 94 deletions

File tree

examples/speculative_decoding/eagle_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ def on_log(self, args, state, control, **kwargs):
211211

212212
# log to wandb
213213
if wandb and is_master():
214+
logs = kwargs.get("logs") or {}
215+
if logs:
216+
wandb.log({k: v for k, v in logs.items() if v is not None}, step=state.global_step)
214217
for i, draft_acc in enumerate(average_acc):
215218
for j, step_acc in enumerate(draft_acc):
216219
wandb.log(

examples/speculative_decoding/launch_train.sh

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,6 @@ while [ $# -gt 0 ]; do
6262
if [[ "$1" != *=* ]]; then shift; fi
6363
TRAIN_BS="${1#*=}"
6464
;;
65-
--medusa_num_heads*)
66-
if [[ "$1" != *=* ]]; then shift; fi
67-
MEDUSA_NUM_HEADS="${1#*=}"
68-
;;
69-
--medusa_num_layers*)
70-
if [[ "$1" != *=* ]]; then shift; fi
71-
MEDUSA_NUM_LAYERS="${1#*=}"
72-
;;
7365
--eagle_config*)
7466
if [[ "$1" != *=* ]]; then shift; fi
7567
EAGLE_CONFIG="${1#*=}"
@@ -110,6 +102,14 @@ while [ $# -gt 0 ]; do
110102
if [[ "$1" != *=* ]]; then shift; fi
111103
DRAFT_VOCAB_CACHE="${1#*=}"
112104
;;
105+
--num_nodes*)
106+
if [[ "$1" != *=* ]]; then shift; fi
107+
NUM_NODES="${1#*=}"
108+
;;
109+
--head_node_ip*)
110+
if [[ "$1" != *=* ]]; then shift; fi
111+
HEAD_NODE_IP="${1#*=}"
112+
;;
113113
*)
114114
>&2 printf "Error: Invalid argument ${1#*=}\n"
115115
exit 1
@@ -120,10 +120,12 @@ done
120120

121121
set -x
122122

123-
# Get the default value for save_steps based on the available number of GPUs
124-
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
123+
NUM_NODES=${NUM_NODES:-1}
124+
GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
125+
TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
126+
echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)"
125127
# Calculate save_steps
126-
DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
128+
DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU))
127129

128130
MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
129131
MODE=${MODE:-"eagle3"}
@@ -135,8 +137,6 @@ NUM_EPOCHS=${NUM_EPOCHS:-1}
135137
SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS}
136138
LR=${LR:-"1e-4"}
137139
TRAIN_BS=${TRAIN_BS:-1}
138-
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
139-
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
140140
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
141141
OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""}
142142
DISABLE_TQDM=${DISABLE_TQDM:-False}
@@ -145,20 +145,19 @@ VLM_IMG_DIR=${VLM_IMG_DIR:-}
145145
AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000}
146146
ESTIMATE_AR=${ESTIMATE_AR:-False}
147147
CP_SIZE=${CP_SIZE:-1}
148-
DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))}
148+
DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))}
149149
LOG_STEPS=${LOG_STEPS:-100}
150150
DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""}
151151

152-
if [[ "$MODE" == "medusa" ]]; then
153-
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
154-
elif [[ "$MODE" == "eagle1" || "$MODE" == "eagle3" ]]; then
152+
153+
if [[ "$MODE" == "eagle3" ]]; then
155154
if [[ -n "$EAGLE_CONFIG" ]]; then
156155
SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG"
157156
else
158157
SPECULATIVE_ARGS=""
159158
fi
160159
else
161-
echo "Only medusa, eagle1, eagle3 supported for now!"
160+
echo "Only eagle3 supported for now!"
162161
exit 1
163162
fi
164163

@@ -180,7 +179,7 @@ else
180179
VLM_ARGS=""
181180
fi
182181

183-
if [[ "$GPU_COUNT" -gt 1 ]]; then
182+
if [[ "$TOTAL_GPU" -gt 1 ]]; then
184183
#Use FSDP2 when multi GPU available
185184
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config fsdp_config.json"
186185
else
@@ -195,10 +194,20 @@ else
195194
DRAFT_VOCAB_CACHE_ARGS=""
196195
fi
197196

197+
if [[ "$NUM_NODES" != 1 ]]; then
198+
MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \
199+
--num_machines $NUM_NODES \
200+
--machine_rank $SLURM_PROCID \
201+
--rdzv_backend c10d \
202+
--main_process_ip $HEAD_NODE_IP \
203+
--main_process_port 29500"
204+
else
205+
MULTI_NODE_ARGS=""
206+
fi
198207

199208
# Disable tokenizers parallelism to avoid warning
200209
export TOKENIZERS_PARALLELISM=False
201-
CMD="accelerate launch --mixed_precision bf16 main.py \
210+
CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 main.py \
202211
--mode $MODE \
203212
--eagle_decoder_type $EAGLE_DECODER_TYPE \
204213
--model_name_or_path $MODEL \
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/bin/bash
2+
3+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
#SBATCH -A {account}
19+
#SBATCH --job-name={job_name}
20+
#SBATCH --nodes={num_nodes} --ntasks-per-node=1 --gpus-per-node={num_gpus_per_node}
21+
#SBATCH -p {partition}
22+
#SBATCH -t {time_limit}
23+
24+
CONTAINER_IMAGE={container_image}
25+
WORK_DIR={path_to_modelopt}
26+
27+
CONTAINER_MOUNT="${WORK_DIR}:/modelopt"
28+
29+
OUTPUT_DIR={path_to_output_dir}
30+
MODEL={path_to_model_dir}
31+
DATA={path_to_data_dir}
32+
OFFLINE_DATA={path_to_offline_data_dir}
33+
34+
CMD="./launch_train.sh --model $MODEL \
35+
--output_dir $OUTPUT_DIR \
36+
--data $DATA \
37+
--num_epochs 1 \
38+
--train_bs 1 \
39+
--lr 1e-4 \
40+
--eagle_config eagle_config.json \
41+
--training_seq_len 4096 \
42+
--save_steps 1000 \
43+
--estimate_ar True \
44+
--disable_tqdm True \
45+
--offline-data $OFFLINE_DATA \
46+
--num_nodes $SLURM_NNODES \
47+
--head_node_ip $head_node_ip \
48+
"
49+
50+
srun -l \
51+
--mpi=pmix \
52+
--output=%x_%j_$DATETIME.log \
53+
--container-workdir "/modelopt/examples/speculative_decoding" \
54+
--container-image ${CONTAINER_IMAGE} --container-mounts ${CONTAINER_MOUNT} \
55+
bash -lc "$CMD"
56+
57+
set +x

tests/examples/speculative_decoding/test_medusa.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

0 commit comments

Comments
 (0)