Skip to content

Commit 9567c0a

Browse files
pyc96tcligg
authored andcommitted
Add gemma4 data regen.
1 parent 155854b commit 9567c0a

8 files changed

Lines changed: 360 additions & 28 deletions

File tree

configs/gemma4-26b-a4b-eagle3.json

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"architectures": [
3+
"LlamaForCausalLMEagle3"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"bos_token_id": 2,
8+
"eos_token_id": 1,
9+
"pad_token_id": 0,
10+
"head_dim": 128,
11+
"hidden_act": "silu",
12+
"hidden_size": 2816,
13+
"initializer_range": 0.02,
14+
"intermediate_size": 2112,
15+
"max_position_embeddings": 4096,
16+
"model_type": "llama",
17+
"num_attention_heads": 32,
18+
"num_hidden_layers": 1,
19+
"num_key_value_heads": 8,
20+
"rms_norm_eps": 1e-06,
21+
"rope_scaling": null,
22+
"rope_theta": 1000000,
23+
"sliding_window": 512,
24+
"tie_word_embeddings": false,
25+
"torch_dtype": "bfloat16",
26+
"transformers_version": "4.50.0",
27+
"use_cache": true,
28+
"use_sliding_window": false,
29+
"vocab_size": 262144,
30+
"draft_vocab_size": 262144,
31+
"target_model_type": "gemma4_text"
32+
}

