|
| 1 | +#!/bin/bash |
| 2 | + |
| 3 | +# This script launches a DiLoCo pre-training workload on a GKE cluster using XPK. |
| 4 | + |
| 5 | +set -e |
| 6 | + |
| 7 | +# --- Environment Setup --- |
| 8 | +if ! pip show xpk &> /dev/null; then |
| 9 | + echo "xpk not found in the environment. Please install it by running:" |
| 10 | + echo "uv pip install -e .[runner] --resolution=lowest" |
| 11 | + exit 1 |
| 12 | +fi |
| 13 | + |
| 14 | +# --- Environment Variables --- |
| 15 | +export PROJECT_ID="${PROJECT_ID:-cloud-tpu-multipod-dev}" |
| 16 | +export CLUSTER_NAME="${CLUSTER_NAME:-v5p-128-bodaborg-europe-west4-b}" |
| 17 | +export ZONE="${ZONE:-europe-west4}" |
| 18 | +export RESERVATION="${RESERVATION:-cloudtpu-20240716121201-595617744}" |
| 19 | +export BASE_OUTPUT_DIRECTORY="${BASE_OUTPUT_DIRECTORY:-gs://chriszuo-maxtext-logs}" # change to your own GCS bucket for logging and checkpointing |
| 20 | +export DATASET_PATH="${DATASET_PATH:-gs://chriszuo-maxtext-datasets}" # change to your own GSC bucket for datasets. Make sure datasets exists |
| 21 | +export DOCKER_IMAGE="${DOCKER_IMAGE:-gcr.io/tpu-prod-env-multipod/maxtext_jax_stable:2026-06-04}" # should update if later versions come up |
| 22 | +export TPU_TYPE="${TPU_TYPE:-v5p-128}" # At least v5p-32 is needed to run Qwen3-30b-a3b. For v5p-8 you may need to decrease the PER_DEVICE_BATCH_SIZE |
| 23 | +export NUM_SLICES="${NUM_SLICES:-2}" # you need at least two slices to let diloco take effect |
| 24 | +export WORKLOAD_NAME="${WORKLOAD_NAME:-$(whoami)-diloco-v5p-$(date +%Y%m%d-%H%M%S)}" # this will be the name of run, for logging purposes |
| 25 | + |
| 26 | +# --- Hyperparameters --- |
| 27 | +export MODEL_NAME="${MODEL_NAME:-qwen3-30b-a3b}" |
| 28 | +export PER_DEVICE_BATCH_SIZE="${PER_DEVICE_BATCH_SIZE:-8.0}" |
| 29 | +export MAX_TARGET_LENGTH="${MAX_TARGET_LENGTH:-2048}" |
| 30 | +export DILOCO_SYNC_PERIOD="${DILOCO_SYNC_PERIOD:-10}" |
| 31 | +export DILOCO_OUTER_LR="${DILOCO_OUTER_LR:-0.3}" |
| 32 | +export DILOCO_OUTER_MOMENTUM="${DILOCO_OUTER_MOMENTUM:-0.9}" |
| 33 | +export TRAINING_STEPS="${TRAINING_STEPS:-100}" |
| 34 | + |
| 35 | +# --- Variable Validation --- |
| 36 | +if [ -z "$PROJECT_ID" ] || [ -z "$CLUSTER_NAME" ] || [ -z "$ZONE" ]; then |
| 37 | + echo "Error: PROJECT_ID, CLUSTER_NAME, or ZONE is not set." |
| 38 | + exit 1 |
| 39 | +fi |
| 40 | + |
| 41 | +if [ -z "$BASE_OUTPUT_DIRECTORY" ] || [ -z "$DATASET_PATH" ]; then |
| 42 | + echo "Error: BASE_OUTPUT_DIRECTORY or DATASET_PATH is not set." |
| 43 | + exit 1 |
| 44 | +fi |
| 45 | + |
| 46 | +if [ "$NUM_SLICES" -lt 2 ]; then |
| 47 | + echo "Warning: NUM_SLICES is less than 2. DiLoCo will not take effect." |
| 48 | +fi |
| 49 | + |
| 50 | +# MaxText command |
| 51 | +MAXTEXT_COMMAND="cd /deps/src/ && python3 maxtext/trainers/pre_train/train.py \ |
| 52 | +maxtext/configs/base.yml \ |
| 53 | +run_name=$WORKLOAD_NAME \ |
| 54 | +save_config_to_gcs=true \ |
| 55 | +base_output_directory=$BASE_OUTPUT_DIRECTORY \ |
| 56 | +dataset_path=$DATASET_PATH \ |
| 57 | +dataset_name='c4/en:3.0.1' \ |
| 58 | +eval_dataset_name='c4/en:3.0.1' \ |
| 59 | +model_name=$MODEL_NAME \ |
| 60 | +tokenizer_type=huggingface \ |
| 61 | +tokenizer_path=maxtext/assets/tokenizers/qwen3-tokenizer \ |
| 62 | +per_device_batch_size=$PER_DEVICE_BATCH_SIZE \ |
| 63 | +max_target_length=$MAX_TARGET_LENGTH \ |
| 64 | +enable_diloco=true \ |
| 65 | +dcn_diloco_parallelism=$NUM_SLICES \ |
| 66 | +diloco_sync_period=$DILOCO_SYNC_PERIOD \ |
| 67 | +diloco_outer_lr=$DILOCO_OUTER_LR \ |
| 68 | +diloco_outer_momentum=$DILOCO_OUTER_MOMENTUM \ |
| 69 | +steps=$TRAINING_STEPS" |
| 70 | + |
| 71 | +# Workload Creation |
| 72 | +echo "Submitting DiLoCo job to XPK..." |
| 73 | +xpk workload create \ |
| 74 | + --cluster="$CLUSTER_NAME" \ |
| 75 | + --project="$PROJECT_ID" \ |
| 76 | + --reservation="$RESERVATION" \ |
| 77 | + --zone="$ZONE" \ |
| 78 | + --tpu-type="$TPU_TYPE" \ |
| 79 | + --num-slices="$NUM_SLICES" \ |
| 80 | + --docker-image="${DOCKER_IMAGE}" \ |
| 81 | + --workload="${WORKLOAD_NAME}" \ |
| 82 | + --command="${MAXTEXT_COMMAND}" |
0 commit comments