Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions scripts/distillation/distill_gpt_oss_20b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/bin/bash
# Launch gpt-oss-20b distillation on TPU v7x.
# Usage: bash scripts/distillation/distill_gpt_oss_20b.sh [submit|monitor|resume_until_done]
#
# Set DISTILL_GCS_BUCKET to your own GCS bucket and XPK_BASE_IMAGE to your own
# image before running — the placeholders below will not work as-is. Everything
# else has a working default. Example:
#
# DISTILL_GCS_BUCKET=gs://your-bucket \
# XPK_BASE_IMAGE=gcr.io/your-project/maxtext_base_image:tag \
# bash scripts/distillation/distill_gpt_oss_20b.sh submit
set -euo pipefail

MODE="${1:-submit}"
REPO_ROOT=$(cd "$(dirname "$0")/../.." && pwd)
cd "$REPO_ROOT"

# Your GCS bucket: run outputs, staged YAML, and tokenizer files all land here.
export DISTILL_GCS_BUCKET="${DISTILL_GCS_BUCKET:-gs://YOUR-BUCKET}"

export XPK_WORKLOAD="${XPK_WORKLOAD:-goss-base-$(date +%Y%m%d-%H%M)}"
export XPK_RUN_NAME="${XPK_RUN_NAME:-gpt_oss_20b_base}"
export XPK_CLUSTER="${XPK_CLUSTER:-bodaborg-super-xpk-x8p}"
export XPK_PROJECT="${XPK_PROJECT:-cloud-tpu-multipod-dev}"
export XPK_ZONE="${XPK_ZONE:-us-central1}"
export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}"
export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-${DISTILL_GCS_BUCKET}/distillation}"
export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Using a dev image as a demo default is acceptable, but it might be better to point to a more stable or public reference if available.
Suggested change
export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}"
export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}"

export XPK_PRIORITY="${XPK_PRIORITY:-high}"

export XPK_USE_GCSFUSE=1
export XPK_DATASET_BUCKET="${XPK_DATASET_BUCKET:-maxtext-dataset}"
export XPK_DATASET_SUBPATH="${XPK_DATASET_SUBPATH:-array-record/climbmix/*.arrayrecord}"

# Stage HF tokenizer files (not in the image for gpt-oss).
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Using a specific user's bucket as a default for `XPK_YAML_GCS` will cause the `submit` mode to fail for any other user due to lack of write permissions. Consider using a more generic placeholder or documenting this as a mandatory override.
Suggested change
# Stage HF tokenizer files (not in the image for gpt-oss).
export XPK_YAML_GCS="${XPK_YAML_GCS:-gs://YOUR-BUCKET/distill-configs/distillation_gpt_oss_20b.yml}"

export XPK_TOKENIZER_GCS="${XPK_TOKENIZER_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/tokenizer-gpt-oss-20b/}"
export XPK_TOKENIZER_LOCAL="${XPK_TOKENIZER_LOCAL:-/deps/src/maxtext/assets/tokenizers/gpt-oss-20b-tokenizer}"

LOCAL_YAML="src/maxtext/configs/post_train/distillation_gpt_oss_20b.yml"
export XPK_DISTILL_CONFIG="${XPK_DISTILL_CONFIG:-$LOCAL_YAML}"
export XPK_YAML_GCS="${XPK_YAML_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/distillation_gpt_oss_20b.yml}"

# distill_beta=0: decoder feature loss is broken on gpt-oss.
export DISTILL_ALPHA="${DISTILL_ALPHA:-0.5}"
export DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}"
export DISTILL_BETA="${DISTILL_BETA:-0}"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 For consistency with the `qwen3` script and the default in `run_distill_xpk.sh`, consider using `61440` (60MB) unless `65536` (64MB) was specifically found to be necessary for `gpt-oss-20b`.
Suggested change
export DISTILL_BETA="${DISTILL_BETA:-0}"
export XPK_LIBTPU_INIT_ARGS="${XPK_LIBTPU_INIT_ARGS:---xla_tpu_scoped_vmem_limit_kib=61440 \

export DISTILL_LAYER_INDICES="${DISTILL_LAYER_INDICES:-[]}"

# XLA flags tuned for ~17% MFU. sparse_core_collective_aggregator is required
# by latency_hiding_layer_scheduler.
export XPK_LIBTPU_INIT_ARGS="${XPK_LIBTPU_INIT_ARGS:---xla_tpu_scoped_vmem_limit_kib=65536 \
--xla_tpu_impure_enable_packed_bf16_math_ops=true \
--xla_tpu_aggressive_opt_barrier_removal=true \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 For consistency with other scripts in the repository (e.g., `distill_qwen3_30b_base.sh` and `run_distill_xpk.sh`) and the default XLA flags defined in `benchmarks/xla_flags_library.py`, consider using `ENABLED` instead of `true` for this flag.
Suggested change
--xla_tpu_aggressive_opt_barrier_removal=true \
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \

--xla_tpu_enable_sparse_core_collective_aggregator=true \
--xla_tpu_enable_latency_hiding_layer_scheduler=true \
--xla_tpu_enable_layer_scheduler_for_dependent_collectives=true \
--xla_tpu_enable_multi_compute_overlap_in_layer_scheduler=true \
--xla_tpu_scheduler_percent_shared_memory_limit=150 \
--xla_enable_async_all_gather=true \
--xla_tpu_prefer_async_allgather_to_allreduce=true \
--xla_max_concurrent_async_all_gathers=2 \
--xla_max_concurrent_async_reduce_scatters=2 \
--xla_tpu_enable_async_collective_fusion_fuse_all_gather=false}"

if [ "$MODE" = "submit" ]; then
gcloud storage cp "$XPK_DISTILL_CONFIG" "$XPK_YAML_GCS"
fi

exec bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh "$MODE"
62 changes: 62 additions & 0 deletions scripts/distillation/distill_qwen3_30b_base.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/bin/bash
# Launch qwen3-30b-a3b-base distillation on TPU v7x.
# Usage: bash scripts/distillation/distill_qwen3_30b_base.sh [submit|monitor|resume_until_done]
#
# Set DISTILL_GCS_BUCKET to your own GCS bucket and XPK_BASE_IMAGE to your own
# image before running — the placeholders below will not work as-is. Everything
# else has a working default. Example:
#
# DISTILL_GCS_BUCKET=gs://your-bucket \
# XPK_BASE_IMAGE=gcr.io/your-project/maxtext_base_image:tag \
# bash scripts/distillation/distill_qwen3_30b_base.sh submit
set -euo pipefail

MODE="${1:-submit}"
REPO_ROOT=$(cd "$(dirname "$0")/../.." && pwd)
cd "$REPO_ROOT"

# Your GCS bucket: run outputs and staged YAML land here.
export DISTILL_GCS_BUCKET="${DISTILL_GCS_BUCKET:-gs://YOUR-BUCKET}"

export XPK_WORKLOAD="${XPK_WORKLOAD:-q30b-base-$(date +%Y%m%d-%H%M)}"
export XPK_RUN_NAME="${XPK_RUN_NAME:-qwen3_30b_base}"
export XPK_CLUSTER="${XPK_CLUSTER:-bodaborg-super-xpk-x8p}"
export XPK_PROJECT="${XPK_PROJECT:-cloud-tpu-multipod-dev}"
export XPK_ZONE="${XPK_ZONE:-us-central1}"
export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}"
export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-${DISTILL_GCS_BUCKET}/distillation}"
export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Consider using a more generic or stable base image reference for the demo default.
Suggested change
export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}"
export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}"

export XPK_PRIORITY="${XPK_PRIORITY:-high}"

export XPK_USE_GCSFUSE=1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Similar to the GPT-OSS script, the default `XPK_YAML_GCS` points to a specific user's bucket, which will prevent other users from using the `submit` mode out of the box.
Suggested change
export XPK_USE_GCSFUSE=1
export XPK_YAML_GCS="${XPK_YAML_GCS:-gs://YOUR-BUCKET/distill-configs/distillation_qwen3_30b_base.yml}"

export XPK_DATASET_BUCKET="${XPK_DATASET_BUCKET:-maxtext-dataset}"
export XPK_DATASET_SUBPATH="${XPK_DATASET_SUBPATH:-array-record/climbmix/*.arrayrecord}"

LOCAL_YAML="src/maxtext/configs/post_train/distillation_qwen3_30b_base.yml"
export XPK_DISTILL_CONFIG="${XPK_DISTILL_CONFIG:-$LOCAL_YAML}"
export XPK_YAML_GCS="${XPK_YAML_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/distillation_qwen3_30b_base.yml}"

export DISTILL_ALPHA="${DISTILL_ALPHA:-0.6}"
export DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}"
export DISTILL_BETA="${DISTILL_BETA:-1.0}"
export DISTILL_LAYER_INDICES="${DISTILL_LAYER_INDICES:-[0,1,2,3,4,5,6,7]}"

# XLA flags tuned for ~20% MFU.
export XPK_LIBTPU_INIT_ARGS="${XPK_LIBTPU_INIT_ARGS:---xla_tpu_scoped_vmem_limit_kib=61440 \
--xla_tpu_enable_all_experimental_scheduler_features=true \
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
--xla_tpu_host_transfer_overlap_limit=24 \
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \
--xla_lhs_prioritize_async_depth_over_stall=ENABLED \
--xla_tpu_enable_ag_backward_pipelining=true \
--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
--xla_should_add_loop_invariant_op_in_chain=ENABLED \
--xla_max_concurrent_host_send_recv=100 \
--xla_tpu_scheduler_percent_shared_memory_limit=100 \
--xla_latency_hiding_scheduler_rerun=2}"

