Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
MODEL="${1#*=}"
;;
--trust_remote_code*)
if [[ "$1" != *=* ]]; then shift; fi
TRUST_REMOTE_CODE="${1#*=}"
;;
Comment thread
h-guo18 marked this conversation as resolved.
--data*)
if [[ "$1" != *=* ]]; then shift; fi
DATA="${1#*=}"
Expand Down Expand Up @@ -114,6 +118,10 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
MIX_HIDDEN_STATES="${1#*=}"
;;
--fsdp*)
if [[ "$1" != *=* ]]; then shift; fi
FSDP="${1#*=}"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand All @@ -126,13 +134,21 @@ set -x

SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
NUM_NODES=${NUM_NODES:-1}
GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)"
if [[ "$NUM_NODES" != 1 ]]; then
#Multi Node Training
Comment thread
h-guo18 marked this conversation as resolved.
GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
Comment thread
h-guo18 marked this conversation as resolved.
echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)"
else
#Single Node Training, GPU can be specified by $CUDA_VISIBLE_DEVICES
TOTAL_GPU=$(python -c "import torch; print(torch.cuda.device_count())")
echo "Total GPUs: $TOTAL_GPU (Single Node Training)"
fi
# Calculate save_steps
DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU))
Comment thread
h-guo18 marked this conversation as resolved.

MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
TRUST_REMOTE_CODE=${TRUST_REMOTE_CODE:-True}
MODE=${MODE:-"eagle3"}
EAGLE_DECODER_TYPE=${EAGLE_DECODER_TYPE:-"llama"}
# Set default OUTPUT_DIR to ckpts/{modelname}, where {modelname} is the last part of the model path
Expand All @@ -154,7 +170,7 @@ DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))}
LOG_STEPS=${LOG_STEPS:-100}
DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""}
MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"}

FSDP=${FSDP:-"False"}
Comment thread
h-guo18 marked this conversation as resolved.

if [[ "$MODE" == "eagle3" ]]; then
if [[ -n "$EAGLE_CONFIG" ]]; then
Expand Down Expand Up @@ -185,15 +201,17 @@ else
VLM_ARGS=""
fi

if [[ "$TOTAL_GPU" -gt 1 ]]; then
#Use FSDP2 when multi GPU available
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json"
else
#Otherwise, single GPU training
FSDP_ARGS=""
FSDP_ARGS=""
if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then
# Use FSDP when multi GPU available, default to FSDP1
Comment thread
h-guo18 marked this conversation as resolved.
FSDP_ARGS="$FSDP_ARGS --fsdp 'full_shard'"
TRANSFORMERS_5=$(python -c "from packaging.version import Version; import transformers; print(Version(transformers.__version__) >= Version('5.0'))" 2>/dev/null)
if [[ "$TRANSFORMERS_5" == "True" ]]; then
# For transformers >= 5.0, use FSDP2
FSDP_ARGS="$FSDP_ARGS --fsdp_config ${SCRIPT_DIR}/fsdp_config.json"
fi
fi


if [[ "$DRAFT_VOCAB_CACHE" != "" ]]; then
DRAFT_VOCAB_CACHE_ARGS="--draft_vocab_cache $DRAFT_VOCAB_CACHE"
else
Expand All @@ -217,6 +235,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai
--mode $MODE \
--eagle_decoder_type $EAGLE_DECODER_TYPE \
--model_name_or_path $MODEL \
--trust_remote_code $TRUST_REMOTE_CODE \
--training_seq_len $TRAINING_SEQ_LEN \
--dataloader_drop_last True \
--bf16 True \
Expand Down Expand Up @@ -251,4 +270,4 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai

start_time=$(date +%s)
sh -c "$CMD"
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
12 changes: 8 additions & 4 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
@dataclass
class ModelArguments:
model_name_or_path: str | None = field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
Comment thread
h-guo18 marked this conversation as resolved.
trust_remote_code: bool = field(default=False)


