|
| 1 | +#!/usr/bin/env bash |
| 2 | +# Regenerate training data for Gemma4-26B Eagle3. |
| 3 | +# |
| 4 | +# This script: |
| 5 | +# 1. Launches SGLang server(s) for Gemma4-26B on available GPUs. |
| 6 | +# 2. Waits for the server(s) to become healthy. |
| 7 | +# 3. Runs regenerate_train_data.py with thinking-ratio support. |
| 8 | +# 4. Shuts down the server(s) on exit. |
| 9 | +# |
| 10 | +# Usage: |
| 11 | +# bash examples/regen_gemma4_26b_data.sh |
| 12 | +# |
| 13 | +# Environment variables (override defaults): |
| 14 | +# MODEL - HuggingFace model ID (default: google/gemma-4-26b-a4b-it) |
| 15 | +# TP_SIZE - Tensor-parallel size (default: 2) |
| 16 | +# NUM_SERVERS - Number of server instances (default: 1) |
| 17 | +# BASE_PORT - First server port (default: 30000) |
| 18 | +# CONCURRENCY - Requests per server (default: 128) |
| 19 | +# MAX_TOKENS - Max generation tokens (default: 8192) |
| 20 | +# TEMPERATURE - Sampling temperature (default: 0.8) |
| 21 | +# THINKING_RATIO - Fraction with thinking (default: 0.7) |
| 22 | +# INPUT_FILE - Input JSONL path (required) |
| 23 | +# OUTPUT_FILE - Output JSONL path (required) |
| 24 | +# NUM_SAMPLES - Max samples to process (default: all) |
| 25 | + |
| 26 | +set -euo pipefail |
| 27 | + |
| 28 | +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) |
| 29 | +ROOT_DIR=$(dirname "$SCRIPT_DIR") |
| 30 | + |
| 31 | +# ── Configurable defaults ──────────────────────────────────────────────────── |
| 32 | +MODEL="${MODEL:-google/gemma-4-26b-a4b-it}" |
| 33 | +TP_SIZE="${TP_SIZE:-1}" |
| 34 | +NUM_SERVERS="${NUM_SERVERS:-8}" |
| 35 | +BASE_PORT="${BASE_PORT:-30000}" |
| 36 | +CONCURRENCY="${CONCURRENCY:-128}" |
| 37 | +MAX_TOKENS="${MAX_TOKENS:-2048}" |
| 38 | +TEMPERATURE="${TEMPERATURE:-1}" |
| 39 | +THINKING_RATIO="${THINKING_RATIO:-0.7}" |
| 40 | +INPUT_FILE="${INPUT_FILE:-$ROOT_DIR/cache/dataset/ultrachat_train.jsonl}" |
| 41 | +OUTPUT_FILE="${OUTPUT_FILE:-$ROOT_DIR/outputs/dataset/ultrachat_regen_gemma4.jsonl}" |
| 42 | +NUM_SAMPLES="${NUM_SAMPLES:-}" |
| 43 | + |
| 44 | +# ── Derived ────────────────────────────────────────────────────────────────── |
| 45 | +TOTAL_GPUS=$(( TP_SIZE * NUM_SERVERS )) |
| 46 | +AVAIL_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l || echo 0) |
| 47 | + |
| 48 | +if [ "$AVAIL_GPUS" -lt "$TOTAL_GPUS" ]; then |
| 49 | + echo "Error: Need ${TOTAL_GPUS} GPUs (${NUM_SERVERS} servers x TP ${TP_SIZE}) but only ${AVAIL_GPUS} available." |
| 50 | + exit 1 |
| 51 | +fi |
| 52 | + |
| 53 | +echo "============================================================" |
| 54 | +echo " Gemma4-26B Data Regeneration" |
| 55 | +echo "============================================================" |
| 56 | +echo " Model: ${MODEL}" |
| 57 | +echo " TP size: ${TP_SIZE}" |
| 58 | +echo " Servers: ${NUM_SERVERS}" |
| 59 | +echo " Ports: ${BASE_PORT}..$(( BASE_PORT + (NUM_SERVERS - 1) * 10 ))" |
| 60 | +echo " Concurrency: ${CONCURRENCY} per server" |
| 61 | +echo " Max tokens: ${MAX_TOKENS}" |
| 62 | +echo " Temperature: ${TEMPERATURE}" |
| 63 | +echo " Thinking ratio: ${THINKING_RATIO}" |
| 64 | +echo " Input: ${INPUT_FILE}" |
| 65 | +echo " Output: ${OUTPUT_FILE}" |
| 66 | +echo "============================================================" |
| 67 | + |
| 68 | +# ── Cleanup on exit ────────────────────────────────────────────────────────── |
| 69 | +SERVER_PIDS=() |
| 70 | + |
| 71 | +cleanup() { |
| 72 | + echo "" |
| 73 | + echo "Shutting down SGLang server(s)..." |
| 74 | + for pid in "${SERVER_PIDS[@]}"; do |
| 75 | + if kill -0 "$pid" 2>/dev/null; then |
| 76 | + kill "$pid" 2>/dev/null || true |
| 77 | + fi |
| 78 | + done |
| 79 | + # Wait briefly then force-kill stragglers |
| 80 | + sleep 2 |
| 81 | + for pid in "${SERVER_PIDS[@]}"; do |
| 82 | + if kill -0 "$pid" 2>/dev/null; then |
| 83 | + kill -9 "$pid" 2>/dev/null || true |
| 84 | + fi |
| 85 | + done |
| 86 | + echo "All servers stopped." |
| 87 | +} |
| 88 | +trap cleanup EXIT |
| 89 | + |
| 90 | +# ── Launch servers ─────────────────────────────────────────────────────────── |
| 91 | +SERVER_ADDRESSES=() |
| 92 | + |
| 93 | +for i in $(seq 0 $(( NUM_SERVERS - 1 ))); do |
| 94 | + PORT=$(( BASE_PORT + i * 10 )) |
| 95 | + GPU_START=$(( i * TP_SIZE )) |
| 96 | + GPU_END=$(( GPU_START + TP_SIZE - 1 )) |
| 97 | + CUDA_DEVICES=$(seq -s, "$GPU_START" "$GPU_END") |
| 98 | + |
| 99 | + echo "Starting server $((i+1))/${NUM_SERVERS} on GPUs ${CUDA_DEVICES}, port ${PORT}..." |
| 100 | + |
| 101 | + CUDA_VISIBLE_DEVICES="${CUDA_DEVICES}" /home/pyc_google_com/dev/gemma/.venv/bin/python -m sglang.launch_server \ |
| 102 | + --model "${MODEL}" \ |
| 103 | + --tp "${TP_SIZE}" \ |
| 104 | + --port "${PORT}" \ |
| 105 | + --host 0.0.0.0 \ |
| 106 | + --cuda-graph-max-bs 128 \ |
| 107 | + --trust-remote-code --enable-torch-compile \ |
| 108 | + > "${ROOT_DIR}/cache/sglang_server_${PORT}.log" 2>&1 & |
| 109 | + |
| 110 | + SERVER_PIDS+=($!) |
| 111 | + SERVER_ADDRESSES+=("localhost:${PORT}") |
| 112 | +done |
| 113 | + |
| 114 | +# ── Wait for servers to be healthy ─────────────────────────────────────────── |
| 115 | +echo "" |
| 116 | +echo "Waiting for servers to become healthy..." |
| 117 | + |
| 118 | +wait_for_server() { |
| 119 | + local addr=$1 |
| 120 | + local max_wait=600 # 10 minutes |
| 121 | + local elapsed=0 |
| 122 | + while [ $elapsed -lt $max_wait ]; do |
| 123 | + if curl -sf "http://${addr}/health" > /dev/null 2>&1; then |
| 124 | + return 0 |
| 125 | + fi |
| 126 | + sleep 5 |
| 127 | + elapsed=$(( elapsed + 5 )) |
| 128 | + done |
| 129 | + return 1 |
| 130 | +} |
| 131 | + |
| 132 | +for addr in "${SERVER_ADDRESSES[@]}"; do |
| 133 | + if wait_for_server "$addr"; then |
| 134 | + echo " ${addr} is healthy." |
| 135 | + else |
| 136 | + echo "Error: ${addr} did not become healthy within 10 minutes." |
| 137 | + echo "Check logs at: ${ROOT_DIR}/cache/sglang_server_*.log" |
| 138 | + exit 1 |
| 139 | + fi |
| 140 | +done |
| 141 | + |
| 142 | +echo "All ${NUM_SERVERS} server(s) are ready." |
| 143 | +echo "------------------------------------------------------------" |
| 144 | + |
| 145 | +# ── Build regen command ────────────────────────────────────────────────────── |
| 146 | +REGEN_ARGS=( |
| 147 | + python3 "${ROOT_DIR}/scripts/regenerate_train_data.py" |
| 148 | + --model "${MODEL}" |
| 149 | + --is-reasoning-model |
| 150 | + --thinking-ratio "${THINKING_RATIO}" |
| 151 | + --concurrency "${CONCURRENCY}" |
| 152 | + --max-tokens "${MAX_TOKENS}" |
| 153 | + --temperature "${TEMPERATURE}" |
| 154 | + --server-address "${SERVER_ADDRESSES[@]}" |
| 155 | + --input-file-path "${INPUT_FILE}" |
| 156 | + --output-file-path "${OUTPUT_FILE}" |
| 157 | + --resume |
| 158 | +) |
| 159 | + |
| 160 | +if [ -n "${NUM_SAMPLES}" ]; then |
| 161 | + REGEN_ARGS+=(--num-samples "${NUM_SAMPLES}") |
| 162 | +fi |
| 163 | + |
| 164 | +# ── Run regeneration ───────────────────────────────────────────────────────── |
| 165 | +echo "Starting data regeneration..." |
| 166 | +echo "" |
| 167 | + |
| 168 | +mkdir -p "$(dirname "${OUTPUT_FILE}")" |
| 169 | +"${REGEN_ARGS[@]}" |
| 170 | + |
| 171 | +echo "" |
| 172 | +echo "============================================================" |
| 173 | +echo " Done! Output saved to: ${OUTPUT_FILE}" |
| 174 | +echo "============================================================" |
0 commit comments