|
| 1 | +#!/bin/bash |
| 2 | +# End-to-end resume test for the OLMo grain pipeline (stateless sampler + |
| 3 | +# step-derived initial_step). See scripts/run_olmo3_7b_grain_smoke.sh for |
| 4 | +# the env-var contract; this script accepts the same vars. |
| 5 | +# |
| 6 | +# Plan: |
| 7 | +# Run A: train 50 steps from scratch, save checkpoint at step 50, exit. |
| 8 | +# Run B: relaunch with the SAME run_name (so the checkpoint dir is reused). |
| 9 | +# The trainer restores model state at step 50; our iterator factory |
| 10 | +# detects the latest checkpoint step and sets ``initial_step`` so |
| 11 | +# the data stream picks up at absolute position 50 * per_host_batch. |
| 12 | +# Train 25 more steps (to step 75). |
| 13 | +# |
| 14 | +# What success looks like: |
| 15 | +# * Run B's first step (step 51) reports a loss similar to Run A's step 50 |
| 16 | +# loss. A spike or jump → model state didn't restore. |
| 17 | +# * No repeats: Run B's batches are NOT the same as Run A's batches at the |
| 18 | +# same absolute step. (Hard to assert without batch-content hashing in |
| 19 | +# the trainer; for the smoke we rely on the unit tests + loss continuity.) |
| 20 | +# * No regression: Run B's loss continues to decrease. |
| 21 | +# |
| 22 | +# Outputs: |
| 23 | +# ${LOG_A} — first 50 steps |
| 24 | +# ${LOG_B} — resumed 25 steps |
| 25 | +# $OUTPUT_DIR/<run_name>/checkpoints/ — Orbax checkpoint(s) |
| 26 | + |
| 27 | +set -euo pipefail |
| 28 | + |
| 29 | +MAXTEXT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" |
| 30 | +VENV_PATH="${VENV_PATH:-${MAXTEXT_ROOT}/maxtext_venv}" |
| 31 | +HF_SECRETS="${HF_SECRETS:-}" |
| 32 | +INDEX_PATH="${INDEX_PATH:?INDEX_PATH is required (path to olmo index JSON)}" |
| 33 | +GCS_BASE="${GCS_BASE:?GCS_BASE is required (e.g. gs://my-bucket/)}" |
| 34 | +LOCAL_MOUNT="${LOCAL_MOUNT:?LOCAL_MOUNT is required (gcsfuse mount path of GCS_BASE)}" |
| 35 | +OUTPUT_DIR="${OUTPUT_DIR:-/tmp/olmo_resume_test_out}" |
| 36 | +RUN_NAME="${RUN_NAME:-olmo_resume_$(date +%Y%m%d-%H%M%S)}" |
| 37 | + |
| 38 | +# Where each run's stdout is teed. Keep them under OUTPUT_DIR so the |
| 39 | +# script doesn't depend on a hard-coded absolute path. |
| 40 | +LOG_A="${LOG_A:-${OUTPUT_DIR}/${RUN_NAME}.runA.log}" |
| 41 | +LOG_B="${LOG_B:-${OUTPUT_DIR}/${RUN_NAME}.runB.log}" |
| 42 | + |
| 43 | +PER_DEVICE_BATCH="${PER_DEVICE_BATCH:-1}" |
| 44 | +SEQ_LEN="${SEQ_LEN:-8192}" |
| 45 | +WEIGHT_DTYPE="${WEIGHT_DTYPE:-bfloat16}" |
| 46 | +NUM_LAYERS="${NUM_LAYERS:-4}" |
| 47 | +DATA_SEED="${DATA_SEED:-42}" |
| 48 | + |
| 49 | +# Run A trains 50 steps + saves a checkpoint at step 50; Run B continues to 75. |
| 50 | +STEPS_A="${STEPS_A:-50}" |
| 51 | +STEPS_B="${STEPS_B:-75}" |
| 52 | +CHECKPOINT_PERIOD="${CHECKPOINT_PERIOD:-50}" |
| 53 | + |
| 54 | +# shellcheck disable=SC1090,SC1091 |
| 55 | +source "${VENV_PATH}/bin/activate" |
| 56 | +if [[ -n "${HF_SECRETS:-}" && -f "${HF_SECRETS}" ]]; then |
| 57 | + # shellcheck disable=SC1090 |
| 58 | + source "${HF_SECRETS}" |
| 59 | +fi |
| 60 | +: "${HF_TOKEN:?HF_TOKEN must be set (or HF_SECRETS pointing at a file that exports it)}" |
| 61 | +export PYTHONPATH="${MAXTEXT_ROOT}/src:${PYTHONPATH:-}" |
| 62 | +export PYTHONUNBUFFERED=1 |
| 63 | + |
| 64 | +mkdir -p "${OUTPUT_DIR}" |
| 65 | + |
| 66 | +TOKENIZER_PATH="${TOKENIZER_PATH:-allenai/Olmo-3-7B-Instruct}" |
| 67 | + |
| 68 | +run_train() { |
| 69 | + local steps="$1" |
| 70 | + local logfile="$2" |
| 71 | + echo "----- launching: steps=${steps} log=${logfile} -----" |
| 72 | + python -m maxtext.trainers.pre_train.train \ |
| 73 | + "${MAXTEXT_ROOT}/src/maxtext/configs/base.yml" \ |
| 74 | + model_name=olmo3-7b-pt \ |
| 75 | + run_name="${RUN_NAME}" \ |
| 76 | + base_output_directory="${OUTPUT_DIR}" \ |
| 77 | + dataset_type=olmo_grain \ |
| 78 | + olmo_index_path="${INDEX_PATH}" \ |
| 79 | + olmo_path_remap_from="${GCS_BASE}" \ |
| 80 | + olmo_path_remap_to="${LOCAL_MOUNT}" \ |
| 81 | + data_shuffle_seed="${DATA_SEED}" \ |
| 82 | + olmo_apply_ngram_filter=True \ |
| 83 | + grain_worker_count=0 \ |
| 84 | + per_device_batch_size="${PER_DEVICE_BATCH}" \ |
| 85 | + max_target_length="${SEQ_LEN}" \ |
| 86 | + steps="${steps}" \ |
| 87 | + enable_checkpointing=True \ |
| 88 | + async_checkpointing=False \ |
| 89 | + checkpoint_period="${CHECKPOINT_PERIOD}" \ |
| 90 | + save_checkpoint_on_completion=True \ |
| 91 | + tokenizer_type=huggingface \ |
| 92 | + tokenizer_path="${TOKENIZER_PATH}" \ |
| 93 | + weight_dtype="${WEIGHT_DTYPE}" \ |
| 94 | + override_model_config=True \ |
| 95 | + base_num_decoder_layers="${NUM_LAYERS}" \ |
| 96 | + sharding_tolerance=0.05 \ |
| 97 | + 2>&1 | tee "${logfile}" |
| 98 | +} |
| 99 | + |
| 100 | +echo "=== OLMo 3 grain resume test ===" |
| 101 | +echo " run_name : ${RUN_NAME}" |
| 102 | +echo " output_dir : ${OUTPUT_DIR}/${RUN_NAME}" |
| 103 | +echo " per_device_bs : ${PER_DEVICE_BATCH}" |
| 104 | +echo " seq_len : ${SEQ_LEN}" |
| 105 | +echo " num_layers : ${NUM_LAYERS}" |
| 106 | +echo " Run A steps : ${STEPS_A} (will checkpoint at step ${CHECKPOINT_PERIOD})" |
| 107 | +echo " Run B steps : ${STEPS_B} (resumed via initial_step)" |
| 108 | +echo |
| 109 | + |
| 110 | +# Run A |
| 111 | +run_train "${STEPS_A}" "${LOG_A}" |
| 112 | + |
| 113 | +echo |
| 114 | +echo "=== Run A done. Last 3 step events: ===" |
| 115 | +grep -E "completed step:" "${LOG_A}" | tail -3 |
| 116 | +echo |
| 117 | + |
| 118 | +# Run B (resume) |
| 119 | +run_train "${STEPS_B}" "${LOG_B}" |
| 120 | + |
| 121 | +echo |
| 122 | +echo "=== Run B done ===" |
| 123 | +echo "First 3 step events from Run B (expect step >= ${STEPS_A}):" |
| 124 | +grep -E "completed step:" "${LOG_B}" | head -3 |
| 125 | +echo |
| 126 | +echo "Last 3 step events from Run B:" |
| 127 | +grep -E "completed step:" "${LOG_B}" | tail -3 |
| 128 | +echo |
| 129 | + |
| 130 | +echo "=== Pass criteria (manual check): ===" |
| 131 | +echo " 1. Run B's first step number >= ${STEPS_A} (model state restored)" |
| 132 | +echo " 2. Run B's first step loss within ~5% of Run A's last step loss" |
| 133 | +echo " (model continued, no re-init)" |
| 134 | +echo " 3. Loss continues to decrease across Run B" |
| 135 | +echo " 4. iterator log line shows 'resumed_step=${STEPS_A} initial_step=...' on Run B" |
0 commit comments