-
Notifications
You must be signed in to change notification settings - Fork 364
Enable multinode training for HF speculative decoding #922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
77270bc
b464a77
558e708
dbf315f
08478e2
023b3d6
e834267
5f63c5b
6295092
5667742
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -62,14 +62,6 @@ while [ $# -gt 0 ]; do | |||||||||||||||||||||||||||||||||||||||||||||||||||
| if [[ "$1" != *=* ]]; then shift; fi | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| TRAIN_BS="${1#*=}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ;; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| --medusa_num_heads*) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if [[ "$1" != *=* ]]; then shift; fi | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| MEDUSA_NUM_HEADS="${1#*=}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ;; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| --medusa_num_layers*) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if [[ "$1" != *=* ]]; then shift; fi | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| MEDUSA_NUM_LAYERS="${1#*=}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ;; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| --eagle_config*) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if [[ "$1" != *=* ]]; then shift; fi | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| EAGLE_CONFIG="${1#*=}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -110,6 +102,14 @@ while [ $# -gt 0 ]; do | |||||||||||||||||||||||||||||||||||||||||||||||||||
| if [[ "$1" != *=* ]]; then shift; fi | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| DRAFT_VOCAB_CACHE="${1#*=}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ;; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| --num_nodes*) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if [[ "$1" != *=* ]]; then shift; fi | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| NUM_NODES="${1#*=}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ;; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| --head_node_ip*) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if [[ "$1" != *=* ]]; then shift; fi | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| HEAD_NODE_IP="${1#*=}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ;; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| *) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| >&2 printf "Error: Invalid argument ${1#*=}\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| exit 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -120,10 +120,12 @@ done | |||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| set -x | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Get the default value for save_steps based on the available number of GPUs | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| NUM_NODES=${NUM_NODES:-1} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Calculate save_steps | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| MODE=${MODE:-"eagle3"} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -135,8 +137,6 @@ NUM_EPOCHS=${NUM_EPOCHS:-1} | |||||||||||||||||||||||||||||||||||||||||||||||||||
| SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| LR=${LR:-"1e-4"} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| TRAIN_BS=${TRAIN_BS:-1} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| DISABLE_TQDM=${DISABLE_TQDM:-False} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -145,20 +145,19 @@ VLM_IMG_DIR=${VLM_IMG_DIR:-} | |||||||||||||||||||||||||||||||||||||||||||||||||||
| AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ESTIMATE_AR=${ESTIMATE_AR:-False} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| CP_SIZE=${CP_SIZE:-1} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| LOG_STEPS=${LOG_STEPS:-100} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
147
to
149
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate
✅ Suggested fix CP_SIZE=${CP_SIZE:-1}
-DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))}
+if (( TOTAL_GPU % CP_SIZE != 0 )); then
+ echo "CP_SIZE ($CP_SIZE) must evenly divide TOTAL_GPU ($TOTAL_GPU)."
+ exit 1
+fi
+DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))}
+if (( DP_SHARD_SIZE < 1 )); then
+ echo "DP_SHARD_SIZE must be >= 1."
+ exit 1
+fi📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||
| DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| if [[ "$MODE" == "medusa" ]]; then | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif [[ "$MODE" == "eagle1" || "$MODE" == "eagle3" ]]; then | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| if [[ "$MODE" == "eagle3" ]]; then | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if [[ -n "$EAGLE_CONFIG" ]]; then | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| SPECULATIVE_ARGS="" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| fi | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| echo "Only medusa, eagle1, eagle3 supported for now!" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| echo "Only eagle3 supported for now!" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| exit 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| fi | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -180,7 +179,7 @@ else | |||||||||||||||||||||||||||||||||||||||||||||||||||
| VLM_ARGS="" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| fi | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| if [[ "$GPU_COUNT" -gt 1 ]]; then | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if [[ "$TOTAL_GPU" -gt 1 ]]; then | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| #Use FSDP2 when multi GPU available | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| FSDP_ARGS="--fsdp 'full_shard' --fsdp_config fsdp_config.json" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -195,10 +194,20 @@ else | |||||||||||||||||||||||||||||||||||||||||||||||||||
| DRAFT_VOCAB_CACHE_ARGS="" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| fi | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| if [[ "$NUM_NODES" != 1 ]]; then | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| --num_machines $NUM_NODES \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| --machine_rank $SLURM_PROCID \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| --rdzv_backend c10d \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| --main_process_ip $HEAD_NODE_IP \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| --main_process_port 29500" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+197
to
+203
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fail fast when required multi-node vars are missing. When 🚦 Suggested fix if [[ "$NUM_NODES" != 1 ]]; then
+ if [[ -z "$HEAD_NODE_IP" ]]; then
+ echo "HEAD_NODE_IP is required when NUM_NODES > 1."
+ exit 1
+ fi
+ if [[ -z "$SLURM_PROCID" ]]; then
+ echo "SLURM_PROCID is required when NUM_NODES > 1."
+ exit 1
+ fi
MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \
--num_machines $NUM_NODES \
--machine_rank $SLURM_PROCID \
--rdzv_backend c10d \
--main_process_ip $HEAD_NODE_IP \
--main_process_port 29500"
else
MULTI_NODE_ARGS=""
fi📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| MULTI_NODE_ARGS="" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| fi | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Disable tokenizers parallelism to avoid warning | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| export TOKENIZERS_PARALLELISM=False | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| CMD="accelerate launch --mixed_precision bf16 main.py \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 main.py \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| --mode $MODE \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| --eagle_decoder_type $EAGLE_DECODER_TYPE \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| --model_name_or_path $MODEL \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,57 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #!/bin/bash | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #SBATCH -A {account} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #SBATCH --job-name={job_name} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #SBATCH --nodes={num_nodes} --ntasks-per-node=1 --gpus-per-node={num_gpus_per_node} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #SBATCH -p {partition} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #SBATCH -t {time_limit} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| CONTAINER_IMAGE={container_image} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| WORK_DIR={path_to_modelopt} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| CONTAINER_MOUNT="${WORK_DIR}:/modelopt" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| OUTPUT_DIR={path_to_output_dir} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| MODEL={path_to_model_dir} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| DATA={path_to_data_dir} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| OFFLINE_DATA={path_to_offline_data_dir} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| CMD="./launch_train.sh --model $MODEL \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --output_dir $OUTPUT_DIR \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --data $DATA \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --num_epochs 1 \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --train_bs 1 \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --lr 1e-4 \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --eagle_config eagle_config.json \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --training_seq_len 4096 \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --save_steps 1000 \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --estimate_ar True \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --disable_tqdm True \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --offline-data $OFFLINE_DATA \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --num_nodes $SLURM_NNODES \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --head_node_ip $head_node_ip \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+24
to
+47
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Define HEAD_NODE_IP (and use consistent casing) before passing it.
🔧 Suggested fix CONTAINER_IMAGE={container_image}
WORK_DIR={path_to_modelopt}
+HEAD_NODE_IP={head_node_ip}
CONTAINER_MOUNT="${WORK_DIR}:/modelopt"
OUTPUT_DIR={path_to_output_dir}
MODEL={path_to_model_dir}
DATA={path_to_data_dir}
OFFLINE_DATA={path_to_offline_data_dir}
+
+if [[ "${SLURM_NNODES:-1}" -gt 1 && -z "$HEAD_NODE_IP" ]]; then
+ echo "HEAD_NODE_IP is required for multi-node runs."
+ exit 1
+fi
CMD="./launch_train.sh --model $MODEL \
--output_dir $OUTPUT_DIR \
--data $DATA \
--num_epochs 1 \
--train_bs 1 \
--lr 1e-4 \
--eagle_config eagle_config.json \
--training_seq_len 4096 \
--save_steps 1000 \
--estimate_ar True \
--disable_tqdm True \
--offline-data $OFFLINE_DATA \
--num_nodes $SLURM_NNODES \
- --head_node_ip $head_node_ip \
+ --head_node_ip $HEAD_NODE_IP \
"📝 Committable suggestion
Suggested change
🧰 Tools🪛 Shellcheck (0.11.0)[warning] 9-9: This { is literal. Check expression (missing ;/\n?) or quote it. (SC1083) [warning] 9-9: This } is literal. Check expression (missing ;/\n?) or quote it. (SC1083) [warning] 10-10: This { is literal. Check expression (missing ;/\n?) or quote it. (SC1083) [warning] 10-10: This } is literal. Check expression (missing ;/\n?) or quote it. (SC1083) [warning] 14-14: This { is literal. Check expression (missing ;/\n?) or quote it. (SC1083) [warning] 14-14: This } is literal. Check expression (missing ;/\n?) or quote it. (SC1083) [warning] 15-15: This { is literal. Check expression (missing ;/\n?) or quote it. (SC1083) [warning] 15-15: This } is literal. Check expression (missing ;/\n?) or quote it. (SC1083) [warning] 16-16: This { is literal. Check expression (missing ;/\n?) or quote it. (SC1083) [warning] 16-16: This } is literal. Check expression (missing ;/\n?) or quote it. (SC1083) [warning] 17-17: This { is literal. Check expression (missing ;/\n?) or quote it. (SC1083) [warning] 17-17: This } is literal. Check expression (missing ;/\n?) or quote it. (SC1083) [warning] 32-32: head_node_ip is referenced but not assigned. (SC2154) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| srun -l \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --mpi=pmix \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --output=%x_%j_$DATETIME.log \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --container-workdir "/modelopt/examples/speculative_decoding" \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --container-image ${CONTAINER_IMAGE} --container-mounts ${CONTAINER_MOUNT} \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bash -lc "$CMD" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| set +x | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against zero GPUs and zero
save_steps.If
nvidia-smireturns 0 GPUs (or fails),TOTAL_GPUbecomes 0 and the script will divide by zero. Also, largeTOTAL_GPUvalues can driveDEFAULT_SAVE_STEPSto 0, which is invalid for trainers.🛡️ Suggested fix
NUM_NODES=${NUM_NODES:-1} GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) +if (( TOTAL_GPU <= 0 )); then + echo "No GPUs detected. Set GPU_PER_NODE/NUM_NODES explicitly." + exit 1 +fi echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" # Calculate save_steps DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU)) +if (( DEFAULT_SAVE_STEPS < 1 )); then + DEFAULT_SAVE_STEPS=1 +fi📝 Committable suggestion
🤖 Prompt for AI Agents