Skip to content

Commit fe674ae

Browse files
Merge pull request #3748 from AI-Hypercomputer:gagik-update-xpk
PiperOrigin-RevId: 906660149
2 parents b61299e + 34712fe commit fe674ae

2 files changed

Lines changed: 169 additions & 13 deletions

File tree

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# MaxText distillation trainer
2+
3+
Distillation from a large teacher LLM into a (optionally pruned) student on
4+
TPU, via JAX/MaxText on top of Tunix's PEFT trainer. This README is
5+
operational — for concepts, see
6+
[`knowledge_distillation.md`](../../../../../docs/tutorials/posttraining/knowledge_distillation.md)
7+
and [`post_training_index.md`](../../../../../docs/tutorials/post_training_index.md).
8+
9+
Canonical launcher: [`scripts/run_distill_xpk.sh`](scripts/run_distill_xpk.sh)
10+
(see its header for all env vars).
11+
12+
13+
## 1. Pick a config
14+
15+
Configs live in [`src/maxtext/configs/post_train/`](../../../configs/post_train/):
16+
17+
| File | Student | Teacher | Notes |
18+
|---|---|---|---|
19+
| `distillation.yml` | llama3.1-8b | llama3.1-8b | Baseline |
20+
| `distillation-sft.yml` | llama3.1-8b | llama3.1-8b | Distillation + SFT mix |
21+
22+
**Override `num_epoch` to a value > 1 if you want to train for more than one
23+
epoch** (e.g. `num_epoch=10`) — the base default is 1, and the input pipeline
24+
(Grain) iterates the dataset only once before stopping, so longer runs run
25+
out of data mid-training. Pass it as a CLI override (`… distillation.yml num_epoch=10`) or
26+
edit the YAML directly.
27+
28+
## 2. Single-host smoke test
29+
30+
Validate your config + checkpoint paths on a single TPU VM (no xpk, no GKE)
31+
before scaling to a cluster. The default `distillation.yml` (llama3.1-8b
32+
student + 8b teacher) needs a slice large enough to hold both models in
33+
HBM — ≥ v5p-16 in practice; a v5p-8 only fits with bf16 weights or a
34+
shrunken student (see below):
35+
36+
```bash
37+
source <your-venv>/bin/activate
38+
PYTHONPATH=$PWD/src python -m maxtext.trainers.post_train.distillation.train_distill \
39+
src/maxtext/configs/post_train/distillation.yml \
40+
run_name=local_smoke \
41+
base_output_directory=gs://<bucket>/distill_smoke \
42+
steps=5
43+
```
44+
45+
Smaller TPU? Shrink the model with overrides like `base_emb_dim=...
46+
base_num_decoder_layers=... base_num_query_heads=...` — pass
47+
`override_model_config=True` so the CLI overrides actually take effect
48+
(default is `False`).
49+
50+
51+
## 3. Cluster auth
52+
53+
```bash
54+
pip install git+https://github.com/AI-Hypercomputer/xpk.git
55+
56+
# Kubeconfig (use --dns-endpoint; IP endpoints are often stale).
57+
# Use --zone for zonal clusters, --region for regional ones.
58+
gcloud container clusters get-credentials <cluster> \
59+
--zone=<zone> --project=<project> --dns-endpoint
60+
61+
# Verify RBAC in the default namespace:
62+
kubectl auth can-i create roles --namespace=default # must print: yes
63+
```
64+
65+
If `can-i` prints `no`, ask a cluster admin to bind
66+
`roles/container.admin` to your user.
67+
68+
## 4. Build + push the image (one time)
69+
70+
The flow is: build the MaxText base → `prep_image` rebuilds `$XPK_BASE_IMAGE`
71+
**in place** (same local tag) with tunix layered on top → `docker_upload_runner.sh`
72+
pushes that modified local tag to GCR under `$CLOUD_IMAGE_NAME`.
73+
74+
```bash
75+
# Local tag prep_image rebuilds; registry name docker_upload_runner pushes to.
76+
export XPK_BASE_IMAGE=maxtext_base_image:stable
77+
export CLOUD_IMAGE_NAME=gcr.io/<your-project>/maxtext_base_image:stable
78+
79+
# Base image with MaxText + TPU deps.
80+
sudo bash src/dependencies/scripts/docker_build_dependency_image.sh \
81+
MODE=stable WORKFLOW=post-training
82+
83+
# Layer tunix + re-pin jax/libtpu for libtpu compat. Rebuilds $XPK_BASE_IMAGE in place.
84+
bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh prep_image
85+
86+
# Push the modified local tag to GCR so later submits pull from the registry (no buildx).
87+
sudo bash src/dependencies/scripts/docker_upload_runner.sh \
88+
CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}
89+
```
90+
91+
## 5. Submit
92+
93+
```bash
94+
export XPK_CLUSTER=<cluster>
95+
export XPK_PROJECT=<project>
96+
export XPK_ZONE=<zone>
97+
export XPK_DEVICE_TYPE=tpu7x-4x4x4
98+
export XPK_BASE_IMAGE=${CLOUD_IMAGE_NAME} # slash in name → --docker-image auto-selected
99+
export XPK_BASE_OUTPUT_DIR=gs://<bucket>/distillation
100+
export XPK_RUN_NAME=<experiment> # default: distill_run; set per experiment
101+
# to scope checkpoints + TB under
102+
# ${XPK_BASE_OUTPUT_DIR}/${XPK_WORKLOAD}/${XPK_RUN_NAME}/
103+
104+
bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh submit
105+
```
106+
107+
The launcher writes the workload name to `~/.xpk_last_workload` for §6/§7.
108+
Override `XPK_WORKLOAD` if you want; keep it ≲16 chars — some clusters
109+
cap derived resource names at 49. See the script header for other env
110+
vars (`DISTILL_ALPHA`, `DISTILL_BETA`, `STEPS_OVERRIDE`, etc.).
111+
112+
## 6. Monitor
113+
114+
```bash
115+
WL=$(cat ~/.xpk_last_workload)
116+
POD=$(kubectl get pods -l jobset.sigs.k8s.io/jobset-name=$WL,batch.kubernetes.io/job-completion-index=0 -o name | head -1)
117+
kubectl logs -f ${POD} -c jax-tpu-1 | grep "Train step"
118+
```
119+
120+
## 7. Resume
121+
122+
Submit again with the **same `XPK_BASE_OUTPUT_DIR` + `XPK_WORKLOAD` + `XPK_RUN_NAME`**
123+
checkpoints live at `${XPK_BASE_OUTPUT_DIR}/${XPK_WORKLOAD}/${XPK_RUN_NAME}/checkpoints/`,
124+
and `maybe_restore` picks up the latest one. All three must match the
125+
previous submit (the launcher writes the workload name to `~/.xpk_last_workload`).
126+
For auto-retry:
127+
128+
```bash
129+
STEPS_OVERRIDE=108000 \
130+
bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh resume_until_done
131+
```
132+
133+
## 8. Checkpoint retention
134+
135+
| Field | Default | Description |
136+
|---|---|---|
137+
| `checkpoint_period` | 2000 | Save a checkpoint every N training steps |
138+
| `max_num_checkpoints_to_keep` | `None` | Keep the most recent N checkpoints. `None` keeps all |
139+

