Skip to content

Commit fd9717a

Browse files
committed
Merge remote-tracking branch 'origin/main' into davidsotomora-rl-chat-template
2 parents 885d96b + 56541ae commit fd9717a

33 files changed

Lines changed: 2103 additions & 260 deletions

.github/workflows/run_tests_against_package.yml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,15 @@ jobs:
162162
# Dynamically discover the 'nvidia' folder and prepend all its sub-library
163163
# directories (including nccl, cublas, cudnn) to LD_LIBRARY_PATH to prevent
164164
# JAX from partially loading incompatible system-level CUDA libraries.
165-
NVIDIA_DIR=$(find .venv/lib/ -maxdepth 3 -name "nvidia" -type d 2>/dev/null | head -n 1)
166-
if [ -n "${NVIDIA_DIR}" ]; then
167-
for dir in "${NVIDIA_DIR}"/*; do
168-
if [ -d "$dir/lib" ]; then
169-
export LD_LIBRARY_PATH=$(pwd)/$dir/lib:${LD_LIBRARY_PATH}
170-
fi
171-
done
165+
if [ -d ".venv/lib" ]; then
166+
NVIDIA_DIR=$(find .venv/lib/ -maxdepth 3 -name "nvidia" -type d 2>/dev/null | head -n 1)
167+
if [ -n "${NVIDIA_DIR}" ]; then
168+
for dir in "${NVIDIA_DIR}"/*; do
169+
if [ -d "$dir/lib" ]; then
170+
export LD_LIBRARY_PATH=$(pwd)/$dir/lib:${LD_LIBRARY_PATH}
171+
fi
172+
done
173+
fi
172174
fi
173175
fi
174176
if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#!/bin/bash
2+
# Launch gpt-oss-20b distillation on TPU v7x.
3+
# Usage: bash scripts/distillation/distill_gpt_oss_20b.sh [submit|monitor|resume_until_done]
4+
#
5+
# Set DISTILL_GCS_BUCKET to your own GCS bucket and XPK_BASE_IMAGE to your own
6+
# image before running — the placeholders below will not work as-is. Everything
7+
# else has a working default. Example:
8+
#
9+
# DISTILL_GCS_BUCKET=gs://your-bucket \
10+
# XPK_BASE_IMAGE=gcr.io/your-project/maxtext_base_image:tag \
11+
# bash scripts/distillation/distill_gpt_oss_20b.sh submit
12+
set -euo pipefail
13+
14+
MODE="${1:-submit}"
15+
REPO_ROOT=$(cd "$(dirname "$0")/../.." && pwd)
16+
cd "$REPO_ROOT"
17+
18+
# Your GCS bucket: run outputs, staged YAML, and tokenizer files all land here.
19+
export DISTILL_GCS_BUCKET="${DISTILL_GCS_BUCKET:-gs://YOUR-BUCKET}"
20+
21+
export XPK_WORKLOAD="${XPK_WORKLOAD:-goss-base-$(date +%Y%m%d-%H%M)}"
22+
export XPK_RUN_NAME="${XPK_RUN_NAME:-gpt_oss_20b_base}"
23+
export XPK_CLUSTER="${XPK_CLUSTER:-bodaborg-super-xpk-x8p}"
24+
export XPK_PROJECT="${XPK_PROJECT:-cloud-tpu-multipod-dev}"
25+
export XPK_ZONE="${XPK_ZONE:-us-central1}"
26+
export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}"
27+
export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-${DISTILL_GCS_BUCKET}/distillation}"
28+
export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}"
29+
export XPK_PRIORITY="${XPK_PRIORITY:-high}"
30+
31+
export XPK_USE_GCSFUSE=1
32+
export XPK_DATASET_BUCKET="${XPK_DATASET_BUCKET:-maxtext-dataset}"
33+
export XPK_DATASET_SUBPATH="${XPK_DATASET_SUBPATH:-array-record/climbmix/*.arrayrecord}"
34+
35+
# Stage HF tokenizer files (not in the image for gpt-oss).
36+
export XPK_TOKENIZER_GCS="${XPK_TOKENIZER_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/tokenizer-gpt-oss-20b/}"
37+
export XPK_TOKENIZER_LOCAL="${XPK_TOKENIZER_LOCAL:-/deps/src/maxtext/assets/tokenizers/gpt-oss-20b-tokenizer}"
38+
39+
LOCAL_YAML="src/maxtext/configs/post_train/distillation_gpt_oss_20b.yml"
40+
export XPK_DISTILL_CONFIG="${XPK_DISTILL_CONFIG:-$LOCAL_YAML}"
41+
export XPK_YAML_GCS="${XPK_YAML_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/distillation_gpt_oss_20b.yml}"
42+
43+
# distill_beta=0: decoder feature loss is broken on gpt-oss.
44+
export DISTILL_ALPHA="${DISTILL_ALPHA:-0.5}"
45+
export DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}"
46+
export DISTILL_BETA="${DISTILL_BETA:-0}"
47+
export DISTILL_LAYER_INDICES="${DISTILL_LAYER_INDICES:-[]}"
48+
49+
# XLA flags tuned for ~17% MFU. sparse_core_collective_aggregator is required
50+
# by latency_hiding_layer_scheduler.
51+
export XPK_LIBTPU_INIT_ARGS="${XPK_LIBTPU_INIT_ARGS:---xla_tpu_scoped_vmem_limit_kib=65536 \
52+
--xla_tpu_impure_enable_packed_bf16_math_ops=true \
53+
--xla_tpu_aggressive_opt_barrier_removal=true \
54+
--xla_tpu_enable_sparse_core_collective_aggregator=true \
55+
--xla_tpu_enable_latency_hiding_layer_scheduler=true \
56+
--xla_tpu_enable_layer_scheduler_for_dependent_collectives=true \
57+
--xla_tpu_enable_multi_compute_overlap_in_layer_scheduler=true \
58+
--xla_tpu_scheduler_percent_shared_memory_limit=150 \
59+
--xla_enable_async_all_gather=true \
60+
--xla_tpu_prefer_async_allgather_to_allreduce=true \
61+
--xla_max_concurrent_async_all_gathers=2 \
62+
--xla_max_concurrent_async_reduce_scatters=2 \
63+
--xla_tpu_enable_async_collective_fusion_fuse_all_gather=false}"
64+
65+
if [ "$MODE" = "submit" ]; then
66+
gcloud storage cp "$XPK_DISTILL_CONFIG" "$XPK_YAML_GCS"
67+
fi
68+
69+
exec bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh "$MODE"
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#!/bin/bash
2+
# Launch qwen3-30b-a3b-base distillation on TPU v7x.
3+
# Usage: bash scripts/distillation/distill_qwen3_30b_base.sh [submit|monitor|resume_until_done]
4+
#
5+
# Set DISTILL_GCS_BUCKET to your own GCS bucket and XPK_BASE_IMAGE to your own
6+
# image before running — the placeholders below will not work as-is. Everything
7+
# else has a working default. Example:
8+
#
9+
# DISTILL_GCS_BUCKET=gs://your-bucket \
10+
# XPK_BASE_IMAGE=gcr.io/your-project/maxtext_base_image:tag \
11+
# bash scripts/distillation/distill_qwen3_30b_base.sh submit
12+
set -euo pipefail
13+
14+
MODE="${1:-submit}"
15+
REPO_ROOT=$(cd "$(dirname "$0")/../.." && pwd)
16+
cd "$REPO_ROOT"
17+
18+
# Your GCS bucket: run outputs and staged YAML land here.
19+
export DISTILL_GCS_BUCKET="${DISTILL_GCS_BUCKET:-gs://YOUR-BUCKET}"
20+
21+
export XPK_WORKLOAD="${XPK_WORKLOAD:-q30b-base-$(date +%Y%m%d-%H%M)}"
22+
export XPK_RUN_NAME="${XPK_RUN_NAME:-qwen3_30b_base}"
23+
export XPK_CLUSTER="${XPK_CLUSTER:-bodaborg-super-xpk-x8p}"
24+
export XPK_PROJECT="${XPK_PROJECT:-cloud-tpu-multipod-dev}"
25+
export XPK_ZONE="${XPK_ZONE:-us-central1}"
26+
export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}"
27+
export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-${DISTILL_GCS_BUCKET}/distillation}"
28+
export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}"
29+
export XPK_PRIORITY="${XPK_PRIORITY:-high}"
30+
31+
export XPK_USE_GCSFUSE=1
32+
export XPK_DATASET_BUCKET="${XPK_DATASET_BUCKET:-maxtext-dataset}"
33+
export XPK_DATASET_SUBPATH="${XPK_DATASET_SUBPATH:-array-record/climbmix/*.arrayrecord}"
34+
35+
LOCAL_YAML="src/maxtext/configs/post_train/distillation_qwen3_30b_base.yml"
36+
export XPK_DISTILL_CONFIG="${XPK_DISTILL_CONFIG:-$LOCAL_YAML}"
37+
export XPK_YAML_GCS="${XPK_YAML_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/distillation_qwen3_30b_base.yml}"
38+
39+
export DISTILL_ALPHA="${DISTILL_ALPHA:-0.6}"
40+
export DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}"
41+
export DISTILL_BETA="${DISTILL_BETA:-1.0}"
42+
export DISTILL_LAYER_INDICES="${DISTILL_LAYER_INDICES:-[0,1,2,3,4,5,6,7]}"
43+
44+
# XLA flags tuned for ~20% MFU.
45+
export XPK_LIBTPU_INIT_ARGS="${XPK_LIBTPU_INIT_ARGS:---xla_tpu_scoped_vmem_limit_kib=61440 \
46+
--xla_tpu_enable_all_experimental_scheduler_features=true \
47+
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
48+
--xla_tpu_host_transfer_overlap_limit=24 \
49+
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \
50+
--xla_lhs_prioritize_async_depth_over_stall=ENABLED \
51+
--xla_tpu_enable_ag_backward_pipelining=true \
52+
--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
53+
--xla_should_add_loop_invariant_op_in_chain=ENABLED \
54+
--xla_max_concurrent_host_send_recv=100 \
55+
--xla_tpu_scheduler_percent_shared_memory_limit=100 \
56+
--xla_latency_hiding_scheduler_rerun=2}"
57+
58+
if [ "$MODE" = "submit" ]; then
59+
gcloud storage cp "$XPK_DISTILL_CONFIG" "$XPK_YAML_GCS"
60+
fi
61+
62+
exec bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh "$MODE"
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# gpt-oss-20b distillation. ~17% MFU on v7x.
2+
base_config: "base.yml"
3+
4+
# NOTE: load_parameters_path values are placeholders. Replace them with paths
5+
# to your own student and teacher checkpoints (Orbax format, ending in /0/items).
6+
student_overrides:
7+
model_name: "gpt-oss-20b"
8+
override_model_config: True
9+
base_num_query_heads: 32
10+
head_dim: 128
11+
base_num_kv_heads: 4
12+
load_parameters_path: "gs://YOUR-BUCKET/distillation/gpt_oss_20b/student/0/items"
13+
teacher_overrides:
14+
model_name: "gpt-oss-20b"
15+
load_parameters_path: "gs://YOUR-BUCKET/distillation/gpt_oss_20b/teacher/0/items"
16+
17+
# distill_beta=0: decoder feature loss is broken on gpt-oss.
18+
distill_alpha: 0.5
19+
distill_temperature: 1.0
20+
distill_beta: 0
21+
distill_layer_indices: []
22+
enable_nnx: True
23+
load_balance_loss_weight: 0.001
24+
25+
ici_fsdp_parallelism: 32
26+
ici_data_parallelism: 4
27+
28+
dataset_type: "grain"
29+
grain_file_type: "arrayrecord"
30+
# Launcher's gcsfuse override (CLI arg) replaces this for the student config.
31+
grain_train_files: "gs://maxtext-dataset/array-record/climbmix/*.arrayrecord"
32+
grain_worker_count: 16
33+
grain_ram_budget_mb: 1024
34+
grain_per_worker_buffer_size: 2
35+
grain_prefetch_buffer_size: 1024
36+
num_epoch: 10
37+
38+
tokenizer_path: "src/maxtext/assets/tokenizers/gpt-oss-20b-tokenizer"
39+
tokenizer_type: "huggingface"
40+
41+
max_target_length: 32768
42+
per_device_batch_size: 1
43+
gradient_accumulation_steps: 1
44+
45+
steps: 100000
46+
learning_rate_schedule_steps: 100000
47+
log_period: 10
48+
checkpoint_period: 500
49+
save_checkpoint_on_completion: True
50+
enable_checkpointing: True
51+
async_checkpointing: True
52+
skip_jax_distributed_system: False
53+
54+
learning_rate: 5.0e-5
55+
learning_rate_final_fraction: 0.2
56+
warmup_steps_fraction: 0.02
57+
adam_b1: 0.9
58+
adam_b2: 0.95
59+
adam_eps: 1.e-5
60+
adam_weight_decay: 0.01
61+
adamw_mask: ['.*embedding.*', '.*norm.*', '.*bias']
62+
z_loss_multiplier: 1.0e-5
63+
float32_logits: True
64+
65+
# Tokamax splash attention. Layouts: HEAD_DIM_MINOR | SEQ_MINOR.
66+
attention: "flash"
67+
use_tokamax_splash: True
68+
sa_use_fused_bwd_kernel: True
69+
sa_block_q: 2048
70+
sa_block_kv: 2048
71+
sa_block_kv_compute: 2048
72+
sa_block_q_dkv: 2048
73+
sa_block_kv_dkv: 2048
74+
sa_block_kv_dkv_compute: 2048
75+
sa_block_q_dq: 2048
76+
sa_block_kv_dq: 2048
77+
sa_q_layout: "SEQ_MINOR"
78+
sa_k_layout: "SEQ_MINOR"
79+
sa_v_layout: "SEQ_MINOR"
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Qwen3-30b-a3b-base distillation. ~20% MFU on v7x.
2+
base_config: "base.yml"
3+
4+
# NOTE: load_parameters_path values are placeholders. Replace them with paths
5+
# to your own student and teacher checkpoints (Orbax format, ending in /0/items).
6+
student_overrides:
7+
model_name: "qwen3-30b-a3b-base"
8+
override_model_config: True
9+
base_num_query_heads: 16
10+
head_dim: 256
11+
base_num_kv_heads: 2
12+
rope_max_timescale: 1_000_000
13+
load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/student/0/items"
14+
teacher_overrides:
15+
override_model_config: True
16+
model_name: "qwen3-30b-a3b-base"
17+
rope_max_timescale: 1_000_000
18+
load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/teacher/0/items"
19+
20+
distill_alpha: 0.6
21+
distill_temperature: 1.0
22+
distill_beta: 1.0
23+
distill_layer_indices: [0,1,2,3,4,5,6,7]
24+
enable_nnx: True
25+
load_balance_loss_weight: 0.001
26+
27+
ici_fsdp_parallelism: -1
28+
ici_data_parallelism: 4
29+
30+
dataset_type: "grain"
31+
grain_file_type: "arrayrecord"
32+
# Launcher's gcsfuse override (CLI arg) replaces this for the student config.
33+
grain_train_files: "gs://maxtext-dataset/array-record/climbmix/*.arrayrecord"
34+
grain_worker_count: 16
35+
grain_ram_budget_mb: 1024
36+
grain_per_worker_buffer_size: 2
37+
grain_prefetch_buffer_size: 1024
38+
num_epoch: 10
39+
40+
tokenizer_path: "src/maxtext/assets/tokenizers/qwen3-tokenizer"
41+
tokenizer_type: "huggingface"
42+
43+
max_target_length: 8192
44+
per_device_batch_size: 4
45+
gradient_accumulation_steps: 1
46+
47+
steps: 100000
48+
learning_rate_schedule_steps: 100000
49+
log_period: 10
50+
checkpoint_period: 500
51+
save_checkpoint_on_completion: True
52+
enable_checkpointing: True
53+
async_checkpointing: True
54+
skip_jax_distributed_system: False
55+
56+
learning_rate: 1.0e-4
57+
learning_rate_final_fraction: 0.1
58+
warmup_steps_fraction: 0.08
59+
adam_b1: 0.9
60+
adam_b2: 0.95
61+
adam_eps: 1.e-5
62+
adam_weight_decay: 0.01
63+
adamw_mask: ['.*embedding.*', '.*norm.*', '.*bias']
64+
z_loss_multiplier: 1.0e-5
65+
float32_logits: True
66+
67+
# Tokamax splash attention. Layouts: HEAD_DIM_MINOR | SEQ_MINOR.
68+
attention: "flash"
69+
use_tokamax_splash: True
70+
sa_use_fused_bwd_kernel: True
71+
sa_block_q: 1024
72+
sa_block_kv: 1024
73+
sa_block_kv_compute: 512
74+
sa_block_q_dkv: 2048
75+
sa_block_kv_dkv: 2048
76+
sa_block_kv_dkv_compute: 1024
77+
sa_block_q_dq: 1024
78+
sa_block_kv_dq: 1024
79+
sa_q_layout: "HEAD_DIM_MINOR"
80+
sa_k_layout: "SEQ_MINOR"
81+
sa_v_layout: "HEAD_DIM_MINOR"

0 commit comments

Comments
 (0)