diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 3ef7156372..8c96a19a76 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -211,6 +211,9 @@ def on_log(self, args, state, control, **kwargs): # log to wandb if wandb and is_master(): + logs = kwargs.get("logs") or {} + if logs: + wandb.log({k: v for k, v in logs.items() if v is not None}, step=state.global_step) for i, draft_acc in enumerate(average_acc): for j, step_acc in enumerate(draft_acc): wandb.log( diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index ae8a21eea4..7a566caaf0 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -62,14 +62,6 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi TRAIN_BS="${1#*=}" ;; - --medusa_num_heads*) - if [[ "$1" != *=* ]]; then shift; fi - MEDUSA_NUM_HEADS="${1#*=}" - ;; - --medusa_num_layers*) - if [[ "$1" != *=* ]]; then shift; fi - MEDUSA_NUM_LAYERS="${1#*=}" - ;; --eagle_config*) if [[ "$1" != *=* ]]; then shift; fi EAGLE_CONFIG="${1#*=}" @@ -110,6 +102,14 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi DRAFT_VOCAB_CACHE="${1#*=}" ;; + --num_nodes*) + if [[ "$1" != *=* ]]; then shift; fi + NUM_NODES="${1#*=}" + ;; + --head_node_ip*) + if [[ "$1" != *=* ]]; then shift; fi + HEAD_NODE_IP="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -120,10 +120,12 @@ done set -x -# Get the default value for save_steps based on the available number of GPUs -GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") +NUM_NODES=${NUM_NODES:-1} +GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} +TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) +echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" # Calculate save_steps -DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT)) +DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU)) MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"} MODE=${MODE:-"eagle3"} @@ -135,8 +137,6 @@ NUM_EPOCHS=${NUM_EPOCHS:-1} SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS} LR=${LR:-"1e-4"} TRAIN_BS=${TRAIN_BS:-1} -MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1} -MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1} TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""} DISABLE_TQDM=${DISABLE_TQDM:-False} @@ -145,20 +145,19 @@ VLM_IMG_DIR=${VLM_IMG_DIR:-} AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000} ESTIMATE_AR=${ESTIMATE_AR:-False} CP_SIZE=${CP_SIZE:-1} -DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))} +DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))} LOG_STEPS=${LOG_STEPS:-100} DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} -if [[ "$MODE" == "medusa" ]]; then - SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" -elif [[ "$MODE" == "eagle1" || "$MODE" == "eagle3" ]]; then + +if [[ "$MODE" == "eagle3" ]]; then if [[ -n "$EAGLE_CONFIG" ]]; then SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG" else SPECULATIVE_ARGS="" fi else - echo "Only medusa, eagle1, eagle3 supported for now!" + echo "Only eagle3 supported for now!" exit 1 fi @@ -180,7 +179,7 @@ else VLM_ARGS="" fi -if [[ "$GPU_COUNT" -gt 1 ]]; then +if [[ "$TOTAL_GPU" -gt 1 ]]; then #Use FSDP2 when multi GPU available FSDP_ARGS="--fsdp 'full_shard' --fsdp_config fsdp_config.json" else @@ -195,10 +194,20 @@ else DRAFT_VOCAB_CACHE_ARGS="" fi +if [[ "$NUM_NODES" != 1 ]]; then + MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ + --num_machines $NUM_NODES \ + --machine_rank $SLURM_PROCID \ + --rdzv_backend c10d \ + --main_process_ip $HEAD_NODE_IP \ + --main_process_port 29500" +else + MULTI_NODE_ARGS="" +fi # Disable tokenizers parallelism to avoid warning export TOKENIZERS_PARALLELISM=False -CMD="accelerate launch --mixed_precision bf16 main.py \ +CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 main.py \ --mode $MODE \ --eagle_decoder_type $EAGLE_DECODER_TYPE \ --model_name_or_path $MODEL \ diff --git a/examples/speculative_decoding/slurm.sh b/examples/speculative_decoding/slurm.sh new file mode 100644 index 0000000000..d8091a37fb --- /dev/null +++ b/examples/speculative_decoding/slurm.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#SBATCH -A {account} +#SBATCH --job-name={job_name} +#SBATCH --nodes={num_nodes} --ntasks-per-node=1 --gpus-per-node={num_gpus_per_node} +#SBATCH -p {partition} +#SBATCH -t {time_limit} + +CONTAINER_IMAGE={container_image} +WORK_DIR={path_to_modelopt} + +CONTAINER_MOUNT="${WORK_DIR}:/modelopt" + +OUTPUT_DIR={path_to_output_dir} +MODEL={path_to_model_dir} +DATA={path_to_data_dir} +OFFLINE_DATA={path_to_offline_data_dir} + +CMD="./launch_train.sh --model $MODEL \ + --output_dir $OUTPUT_DIR \ + --data $DATA \ + --num_epochs 1 \ + --train_bs 1 \ + --lr 1e-4 \ + --eagle_config eagle_config.json \ + --training_seq_len 4096 \ + --save_steps 1000 \ + --estimate_ar True \ + --disable_tqdm True \ + --offline-data $OFFLINE_DATA \ + --num_nodes $SLURM_NNODES \ + --head_node_ip $head_node_ip \ +" + +srun -l \ + --mpi=pmix \ + --output=%x_%j_$DATETIME.log \ + --container-workdir "/modelopt/examples/speculative_decoding" \ + --container-image ${CONTAINER_IMAGE} --container-mounts ${CONTAINER_MOUNT} \ + bash -lc "$CMD" + +set +x diff --git a/tests/examples/speculative_decoding/test_medusa.py b/tests/examples/speculative_decoding/test_medusa.py deleted file mode 100644 index 545b79d7ea..0000000000 --- a/tests/examples/speculative_decoding/test_medusa.py +++ /dev/null @@ -1,74 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -from _test_utils.examples.run_command import run_example_command - - -# fmt: off -def _run_hf_ptq(model_path, output_dir, qformat): - run_example_command( - [ - "python", "hf_ptq.py", - "--pyt_ckpt_path", model_path, - "--batch_size", "1", - "--calib_size", "64", - "--export_path", output_dir, - "--qformat", qformat, - ], - "llm_ptq", - ) - - -def test_llama_medusa_fp8_qat(tiny_llama_path, tiny_daring_anteater_path, tmp_path): - medusa_path = tmp_path / "medusa-tinyllama" - - # Test Medusa - run_example_command( - [ - "./launch_train.sh", - "--model", tiny_llama_path, - "--data", tiny_daring_anteater_path, - "--num_epochs", "1", - "--lr", "1e-5", - "--mode", "medusa", - "--output_dir", medusa_path, - "--medusa_num_heads", "2", - "--medusa_num_layers", "1", - ], - "speculative_decoding", - ) - - pytest.skip("speculative decoding uses transformers 5.x, quantization example uses transformers 4.x") - - # Test PTQ on Medusa - _run_hf_ptq(medusa_path, tmp_path / "medusa-tinyllama-hf", "fp8") - - # Test QAT on Medusa - run_example_command( - [ - "./launch.sh", - "--model", medusa_path, - "--num_epochs", "1", - "--train_size", "128", - "--eval_size", "64", - "--lr", "1e-5", - "--output_dir", tmp_path / "medusa-tinyllama-qat-finetune", - "--quant_cfg", "FP8_DEFAULT_CFG", - "--calib_size", "64", - "--backend", "fsdp2", - ], - "llm_qat", - )