Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
f2814ad
base.yml formatting auto
Shuwen-Fang Mar 13, 2026
5f458fe
moe formatting changes
Shuwen-Fang Mar 13, 2026
05f9096
check vma changes
Shuwen-Fang Mar 13, 2026
8696f9a
remove input variable name due to internal change of name
NuojCheng Mar 11, 2026
9c6ed15
enable pp with batch split ds
NuojCheng Feb 13, 2026
a7fd50d
add another layer of custom vjp
NuojCheng Feb 25, 2026
b709e0f
add new pipeline weight prefetching config
NuojCheng Mar 6, 2026
0697c4d
refactor pr
NuojCheng Mar 6, 2026
dc2228c
retrigger CI
gulsumgudukbay Mar 10, 2026
d7ab416
Remove post_training_local_dependencies.Dockerfile
SurbhiJainUSC Mar 11, 2026
1a693a3
update tokamax group sizes for pipeline
NuojCheng Mar 12, 2026
3996e45
Update user docs to drop config file path
dipannita08 Mar 11, 2026
252e34c
Make tokenizer_path to be non-mandatory
A9isha Mar 10, 2026
c17d5b1
PR #2831: Migrate Decoder to NNX
hsuan-lun-chiang Mar 16, 2026
319992e
docs: simplify checkpoint storage flags for Pathways workloads
igorts-git Mar 12, 2026
5afca54
Add option to start test_batch in train_rl from a specific index, als…
A9isha Mar 14, 2026
6740d85
Add qwen2 implementation
ChingTsai Dec 26, 2025
010816a
Move tests, rto_setup.sh and preflight.sh to docker image
SurbhiJainUSC Mar 19, 2026
1d1704d
Update post-training docs to point to single source of truth for inst…
SurbhiJainUSC Mar 19, 2026
c6da8da
add custom mesh and logical rule support
NuojCheng Mar 18, 2026
a32d52d
PR #3449: Move install_maxtext_extra_deps to dependencies directory
bvandermoon Mar 20, 2026
8feb476
add qwen3-base variants and qwen3-1.7b
andytwigg Mar 14, 2026
1dbde05
Fix src/MaxText references in GPU/runner Dockerfiles
bvandermoon Mar 20, 2026
91f0a87
Skip Tokamax RaggedDotGroupSizes for FP8
BirdsOfAFthr Mar 19, 2026
8c30d30
Update moe.py
BirdsOfAFthr Mar 19, 2026
e547f7a
format
BirdsOfAFthr Mar 19, 2026
0deb4cb
Update moe.py
BirdsOfAFthr Mar 19, 2026
3adf889
split logical names in moe module
NuojCheng Mar 20, 2026
3df0d31
base.yml formatting auto
Shuwen-Fang Mar 13, 2026
a5a7cee
moe formatting changes
Shuwen-Fang Mar 13, 2026
99631fd
check vma changes
Shuwen-Fang Mar 13, 2026
331882c
remove input variable name due to internal change of name
NuojCheng Mar 11, 2026
7cd4ede
enable pp with batch split ds
NuojCheng Feb 13, 2026
665877f
add another layer of custom vjp
NuojCheng Feb 25, 2026
3b46bd5
add new pipeline weight prefetching config
NuojCheng Mar 6, 2026
42b4e8a
refactor pr
NuojCheng Mar 6, 2026
48555d8
retrigger CI
gulsumgudukbay Mar 10, 2026
8b774cf
Remove post_training_local_dependencies.Dockerfile
SurbhiJainUSC Mar 11, 2026
84ae8f7
update tokamax group sizes for pipeline
NuojCheng Mar 12, 2026
8a2c5a3
Update user docs to drop config file path
dipannita08 Mar 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build_and_push_docker_image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ jobs:
MODE=${{ inputs.build_mode }}
WORKFLOW=${{ inputs.workflow }}
PACKAGE_DIR=./src
TESTS_DIR=./tests
JAX_VERSION=NONE
LIBTPU_VERSION=NONE
INCLUDE_TEST_ASSETS=true
Expand Down
16 changes: 10 additions & 6 deletions PREFLIGHT.md
Original file line number Diff line number Diff line change
@@ -1,35 +1,39 @@
# Optimization 1: Multihost recommended network settings
We included all the recommended network settings in [rto_setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/rto_setup.sh).

We included all the recommended network settings in [rto_setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/rto_setup.sh).

[preflight.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/preflight.sh) will help you apply them based on GCE or GKE platform.

Before you run ML workload on Multihost with GCE or GKE, simply apply `bash preflight.sh PLATFORM=[GCE or GKE]` to leverage the best DCN network performance.

Here is an example for GCE:

```
bash src/dependencies/scripts/preflight.sh PLATFORM=GCE && python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?}
bash preflight.sh PLATFORM=GCE && python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?}
```

Here is an example for GKE:

