Skip to content

Commit 5448207

Browse files
Merge pull request #3749 from AI-Hypercomputer:gagik-olmo-data
PiperOrigin-RevId: 911613385
2 parents baa48fb + 24500db commit 5448207

14 files changed

Lines changed: 2960 additions & 1 deletion
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# OLMo numpy pipeline (`dataset_type=olmo_grain`)
2+
3+
Grain-based input pipeline for AI2's pre-tokenized OLMo data mixes (e.g.
4+
`OLMo-mix-0925-official.txt`). Reads headerless flat `.npy` token streams
5+
from a gcsfuse mount, shards across hosts, optionally masks repeated-n-gram
6+
instances, and yields the shapes the MaxText pretrain trainer expects.
7+
8+
## Quick start
9+
10+
1. **Download the data** to a GCS bucket. `--mix-file` is a local AI2 manifest listing relative npy paths to fetch from AI2's public bucket (e.g. `OLMo-mix-0925-official.txt` for the 6T pretrain mix or `OLMo-midtraining-mix-0625-100B.txt` for the 100B midtraining mix).
11+
12+
```bash
13+
python tools/data_generation/download_olmo_data_to_gcs.py \
14+
--mix-file ./OLMo-mix-0925-official.txt \
15+
--gcs-dest gs://my-bucket/dataset/ \
16+
--staging-dir /mnt/local-ssd/olmo-staging \
17+
--workers 16
18+
```
19+
20+
2. **Mount it read-only** with gcsfuse (`np.memmap` needs a local path):
21+
22+
```bash
23+
gcsfuse --implicit-dirs --o ro my-bucket /mnt/olmo-readonly
24+
```
25+
26+
3. **Build the index**:
27+
28+
```bash
29+
python tools/data_generation/build_olmo_npy_index.py \
30+
--mix-file /path/to/OLMo-mix-0925-official.txt \
31+
--gcs-base gs://my-bucket/dataset/ \
32+
--tokenizer allenai/dolma3-tokenizer \
33+
--sequence-length 8192 \
34+
--output /path/to/olmo_index_seq8192.json
35+
```
36+
37+
4. **Configure + run** the trainer:
38+
39+
```yaml
40+
dataset_type: olmo_grain
41+
olmo_index_path: /path/to/olmo_index_seq8192.json
42+
olmo_path_remap_from: "gs://my-bucket/"
43+
olmo_path_remap_to: "/mnt/olmo-readonly/"
44+
max_target_length: 8192 # must equal index sequence_length
45+
tokenizer_type: huggingface
46+
tokenizer_path: allenai/Olmo-3-7B-Instruct
47+
```
48+
49+
See `scripts/run_olmo3_7b_grain_smoke.sh` for a runnable smoke launcher.
50+
51+
## Resume
52+
53+
Stateless sampler: record at step *k* is a pure function of `(seed, shard, k)`. On startup, the trainer adapter reads the latest step from
54+
`config.checkpoint_dir` and shifts the sampler so the data stream picks
55+
up where it left off — no Grain-iterator-state in the checkpoint.
56+
57+
`scripts/run_olmo3_7b_grain_resume_test.sh` validates this end-to-end.
58+
59+
## Notes
60+
61+
- Files are headerless raw uint32 by default (matches AI2's published
62+
format). The numpy `.npy` extension is misleading.
63+
- Documents may span instance boundaries; this matches OLMo-core.
64+
- `olmo_apply_ngram_filter: True` (default) zeroes loss on instances with
65+
≥ 32 repetitions of any 1–13-gram, per OLMo-core.
66+
- For mixing pretraining + midtraining, build a combined index by
67+
concatenating the two .txt mix files.
68+
69+
## Troubleshooting
70+
71+
| Symptom | Fix |
72+
| ------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------- |
73+
| `OLMo index sequence_length=N but config.max_target_length=M` | Rebuild the index with `--sequence-length M`. |
74+
| `q_block_size=512 should divide q_seq_len=…` | Set `max_target_length` to a multiple of 512. |
75+
| OOM during compile on a small TPU | Shrink with `override_model_config=True base_num_decoder_layers=N`, use `weight_dtype=bfloat16`. |
76+
| Resume restarts at step 0 | Iterator log should print `resumed_step=N initial_step=…`; if both 0, `checkpoint_dir` is empty or wrong. |
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#!/bin/bash
2+
# End-to-end resume test for the OLMo grain pipeline (stateless sampler +
3+
# step-derived initial_step). See scripts/run_olmo3_7b_grain_smoke.sh for
4+
# the env-var contract; this script accepts the same vars.
5+
#
6+
# Plan:
7+
# Run A: train 50 steps from scratch, save checkpoint at step 50, exit.
8+
# Run B: relaunch with the SAME run_name (so the checkpoint dir is reused).
9+
# The trainer restores model state at step 50; our iterator factory
10+
# detects the latest checkpoint step and sets ``initial_step`` so
11+
# the data stream picks up at absolute position 50 * per_host_batch.
12+
# Train 25 more steps (to step 75).
13+
#
14+
# What success looks like:
15+
# * Run B's first step (step 51) reports a loss similar to Run A's step 50
16+
# loss. A spike or jump → model state didn't restore.
17+
# * No repeats: Run B's batches are NOT the same as Run A's batches at the
18+
# same absolute step. (Hard to assert without batch-content hashing in
19+
# the trainer; for the smoke we rely on the unit tests + loss continuity.)
20+
# * No regression: Run B's loss continues to decrease.
21+
#
22+
# Outputs:
23+
# ${LOG_A} — first 50 steps
24+
# ${LOG_B} — resumed 25 steps
25+
# $OUTPUT_DIR/<run_name>/checkpoints/ — Orbax checkpoint(s)
26+
27+
set -euo pipefail
28+
29+
MAXTEXT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
30+
VENV_PATH="${VENV_PATH:-${MAXTEXT_ROOT}/maxtext_venv}"
31+
HF_SECRETS="${HF_SECRETS:-}"
32+
INDEX_PATH="${INDEX_PATH:?INDEX_PATH is required (path to olmo index JSON)}"
33+
GCS_BASE="${GCS_BASE:?GCS_BASE is required (e.g. gs://my-bucket/)}"
34+
LOCAL_MOUNT="${LOCAL_MOUNT:?LOCAL_MOUNT is required (gcsfuse mount path of GCS_BASE)}"
35+
OUTPUT_DIR="${OUTPUT_DIR:-/tmp/olmo_resume_test_out}"
36+
RUN_NAME="${RUN_NAME:-olmo_resume_$(date +%Y%m%d-%H%M%S)}"
37+
38+
# Where each run's stdout is teed. Keep them under OUTPUT_DIR so the
39+
# script doesn't depend on a hard-coded absolute path.
40+
LOG_A="${LOG_A:-${OUTPUT_DIR}/${RUN_NAME}.runA.log}"
41+
LOG_B="${LOG_B:-${OUTPUT_DIR}/${RUN_NAME}.runB.log}"
42+
43+
PER_DEVICE_BATCH="${PER_DEVICE_BATCH:-1}"
44+
SEQ_LEN="${SEQ_LEN:-8192}"
45+
WEIGHT_DTYPE="${WEIGHT_DTYPE:-bfloat16}"
46+
NUM_LAYERS="${NUM_LAYERS:-4}"
47+
DATA_SEED="${DATA_SEED:-42}"
48+
49+
# Run A trains 50 steps + saves a checkpoint at step 50; Run B continues to 75.
50+
STEPS_A="${STEPS_A:-50}"
51+
STEPS_B="${STEPS_B:-75}"
52+
CHECKPOINT_PERIOD="${CHECKPOINT_PERIOD:-50}"
53+
54+
# shellcheck disable=SC1090,SC1091
55+
source "${VENV_PATH}/bin/activate"
56+
if [[ -n "${HF_SECRETS:-}" && -f "${HF_SECRETS}" ]]; then
57+
# shellcheck disable=SC1090
58+
source "${HF_SECRETS}"
59+
fi
60+
: "${HF_TOKEN:?HF_TOKEN must be set (or HF_SECRETS pointing at a file that exports it)}"
61+
export PYTHONPATH="${MAXTEXT_ROOT}/src:${PYTHONPATH:-}"
62+
export PYTHONUNBUFFERED=1
63+
64+
mkdir -p "${OUTPUT_DIR}"
65+
66+
TOKENIZER_PATH="${TOKENIZER_PATH:-allenai/Olmo-3-7B-Instruct}"
67+
68+
run_train() {
69+
local steps="$1"
70+
local logfile="$2"
71+
echo "----- launching: steps=${steps} log=${logfile} -----"
72+
python -m maxtext.trainers.pre_train.train \
73+
"${MAXTEXT_ROOT}/src/maxtext/configs/base.yml" \
74+
model_name=olmo3-7b-pt \
75+
run_name="${RUN_NAME}" \
76+
base_output_directory="${OUTPUT_DIR}" \
77+
dataset_type=olmo_grain \
78+
olmo_index_path="${INDEX_PATH}" \
79+
olmo_path_remap_from="${GCS_BASE}" \
80+
olmo_path_remap_to="${LOCAL_MOUNT}" \
81+
data_shuffle_seed="${DATA_SEED}" \
82+
olmo_apply_ngram_filter=True \
83+
grain_worker_count=0 \
84+
per_device_batch_size="${PER_DEVICE_BATCH}" \
85+
max_target_length="${SEQ_LEN}" \
86+
steps="${steps}" \
87+
enable_checkpointing=True \
88+
async_checkpointing=False \
89+
checkpoint_period="${CHECKPOINT_PERIOD}" \
90+
save_checkpoint_on_completion=True \
91+
tokenizer_type=huggingface \
92+
tokenizer_path="${TOKENIZER_PATH}" \
93+
weight_dtype="${WEIGHT_DTYPE}" \
94+
override_model_config=True \
95+
base_num_decoder_layers="${NUM_LAYERS}" \
96+
sharding_tolerance=0.05 \
97+
2>&1 | tee "${logfile}"
98+
}
99+
100+
echo "=== OLMo 3 grain resume test ==="
101+
echo " run_name : ${RUN_NAME}"
102+
echo " output_dir : ${OUTPUT_DIR}/${RUN_NAME}"
103+
echo " per_device_bs : ${PER_DEVICE_BATCH}"
104+
echo " seq_len : ${SEQ_LEN}"
105+
echo " num_layers : ${NUM_LAYERS}"
106+
echo " Run A steps : ${STEPS_A} (will checkpoint at step ${CHECKPOINT_PERIOD})"
107+
echo " Run B steps : ${STEPS_B} (resumed via initial_step)"
108+
echo
109+
110+
# Run A
111+
run_train "${STEPS_A}" "${LOG_A}"
112+
113+
echo
114+
echo "=== Run A done. Last 3 step events: ==="
115+
grep -E "completed step:" "${LOG_A}" | tail -3
116+
echo
117+
118+
# Run B (resume)
119+
run_train "${STEPS_B}" "${LOG_B}"
120+
121+
echo
122+
echo "=== Run B done ==="
123+
echo "First 3 step events from Run B (expect step >= ${STEPS_A}):"
124+
grep -E "completed step:" "${LOG_B}" | head -3
125+
echo
126+
echo "Last 3 step events from Run B:"
127+
grep -E "completed step:" "${LOG_B}" | tail -3
128+
echo
129+
130+
echo "=== Pass criteria (manual check): ==="
131+
echo " 1. Run B's first step number >= ${STEPS_A} (model state restored)"
132+
echo " 2. Run B's first step loss within ~5% of Run A's last step loss"
133+
echo " (model continued, no re-init)"
134+
echo " 3. Loss continues to decrease across Run B"
135+
echo " 4. iterator log line shows 'resumed_step=${STEPS_A} initial_step=...' on Run B"
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/bin/bash
2+
# Smoke training run for OLMo 3 7B on the OLMo numpy grain pipeline.
3+
#
4+
# Validates that dataset_type=olmo_grain wires through the trainer, that
5+
# OlmoNpyDataSource reads .npy data via a gcsfuse mount, and that 50 steps
6+
# execute without crashes / shape mismatches with monotonically decreasing
7+
# loss.
8+
#
9+
# Required env vars:
10+
# INDEX_PATH JSON index from tools/data_generation/build_olmo_npy_index.py
11+
# GCS_BASE gs:// prefix recorded in the index (e.g. gs://my-bucket/)
12+
# LOCAL_MOUNT gcsfuse mount of GCS_BASE on this host
13+
# HF_TOKEN HuggingFace token for the tokenizer (or HF_SECRETS=<file>)
14+
# Optional: VENV_PATH, OUTPUT_DIR, PER_DEVICE_BATCH, SEQ_LEN, STEPS,
15+
# WEIGHT_DTYPE, NUM_LAYERS.
16+
#
17+
# Usage:
18+
# INDEX_PATH=/path/to/olmo_index_seq8192.json \
19+
# LOCAL_MOUNT=/mnt/your-mount/ \
20+
# GCS_BASE=gs://your-bucket/ \
21+
# HF_TOKEN=hf_... \
22+
# bash scripts/run_olmo3_7b_grain_smoke.sh
23+
24+
set -euo pipefail
25+
26+
MAXTEXT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
27+
28+
VENV_PATH="${VENV_PATH:-${MAXTEXT_ROOT}/maxtext_venv}"
29+
HF_SECRETS="${HF_SECRETS:-}"
30+
INDEX_PATH="${INDEX_PATH:?INDEX_PATH is required (path to olmo index JSON)}"
31+
GCS_BASE="${GCS_BASE:?GCS_BASE is required (e.g. gs://my-bucket/)}"
32+
LOCAL_MOUNT="${LOCAL_MOUNT:?LOCAL_MOUNT is required (gcsfuse mount path of GCS_BASE)}"
33+
OUTPUT_DIR="${OUTPUT_DIR:-/tmp/olmo_smoke_out}"
34+
35+
PER_DEVICE_BATCH="${PER_DEVICE_BATCH:-1}"
36+
SEQ_LEN="${SEQ_LEN:-8192}"
37+
STEPS="${STEPS:-50}"
38+
DATA_SEED="${DATA_SEED:-42}"
39+
# Smoke test uses a reduced model (bf16, 4 layers) so it fits small TPU
40+
# slices; we're validating the data path, not full-size convergence.
41+
WEIGHT_DTYPE="${WEIGHT_DTYPE:-bfloat16}"
42+
NUM_LAYERS="${NUM_LAYERS:-4}"
43+
44+
RUN_NAME="${RUN_NAME:-olmo_grain_smoke_$(date +%Y%m%d-%H%M%S)}"
45+
46+
# Activate venv + load HF secrets.
47+
# shellcheck disable=SC1090,SC1091
48+
source "${VENV_PATH}/bin/activate"
49+
if [[ -n "${HF_SECRETS:-}" && -f "${HF_SECRETS}" ]]; then
50+
# shellcheck disable=SC1090
51+
source "${HF_SECRETS}"
52+
fi
53+
: "${HF_TOKEN:?HF_TOKEN must be set (or HF_SECRETS pointing at a file that exports it)}"
54+
export PYTHONPATH="${MAXTEXT_ROOT}/src:${PYTHONPATH:-}"
55+
export PYTHONUNBUFFERED=1
56+
57+
mkdir -p "${OUTPUT_DIR}"
58+
59+
echo "=== OLMo 3 7B + olmo_grain smoke run ==="
60+
echo " run_name : ${RUN_NAME}"
61+
echo " index : ${INDEX_PATH}"
62+
echo " path remap : ${GCS_BASE}${LOCAL_MOUNT}"
63+
echo " per_device_bs : ${PER_DEVICE_BATCH}"
64+
echo " seq_len : ${SEQ_LEN}"
65+
echo " steps : ${STEPS}"
66+
echo " weight_dtype : ${WEIGHT_DTYPE}"
67+
echo " num_layers : ${NUM_LAYERS} (full 7B has 32)"
68+
echo " output_dir : ${OUTPUT_DIR}"
69+
echo
70+
71+
# Data is already tokenized; the tokenizer is loaded only for pad/eos IDs +
72+
# vocab_size checks. Olmo-3-7B-Instruct uses the same dolma3 tokenizer.
73+
TOKENIZER_PATH="${TOKENIZER_PATH:-allenai/Olmo-3-7B-Instruct}"
74+
75+
python -m maxtext.trainers.pre_train.train \
76+
"${MAXTEXT_ROOT}/src/maxtext/configs/base.yml" \
77+
model_name=olmo3-7b-pt \
78+
run_name="${RUN_NAME}" \
79+
base_output_directory="${OUTPUT_DIR}" \
80+
dataset_type=olmo_grain \
81+
olmo_index_path="${INDEX_PATH}" \
82+
olmo_path_remap_from="${GCS_BASE}" \
83+
olmo_path_remap_to="${LOCAL_MOUNT}" \
84+
data_shuffle_seed="${DATA_SEED}" \
85+
olmo_apply_ngram_filter=True \
86+
grain_worker_count=0 \
87+
per_device_batch_size="${PER_DEVICE_BATCH}" \
88+
max_target_length="${SEQ_LEN}" \
89+
steps="${STEPS}" \
90+
enable_checkpointing=False \
91+
tokenizer_type=huggingface \
92+
tokenizer_path="${TOKENIZER_PATH}" \
93+
weight_dtype="${WEIGHT_DTYPE}" \
94+
override_model_config=True \
95+
base_num_decoder_layers="${NUM_LAYERS}" \
96+
sharding_tolerance=0.05

