Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/gemma3-1b-eagle3.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"transformers_version": "4.50.0",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 262145,
"draft_vocab_size": 32000,
"vocab_size": 262144,
"draft_vocab_size": null,
"target_model_type": "gemma3_text"
}
32 changes: 32 additions & 0 deletions configs/gemma3-27b-eagle3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"architectures": [
"LlamaForCausalLMEagle3"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 2,
"eos_token_id": 1,
"pad_token_id": 0,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 5376,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 4096,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 1,
"num_key_value_heads": 16,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000,
"sliding_window": 512,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.50.0",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 262208,
"draft_vocab_size": 12288,
"target_model_type": "gemma3_text"
}
32 changes: 32 additions & 0 deletions configs/gemma4-26b-a4b-eagle3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"architectures": [
"LlamaForCausalLMEagle3"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 2,
"eos_token_id": 1,
"pad_token_id": 0,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2816,
"initializer_range": 0.02,
"intermediate_size": 2112,
"max_position_embeddings": 4096,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 1,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000,
"sliding_window": 512,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.50.0",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 262144,
"draft_vocab_size": 262144,
"target_model_type": "gemma4_text"
}
174 changes: 174 additions & 0 deletions examples/regen_gemma4_26b_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#!/usr/bin/env bash
# Regenerate training data for Gemma4-26B Eagle3.
#
# This script:
# 1. Launches SGLang server(s) for Gemma4-26B on available GPUs.
# 2. Waits for the server(s) to become healthy.
# 3. Runs regenerate_train_data.py with thinking-ratio support.
# 4. Shuts down the server(s) on exit.
#
# Usage:
# bash examples/regen_gemma4_26b_data.sh
#
# Environment variables (override defaults):
# MODEL - HuggingFace model ID (default: google/gemma-4-26b-a4b-it)
# TP_SIZE - Tensor-parallel size (default: 2)
# NUM_SERVERS - Number of server instances (default: 1)
# BASE_PORT - First server port (default: 30000)
# CONCURRENCY - Requests per server (default: 128)
# MAX_TOKENS - Max generation tokens (default: 8192)
# TEMPERATURE - Sampling temperature (default: 0.8)
# THINKING_RATIO - Fraction with thinking (default: 0.7)
# INPUT_FILE - Input JSONL path (required)
# OUTPUT_FILE - Output JSONL path (required)
# NUM_SAMPLES - Max samples to process (default: all)

set -euo pipefail

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname "$SCRIPT_DIR")

# ── Configurable defaults ────────────────────────────────────────────────────
MODEL="${MODEL:-google/gemma-4-26b-a4b-it}"
TP_SIZE="${TP_SIZE:-1}"
NUM_SERVERS="${NUM_SERVERS:-8}"
BASE_PORT="${BASE_PORT:-30000}"
CONCURRENCY="${CONCURRENCY:-128}"
MAX_TOKENS="${MAX_TOKENS:-2048}"
TEMPERATURE="${TEMPERATURE:-1}"
THINKING_RATIO="${THINKING_RATIO:-0.7}"
INPUT_FILE="${INPUT_FILE:-$ROOT_DIR/cache/dataset/ultrachat_train.jsonl}"
OUTPUT_FILE="${OUTPUT_FILE:-$ROOT_DIR/outputs/dataset/ultrachat_regen_gemma4.jsonl}"
NUM_SAMPLES="${NUM_SAMPLES:-}"

# ── Derived ──────────────────────────────────────────────────────────────────
TOTAL_GPUS=$(( TP_SIZE * NUM_SERVERS ))
AVAIL_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l || echo 0)

if [ "$AVAIL_GPUS" -lt "$TOTAL_GPUS" ]; then
echo "Error: Need ${TOTAL_GPUS} GPUs (${NUM_SERVERS} servers x TP ${TP_SIZE}) but only ${AVAIL_GPUS} available."
exit 1
fi