if [ "$MODE" = "submit" ]; then
gcloud storage cp "$XPK_DISTILL_CONFIG" "$XPK_YAML_GCS"
fi

exec bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh "$MODE"
79 changes: 79 additions & 0 deletions src/maxtext/configs/post_train/distillation_gpt_oss_20b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# gpt-oss-20b distillation. ~17% MFU on v7x.
base_config: "base.yml"

# NOTE: load_parameters_path values are placeholders. Replace them with paths
# to your own student and teacher checkpoints (Orbax format, ending in /0/items).
student_overrides:
model_name: "gpt-oss-20b"
override_model_config: True
base_num_query_heads: 32
head_dim: 128
base_num_kv_heads: 4
load_parameters_path: "gs://YOUR-BUCKET/distillation/gpt_oss_20b/student/0/items"
teacher_overrides:
model_name: "gpt-oss-20b"
load_parameters_path: "gs://YOUR-BUCKET/distillation/gpt_oss_20b/teacher/0/items"

# distill_beta=0: decoder feature loss is broken on gpt-oss.
distill_alpha: 0.5
distill_temperature: 1.0
distill_beta: 0
distill_layer_indices: []
enable_nnx: True
load_balance_loss_weight: 0.001
Comment thread
gagika marked this conversation as resolved.

ici_fsdp_parallelism: 32
ici_data_parallelism: 4

dataset_type: "grain"
grain_file_type: "arrayrecord"
# Launcher's gcsfuse override (CLI arg) replaces this for the student config.
grain_train_files: "gs://maxtext-dataset/array-record/climbmix/*.arrayrecord"
grain_worker_count: 16
grain_ram_budget_mb: 1024
grain_per_worker_buffer_size: 2
grain_prefetch_buffer_size: 1024
num_epoch: 10

tokenizer_path: "src/maxtext/assets/tokenizers/gpt-oss-20b-tokenizer"
tokenizer_type: "huggingface"

max_target_length: 32768
per_device_batch_size: 1
gradient_accumulation_steps: 1

steps: 100000
learning_rate_schedule_steps: 100000
log_period: 10
checkpoint_period: 500
save_checkpoint_on_completion: True
enable_checkpointing: True
async_checkpointing: True
skip_jax_distributed_system: False

learning_rate: 5.0e-5
learning_rate_final_fraction: 0.2
warmup_steps_fraction: 0.02
adam_b1: 0.9
adam_b2: 0.95
adam_eps: 1.e-5
adam_weight_decay: 0.01
adamw_mask: ['.*embedding.*', '.*norm.*', '.*bias']
z_loss_multiplier: 1.0e-5
float32_logits: True

# Tokamax splash attention. Layouts: HEAD_DIM_MINOR | SEQ_MINOR.
attention: "flash"
use_tokamax_splash: True
sa_use_fused_bwd_kernel: True
sa_block_q: 2048
sa_block_kv: 2048
sa_block_kv_compute: 2048
sa_block_q_dkv: 2048
sa_block_kv_dkv: 2048
sa_block_kv_dkv_compute: 2048
sa_block_q_dq: 2048
sa_block_kv_dq: 2048
sa_q_layout: "SEQ_MINOR"
sa_k_layout: "SEQ_MINOR"
sa_v_layout: "SEQ_MINOR"
81 changes: 81 additions & 0 deletions src/maxtext/configs/post_train/distillation_qwen3_30b_base.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Qwen3-30b-a3b-base distillation. ~20% MFU on v7x.
base_config: "base.yml"

# NOTE: load_parameters_path values are placeholders. Replace them with paths
# to your own student and teacher checkpoints (Orbax format, ending in /0/items).
student_overrides:
model_name: "qwen3-30b-a3b-base"
override_model_config: True
base_num_query_heads: 16
head_dim: 256
base_num_kv_heads: 2
rope_max_timescale: 1_000_000
load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/student/0/items"
teacher_overrides:
override_model_config: True
model_name: "qwen3-30b-a3b-base"
rope_max_timescale: 1_000_000
load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/teacher/0/items"

distill_alpha: 0.6
distill_temperature: 1.0
distill_beta: 1.0
distill_layer_indices: [0,1,2,3,4,5,6,7]
enable_nnx: True
load_balance_loss_weight: 0.001

ici_fsdp_parallelism: -1
ici_data_parallelism: 4