src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,30 @@
6060
# ${XPK_BASE_OUTPUT_DIR}/${XPK_WORKLOAD}/)
6161
#
6262
# OPTIONAL env vars (with defaults):
63-
# XPK_BASE_IMAGE default: maxtext_base_image
64-
# XPK_WORKLOAD default: distill-${USER}-${RANDOM}
63+
# XPK_BASE_IMAGE default: maxtext_base_image. A slash in the name
64+
# (e.g. gcr.io/...) switches xpk from --base-docker-image
65+
# (buildx re-push on each submit) to --docker-image
66+
# (pull from registry). Prefer the registry path after
67+
# the first submit.
68+
# XPK_WORKLOAD default: d-${USER:0:8}-${RANDOM} (~14 chars max).
69+
# Keep ≲16 chars: some clusters cap derived
70+
# resource names at 49 chars
71+
# (default-jobset-<workload>-<5>-...-<5>).
6572
# XPK_PRIORITY default: medium
6673
# XPK_NUM_SLICES default: 1
6774
# XPK_DISTILL_CONFIG default: src/maxtext/configs/post_train/distillation.yml
6875
# XPK_RUN_NAME default: distill_run — passed as MaxText run_name; becomes
6976
# the subdir under base_output_directory where checkpoints
7077
# and TB logs land (...${OUTPUT_DIR}/${XPK_RUN_NAME}/...).
7178
# resume_until_done lists this subdir to find the latest step.
72-
# XPK_USE_GCSFUSE default: 0 — if 1, mount XPK_DATASET_BUCKET via gcsfuse
73-
# and override grain_train_files to the local mount path
74-
# (useful for ArrayRecord; gs:// paths are slower)
75-
# XPK_DATASET_BUCKET default: maxtext-dataset (only used when XPK_USE_GCSFUSE=1)
79+
# XPK_USE_GCSFUSE default: 1 — mount XPK_DATASET_BUCKET via gcsfuse and
80+
# point grain at the local mount path. ~10x faster than
81+
# direct gs:// reads for ArrayRecord shards. Set to 0
82+
# to bypass gcsfuse and read directly from gs://.
83+
# XPK_DATASET_BUCKET default: maxtext-dataset
7684
# XPK_DATASET_SUBPATH default: array-record/climbmix/*.arrayrecord
77-
# (relative to the bucket; only used when XPK_USE_GCSFUSE=1)
85+
# The script always sets grain_train_files from these
86+
# two, overriding the YAML in both modes.
7887
# STEPS_OVERRIDE default: empty — yml `steps` is used unless set
7988
# CHECKPOINT_PERIOD_OVERRIDE default: empty — yml `checkpoint_period` is used
8089
# MAX_RETRIES default: 10 — only used by resume_until_done
@@ -122,12 +131,12 @@ require_env() {
122131

123132
# -------------------------- defaults --------------------------
124133
: "${XPK_BASE_IMAGE:=maxtext_base_image}"
125-
: "${XPK_WORKLOAD:=distill-${USER:-anon}-${RANDOM}}"
134+
: "${XPK_WORKLOAD:=d-${USER:0:8}-${RANDOM}}"
126135
: "${XPK_PRIORITY:=medium}"
127136
: "${XPK_NUM_SLICES:=1}"
128137
: "${XPK_DISTILL_CONFIG:=src/maxtext/configs/post_train/distillation.yml}"
129138
: "${XPK_RUN_NAME:=distill_run}"
130-
: "${XPK_USE_GCSFUSE:=0}"
139+
: "${XPK_USE_GCSFUSE:=1}"
131140
: "${XPK_DATASET_BUCKET:=maxtext-dataset}"
132141
: "${XPK_DATASET_SUBPATH:=array-record/climbmix/*.arrayrecord}"
133142
: "${MAX_RETRIES:=10}"
@@ -169,14 +178,14 @@ if [ -n "${CHECKPOINT_PERIOD_OVERRIDE:-}" ]; then
169178
extra_cli="$extra_cli checkpoint_period=${CHECKPOINT_PERIOD_OVERRIDE}"
170179
fi
171180

172-
# Optional gcsfuse prelude — direct gs:// reads of ArrayRecord shards are slow,
173-
# so mounting the bucket and pointing grain at the local path is recommended.
181+
# Build grain_train_files (configs leave it empty); pick local mount or direct gs://.
174182
gcsfuse_prelude=""
175-
grain_files_override=""
176183
if [ "$XPK_USE_GCSFUSE" = "1" ]; then
177184
gcsfuse_prelude="bash src/dependencies/scripts/setup_gcsfuse.sh \
178185
DATASET_GCS_BUCKET=${XPK_DATASET_BUCKET} MOUNT_PATH=/tmp/gcsfuse;"
179186
grain_files_override="grain_train_files=/tmp/gcsfuse/${XPK_DATASET_SUBPATH}"
187+
else
188+
grain_files_override="grain_train_files=gs://${XPK_DATASET_BUCKET}/${XPK_DATASET_SUBPATH}"
180189
fi
181190

182191
# -------------------------- prep_image --------------------------
@@ -226,6 +235,14 @@ submit_workload() {
226235
echo "Config: $XPK_DISTILL_CONFIG"
227236
[ -n "$extra_cli" ] && echo "Overrides: $extra_cli"
228237

238+
# Registry path (contains slash) → --docker-image (pull on cluster);
239+
# local tag → --base-docker-image (buildx re-push).
240+
local image_flag="--base-docker-image"
241+
if [[ "$XPK_BASE_IMAGE" == *"/"* ]]; then
242+
image_flag="--docker-image"
243+
fi
244+
echo "Image flag: $image_flag=$XPK_BASE_IMAGE"
245+
229246
xpk workload create \
230247
--cluster "$XPK_CLUSTER" \
231248
--workload "$XPK_WORKLOAD" \
@@ -234,7 +251,7 @@ submit_workload() {
234251
--num-slices="$XPK_NUM_SLICES" \
235252
--project="$XPK_PROJECT" \
236253
--zone="$XPK_ZONE" \
237-
--base-docker-image="$XPK_BASE_IMAGE" \
254+
"$image_flag=$XPK_BASE_IMAGE" \
238255
--command "export PYTHONPATH=/app/src; \
239256
export BASE_OUTPUT_DIRECTORY=${OUTPUT_DIR}; \
240257
${gcsfuse_prelude} \

0 commit comments

Comments
 (0)