Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions .github/workflows/check_docs_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,47 @@ jobs:
uses: actions/checkout@v5
with:
persist-credentials: false
fetch-depth: 0

- name: Check if only documentation changed
id: check
run: |
git fetch origin ${GITHUB_BASE_REF}
CHANGED_FILES=$(git diff --name-only origin/${GITHUB_BASE_REF}...HEAD)

# Check for documentation changes
if echo "$CHANGED_FILES" | grep -E '\.(md)$|^docs/' > /dev/null; then
echo "Documentation files changed, enabling docs build."
echo "build_docs=true" >> $GITHUB_OUTPUT
else
echo "No documentation changes, skipping docs build."
echo "build_docs=false" >> $GITHUB_OUTPUT
fi

- name: Install uv and set the Python version
if: steps.check.outputs.build_docs == 'true'
uses: astral-sh/setup-uv@eb1897b8dc4b5d5bfe39a428a8f2304605e0983c # v7.0.0
with:
python-version: '3.12'
enable-cache: true

- name: Set venv
if: steps.check.outputs.build_docs == 'true'
run: uv venv --python 3.12 $GITHUB_WORKSPACE/venv

- name: Install dependencies
if: steps.check.outputs.build_docs == 'true'
run: . $GITHUB_WORKSPACE/venv/bin/activate && uv pip install -r src/dependencies/requirements/requirements_docs.txt

- name: Build documentation
if: steps.check.outputs.build_docs == 'true'
run: |
. $GITHUB_WORKSPACE/venv/bin/activate
uv pip install -e . --no-deps
uv pip install torch
# verify links; the build fails if errors are found
sphinx-build -b linkcheck docs docs/_build/linkcheck -q --keep-going
# generates the actual website
sphinx-build -b html docs docs/_build/html
env:
JAX_PLATFORMS: cpu
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ This project follows

All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
[GitHub Help](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests) for more
information on using pull requests.
37 changes: 37 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,43 @@
os.path.join("reference", "api_generated", "dependencies.github_deps.install_pre_train_deps.rst"),
]

# -- Options for linkcheck ----------------------------------------------------
# Only report broken links (status 'broken') in the output
linkcheck_report_timeouts_as_broken = True

# Enable anchor checking so Sphinx looks for the #section-name
linkcheck_anchors = True

# Ignore the dynamic anchors that are generated in the documentation which can cause false positives in link checking
linkcheck_anchors_ignore = [
r"^L\d+",
r"badput-breakdown-details",
r"online-inference",
r"install-the-seed-env-tool",
r"collect-stack-traces",
r"1-prerequisites",
]

# Disable reporting of allowed redirects to reduce noise in the output
linkcheck_allowed_redirects = {
r"https://github\.com/google/maxtext": r"https://github\.com/AI-Hypercomputer/maxtext/.*",
r"https://cloud\.google\.com/.*": r"https://docs\.cloud\.google\.com/.*",
r"https://jax\.readthedocs\.io/.*": r"https://docs\.jax\.dev/.*",
r"https://twitter\.com/.*": r"https://x\.com/.*",
r"https://www\.sphinx-doc\.org": r"https://www\.sphinx-doc\.org/en/master/.*",
r"https://.*\.readthedocs\.io": r"https://.*\.readthedocs\.io/en/.*",
}

# Ignore specific links that are known to be inaccessible during the build process
linkcheck_ignore = [
# Ignore Google Auth/Console redirects which require login
r"https://accounts\.google\.com/.*",
r"https://console\.cloud\.google\.com/.*",
r"https://cla\.developers\.google\.com/.*",
# Ignore GitHub commit history links which frequently trigger rate limiting (429)
r"https://github\.com/jax-ml/jax/commits/.*",
]


# -- Autogenerate API documentation ------------------------------------------
def run_apidoc(_):
Expand Down
6 changes: 3 additions & 3 deletions docs/guides/data_input_pipeline/data_input_grain.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Grain ensures determinism in data input pipelines by saving the pipeline's state

