Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ omit =
[paths]
source =
src/MaxText
src/MaxText
src/maxtext
*/site-packages/MaxText
*/site-packages/maxtext

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ repos:
# args:
# - '--jobs=auto'
# - '--keep-going'
# - 'src/MaxText/'
# - 'src/maxtext/'

- repo: https://github.com/google/pyink
rev: 24.10.1
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies
* \[October 10, 2025\] Post-Training (SFT, RL) via [Tunix](https://github.com/google/tunix) is now available.
* \[September 26, 2025\] Vocabulary tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_vocab_tiling` to unlock more efficient memory usage.
* \[September 24, 2025\] The GPT-OSS family of models (20B, 120B) is now supported.
* \[September 15, 2025\] MaxText is now available as a [PyPI package](https://pypi.org/project/maxtext). Users can now [install maxtext through pip](https://maxtext.readthedocs.io/en/latest/guides/install_maxtext.html).
* \[September 15, 2025\] MaxText is now available as a [PyPI package](https://pypi.org/project/maxtext). Users can now [install maxtext through pip](https://maxtext.readthedocs.io/en/latest/install_maxtext.html).
* \[September 5, 2025\] MaxText has moved to an `src` layout as part of [RESTRUCTURE.md](https://github.com/AI-Hypercomputer/maxtext/blob/aca5b24931ebcbadb55a82e56ebffe8024874028/RESTRUCTURE.md). For existing environments, please run `pip install -e .` from MaxText root.
* \[August 13, 2025\] The Qwen3 2507 MoE family of models is now supported: MoEs: 235B Thinking & 480B Coder as well as existing dense models: 0.6B, 4B, 8B, 14B, and 32B.
* \[July 27, 2025\] Updated TFLOPS/s calculation ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/1988)) to account for causal attention, dividing the attention flops in half. Accounted for sliding window and chunked attention reduced attention flops in [PR](https://github.com/AI-Hypercomputer/maxtext/pull/2009) and [PR](https://github.com/AI-Hypercomputer/maxtext/pull/2030). Changes impact large sequence configs, as explained in this [doc](https://maxtext.readthedocs.io/en/latest/reference/performance_metrics.html)
Expand Down
1 change: 0 additions & 1 deletion benchmarks/convergence/c4_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from benchmarks.benchmark_utils import MaxTextModel, _add_to_model_dictionary
from benchmarks.convergence.convergence_utils import DatasetHParams, ConvHParams, _setup_model_convergence_

from benchmarks.maxtext_v5p_model_configs import deepseek_v3_ep_256_v5p_512

c4_pretrain_model_dict = {}
Expand Down
1 change: 0 additions & 1 deletion benchmarks/disruption_management/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import time

from benchmarks.disruption_management.disruption_utils import wait_for_pod_to_start

from benchmarks.disruption_management.disruption_handler import DisruptionConfig
from benchmarks.disruption_management.disruption_handler import TriggerType

Expand Down
52 changes: 27 additions & 25 deletions benchmarks/llama2_v6e-256_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
on a specific v6e-256 hardware setup using the XPK runner.
"""

import maxtext_trillium_model_configs as model_configs
import os

from maxtext_xpk_runner import BenchmarkRunner
from maxtext_xpk_runner import HWConfig
from maxtext_xpk_runner import SWconfig
from maxtext_xpk_runner import xpk_benchmark_runner
from maxtext_xpk_runner import XpkConfig
from benchmarks import maxtext_trillium_model_configs as model_configs
from benchmarks.maxtext_xpk_runner import WorkloadConfig
from benchmarks.maxtext_xpk_runner import xpk_benchmark_runner
from benchmarks.maxtext_xpk_runner import XpkClusterConfig


DATE = "20241009"
Expand All @@ -35,34 +34,37 @@
DEVICE_TYPE = "v6e-256"
NUM_SLICES = 1
BASE_OUTPUT_DIR = "gs://maxtext-experiments-tpem/"

v6e_env_configs = SWconfig(base_docker_image=BASE_DOCKER_IMAGE, libtpu_version=DATE)
v6e_256_configs = HWConfig(num_slices=NUM_SLICES, device_type=DEVICE_TYPE)

llama2_70b_4096 = BenchmarkRunner(
model_name=model_configs.llama2_70b_4096,
software_config=v6e_env_configs,
hardware_config=v6e_256_configs,
)

llama2_7b_4096 = BenchmarkRunner(
model_name=model_configs.llama2_7b_4096,
software_config=v6e_env_configs,
hardware_config=v6e_256_configs,
)
XPK_PATH = os.path.join("~", "xpk")
BENCHMARK_STEPS = 20


def main() -> None:
cluster_config = XpkConfig(
cluster_config = XpkClusterConfig(
cluster_name=CLUSTER_NAME,
project=PROJECT,
zone=ZONE,
num_slices=NUM_SLICES,
device_type=DEVICE_TYPE,
base_output_directory=BASE_OUTPUT_DIR,
)

xpk_benchmark_runner(cluster_config, [llama2_7b_4096, llama2_70b_4096])
workload_configs = []
for model in [model_configs.llama2_7b_4096, model_configs.llama2_70b_4096]:
workload_configs.append(
WorkloadConfig(
model=model,
num_slices=NUM_SLICES,
device_type=DEVICE_TYPE,
base_output_directory=BASE_OUTPUT_DIR,
base_docker_image=BASE_DOCKER_IMAGE,
libtpu_type=None,
libtpu_nightly_version=DATE,
pathways_config=None,
xpk_path=XPK_PATH,
num_steps=BENCHMARK_STEPS,
priority="medium",
)
)

xpk_benchmark_runner(cluster_config, workload_configs)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/maxtext_xpk_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
import omegaconf

import benchmarks.maxtext_trillium_model_configs as model_configs
import benchmarks.xla_flags_library as xla_flags
from benchmarks.globals import MAXTEXT_PKG_DIR
from benchmarks.command_utils import run_command_with_updates
import benchmarks.xla_flags_library as xla_flags
from benchmarks.disruption_management.disruption_handler import DisruptionConfig
from benchmarks.disruption_management.disruption_manager import DisruptionManager
from benchmarks.xpk_configs import XpkClusterConfig
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/recipes/mcjax_long_running_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import benchmarks.maxtext_trillium_model_configs as model_configs
import benchmarks.maxtext_xpk_runner as mxr
from benchmarks.xpk_configs import XpkClusterConfig
from . import user_configs
from benchmarks.recipes import user_configs

# Cluster Params
CLUSTER = "v6e-256-cluster"
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/recipes/pw_elastic_training_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
from . import args_helper as helper
from . import user_configs

from benchmarks.disruption_management.disruption_handler import DisruptionMethod
from .runner_utils import generate_and_run_workloads
from benchmarks.recipes import args_helper as helper
from benchmarks.recipes import user_configs
from benchmarks.recipes.runner_utils import generate_and_run_workloads

user_configs.USER_CONFIG.max_restarts = 10
COMPARE_WITH_MCJAX = True
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/recipes/pw_headless_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"""

import benchmarks.recipes.args_helper as helper
from .. import maxtext_xpk_runner as mxr
from ..recipes.user_configs import USER_CONFIG
from benchmarks import maxtext_xpk_runner as mxr
from benchmarks.recipes.user_configs import USER_CONFIG


def main() -> int:
Expand Down
13 changes: 5 additions & 8 deletions benchmarks/recipes/pw_long_running_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,10 @@
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)

import recipes.args_helper as helper

import maxtext_trillium_model_configs as model_configs

import maxtext_xpk_runner as mxr

from xpk_configs import XpkClusterConfig
import benchmarks.maxtext_trillium_model_configs as model_configs
import benchmarks.maxtext_xpk_runner as mxr
import benchmarks.recipes.args_helper as helper
from benchmarks.xpk_configs import XpkClusterConfig

PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server"
SERVER_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server"
Expand Down Expand Up @@ -66,7 +63,7 @@ def main():
)

# Handle command line arguments using args_helper
should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, xpk_path=XPK_PATH)
should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, USER, xpk_path=XPK_PATH)

if not should_continue:
return
Expand Down
12 changes: 6 additions & 6 deletions benchmarks/recipes/pw_mcjax_benchmark_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
from . import args_helper as helper
from .user_configs import UserConfig
from .user_configs import USER_CONFIG
from .runner_utils import generate_and_run_workloads
from . import parser_utils
from benchmarks.recipes import args_helper as helper
from benchmarks.recipes import parser_utils
from benchmarks.recipes.pw_utils import check_and_create_bucket
from benchmarks.recipes.runner_utils import generate_and_run_workloads
from benchmarks.recipes.user_configs import UserConfig
from benchmarks.recipes.user_configs import USER_CONFIG
import argparse
from google.cloud import storage
from .pw_utils import check_and_create_bucket


def main(user_config) -> int:
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/recipes/pw_mcjax_checkpoint_benchmark_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import datetime
import dataclasses
import os
import args_helper as helper

from benchmarks import maxtext_trillium_model_configs as model_configs
import benchmarks.maxtext_xpk_runner as mxr
from benchmarks import maxtext_trillium_model_configs as model_configs
from benchmarks.recipes import args_helper as helper
from benchmarks.xpk_configs import XpkClusterConfig

PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server"
Expand Down Expand Up @@ -185,7 +185,7 @@ def main() -> int:
)

# Handle command line arguments using args_helper
should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, xpk_path=XPK_PATH)
should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, os.environ["USER"], xpk_path=XPK_PATH)

if not should_continue:
return 0
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/recipes/pw_remote_python_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import os

import args_helper as helper
import benchmarks.recipes.args_helper as helper

from benchmarks import maxtext_trillium_model_configs as model_configs
from benchmarks import maxtext_xpk_runner as mxr
Expand All @@ -40,7 +40,7 @@ def main():
xpk_path = "xpk"

# Handle command line arguments using args_helper
should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, xpk_path=xpk_path)
should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, os.environ["USER"], xpk_path=xpk_path)

if not should_continue:
return
Expand Down
7 changes: 3 additions & 4 deletions benchmarks/recipes/pw_suspend_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
from . import args_helper as helper
from . import user_configs

from benchmarks.disruption_management.disruption_handler import DisruptionMethod
from .runner_utils import generate_and_run_workloads
from benchmarks.recipes import args_helper as helper
from benchmarks.recipes import user_configs
from benchmarks.recipes.runner_utils import generate_and_run_workloads

user_configs.USER_CONFIG.max_restarts = 3
DISRUPTION_METHOD = DisruptionMethod.SIGTERM
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/recipes/pw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import typing

import maxtext_xpk_runner as mxr
import benchmarks.maxtext_xpk_runner as mxr
from google.api_core.exceptions import (
NotFound,
Conflict,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/recipes/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging

from .. import maxtext_xpk_runner as mxr
from benchmarks import maxtext_xpk_runner as mxr
from benchmarks.benchmark_utils import Framework
from benchmarks.disruption_management.disruption_manager import construct_disruption_configs

Expand Down
8 changes: 4 additions & 4 deletions benchmarks/recipes/user_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
from .. import maxtext_trillium_model_configs as v6e_model_configs
from .. import maxtext_v5e_model_configs as v5e_model_configs
from .. import maxtext_v5p_model_configs as v5p_model_configs
from .pw_utils import build_user_models, get_cluster_config, get_pathways_config
from benchmarks import maxtext_trillium_model_configs as v6e_model_configs
from benchmarks import maxtext_v5e_model_configs as v5e_model_configs
from benchmarks import maxtext_v5p_model_configs as v5p_model_configs
from benchmarks.recipes.pw_utils import build_user_models, get_cluster_config, get_pathways_config


AVAILABLE_MODELS_FRAMEWORKS = ["mcjax", "pathways"]
Expand Down
9 changes: 3 additions & 6 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,18 @@
codecov:
token: 35742a22-fb1f-4839-97ff-b54da5588689
# By default file names in the coverage report will have their path in the file system, which in our
# runners would be /__w/maxtext/maxtext/src/MaxText/* but Codecov expects src/MaxText/* so we need to fix the path
# runners would be /__w/maxtext/maxtext/src/maxtext/* but Codecov expects src/maxtext/* so we need to fix the path
fixes:
# - ".*/maxtext/src/::src/"
- "/github/workspace/::"
ignore:
- "src/maxtext/assets"
- "src/maxtext/configs"
- "src/maxtext/examples"
- "src/MaxText/experimental"
- "src/maxtext/experimental"
- "src/maxtext/inference"
- "src/maxtext/scratch_code"
- "src/MaxText/distillation" # code moved to src/maxtext/trainers/post_train/distillation
- "src/MaxText/sft" # code moved to src/maxtext/trainers/post_train/sft
- "src/MaxText/rl" # code moved to src/maxtext/trainers/post_train/rl

- "src/MaxText"

flags:
# Updated ONLY by PRs (contains subset of tests, excluding scheduled_only).
Expand Down
21 changes: 11 additions & 10 deletions docs/guides.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,45 @@

Explore our how-to guides for optimizing, debugging, and managing your MaxText workloads.

::::{grid} 1 2 2 2
::::\{grid} 1 2 2 2
:gutter: 2

:::{grid-item-card} ⚡ Optimization
:::\{grid-item-card} ⚡ Optimization
:link: guides/optimization
:link-type: doc

Techniques for maximizing performance, including sharding strategies, Pallas kernels, and benchmarking.
:::

:::{grid-item-card} 💾 Data Pipelines
:::\{grid-item-card} 💾 Data Pipelines
:link: guides/data_input_pipeline
:link-type: doc

Configure input pipelines using **Grain** (recommended for determinism), **HuggingFace**, or **TFDS**.
:::

:::{grid-item-card} 🔄 Checkpointing
:::\{grid-item-card} 🔄 Checkpointing
:link: guides/checkpointing_solutions
:link-type: doc

Manage GCS checkpoints, handle preemption with emergency checkpointing, and configure multi-tier storage.
:::

:::{grid-item-card} 🔍 Monitoring & Debugging
:::\{grid-item-card} 🔍 Monitoring & Debugging
:link: guides/monitoring_and_debugging
:link-type: doc

Tools for observability: goodput monitoring, hung job debugging, and Vertex AI TensorBoard integration.
:::

:::{grid-item-card} 🐍 Python Notebooks
:::\{grid-item-card} 🐍 Python Notebooks
:link: guides/run_python_notebook
:link-type: doc

Interactive development guides for running MaxText on Google Colab or local JupyterLab environments.
:::

:::{grid-item-card} 🌱 Model Bringup
:::\{grid-item-card} 🌱 Model Bringup
:link: guides/model_bringup
:link-type: doc

Expand All @@ -65,9 +65,10 @@ A step-by-step guide for the community to help expand MaxText's model library.
::::

```{toctree}
:hidden:
:maxdepth: 1

---
hidden:
maxdepth: 1
---
guides/optimization.md
guides/data_input_pipeline.md
guides/checkpointing_solutions.md
Expand Down
Loading
Loading