examples/regen_gemma4_26b_data.sh

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#!/usr/bin/env bash
2+
# Regenerate training data for Gemma4-26B Eagle3.
3+
#
4+
# This script:
5+
# 1. Launches SGLang server(s) for Gemma4-26B on available GPUs.
6+
# 2. Waits for the server(s) to become healthy.
7+
# 3. Runs regenerate_train_data.py with thinking-ratio support.
8+
# 4. Shuts down the server(s) on exit.
9+
#
10+
# Usage:
11+
# bash examples/regen_gemma4_26b_data.sh
12+
#
13+
# Environment variables (override defaults):
14+
# MODEL - HuggingFace model ID (default: google/gemma-4-26b-a4b-it)
15+
# TP_SIZE - Tensor-parallel size (default: 2)
16+
# NUM_SERVERS - Number of server instances (default: 1)
17+
# BASE_PORT - First server port (default: 30000)
18+
# CONCURRENCY - Requests per server (default: 128)
19+
# MAX_TOKENS - Max generation tokens (default: 8192)
20+
# TEMPERATURE - Sampling temperature (default: 0.8)
21+
# THINKING_RATIO - Fraction with thinking (default: 0.7)
22+
# INPUT_FILE - Input JSONL path (required)
23+
# OUTPUT_FILE - Output JSONL path (required)
24+
# NUM_SAMPLES - Max samples to process (default: all)
25+
26+
set -euo pipefail
27+
28+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
29+
ROOT_DIR=$(dirname "$SCRIPT_DIR")
30+
31+
# ── Configurable defaults ────────────────────────────────────────────────────
32+
MODEL="${MODEL:-google/gemma-4-26b-a4b-it}"
33+
TP_SIZE="${TP_SIZE:-1}"
34+
NUM_SERVERS="${NUM_SERVERS:-8}"
35+
BASE_PORT="${BASE_PORT:-30000}"
36+
CONCURRENCY="${CONCURRENCY:-128}"
37+
MAX_TOKENS="${MAX_TOKENS:-2048}"
38+
TEMPERATURE="${TEMPERATURE:-1}"
39+
THINKING_RATIO="${THINKING_RATIO:-0.7}"
40+
INPUT_FILE="${INPUT_FILE:-$ROOT_DIR/cache/dataset/ultrachat_train.jsonl}"
41+
OUTPUT_FILE="${OUTPUT_FILE:-$ROOT_DIR/outputs/dataset/ultrachat_regen_gemma4.jsonl}"
42+
NUM_SAMPLES="${NUM_SAMPLES:-}"
43+
44+
# ── Derived ──────────────────────────────────────────────────────────────────
45+
TOTAL_GPUS=$(( TP_SIZE * NUM_SERVERS ))
46+
AVAIL_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l || echo 0)
47+
48+
if [ "$AVAIL_GPUS" -lt "$TOTAL_GPUS" ]; then
49+
echo "Error: Need ${TOTAL_GPUS} GPUs (${NUM_SERVERS} servers x TP ${TP_SIZE}) but only ${AVAIL_GPUS} available."
50+
exit 1
51+
fi
52+
53+
echo "============================================================"
54+
echo " Gemma4-26B Data Regeneration"
55+
echo "============================================================"
56+
echo " Model: ${MODEL}"
57+
echo " TP size: ${TP_SIZE}"
58+
echo " Servers: ${NUM_SERVERS}"
59+
echo " Ports: ${BASE_PORT}..$(( BASE_PORT + (NUM_SERVERS - 1) * 10 ))"
60+
echo " Concurrency: ${CONCURRENCY} per server"
61+
echo " Max tokens: ${MAX_TOKENS}"
62+
echo " Temperature: ${TEMPERATURE}"
63+
echo " Thinking ratio: ${THINKING_RATIO}"
64+
echo " Input: ${INPUT_FILE}"
65+
echo " Output: ${OUTPUT_FILE}"
66+
echo "============================================================"
67+
68+
# ── Cleanup on exit ──────────────────────────────────────────────────────────
69+
SERVER_PIDS=()
70+
71+
cleanup() {
72+
echo ""
73+
echo "Shutting down SGLang server(s)..."
74+
for pid in "${SERVER_PIDS[@]}"; do
75+
if kill -0 "$pid" 2>/dev/null; then
76+
kill "$pid" 2>/dev/null || true
77+
fi
78+
done
79+
# Wait briefly then force-kill stragglers
80+
sleep 2
81+
for pid in "${SERVER_PIDS[@]}"; do
82+
if kill -0 "$pid" 2>/dev/null; then
83+
kill -9 "$pid" 2>/dev/null || true
84+
fi
85+
done
86+
echo "All servers stopped."
87+
}
88+
trap cleanup EXIT
89+
90+
# ── Launch servers ───────────────────────────────────────────────────────────
91+
SERVER_ADDRESSES=()
92+
93+
for i in $(seq 0 $(( NUM_SERVERS - 1 ))); do
94+
PORT=$(( BASE_PORT + i * 10 ))
95+
GPU_START=$(( i * TP_SIZE ))
96+
GPU_END=$(( GPU_START + TP_SIZE - 1 ))
97+
CUDA_DEVICES=$(seq -s, "$GPU_START" "$GPU_END")
98+
99+
echo "Starting server $((i+1))/${NUM_SERVERS} on GPUs ${CUDA_DEVICES}, port ${PORT}..."
100+
101+
CUDA_VISIBLE_DEVICES="${CUDA_DEVICES}" /home/pyc_google_com/dev/gemma/.venv/bin/python -m sglang.launch_server \
102+
--model "${MODEL}" \
103+
--tp "${TP_SIZE}" \
104+
--port "${PORT}" \
105+
--host 0.0.0.0 \
106+
--cuda-graph-max-bs 128 \
107+
--trust-remote-code --enable-torch-compile \
108+
> "${ROOT_DIR}/cache/sglang_server_${PORT}.log" 2>&1 &
109+
110+
SERVER_PIDS+=($!)
111+
SERVER_ADDRESSES+=("localhost:${PORT}")
112+
done
113+
114+
# ── Wait for servers to be healthy ───────────────────────────────────────────
115+
echo ""
116+
echo "Waiting for servers to become healthy..."
117+
118+
wait_for_server() {
119+
local addr=$1
120+
local max_wait=600 # 10 minutes
121+
local elapsed=0
122+
while [ $elapsed -lt $max_wait ]; do
123+
if curl -sf "http://${addr}/health" > /dev/null 2>&1; then
124+
return 0
125+
fi
126+
sleep 5
127+
elapsed=$(( elapsed + 5 ))
128+
done
129+
return 1
130+
}
131+
132+
for addr in "${SERVER_ADDRESSES[@]}"; do
133+
if wait_for_server "$addr"; then
134+
echo " ${addr} is healthy."
135+
else
136+
echo "Error: ${addr} did not become healthy within 10 minutes."
137+
echo "Check logs at: ${ROOT_DIR}/cache/sglang_server_*.log"
138+
exit 1
139+
fi
140+
done
141+
142+
echo "All ${NUM_SERVERS} server(s) are ready."
143+
echo "------------------------------------------------------------"
144+
145+
# ── Build regen command ──────────────────────────────────────────────────────
146+
REGEN_ARGS=(
147+
python3 "${ROOT_DIR}/scripts/regenerate_train_data.py"
148+
--model "${MODEL}"
149+
--is-reasoning-model
150+
--thinking-ratio "${THINKING_RATIO}"
151+
--concurrency "${CONCURRENCY}"
152+
--max-tokens "${MAX_TOKENS}"
153+
--temperature "${TEMPERATURE}"
154+
--server-address "${SERVER_ADDRESSES[@]}"
155+
--input-file-path "${INPUT_FILE}"
156+
--output-file-path "${OUTPUT_FILE}"
157+
--resume
158+
)
159+
160+
if [ -n "${NUM_SAMPLES}" ]; then
161+
REGEN_ARGS+=(--num-samples "${NUM_SAMPLES}")
162+
fi
163+
164+
# ── Run regeneration ─────────────────────────────────────────────────────────
165+
echo "Starting data regeneration..."
166+
echo ""
167+
168+
mkdir -p "$(dirname "${OUTPUT_FILE}")"
169+
"${REGEN_ARGS[@]}"
170+
171+
echo ""
172+
echo "============================================================"
173+
echo " Done! Output saved to: ${OUTPUT_FILE}"
174+
echo "============================================================"

