|
| 1 | +#!/bin/bash |
| 2 | +# Launch gpt-oss-20b distillation on TPU v7x. |
| 3 | +# Usage: bash scripts/distillation/distill_gpt_oss_20b.sh [submit|monitor|resume_until_done] |
| 4 | +# |
| 5 | +# The gs:// / gcr.io values below (e.g. gs://agagik-us/...) are DEMO DEFAULTS so |
| 6 | +# this launcher runs out of the box. Every one is an env-var override — point |
| 7 | +# XPK_BASE_OUTPUT_DIR / XPK_BASE_IMAGE / XPK_TOKENIZER_GCS / XPK_YAML_GCS at your |
| 8 | +# own bucket and image before running for real. |
| 9 | +set -euo pipefail |
| 10 | + |
| 11 | +MODE="${1:-submit}" |
| 12 | +REPO_ROOT=$(cd "$(dirname "$0")/../.." && pwd) |
| 13 | +cd "$REPO_ROOT" |
| 14 | + |
| 15 | +export XPK_WORKLOAD="${XPK_WORKLOAD:-goss-base-$(date +%Y%m%d-%H%M)}" |
| 16 | +export XPK_RUN_NAME="${XPK_RUN_NAME:-gpt_oss_20b_base}" |
| 17 | +export XPK_CLUSTER="${XPK_CLUSTER:-bodaborg-super-xpk-x8p}" |
| 18 | +export XPK_PROJECT="${XPK_PROJECT:-cloud-tpu-multipod-dev}" |
| 19 | +export XPK_ZONE="${XPK_ZONE:-us-central1}" |
| 20 | +export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}" |
| 21 | +export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-gs://agagik-us/distillation}" |
| 22 | +export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}" |
| 23 | +export XPK_PRIORITY="${XPK_PRIORITY:-high}" |
| 24 | + |
| 25 | +export XPK_USE_GCSFUSE=1 |
| 26 | +export XPK_DATASET_BUCKET="${XPK_DATASET_BUCKET:-maxtext-dataset}" |
| 27 | +export XPK_DATASET_SUBPATH="${XPK_DATASET_SUBPATH:-array-record/climbmix/*.arrayrecord}" |
| 28 | + |
| 29 | +# Stage HF tokenizer files (not in the image for gpt-oss). |
| 30 | +export XPK_TOKENIZER_GCS="${XPK_TOKENIZER_GCS:-gs://agagik-us/distill-configs/tokenizer-gpt-oss-20b/}" |
| 31 | +export XPK_TOKENIZER_LOCAL="${XPK_TOKENIZER_LOCAL:-/deps/src/maxtext/assets/tokenizers/gpt-oss-20b-tokenizer}" |
| 32 | + |
| 33 | +LOCAL_YAML="src/maxtext/configs/post_train/distillation_gpt_oss_20b.yml" |
| 34 | +export XPK_DISTILL_CONFIG="${XPK_DISTILL_CONFIG:-$LOCAL_YAML}" |
| 35 | +export XPK_YAML_GCS="${XPK_YAML_GCS:-gs://agagik-us/distill-configs/distillation_gpt_oss_20b.yml}" |
| 36 | + |
| 37 | +# distill_beta=0: decoder feature loss is broken on gpt-oss. |
| 38 | +export DISTILL_ALPHA="${DISTILL_ALPHA:-0.5}" |
| 39 | +export DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" |
| 40 | +export DISTILL_BETA="${DISTILL_BETA:-0}" |
| 41 | +export DISTILL_LAYER_INDICES="${DISTILL_LAYER_INDICES:-[]}" |
| 42 | + |
| 43 | +# XLA flags tuned for ~17% MFU. sparse_core_collective_aggregator is required |
| 44 | +# by latency_hiding_layer_scheduler. |
| 45 | +export XPK_LIBTPU_INIT_ARGS="${XPK_LIBTPU_INIT_ARGS:---xla_tpu_scoped_vmem_limit_kib=65536 \ |
| 46 | +--xla_tpu_impure_enable_packed_bf16_math_ops=true \ |
| 47 | +--xla_tpu_aggressive_opt_barrier_removal=true \ |
| 48 | +--xla_tpu_enable_sparse_core_collective_aggregator=true \ |
| 49 | +--xla_tpu_enable_latency_hiding_layer_scheduler=true \ |
| 50 | +--xla_tpu_enable_layer_scheduler_for_dependent_collectives=true \ |
| 51 | +--xla_tpu_enable_multi_compute_overlap_in_layer_scheduler=true \ |
| 52 | +--xla_tpu_scheduler_percent_shared_memory_limit=150 \ |
| 53 | +--xla_enable_async_all_gather=true \ |
| 54 | +--xla_tpu_prefer_async_allgather_to_allreduce=true \ |
| 55 | +--xla_max_concurrent_async_all_gathers=2 \ |
| 56 | +--xla_max_concurrent_async_reduce_scatters=2 \ |
| 57 | +--xla_tpu_enable_async_collective_fusion_fuse_all_gather=false}" |
| 58 | + |
| 59 | +if [ "$MODE" = "submit" ]; then |
| 60 | + gcloud storage cp "$XPK_DISTILL_CONFIG" "$XPK_YAML_GCS" |
| 61 | +fi |
| 62 | + |
| 63 | +exec bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh "$MODE" |
0 commit comments