```
bash src/dependencies/scripts/preflight.sh PLATFORM=GKE && python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?}
bash preflight.sh PLATFORM=GKE && python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?}
```

# Optimization 2: Numa binding (You can only apply this to v4 and v5p)

NUMA binding is recommended for enhanced performance, as it reduces memory latency and maximizes data throughput, ensuring that your high-performance applications operate more efficiently and effectively.

For GCE,
For GCE,
[preflight.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/preflight.sh) will help you install `numactl` dependency, so you can use it directly, here is an example:

```
bash src/dependencies/scripts/preflight.sh PLATFORM=GCE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?}
bash preflight.sh PLATFORM=GCE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?}
```

For GKE,
`numactl` should be built into your docker image from [maxtext_tpu_dependencies.Dockerfile](https://github.com/google/maxtext/blob/main/src/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile), so you can use it directly if you built the maxtext docker image. Here is an example

```
bash src/dependencies/scripts/preflight.sh PLATFORM=GKE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?}
bash preflight.sh PLATFORM=GKE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train run_name=${YOUR_JOB_NAME?}
```

1. `numactl`: This is the command-line tool used for controlling NUMA policy for processes or shared memory. It's particularly useful on multi-socket systems where memory locality can impact performance.
Expand Down
9 changes: 9 additions & 0 deletions docs/tutorials/first_run.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ multiple hosts but is a good way to learn about MaxText.
2. For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html).
3. After installation completes, run training on synthetic data with the following command:

```sh
python3 -m venv ~/venv-maxtext
source ~/venv-maxtext/bin/activate
bash tools/setup/setup.sh
pre-commit install
```

4. After installation completes, run training on synthetic data with the following command:

