|
| 1 | +# MaxText distillation trainer |
| 2 | + |
| 3 | +Distillation from a large teacher LLM into a (optionally pruned) student on |
| 4 | +TPU, via JAX/MaxText on top of Tunix's PEFT trainer. This README is |
| 5 | +operational — for concepts, see |
| 6 | +[`knowledge_distillation.md`](../../../../../docs/tutorials/posttraining/knowledge_distillation.md) |
| 7 | +and [`post_training_index.md`](../../../../../docs/tutorials/post_training_index.md). |
| 8 | + |
| 9 | +Canonical launcher: [`scripts/run_distill_xpk.sh`](scripts/run_distill_xpk.sh) |
| 10 | +(see its header for all env vars). |
| 11 | + |
| 12 | + |
| 13 | +## 1. Pick a config |
| 14 | + |
| 15 | +Configs live in [`src/maxtext/configs/post_train/`](../../../configs/post_train/): |
| 16 | + |
| 17 | +| File | Student | Teacher | Notes | |
| 18 | +|---|---|---|---| |
| 19 | +| `distillation.yml` | llama3.1-8b | llama3.1-8b | Baseline | |
| 20 | +| `distillation-sft.yml` | llama3.1-8b | llama3.1-8b | Distillation + SFT mix | |
| 21 | + |
| 22 | +**Override `num_epoch` to a value > 1 if you want to train for more than one |
| 23 | +epoch** (e.g. `num_epoch=10`) — the base default is 1, and the input pipeline |
| 24 | +(Grain) iterates the dataset only once before stopping, so longer runs run |
| 25 | +out of data mid-training. Pass it as a CLI override (`… distillation.yml num_epoch=10`) or |
| 26 | +edit the YAML directly. |
| 27 | + |
| 28 | +## 2. Single-host smoke test |
| 29 | + |
| 30 | +Validate your config + checkpoint paths on a single TPU VM (no xpk, no GKE) |
| 31 | +before scaling to a cluster. The default `distillation.yml` (llama3.1-8b |
| 32 | +student + 8b teacher) needs a slice large enough to hold both models in |
| 33 | +HBM — ≥ v5p-16 in practice; a v5p-8 only fits with bf16 weights or a |
| 34 | +shrunken student (see below): |
| 35 | + |
| 36 | +```bash |
| 37 | +source <your-venv>/bin/activate |
| 38 | +PYTHONPATH=$PWD/src python -m maxtext.trainers.post_train.distillation.train_distill \ |
| 39 | + src/maxtext/configs/post_train/distillation.yml \ |
| 40 | + run_name=local_smoke \ |
| 41 | + base_output_directory=gs://<bucket>/distill_smoke \ |
| 42 | + steps=5 |
| 43 | +``` |
| 44 | + |
| 45 | +Smaller TPU? Shrink the model with overrides like `base_emb_dim=... |
| 46 | +base_num_decoder_layers=... base_num_query_heads=...` — pass |
| 47 | +`override_model_config=True` so the CLI overrides actually take effect |
| 48 | +(default is `False`). |
| 49 | + |
| 50 | + |
| 51 | +## 3. Cluster auth |
| 52 | + |
| 53 | +```bash |
| 54 | +pip install git+https://github.com/AI-Hypercomputer/xpk.git |
| 55 | + |
| 56 | +# Kubeconfig (use --dns-endpoint; IP endpoints are often stale). |
| 57 | +# Use --zone for zonal clusters, --region for regional ones. |
| 58 | +gcloud container clusters get-credentials <cluster> \ |
| 59 | + --zone=<zone> --project=<project> --dns-endpoint |
| 60 | + |
| 61 | +# Verify RBAC in the default namespace: |
| 62 | +kubectl auth can-i create roles --namespace=default # must print: yes |
| 63 | +``` |
| 64 | + |
| 65 | +If `can-i` prints `no`, ask a cluster admin to bind |
| 66 | +`roles/container.admin` to your user. |
| 67 | + |
| 68 | +## 4. Build + push the image (one time) |
| 69 | + |
| 70 | +The flow is: build the MaxText base → `prep_image` rebuilds `$XPK_BASE_IMAGE` |
| 71 | +**in place** (same local tag) with tunix layered on top → `docker_upload_runner.sh` |
| 72 | +pushes that modified local tag to GCR under `$CLOUD_IMAGE_NAME`. |
| 73 | + |
| 74 | +```bash |
| 75 | +# Local tag prep_image rebuilds; registry name docker_upload_runner pushes to. |
| 76 | +export XPK_BASE_IMAGE=maxtext_base_image:stable |
| 77 | +export CLOUD_IMAGE_NAME=gcr.io/<your-project>/maxtext_base_image:stable |
| 78 | + |
| 79 | +# Base image with MaxText + TPU deps. |
| 80 | +sudo bash src/dependencies/scripts/docker_build_dependency_image.sh \ |
| 81 | + MODE=stable WORKFLOW=post-training |
| 82 | + |
| 83 | +# Layer tunix + re-pin jax/libtpu for libtpu compat. Rebuilds $XPK_BASE_IMAGE in place. |
| 84 | +bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh prep_image |
| 85 | + |
| 86 | +# Push the modified local tag to GCR so later submits pull from the registry (no buildx). |
| 87 | +sudo bash src/dependencies/scripts/docker_upload_runner.sh \ |
| 88 | + CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME} |
| 89 | +``` |
| 90 | + |
| 91 | +## 5. Submit |
| 92 | + |
| 93 | +```bash |
| 94 | +export XPK_CLUSTER=<cluster> |
| 95 | +export XPK_PROJECT=<project> |
| 96 | +export XPK_ZONE=<zone> |
| 97 | +export XPK_DEVICE_TYPE=tpu7x-4x4x4 |
| 98 | +export XPK_BASE_IMAGE=${CLOUD_IMAGE_NAME} # slash in name → --docker-image auto-selected |
| 99 | +export XPK_BASE_OUTPUT_DIR=gs://<bucket>/distillation |
| 100 | +export XPK_RUN_NAME=<experiment> # default: distill_run; set per experiment |
| 101 | + # to scope checkpoints + TB under |
| 102 | + # ${XPK_BASE_OUTPUT_DIR}/${XPK_WORKLOAD}/${XPK_RUN_NAME}/ |
| 103 | + |
| 104 | +bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh submit |
| 105 | +``` |
| 106 | + |
| 107 | +The launcher writes the workload name to `~/.xpk_last_workload` for §6/§7. |
| 108 | +Override `XPK_WORKLOAD` if you want; keep it ≲16 chars — some clusters |
| 109 | +cap derived resource names at 49. See the script header for other env |
| 110 | +vars (`DISTILL_ALPHA`, `DISTILL_BETA`, `STEPS_OVERRIDE`, etc.). |
| 111 | + |
| 112 | +## 6. Monitor |
| 113 | + |
| 114 | +```bash |
| 115 | +WL=$(cat ~/.xpk_last_workload) |
| 116 | +POD=$(kubectl get pods -l jobset.sigs.k8s.io/jobset-name=$WL,batch.kubernetes.io/job-completion-index=0 -o name | head -1) |
| 117 | +kubectl logs -f ${POD} -c jax-tpu-1 | grep "Train step" |
| 118 | +``` |
| 119 | + |
| 120 | +## 7. Resume |
| 121 | + |
| 122 | +Submit again with the **same `XPK_BASE_OUTPUT_DIR` + `XPK_WORKLOAD` + `XPK_RUN_NAME`** — |
| 123 | +checkpoints live at `${XPK_BASE_OUTPUT_DIR}/${XPK_WORKLOAD}/${XPK_RUN_NAME}/checkpoints/`, |
| 124 | +and `maybe_restore` picks up the latest one. All three must match the |
| 125 | +previous submit (the launcher writes the workload name to `~/.xpk_last_workload`). |
| 126 | +For auto-retry: |
| 127 | + |
| 128 | +```bash |
| 129 | +STEPS_OVERRIDE=108000 \ |
| 130 | + bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh resume_until_done |
| 131 | +``` |
| 132 | + |
| 133 | +## 8. Checkpoint retention |
| 134 | + |
| 135 | +| Field | Default | Description | |
| 136 | +|---|---|---| |
| 137 | +| `checkpoint_period` | 2000 | Save a checkpoint every N training steps | |
| 138 | +| `max_num_checkpoints_to_keep` | `None` | Keep the most recent N checkpoints. `None` keeps all | |
| 139 | + |
0 commit comments