src/maxtext/configs/base.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,14 @@ grain_use_elastic_iterator: False # For elastic training, set to this true and p
734734
# for using pathways
735735
colocated_python_data_input: False # experimental feature, under testing
736736

737+
# OLMo numpy pipeline (dataset_type=olmo_grain). Worker count, buffer size,
738+
# and shuffle seed reuse grain_worker_count / grain_per_worker_buffer_size /
739+
# data_shuffle_seed.
740+
olmo_index_path: '' # JSON from tools/data_generation/build_olmo_npy_index.py
741+
olmo_path_remap_from: '' # rewrite index paths starting with this prefix...
742+
olmo_path_remap_to: '' # ...to this one (e.g. gs://bucket/ -> /mnt/.../ for gcsfuse).
743+
olmo_apply_ngram_filter: True # mask instances with repetitive n-grams (OLMo-core filter)
744+
737745
# Training loop
738746
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
739747
log_period: 100 # The frequency of Tensorboard flush, gcs metrics writing, and managed profiler metrics updating.

src/maxtext/configs/types.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ class DatasetType(str, Enum):
177177
GRAIN = "grain"
178178
TFDS = "tfds"
179179
C4MLPERF = "c4_mlperf"
180+
OLMO_GRAIN = "olmo_grain"
180181