examples/run_gemma3_27b_eagle3_online.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ torchrun \
1515
--train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \
1616
--output-dir $ROOT_DIR/outputs/gemma3-27b-eagle3-ultrachat \
1717
--num-epochs 10 \
18-
--batch-size 2 \
18+
--batch-size 8 \
1919
--tp-size $TP_SIZE \
2020
--learning-rate 1e-4 \
2121
--max-length 2048 \
2222
--chat-template gemma \
2323
--cache-dir $ROOT_DIR/cache \
2424
--attention-backend sdpa \
2525
--target-model-backend hf \
26-
--log-interval 100 \
27-
--eval-interval 500 \
28-
--save-interval 10000 \
26+
--log-interval 500 \
27+
--eval-interval 2500 \
28+
--save-interval 5000 \
2929
--report-to tensorboard \
3030
--embedding-key=language_model.model.embed_tokens.weight
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
2+
ROOT_DIR=$(dirname $SCRIPT_DIR)
3+
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
4+
5+
# train eagle3 for gemma3-1b
6+
NUM_GPUS=${1:-8}
7+
TP_SIZE=${2:-2}
8+
9+
torchrun \
10+
--standalone \
11+
--nproc_per_node $NUM_GPUS \
12+
$ROOT_DIR/scripts/train_eagle3.py \
13+
--target-model-path google/gemma-4-26b-a4b-it \
14+
--draft-model-config $ROOT_DIR/configs/gemma4-26b-a4b-eagle3.json \
15+
--train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \
16+
--output-dir $ROOT_DIR/outputs/gemma4-26b-a4b-eagle3-ultrachat \
17+
--num-epochs 10 \
18+
--batch-size 4 \
19+
--tp-size $TP_SIZE \
20+
--learning-rate 1e-4 \
21+
--max-length 2048 \
22+
--chat-template gemma-4 \
23+
--cache-dir $ROOT_DIR/cache \
24+
--attention-backend sdpa \
25+
--target-model-backend hf \
26+
--log-interval 500 \
27+
--eval-interval 2500 \
28+
--save-interval 10000 \
29+
--report-to tensorboard \
30+
--embedding-key=model.language_model.embed_tokens.weight \
31+
--eval-holdout-ratio 0.05

scripts/regenerate_train_data.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
This script will re-generate the dataset from target model,
3-
which better aligns the draft model with the target models output distribution.
3+
which better aligns the draft model with the target model's output distribution.
44
55
Usage:
66
1. Set up one or more SGLang servers for the target model.
@@ -60,6 +60,15 @@ def parse_arguments():
6060
action="store_true",
6161
help="Whether the model is a GPT-OSS model",
6262
)
63+
model_group.add_argument(
64+
"--thinking-ratio",
65+
type=float,
66+
default=None,
67+
help="Fraction of requests sent with thinking enabled (0 to 1). "
68+
"Requires --is-reasoning-model. When set, each request randomly "
69+
"enables or disables thinking based on this ratio. "
70+
"E.g., 0.7 means 70%% of samples use thinking, 30%% do not.",
71+
)
6372

6473
# sampling params
6574
sampling_params_group = parser.add_argument_group("sampling parameters")
@@ -184,6 +193,9 @@ def build_query_kwargs(args, messages, max_tokens=None):
184193
extra_body = {}
185194
if args.top_k is not None:
186195
extra_body["top_k"] = args.top_k
196+
if args.thinking_ratio is not None:
197+
enable_thinking = random.random() < args.thinking_ratio
198+
extra_body["chat_template_kwargs"] = {"enable_thinking": enable_thinking}
187199
if extra_body:
188200
query_kwargs["extra_body"] = extra_body
189201
if args.is_gpt_oss:
@@ -255,11 +267,19 @@ def main():
255267
if args.max_tokens <= 0:
256268
raise ValueError("Max tokens must be greater than 0")
257269

270+
if args.thinking_ratio is not None:
271+
if not (0.0 <= args.thinking_ratio <= 1.0):
272+
raise ValueError("--thinking-ratio must be between 0.0 and 1.0")
273+
if not args.is_reasoning_model:
274+
raise ValueError("--thinking-ratio requires --is-reasoning-model")
275+
258276
print(f"Configuration:")
259277
print(f" Model path: {args.model}")
260278
print(f" Max tokens: {args.max_tokens}")
261279
print(f" Concurrency: {args.concurrency}")
262280
print(f" Temperature: {args.temperature}")
281+
if args.thinking_ratio is not None:
282+
print(f" Thinking ratio: {args.thinking_ratio:.0%}")
263283
print(f" API URL: {args.server_address}")
264284
print(f" Input file: {args.input_file_path}")
265285
print(f" Output file: {args.output_file_path}")

0 commit comments

Comments
 (0)