## Using Grain

1. Grain currently supports three data formats: [ArrayRecord](https://github.com/google/array_record) (random access), [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups) and [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources.html) class.
1. Grain currently supports three data formats: [ArrayRecord](https://github.com/google/array_record) (random access), [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups) and [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources/protocol.html) class.
- **Community Resource**: The MaxText community has created a [ArrayRecord Documentation](https://array-record.readthedocs.io/). Note: we appreciate the contribution from the community, but as of now it has not been verified by the MaxText or ArrayRecord developers yet.
2. If the dataset is hosted on a Cloud Storage bucket, the path `gs://` can be provided directly. However, for the best performance, it's recommended to read the bucket through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). This will significantly improve the perf for the ArrayRecord format as it allows meta data caching to speeds up random access. The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh). The script configures some parameters for the mount.

Expand All @@ -43,11 +43,11 @@ MOUNT_PATH=${MOUNT_PATH?} \
[FILE_PATH=${MOUNT_PATH?}/my_dataset]
```

Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pre-filling the metadata cache (see ["Performance tuning best practices" on the Google Cloud documentation](https://cloud.google.com/storage/docs/cloud-storage-fuse/performance#improve-first-time-reads)).
Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pre-filling the metadata cache (see ["Performance tuning best practices" on the Google Cloud documentation](https://docs.cloud.google.com/storage/docs/cloud-storage-fuse/performance)).

1. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet|tfrecord}`, `grain_train_files` in `src/maxtext/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path.

2. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html), [grain_pool.py](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh) to avoid gcsfuse throttling.
2. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh) to avoid gcsfuse throttling.

3. ArrayRecord Only: For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and a comma (,) for weights. The weights will be automatically normalized to sum to 1.0. For example:

Expand Down
2 changes: 1 addition & 1 deletion docs/guides/model_bringup.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ For models with existing Hugging Face support, you can validate parity using the

### 5.2 Eval Benchmark

MaxText integrates with benchmark libraries like [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) and [evalchemy](https://github.com/mlfoundations/evalchemy) to facilitate rapid verification of common inference scores ([guide](../../benchmarks/api_server)). This is particularly useful for validating decoding outputs or assessing model performance when logits deviate slightly from reference values.
MaxText integrates with benchmark libraries like [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) and [evalchemy](https://github.com/mlfoundations/evalchemy) to facilitate rapid verification of common inference scores ([guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/benchmarks/api_server/README.md)). This is particularly useful for validating decoding outputs or assessing model performance when logits deviate slightly from reference values.

## 6. Completion Checklist

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Much of this guide is geared towards providing Google with the right data to hel

1. Use `JAX` 0.6 or up, and enable JAX distributed service. This version of JAX contains additional logging that can help identify which workers are experiencing issues.
2. Generate an HLO dump using the `--xla_dump_to` flag when initializing your workload. This is discussed in the [XLA Documentation](https://openxla.org/xla/hlo_dumps).
3. Run your workload with stack traces enabled. XPK users should follow the [XPK-specific instructions](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#collect-stack-traces). Note the `--deploy-stacktrace-sidecar` flag when running the XPK workload command.
3. Run your workload with stack traces enabled. XPK users should follow the [XPK-specific instructions](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/troubleshooting.md#collect-stack-traces). Note the `--deploy-stacktrace-sidecar` flag when running the XPK workload command.
4. Set `--vmodule=real_program_continuator=1` to enable verbose logging for the TPU program execution status.

## Locate the Megascale Hang Detected Error
Expand Down Expand Up @@ -70,7 +70,7 @@ If the TPU listed in the log shows a non-zero program counter, it is very likely

If the logged TPU shows a program counter of 0, it is likely that the TPU is waiting on input. We can attempt to confirm the worker is hung during the input pipeline using the stack trace library found in the [cloud-tpu-diagnostics package](https://pypi.org/project/cloud-tpu-diagnostics/).

XPK users should follow the [XPK-specific instructions](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#collect-stack-traces) to emit stack traces. Note the `--deploy-stacktrace-sidecar` flag when running the XPK workload command.
XPK users should follow the [XPK-specific instructions](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/troubleshooting.md#collect-stack-traces) to emit stack traces. Note the `--deploy-stacktrace-sidecar` flag when running the XPK workload command.

Customers can then query Cloud Logging for the stack trace logs from the outlier TPU. The stack trace log will help users determine where in the Python code the program was during the hang.

Expand Down
14 changes: 7 additions & 7 deletions docs/guides/monitoring_and_debugging/monitor_goodput.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Goodput is the metric that measures the efficiency of model training jobs, i.e.

Badput is the metric that measures time that a workload spent on anything that is not productive training proportional to the total time spent by the workload. For example, the time spent in accelerator initialization, training preparation, program startup, data loading, portions of checkpointing, disruptions and wasted progress since the last checkpoint etc. all contribute to Badput.

The ML Goodput Measurement library exposes Badput Breakdown. Further details of each bucket can be found [here](https://github.com/AI-Hypercomputer/ml-goodput-measurement?tab=readme-ov-file#badput-breakdown-details)
The ML Goodput Measurement library exposes Badput Breakdown. Further details of each bucket can be found [here](https://github.com/AI-Hypercomputer/ml-goodput-measurement/blob/main/README.md#badput-breakdown-details)

## What is Step Time Deviation

Expand Down Expand Up @@ -69,8 +69,8 @@ following access scope during node pool creation:
XPK adds this access scope to the GPU, TPU and CPU node pools, so XPK is the recommended method to create clusters and node-pools in you intend to run your workloads on GKE.

Instructions on how to create clusters using XPK can be
found [here](https://github.com/AI-Hypercomputer/xpk/blob/main/README.md#cluster-create) and how to create workloads using XPK can be found
[here](https://github.com/AI-Hypercomputer/xpk/blob/main/README.md#workload-create).
found [here](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/usage/clusters.md) and how to create workloads using XPK can be found
[here](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/usage/workloads.md).

```{note}
Access Scopes are immutable and workloads can only be migrated to new node pools with required access scopes. Access scopes on already created clusters cannot be updated.
Expand Down Expand Up @@ -131,7 +131,7 @@ If checkpointing is enabled, please enable the `enable_checkpoint_cloud_logger`

#### Visualize Goodput, Badput and step deviation on Google Cloud Monitoring

By default, performance data ([goodput](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/goodput_time), [badput](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/badput_time), and [step deviation](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/performance)) is automatically sent to Google Cloud Monitoring, enabling visualization on dashboards.
By default, performance data ([goodput](https://docs.cloud.google.com/monitoring/api/metrics_gcp_c), [badput](https://docs.cloud.google.com/monitoring/api/metrics_gcp_c), and [step deviation](https://docs.cloud.google.com/monitoring/api/metrics_gcp_c)) is automatically sent to Google Cloud Monitoring, enabling visualization on dashboards.

This feature leverages Google VM metadata (project ID, location, accelerator type)
and supports replica IDs for uniquely identifying workloads in multi-replica
Expand Down Expand Up @@ -184,13 +184,13 @@ Goodput, Badput and Step Time Deviation metrics can be monitored using GCM Metri

2. Navigate to [Metrics Explorer](https://console.cloud.google.com/monitoring/metrics-explorer). Initiate metric selection by clicking `Select a metric` then search for and select the `Workload` resource. Subsequently, choose the `Workload` metric category.

a. [**Productive Time:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/goodput_time)
a. [**Productive Time:**](https://docs.cloud.google.com/monitoring/api/metrics_gcp_c)
Represents the cumulative duration the workload spent on productive tasks,
measured by `compute.googleapis.com/workload/goodput_time`.\
b. [**Non-Productive Time:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/badput_time)
b. [**Non-Productive Time:**](https://docs.cloud.google.com/monitoring/api/metrics_gcp_c)
Represents the cumulative duration the workload spent on non-productive tasks,
measured by `compute.googleapis.com/workload/badput_time`.\
c. [**Performance:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/performance)
c. [**Performance:**](https://docs.cloud.google.com/monitoring/api/metrics_gcp_c)
Represents the workload's performance metric, specifically step deviation
in this context, measured by `compute.googleapis.com/workload/performance`.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ You can use a single Vertex AI Tensorboard instance to track and compare metrics

## Prerequisites

- Enable [Vertex AI API](https://cloud.google.com/vertex-ai/docs/start/cloud-environment#enable_vertexai_apis) in your Google Cloud console.
- Enable [Vertex AI API](https://docs.cloud.google.com/vertex-ai/docs/start/cloud-environment#set_up_a_project) in your Google Cloud console.
- Assign [Vertex AI User IAM role](https://cloud.google.com/vertex-ai/docs/general/access-control#aiplatform.user) to the service account used by the TPU VMs. This is required to create and access the Vertex AI Tensorboard in Google Cloud console. If you are using XPK for MaxText, the necessary Vertex AI User IAM role will be automatically assigned to your node pools by XPK – no need to assign it manually.

## Upload logs to Vertex AI Tensorboard
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/optimization/custom_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ Ironwood over ICI:
- `3 * M * 8 / 2 > 12800`
- `M > 1100`

It is important to emphasize that this is a theoretical roofline analysis. Real-world performance will depend on the efficiency of the implementation and XLA compilation on the TPU. Refer to the [link](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/sharding.md#pp--fsdpdp) for specific challenges regarding PP + FSDP/DP.
It is important to emphasize that this is a theoretical roofline analysis. Real-world performance will depend on the efficiency of the implementation and XLA compilation on the TPU. Refer to the [link](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/optimization/sharding.html) for specific challenges regarding PP + FSDP/DP.

## Step 4. Analyze experiments

Expand Down
2 changes: 0 additions & 2 deletions docs/guides/optimization/pallas_kernels_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth

> This is an efficient computation method for Mixture-of-Experts (MoE) models like DeepSeek, Llama 4, Mixtral and Qwen-MoE. In MoE, each token is processed by only a few "experts," which is inefficient for standard matrix multiplication. Megablox solves this by having the CPU (**host**) first create a routing plan (**metadata**) that assigns tokens to experts. The accelerator (**device**) then uses this plan to perform many small, dense matrix multiplications in parallel (**Grouped Matrix Multiplication**), avoiding wasted work on unused experts.

- [`src/maxtext/kernels/megablox/gmm.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/kernels/megablox/gmm.py)

**Note:** Megablox accelerates the grouped **matmul**; **routing/gating** is separate code ([`src/maxtext/layers/moe.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/moe.py)).

## 🔧 The Pallas optimization workflow: code → profile → tune → repeat
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/architecture/jax_ai_libraries_chosen.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ For more information on using Orbax, please refer to https://github.com/google/o

1. **Deterministic by Design**: Grain allows storing data loader states, provides strong guarantees about data ordering and sharding even with preemptions, which is critical for reproducibility.
2. **Global Shuffle**: Prevents local overfitting.
3. **Built for Multi-Host Training**: The using random access file format streamlines [data loading in the multi-host environments](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline.md#multihost-dataloading-best-practice).
3. **Built for Multi-Host Training**: The using random access file format streamlines [data loading in the multi-host environments](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/data_input_pipeline.html#multihost-dataloading-best-practice).

Its APIs are explicitly designed for the multi-host paradigm, simplifying the process of ensuring that each host loads a unique shard of the global batch.

Expand Down
Loading
Loading