diff --git a/.github/workflows/check_docs_build.yml b/.github/workflows/check_docs_build.yml index 59d18110b0..e755ff27c5 100644 --- a/.github/workflows/check_docs_build.yml +++ b/.github/workflows/check_docs_build.yml @@ -16,24 +16,47 @@ jobs: uses: actions/checkout@v5 with: persist-credentials: false + fetch-depth: 0 + + - name: Check if only documentation changed + id: check + run: | + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --name-only origin/${GITHUB_BASE_REF}...HEAD) + + # Check for documentation changes + if echo "$CHANGED_FILES" | grep -E '\.(md)$|^docs/' > /dev/null; then + echo "Documentation files changed, enabling docs build." + echo "build_docs=true" >> $GITHUB_OUTPUT + else + echo "No documentation changes, skipping docs build." + echo "build_docs=false" >> $GITHUB_OUTPUT + fi - name: Install uv and set the Python version + if: steps.check.outputs.build_docs == 'true' uses: astral-sh/setup-uv@eb1897b8dc4b5d5bfe39a428a8f2304605e0983c # v7.0.0 with: python-version: '3.12' enable-cache: true - name: Set venv + if: steps.check.outputs.build_docs == 'true' run: uv venv --python 3.12 $GITHUB_WORKSPACE/venv - name: Install dependencies + if: steps.check.outputs.build_docs == 'true' run: . $GITHUB_WORKSPACE/venv/bin/activate && uv pip install -r src/dependencies/requirements/requirements_docs.txt - name: Build documentation + if: steps.check.outputs.build_docs == 'true' run: | . $GITHUB_WORKSPACE/venv/bin/activate uv pip install -e . --no-deps uv pip install torch + # verify links; the build fails if errors are found + sphinx-build -b linkcheck docs docs/_build/linkcheck -q --keep-going + # generates the actual website sphinx-build -b html docs docs/_build/html env: JAX_PLATFORMS: cpu diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 23cb9381ba..0dd851418d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,5 +29,5 @@ This project follows All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult -[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +[GitHub Help](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests) for more information on using pull requests. \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index cc71573ee1..626d242fcd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -124,6 +124,43 @@ os.path.join("reference", "api_generated", "dependencies.github_deps.install_pre_train_deps.rst"), ] +# -- Options for linkcheck ---------------------------------------------------- +# Only report broken links (status 'broken') in the output +linkcheck_report_timeouts_as_broken = True + +# Enable anchor checking so Sphinx looks for the #section-name +linkcheck_anchors = True + +# Ignore the dynamic anchors that are generated in the documentation which can cause false positives in link checking +linkcheck_anchors_ignore = [ + r"^L\d+", + r"badput-breakdown-details", + r"online-inference", + r"install-the-seed-env-tool", + r"collect-stack-traces", + r"1-prerequisites", +] + +# Disable reporting of allowed redirects to reduce noise in the output +linkcheck_allowed_redirects = { + r"https://github\.com/google/maxtext": r"https://github\.com/AI-Hypercomputer/maxtext/.*", + r"https://cloud\.google\.com/.*": r"https://docs\.cloud\.google\.com/.*", + r"https://jax\.readthedocs\.io/.*": r"https://docs\.jax\.dev/.*", + r"https://twitter\.com/.*": r"https://x\.com/.*", + r"https://www\.sphinx-doc\.org": r"https://www\.sphinx-doc\.org/en/master/.*", + r"https://.*\.readthedocs\.io": r"https://.*\.readthedocs\.io/en/.*", +} + +# Ignore specific links that are known to be inaccessible during the build process +linkcheck_ignore = [ + # Ignore Google Auth/Console redirects which require login + r"https://accounts\.google\.com/.*", + r"https://console\.cloud\.google\.com/.*", + r"https://cla\.developers\.google\.com/.*", + # Ignore GitHub commit history links which frequently trigger rate limiting (429) + r"https://github\.com/jax-ml/jax/commits/.*", +] + # -- Autogenerate API documentation ------------------------------------------ def run_apidoc(_): diff --git a/docs/guides/data_input_pipeline/data_input_grain.md b/docs/guides/data_input_pipeline/data_input_grain.md index 6b061cc1a1..8d4d0e06e1 100644 --- a/docs/guides/data_input_pipeline/data_input_grain.md +++ b/docs/guides/data_input_pipeline/data_input_grain.md @@ -32,7 +32,7 @@ Grain ensures determinism in data input pipelines by saving the pipeline's state ## Using Grain -1. Grain currently supports three data formats: [ArrayRecord](https://github.com/google/array_record) (random access), [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups) and [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources.html) class. +1. Grain currently supports three data formats: [ArrayRecord](https://github.com/google/array_record) (random access), [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups) and [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources/protocol.html) class. - **Community Resource**: The MaxText community has created a [ArrayRecord Documentation](https://array-record.readthedocs.io/). Note: we appreciate the contribution from the community, but as of now it has not been verified by the MaxText or ArrayRecord developers yet. 2. If the dataset is hosted on a Cloud Storage bucket, the path `gs://` can be provided directly. However, for the best performance, it's recommended to read the bucket through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). This will significantly improve the perf for the ArrayRecord format as it allows meta data caching to speeds up random access. The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh). The script configures some parameters for the mount. @@ -43,11 +43,11 @@ MOUNT_PATH=${MOUNT_PATH?} \ [FILE_PATH=${MOUNT_PATH?}/my_dataset] ``` -Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pre-filling the metadata cache (see ["Performance tuning best practices" on the Google Cloud documentation](https://cloud.google.com/storage/docs/cloud-storage-fuse/performance#improve-first-time-reads)). +Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pre-filling the metadata cache (see ["Performance tuning best practices" on the Google Cloud documentation](https://docs.cloud.google.com/storage/docs/cloud-storage-fuse/performance)). 1. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet|tfrecord}`, `grain_train_files` in `src/maxtext/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path. -2. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html), [grain_pool.py](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh) to avoid gcsfuse throttling. +2. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh) to avoid gcsfuse throttling. 3. ArrayRecord Only: For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and a comma (,) for weights. The weights will be automatically normalized to sum to 1.0. For example: diff --git a/docs/guides/model_bringup.md b/docs/guides/model_bringup.md index 5aadb3aee1..fb231a49e6 100644 --- a/docs/guides/model_bringup.md +++ b/docs/guides/model_bringup.md @@ -93,7 +93,7 @@ For models with existing Hugging Face support, you can validate parity using the ### 5.2 Eval Benchmark -MaxText integrates with benchmark libraries like [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) and [evalchemy](https://github.com/mlfoundations/evalchemy) to facilitate rapid verification of common inference scores ([guide](../../benchmarks/api_server)). This is particularly useful for validating decoding outputs or assessing model performance when logits deviate slightly from reference values. +MaxText integrates with benchmark libraries like [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) and [evalchemy](https://github.com/mlfoundations/evalchemy) to facilitate rapid verification of common inference scores ([guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/benchmarks/api_server/README.md)). This is particularly useful for validating decoding outputs or assessing model performance when logits deviate slightly from reference values. ## 6. Completion Checklist diff --git a/docs/guides/monitoring_and_debugging/megascale_hang_playbook.md b/docs/guides/monitoring_and_debugging/megascale_hang_playbook.md index 3341824727..bc69ddad2c 100644 --- a/docs/guides/monitoring_and_debugging/megascale_hang_playbook.md +++ b/docs/guides/monitoring_and_debugging/megascale_hang_playbook.md @@ -14,7 +14,7 @@ Much of this guide is geared towards providing Google with the right data to hel 1. Use `JAX` 0.6 or up, and enable JAX distributed service. This version of JAX contains additional logging that can help identify which workers are experiencing issues. 2. Generate an HLO dump using the `--xla_dump_to` flag when initializing your workload. This is discussed in the [XLA Documentation](https://openxla.org/xla/hlo_dumps). -3. Run your workload with stack traces enabled. XPK users should follow the [XPK-specific instructions](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#collect-stack-traces). Note the `--deploy-stacktrace-sidecar` flag when running the XPK workload command. +3. Run your workload with stack traces enabled. XPK users should follow the [XPK-specific instructions](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/troubleshooting.md#collect-stack-traces). Note the `--deploy-stacktrace-sidecar` flag when running the XPK workload command. 4. Set `--vmodule=real_program_continuator=1` to enable verbose logging for the TPU program execution status. ## Locate the Megascale Hang Detected Error @@ -70,7 +70,7 @@ If the TPU listed in the log shows a non-zero program counter, it is very likely If the logged TPU shows a program counter of 0, it is likely that the TPU is waiting on input. We can attempt to confirm the worker is hung during the input pipeline using the stack trace library found in the [cloud-tpu-diagnostics package](https://pypi.org/project/cloud-tpu-diagnostics/). -XPK users should follow the [XPK-specific instructions](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#collect-stack-traces) to emit stack traces. Note the `--deploy-stacktrace-sidecar` flag when running the XPK workload command. +XPK users should follow the [XPK-specific instructions](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/troubleshooting.md#collect-stack-traces) to emit stack traces. Note the `--deploy-stacktrace-sidecar` flag when running the XPK workload command. Customers can then query Cloud Logging for the stack trace logs from the outlier TPU. The stack trace log will help users determine where in the Python code the program was during the hang. diff --git a/docs/guides/monitoring_and_debugging/monitor_goodput.md b/docs/guides/monitoring_and_debugging/monitor_goodput.md index ca949d6079..b662499452 100644 --- a/docs/guides/monitoring_and_debugging/monitor_goodput.md +++ b/docs/guides/monitoring_and_debugging/monitor_goodput.md @@ -30,7 +30,7 @@ Goodput is the metric that measures the efficiency of model training jobs, i.e. Badput is the metric that measures time that a workload spent on anything that is not productive training proportional to the total time spent by the workload. For example, the time spent in accelerator initialization, training preparation, program startup, data loading, portions of checkpointing, disruptions and wasted progress since the last checkpoint etc. all contribute to Badput. -The ML Goodput Measurement library exposes Badput Breakdown. Further details of each bucket can be found [here](https://github.com/AI-Hypercomputer/ml-goodput-measurement?tab=readme-ov-file#badput-breakdown-details) +The ML Goodput Measurement library exposes Badput Breakdown. Further details of each bucket can be found [here](https://github.com/AI-Hypercomputer/ml-goodput-measurement/blob/main/README.md#badput-breakdown-details) ## What is Step Time Deviation @@ -69,8 +69,8 @@ following access scope during node pool creation: XPK adds this access scope to the GPU, TPU and CPU node pools, so XPK is the recommended method to create clusters and node-pools in you intend to run your workloads on GKE. Instructions on how to create clusters using XPK can be -found [here](https://github.com/AI-Hypercomputer/xpk/blob/main/README.md#cluster-create) and how to create workloads using XPK can be found -[here](https://github.com/AI-Hypercomputer/xpk/blob/main/README.md#workload-create). +found [here](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/usage/clusters.md) and how to create workloads using XPK can be found +[here](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/usage/workloads.md). ```{note} Access Scopes are immutable and workloads can only be migrated to new node pools with required access scopes. Access scopes on already created clusters cannot be updated. @@ -131,7 +131,7 @@ If checkpointing is enabled, please enable the `enable_checkpoint_cloud_logger` #### Visualize Goodput, Badput and step deviation on Google Cloud Monitoring -By default, performance data ([goodput](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/goodput_time), [badput](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/badput_time), and [step deviation](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/performance)) is automatically sent to Google Cloud Monitoring, enabling visualization on dashboards. +By default, performance data ([goodput](https://docs.cloud.google.com/monitoring/api/metrics_gcp_c), [badput](https://docs.cloud.google.com/monitoring/api/metrics_gcp_c), and [step deviation](https://docs.cloud.google.com/monitoring/api/metrics_gcp_c)) is automatically sent to Google Cloud Monitoring, enabling visualization on dashboards. This feature leverages Google VM metadata (project ID, location, accelerator type) and supports replica IDs for uniquely identifying workloads in multi-replica @@ -184,13 +184,13 @@ Goodput, Badput and Step Time Deviation metrics can be monitored using GCM Metri 2. Navigate to [Metrics Explorer](https://console.cloud.google.com/monitoring/metrics-explorer). Initiate metric selection by clicking `Select a metric` then search for and select the `Workload` resource. Subsequently, choose the `Workload` metric category. - a. [**Productive Time:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/goodput_time) + a. [**Productive Time:**](https://docs.cloud.google.com/monitoring/api/metrics_gcp_c) Represents the cumulative duration the workload spent on productive tasks, measured by `compute.googleapis.com/workload/goodput_time`.\ - b. [**Non-Productive Time:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/badput_time) + b. [**Non-Productive Time:**](https://docs.cloud.google.com/monitoring/api/metrics_gcp_c) Represents the cumulative duration the workload spent on non-productive tasks, measured by `compute.googleapis.com/workload/badput_time`.\ - c. [**Performance:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/performance) + c. [**Performance:**](https://docs.cloud.google.com/monitoring/api/metrics_gcp_c) Represents the workload's performance metric, specifically step deviation in this context, measured by `compute.googleapis.com/workload/performance`. diff --git a/docs/guides/monitoring_and_debugging/use_vertex_ai_tensorboard.md b/docs/guides/monitoring_and_debugging/use_vertex_ai_tensorboard.md index 5bb0a439af..51fa346313 100644 --- a/docs/guides/monitoring_and_debugging/use_vertex_ai_tensorboard.md +++ b/docs/guides/monitoring_and_debugging/use_vertex_ai_tensorboard.md @@ -28,7 +28,7 @@ You can use a single Vertex AI Tensorboard instance to track and compare metrics ## Prerequisites -- Enable [Vertex AI API](https://cloud.google.com/vertex-ai/docs/start/cloud-environment#enable_vertexai_apis) in your Google Cloud console. +- Enable [Vertex AI API](https://docs.cloud.google.com/vertex-ai/docs/start/cloud-environment#set_up_a_project) in your Google Cloud console. - Assign [Vertex AI User IAM role](https://cloud.google.com/vertex-ai/docs/general/access-control#aiplatform.user) to the service account used by the TPU VMs. This is required to create and access the Vertex AI Tensorboard in Google Cloud console. If you are using XPK for MaxText, the necessary Vertex AI User IAM role will be automatically assigned to your node pools by XPK – no need to assign it manually. ## Upload logs to Vertex AI Tensorboard diff --git a/docs/guides/optimization/custom_model.md b/docs/guides/optimization/custom_model.md index 7bbb93edf1..3ba6a1df59 100644 --- a/docs/guides/optimization/custom_model.md +++ b/docs/guides/optimization/custom_model.md @@ -254,7 +254,7 @@ Ironwood over ICI: - `3 * M * 8 / 2 > 12800` - `M > 1100` -It is important to emphasize that this is a theoretical roofline analysis. Real-world performance will depend on the efficiency of the implementation and XLA compilation on the TPU. Refer to the [link](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/sharding.md#pp--fsdpdp) for specific challenges regarding PP + FSDP/DP. +It is important to emphasize that this is a theoretical roofline analysis. Real-world performance will depend on the efficiency of the implementation and XLA compilation on the TPU. Refer to the [link](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/optimization/sharding.html) for specific challenges regarding PP + FSDP/DP. ## Step 4. Analyze experiments diff --git a/docs/guides/optimization/pallas_kernels_performance.md b/docs/guides/optimization/pallas_kernels_performance.md index 007c2f103c..6ec1b2a51a 100644 --- a/docs/guides/optimization/pallas_kernels_performance.md +++ b/docs/guides/optimization/pallas_kernels_performance.md @@ -69,8 +69,6 @@ To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth > This is an efficient computation method for Mixture-of-Experts (MoE) models like DeepSeek, Llama 4, Mixtral and Qwen-MoE. In MoE, each token is processed by only a few "experts," which is inefficient for standard matrix multiplication. Megablox solves this by having the CPU (**host**) first create a routing plan (**metadata**) that assigns tokens to experts. The accelerator (**device**) then uses this plan to perform many small, dense matrix multiplications in parallel (**Grouped Matrix Multiplication**), avoiding wasted work on unused experts. - - [`src/maxtext/kernels/megablox/gmm.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/kernels/megablox/gmm.py) - **Note:** Megablox accelerates the grouped **matmul**; **routing/gating** is separate code ([`src/maxtext/layers/moe.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/moe.py)). ## 🔧 The Pallas optimization workflow: code → profile → tune → repeat diff --git a/docs/reference/architecture/jax_ai_libraries_chosen.md b/docs/reference/architecture/jax_ai_libraries_chosen.md index ea1b7ad8d9..4dac03eb44 100644 --- a/docs/reference/architecture/jax_ai_libraries_chosen.md +++ b/docs/reference/architecture/jax_ai_libraries_chosen.md @@ -56,7 +56,7 @@ For more information on using Orbax, please refer to https://github.com/google/o 1. **Deterministic by Design**: Grain allows storing data loader states, provides strong guarantees about data ordering and sharding even with preemptions, which is critical for reproducibility. 2. **Global Shuffle**: Prevents local overfitting. -3. **Built for Multi-Host Training**: The using random access file format streamlines [data loading in the multi-host environments](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline.md#multihost-dataloading-best-practice). +3. **Built for Multi-Host Training**: The using random access file format streamlines [data loading in the multi-host environments](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/data_input_pipeline.html#multihost-dataloading-best-practice). Its APIs are explicitly designed for the multi-host paradigm, simplifying the process of ensuring that each host loads a unique shard of the global batch. diff --git a/docs/reference/core_concepts.md b/docs/reference/core_concepts.md index 4c03c48ddc..209118aab3 100644 --- a/docs/reference/core_concepts.md +++ b/docs/reference/core_concepts.md @@ -69,6 +69,7 @@ maxdepth: 1 --- core_concepts/checkpoints.md core_concepts/alternatives.md +core_concepts/batch_size.md core_concepts/quantization.md core_concepts/tiling.md core_concepts/jax_xla_and_pallas.md diff --git a/docs/reference/core_concepts/batch_size.md b/docs/reference/core_concepts/batch_size.md index c1a17764a4..134a495c86 100644 --- a/docs/reference/core_concepts/batch_size.md +++ b/docs/reference/core_concepts/batch_size.md @@ -34,7 +34,7 @@ You can set `per_device_batch_size` and `gradient_accumulation_steps` in `config `global_batch_to_load` = `global_batch_size_to_train_on x expansion_factor_real_data` -When `expansion_factor_real_data > 1`, only a subset of hosts read data from the source (e.g., a GCS bucket). These "loading hosts" read more data than they need for their own devices and distribute the surplus to other "non-loading" hosts. This reduces the number of concurrent connections to the data source, which can significantly improve I/O throughput. When set to between 0 and 1, it's for grain pipeline to use a smaller chip count to read checkpoint from a larger chip count job. Details in https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline/data_input_grain.md#using-grain. +When `expansion_factor_real_data > 1`, only a subset of hosts read data from the source (e.g., a GCS bucket). These "loading hosts" read more data than they need for their own devices and distribute the surplus to other "non-loading" hosts. This reduces the number of concurrent connections to the data source, which can significantly improve I/O throughput. When set to between 0 and 1, it's for grain pipeline to use a smaller chip count to read checkpoint from a larger chip count job. Details in https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/data_input_pipeline/data_input_grain.html#using-grain. ## Gradient Accumulation Steps diff --git a/docs/reference/models/supported_models_and_architectures.md b/docs/reference/models/supported_models_and_architectures.md index f329c9ba2b..e1dfa1957a 100644 --- a/docs/reference/models/supported_models_and_architectures.md +++ b/docs/reference/models/supported_models_and_architectures.md @@ -10,7 +10,7 @@ MaxText is an open-source, high-performance LLM framework written in Python/JAX. - **Supported Precisions**: FP32, BF16, INT8, and FP8. - **Ahead-of-Time Compilation (AOT)**: For faster model development/prototyping and earlier OOM detection. -- **Quantization**: Via **Qwix** (recommended) and AQT. See Quantization [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/quantization.md). +- **Quantization**: Via **Qwix** (recommended) and AQT. See Quantization [Guide](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/reference/core_concepts/quantization.html). - **Diagnostics**: Structured error context via **`cloud_tpu_diagnostics`** (filters stack traces to user code), simple logging via `max_logging`, profiling in **XProf**, and visualization in **TensorBoard**. - **Multi-Token Prediction (MTP)**: Enables token efficient training with mutli-token prediction. - **Elastic Training**: Fault-tolorent and dynamic scale-up/scale-down on Cloud TPUs with Pathways. diff --git a/docs/reference/models/tiering.md b/docs/reference/models/tiering.md index b757c07e77..ba83b30271 100644 --- a/docs/reference/models/tiering.md +++ b/docs/reference/models/tiering.md @@ -1,6 +1,6 @@ # Optimized models tiering -For each of the TPU platforms listed below, we present a list of optimized models[^1] [^2] for pre-training. If you’re getting started with MaxText, or want to push performance, we recommend choosing a Gold model, with an accompanying pre-training recipe. +For each of the TPU platforms listed below, we present a list of optimized models[1] [2] for pre-training. If you’re getting started with MaxText, or want to push performance, we recommend choosing a Gold model, with an accompanying pre-training recipe. - **Gold Tier**: Fully Optimized Models certified to run with maximum efficiency on Cloud TPUs. They are thoroughly refined for the highest possible performance, making them ideal for production-critical workloads requiring peak throughput. @@ -38,5 +38,6 @@ For each of the TPU platforms listed below, we present a list of optimized model | :----------- | :-------------------------------------------------------------------------------------------------- | :------------------------------ | :----- | :----------------------- | | Mixtral 8X7B | [Link](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/training/v5p/Mixtral-8X7B-Maxtext) | 256 Chips(8x4x4), bf16, SL=4096 | 52.56% | 2,909 | -\[^1\]: Performance results are subject to variations based on system configuration, software versions, and other factors. These benchmarks represent point-in-time measurements under specific conditions. -\[^2\]: Some older TFLOPS/s results are impacted by an updated calculation for causal attention ([PR #1988](https://github.com/AI-Hypercomputer/maxtext/pull/1988)), which halves the attention FLOPs. This change particularly affects configurations with large sequence lengths. For more details, please refer to the [performance metrics guide](https://maxtext.readthedocs.io/en/latest/reference/performance_metrics.html). +\[1\]: Performance results are subject to variations based on system configuration, software versions, and other factors. These benchmarks represent point-in-time measurements under specific conditions. + +\[2\]: Some older TFLOPS/s results are impacted by an updated calculation for causal attention ([PR #1988](https://github.com/AI-Hypercomputer/maxtext/pull/1988)), which halves the attention FLOPs. This change particularly affects configurations with large sequence lengths. For more details, please refer to the [performance metrics guide](https://maxtext.readthedocs.io/en/latest/reference/performance_metrics.html). diff --git a/docs/release_notes.md b/docs/release_notes.md index 7192da4d77..80f675f1c5 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -32,7 +32,7 @@ MaxText is [available in PyPI](https://pypi.org/project/maxtext/) and can be ins ### v0.2.0 -# Changes +#### Changes - New `tpu-post-train` target in PyPI. Please also use this installation option for running vllm_decode. See the [MaxText installation instructions](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) for more info. - [Qwen3-Next](https://github.com/AI-Hypercomputer/maxtext/blob/7656eb8d1c9eb0dd91e617a6fdf6ad805221221a/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md) is now supported. @@ -50,7 +50,7 @@ MaxText is [available in PyPI](https://pypi.org/project/maxtext/) and can be ins - Vocabulary tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_vocab_tiling` to unlock more efficient memory usage. - The GPT-OSS family of models (20B, 120B) is now supported. -# Deprecations +#### Deprecations - Many MaxText modules have changed locations. Core commands like train, decode, sft, etc. will still work as expected temporarily. Please update your commands to the latest file locations - install_maxtext_github_deps installation script replaced with install_maxtext_tpu_github_deps diff --git a/docs/tutorials/inference.md b/docs/tutorials/inference.md index 6b7bf432d1..adc94cee64 100644 --- a/docs/tutorials/inference.md +++ b/docs/tutorials/inference.md @@ -130,7 +130,7 @@ curl http://localhost:8000/v1/completions \ > **_NOTE:_** > You will need a HuggingFace token to run this command in addition to a MaxText model checkpoint. Please see the following [guide](https://huggingface.co/docs/hub/en/security-tokens) to generate one. -To use a MaxText model architecture for samplers in reinforcement learning algorithms like GRPO, we can override the vLLM model architecture and pass in MaxText specific config arguments similar to the [online inference](online-inference) use-case. An example of an RL command using the MaxText model for samplers can be found below: +To use a MaxText model architecture for samplers in reinforcement learning algorithms like GRPO, we can override the vLLM model architecture and pass in MaxText specific config arguments similar to the [online inference](https://maxtext.readthedocs.io/en/latest/tutorials/inference.html#online-inference) use-case. An example of an RL command using the MaxText model for samplers can be found below: ```bash python3 -m src.maxtext.trainers.post_train.rl.train_rl \ diff --git a/docs/tutorials/post_training_index.md b/docs/tutorials/post_training_index.md index 46a68f830b..0ed355e16b 100644 --- a/docs/tutorials/post_training_index.md +++ b/docs/tutorials/post_training_index.md @@ -10,7 +10,7 @@ We’re investing in performance, scale, algorithms, models, reliability, and ea MaxText was co-designed with key Google led innovations to provide a unified post training experience: -- [MaxText model library](https://maxtext.readthedocs.io/en/latest/index.html#model-library) for JAX LLMs highly optimized for TPUs +- [MaxText model library](https://maxtext.readthedocs.io/en/latest/reference/models/supported_models_and_architectures.html#supported-model-families) for JAX LLMs highly optimized for TPUs - [Tunix](https://github.com/google/tunix) for the latest algorithms and post-training techniques - [vLLM on TPU](https://github.com/vllm-project/tpu-inference) for high performance sampling (inference) for Reinforcement Learning (RL) - [Pathways](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro) for multi-host inference (sampling) and highly efficient weight transfer @@ -38,7 +38,7 @@ Here is an example of the steps you might go through to run a Reinforcement Lear ## What is Pathways and why is it key for RL? -Pathways is a single controller JAX runtime that was [designed and pressure tested internally at Google DeepMind](https://blog.google/technology/ai/introducing-pathways-next-generation-ai-architecture/) over many years. Now available on Google Cloud, it is designed to coordinate distributed computations across thousands of accelerators from a single Python program. It efficiently performs data transfers between accelerators both within a slice using ICI (Inter-chip Interconnect) and across slices over DCN (Data Center Network). +Pathways is a single controller JAX runtime that was [designed and pressure tested internally at Google DeepMind](https://blog.google/innovation-and-ai/products/introducing-pathways-next-generation-ai-architecture/) over many years. Now available on Google Cloud, it is designed to coordinate distributed computations across thousands of accelerators from a single Python program. It efficiently performs data transfers between accelerators both within a slice using ICI (Inter-chip Interconnect) and across slices over DCN (Data Center Network). Pathways allows for fine grained resource allocation (subslice of a physical slice) and scheduling. This allows JAX developers to explore novel model architectures in an easy to develop single controller programming environment. diff --git a/docs/tutorials/posttraining/knowledge_distillation.md b/docs/tutorials/posttraining/knowledge_distillation.md index 8a068c66c7..79b21d3e09 100644 --- a/docs/tutorials/posttraining/knowledge_distillation.md +++ b/docs/tutorials/posttraining/knowledge_distillation.md @@ -104,7 +104,7 @@ mkdir -p ${BASE_DIRECTORY?} ### Obtain and prepare the teacher model -For the teacher model, we will use **vLLM** to run inference. vLLM can load Hugging Face checkpoints directly, so **no conversion to MaxText format is needed** for the teacher. Ensure the teacher model is supported on TPU vLLM (refer to the [vLLM TPU recommended models](https://docs.vllm.ai/projects/tpu/en/latest/recommended_models_features/#text-only-models) for the latest list). +For the teacher model, we will use **vLLM** to run inference. vLLM can load Hugging Face checkpoints directly, so **no conversion to MaxText format is needed** for the teacher. Ensure the teacher model is supported on TPU vLLM (refer to the [vLLM TPU recommended models](https://docs.vllm.ai/projects/tpu/en/latest/recommended_models_features) for the latest list). You can simply download the model from Hugging Face to your local directory: