From 77270bc4c43687fbc516cc5e7f3fb2f189f16977 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 23 Feb 2026 11:36:20 -0800 Subject: [PATCH 01/10] update launch script for multi node training Signed-off-by: Ye Yu --- examples/speculative_decoding/launch_train.sh | 50 +++++++++++-------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index ae8a21eea4..c292100cf8 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -62,14 +62,6 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi TRAIN_BS="${1#*=}" ;; - --medusa_num_heads*) - if [[ "$1" != *=* ]]; then shift; fi - MEDUSA_NUM_HEADS="${1#*=}" - ;; - --medusa_num_layers*) - if [[ "$1" != *=* ]]; then shift; fi - MEDUSA_NUM_LAYERS="${1#*=}" - ;; --eagle_config*) if [[ "$1" != *=* ]]; then shift; fi EAGLE_CONFIG="${1#*=}" @@ -110,6 +102,14 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi DRAFT_VOCAB_CACHE="${1#*=}" ;; + --num_nodes*) + if [[ "$1" != *=* ]]; then shift; fi + NUM_NODES="${1#*=}" + ;; + --head_node_ip*) + if [[ "$1" != *=* ]]; then shift; fi + HEAD_NODE_IP="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -120,13 +120,15 @@ done set -x -# Get the default value for save_steps based on the available number of GPUs -GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") +NUM_NODES=${NUM_NODES:-1} +GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} +TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) +echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" # Calculate save_steps -DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT)) +DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU)) MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"} -MODE=${MODE:-"eagle3"} +MODE=${MODE:-"eagle"} EAGLE_DECODER_TYPE=${EAGLE_DECODER_TYPE:-"llama"} # Set default OUTPUT_DIR to ckpts/{modelname}, where {modelname} is the last part of the model path MODEL_BASENAME=$(basename "$MODEL") @@ -135,8 +137,6 @@ NUM_EPOCHS=${NUM_EPOCHS:-1} SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS} LR=${LR:-"1e-4"} TRAIN_BS=${TRAIN_BS:-1} -MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1} -MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1} TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""} DISABLE_TQDM=${DISABLE_TQDM:-False} @@ -145,20 +145,19 @@ VLM_IMG_DIR=${VLM_IMG_DIR:-} AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000} ESTIMATE_AR=${ESTIMATE_AR:-False} CP_SIZE=${CP_SIZE:-1} -DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))} +DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))} LOG_STEPS=${LOG_STEPS:-100} DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} -if [[ "$MODE" == "medusa" ]]; then - SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" -elif [[ "$MODE" == "eagle1" || "$MODE" == "eagle3" ]]; then + +if [[ "$MODE" == "eagle" ]]; then if [[ -n "$EAGLE_CONFIG" ]]; then SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG" else SPECULATIVE_ARGS="" fi else - echo "Only medusa, eagle1, eagle3 supported for now!" + echo "Only eagle supported for now!" exit 1 fi @@ -180,7 +179,7 @@ else VLM_ARGS="" fi -if [[ "$GPU_COUNT" -gt 1 ]]; then +if [[ "$TOTAL_GPU" -gt 1 ]]; then #Use FSDP2 when multi GPU available FSDP_ARGS="--fsdp 'full_shard' --fsdp_config fsdp_config.json" else @@ -195,10 +194,19 @@ else DRAFT_VOCAB_CACHE_ARGS="" fi +if [[ "$HEAD_NODE_IP" != "" ]]; then + MULTI_NODE_ARGS="--num_processes $((SLURM_NNODES * GPUS_PER_NODE)) \ + --num_machines $SLURM_NNODES \ + --rdzv_backend c10d \ + --main_process_ip $HEAD_NODE_IP \ + --main_process_port 29500" +else + MULTI_NODE_ARGS="" +fi # Disable tokenizers parallelism to avoid warning export TOKENIZERS_PARALLELISM=False -CMD="accelerate launch --mixed_precision bf16 main.py \ +CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 main.py \ --mode $MODE \ --eagle_decoder_type $EAGLE_DECODER_TYPE \ --model_name_or_path $MODEL \ From b464a773ac0a6e337b1b042f4a6eafd1355f2b15 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 23 Feb 2026 11:46:20 -0800 Subject: [PATCH 02/10] debug Signed-off-by: Ye Yu --- examples/speculative_decoding/launch_train.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index c292100cf8..088f3baf98 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -195,8 +195,8 @@ else fi if [[ "$HEAD_NODE_IP" != "" ]]; then - MULTI_NODE_ARGS="--num_processes $((SLURM_NNODES * GPUS_PER_NODE)) \ - --num_machines $SLURM_NNODES \ + MULTI_NODE_ARGS="--num_processes $((NUM_NODES * GPUS_PER_NODE)) \ + --num_machines $NUM_NODES \ --rdzv_backend c10d \ --main_process_ip $HEAD_NODE_IP \ --main_process_port 29500" From 558e708e07e2e6a78925862e3c254ce82d149105 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 23 Feb 2026 11:55:45 -0800 Subject: [PATCH 03/10] debug Signed-off-by: Ye Yu --- examples/speculative_decoding/launch_train.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 088f3baf98..7d5c1a9961 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -128,7 +128,7 @@ echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU)) MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"} -MODE=${MODE:-"eagle"} +MODE=${MODE:-"eagle3"} EAGLE_DECODER_TYPE=${EAGLE_DECODER_TYPE:-"llama"} # Set default OUTPUT_DIR to ckpts/{modelname}, where {modelname} is the last part of the model path MODEL_BASENAME=$(basename "$MODEL") @@ -150,14 +150,14 @@ LOG_STEPS=${LOG_STEPS:-100} DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} -if [[ "$MODE" == "eagle" ]]; then +if [[ "$MODE" == "eagle3" ]]; then if [[ -n "$EAGLE_CONFIG" ]]; then SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG" else SPECULATIVE_ARGS="" fi else - echo "Only eagle supported for now!" + echo "Only eagle3 supported for now!" exit 1 fi From dbf315f7629615bcd67d071172033fb346a41878 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 23 Feb 2026 12:46:34 -0800 Subject: [PATCH 04/10] add machine rank Signed-off-by: Ye Yu --- examples/speculative_decoding/launch_train.sh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 7d5c1a9961..ea3f62d6fc 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -110,6 +110,10 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi HEAD_NODE_IP="${1#*=}" ;; + --machine_rank*) + if [[ "$1" != *=* ]]; then shift; fi + MACHINE_RANK="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -121,6 +125,7 @@ done set -x NUM_NODES=${NUM_NODES:-1} +MACHINE_RANK=${MACHINE_RANK:-0} GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" @@ -195,8 +200,9 @@ else fi if [[ "$HEAD_NODE_IP" != "" ]]; then - MULTI_NODE_ARGS="--num_processes $((NUM_NODES * GPUS_PER_NODE)) \ + MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ --num_machines $NUM_NODES \ + --machine_rank $MACHINE_RANK \ --rdzv_backend c10d \ --main_process_ip $HEAD_NODE_IP \ --main_process_port 29500" From 08478e222395557e9849dcb32d0f5e737af6aa86 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 23 Feb 2026 13:30:58 -0800 Subject: [PATCH 05/10] debug Signed-off-by: Ye Yu --- examples/speculative_decoding/launch_train.sh | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index ea3f62d6fc..b4c8857d6c 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -106,14 +106,6 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi NUM_NODES="${1#*=}" ;; - --head_node_ip*) - if [[ "$1" != *=* ]]; then shift; fi - HEAD_NODE_IP="${1#*=}" - ;; - --machine_rank*) - if [[ "$1" != *=* ]]; then shift; fi - MACHINE_RANK="${1#*=}" - ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -125,7 +117,6 @@ done set -x NUM_NODES=${NUM_NODES:-1} -MACHINE_RANK=${MACHINE_RANK:-0} GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" @@ -202,9 +193,9 @@ fi if [[ "$HEAD_NODE_IP" != "" ]]; then MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ --num_machines $NUM_NODES \ - --machine_rank $MACHINE_RANK \ + --machine_rank $SLURM_PROCID \ --rdzv_backend c10d \ - --main_process_ip $HEAD_NODE_IP \ + --main_process_ip $(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) \ --main_process_port 29500" else MULTI_NODE_ARGS="" From 023b3d660fa13fd4813ced4c385c649ca369e5de Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 23 Feb 2026 13:47:56 -0800 Subject: [PATCH 06/10] debug Signed-off-by: Ye Yu --- examples/speculative_decoding/launch_train.sh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index b4c8857d6c..8c2d93d608 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -106,6 +106,10 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi NUM_NODES="${1#*=}" ;; + --head_node_ip* + if [[ "$1" != *=* ]]; then shift; fi + HEAD_NODE_IP="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -190,12 +194,12 @@ else DRAFT_VOCAB_CACHE_ARGS="" fi -if [[ "$HEAD_NODE_IP" != "" ]]; then +if [[ "$NUM_NODES" != 1 ]]; then MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ --num_machines $NUM_NODES \ --machine_rank $SLURM_PROCID \ --rdzv_backend c10d \ - --main_process_ip $(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) \ + --main_process_ip $HEAD_NODE_IP \ --main_process_port 29500" else MULTI_NODE_ARGS="" From e8342675fba862d3558ba6e7f3bb8f565f2f8424 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 23 Feb 2026 14:50:47 -0800 Subject: [PATCH 07/10] add slurm template Signed-off-by: Ye Yu --- examples/speculative_decoding/launch_train.sh | 4 +- examples/speculative_decoding/slurm.sh | 42 +++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 examples/speculative_decoding/slurm.sh diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 8c2d93d608..7a566caaf0 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -106,7 +106,7 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi NUM_NODES="${1#*=}" ;; - --head_node_ip* + --head_node_ip*) if [[ "$1" != *=* ]]; then shift; fi HEAD_NODE_IP="${1#*=}" ;; @@ -199,7 +199,7 @@ if [[ "$NUM_NODES" != 1 ]]; then --num_machines $NUM_NODES \ --machine_rank $SLURM_PROCID \ --rdzv_backend c10d \ - --main_process_ip $HEAD_NODE_IP \ + --main_process_ip $HEAD_NODE_IP \ --main_process_port 29500" else MULTI_NODE_ARGS="" diff --git a/examples/speculative_decoding/slurm.sh b/examples/speculative_decoding/slurm.sh new file mode 100644 index 0000000000..957c5f3975 --- /dev/null +++ b/examples/speculative_decoding/slurm.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +#SBATCH -A {account} +#SBATCH --job-name={job_name} +#SBATCH --nodes={num_nodes} --ntasks-per-node=1 --gpus-per-node={num_gpus_per_node} +#SBATCH -p {partition} +#SBATCH -t {time_limit} + +CONTAINER_IMAGE={container_image} +WORK_DIR={path_to_modelopt} + +CONTAINER_MOUNT="${WORK_DIR}:/modelopt" + +OUTPUT_DIR={path_to_output_dir} +MODEL={path_to_model_dir} +DATA={path_to_data_dir} +OFFLINE_DATA={path_to_offline_data_dir} + +CMD="./launch_train.sh --model $MODEL \ + --output_dir $OUTPUT_DIR \ + --data $DATA \ + --num_epochs 1 \ + --train_bs 1 \ + --lr 1e-4 \ + --eagle_config eagle_config.json \ + --training_seq_len 4096 \ + --save_steps 1000 \ + --estimate_ar True \ + --disable_tqdm True \ + --offline-data $OFFLINE_DATA \ + --num_nodes $SLURM_NNODES \ + --head_node_ip $head_node_ip \ +" + +srun -l \ + --mpi=pmix \ + --output=%x_%j_$DATETIME.log \ + --container-workdir "/modelopt/examples/speculative_decoding" \ + --container-image ${CONTAINER_IMAGE} --container-mounts ${CONTAINER_MOUNT} \ + bash -lc "$CMD" + +set +x From 5f63c5b6077fb24ec4062c8f3644cfe022a4c94a Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 24 Feb 2026 09:07:20 -0800 Subject: [PATCH 08/10] formatting Signed-off-by: Ye Yu --- examples/speculative_decoding/slurm.sh | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/examples/speculative_decoding/slurm.sh b/examples/speculative_decoding/slurm.sh index 957c5f3975..d8091a37fb 100644 --- a/examples/speculative_decoding/slurm.sh +++ b/examples/speculative_decoding/slurm.sh @@ -1,5 +1,20 @@ #!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + #SBATCH -A {account} #SBATCH --job-name={job_name} #SBATCH --nodes={num_nodes} --ntasks-per-node=1 --gpus-per-node={num_gpus_per_node} From 6295092fe6cc51b785d3b63130d52b7f7f82291c Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 24 Feb 2026 09:34:20 -0800 Subject: [PATCH 09/10] log to wandb Signed-off-by: Ye Yu --- examples/speculative_decoding/eagle_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 3ef7156372..8c96a19a76 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -211,6 +211,9 @@ def on_log(self, args, state, control, **kwargs): # log to wandb if wandb and is_master(): + logs = kwargs.get("logs") or {} + if logs: + wandb.log({k: v for k, v in logs.items() if v is not None}, step=state.global_step) for i, draft_acc in enumerate(average_acc): for j, step_acc in enumerate(draft_acc): wandb.log( From 56677425d292b9f0a3c714b1b1e2f9d0b5aae7cc Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 24 Feb 2026 12:57:02 -0800 Subject: [PATCH 10/10] deprecate medusa test Signed-off-by: Ye Yu --- .../speculative_decoding/test_medusa.py | 74 ------------------- 1 file changed, 74 deletions(-) delete mode 100644 tests/examples/speculative_decoding/test_medusa.py diff --git a/tests/examples/speculative_decoding/test_medusa.py b/tests/examples/speculative_decoding/test_medusa.py deleted file mode 100644 index 545b79d7ea..0000000000 --- a/tests/examples/speculative_decoding/test_medusa.py +++ /dev/null @@ -1,74 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -from _test_utils.examples.run_command import run_example_command - - -# fmt: off -def _run_hf_ptq(model_path, output_dir, qformat): - run_example_command( - [ - "python", "hf_ptq.py", - "--pyt_ckpt_path", model_path, - "--batch_size", "1", - "--calib_size", "64", - "--export_path", output_dir, - "--qformat", qformat, - ], - "llm_ptq", - ) - - -def test_llama_medusa_fp8_qat(tiny_llama_path, tiny_daring_anteater_path, tmp_path): - medusa_path = tmp_path / "medusa-tinyllama" - - # Test Medusa - run_example_command( - [ - "./launch_train.sh", - "--model", tiny_llama_path, - "--data", tiny_daring_anteater_path, - "--num_epochs", "1", - "--lr", "1e-5", - "--mode", "medusa", - "--output_dir", medusa_path, - "--medusa_num_heads", "2", - "--medusa_num_layers", "1", - ], - "speculative_decoding", - ) - - pytest.skip("speculative decoding uses transformers 5.x, quantization example uses transformers 4.x") - - # Test PTQ on Medusa - _run_hf_ptq(medusa_path, tmp_path / "medusa-tinyllama-hf", "fp8") - - # Test QAT on Medusa - run_example_command( - [ - "./launch.sh", - "--model", medusa_path, - "--num_epochs", "1", - "--train_size", "128", - "--eval_size", "64", - "--lr", "1e-5", - "--output_dir", tmp_path / "medusa-tinyllama-qat-finetune", - "--quant_cfg", "FP8_DEFAULT_CFG", - "--calib_size", "64", - "--backend", "fsdp2", - ], - "llm_qat", - )