echo "============================================================"
echo " Gemma4-26B Data Regeneration"
echo "============================================================"
echo " Model: ${MODEL}"
echo " TP size: ${TP_SIZE}"
echo " Servers: ${NUM_SERVERS}"
echo " Ports: ${BASE_PORT}..$(( BASE_PORT + (NUM_SERVERS - 1) * 10 ))"
echo " Concurrency: ${CONCURRENCY} per server"
echo " Max tokens: ${MAX_TOKENS}"
echo " Temperature: ${TEMPERATURE}"
echo " Thinking ratio: ${THINKING_RATIO}"
echo " Input: ${INPUT_FILE}"
echo " Output: ${OUTPUT_FILE}"
echo "============================================================"

# ── Cleanup on exit ──────────────────────────────────────────────────────────
SERVER_PIDS=()

cleanup() {
echo ""
echo "Shutting down SGLang server(s)..."
for pid in "${SERVER_PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
kill "$pid" 2>/dev/null || true
fi
done
# Wait briefly then force-kill stragglers
sleep 2
for pid in "${SERVER_PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
kill -9 "$pid" 2>/dev/null || true
fi
done
echo "All servers stopped."
}
trap cleanup EXIT

# ── Launch servers ───────────────────────────────────────────────────────────
SERVER_ADDRESSES=()

for i in $(seq 0 $(( NUM_SERVERS - 1 ))); do
PORT=$(( BASE_PORT + i * 10 ))
GPU_START=$(( i * TP_SIZE ))
GPU_END=$(( GPU_START + TP_SIZE - 1 ))
CUDA_DEVICES=$(seq -s, "$GPU_START" "$GPU_END")

echo "Starting server $((i+1))/${NUM_SERVERS} on GPUs ${CUDA_DEVICES}, port ${PORT}..."

CUDA_VISIBLE_DEVICES="${CUDA_DEVICES}" /home/pyc_google_com/dev/gemma/.venv/bin/python -m sglang.launch_server \
--model "${MODEL}" \
--tp "${TP_SIZE}" \
--port "${PORT}" \
--host 0.0.0.0 \
--cuda-graph-max-bs 128 \
--trust-remote-code --enable-torch-compile \
> "${ROOT_DIR}/cache/sglang_server_${PORT}.log" 2>&1 &

SERVER_PIDS+=($!)
SERVER_ADDRESSES+=("localhost:${PORT}")
done

# ── Wait for servers to be healthy ───────────────────────────────────────────
echo ""
echo "Waiting for servers to become healthy..."

wait_for_server() {
local addr=$1
local max_wait=600 # 10 minutes
local elapsed=0
while [ $elapsed -lt $max_wait ]; do
if curl -sf "http://${addr}/health" > /dev/null 2>&1; then
return 0
fi
sleep 5
elapsed=$(( elapsed + 5 ))
done
return 1
}

for addr in "${SERVER_ADDRESSES[@]}"; do
if wait_for_server "$addr"; then
echo " ${addr} is healthy."
else
echo "Error: ${addr} did not become healthy within 10 minutes."
echo "Check logs at: ${ROOT_DIR}/cache/sglang_server_*.log"
exit 1
fi
done

echo "All ${NUM_SERVERS} server(s) are ready."
echo "------------------------------------------------------------"

# ── Build regen command ──────────────────────────────────────────────────────
REGEN_ARGS=(
python3 "${ROOT_DIR}/scripts/regenerate_train_data.py"
--model "${MODEL}"
--is-reasoning-model
--thinking-ratio "${THINKING_RATIO}"
--concurrency "${CONCURRENCY}"
--max-tokens "${MAX_TOKENS}"
--temperature "${TEMPERATURE}"
--server-address "${SERVER_ADDRESSES[@]}"
--input-file-path "${INPUT_FILE}"
--output-file-path "${OUTPUT_FILE}"
--resume
)