@dataclass
Expand Down Expand Up @@ -170,9 +171,11 @@ def train():
if checkpoint:
with patch_transformers5_params_loading():
_, model = load_vlm_or_llm_with_kwargs(
checkpoint, torch_dtype="auto", trust_remote_code=True
checkpoint, torch_dtype="auto", trust_remote_code=model_args.trust_remote_code
)
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer = transformers.AutoTokenizer.from_pretrained(
checkpoint, trust_remote_code=model_args.trust_remote_code
)
else:
# To avoid OOM for large models, we load and convert model on CPU first.
# Model will be moved to GPU during HF trainer.init().
Expand All @@ -181,7 +184,7 @@ def train():
model_args.model_name_or_path,
torch_dtype="auto",
device_map="cpu",
trust_remote_code=True,
trust_remote_code=model_args.trust_remote_code,
**offline_kwargs,
)
if use_offline_training:
Expand All @@ -191,7 +194,7 @@ def train():
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
model_max_length=training_args.training_seq_len,
trust_remote_code=True,
trust_remote_code=model_args.trust_remote_code,
)
if training_args.mode == "medusa":
config = {
Expand All @@ -209,6 +212,7 @@ def train():
"eagle_offline": use_offline_training,
"eagle_mix_hidden_states": eagle_args.mix_hidden_states,
"eagle_architecture_config": custom_config,
"eagle_train_length": training_args.training_seq_len,
}

mtsp.convert(model, [("eagle", config)])
Expand Down
18 changes: 18 additions & 0 deletions modelopt/torch/export/plugins/hf_spec_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,24 @@ def _get_config_from_draft_or_base(key: str, model: nn.Module):
new_value = str(new_value).replace("torch.", "")
template_config[key] = new_value

# For long context quality, we disable rope scaling for training
# and set yarn during export for inference.
eagle_train_length = getattr(self.model, "eagle_train_length", None)
if eagle_train_length is None:
raise ValueError("eagle_train_length is needed for rope scaling but not set.")
if self.model.eagle_config.rope_parameters["rope_type"] == "default":
template_config["rope_scaling"] = {
"rope_type": "yarn",
Comment thread
h-guo18 marked this conversation as resolved.
Outdated
"factor": 32.0,
"original_max_position_embeddings": getattr(self.model, "eagle_train_length", 4096),
Comment thread
h-guo18 marked this conversation as resolved.
Outdated
Comment thread
h-guo18 marked this conversation as resolved.
Outdated
}

# In transformer 5.x, rope_theta is under rope_parameters, rather than main config
if not template_config.get("rope_theta"):
Comment thread
h-guo18 marked this conversation as resolved.
Outdated
template_config["rope_theta"] = self.model.eagle_config.rope_parameters.get(
"rope_theta"
)

return template_config

def export(self, export_dir: Path | str, dtype: torch.dtype | None = None):
Expand Down
7 changes: 7 additions & 0 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,10 @@ class EagleConfig(ModeloptBaseConfig):
"Whether to mix hidden states of multiple TTT steps. It is a technique to reduce training cost."
),
)

eagle_train_length: int = ModeloptField(
default=2048,
description=(
"The length of the training data. Used to set original_max_position_embeddings in rope_scaling."
),
)
Comment thread
h-guo18 marked this conversation as resolved.
11 changes: 3 additions & 8 deletions modelopt/torch/speculative/eagle/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,8 @@
"hidden_act": "silu",
"torch_dtype": "bfloat16",
"position_embedding_type": "rope",
"rope_scaling": {
Comment thread
h-guo18 marked this conversation as resolved.
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
"rope_theta": 500000.0,
"rope_scaling": {"rope_type": "default", "rope_theta": 10000},
Comment thread
benchislett marked this conversation as resolved.
"rope_theta": 10000,
"num_hidden_layers": 1,
"intermediate_size": 14336,
"num_attention_heads": 32,
Expand Down Expand Up @@ -90,6 +84,7 @@
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
"rope_theta": 50000.0,
},
"rope_theta": 50000.0,
"routed_scaling_factor": 2.827,
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/speculative/eagle/eagle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ def modify(
self.eagle_decoder_type = config.eagle_decoder_type
self.eagle_ttt_steps = config.eagle_ttt_steps
self.eagle_mix_hidden_states = config.eagle_mix_hidden_states
self.eagle_train_length = config.eagle_train_length
Loading