|
| 1 | +#!/bin/bash |
| 2 | +# Launch gpt-oss-20b distillation on TPU v7x. |
| 3 | +# Usage: bash scripts/distillation/distill_gpt_oss_20b.sh [submit|monitor|resume_until_done] |
| 4 | +# |
| 5 | +# Set DISTILL_GCS_BUCKET to your own GCS bucket and XPK_BASE_IMAGE to your own |
| 6 | +# image before running — the placeholders below will not work as-is. Everything |
| 7 | +# else has a working default. Example: |
| 8 | +# |
| 9 | +# DISTILL_GCS_BUCKET=gs://your-bucket \ |
| 10 | +# XPK_BASE_IMAGE=gcr.io/your-project/maxtext_base_image:tag \ |
| 11 | +# bash scripts/distillation/distill_gpt_oss_20b.sh submit |
| 12 | +set -euo pipefail |
| 13 | + |
| 14 | +MODE="${1:-submit}" |
| 15 | +REPO_ROOT=$(cd "$(dirname "$0")/../.." && pwd) |
| 16 | +cd "$REPO_ROOT" |
| 17 | + |
| 18 | +# Your GCS bucket: run outputs, staged YAML, and tokenizer files all land here. |
| 19 | +export DISTILL_GCS_BUCKET="${DISTILL_GCS_BUCKET:-gs://YOUR-BUCKET}" |
| 20 | + |
| 21 | +export XPK_WORKLOAD="${XPK_WORKLOAD:-goss-base-$(date +%Y%m%d-%H%M)}" |
| 22 | +export XPK_RUN_NAME="${XPK_RUN_NAME:-gpt_oss_20b_base}" |
| 23 | +export XPK_CLUSTER="${XPK_CLUSTER:-bodaborg-super-xpk-x8p}" |
| 24 | +export XPK_PROJECT="${XPK_PROJECT:-cloud-tpu-multipod-dev}" |
| 25 | +export XPK_ZONE="${XPK_ZONE:-us-central1}" |
| 26 | +export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}" |
| 27 | +export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-${DISTILL_GCS_BUCKET}/distillation}" |
| 28 | +export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}" |
| 29 | +export XPK_PRIORITY="${XPK_PRIORITY:-high}" |
| 30 | + |
| 31 | +export XPK_USE_GCSFUSE=1 |
| 32 | +export XPK_DATASET_BUCKET="${XPK_DATASET_BUCKET:-maxtext-dataset}" |
| 33 | +export XPK_DATASET_SUBPATH="${XPK_DATASET_SUBPATH:-array-record/climbmix/*.arrayrecord}" |
| 34 | + |
| 35 | +# Stage HF tokenizer files (not in the image for gpt-oss). |
| 36 | +export XPK_TOKENIZER_GCS="${XPK_TOKENIZER_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/tokenizer-gpt-oss-20b/}" |
| 37 | +export XPK_TOKENIZER_LOCAL="${XPK_TOKENIZER_LOCAL:-/deps/src/maxtext/assets/tokenizers/gpt-oss-20b-tokenizer}" |
| 38 | + |
| 39 | +LOCAL_YAML="src/maxtext/configs/post_train/distillation_gpt_oss_20b.yml" |
| 40 | +export XPK_DISTILL_CONFIG="${XPK_DISTILL_CONFIG:-$LOCAL_YAML}" |
| 41 | +export XPK_YAML_GCS="${XPK_YAML_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/distillation_gpt_oss_20b.yml}" |
| 42 | + |
| 43 | +# distill_beta=0: decoder feature loss is broken on gpt-oss. |
| 44 | +export DISTILL_ALPHA="${DISTILL_ALPHA:-0.5}" |
| 45 | +export DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" |
| 46 | +export DISTILL_BETA="${DISTILL_BETA:-0}" |
| 47 | +export DISTILL_LAYER_INDICES="${DISTILL_LAYER_INDICES:-[]}" |
| 48 | + |
| 49 | +# XLA flags tuned for ~17% MFU. sparse_core_collective_aggregator is required |
| 50 | +# by latency_hiding_layer_scheduler. |
| 51 | +export XPK_LIBTPU_INIT_ARGS="${XPK_LIBTPU_INIT_ARGS:---xla_tpu_scoped_vmem_limit_kib=65536 \ |
| 52 | +--xla_tpu_impure_enable_packed_bf16_math_ops=true \ |
| 53 | +--xla_tpu_aggressive_opt_barrier_removal=true \ |
| 54 | +--xla_tpu_enable_sparse_core_collective_aggregator=true \ |
| 55 | +--xla_tpu_enable_latency_hiding_layer_scheduler=true \ |
| 56 | +--xla_tpu_enable_layer_scheduler_for_dependent_collectives=true \ |
| 57 | +--xla_tpu_enable_multi_compute_overlap_in_layer_scheduler=true \ |
| 58 | +--xla_tpu_scheduler_percent_shared_memory_limit=150 \ |
| 59 | +--xla_enable_async_all_gather=true \ |
| 60 | +--xla_tpu_prefer_async_allgather_to_allreduce=true \ |
| 61 | +--xla_max_concurrent_async_all_gathers=2 \ |
| 62 | +--xla_max_concurrent_async_reduce_scatters=2 \ |
| 63 | +--xla_tpu_enable_async_collective_fusion_fuse_all_gather=false}" |
| 64 | + |
| 65 | +if [ "$MODE" = "submit" ]; then |
| 66 | + gcloud storage cp "$XPK_DISTILL_CONFIG" "$XPK_YAML_GCS" |
| 67 | +fi |
| 68 | + |
| 69 | +exec bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh "$MODE" |
0 commit comments