diff --git a/scripts/distillation/distill_gpt_oss_20b.sh b/scripts/distillation/distill_gpt_oss_20b.sh new file mode 100755 index 0000000000..0b6ccbcaf5 --- /dev/null +++ b/scripts/distillation/distill_gpt_oss_20b.sh @@ -0,0 +1,69 @@ +#!/bin/bash +# Launch gpt-oss-20b distillation on TPU v7x. +# Usage: bash scripts/distillation/distill_gpt_oss_20b.sh [submit|monitor|resume_until_done] +# +# Set DISTILL_GCS_BUCKET to your own GCS bucket and XPK_BASE_IMAGE to your own +# image before running — the placeholders below will not work as-is. Everything +# else has a working default. Example: +# +# DISTILL_GCS_BUCKET=gs://your-bucket \ +# XPK_BASE_IMAGE=gcr.io/your-project/maxtext_base_image:tag \ +# bash scripts/distillation/distill_gpt_oss_20b.sh submit +set -euo pipefail + +MODE="${1:-submit}" +REPO_ROOT=$(cd "$(dirname "$0")/../.." && pwd) +cd "$REPO_ROOT" + +# Your GCS bucket: run outputs, staged YAML, and tokenizer files all land here. +export DISTILL_GCS_BUCKET="${DISTILL_GCS_BUCKET:-gs://YOUR-BUCKET}" + +export XPK_WORKLOAD="${XPK_WORKLOAD:-goss-base-$(date +%Y%m%d-%H%M)}" +export XPK_RUN_NAME="${XPK_RUN_NAME:-gpt_oss_20b_base}" +export XPK_CLUSTER="${XPK_CLUSTER:-bodaborg-super-xpk-x8p}" +export XPK_PROJECT="${XPK_PROJECT:-cloud-tpu-multipod-dev}" +export XPK_ZONE="${XPK_ZONE:-us-central1}" +export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}" +export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-${DISTILL_GCS_BUCKET}/distillation}" +export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}" +export XPK_PRIORITY="${XPK_PRIORITY:-high}" + +export XPK_USE_GCSFUSE=1 +export XPK_DATASET_BUCKET="${XPK_DATASET_BUCKET:-maxtext-dataset}" +export XPK_DATASET_SUBPATH="${XPK_DATASET_SUBPATH:-array-record/climbmix/*.arrayrecord}" + +# Stage HF tokenizer files (not in the image for gpt-oss). +export XPK_TOKENIZER_GCS="${XPK_TOKENIZER_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/tokenizer-gpt-oss-20b/}" +export XPK_TOKENIZER_LOCAL="${XPK_TOKENIZER_LOCAL:-/deps/src/maxtext/assets/tokenizers/gpt-oss-20b-tokenizer}" + +LOCAL_YAML="src/maxtext/configs/post_train/distillation_gpt_oss_20b.yml" +export XPK_DISTILL_CONFIG="${XPK_DISTILL_CONFIG:-$LOCAL_YAML}" +export XPK_YAML_GCS="${XPK_YAML_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/distillation_gpt_oss_20b.yml}" + +# distill_beta=0: decoder feature loss is broken on gpt-oss. +export DISTILL_ALPHA="${DISTILL_ALPHA:-0.5}" +export DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" +export DISTILL_BETA="${DISTILL_BETA:-0}" +export DISTILL_LAYER_INDICES="${DISTILL_LAYER_INDICES:-[]}" + +# XLA flags tuned for ~17% MFU. sparse_core_collective_aggregator is required +# by latency_hiding_layer_scheduler. +export XPK_LIBTPU_INIT_ARGS="${XPK_LIBTPU_INIT_ARGS:---xla_tpu_scoped_vmem_limit_kib=65536 \ +--xla_tpu_impure_enable_packed_bf16_math_ops=true \ +--xla_tpu_aggressive_opt_barrier_removal=true \ +--xla_tpu_enable_sparse_core_collective_aggregator=true \ +--xla_tpu_enable_latency_hiding_layer_scheduler=true \ +--xla_tpu_enable_layer_scheduler_for_dependent_collectives=true \ +--xla_tpu_enable_multi_compute_overlap_in_layer_scheduler=true \ +--xla_tpu_scheduler_percent_shared_memory_limit=150 \ +--xla_enable_async_all_gather=true \ +--xla_tpu_prefer_async_allgather_to_allreduce=true \ +--xla_max_concurrent_async_all_gathers=2 \ +--xla_max_concurrent_async_reduce_scatters=2 \ +--xla_tpu_enable_async_collective_fusion_fuse_all_gather=false}" + +if [ "$MODE" = "submit" ]; then + gcloud storage cp "$XPK_DISTILL_CONFIG" "$XPK_YAML_GCS" +fi + +exec bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh "$MODE" diff --git a/scripts/distillation/distill_qwen3_30b_base.sh b/scripts/distillation/distill_qwen3_30b_base.sh new file mode 100755 index 0000000000..8c179228c4 --- /dev/null +++ b/scripts/distillation/distill_qwen3_30b_base.sh @@ -0,0 +1,62 @@ +#!/bin/bash +# Launch qwen3-30b-a3b-base distillation on TPU v7x. +# Usage: bash scripts/distillation/distill_qwen3_30b_base.sh [submit|monitor|resume_until_done] +# +# Set DISTILL_GCS_BUCKET to your own GCS bucket and XPK_BASE_IMAGE to your own +# image before running — the placeholders below will not work as-is. Everything +# else has a working default. Example: +# +# DISTILL_GCS_BUCKET=gs://your-bucket \ +# XPK_BASE_IMAGE=gcr.io/your-project/maxtext_base_image:tag \ +# bash scripts/distillation/distill_qwen3_30b_base.sh submit +set -euo pipefail + +MODE="${1:-submit}" +REPO_ROOT=$(cd "$(dirname "$0")/../.." && pwd) +cd "$REPO_ROOT" + +# Your GCS bucket: run outputs and staged YAML land here. +export DISTILL_GCS_BUCKET="${DISTILL_GCS_BUCKET:-gs://YOUR-BUCKET}" + +export XPK_WORKLOAD="${XPK_WORKLOAD:-q30b-base-$(date +%Y%m%d-%H%M)}" +export XPK_RUN_NAME="${XPK_RUN_NAME:-qwen3_30b_base}" +export XPK_CLUSTER="${XPK_CLUSTER:-bodaborg-super-xpk-x8p}" +export XPK_PROJECT="${XPK_PROJECT:-cloud-tpu-multipod-dev}" +export XPK_ZONE="${XPK_ZONE:-us-central1}" +export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}" +export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-${DISTILL_GCS_BUCKET}/distillation}" +export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}" +export XPK_PRIORITY="${XPK_PRIORITY:-high}" + +export XPK_USE_GCSFUSE=1 +export XPK_DATASET_BUCKET="${XPK_DATASET_BUCKET:-maxtext-dataset}" +export XPK_DATASET_SUBPATH="${XPK_DATASET_SUBPATH:-array-record/climbmix/*.arrayrecord}" + +LOCAL_YAML="src/maxtext/configs/post_train/distillation_qwen3_30b_base.yml" +export XPK_DISTILL_CONFIG="${XPK_DISTILL_CONFIG:-$LOCAL_YAML}" +export XPK_YAML_GCS="${XPK_YAML_GCS:-${DISTILL_GCS_BUCKET}/distill-configs/distillation_qwen3_30b_base.yml}" + +export DISTILL_ALPHA="${DISTILL_ALPHA:-0.6}" +export DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" +export DISTILL_BETA="${DISTILL_BETA:-1.0}" +export DISTILL_LAYER_INDICES="${DISTILL_LAYER_INDICES:-[0,1,2,3,4,5,6,7]}" + +# XLA flags tuned for ~20% MFU. +export XPK_LIBTPU_INIT_ARGS="${XPK_LIBTPU_INIT_ARGS:---xla_tpu_scoped_vmem_limit_kib=61440 \ +--xla_tpu_enable_all_experimental_scheduler_features=true \ +--xla_tpu_enable_scheduler_memory_pressure_tracking=true \ +--xla_tpu_host_transfer_overlap_limit=24 \ +--xla_tpu_aggressive_opt_barrier_removal=ENABLED \ +--xla_lhs_prioritize_async_depth_over_stall=ENABLED \ +--xla_tpu_enable_ag_backward_pipelining=true \ +--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \ +--xla_should_add_loop_invariant_op_in_chain=ENABLED \ +--xla_max_concurrent_host_send_recv=100 \ +--xla_tpu_scheduler_percent_shared_memory_limit=100 \ +--xla_latency_hiding_scheduler_rerun=2}" + +if [ "$MODE" = "submit" ]; then + gcloud storage cp "$XPK_DISTILL_CONFIG" "$XPK_YAML_GCS" +fi + +exec bash src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh "$MODE" diff --git a/src/maxtext/configs/post_train/distillation_gpt_oss_20b.yml b/src/maxtext/configs/post_train/distillation_gpt_oss_20b.yml new file mode 100644 index 0000000000..e3a81809b5 --- /dev/null +++ b/src/maxtext/configs/post_train/distillation_gpt_oss_20b.yml @@ -0,0 +1,79 @@ +# gpt-oss-20b distillation. ~17% MFU on v7x. +base_config: "base.yml" + +# NOTE: load_parameters_path values are placeholders. Replace them with paths +# to your own student and teacher checkpoints (Orbax format, ending in /0/items). +student_overrides: + model_name: "gpt-oss-20b" + override_model_config: True + base_num_query_heads: 32 + head_dim: 128 + base_num_kv_heads: 4 + load_parameters_path: "gs://YOUR-BUCKET/distillation/gpt_oss_20b/student/0/items" +teacher_overrides: + model_name: "gpt-oss-20b" + load_parameters_path: "gs://YOUR-BUCKET/distillation/gpt_oss_20b/teacher/0/items" + +# distill_beta=0: decoder feature loss is broken on gpt-oss. +distill_alpha: 0.5 +distill_temperature: 1.0 +distill_beta: 0 +distill_layer_indices: [] +enable_nnx: True +load_balance_loss_weight: 0.001 + +ici_fsdp_parallelism: 32 +ici_data_parallelism: 4 + +dataset_type: "grain" +grain_file_type: "arrayrecord" +# Launcher's gcsfuse override (CLI arg) replaces this for the student config. +grain_train_files: "gs://maxtext-dataset/array-record/climbmix/*.arrayrecord" +grain_worker_count: 16 +grain_ram_budget_mb: 1024 +grain_per_worker_buffer_size: 2 +grain_prefetch_buffer_size: 1024 +num_epoch: 10 + +tokenizer_path: "src/maxtext/assets/tokenizers/gpt-oss-20b-tokenizer" +tokenizer_type: "huggingface" + +max_target_length: 32768 +per_device_batch_size: 1 +gradient_accumulation_steps: 1 + +steps: 100000 +learning_rate_schedule_steps: 100000 +log_period: 10 +checkpoint_period: 500 +save_checkpoint_on_completion: True +enable_checkpointing: True +async_checkpointing: True +skip_jax_distributed_system: False + +learning_rate: 5.0e-5 +learning_rate_final_fraction: 0.2 +warmup_steps_fraction: 0.02 +adam_b1: 0.9 +adam_b2: 0.95 +adam_eps: 1.e-5 +adam_weight_decay: 0.01 +adamw_mask: ['.*embedding.*', '.*norm.*', '.*bias'] +z_loss_multiplier: 1.0e-5 +float32_logits: True + +# Tokamax splash attention. Layouts: HEAD_DIM_MINOR | SEQ_MINOR. +attention: "flash" +use_tokamax_splash: True +sa_use_fused_bwd_kernel: True +sa_block_q: 2048 +sa_block_kv: 2048 +sa_block_kv_compute: 2048 +sa_block_q_dkv: 2048 +sa_block_kv_dkv: 2048 +sa_block_kv_dkv_compute: 2048 +sa_block_q_dq: 2048 +sa_block_kv_dq: 2048 +sa_q_layout: "SEQ_MINOR" +sa_k_layout: "SEQ_MINOR" +sa_v_layout: "SEQ_MINOR" diff --git a/src/maxtext/configs/post_train/distillation_qwen3_30b_base.yml b/src/maxtext/configs/post_train/distillation_qwen3_30b_base.yml new file mode 100644 index 0000000000..a154f0e292 --- /dev/null +++ b/src/maxtext/configs/post_train/distillation_qwen3_30b_base.yml @@ -0,0 +1,81 @@ +# Qwen3-30b-a3b-base distillation. ~20% MFU on v7x. +base_config: "base.yml" + +# NOTE: load_parameters_path values are placeholders. Replace them with paths +# to your own student and teacher checkpoints (Orbax format, ending in /0/items). +student_overrides: + model_name: "qwen3-30b-a3b-base" + override_model_config: True + base_num_query_heads: 16 + head_dim: 256 + base_num_kv_heads: 2 + rope_max_timescale: 1_000_000 + load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/student/0/items" +teacher_overrides: + override_model_config: True + model_name: "qwen3-30b-a3b-base" + rope_max_timescale: 1_000_000 + load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/teacher/0/items" + +distill_alpha: 0.6 +distill_temperature: 1.0 +distill_beta: 1.0 +distill_layer_indices: [0,1,2,3,4,5,6,7] +enable_nnx: True +load_balance_loss_weight: 0.001 + +ici_fsdp_parallelism: -1 +ici_data_parallelism: 4 + +dataset_type: "grain" +grain_file_type: "arrayrecord" +# Launcher's gcsfuse override (CLI arg) replaces this for the student config. +grain_train_files: "gs://maxtext-dataset/array-record/climbmix/*.arrayrecord" +grain_worker_count: 16 +grain_ram_budget_mb: 1024 +grain_per_worker_buffer_size: 2 +grain_prefetch_buffer_size: 1024 +num_epoch: 10 + +tokenizer_path: "src/maxtext/assets/tokenizers/qwen3-tokenizer" +tokenizer_type: "huggingface" + +max_target_length: 8192 +per_device_batch_size: 4 +gradient_accumulation_steps: 1 + +steps: 100000 +learning_rate_schedule_steps: 100000 +log_period: 10 +checkpoint_period: 500 +save_checkpoint_on_completion: True +enable_checkpointing: True +async_checkpointing: True +skip_jax_distributed_system: False + +learning_rate: 1.0e-4 +learning_rate_final_fraction: 0.1 +warmup_steps_fraction: 0.08 +adam_b1: 0.9 +adam_b2: 0.95 +adam_eps: 1.e-5 +adam_weight_decay: 0.01 +adamw_mask: ['.*embedding.*', '.*norm.*', '.*bias'] +z_loss_multiplier: 1.0e-5 +float32_logits: True + +# Tokamax splash attention. Layouts: HEAD_DIM_MINOR | SEQ_MINOR. +attention: "flash" +use_tokamax_splash: True +sa_use_fused_bwd_kernel: True +sa_block_q: 1024 +sa_block_kv: 1024 +sa_block_kv_compute: 512 +sa_block_q_dkv: 2048 +sa_block_kv_dkv: 2048 +sa_block_kv_dkv_compute: 1024 +sa_block_q_dq: 1024 +sa_block_kv_dq: 1024 +sa_q_layout: "HEAD_DIM_MINOR" +sa_k_layout: "SEQ_MINOR" +sa_v_layout: "HEAD_DIM_MINOR" diff --git a/src/maxtext/configs/post_train/distillation_qwen3_30b_base_pdbs8.yml b/src/maxtext/configs/post_train/distillation_qwen3_30b_base_pdbs8.yml new file mode 100644 index 0000000000..56f079decc --- /dev/null +++ b/src/maxtext/configs/post_train/distillation_qwen3_30b_base_pdbs8.yml @@ -0,0 +1,85 @@ +# Qwen3-30b-a3b-base distillation, pdbs=8 + activation offload. ~22% MFU on v7x. +base_config: "base.yml" + +# NOTE: load_parameters_path values are placeholders. Replace them with paths +# to your own student and teacher checkpoints (Orbax format, ending in /0/items). +student_overrides: + model_name: "qwen3-30b-a3b-base" + override_model_config: True + base_num_query_heads: 16 + head_dim: 256 + base_num_kv_heads: 2 + rope_max_timescale: 1_000_000 + load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/student/0/items" +teacher_overrides: + override_model_config: True + model_name: "qwen3-30b-a3b-base" + rope_max_timescale: 1_000_000 + load_parameters_path: "gs://YOUR-BUCKET/distillation/qwen3_30b/teacher/0/items" + +distill_alpha: 0.6 +distill_temperature: 1.0 +distill_beta: 1.0 +distill_layer_indices: [0,1,2,3,4,5,6,7] +enable_nnx: True +load_balance_loss_weight: 0.001 + +ici_fsdp_parallelism: -1 +ici_data_parallelism: 4 + +dataset_type: "grain" +grain_file_type: "arrayrecord" +# Launcher's gcsfuse override (CLI arg) replaces this for the student config. +grain_train_files: "gs://maxtext-dataset/array-record/climbmix/*.arrayrecord" +grain_worker_count: 16 +grain_ram_budget_mb: 1024 +grain_per_worker_buffer_size: 2 +grain_prefetch_buffer_size: 1024 +num_epoch: 10 + +tokenizer_path: "src/maxtext/assets/tokenizers/qwen3-tokenizer" +tokenizer_type: "huggingface" + +max_target_length: 8192 +per_device_batch_size: 8 +gradient_accumulation_steps: 1 + +# Activation offload to fit pdbs=8. +remat_policy: 'custom' +decoder_layer_input: 'offload' + +steps: 100000 +learning_rate_schedule_steps: 100000 +log_period: 10 +checkpoint_period: 500 +save_checkpoint_on_completion: True +enable_checkpointing: True +async_checkpointing: True +skip_jax_distributed_system: False + +learning_rate: 1.0e-4 +learning_rate_final_fraction: 0.1 +warmup_steps_fraction: 0.08 +adam_b1: 0.9 +adam_b2: 0.95 +adam_eps: 1.e-5 +adam_weight_decay: 0.01 +adamw_mask: ['.*embedding.*', '.*norm.*', '.*bias'] +z_loss_multiplier: 1.0e-5 +float32_logits: True + +# Tokamax splash attention. Layouts: HEAD_DIM_MINOR | SEQ_MINOR. +attention: "flash" +use_tokamax_splash: True +sa_use_fused_bwd_kernel: True +sa_block_q: 1024 +sa_block_kv: 1024 +sa_block_kv_compute: 512 +sa_block_q_dkv: 2048 +sa_block_kv_dkv: 2048 +sa_block_kv_dkv_compute: 1024 +sa_block_q_dq: 1024 +sa_block_kv_dq: 1024 +sa_q_layout: "HEAD_DIM_MINOR" +sa_k_layout: "SEQ_MINOR" +sa_v_layout: "HEAD_DIM_MINOR" diff --git a/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh b/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh index ee051839a2..91078508d9 100644 --- a/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh +++ b/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh @@ -85,6 +85,11 @@ # XPK_DATASET_SUBPATH default: array-record/climbmix/*.arrayrecord # The script always sets grain_train_files from these # two, overriding the YAML in both modes. +# XPK_HF_CACHE_DIR default: /dev/shm/hf — HF `datasets` Arrow cache dir. +# tmpfs by default so the full HF dataset doesn't fill +# the ~10GB ephemeral quota and Evict the pod (hit on +# gpt-oss stage1+stage2, ~23GB). No-op for grain runs. +# Point at a local SSD if the dataset exceeds host RAM. # STEPS_OVERRIDE default: empty — yml `steps` is used unless set # CHECKPOINT_PERIOD_OVERRIDE default: empty — yml `checkpoint_period` is used # MAX_RETRIES default: 10 — only used by resume_until_done @@ -149,6 +154,7 @@ require_env() { : "${XPK_USE_GCSFUSE:=1}" : "${XPK_DATASET_BUCKET:=maxtext-dataset}" : "${XPK_DATASET_SUBPATH:=array-record/climbmix/*.arrayrecord}" +: "${XPK_HF_CACHE_DIR:=/dev/shm/hf}" : "${MAX_RETRIES:=10}" # Feature-mapping / distillation loss hyperparameters. @@ -198,6 +204,35 @@ else grain_files_override="grain_train_files=gs://${XPK_DATASET_BUCKET}/${XPK_DATASET_SUBPATH}" fi +# Optional: stage the YAML from GCS instead of baking via upload_runner. +yaml_prelude="" +if [ -n "${XPK_YAML_GCS:-}" ]; then + yaml_prelude="gcloud storage cp \"${XPK_YAML_GCS}\" \"${XPK_DISTILL_CONFIG}\";" +fi + +# Optional: stage HF tokenizer files from GCS for models whose tokenizer isn't +# baked into the image (e.g. gpt-oss). +tokenizer_prelude="" +if [ -n "${XPK_TOKENIZER_GCS:-}" ] && [ -n "${XPK_TOKENIZER_LOCAL:-}" ]; then + tokenizer_prelude="mkdir -p \"${XPK_TOKENIZER_LOCAL}\" && gcloud storage rsync \"${XPK_TOKENIZER_GCS}\" \"${XPK_TOKENIZER_LOCAL}\";" +fi + +# Default v7x XLA flags. The default vmem limit (32 MB) is too small for +# tokamax splash backward; we need ≥60 MB. +default_libtpu_args="--xla_tpu_scoped_vmem_limit_kib=61440 \ +--xla_tpu_enable_all_experimental_scheduler_features=true \ +--xla_tpu_enable_scheduler_memory_pressure_tracking=true \ +--xla_tpu_host_transfer_overlap_limit=24 \ +--xla_tpu_aggressive_opt_barrier_removal=ENABLED \ +--xla_lhs_prioritize_async_depth_over_stall=ENABLED \ +--xla_tpu_enable_ag_backward_pipelining=true \ +--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \ +--xla_should_add_loop_invariant_op_in_chain=ENABLED \ +--xla_max_concurrent_host_send_recv=100 \ +--xla_tpu_scheduler_percent_shared_memory_limit=100 \ +--xla_latency_hiding_scheduler_rerun=2" +libtpu_init_args=$(printf '%s' "${XPK_LIBTPU_INIT_ARGS:-$default_libtpu_args}" | tr -s '[:space:]' ' ') + # -------------------------- prep_image -------------------------- # Adds tunix and repins jax/libtpu on top of $XPK_BASE_IMAGE, then retags # the result as $XPK_BASE_IMAGE (the original tag is overwritten). @@ -279,6 +314,11 @@ submit_workload() { echo "Image flag: $image_flag=$XPK_BASE_IMAGE" # PYTHONPATH covers both image flows: /deps/src (upload_runner-baked) and /app/src (xpk crane overlay). + # TMPDIR=/dev/shm (set in --command below): XPK Pathways mounts /dev/shm as a + # disk-backed emptyDir by default (NOT real tmpfs), so this redirect just moves + # scratch off the image filesystem — it does NOT consume RAM. Make /dev/shm + # Memory-backed via a kubectl patch if you want true tmpfs. (Comment lives here, + # not inline: a `#` in the quoted --command would comment out the rest.) xpk workload create \ --cluster "$XPK_CLUSTER" \ --workload "$XPK_WORKLOAD" \ @@ -290,6 +330,11 @@ submit_workload() { "$image_flag=$XPK_BASE_IMAGE" \ --command "export PYTHONPATH=/deps/src:/app/src; \ export BASE_OUTPUT_DIRECTORY=${OUTPUT_DIR}; \ +export LIBTPU_INIT_ARGS='${libtpu_init_args}'; \ +export TMPDIR=/dev/shm; export JAX_COMPILATION_CACHE_DIR=/dev/shm/jax_cache; \ +export HF_HOME=${XPK_HF_CACHE_DIR}; export HF_DATASETS_CACHE=${XPK_HF_CACHE_DIR}/datasets; mkdir -p ${XPK_HF_CACHE_DIR}/datasets; \ +${yaml_prelude} \ +${tokenizer_prelude} \ ${gcsfuse_prelude} \ python3 -m maxtext.trainers.post_train.distillation.train_distill ${XPK_DISTILL_CONFIG} \ run_name=${XPK_RUN_NAME} \