```sh
python3 -m maxtext.trainers.pre_train.train \
run_name=${YOUR_JOB_NAME?} \
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/posttraining/sft_on_multi_host.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ xpk workload create \
--workload=${WORKLOAD_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${TPU_SLICE?} \
--command "python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL?} load_parameters_path=${MAXTEXT_CKPT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}"
--command "python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane hf_path=${DATASET_NAME?} train_split=${TRAIN_SPLIT?} train_data_columns=${TRAIN_DATA_COLUMNS?}"
```

Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.
Expand All @@ -131,7 +131,7 @@ xpk workload create-pathways \
--workload=${WORKLOAD_NAME?} \
--tpu-type=${TPU_TYPE?} \
--num-slices=${TPU_SLICE?} \
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL?} load_parameters_path=${MAXTEXT_CKPT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) enable_single_controller=True"
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} tokenizer_path=${TOKENIZER_PATH?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True"
```

Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Repository = "https://github.com/AI-Hypercomputer/maxtext.git"
allow-direct-references = true

[tool.hatch.build.targets.wheel]
packages = ["src/MaxText", "src/maxtext", "src/dependencies"]
packages = ["src/MaxText", "src/maxtext", "src/dependencies/github_deps", "src/dependencies"]

# TODO: Add this hook back when it handles device-type parsing
# [tool.hatch.build.targets.wheel.hooks.custom]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ ENV ENV_DEVICE=$DEVICE
ARG PACKAGE_DIR
ENV PACKAGE_DIR=$PACKAGE_DIR

ARG TESTS_DIR
ENV TESTS_DIR=$TESTS_DIR

ENV MAXTEXT_ASSETS_ROOT=/deps/src/maxtext/assets
ENV MAXTEXT_TEST_ASSETS_ROOT=/deps/tests/assets
ENV MAXTEXT_PKG_DIR=/deps/src/maxtext
Expand All @@ -63,9 +66,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \

# Now copy the remaining code (source files that may change frequently)
COPY ${PACKAGE_DIR}/maxtext/ src/maxtext/
COPY ${PACKAGE_DIR}/MaxText/ src/MaxText/
COPY tests*/ tests/
COPY benchmarks*/ benchmarks/
COPY ${TESTS_DIR}*/ tests/

# Download test assets from GCS if building image with test assets
ARG INCLUDE_TEST_ASSETS=false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ ENV ENV_DEVICE=$DEVICE
ARG PACKAGE_DIR
ENV PACKAGE_DIR=$PACKAGE_DIR

ARG TESTS_DIR
ENV TESTS_DIR=$TESTS_DIR

ENV MAXTEXT_ASSETS_ROOT=/deps/src/maxtext/assets
ENV MAXTEXT_TEST_ASSETS_ROOT=/deps/tests/assets
ENV MAXTEXT_PKG_DIR=/deps/src/maxtext
Expand All @@ -63,9 +66,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \

# Now copy the remaining code (source files that may change frequently)
COPY ${PACKAGE_DIR}/maxtext/ src/maxtext/
COPY ${PACKAGE_DIR}/MaxText/ src/MaxText/
COPY tests*/ tests/
COPY benchmarks*/ benchmarks/
COPY ${TESTS_DIR}*/ tests/

# Download test assets from GCS if building image with test assets
ARG INCLUDE_TEST_ASSETS=false
Expand Down
12 changes: 7 additions & 5 deletions src/dependencies/github_deps/install_post_train_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ def main():
"""
os.environ["VLLM_TARGET_DEVICE"] = "tpu"

current_dir = os.path.dirname(os.path.abspath(__file__))
repo_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
extra_deps_path = os.path.join(current_dir, "post_train_deps.txt")
if not os.path.exists(extra_deps_path):
raise FileNotFoundError(f"Dependencies file not found at {extra_deps_path}")
# Adjust this path if your post_train_deps.txt is in a different location,
# e.g., script_dir / "data" / "post_train_deps.txt"
extra_deps_file = script_dir / "post_train_deps.txt"

if not extra_deps_file.exists():
print(f"Error: '{extra_deps_file}' not found.")
print("Please ensure 'post_train_deps.txt' is in the correct location relative to the script.")
sys.exit(1)
# Check if 'uv' is available in the environment
try:
subprocess.run([sys.executable, "-m", "pip", "install", "uv"], check=True, capture_output=True)
Expand Down
8 changes: 8 additions & 0 deletions src/dependencies/github_deps/install_pre_train_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def main():
if not os.path.exists(extra_deps_path):
raise FileNotFoundError(f"Dependencies file not found at {extra_deps_path}")

# Adjust this path if your pre_train_deps.txt is in a different location,
# e.g., script_dir / "data" / "pre_train_deps.txt"
extra_deps_file = script_dir / "pre_train_deps.txt"

if not extra_deps_file.exists():
print(f"Error: '{extra_deps_file}' not found.")
print("Please ensure 'pre_train_deps.txt' is in the correct location relative to the script.")
sys.exit(1)
# Check if 'uv' is available in the environment
try:
subprocess.run([sys.executable, "-m", "pip", "install", "uv"], check=True, capture_output=True)
Expand Down
42 changes: 39 additions & 3 deletions src/dependencies/scripts/docker_build_dependency_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,43 @@

# For instructions on building the MaxText Docker image, please refer to the https://maxtext.readthedocs.io/en/latest/build_maxtext.html.

PACKAGE_DIR="${PACKAGE_DIR:-src}"
echo "PACKAGE_DIR: $PACKAGE_DIR"
# Build docker image with stable dependencies
## bash src/dependencies/scripts/docker_build_dependency_image.sh DEVICE={{gpu|tpu}} MODE=stable

# Build docker image with nightly dependencies
## bash src/dependencies/scripts/docker_build_dependency_image.sh DEVICE={{gpu|tpu}} MODE=nightly

# Build docker image with stable dependencies and, a pinned JAX_VERSION for TPUs
## bash src/dependencies/scripts/docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13

# Build docker image with a pinned JAX_VERSION and, a pinned LIBTPU_VERSION for TPUs
## bash src/dependencies/scripts/docker_build_dependency_image.sh MODE={{stable|nightly}} JAX_VERSION=0.8.1 LIBTPU_VERSION=0.0.31.dev20251119+nightly

# Build docker image with a custom libtpu.so for TPUs
# Note: libtpu.so file must be present in the root directory of the MaxText repository
## bash src/dependencies/scripts/docker_build_dependency_image.sh MODE={{stable|nightly}}

# Build docker image with nightly dependencies and, a pinned JAX_VERSION for GPUs
# Available versions listed at https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax
## bash src/dependencies/scripts/docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109

# ==================================
# POST-TRAINING BUILD EXAMPLES
# ==================================

# Build docker image with post-training dependencies
## bash src/dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training

if [ "${BASH_SOURCE-}" ]; then
this_file="${BASH_SOURCE[0]}"
elif [ "${ZSH_VERSION-}" ]; then
# shellcheck disable=SC2296
this_file="${(%):-%x}"
else
this_file="${0}"
fi

MAXTEXT_REPO_ROOT="${MAXTEXT_REPO_ROOT:-$(CDPATH='' cd -- "$(dirname -- "${this_file}")"'/../../..' && pwd)}"

# Enable "exit immediately if any command fails" option
set -e
Expand Down Expand Up @@ -71,6 +106,7 @@ docker_build_args=(
"MODE=${MODE}"
"JAX_VERSION=${JAX_VERSION}"
"PACKAGE_DIR=${PACKAGE_DIR}"
"TESTS_DIR=${TESTS_DIR}"
)

run_docker_build() {
Expand Down Expand Up @@ -104,7 +140,7 @@ build_tpu_image() {
fi

echo "Building docker image with arguments: ${docker_build_args[*]}"
run_docker_build "$PACKAGE_DIR/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile" "${docker_build_args[@]}"
run_docker_build "$MAXTEXT_REPO_ROOT/src/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile" "${docker_build_args[@]}"
}

if [[ ${DEVICE} == "gpu" ]]; then
Expand Down
Loading
Loading