dataset_type: "grain"
grain_file_type: "arrayrecord"
# Launcher's gcsfuse override (CLI arg) replaces this for the student config.
grain_train_files: "gs://maxtext-dataset/array-record/climbmix/*.arrayrecord"
grain_worker_count: 16
grain_ram_budget_mb: 1024
grain_per_worker_buffer_size: 2
grain_prefetch_buffer_size: 1024
num_epoch: 10

tokenizer_path: "src/maxtext/assets/tokenizers/qwen3-tokenizer"
tokenizer_type: "huggingface"

max_target_length: 8192
per_device_batch_size: 4
gradient_accumulation_steps: 1

steps: 100000
learning_rate_schedule_steps: 100000
log_period: 10
checkpoint_period: 500
save_checkpoint_on_completion: True
enable_checkpointing: True
async_checkpointing: True
skip_jax_distributed_system: False

learning_rate: 1.0e-4
learning_rate_final_fraction: 0.1
warmup_steps_fraction: 0.08
adam_b1: 0.9
adam_b2: 0.95
adam_eps: 1.e-5
adam_weight_decay: 0.01
adamw_mask: ['.*embedding.*', '.*norm.*', '.*bias']
z_loss_multiplier: 1.0e-5
float32_logits: True

# Tokamax splash attention. Layouts: HEAD_DIM_MINOR | SEQ_MINOR.
attention: "flash"
use_tokamax_splash: True
sa_use_fused_bwd_kernel: True
sa_block_q: 1024
sa_block_kv: 1024
sa_block_kv_compute: 512
sa_block_q_dkv: 2048
sa_block_kv_dkv: 2048
sa_block_kv_dkv_compute: 1024
sa_block_q_dq: 1024
sa_block_kv_dq: 1024
sa_q_layout: "HEAD_DIM_MINOR"
sa_k_layout: "SEQ_MINOR"
sa_v_layout: "HEAD_DIM_MINOR"
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Qwen3-30b-a3b-base distillation, pdbs=8 + activation offload. ~22% MFU on v7x.
base_config: "base.yml"

# NOTE: load_parameters_path values are placeholders. Replace them with paths
# to your own student and teacher checkpoints (Orbax format, ending in /0/items).
student_overrides:
model_name: "qwen3-30b-a3b-base"
override_model_config: True
base_num_query_heads: 16
head_dim: 256
base_num_kv_heads: 2
rope_max_timescale: 1_000_000
load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/student/0/items"
teacher_overrides:
override_model_config: True
model_name: "qwen3-30b-a3b-base"
rope_max_timescale: 1_000_000
load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/teacher/0/items"

distill_alpha: 0.6
distill_temperature: 1.0
distill_beta: 1.0
distill_layer_indices: [0,1,2,3,4,5,6,7]
enable_nnx: True
load_balance_loss_weight: 0.001

ici_fsdp_parallelism: -1
ici_data_parallelism: 4

dataset_type: "grain"
grain_file_type: "arrayrecord"
# Launcher's gcsfuse override (CLI arg) replaces this for the student config.
grain_train_files: "gs://maxtext-dataset/array-record/climbmix/*.arrayrecord"
grain_worker_count: 16
grain_ram_budget_mb: 1024
grain_per_worker_buffer_size: 2
grain_prefetch_buffer_size: 1024
num_epoch: 10

tokenizer_path: "src/maxtext/assets/tokenizers/qwen3-tokenizer"
tokenizer_type: "huggingface"

max_target_length: 8192
per_device_batch_size: 8
gradient_accumulation_steps: 1

# Activation offload to fit pdbs=8.
remat_policy: 'custom'
decoder_layer_input: 'offload'

steps: 100000
learning_rate_schedule_steps: 100000
log_period: 10
checkpoint_period: 500
save_checkpoint_on_completion: True
enable_checkpointing: True
async_checkpointing: True
skip_jax_distributed_system: False

learning_rate: 1.0e-4
learning_rate_final_fraction: 0.1
warmup_steps_fraction: 0.08
adam_b1: 0.9
adam_b2: 0.95
adam_eps: 1.e-5
adam_weight_decay: 0.01
adamw_mask: ['.*embedding.*', '.*norm.*', '.*bias']
z_loss_multiplier: 1.0e-5
float32_logits: True

# Tokamax splash attention. Layouts: HEAD_DIM_MINOR | SEQ_MINOR.
attention: "flash"
use_tokamax_splash: True
sa_use_fused_bwd_kernel: True
sa_block_q: 1024
sa_block_kv: 1024
sa_block_kv_compute: 512
sa_block_q_dkv: 2048
sa_block_kv_dkv: 2048
sa_block_kv_dkv_compute: 1024
sa_block_q_dq: 1024
sa_block_kv_dq: 1024
sa_q_layout: "HEAD_DIM_MINOR"
sa_k_layout: "SEQ_MINOR"
sa_v_layout: "HEAD_DIM_MINOR"
Loading
Loading