if [ -n "${NUM_SAMPLES}" ]; then
REGEN_ARGS+=(--num-samples "${NUM_SAMPLES}")
fi

# ── Run regeneration ─────────────────────────────────────────────────────────
echo "Starting data regeneration..."
echo ""

mkdir -p "$(dirname "${OUTPUT_FILE}")"
"${REGEN_ARGS[@]}"

echo ""
echo "============================================================"
echo " Done! Output saved to: ${OUTPUT_FILE}"
echo "============================================================"
7 changes: 5 additions & 2 deletions examples/run_gemma3_1b_eagle3_online.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ torchrun \
$ROOT_DIR/scripts/train_eagle3.py \
--target-model-path google/gemma-3-1b-it \
--draft-model-config $ROOT_DIR/configs/gemma3-1b-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \
--train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \
--output-dir $ROOT_DIR/outputs/gemma3-1b-eagle3-sharegpt \
--num-epochs 10 \
--batch-size 1 \
Expand All @@ -23,4 +23,7 @@ torchrun \
--cache-dir $ROOT_DIR/cache \
--attention-backend sdpa \
--target-model-backend hf \
--log-interval 10
--log-interval 500 \
--eval-interval 2500 \
--save-interval 60000 \
--report-to tensorboard
31 changes: 31 additions & 0 deletions examples/run_gemma3_27b_eagle3_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels

# train eagle3 for gemma3-1b
NUM_GPUS=${1:-8}
TP_SIZE=${2:-8}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3.py \
--target-model-path google/gemma-3-27b-it \
--draft-model-config $ROOT_DIR/configs/gemma3-27b-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \
--output-dir $ROOT_DIR/outputs/gemma3-27b-eagle3-ultrachat \
--eval-holdout-ratio 0.03 \
--num-epochs 10 \
--batch-size 8 \
--tp-size $TP_SIZE \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template gemma \
--cache-dir $ROOT_DIR/cache \
--attention-backend sdpa \
--target-model-backend hf \
--log-interval 500 \
--eval-interval 2500 \
--save-interval 5000 \
--report-to tensorboard \
--embedding-key=language_model.model.embed_tokens.weight
31 changes: 31 additions & 0 deletions examples/run_gemma4_26b_eagle3_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels

# train eagle3 for gemma3-1b
NUM_GPUS=${1:-8}
TP_SIZE=${2:-2}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3.py \
--target-model-path google/gemma-4-26b-a4b-it \
--draft-model-config $ROOT_DIR/configs/gemma4-26b-a4b-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \
--output-dir $ROOT_DIR/outputs/gemma4-26b-a4b-eagle3-ultrachat \
--num-epochs 10 \
--batch-size 4 \
--tp-size $TP_SIZE \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template gemma-4 \
--cache-dir $ROOT_DIR/cache \
--attention-backend sdpa \
--target-model-backend hf \
--log-interval 500 \
--eval-interval 2500 \
--save-interval 10000 \
--report-to tensorboard \
--embedding-key=model.language_model.embed_tokens.weight \
--eval-holdout-ratio 0.05
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
"torch==2.9.1",
"torchaudio==2.9.1",
"torchvision==0.24.1",
"transformers==4.57.1",
"transformers>=5.0.0",
"qwen-vl-utils==0.0.11",
"datasets",
"setuptools",
Expand Down
4 changes: 2 additions & 2 deletions requirements-rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pre-commit
torch==2.8.0+rocm6.3
torchaudio==2.8.0+rocm6.3
torchvision==0.23.0+rocm6.3
transformers==4.57.1
transformers>=5.0.0
qwen-vl-utils==0.0.11
datasets
setuptools
Expand All @@ -15,6 +15,6 @@ psutil
numpy
accelerate
pydantic
sglang[all]==0.5.4
sglang[all]==0.5.9
openai-harmony
tensorboard
Loading