181182

182183
class SamplingStrategy(str, Enum):
@@ -1158,6 +1159,32 @@ class GrainDataset(BaseModel):
11581159
grain_shuffle_buffer_size: int = Field(100, description="Shuffle buffer size when using Parquet or TFRecord.")
11591160

11601161

1162+
class OlmoGrainDataset(BaseModel):
1163+
"""Configuration for the OLMo numpy fixed-seq-length input pipeline (dataset_type=olmo_grain).
1164+
1165+
Separate from the standard grain config because this pipeline reads
1166+
pre-tokenized fixed-length sequences from raw npy files (one ``int32``
1167+
token per element, ``sequence_length`` from an index JSON), not
1168+
arrayrecord/tfds shards — so flags like ``grain_train_files`` /
1169+
``packing`` don't apply.
1170+
1171+
Worker count, per-worker buffer size, and shuffle seed reuse the standard
1172+
grain flags (``grain_worker_count``, ``grain_per_worker_buffer_size``,
1173+
``data_shuffle_seed``); only OLMo-specific fields are listed here.
1174+
"""
1175+
1176+
olmo_index_path: PathStr = Field("", description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.")
1177+
olmo_path_remap_from: PathStr = Field(
1178+
"",
1179+
description="If set, rewrite index file paths starting with this prefix to olmo_path_remap_to.",
1180+
)
1181+
olmo_path_remap_to: PathStr = Field(
1182+
"",
1183+
description="Replacement prefix used together with olmo_path_remap_from (e.g. /mnt/disks/.../).",
1184+
)
1185+
olmo_apply_ngram_filter: bool = Field(True, description="Mask repetitive instances per OLMo-core's repetition filter.")
1186+
1187+
11611188
class FineTuning(BaseModel):
11621189
"""Configuration for fine-tuning methods like DPO, SFT, and GRPO."""
11631190

@@ -2154,6 +2181,7 @@ class MaxTextConfig(
21542181
TfdsDataset,
21552182
HfDataset,
21562183
GrainDataset,
2184+
OlmoGrainDataset,
21572185
Tokenizer,
21582186
# Inference
21592187
InferenceGeneral,

src/maxtext/input_pipeline/input_pipeline_interface.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from maxtext.input_pipeline.grain_data_processing import make_grain_eval_iterator
2424
from maxtext.input_pipeline.hf_data_processing import make_hf_train_iterator
2525
from maxtext.input_pipeline.hf_data_processing import make_hf_eval_iterator
26+
from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_train_iterator
27+
from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_eval_iterator
2628
from maxtext.input_pipeline.tfds_data_processing import make_tfds_train_iterator
2729
from maxtext.input_pipeline.tfds_data_processing import make_tfds_eval_iterator
2830
from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_train_iterator
@@ -73,10 +75,11 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh):
7375
"grain": (make_grain_train_iterator, make_grain_eval_iterator),
7476
"hf": (make_hf_train_iterator, make_hf_eval_iterator),
7577
"c4_mlperf": (make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator),
78+
"olmo_grain": (make_olmo_grain_train_iterator, make_olmo_grain_eval_iterator),
7679
}
7780

7881
# Collect train and eval iterators
79-
if config.dataset_type in ["tfds", "grain", "hf", "c4_mlperf"]:
82+
if config.dataset_type in ["tfds", "grain", "hf", "c4_mlperf", "olmo_grain"]:
8083
if config.dataset_type == "c4_mlperf":
8184
assert config.packing, "c4_mlperf dataloader only works with packing. For padded version, use tfds dataloader"
8285
train_iterator, eval_iterator = dataset_type_to_train_eval_iterator[config.dataset_type]

0 commit comments

Comments
 (0)