diff --git a/.coveragerc b/.coveragerc index aceea9571f..7e4acd5b39 100644 --- a/.coveragerc +++ b/.coveragerc @@ -10,7 +10,7 @@ omit = [paths] source = src/MaxText - src/MaxText + src/maxtext */site-packages/MaxText */site-packages/maxtext diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 39313ef66c..1b601dbe72 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,7 +42,7 @@ repos: # args: # - '--jobs=auto' # - '--keep-going' - # - 'src/MaxText/' + # - 'src/maxtext/' - repo: https://github.com/google/pyink rev: 24.10.1 diff --git a/README.md b/README.md index 938152e7e5..7489881d8a 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies * \[October 10, 2025\] Post-Training (SFT, RL) via [Tunix](https://github.com/google/tunix) is now available. * \[September 26, 2025\] Vocabulary tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_vocab_tiling` to unlock more efficient memory usage. * \[September 24, 2025\] The GPT-OSS family of models (20B, 120B) is now supported. -* \[September 15, 2025\] MaxText is now available as a [PyPI package](https://pypi.org/project/maxtext). Users can now [install maxtext through pip](https://maxtext.readthedocs.io/en/latest/guides/install_maxtext.html). +* \[September 15, 2025\] MaxText is now available as a [PyPI package](https://pypi.org/project/maxtext). Users can now [install maxtext through pip](https://maxtext.readthedocs.io/en/latest/install_maxtext.html). * \[September 5, 2025\] MaxText has moved to an `src` layout as part of [RESTRUCTURE.md](https://github.com/AI-Hypercomputer/maxtext/blob/aca5b24931ebcbadb55a82e56ebffe8024874028/RESTRUCTURE.md). For existing environments, please run `pip install -e .` from MaxText root. * \[August 13, 2025\] The Qwen3 2507 MoE family of models is now supported: MoEs: 235B Thinking & 480B Coder as well as existing dense models: 0.6B, 4B, 8B, 14B, and 32B. * \[July 27, 2025\] Updated TFLOPS/s calculation ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/1988)) to account for causal attention, dividing the attention flops in half. Accounted for sliding window and chunked attention reduced attention flops in [PR](https://github.com/AI-Hypercomputer/maxtext/pull/2009) and [PR](https://github.com/AI-Hypercomputer/maxtext/pull/2030). Changes impact large sequence configs, as explained in this [doc](https://maxtext.readthedocs.io/en/latest/reference/performance_metrics.html) diff --git a/benchmarks/convergence/c4_exp.py b/benchmarks/convergence/c4_exp.py index c349a935e3..daa2241893 100644 --- a/benchmarks/convergence/c4_exp.py +++ b/benchmarks/convergence/c4_exp.py @@ -23,7 +23,6 @@ from benchmarks.benchmark_utils import MaxTextModel, _add_to_model_dictionary from benchmarks.convergence.convergence_utils import DatasetHParams, ConvHParams, _setup_model_convergence_ - from benchmarks.maxtext_v5p_model_configs import deepseek_v3_ep_256_v5p_512 c4_pretrain_model_dict = {} diff --git a/benchmarks/disruption_management/monitor.py b/benchmarks/disruption_management/monitor.py index 31960a7d19..5952f5f935 100644 --- a/benchmarks/disruption_management/monitor.py +++ b/benchmarks/disruption_management/monitor.py @@ -29,7 +29,6 @@ import time from benchmarks.disruption_management.disruption_utils import wait_for_pod_to_start - from benchmarks.disruption_management.disruption_handler import DisruptionConfig from benchmarks.disruption_management.disruption_handler import TriggerType diff --git a/benchmarks/llama2_v6e-256_benchmarks.py b/benchmarks/llama2_v6e-256_benchmarks.py index 2b56e0e3e6..4cdebf0059 100644 --- a/benchmarks/llama2_v6e-256_benchmarks.py +++ b/benchmarks/llama2_v6e-256_benchmarks.py @@ -17,13 +17,12 @@ on a specific v6e-256 hardware setup using the XPK runner. """ -import maxtext_trillium_model_configs as model_configs +import os -from maxtext_xpk_runner import BenchmarkRunner -from maxtext_xpk_runner import HWConfig -from maxtext_xpk_runner import SWconfig -from maxtext_xpk_runner import xpk_benchmark_runner -from maxtext_xpk_runner import XpkConfig +from benchmarks import maxtext_trillium_model_configs as model_configs +from benchmarks.maxtext_xpk_runner import WorkloadConfig +from benchmarks.maxtext_xpk_runner import xpk_benchmark_runner +from benchmarks.maxtext_xpk_runner import XpkClusterConfig DATE = "20241009" @@ -35,34 +34,37 @@ DEVICE_TYPE = "v6e-256" NUM_SLICES = 1 BASE_OUTPUT_DIR = "gs://maxtext-experiments-tpem/" - -v6e_env_configs = SWconfig(base_docker_image=BASE_DOCKER_IMAGE, libtpu_version=DATE) -v6e_256_configs = HWConfig(num_slices=NUM_SLICES, device_type=DEVICE_TYPE) - -llama2_70b_4096 = BenchmarkRunner( - model_name=model_configs.llama2_70b_4096, - software_config=v6e_env_configs, - hardware_config=v6e_256_configs, -) - -llama2_7b_4096 = BenchmarkRunner( - model_name=model_configs.llama2_7b_4096, - software_config=v6e_env_configs, - hardware_config=v6e_256_configs, -) +XPK_PATH = os.path.join("~", "xpk") +BENCHMARK_STEPS = 20 def main() -> None: - cluster_config = XpkConfig( + cluster_config = XpkClusterConfig( cluster_name=CLUSTER_NAME, project=PROJECT, zone=ZONE, - num_slices=NUM_SLICES, device_type=DEVICE_TYPE, - base_output_directory=BASE_OUTPUT_DIR, ) - xpk_benchmark_runner(cluster_config, [llama2_7b_4096, llama2_70b_4096]) + workload_configs = [] + for model in [model_configs.llama2_7b_4096, model_configs.llama2_70b_4096]: + workload_configs.append( + WorkloadConfig( + model=model, + num_slices=NUM_SLICES, + device_type=DEVICE_TYPE, + base_output_directory=BASE_OUTPUT_DIR, + base_docker_image=BASE_DOCKER_IMAGE, + libtpu_type=None, + libtpu_nightly_version=DATE, + pathways_config=None, + xpk_path=XPK_PATH, + num_steps=BENCHMARK_STEPS, + priority="medium", + ) + ) + + xpk_benchmark_runner(cluster_config, workload_configs) if __name__ == "__main__": diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index 2a7ebc8b66..f0510533c7 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -35,9 +35,9 @@ import omegaconf import benchmarks.maxtext_trillium_model_configs as model_configs +import benchmarks.xla_flags_library as xla_flags from benchmarks.globals import MAXTEXT_PKG_DIR from benchmarks.command_utils import run_command_with_updates -import benchmarks.xla_flags_library as xla_flags from benchmarks.disruption_management.disruption_handler import DisruptionConfig from benchmarks.disruption_management.disruption_manager import DisruptionManager from benchmarks.xpk_configs import XpkClusterConfig diff --git a/benchmarks/recipes/mcjax_long_running_recipe.py b/benchmarks/recipes/mcjax_long_running_recipe.py index 17222723d0..9fcd24b5bc 100644 --- a/benchmarks/recipes/mcjax_long_running_recipe.py +++ b/benchmarks/recipes/mcjax_long_running_recipe.py @@ -27,7 +27,7 @@ import benchmarks.maxtext_trillium_model_configs as model_configs import benchmarks.maxtext_xpk_runner as mxr from benchmarks.xpk_configs import XpkClusterConfig -from . import user_configs +from benchmarks.recipes import user_configs # Cluster Params CLUSTER = "v6e-256-cluster" diff --git a/benchmarks/recipes/pw_elastic_training_recipe.py b/benchmarks/recipes/pw_elastic_training_recipe.py index 7ab1f3b7e6..3a3f68aba0 100644 --- a/benchmarks/recipes/pw_elastic_training_recipe.py +++ b/benchmarks/recipes/pw_elastic_training_recipe.py @@ -25,11 +25,11 @@ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(parent_dir) -from . import args_helper as helper -from . import user_configs from benchmarks.disruption_management.disruption_handler import DisruptionMethod -from .runner_utils import generate_and_run_workloads +from benchmarks.recipes import args_helper as helper +from benchmarks.recipes import user_configs +from benchmarks.recipes.runner_utils import generate_and_run_workloads user_configs.USER_CONFIG.max_restarts = 10 COMPARE_WITH_MCJAX = True diff --git a/benchmarks/recipes/pw_headless_mode.py b/benchmarks/recipes/pw_headless_mode.py index eaac22782b..94064bf1fe 100644 --- a/benchmarks/recipes/pw_headless_mode.py +++ b/benchmarks/recipes/pw_headless_mode.py @@ -22,8 +22,8 @@ """ import benchmarks.recipes.args_helper as helper -from .. import maxtext_xpk_runner as mxr -from ..recipes.user_configs import USER_CONFIG +from benchmarks import maxtext_xpk_runner as mxr +from benchmarks.recipes.user_configs import USER_CONFIG def main() -> int: diff --git a/benchmarks/recipes/pw_long_running_recipe.py b/benchmarks/recipes/pw_long_running_recipe.py index f1f2b52c81..bc1b3019ba 100644 --- a/benchmarks/recipes/pw_long_running_recipe.py +++ b/benchmarks/recipes/pw_long_running_recipe.py @@ -27,13 +27,10 @@ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(parent_dir) -import recipes.args_helper as helper - -import maxtext_trillium_model_configs as model_configs - -import maxtext_xpk_runner as mxr - -from xpk_configs import XpkClusterConfig +import benchmarks.maxtext_trillium_model_configs as model_configs +import benchmarks.maxtext_xpk_runner as mxr +import benchmarks.recipes.args_helper as helper +from benchmarks.xpk_configs import XpkClusterConfig PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server" SERVER_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server" @@ -66,7 +63,7 @@ def main(): ) # Handle command line arguments using args_helper - should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, xpk_path=XPK_PATH) + should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, USER, xpk_path=XPK_PATH) if not should_continue: return diff --git a/benchmarks/recipes/pw_mcjax_benchmark_recipe.py b/benchmarks/recipes/pw_mcjax_benchmark_recipe.py index 575150bbaa..7deeb3cabf 100644 --- a/benchmarks/recipes/pw_mcjax_benchmark_recipe.py +++ b/benchmarks/recipes/pw_mcjax_benchmark_recipe.py @@ -18,14 +18,14 @@ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(parent_dir) -from . import args_helper as helper -from .user_configs import UserConfig -from .user_configs import USER_CONFIG -from .runner_utils import generate_and_run_workloads -from . import parser_utils +from benchmarks.recipes import args_helper as helper +from benchmarks.recipes import parser_utils +from benchmarks.recipes.pw_utils import check_and_create_bucket +from benchmarks.recipes.runner_utils import generate_and_run_workloads +from benchmarks.recipes.user_configs import UserConfig +from benchmarks.recipes.user_configs import USER_CONFIG import argparse from google.cloud import storage -from .pw_utils import check_and_create_bucket def main(user_config) -> int: diff --git a/benchmarks/recipes/pw_mcjax_checkpoint_benchmark_recipe.py b/benchmarks/recipes/pw_mcjax_checkpoint_benchmark_recipe.py index 01ef72df63..2bb9760b74 100644 --- a/benchmarks/recipes/pw_mcjax_checkpoint_benchmark_recipe.py +++ b/benchmarks/recipes/pw_mcjax_checkpoint_benchmark_recipe.py @@ -22,10 +22,10 @@ import datetime import dataclasses import os -import args_helper as helper -from benchmarks import maxtext_trillium_model_configs as model_configs import benchmarks.maxtext_xpk_runner as mxr +from benchmarks import maxtext_trillium_model_configs as model_configs +from benchmarks.recipes import args_helper as helper from benchmarks.xpk_configs import XpkClusterConfig PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server" @@ -185,7 +185,7 @@ def main() -> int: ) # Handle command line arguments using args_helper - should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, xpk_path=XPK_PATH) + should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, os.environ["USER"], xpk_path=XPK_PATH) if not should_continue: return 0 diff --git a/benchmarks/recipes/pw_remote_python_recipe.py b/benchmarks/recipes/pw_remote_python_recipe.py index 62653bb3ab..f6ef732443 100644 --- a/benchmarks/recipes/pw_remote_python_recipe.py +++ b/benchmarks/recipes/pw_remote_python_recipe.py @@ -21,7 +21,7 @@ import os -import args_helper as helper +import benchmarks.recipes.args_helper as helper from benchmarks import maxtext_trillium_model_configs as model_configs from benchmarks import maxtext_xpk_runner as mxr @@ -40,7 +40,7 @@ def main(): xpk_path = "xpk" # Handle command line arguments using args_helper - should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, xpk_path=xpk_path) + should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, os.environ["USER"], xpk_path=xpk_path) if not should_continue: return diff --git a/benchmarks/recipes/pw_suspend_resume.py b/benchmarks/recipes/pw_suspend_resume.py index addc28c1e7..490d25c0d8 100644 --- a/benchmarks/recipes/pw_suspend_resume.py +++ b/benchmarks/recipes/pw_suspend_resume.py @@ -25,11 +25,10 @@ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(parent_dir) -from . import args_helper as helper -from . import user_configs - from benchmarks.disruption_management.disruption_handler import DisruptionMethod -from .runner_utils import generate_and_run_workloads +from benchmarks.recipes import args_helper as helper +from benchmarks.recipes import user_configs +from benchmarks.recipes.runner_utils import generate_and_run_workloads user_configs.USER_CONFIG.max_restarts = 3 DISRUPTION_METHOD = DisruptionMethod.SIGTERM diff --git a/benchmarks/recipes/pw_utils.py b/benchmarks/recipes/pw_utils.py index 4ec4bca9c3..3f044b9f1c 100644 --- a/benchmarks/recipes/pw_utils.py +++ b/benchmarks/recipes/pw_utils.py @@ -20,7 +20,7 @@ import typing -import maxtext_xpk_runner as mxr +import benchmarks.maxtext_xpk_runner as mxr from google.api_core.exceptions import ( NotFound, Conflict, diff --git a/benchmarks/recipes/runner_utils.py b/benchmarks/recipes/runner_utils.py index 43626b59f8..f1f45feda2 100644 --- a/benchmarks/recipes/runner_utils.py +++ b/benchmarks/recipes/runner_utils.py @@ -16,7 +16,7 @@ import logging -from .. import maxtext_xpk_runner as mxr +from benchmarks import maxtext_xpk_runner as mxr from benchmarks.benchmark_utils import Framework from benchmarks.disruption_management.disruption_manager import construct_disruption_configs diff --git a/benchmarks/recipes/user_configs.py b/benchmarks/recipes/user_configs.py index 5d283c8b47..0e2ad05e1e 100644 --- a/benchmarks/recipes/user_configs.py +++ b/benchmarks/recipes/user_configs.py @@ -27,10 +27,10 @@ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(parent_dir) -from .. import maxtext_trillium_model_configs as v6e_model_configs -from .. import maxtext_v5e_model_configs as v5e_model_configs -from .. import maxtext_v5p_model_configs as v5p_model_configs -from .pw_utils import build_user_models, get_cluster_config, get_pathways_config +from benchmarks import maxtext_trillium_model_configs as v6e_model_configs +from benchmarks import maxtext_v5e_model_configs as v5e_model_configs +from benchmarks import maxtext_v5p_model_configs as v5p_model_configs +from benchmarks.recipes.pw_utils import build_user_models, get_cluster_config, get_pathways_config AVAILABLE_MODELS_FRAMEWORKS = ["mcjax", "pathways"] diff --git a/codecov.yml b/codecov.yml index f5971c2a21..302d8bc243 100644 --- a/codecov.yml +++ b/codecov.yml @@ -27,7 +27,7 @@ codecov: token: 35742a22-fb1f-4839-97ff-b54da5588689 # By default file names in the coverage report will have their path in the file system, which in our -# runners would be /__w/maxtext/maxtext/src/MaxText/* but Codecov expects src/MaxText/* so we need to fix the path +# runners would be /__w/maxtext/maxtext/src/maxtext/* but Codecov expects src/maxtext/* so we need to fix the path fixes: # - ".*/maxtext/src/::src/" - "/github/workspace/::" @@ -35,13 +35,10 @@ ignore: - "src/maxtext/assets" - "src/maxtext/configs" - "src/maxtext/examples" - - "src/MaxText/experimental" + - "src/maxtext/experimental" - "src/maxtext/inference" - "src/maxtext/scratch_code" - - "src/MaxText/distillation" # code moved to src/maxtext/trainers/post_train/distillation - - "src/MaxText/sft" # code moved to src/maxtext/trainers/post_train/sft - - "src/MaxText/rl" # code moved to src/maxtext/trainers/post_train/rl - + - "src/MaxText" flags: # Updated ONLY by PRs (contains subset of tests, excluding scheduled_only). diff --git a/docs/guides.md b/docs/guides.md index 6c7e60bf06..bb50cefb32 100644 --- a/docs/guides.md +++ b/docs/guides.md @@ -18,45 +18,45 @@ Explore our how-to guides for optimizing, debugging, and managing your MaxText workloads. -::::{grid} 1 2 2 2 +::::\{grid} 1 2 2 2 :gutter: 2 -:::{grid-item-card} ⚡ Optimization +:::\{grid-item-card} ⚡ Optimization :link: guides/optimization :link-type: doc Techniques for maximizing performance, including sharding strategies, Pallas kernels, and benchmarking. ::: -:::{grid-item-card} 💾 Data Pipelines +:::\{grid-item-card} 💾 Data Pipelines :link: guides/data_input_pipeline :link-type: doc Configure input pipelines using **Grain** (recommended for determinism), **HuggingFace**, or **TFDS**. ::: -:::{grid-item-card} 🔄 Checkpointing +:::\{grid-item-card} 🔄 Checkpointing :link: guides/checkpointing_solutions :link-type: doc Manage GCS checkpoints, handle preemption with emergency checkpointing, and configure multi-tier storage. ::: -:::{grid-item-card} 🔍 Monitoring & Debugging +:::\{grid-item-card} 🔍 Monitoring & Debugging :link: guides/monitoring_and_debugging :link-type: doc Tools for observability: goodput monitoring, hung job debugging, and Vertex AI TensorBoard integration. ::: -:::{grid-item-card} 🐍 Python Notebooks +:::\{grid-item-card} 🐍 Python Notebooks :link: guides/run_python_notebook :link-type: doc Interactive development guides for running MaxText on Google Colab or local JupyterLab environments. ::: -:::{grid-item-card} 🌱 Model Bringup +:::\{grid-item-card} 🌱 Model Bringup :link: guides/model_bringup :link-type: doc @@ -65,9 +65,10 @@ A step-by-step guide for the community to help expand MaxText's model library. :::: ```{toctree} -:hidden: -:maxdepth: 1 - +--- +hidden: +maxdepth: 1 +--- guides/optimization.md guides/data_input_pipeline.md guides/checkpointing_solutions.md diff --git a/docs/guides/checkpointing_solutions.md b/docs/guides/checkpointing_solutions.md index ee92b1dcab..f902e3f515 100644 --- a/docs/guides/checkpointing_solutions.md +++ b/docs/guides/checkpointing_solutions.md @@ -2,31 +2,31 @@ # Checkpointing -::::{grid} 1 2 2 2 +::::\{grid} 1 2 2 2 :gutter: 2 -:::{grid-item-card} 💾 GCS Checkpointing +:::\{grid-item-card} 💾 GCS Checkpointing :link: checkpointing_solutions/gcs_checkpointing :link-type: doc Standard checkpointing to Google Cloud Storage. ::: -:::{grid-item-card} 🚑 Emergency Checkpointing +:::\{grid-item-card} 🚑 Emergency Checkpointing :link: checkpointing_solutions/emergency_checkpointing :link-type: doc Handle preemption and recover training progress. ::: -:::{grid-item-card} 🗄️ Multi-tier checkpointing +:::\{grid-item-card} 🗄️ Multi-tier checkpointing :link: checkpointing_solutions/multi_tier_checkpointing :link-type: doc Optimize storage costs and performance with multi-tier usage. ::: -:::{grid-item-card} 🔁 Checkpoint conversion utilities +:::\{grid-item-card} 🔁 Checkpoint conversion utilities :link: checkpointing_solutions/convert_checkpoint :link-type: doc diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index 30e9750bdf..eba5fc7261 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -1,6 +1,6 @@ -# Checkpoint conversion utilities +# Checkpoint Conversion Utilities -This guide provides instructions for using the [scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/checkpoint_conversion) that convert model checkpoints bidirectionally between Hugging Face and MaxText formats. +This guide provides instructions to use [checkpoint conversion scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/checkpoint_conversion) to convert model checkpoints bidirectionally between Hugging Face and MaxText formats. ## Supported models @@ -21,58 +21,39 @@ The following models are supported: ## Prerequisites -- Hugging Face requires Pytorch. -- Hugging Face model checkpoints require local disk space. - - The model files are always downloaded to a disk cache first before being loaded into memory (for more info, please consult Hugging Face [docs](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference)). The default local storage path for Hugging Face models is `$HOME/.cache/huggingface/hub` +- MaxText must be installed in a Python virtual environment. For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/install_maxtext.html). +- Hugging Face model checkpoints are cached locally at `$HOME/.cache/huggingface/hub` before conversion. Ensure you have sufficient disk space. +- Authenticate via the [Hugging Face CLI](https://huggingface.co/docs/huggingface_hub/v0.21.2/guides/cli) if using private or gated models. ## Hugging Face to MaxText -Use the `to_maxtext.py` script to convert a Hugging Face model into a MaxText checkpoint. The script will automatically download the specified model from the Hugging Face Hub, perform conversion, and save converted checkpoints to given output directory. +Use the `to_maxtext.py` script to convert a Hugging Face model checkpoint into a MaxText checkpoint. The script will automatically download the specified model from the Hugging Face Hub, perform conversion, and save converted checkpoints to the given output directory. -\*\**For a complete example, see the test script at [`tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh) and [`tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh).* +> **Note:** For more information, checkout [qwen3-4b example script](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh) and [gemma3-4b example script](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh). -### Usage - -First, make sure python3 virtual environment for MaxText is set up and enabled. - -```bash -export VENV_NAME= # e.g., maxtext_venv -pip install uv -uv venv --python 3.12 --seed ${VENV_NAME?} -source ${VENV_NAME?}/bin/activate -``` - -Second, ensure you have the necessary dependencies installed (e.g., install PyTorch for checkpoint conversion and logit check). +### Setup Environment ```bash +# Install PyTorch (in MaxText virtual environment) python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -``` - -Third, setup following environment variables for conversion script - -```bash -# -- Model configuration -- -export MODEL_NAME= # e.g. 'llama3.1-8b-Instruct' -export HF_TOKEN= # your token to access gated HF repos -# -- MaxText configuration -- -export MODEL_CHECKPOINT_DIRECTORY= # e.g., gs://my-bucket/my-checkpoint-directory -# -- storage and format options -export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX. - -export LAZY_LOAD_TENSORS= # True to use lazy load, False to use eager load. +# Setup environment variables +export MODEL= # e.g. 'llama3.1-8b-Instruct' +export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-checkpoint-directory +export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX +export LAZY_LOAD_TENSORS= # Set to True to save RAM ``` -Finally, run below command to complete the conversion +### Run Conversion ```bash -# Optional: If run out of disk space when downloading HuggingFace safetensors, +# Optional: If you run out of disk space when downloading Hugging Face safetensors, # customize your "HF_HOME" to redirect the cache to a larger or mounted disk (e.g., on a TPU VM). # export HF_HOME="/dev/shm/huggingface_tmp" + python3 -m maxtext.checkpoint_conversion.to_maxtext \ - model_name=${MODEL_NAME?} \ - hf_access_token=${HF_TOKEN?} \ - base_output_directory=${MODEL_CHECKPOINT_DIRECTORY?} \ + model_name=${MODEL?} \ + base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ scan_layers=True \ use_multimodal=false \ hardware=cpu \ @@ -82,82 +63,108 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \ --lazy_load_tensors=${LAZY_LOAD_TENSORS?} ``` -- `model_name`: The model identifier, which should be defined in `src/maxtext/configs/types.py`. -- `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false). -- `use_multimodal`: Indicates if multimodality is used, important for Gemma3. -- `hf_access_token`: Your Hugging Face token. -- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`. -- `hardware=cpu`: run the conversion script on a CPU machine. -- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: Set to True for McJAX (default, `USE_PATHWAYS=0`); set to False for Pathways (`USE_PATHWAYS=1`). Both are controlled by the `$((1 - USE_PATHWAYS))` calculation in the example above. -- `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. When memory is constrained, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes. -- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. +You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/items`. + +### Key Parameters -Above command will download the Hugging Face model to local machine if `hf_model_path` is unspecified, or reuse the checkpoint in `hf_model_path`. It will convert the checkpoint to the MaxText format and save it to `${MODEL_CHECKPOINT_DIRECTORY}/0/items`. +- `model_name`: The specific model identifier. It must match a supported entry in the MaxText [globals.py](https://github.com/AI-Hypercomputer/maxtext/blob/16b684840db9b96b19e24e84ac49f06af7204ae3/src/maxtext/utils/globals.py#L46C1-L46C7). +- `scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [here](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/reference/core_concepts/checkpoints.html) for more information. +- `use_multimodal`: Indicates if multimodality is used, important for Gemma3. +- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Google Cloud Storage (GCS) or local. +- `hardware=cpu`: The conversion script runs on a CPU machine. +- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: These storage flags enable McJAX compatibility when set to True (the default). For Pathways, these should be False. +- `--lazy_load_tensors` (Optional): Enables on-demand loading of weights to prevent OOM (Out of Memory) errors. Highly recommended for large models to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes. +- `--hf_model_path` (Optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. ## MaxText to Hugging Face Use the `to_huggingface.py` script to convert a MaxText checkpoint into the Hugging Face format. This is useful for sharing your models or integrating them with the Hugging Face ecosystem. -\*\**For a complete example, see the test script at [`tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh).* -### Usage +> **Note:** For more information, checkout [qwen3-4b example script](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh). + +### Setup Environment + +```bash +# Install PyTorch (in MaxText virtual environment) +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Setup environment variables +export MODEL= # e.g. 'qwen3-4b' +export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items +export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-checkpoint-directory +``` -The following command converts a MaxText checkpoint and saves it locally, to GCS, or uploads it directly to the Hugging Face Hub. +### Run Conversion + +The following command converts a MaxText checkpoint and saves it locally, to GCS (`gs://`), or uploads it directly to the Hugging Face Hub (`hf://`). ```bash python3 -m maxtext.checkpoint_conversion.to_huggingface \ - model_name= \ - load_parameters_path= \ - base_output_directory= \ + model_name=${MODEL?} \ + load_parameters_path=${MAXTEXT_CKPT_PATH?} \ + base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + hardware=cpu \ + skip_jax_distributed_system=true \ scan_layers=false \ use_multimodal=false \ - hf_access_token= \ weight_dtype=bfloat16 ``` -**Key arguments:** +### Key Parameters -- `load_parameters_path`: The path to the source MaxText Orbax checkpoint (e.g., `gs://your-bucket/maxtext-checkpoint/0/items`). -- `model_name`: The corresponding model name in the MaxText configuration (e.g., `qwen3-4b`). -- `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false). -- `hf_access_token`: Your Hugging Face token. +- `model_name`: The specific model identifier. It must match a supported entry in the MaxText [globals.py](https://github.com/AI-Hypercomputer/maxtext/blob/16b684840db9b96b19e24e84ac49f06af7204ae3/src/maxtext/utils/globals.py#L46C1-L46C7). +- `load_parameters_path`: The path to the MaxText Orbax checkpoint. +- `scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [here](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/reference/core_concepts/checkpoints.html) for more information. - `use_multimodal`: Indicates if multimodality is used, important for Gemma3. -- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS), Hugging Face Hub or local. If not set, the default output directory is `Maxtext/tmp`. -- `weight_dtype`: dtype for MaxText weights. It affects the resulting HF weight dtype. Default value is `float32`. We recommend using `bfloat16` to save memory and speed up conversion. +- `hardware=cpu`: The conversion script runs on a CPU machine. +- `base_output_directory`: The path where the converted checkpoint will be stored; it can be Google Cloud Storage (GCS), Hugging Face Hub or local. +- `weight_dtype`: dtype for MaxText weights. It affects the resulting Hugging Face weight dtype. Default value is `float32`. We recommend using `bfloat16` to save memory and speed up conversion. ## Verifying conversion correctness -To ensure the conversion was successful, you can use the [`tests/utils/forward_pass_logit_checker.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/utils/forward_pass_logit_checker.py) script. It runs a forward pass on both the original and converted models and compares the output logits to verify conversion. It is used to verify the bidirectional conversion. +To ensure the conversion was successful, you can use the [test script](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/utils/forward_pass_logit_checker.py). It runs a forward pass on both the original and converted models and compares the output logits to verify conversion. It is used to verify the bidirectional conversion. -### Usage +> **Note:** This correctness test will only work when MaxText is installed from source by following the installation instructions [here](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/install_maxtext.html#from-source). + +### Setup Environment + +```bash +# Setup environment variables +export MODEL= # e.g. 'qwen3-4b' +export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items +export HF_CKPT_PATH= # e.g., gs://my-bucket/my-checkpoint-directory +``` + +### Run Correctness Test ```bash python3 -m tests.utils.forward_pass_logit_checker src/maxtext/configs/base.yml \ - tokenizer_path= \ - load_parameters_path= \ - model_name= \ + load_parameters_path=${MAXTEXT_CKPT_PATH?} \ + model_name=${MODEL?} \ + skip_jax_distributed_system=true \ scan_layers=false \ max_prefill_predict_length=4 \ max_target_length=8 \ use_multimodal=false \ --run_hf_model=True \ - --hf_model_path= \ + --hf_model_path=${HF_CKPT_PATH?} \ --max_kl_div=0.015 ``` -**Key arguments:** +### Key Parameters -- `load_parameters_path`: The path to the source MaxText Orbax checkpoint (e.g., `gs://your-bucket/maxtext-checkpoint/0/items`). +- `load_parameters_path`: The path to the MaxText Orbax checkpoint (e.g., `gs://your-bucket/maxtext-checkpoint/0/items`). - `model_name`: The corresponding model name in the MaxText configuration (e.g., `qwen3-4b`). -- `scan_layers`: Indicates if the output checkpoint is scanned (scan_layers=true) or unscanned (scan_layers=false). +- `scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [here](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/reference/core_concepts/checkpoints.html) for more information. - `use_multimodal`: Indicates if multimodality is used. -- `--run_hf_model` (optional): Indicates if loading Hugging Face model from the hf_model_path. If not set, it will compare the maxtext logits with pre-saved golden logits. -- `--hf_model_path` (optional): The path to the Hugging Face checkpoint (if `--run_hf_model=True`) -- `--golden_logits_path` (optional): The pre-saved golden logits. (if `--run_hf_model` is not set) +- `--run_hf_model` (Optional): Indicates if loading Hugging Face model from the hf_model_path. If not set, it will compare the maxtext logits with pre-saved golden logits. +- `--hf_model_path` (Optional): The path to the Hugging Face checkpoint (if `--run_hf_model=True`). +- `--golden_logits_path` (Optional): The pre-saved golden logits. (if `--run_hf_model` is not set). - `--max_kl_div`: Max KL divergence tolerance during comparisons. -**Example successful conversion verification:** +### Example of Successful Conversion Verification -Here is part of the output of forward_pass_logit_checker for the gemma2-2b. +Here is part of the output of `forward_pass_logit_checker` for the gemma2-2b. ``` --- Prompt: What is the --- @@ -207,33 +214,30 @@ Max KL divergence for a single token in the set: 0.003497 ______________________________________________________________________ -## Adding support for new models +## Troubleshooting and Development + +### Adding New Models To extend conversion support to a new model architecture, you must define its specific parameter and configuration mappings. The conversion logic is decoupled, so you only need to modify the mapping files. 1. **Add parameter mappings**: -- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the parameter name mappings(`def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING`). This is the 1-to-1 mappings of parameters names per layer. -- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer. - -2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion. -3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`. -4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in [`src/maxtext/configs/models`](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture. +- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/param_mapping.py), add the parameter name mappings(`def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING`). This is the 1-to-1 mappings of parameters names per layer. -Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983) +- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer. -## Debugging tips +2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion. -If the converted checkpoint can not get loaded and got error like: "type \ is not a valid JAX type." +3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`. -- **Potential Cause**: The scan_layers flag is set wrong. +4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in [`src/maxtext/configs/models`](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). This configuration must precisely match the MaxText model's architecture. -If a converted checkpoint loads without errors but produces incorrect output, consider these common issues: +Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983). -- **Symptom**: The model generates garbage or nonsensical tokens. +### Common Errors - - **Potential Cause**: The query/key/value (Q/K/V) or Out vectors weights were likely reshaped incorrectly during conversion. +- "Type ShapeDtypeStruct is not a valid JAX type": Usually caused by a mismatch in the `scan_layers` flag. -- **Symptom**: The model generates repetitive text sequences. +- If the converted checkpoint loads without errors but produces nonsensical output, likely an error in the Q/K/V weight reshaping logic during conversion. - - **Potential Cause**: The layer normalization parameters may have been converted incorrectly. +- If the model generates repetitive text sequences, check if layer normalization parameters were mapped correctly. diff --git a/docs/guides/data_input_pipeline.md b/docs/guides/data_input_pipeline.md index 8772db513e..0b65bdfa6c 100644 --- a/docs/guides/data_input_pipeline.md +++ b/docs/guides/data_input_pipeline.md @@ -42,7 +42,7 @@ In MaxText, this is best supported by the ArrayRecord format using the Grain inp - **Concurrent access and uniqueness**: Grain assigns a unique set of indices to each host. ArrayRecord allows different hosts to read from different indices in the same file. -- **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/MaxText/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met. +- **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/maxtext/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met. ```{note} When sequence packing is enabled, the difference in the number of packed examples per host can be larger. The `generate_padding_batch_train`/`generate_padding_batch_eval` flag still solves this. diff --git a/docs/guides/data_input_pipeline/data_pipeline_perf.md b/docs/guides/data_input_pipeline/data_pipeline_perf.md index 2a787346fd..69cca60f36 100644 --- a/docs/guides/data_input_pipeline/data_pipeline_perf.md +++ b/docs/guides/data_input_pipeline/data_pipeline_perf.md @@ -33,8 +33,9 @@ Status of TPU and CPU during training stages. ## Prerequisite: asynchronous execution For this overlap to happen, the data pipeline (on the CPU) and the model computation (on the accelerator) must execute in parallel. You can verify this using a profiler. -* **Good (parallel)**: The trace on the right shows the CPU (bottom tracks) is busy fetching data at the same time the TPU (top track) is computing. -* **Bad (sequential)**: The trace on the left shows a **gap** in TPU utilization, where the TPU is idle. This gap is often caused by **forcing synchronization** (as explained in the following section), not necessarily a slow pipeline. While speeding up data loading might narrow this gap, only removing the synchronization eliminates the gap and achieves true parallelism. + +- **Good (parallel)**: The trace on the right shows the CPU (bottom tracks) is busy fetching data at the same time the TPU (top track) is computing. +- **Bad (sequential)**: The trace on the left shows a **gap** in TPU utilization, where the TPU is idle. This gap is often caused by **forcing synchronization** (as explained in the following section), not necessarily a slow pipeline. While speeding up data loading might narrow this gap, only removing the synchronization eliminates the gap and achieves true parallelism. ```{figure} ../../_static/data_profile.png Example profiles of sequential (left) vs. parallel (right) data loading with TPU computation. @@ -43,17 +44,20 @@ Example profiles of sequential (left) vs. parallel (right) data loading with TPU ## Common pitfall: forcing synchronization JAX's asynchronous dispatch allows the CPU to run ahead. However, this parallelism breaks if your host code (Python) tries to access the result of a computation before it's finished. -* **Example**: Calling `print(loss)` or `.block_until_ready()` on a JAX array from the current step forces the host to wait for the accelerator, stalling the data pipeline. -* **MaxText solution**: MaxText avoids this by using a metrics cache. It only prints the loss from the previous step, allowing the current step's computation and the next step's data loading to proceed in parallel (see [buffer_and_write_train_metrics()](https://github.com/AI-Hypercomputer/maxtext/blob/1c6f5a26dc155262d2ebdd68223397107dfd4b95/src/MaxText/metric_logger.py#L193) in `metric_logger.py`). + +- **Example**: Calling `print(loss)` or `.block_until_ready()` on a JAX array from the current step forces the host to wait for the accelerator, stalling the data pipeline. +- **MaxText solution**: MaxText avoids this by using a metrics cache. It only prints the loss from the previous step, allowing the current step's computation and the next step's data loading to proceed in parallel (see [buffer_and_write_train_metrics()](https://github.com/AI-Hypercomputer/maxtext/blob/1c6f5a26dc155262d2ebdd68223397107dfd4b95/src/MaxText/metric_logger.py#L193) in `metric_logger.py`). ## How to test your pipeline You can check if your data pipeline meets the performance goal in two ways: + 1. **Check the profile**: Look for gaps in the accelerator trace (like the "Bad" example above). If there are no gaps, your data loading is likely fast enough. 2. **Run in isolation**: You can benchmark training and dataloading separately with the following steps: run your training workload with synthetic data (`dataset_type=synthetic`) to get a target_step_time time; use a script (like `standalone_dataloader.py`) to time how long it takes to load data batches without training. If your data_loading_time is consistently less than your target_step_time, your data pipeline is not the bottleneck. However, if your step time with _real data_ is still slower than your target_step_time, it strongly suggests a forced synchronization issue. ## How to speed up a slow data pipeline If your profile confirms that data loading is parallel but still slower than computation, then data loading is the bottleneck. Here are a few ways to speed it up: + 1. **Tune Grain**: If you are using the [Grain data pipeline](data_input_grain.md), start by tuning the `grain_worker_count`. If adjusting the worker count isn't enough, use the [Grain performance and debugging tool](https://google-grain.readthedocs.io/en/latest/tutorials/dataset_debugging_tutorial.html) to find the specific bottleneck. 2. **Pre-process offline**: Perform as much data preparation as possible offline. Apply only light-weight preprocessing during training. diff --git a/docs/guides/model_bringup.md b/docs/guides/model_bringup.md index 1ba7e7181a..5aadb3aee1 100644 --- a/docs/guides/model_bringup.md +++ b/docs/guides/model_bringup.md @@ -26,13 +26,13 @@ The first phase involves determining how the new model's architecture aligns wit **Tokenizer**: Supported [tokenizer options](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/input_pipeline/tokenizer.py) include `TikTokenTokenizer`, `SentencePieceTokenizer`, and `HFTokenizer`. -**Self-Attention & RoPE**: Available mechanisms include optimized [Flash Attention](https://github.com/AI-Hypercomputer/maxtext/blob/62ee818144eb037ad3fe85ab8e789cd074776f46/src/MaxText/layers/attention_op.py#L1184) (supporting MHA, GQA, and MQA), Multi-head Latent Attention ([MLA](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/attention_mla.py)), and [Gated Delta Network](https://github.com/AI-Hypercomputer/maxtext/blob/62ee818144eb037ad3fe85ab8e789cd074776f46/src/MaxText/models/qwen3.py#L358). MaxText also supports [Regular](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L108), [Llama](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L178), and [YaRN](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L282) variations of Rotary Positional Embeddings (RoPE). +**Self-Attention & RoPE**: Available mechanisms include optimized [Flash Attention](https://github.com/AI-Hypercomputer/maxtext/blob/62ee818144eb037ad3fe85ab8e789cd074776f46/src/maxtext/layers/attention_op.py#L1184) (supporting MHA, GQA, and MQA), Multi-head Latent Attention ([MLA](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/attention_mla.py)), and [Gated Delta Network](https://github.com/AI-Hypercomputer/maxtext/blob/62ee818144eb037ad3fe85ab8e789cd074776f46/src/maxtext/models/qwen3.py#L358). MaxText also supports [Regular](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L108), [Llama](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L178), and [YaRN](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L282) variations of Rotary Positional Embeddings (RoPE). **Multi-Layer Perceptron (MLP)**: The framework supports both traditional dense models and Mixture of Experts (MoE) architectures, including [configurations](https://maxtext.readthedocs.io/en/latest/reference/core_concepts/moe_configuration.html) for routed and shared experts. -**Normalization**: We support different [normalization strategies](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/normalizations.py), including RMSNorm and Gated RMSNorm. These can be configured before or after attention/MLP layers. +**Normalization**: We support different [normalization strategies](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/normalizations.py), including RMSNorm and Gated RMSNorm. These can be configured before or after attention/MLP layers. -**Decoder Layers**: Models can have multiple [decoder layers](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/models) with varying structures. The trend has evolved from entirely dense layers to purely MoE layers, and now towards a mix of both. +**Decoder Layers**: Models can have multiple [decoder layers](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/models) with varying structures. The trend has evolved from entirely dense layers to purely MoE layers, and now towards a mix of both. ## 2. (Optional) Feature Implementation @@ -58,7 +58,7 @@ Success starts with a clear map. You must align the parameter names from your so ### 3.2 Write Script -Use existing model scripts within the repository as templates to tailor the conversion logic for your specific architecture. We strongly recommended to use the [checkpoint conversion utility](https://maxtext.readthedocs.io/en/latest/guides/checkpointing_solutions/convert_checkpoint.html) rather than [standalone scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/checkpoint_conversion/standalone_scripts). +Use existing model scripts within the repository as templates to tailor the conversion logic for your specific architecture. We strongly recommended to use the [checkpoint conversion utility](https://maxtext.readthedocs.io/en/latest/guides/checkpointing_solutions/convert_checkpoint.html) rather than [standalone scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/checkpoint_conversion/standalone_scripts). ### 3.3 Verify Compatibility diff --git a/docs/guides/monitoring_and_debugging.md b/docs/guides/monitoring_and_debugging.md index 8b1e00985e..a1d89c568a 100644 --- a/docs/guides/monitoring_and_debugging.md +++ b/docs/guides/monitoring_and_debugging.md @@ -14,55 +14,54 @@ limitations under the License. --> - # Monitoring and debugging -::::{grid} 1 2 2 2 +::::\{grid} 1 2 2 2 :gutter: 2 -:::{grid-item-card} 🕵️ Features & Diagnostics +:::\{grid-item-card} 🕵️ Features & Diagnostics :link: monitoring_and_debugging/features_and_diagnostics :link-type: doc Diagnostic tools and features for monitoring MaxText. ::: -:::{grid-item-card} ☁️ GCP Observability +:::\{grid-item-card} ☁️ GCP Observability :link: monitoring_and_debugging/gcp_workload_observability :link-type: doc Observability for workloads running on Google Cloud Platform. ::: -:::{grid-item-card} 🚫 Hang Playbook +:::\{grid-item-card} 🚫 Hang Playbook :link: monitoring_and_debugging/megascale_hang_playbook :link-type: doc Troubleshooting guide for training hangs at megascale. ::: -:::{grid-item-card} 📈 Goodput +:::\{grid-item-card} 📈 Goodput :link: monitoring_and_debugging/monitor_goodput :link-type: doc Monitoring efficient training time (Goodput). ::: -:::{grid-item-card} 📊 Logs & Metrics +:::\{grid-item-card} 📊 Logs & Metrics :link: monitoring_and_debugging/understand_logs_and_metrics :link-type: doc Understanding MaxText logs and performance metrics. ::: -:::{grid-item-card} 📉 TensorBoard +:::\{grid-item-card} 📉 TensorBoard :link: monitoring_and_debugging/use_vertex_ai_tensorboard :link-type: doc Using Vertex AI TensorBoard for visualization. ::: -:::{grid-item-card} ⏱️ XProf +:::\{grid-item-card} ⏱️ XProf :link: monitoring_and_debugging/xprof_user_guide :link-type: doc @@ -71,9 +70,10 @@ Profiling performance with XProf. :::: ```{toctree} -:hidden: -:maxdepth: 1 - +--- +hidden: +maxdepth: 1 +--- monitoring_and_debugging/features_and_diagnostics.md monitoring_and_debugging/gcp_workload_observability.md monitoring_and_debugging/megascale_hang_playbook.md diff --git a/docs/guides/monitoring_and_debugging/understand_logs_and_metrics.md b/docs/guides/monitoring_and_debugging/understand_logs_and_metrics.md index e381e7e8ac..5653ca05b6 100644 --- a/docs/guides/monitoring_and_debugging/understand_logs_and_metrics.md +++ b/docs/guides/monitoring_and_debugging/understand_logs_and_metrics.md @@ -197,7 +197,7 @@ The **model FLOPs** are the floating point operations to perform model computati - The number of model FLOPs is dependent on model architecture, input size (batch size, sequence length), and gradient accumulation steps. It does not include optimization operations. - We break down the FLOPs into two parts: - "Learnable weight FLOPs" are matmuls between activations and learnable weights. Specifically, this occurs in embedding, feed forward networks, attention-related projections, and unembedding. - - "Attention FLOPs" are matmuls in attention score computation like $\mathrm{softmax}{\left(\frac{QK^\top}{\sqrt{d}}\right)} V$. + - "Attention FLOPs" are matmuls in attention score computation like $\\mathrm{softmax}{\\left(\\frac{QK^\\top}{\\sqrt{d}}\\right)} V$. One **TFLOP** (TeraFLOP) is equal to $10^{12}$ FLOPs. The log shows the theoretical estimate of **model TFLOP per device**: @@ -207,7 +207,7 @@ Per train step: split as 94.54% learnable weight flops and 5.46% attention flops ``` -In this example, given `model=deepseek2-16b`, `per_device_batch_size=24`, `max_target_length=2048`, and no gradient accumulation, we have $\text{model tflop per device} \approx 764.67$. +In this example, given `model=deepseek2-16b`, `per_device_batch_size=24`, `max_target_length=2048`, and no gradient accumulation, we have $\\text{model tflop per device} \\approx 764.67$. - 94.54% of the TFLOPs are attributed to learnable weight and 5.46% are attributed to attention. - As you will see next, this number is important for calculating performance metrics, such as TFLOP/s/device and Model FLOPs Utilization (MFU). @@ -233,8 +233,8 @@ completed step: 9, seconds: 5.667, TFLOP/s/device: 134.924, Tokens/s/device: 867 Before we dive deep here, recall a few numbers from previous sections: -- $\text{max target length} = 2048$, $\text{per device batch size} = 24$ -- $\text{model tflop per device} \approx 764.67$ (rounded), $\text{number of devices} = 4$ +- $\\text{max target length} = 2048$, $\\text{per device batch size} = 24$ +- $\\text{model tflop per device} \\approx 764.67$ (rounded), $\\text{number of devices} = 4$ ### 4.1. Performance metrics @@ -244,38 +244,38 @@ The performance metrics fluctuate at the beginning, and become stable towards th completed step: 9, seconds: 5.667, TFLOP/s/device: 134.924, Tokens/s/device: 8672.758, total_weights: 196608, loss: 10.374 ``` -As shown in `seconds: 5.667`, $\text{measured step time in seconds} \approx 5.667$ (rounded). +As shown in `seconds: 5.667`, $\\text{measured step time in seconds} \\approx 5.667$ (rounded). **TFLOP per second per device** - It is [computed](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/metric_logger.py#L211-L213) as -$$\text{tflop/s/device} = \frac{\text{model tflop per device}}{\text{measured step time in seconds}}$$ +$$\\text{tflop/s/device} = \\frac{\\text{model tflop per device}}{\\text{measured step time in seconds}}$$ - Here we have `TFLOP/s/device: 134.924`. Let's try to verify manually: $764.67 / 5.667 = 134.934$. Not exactly the same but close, since the both tflop and time are rounded in the log. - Further, we can calculate **Model FLOPs Utilization (MFU)** from this: -$$\text{MFU} = \frac{\text{tflop/s/device}}{\text{peak hardware tflop/s}}$$ +$$\\text{MFU} = \\frac{\\text{tflop/s/device}}{\\text{peak hardware tflop/s}}$$ -For TPU v5p, $\text{peak hardware tflop/s}=459$. Thus, $134.924 / 459 = 29.40$%. Note this is an example for explanation with small batch size and sequence length, so the MFU is not optimal. +For TPU v5p, $\\text{peak hardware tflop/s}=459$. Thus, $134.924 / 459 = 29.40$%. Note this is an example for explanation with small batch size and sequence length, so the MFU is not optimal. **Tokens per second per device (throughput)** - It is [computed](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/metric_logger.py#L215-L217) as -$$\text{token/s/device} = \frac{\text{number of tokens per device}}{\text{measured step time in seconds}}$$ +$$\\text{token/s/device} = \\frac{\\text{number of tokens per device}}{\\text{measured step time in seconds}}$$ - The numerator is from [calculate_tokens_training_per_device](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/maxtext_utils.py#L148) -$$\text{number of tokens per device} = \text{per device batch size} \times \text{max target length}$$ +$$\\text{number of tokens per device} = \\text{per device batch size} \\times \\text{max target length}$$ -- Here we have `Tokens/s/device: 8672.758`. Let's try to verify manually: $24 \times 2048 / 5.667 = 8673.372$. Not exactly the same but close, since the time is rounded in the log. +- Here we have `Tokens/s/device: 8672.758`. Let's try to verify manually: $24 \\times 2048 / 5.667 = 8673.372$. Not exactly the same but close, since the time is rounded in the log. ### 4.2. Learning metrics **Loss**. The loss is the key indicator of learning progress, which should decrease over training steps. In this example, the loss is `12.038` at Step 0 and decreases to `10.374` at Step 9. Ideally, we want the loss to converge to a small value with sufficiently large training steps. -**Total weights**. When discussing the throughput, we have $\text{number of tokens} = \text{per device batch size} \times \text{max target length} \times \text{number of device}$. In this example, $\text{number of tokens} = 24 \times 2048 \times 4 = 196608$. There are two types of tokens: real tokens and pad tokens. The pad tokens are placeholders introduced by data preprocessing: We truncate or pad each sentence to max target length. Only real tokens contribute to the learning signal (i.e., loss). Therefore, we monitor $\text{number of real tokens}$, which is shown as [total weights](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/train.py#L151). +**Total weights**. When discussing the throughput, we have $\\text{number of tokens} = \\text{per device batch size} \\times \\text{max target length} \\times \\text{number of device}$. In this example, $\\text{number of tokens} = 24 \\times 2048 \\times 4 = 196608$. There are two types of tokens: real tokens and pad tokens. The pad tokens are placeholders introduced by data preprocessing: We truncate or pad each sentence to max target length. Only real tokens contribute to the learning signal (i.e., loss). Therefore, we monitor $\\text{number of real tokens}$, which is shown as [total weights](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/train.py#L151). - Here we see `total_weights: 196608` for all steps. This is because we are using `dataset_type=synthetic`, where all sentences are generated with a length of `max_target_length=2048`. As a result, there are no pad tokens and total weights = number of tokens. - However, in real datasets, sentences can have variable lengths and total weights < number of tokens. For example, we can set `dataset_type=tfds dataset_path=gs://maxtext-dataset dataset_name='c4/en:3.0.1'`, and will see total weights smaller than `196608`: diff --git a/docs/guides/monitoring_and_debugging/use_vertex_ai_tensorboard.md b/docs/guides/monitoring_and_debugging/use_vertex_ai_tensorboard.md index 00daa984af..5bb0a439af 100644 --- a/docs/guides/monitoring_and_debugging/use_vertex_ai_tensorboard.md +++ b/docs/guides/monitoring_and_debugging/use_vertex_ai_tensorboard.md @@ -15,28 +15,34 @@ --> (vertex-ai-tensorboard)= + # Use Vertex AI Tensorboard MaxText supports automatic upload of logs collected in a directory to a Tensorboard instance in Vertex AI. For more information on how MaxText supports this feature, visit [cloud-accelerator-diagnostics](https://pypi.org/project/cloud-accelerator-diagnostics) PyPI package documentation. ## What is Vertex AI Tensorboard and Vertex AI Experiment -Vertex AI Tensorboard is a fully managed and enterprise-ready version of open-source Tensorboard. To learn more about Vertex AI Tensorboard, visit [this](https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-introduction). Vertex AI Experiment is a tool that helps to track and analyze an experiment run on Vertex AI Tensorboard. To learn more about Vertex AI Experiments, visit [this](https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments). + +Vertex AI Tensorboard is a fully managed and enterprise-ready version of open-source Tensorboard. To learn more about Vertex AI Tensorboard, visit [this](https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-introduction). Vertex AI Experiment is a tool that helps to track and analyze an experiment run on Vertex AI Tensorboard. To learn more about Vertex AI Experiments, visit [this](https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments). You can use a single Vertex AI Tensorboard instance to track and compare metrics from multiple Vertex AI Experiments. While you can view metrics from multiple Vertex AI Experiments within a single Tensorboard instance, the underlying log data for each experiment remains separate. ## Prerequisites -* Enable [Vertex AI API](https://cloud.google.com/vertex-ai/docs/start/cloud-environment#enable_vertexai_apis) in your Google Cloud console. -* Assign [Vertex AI User IAM role](https://cloud.google.com/vertex-ai/docs/general/access-control#aiplatform.user) to the service account used by the TPU VMs. This is required to create and access the Vertex AI Tensorboard in Google Cloud console. If you are using XPK for MaxText, the necessary Vertex AI User IAM role will be automatically assigned to your node pools by XPK – no need to assign it manually. + +- Enable [Vertex AI API](https://cloud.google.com/vertex-ai/docs/start/cloud-environment#enable_vertexai_apis) in your Google Cloud console. +- Assign [Vertex AI User IAM role](https://cloud.google.com/vertex-ai/docs/general/access-control#aiplatform.user) to the service account used by the TPU VMs. This is required to create and access the Vertex AI Tensorboard in Google Cloud console. If you are using XPK for MaxText, the necessary Vertex AI User IAM role will be automatically assigned to your node pools by XPK – no need to assign it manually. ## Upload logs to Vertex AI Tensorboard + **Scenario 1: Using XPK to run MaxText on GKE** XPK simplifies MaxText's Vertex AI Tensorboard integration. A Vertex Tensorboard instance and Experiment are automatically created by XPK during workload scheduling. Also, XPK automatically sets the necessary environment variables, eliminating the need to manually configure this in MaxText. Set `use_vertex_tensorboard=False` to avoid setting up Vertex Tensorboard again in MaxText. This is how the configuration will look like for running MaxText via XPK: + ``` use_vertex_tensorboard: False vertex_tensorboard_project: "" vertex_tensorboard_region: "" ``` + The above configuration will upload logs in `config.tensorboard_dir` to Vertex Tensorboard instance set as an environment variable by XPK. **Scenario 2: Running MaxText on GCE** @@ -51,11 +57,13 @@ use_vertex_tensorboard: True vertex_tensorboard_project: "test-project" # or vertex_tensorboard_project: "" vertex_tensorboard_region: "us-central1" ``` + The above configuration will try to create a Vertex AI Tensorboard instance named `test-project-tb-instance` and a Vertex AI Experiment named `test-run` in the `us-central1` region of `test-project`. If you set `vertex_tensorboard_project=""`, then the default project (`gcloud config get project`) set on the VM will be used to create the Vertex AI resources. It will only create these resources if they do not already exist. Also, the logs in `config.tensorboard_dir` will be uploaded to `test-project-tb-instance` Tensorboard instance and `test-run` Experiment in Vertex AI. **Scenario 2.2: Configuration to not upload logs to Vertex AI Tensorboard** The following configuration will not upload any log data collected in `config.tensorboard_dir` to Tensorboard in Vertex AI. + ``` use_vertex_tensorboard: False vertex_tensorboard_project: "" diff --git a/docs/guides/monitoring_and_debugging/xprof_user_guide.md b/docs/guides/monitoring_and_debugging/xprof_user_guide.md index 8facb43aa3..7f0d13ac7d 100644 --- a/docs/guides/monitoring_and_debugging/xprof_user_guide.md +++ b/docs/guides/monitoring_and_debugging/xprof_user_guide.md @@ -1,25 +1,19 @@ # Profiling with XProf - - ## Introduction to XProf [XProf](https://openxla.org/xprof) is a profiling and performance analysis tool for machine learning. You can use XProf to profile and analyze the training performance of AI models. XProf helps you understand how to optimize model performance, identify bottlenecks, and improve training efficiency. - -## Profiling in JAX +## Profiling in JAX XProf supports profiling JAX models, which is crucial for MaxText developers working with JAX. You can profile your JAX models using various methods, including: - -* **Programmatic Mode:** This provides more granular control over when and what to profile, allowing you to instrument your code with specific profiling markers. This method is integrated with MaxText code. - +- **Programmatic Mode:** This provides more granular control over when and what to profile, allowing you to instrument your code with specific profiling markers. This method is integrated with MaxText code. The following example shows how to trace a JAX operation in Python. - ``` import jax import jax.numpy as jnp @@ -35,24 +29,19 @@ with jax.profiler.StepTraceAnnotation("dot_product", step_num=iter): jax.profiler.stop() ``` - You can use [`jax.profiler.TraceAnnotation`](https://docs.jax.dev/en/latest/_autosummary/jax.profiler.TraceAnnotation.html) to add custom annotations to JAX traces. -* **Sampling Mode** This mode allows for continuous profiling by sampling data during model execution. -This mode has not yet been enabled in MaxText yet. Refer to [remote-profiling](https://docs.jax.dev/en/latest/profiling.html#remote-profiling) for manual capture/sampling. +- **Sampling Mode** This mode allows for continuous profiling by sampling data during model execution. + This mode has not yet been enabled in MaxText yet. Refer to [remote-profiling](https://docs.jax.dev/en/latest/profiling.html#remote-profiling) for manual capture/sampling. ## Profiling configuration in MaxText The following parameters control how profiling is executed within MaxText, allowing you to capture detailed performance data for analysis. -* `profiler` specifies the profiler backend to use for capturing performance traces. Options can be `xplane`, `nsys`. Default is "". `xplane` is for XLA/TPU and `nsys` is for CUDA/GPU. +- `profiler` specifies the profiler backend to use for capturing performance traces. Options can be `xplane`, `nsys`. Default is "". `xplane` is for XLA/TPU and `nsys` is for CUDA/GPU. -* `profiler_steps` defines the total number of steps to run during the profiling capture window. Default is 5 +- `profiler_steps` defines the total number of steps to run during the profiling capture window. Default is 5 -* `skip_first_n_steps_for_profiler` specifies the number of initial training steps to skip before the profiling capture begins. This is typically used to bypass model warmup and capture steady-state performance. default is 1. +- `skip_first_n_steps_for_profiler` specifies the number of initial training steps to skip before the profiling capture begins. This is typically used to bypass model warmup and capture steady-state performance. default is 1. For more information about XProf tools, see the [XProf documentation](https://openxla.org/xprof). - - - - diff --git a/docs/guides/optimization.md b/docs/guides/optimization.md index 308448298a..b24edbdbca 100644 --- a/docs/guides/optimization.md +++ b/docs/guides/optimization.md @@ -18,31 +18,31 @@ Explore techniques for maximizing performance, including model customization, sharding strategies, Pallas kernels, and benchmarking. -::::{grid} 1 2 2 2 +::::\{grid} 1 2 2 2 :gutter: 2 -:::{grid-item-card} 🛠️ Customizing Model Configs +:::\{grid-item-card} 🛠️ Customizing Model Configs :link: optimization/custom_model :link-type: doc Optimize and customize your LLM model configurations for higher performance (MFU) on TPUs. ::: -:::{grid-item-card} 🥞 Sharding Strategies +:::\{grid-item-card} 🥞 Sharding Strategies :link: optimization/sharding :link-type: doc Choose efficient sharding strategies (FSDP, TP, EP, PP) using Roofline Analysis and understand arithmetic intensity. ::: -:::{grid-item-card} ⚡ Pallas Kernels +:::\{grid-item-card} ⚡ Pallas Kernels :link: optimization/pallas_kernels_performance :link-type: doc -Optimize with Pallas kernels for fine-grained control. +Optimize with Pallas kernels for fine-grained control. ::: -:::{grid-item-card} 📈 Benchmarking & Tuning +:::\{grid-item-card} 📈 Benchmarking & Tuning :link: optimization/benchmark_and_performance :link-type: doc @@ -51,9 +51,10 @@ Guide to setting up benchmarks, performing performance tuning, and analyzing met :::: ```{toctree} -:hidden: -:maxdepth: 1 - +--- +hidden: +maxdepth: 1 +--- optimization/custom_model.md optimization/sharding.md optimization/pallas_kernels_performance.md diff --git a/docs/guides/optimization/benchmark_and_performance.md b/docs/guides/optimization/benchmark_and_performance.md index f0d1b15433..858bcb9673 100644 --- a/docs/guides/optimization/benchmark_and_performance.md +++ b/docs/guides/optimization/benchmark_and_performance.md @@ -18,7 +18,7 @@ Begin your benchmarking efforts by performing an arithmetic intensity analysis. Arithmetic intensity is calculated as the ratio of floating-point operations (FLOPs) to memory(bytes) or communication(bytes). -* **Arithmetic Intensity = FLOPs / Bytes** +- **Arithmetic Intensity = FLOPs / Bytes** This metric helps determine whether a computation is MXU-bound (high arithmetic intensity) or memory-bound/communication-bound (low arithmetic intensity). @@ -28,8 +28,8 @@ This metric helps determine whether a computation is MXU-bound (high arithmetic For benchmarking purposes, we collect the step time for training. This step time is then used to calculate MFU and throughputs, which provide insights into the utilization achieved for each benchmark workload. -* **MFU = flops_train_step / step_time / peak HW FLOPS** -* **Throughput = global tokens / step_time / number of devices** +- **MFU = flops_train_step / step_time / peak HW FLOPS** +- **Throughput = global tokens / step_time / number of devices** More detailed are explained in [](performance-metrics). @@ -51,7 +51,7 @@ Remat policies can be chosen from: `minimal_with_context`, `minimal`, `save_dot_ These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest) -`minimal_with_context` consumes the most HBM memory, while `full` signifies minimal checkpointing, with everything being rematerialized. [More explanation and latest support](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/decoders.py#L287) +`minimal_with_context` consumes the most HBM memory, while `full` signifies minimal checkpointing, with everything being rematerialized. [More explanation and latest support](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/decoders.py#L287) **Custom policy** @@ -98,19 +98,19 @@ There are two methods for asynchronous collective offloading: 1. Offload Collectives to Sparse Core: - This method is recommended for v7x. To enable it, set the following flags from [[link](https://github.com/AI-Hypercomputer/maxtext/blob/main/benchmarks/xla_flags_library.py#L70)]: + This method is recommended for v7x. To enable it, set the following flags from \[[link](https://github.com/AI-Hypercomputer/maxtext/blob/main/benchmarks/xla_flags_library.py#L70)\]: -* `ENABLE_SPARSECORE_OFFLOADING_FOR_RS_AG_AR` -* `ENABLE_SPARSECORE_OFFLOADING_FOR_REDUCE_SCATTER` -* `ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER` -* `ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE` +- `ENABLE_SPARSECORE_OFFLOADING_FOR_RS_AG_AR` +- `ENABLE_SPARSECORE_OFFLOADING_FOR_REDUCE_SCATTER` +- `ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER` +- `ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE` - 2. Overlap Collective Using Continuation Fusion:** +2. Overlap Collective Using Continuation Fusion:\*\* - This method is recommended for v5p and v6e. To enable it, set the following flags [[link](https://github.com/AI-Hypercomputer/maxtext/blob/main/benchmarks/xla_flags_library.py#L39)]: + This method is recommended for v5p and v6e. To enable it, set the following flags \[[link](https://github.com/AI-Hypercomputer/maxtext/blob/main/benchmarks/xla_flags_library.py#L39)\]: -* `CF_FOR_ALL_GATHER` -* `CF_FOR_ALL_REDUCE` +- `CF_FOR_ALL_GATHER` +- `CF_FOR_ALL_REDUCE` Those XLA can be set via `LIBTPU_INIT_ARGS` diff --git a/docs/guides/optimization/pallas_kernels_performance.md b/docs/guides/optimization/pallas_kernels_performance.md index b4884f6b17..007c2f103c 100644 --- a/docs/guides/optimization/pallas_kernels_performance.md +++ b/docs/guides/optimization/pallas_kernels_performance.md @@ -58,7 +58,7 @@ To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth - **Training Attention (Flash/Splash-style):** This kernel is the default for training Transformer models in MaxText, such as DeepSeek, Gemma and Llama. It avoids creating the large [L,L] attention matrix to save memory, processing data in smaller, tiled chunks with online softmax accumulation. - - [`src/MaxText/kernels/attention/splash_attention_kernel.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/attention/splash_attention_kernel.py) + - [`src/maxtext/kernels/attention/splash_attention_kernel.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/kernels/attention/splash_attention_kernel.py) - **Serving Attention (Paged & Ragged):** For high-throughput inference, this kernel efficiently fetches non-contiguous "pages" of the KV cache from memory. It is a key optimization for our serving stack and is used for models running on MaxText's inference engine. @@ -69,9 +69,9 @@ To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth > This is an efficient computation method for Mixture-of-Experts (MoE) models like DeepSeek, Llama 4, Mixtral and Qwen-MoE. In MoE, each token is processed by only a few "experts," which is inefficient for standard matrix multiplication. Megablox solves this by having the CPU (**host**) first create a routing plan (**metadata**) that assigns tokens to experts. The accelerator (**device**) then uses this plan to perform many small, dense matrix multiplications in parallel (**Grouped Matrix Multiplication**), avoiding wasted work on unused experts. - - [`src/MaxText/kernels/megablox/gmm.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py) + - [`src/maxtext/kernels/megablox/gmm.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/kernels/megablox/gmm.py) - **Note:** Megablox accelerates the grouped **matmul**; **routing/gating** is separate code ([`src/MaxText/layers/moe.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/moe.py)). + **Note:** Megablox accelerates the grouped **matmul**; **routing/gating** is separate code ([`src/maxtext/layers/moe.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/moe.py)). ## 🔧 The Pallas optimization workflow: code → profile → tune → repeat @@ -83,12 +83,14 @@ Give the kernel a clear name in traces and capture a profile. Always use [`jax.b import jax from jax import profiler -def my_op(...): + +def my_op(x): # This name shows up in Perfetto/TensorBoard traces with jax.named_scope("my_custom_kernel"): - out = my_kernel_wrapper(...) + out = my_kernel_wrapper(x) return out + # Capture a Perfetto/TensorBoard trace with profiler.trace("/tmp/tb_profile"): y = my_op(x) diff --git a/docs/guides/optimization/sharding.md b/docs/guides/optimization/sharding.md index 6e8a69cd0a..44a8c41f68 100644 --- a/docs/guides/optimization/sharding.md +++ b/docs/guides/optimization/sharding.md @@ -26,14 +26,14 @@ When considering different sharding strategies, the main concern is the amount o We illustrate our sharding notation with an example matmul: -$$B_xE \times EM = B_xM$$ +$$B_xE \\times EM = B_xM$$ Where B, E and M are names of dimensions and a subscript denotes sharding. For example, $B_xE$ is a 2-dimensional matrix sharded along the $B$ dimension, using the $x$ mesh axis. Dimensions without a subscript are not sharded. -This example is of standard data parallelism, only the batch dimension is sharded. Note that $B$ refers to the batch dimension, $B_x$ to the local shard of this dimension, whereas we use $\left|B\right|$ and $\left|B_x\right|$ to refer to the lengths of single axes, and $\left|x\right|$ as the degree of sharding on the x axis, e.g. $\left|B_x\right| = \left|B\right|/\left|x\right|$. We drop this $\left|\cdot\right|$ notation when there is a product to reduce clutter, e.g. we use $BEM_x$ instead of $\left|B\right|\left|E\right|\left|M_x\right|$. +This example is of standard data parallelism, only the batch dimension is sharded. Note that $B$ refers to the batch dimension, $B_x$ to the local shard of this dimension, whereas we use $\\left|B\\right|$ and $\\left|B_x\\right|$ to refer to the lengths of single axes, and $\\left|x\\right|$ as the degree of sharding on the x axis, e.g. $\\left|B_x\\right| = \\left|B\\right|/\\left|x\\right|$. We drop this $\\left|\\cdot\\right|$ notation when there is a product to reduce clutter, e.g. we use $BEM_x$ instead of $\\left|B\\right|\\left|E\\right|\\left|M_x\\right|$. We illustrate this notation on model parallelism as well: -$BM_x \times M_xE = BE \rightarrow \text{Reduce-Scatter (RS) over x} \rightarrow BE_x$ +$BM_x \\times M_xE = BE \\rightarrow \\text{Reduce-Scatter (RS) over x} \\rightarrow BE_x$ Explanation: Both the activations ($BM$) and weights ($ME$) are sharded on the M dimension. Thus each device is able to perform the matmul locally with its shard of the $M_x$ dimension, the local result is of the right global shape ($BE$) but is only a partial result - it needs to be summed with the other shards to get the full result. This is achieved with a reduce scatter (which does the summation and additionally shards the activations). Note that some flavors of tensor parallelism call for an all reduce instead a reduce scatter, but generally in maxtext we use a reduce scatter here. @@ -49,11 +49,11 @@ Explanation: Both the activations ($BM$) and weights ($ME$) are sharded on the M Note for the feed forward computation the batch and sequence dimensions act the same and thus we use only one $B$ axis (which you can think of as a token batch dimension, a reshaping of batch and sequence into one axis), but for context and sequence parallelism they act differently and thus we use both a $B$ and $S$ dimension and the $B$ dimension is really batch in sequences. For example a matmul with an explicit sequence dimension might look like -$$BSE \times EM = BSM$$ +$$BSE \\times EM = BSM$$ But for arithmetic intensity roofline analysis purposes the $B$ and $S$ axis act as one, and generally we omit the $S$ axis except for when its needed (context/sequence parallelism), thus we only write -$$BE \times EM = BM$$ +$$BE \\times EM = BM$$ We recognize this overloads the definition of $B$ but for arithmetic intensity purposes the only batch size that matters is batch in tokens - which imagines combining the $B$ and $S$ axis into one. @@ -69,9 +69,9 @@ We will see why this is a useful definition by walking through an example. We want to be compute bound (because there is a fixed amount of compute to perform), which means we want the compute to take longer than the communication. Consider the above example (model parallelism aka tensor parallelism) -$$ BM_x \times M_xE = BE \text{ (partial result)} \rightarrow \text{RS over x} \rightarrow BE_x $$ +$$ BM_x \\times M_xE = BE \\text{ (partial result)} \\rightarrow \\text{RS over x} \\rightarrow BE_x $$ -The compute is $BM_x \times M_xE = BE$ matmul, which takes $2BM_xE$ flops (you can think of this as $\left|B\right| * \left|E\right|$ dot products each of length $\left|M_x\right|$, thus there are $BEM_x$ multiplications and additions to perform. +The compute is $BM_x \\times M_xE = BE$ matmul, which takes $2BM_xE$ flops (you can think of this as $\\left|B\\right| * \\left|E\\right|$ dot products each of length $\\left|M_x\\right|$, thus there are $BEM_x$ multiplications and additions to perform. **Compute time** = Flops / compute speed = $2BEM_x$ / compute speed @@ -95,23 +95,23 @@ Operation Arithmetic Intensity > Hardware Arithmetic Intensity The LHS (Compute Flops / Comm bytes) is the “Operation” or “Model” arithmetic intensity, whereas the RHS (Compute Speed / comm speed) is the hardware arithmetic intensity. This re-arrangement has a huge benefit in that it separates model from hardware - the operational intensity is independent of the hardware. Note however that arithmetic has this funky unity of flops/byte - intuitively you can think of this as the amount of flops unlocked by communicating a certain amount of bytes. -Operation Arithmetic Intensity for this example: $2BM_xE$ flops / $2BE$ bytes = $\left|M_x\right|$ flops/byte +Operation Arithmetic Intensity for this example: $2BM_xE$ flops / $2BE$ bytes = $\\left|M_x\\right|$ flops/byte Hardware Arithmetic Intensity: Compute speed / comm speed -Example hardware for trillium (See https://cloud.google.com/tpu/docs/v6e), compute speed = $917$ TFLOPs, and comm speed of 1 ICI axis is $180$ GB/s so the ratio $917 * 10^12 / 180 * 10^ 9 = 5100$. Thus we would need $\left|M_x\right| > 5100$ (Operational AI > Hardware AI) to be compute bound for this operation. This is an example of key insights that arithmetic intensity gives us - it tells us we need a large $\left|M\right|$ to achieve high utilization for model parallelism because the operational intensity is proportional to $\left|M\right|$. +Example hardware for trillium (See https://cloud.google.com/tpu/docs/v6e), compute speed = $917$ TFLOPs, and comm speed of 1 ICI axis is $180$ GB/s so the ratio $917 * 10^12 / 180 * 10^ 9 = 5100$. Thus we would need $\\left|M_x\\right| > 5100$ (Operational AI > Hardware AI) to be compute bound for this operation. This is an example of key insights that arithmetic intensity gives us - it tells us we need a large $\\left|M\\right|$ to achieve high utilization for model parallelism because the operational intensity is proportional to $\\left|M\\right|$. ## Arithmetic Intensity: Mixed sharding strategies -When we use multiple sharding strategies together it seems intractable to keep track of all of the compute vs communication ratios. However it turns out (not obvious at first), that the arithmetic intensity analysis of a “pure” sharding strategy generalizes to when it's used in a mix. For instance, if we added data parallelism to the above tensor parallelism example then the batch dimension $B$ would also be sharded by a new mesh axes $y$. Both the compute and communication would decrease by this sharding factor $\left|y\right|$, and thus the ratio of compute to comms for tensor parallelism would remain the same ($\left|M\right|\left|x\right|$, independent of $\left|y\right|$). Concretely this would look like +When we use multiple sharding strategies together it seems intractable to keep track of all of the compute vs communication ratios. However it turns out (not obvious at first), that the arithmetic intensity analysis of a “pure” sharding strategy generalizes to when it's used in a mix. For instance, if we added data parallelism to the above tensor parallelism example then the batch dimension $B$ would also be sharded by a new mesh axes $y$. Both the compute and communication would decrease by this sharding factor $\\left|y\\right|$, and thus the ratio of compute to comms for tensor parallelism would remain the same ($\\left|M\\right|\\left|x\\right|$, independent of $\\left|y\\right|$). Concretely this would look like -$$B_yM_x \times M_xE = B_yE \rightarrow \text{RS over x } \rightarrow B_yE_x $$ +$$B_yM_x \\times M_xE = B_yE \\rightarrow \\text{RS over x } \\rightarrow B_yE_x $$ **Compute:** = $2B_yM_xE$ Flops **TP comms (RS)** = $2B_yE$ bytes -**Ratio (Arithmetic Intensity)** = $\left|M_x\right|$ Flops/byte +**Ratio (Arithmetic Intensity)** = $\\left|M_x\\right|$ Flops/byte This "independence" of sharding strategies is true for the main four parallelisms (data, model (tensor), pipeline, and expert). Note that data, fsdp, context and sequence parallelism are all roughly the same for the purpose of arithmetic intensity analysis since they shard the batch, as we will illustrate in the individual sections below. In addition both data and pipeline parallelism (microbatches) shard the batch which decreases the HBM arithmetic intensity. @@ -120,15 +120,15 @@ arithmetic intensity analysis since they shard the batch, as we will illustrate Sharding in maxtext is split into 3 layers -- **Physical** mesh axes (e.g. `data`, `fsdp`, `tensor`) defined [here](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/maxtext/configs/base.yml#L269) +- **Physical** mesh axes (e.g. `data`, `fsdp`, `tensor`) defined [here](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/MaxText/configs/base.yml#L269) - Mesh is created via [create_device_mesh](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/MaxText/max_utils.py#L576-L580) - Mesh given names in train.py via [Mesh](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/MaxText/train.py#L594) -- **Logical** axes which map a meaningful axes name to physical axes defined [here](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/maxtext/configs/base.yml#L270) +- **Logical** axes which map a meaningful axes name to physical axes defined [here](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/MaxText/configs/base.yml#L270) -- E.g. logical axes `activation_batch` is sharded by the physical axes of `data` and `fsdp` (among others) since those sharding strategies shard the batch. `Activation_batch` is a common axis among most activation tensors. Note that if we use `data_parallelism=4` and `fsdp_parallelism=2`, then the `activation_batch` dimension will get sharded over both, e.g. $4*2=8$ ways. +- E.g. logical axes `activation_batch` is sharded by the physical axes of `data` and `fsdp` (among others) since those sharding strategies shard the batch. `Activation_batch` is a common axis among most activation tensors. Note that if we use `data_parallelism=4` and `fsdp_parallelism=2`, then the `activation_batch` dimension will get sharded over both, e.g. $4\*2=8$ ways. - **Individual tensors** have sharding constraints - generally specified by logical rules @@ -163,11 +163,11 @@ The simplest parallelization is data parallelization. Each chip works on a diffe Roughly approximate the entire backward pass: -**Compute**: $4 * \text{local batch} * \text{params}$ +**Compute**: $4 * \\text{local batch} * \\text{params}$ -We saw above that each matmul performs $2 * \text{local batch} * \text{params}$ flops, it turns out that the backward pass requires twice as many flops as the forward pass. We don't derive this here but highly recommend reading these [slides](https://www.cs.toronto.edu/~rgrosse/courses/csc321_2017/slides/lec6.pdf) from University of Toronto to explain the mathematics and implementation of backprop. +We saw above that each matmul performs $2 * \\text{local batch} * \\text{params}$ flops, it turns out that the backward pass requires twice as many flops as the forward pass. We don't derive this here but highly recommend reading these [slides](https://www.cs.toronto.edu/~rgrosse/courses/csc321_2017/slides/lec6.pdf) from University of Toronto to explain the mathematics and implementation of backprop. -**Communicate**: All reduce size of params (`bf16`) : $4 * \text{params}$ (`2*` since `bf16`, another `2*` since an optimal all reduce algorithm turns out to require two passes of communicating data (generally a reduce scatter followed by an all-gather)) +**Communicate**: All reduce size of params (`bf16`) : $4 * \\text{params}$ (`2*` since `bf16`, another `2*` since an optimal all reduce algorithm turns out to require two passes of communicating data (generally a reduce scatter followed by an all-gather)) **Ratio (arithmetic intensity)**: `local_batch` @@ -181,22 +181,22 @@ e.g. the original activations have grown by a factor of `expert_per_token` and a `batch_per_expert` = `batch` * (`expert_per_token`/`expert`) = `batch` / `sparsity` -We denote the local `batch_per_expert` with $\beta$ and analyze an MoE feedfoward matmul to calculate arithmetic intensity: +We denote the local `batch_per_expert` with $\\beta$ and analyze an MoE feedfoward matmul to calculate arithmetic intensity: -$$\beta EX \times EMX = \beta MX$$ +$$\\beta EX \\times EMX = \\beta MX$$ -**Compute:** $4\beta EMX$ Flops (2x in backward pass) +**Compute:** $4\\beta EMX$ Flops (2x in backward pass) **Comms:** All Reduce Gradient of size $EMX$: $4EMX$ bytes -**Ratio (arithmetic intensity):** $\left|\beta\right| = \text{local batch} / \text{sparsity}$ +**Ratio (arithmetic intensity):** $\\left|\\beta\\right| = \\text{local batch} / \\text{sparsity}$ ### DP Arithmetic Intensity (Hierarchical) For a hierarchal mesh (TPU: within slice ICI, across slice DCN, GPU: within NVL domain, across NVL Domains), only one set of gradients need to be communicated across the slower network per slice/NVL Domain (as opposed to one set per chip). This is generally achieved for us automatically by the XLA compiler: -Reduce Scatter grads on fast network $\rightarrow$ All Reduce across slow $\rightarrow$ All Gather on faster network +Reduce Scatter grads on fast network $\\rightarrow$ All Reduce across slow $\\rightarrow$ All Gather on faster network We can compute the arithmetic intensity of these cross slice/NVL Domain comms by imagining the chips forming a slice or NVL Domain as one "super chip". This "super chip" processes all of the tokens within its domain, but it only has to share one copy of the gradients to its super chip neighbors. @@ -207,11 +207,11 @@ If the local per device batch size is `local batch`, then we can imagine each "s We can then perform the same arithmetic intensity analysis as before, and indeed get the same result: -**Compute (per super chip):** $4 * \text{super batch} * \text{params}$ flops +**Compute (per super chip):** $4 * \\text{super batch} * \\text{params}$ flops -**Comms (per super chip):** All reduce params $\rightarrow 4 * \text{params}$ bytes +**Comms (per super chip):** All reduce params $\\rightarrow 4 * \\text{params}$ bytes -**Ratio (arithmetic intensity):** $\text{super batch } (\text{super batch} / \text{sparsity} \text{ for sparse models})$ +**Ratio (arithmetic intensity):** $\\text{super batch } (\\text{super batch} / \\text{sparsity} \\text{ for sparse models})$ This illustrates there are more than one way to calculate arithmetic intensity - we could also derive the same expression from the chip level as long as we are consistent for the compute and comms - either both the compute and comms should be at the super chip level, or both should be at the regular chip level. @@ -230,9 +230,9 @@ Approximate a typical weight @ activation = activation matmul: Start with activations sharded like $B_xE$ and weights sharded like $E_xM$ (it doesn't matter which axis of weights is sharded). We must first All Gather (AG) the weights -$$E_xM \rightarrow \text{AG } x \rightarrow EM$$ +$$E_xM \\rightarrow \\text{AG } x \\rightarrow EM$$ -**Compute**: $B_xE \times EM = B_xM$ +**Compute**: $B_xE \\times EM = B_xM$ This takes $2B_xEM$ flops @@ -300,7 +300,7 @@ Shard the activations along the feature dimensions (e.g. model or `embed` dimens Analyze one pattern of TP as given above -$$ BM_x \times M_xE = BE \text{ (local partial result) } \rightarrow \text{ Reduce-Scatter (RS) } x \rightarrow BE_x $$ +$$ BM_x \\times M_xE = BE \\text{ (local partial result) } \\rightarrow \\text{ Reduce-Scatter (RS) } x \\rightarrow BE_x $$ **Compute:** $2BM_xE$ Flops @@ -308,11 +308,11 @@ $$ BM_x \times M_xE = BE \text{ (local partial result) } \rightarrow \text{ Redu **Ratio (arithmetic intensity)** -$\left|M_x\right| = \left|M\right|/\left|TP\right|$ +$\\left|M_x\\right| = \\left|M\\right|/\\left|TP\\right|$ Note this is one pattern of TP where the contracting dimension is sharded. By contrast for the initial feed forward matmul the non-contracting weight dimension is sharded: -$$BE_x \times EM_x \rightarrow \text{AG activations over } x\rightarrow BE \times EM_x = BM_x$$ +$$BE_x \\times EM_x \\rightarrow \\text{AG activations over } x\\rightarrow BE \\times EM_x = BM_x$$ This is the same amount of compute, and also the same amount of communication - again activations of $BE$ are communicated, but in this case it is an initial all-gathering instead of secondary all-reduce. Ideally these activations (all-gather or reduce scatter) can be overlapped with the compute by the XLA compiler - an idea called a **collective matmul**. This is fairly challenging for the compiler since the comms and compute do depend on each other - to achieve overlap the computation and communication have to be chunked into smaller pieces and pipelined. @@ -334,13 +334,13 @@ Similar to tensor parallelism, but instead of sharding the feed forward weights This is really just swapping $E$ and $M$ of the TP analysis above, but we will include it here: -$$BE_x \times E_xM = BM_x$$ +$$BE_x \\times E_xM = BM_x$$ **Compute:** $2BE_xM$ FLOPS **Communicate:** Reduce scatter $BM$ (`bf16`): $2BM$ bytes -**Ratio (arithmetic intensity):** $\left|E_x\right|=\left|E\right|/\left|TP\right|$ +**Ratio (arithmetic intensity):** $\\left|E_x\\right|=\\left|E\\right|/\\left|TP\\right|$ ## Expert Parallelism (EP) @@ -358,19 +358,19 @@ An all-to-all (A2A) is needed to move between data sharding (fsdp) prior to the Analyze only 1 feed forward matmul -$$ BEX_x \times EMX_x = BMX_x $$ +$$ BEX_x \\times EMX_x = BMX_x $$ -$$ 2BEX_x \text{ Flops} $$ +$$ 2BEX_x \\text{ Flops} $$ **Communicate** -$$ B_xEX \rightarrow A2A \rightarrow BEX_x $$ +$$ B_xEX \\rightarrow A2A \\rightarrow BEX_x $$ Ideally this `A2A` only requires moving around $BEX_x$ elements per shard, but it depends on if the hardware is connected with an all to all network (true for `GPUs` and `TPU DCN` but not for `TPU ICI`) With a true all-to-all network this takes $2BEX_x$ bytes. Over TPU ICI, an all-to-all is instead as costly as `1/4` of all gathering the entire activation as nicely drawn [here](https://jax-ml.github.io/scaling-book/sharding/#our-final-communication-primitive-the-alltoall) in jax's sharding doc. -**Ratio (arithmetic intensity)**: $2BEMX_x / 2BEX_x = \left|M\right|$ +**Ratio (arithmetic intensity)**: $2BEMX_x / 2BEX_x = \\left|M\\right|$ Note: The batch $B$ cancels in above arithmetic intensity - the batch dimension is present in both the compute and communication since we are communicating activations so cancels from the arithmetic intensity ratio regardless of how it is shaped (e.g.`batch` or `batch_per_exp`) @@ -424,8 +424,8 @@ Note that for MoE models, this arithmetic intensity grows by a factor of `expert ## Context Autoregressive -Context Autoregressive shards the KV cache on the sequence dimension. It shards feed forward layer by experts for both activations and weights. This is used for inference only, see [inference.yml](https://github.com/AI-Hypercomputer/maxtext/blob/353a45d57eb1f1cc02e5c8d9e7b18eaf634d7edc/maxtext/configs/inference/inference.yml#L4) for the modified logical axis rules for inference. +Context Autoregressive shards the KV cache on the sequence dimension. It shards feed forward layer by experts for both activations and weights. This is used for inference only, see [inference.yml](https://github.com/AI-Hypercomputer/maxtext/blob/2052c22e3219b9f3a3fd66813bc6be793d79c963/src/maxtext/configs/inference/inference.yml#L4) for the modified logical axis rules for inference. ## Autoregressive -Autoregressive shards weights, but not activations. This is used for inference only. See [inference.yml](https://github.com/AI-Hypercomputer/maxtext/blob/353a45d57eb1f1cc02e5c8d9e7b18eaf634d7edc/maxtext/configs/inference/inference.yml#L4) for the modified logical axis rules for inference. +Autoregressive shards weights, but not activations. This is used for inference only. See [inference.yml](https://github.com/AI-Hypercomputer/maxtext/blob/2052c22e3219b9f3a3fd66813bc6be793d79c963/src/maxtext/configs/inference/inference.yml#L4) for the modified logical axis rules for inference. diff --git a/docs/reference.md b/docs/reference.md index 8ccd78b0ea..5e90ebd4fd 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -18,31 +18,31 @@ Deep dive into MaxText architecture, models, and core concepts. -::::{grid} 1 2 2 2 +::::\{grid} 1 2 2 2 :gutter: 2 -:::{grid-item-card} 📊 Performance Metrics +:::\{grid-item-card} 📊 Performance Metrics :link: reference/performance_metrics :link-type: doc Understanding Model Flops Utilization (MFU), calculation methods, and why it matters for performance optimization. ::: -:::{grid-item-card} 🤖 Models +:::\{grid-item-card} 🤖 Models :link: reference/models :link-type: doc Supported models and architectures, including Llama, Qwen, and Mixtral. Details on tiering and new additions. ::: -:::{grid-item-card} 🏗️ Architecture +:::\{grid-item-card} 🏗️ Architecture :link: reference/architecture :link-type: doc High-level overview of MaxText design, JAX/XLA choices, and how components interact. ::: -:::{grid-item-card} 💡 Core Concepts +:::\{grid-item-card} 💡 Core Concepts :link: reference/core_concepts :link-type: doc @@ -51,9 +51,10 @@ Key concepts including checkpointing strategies, quantization, tiling, and Mixtu :::: ```{toctree} -:hidden: -:maxdepth: 1 - +--- +hidden: +maxdepth: 1 +--- reference/performance_metrics reference/models reference/architecture diff --git a/docs/reference/architecture.md b/docs/reference/architecture.md index 0732478d96..a3004b8f4c 100644 --- a/docs/reference/architecture.md +++ b/docs/reference/architecture.md @@ -1,16 +1,16 @@ # Architecture -::::{grid} 1 2 2 2 +::::\{grid} 1 2 2 2 :gutter: 2 -:::{grid-item-card} 🗺️ Overview +:::\{grid-item-card} 🗺️ Overview :link: architecture/architecture_overview :link-type: doc High-level overview of MaxText design and components. ::: -:::{grid-item-card} 📚 JAX/AI Libraries +:::\{grid-item-card} 📚 JAX/AI Libraries :link: architecture/jax_ai_libraries_chosen :link-type: doc @@ -19,9 +19,10 @@ Deep dive into the JAX and AI libraries used in MaxText. :::: ```{toctree} -:hidden: -:maxdepth: 1 - +--- +hidden: +maxdepth: 1 +--- architecture/architecture_overview.md architecture/jax_ai_libraries_chosen.md ``` diff --git a/docs/reference/architecture/architecture_overview.md b/docs/reference/architecture/architecture_overview.md index 4d3f16f5a9..1b73145dcd 100644 --- a/docs/reference/architecture/architecture_overview.md +++ b/docs/reference/architecture/architecture_overview.md @@ -33,7 +33,7 @@ The control plane of MaxText provides a structured yet flexible interface for us ### `base.yml`: the central configuration hub -Every MaxText job is governed by the same base YAML configuration file ([`src/maxtext/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/01c7137d4e13878e38baae44dc99e588eaa50a70/src/maxtext/configs/base.yml)) with model-specific details and overrides passed through a second config (e.g. [`src/maxtext/configs/models/deepseek3-671b.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/01c7137d4e13878e38baae44dc99e588eaa50a70/src/maxtext/configs/models/deepseek3-671b.yml)). Finally, experiment-specific settings are passed on the command line. The contents of these together comprise all the hyperparameters and settings that define a run: +Every MaxText job is governed by the same base YAML configuration file ([`src/maxtext/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/01c7137d4e13878e38baae44dc99e588eaa50a70/src/MaxText/configs/base.yml)) with model-specific details and overrides passed through a second config (e.g. [`src/maxtext/configs/models/deepseek3-671b.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/01c7137d4e13878e38baae44dc99e588eaa50a70/src/MaxText/configs/models/deepseek3-671b.yml)). Finally, experiment-specific settings are passed on the command line. The contents of these together comprise all the hyperparameters and settings that define a run: - Model architecture: Defines the core transformer structure, with parameters like `model_name` (e.g., 'llama2-7b'), `global_parameter_scale` for size, `base_emb_dim`, `base_num_heads`, the type of attention mechanism, and `quantization` settings (e.g., 'int8'). - Training and optimization: Controls the training process with settings like `steps`, `learning_rate`, optimizer parameters such as `adam_b1`, and the `per_device_batch_size`. @@ -161,7 +161,7 @@ Performance can be further tuned by setting specific XLA flags in the configurat One of the most significant performance levers available in MaxText is the integration of Google's Accurate Quantized Training (AQT) and Qwix libraries. These enable training with reduced numerical precision, reducing memory requirements and often increasing FLOPS, while maintaining model quality and convergence characteristics that are very close to the full-precision baseline. Integration into MaxText is seamless for the user. Quantization can be enabled by simply setting, for example, `quantization: 'int8'` in the configuration file. This flag activates quantization-aware layers (defined in -[`src/MaxText/layers/quantizations.py`](https://github.com/AI-Hypercomputer/maxtext/blob/db7b85be153e6b7ca387a8d02c991f9d35bae6bd/src/MaxText/layers/quantizations.py)) that are applied to the relevant dense layers within the model's Flax definition. The quantization library handles the complexities of simulating quantization during the forward and backward passes, allowing the model to learn weights that are robust to the reduced precision. +[`src/maxtext/layers/quantizations.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/quantizations.py)) that are applied to the relevant dense layers within the model's Flax definition. The quantization library handles the complexities of simulating quantization during the forward and backward passes, allowing the model to learn weights that are robust to the reduced precision. ## The ecosystem: interoperability and advanced features diff --git a/docs/reference/architecture/jax_ai_libraries_chosen.md b/docs/reference/architecture/jax_ai_libraries_chosen.md index 1562017548..ea1b7ad8d9 100644 --- a/docs/reference/architecture/jax_ai_libraries_chosen.md +++ b/docs/reference/architecture/jax_ai_libraries_chosen.md @@ -2,12 +2,12 @@ MaxText is built on a curated stack of JAX libraries, each chosen for a specific purpose. This document provides an opinionated view on *why* MaxText uses the following key components of the JAX ecosystem: -* **Flax (NNX)**: For ergonomic and functional model definition. -* **Optax**: For composable optimization. -* **Orbax**: For robust checkpointing. -* **Grain**: For deterministic, multi-host data loading. -* **Qwix**: For native JAX quantization. -* **Tunix**: For modular fine-tuning. +- **Flax (NNX)**: For ergonomic and functional model definition. +- **Optax**: For composable optimization. +- **Orbax**: For robust checkpointing. +- **Grain**: For deterministic, multi-host data loading. +- **Qwix**: For native JAX quantization. +- **Tunix**: For modular fine-tuning. This stack isn't just a random collection of tools; it represents a design philosophy centered around **explicitness, composability, and performance at scale**. @@ -15,13 +15,13 @@ This document provides an opinionated view on *why* MaxText uses these specific ## Flax: For functional model definition -**What is it?** Flax is a high-performance neural network library for JAX that is designed to be flexible, explicit, and easy to use. +**What is it?** Flax is a high-performance neural network library for JAX that is designed to be flexible, explicit, and easy to use. With its latest generation API, NNX, Flax provides a modern, object-oriented (OOP) approach that makes defining and managing models more intuitive and Pythonic. -1. **Explicit State Management**: Unlike stateful frameworks where parameters are hidden attributes of an object, Flax treats model parameters (`params`, `batch_stats`, etc.) as explicit arguments to its functions. This transparency is crucial for debugging and managing distributed state. -2. **Deep JAX Integration**: Flax's NNX is designed from the ground up to work seamlessly with JAX's powerful transformations like `jax.jit` and `jax.grad`. This enables high performance and scalability without sacrificing ease of use. -3. **Flexibility through PyTrees**: All model state is stored in standard JAX PyTrees (nested dictionaries), making it trivial to inspect and manipulate any part of the model. +1. **Explicit State Management**: Unlike stateful frameworks where parameters are hidden attributes of an object, Flax treats model parameters (`params`, `batch_stats`, etc.) as explicit arguments to its functions. This transparency is crucial for debugging and managing distributed state. +2. **Deep JAX Integration**: Flax's NNX is designed from the ground up to work seamlessly with JAX's powerful transformations like `jax.jit` and `jax.grad`. This enables high performance and scalability without sacrificing ease of use. +3. **Flexibility through PyTrees**: All model state is stored in standard JAX PyTrees (nested dictionaries), making it trivial to inspect and manipulate any part of the model. For more information on using Flax, please refer to https://github.com/google/flax @@ -29,9 +29,9 @@ For more information on using Flax, please refer to https://github.com/google/fl **What is it?** Optax is a gradient processing and optimization library for JAX. It reimagines the optimizer as a series of composable functional transformations. -1. **Decoupling Optimization from Parameters**: Optax completely separates the optimizer's state from the model's parameters, treating the update step as a pure function. -2. **The Power of `optax.chain`**: The core design pattern in Optax is chaining gradient transformations. This makes it easy to build custom optimizers by combining building blocks like gradient clipping, weight decay, and a learning rate schedule. -3. **Rich Library of Pre-Built Optimizers**: While Optax is ideal for building custom optimizers, it also comes with a wide range of popular optimizers like `optax.adamw`, `optax.adam`, and `optax.sgd` ready to use out-of-the-box. This provides the flexibility to start with a standard optimizer and only customize when needed. +1. **Decoupling Optimization from Parameters**: Optax completely separates the optimizer's state from the model's parameters, treating the update step as a pure function. +2. **The Power of `optax.chain`**: The core design pattern in Optax is chaining gradient transformations. This makes it easy to build custom optimizers by combining building blocks like gradient clipping, weight decay, and a learning rate schedule. +3. **Rich Library of Pre-Built Optimizers**: While Optax is ideal for building custom optimizers, it also comes with a wide range of popular optimizers like `optax.adamw`, `optax.adam`, and `optax.sgd` ready to use out-of-the-box. This provides the flexibility to start with a standard optimizer and only customize when needed. For more information on using Optax, please refer to https://github.com/google-deepmind/optax @@ -43,10 +43,10 @@ For more information on using Optax, please refer to https://github.com/google-d For massive models, saving and loading state is a critical part of the training infrastructure. -1. **Asynchronous Checkpointing**: It writes large checkpoints to storage in the background without stalling the expensive TPU/GPU accelerators, maximizing hardware utilization. -2. **Checkpoint Management**: Orbax provides a `CheckpointManager` that handles the entire lifecycle of checkpoints, including versioning, keeping the N most recent saves, and ensuring atomic writes to prevent corruption. -3. **Handling Scanned and Unscanned Formats**: For performance, MaxText uses `jax.lax.scan` over its transformer layers. This results in an efficient "scanned" checkpoint format where layer parameters are stacked along a single array axis. For interoperability with other frameworks and inference, a more standard "unscanned" (layer-by-layer) format is often required. Orbax is used to reliably save, load, and convert between both formats, enabling both efficient training, inference, and easy model sharing. -4. **Facilitating Checkpoint Conversion**: When importing models from other ecosystems like Hugging Face, Orbax provides the final, critical step. Conversion scripts first load external weights (e.g., from `.safetensors` files) and map them to a JAX PyTree. Orbax is then used to save this PyTree as a native, MaxText-compatible checkpoint, providing a robust and standardized endpoint for the conversion pipeline. +1. **Asynchronous Checkpointing**: It writes large checkpoints to storage in the background without stalling the expensive TPU/GPU accelerators, maximizing hardware utilization. +2. **Checkpoint Management**: Orbax provides a `CheckpointManager` that handles the entire lifecycle of checkpoints, including versioning, keeping the N most recent saves, and ensuring atomic writes to prevent corruption. +3. **Handling Scanned and Unscanned Formats**: For performance, MaxText uses `jax.lax.scan` over its transformer layers. This results in an efficient "scanned" checkpoint format where layer parameters are stacked along a single array axis. For interoperability with other frameworks and inference, a more standard "unscanned" (layer-by-layer) format is often required. Orbax is used to reliably save, load, and convert between both formats, enabling both efficient training, inference, and easy model sharing. +4. **Facilitating Checkpoint Conversion**: When importing models from other ecosystems like Hugging Face, Orbax provides the final, critical step. Conversion scripts first load external weights (e.g., from `.safetensors` files) and map them to a JAX PyTree. Orbax is then used to save this PyTree as a native, MaxText-compatible checkpoint, providing a robust and standardized endpoint for the conversion pipeline. For more information on using Orbax, please refer to https://github.com/google/orbax @@ -54,9 +54,9 @@ For more information on using Orbax, please refer to https://github.com/google/o **What is it?** Grain is a high-performance data loading library designed for deterministic, global shuffle and multi-host data loading. -1. **Deterministic by Design**: Grain allows storing data loader states, provides strong guarantees about data ordering and sharding even with preemptions, which is critical for reproducibility. +1. **Deterministic by Design**: Grain allows storing data loader states, provides strong guarantees about data ordering and sharding even with preemptions, which is critical for reproducibility. 2. **Global Shuffle**: Prevents local overfitting. -3. **Built for Multi-Host Training**: The using random access file format streamlines [data loading in the multi-host environments](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline.md#multihost-dataloading-best-practice). +3. **Built for Multi-Host Training**: The using random access file format streamlines [data loading in the multi-host environments](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline.md#multihost-dataloading-best-practice). Its APIs are explicitly designed for the multi-host paradigm, simplifying the process of ensuring that each host loads a unique shard of the global batch. @@ -66,8 +66,8 @@ For more information on using Grain, please refer to https://github.com/google/g **What is it?** Qwix is a Jax quantization library supporting Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ) -1. **Enables State-of-the-Art Techniques**: It provides core quantization formats (e.g., int8 & fp8) and functions to `quantize` and `dequantize` tensors, which are essential for modern efficient training methods. -2. **JAX-Native Integration**: Its operations and data types are designed to work seamlessly with JAX's transformations (`jit`, `pmap`) and PyTree data structures. +1. **Enables State-of-the-Art Techniques**: It provides core quantization formats (e.g., int8 & fp8) and functions to `quantize` and `dequantize` tensors, which are essential for modern efficient training methods. +2. **JAX-Native Integration**: Its operations and data types are designed to work seamlessly with JAX's transformations (`jit`, `pmap`) and PyTree data structures. We chose Qwix because it provides the necessary primitives **natively within the JAX ecosystem**. Using a non-native library would require inefficient "boundary crossing" to move data in and out of JAX's control. Qwix's functions are just another JAX operation, allowing them to be composed and JIT-compiled along with the rest of the model. @@ -81,11 +81,11 @@ For more information on how to quantize your model using Qwix, please refer to h MaxText leverages Tunix as its core library for post-training, offering a unified and high-performance platform for adapting base models. -1. **Unified Post-Training Framework**: Tunix provides a consistent API and infrastructure for various post-training techniques, reducing the need for separate implementations for SFT, RL, and PEFT. -2. **State-of-the-Art RL Integration**: Tunix integrates with vLLM for efficient RL sampling, enabling advanced algorithms like Group Relative Policy Optimization (GRPO). This allows for fine-tuning models based on complex reward signals. -3. **NNX Compatibility**: Tunix is designed to work with NNX, the latest generation of Flax, allowing it to leverage the newest JAX features and a more modern API. -4. **Modularity for PEFT**: While offering full fine-tuning, Tunix also maintains strong support for PEFT methods. Techniques like LoRA are implemented as composable Flax/NNX modules, allowing easy application to existing models without altering their core structure. +1. **Unified Post-Training Framework**: Tunix provides a consistent API and infrastructure for various post-training techniques, reducing the need for separate implementations for SFT, RL, and PEFT. +2. **State-of-the-Art RL Integration**: Tunix integrates with vLLM for efficient RL sampling, enabling advanced algorithms like Group Relative Policy Optimization (GRPO). This allows for fine-tuning models based on complex reward signals. +3. **NNX Compatibility**: Tunix is designed to work with NNX, the latest generation of Flax, allowing it to leverage the newest JAX features and a more modern API. +4. **Modularity for PEFT**: While offering full fine-tuning, Tunix also maintains strong support for PEFT methods. Techniques like LoRA are implemented as composable Flax/NNX modules, allowing easy application to existing models without altering their core structure. -We chose Tunix because it provides a **comprehensive, performant, and JAX-native solution for the entire post-training lifecycle**. Its integration with libraries like vLLM and its alignment with the NNX ecosystem make it a powerful tool for both full model adaptation and parameter-efficient tuning. +We chose Tunix because it provides a **comprehensive, performant, and JAX-native solution for the entire post-training lifecycle**. Its integration with libraries like vLLM and its alignment with the NNX ecosystem make it a powerful tool for both full model adaptation and parameter-efficient tuning. For more information on using Tunix, please refer to https://github.com/google/tunix diff --git a/docs/reference/core_concepts.md b/docs/reference/core_concepts.md index 07f861c671..4c03c48ddc 100644 --- a/docs/reference/core_concepts.md +++ b/docs/reference/core_concepts.md @@ -16,45 +16,45 @@ # Core concepts -::::{grid} 1 2 2 2 +::::\{grid} 1 2 2 2 :gutter: 2 -:::{grid-item-card} 💾 Checkpoints +:::\{grid-item-card} 💾 Checkpoints :link: core_concepts/checkpoints :link-type: doc Understanding checkpoint formats and strategies. ::: -:::{grid-item-card} ⚖️ Alternatives +:::\{grid-item-card} ⚖️ Alternatives :link: core_concepts/alternatives :link-type: doc Comparison with other frameworks like Megatron-LM. ::: -:::{grid-item-card} 📉 Quantization +:::\{grid-item-card} 📉 Quantization :link: core_concepts/quantization :link-type: doc Techniques for reducing model size and improving performance. ::: -:::{grid-item-card} 🧱 Tiling +:::\{grid-item-card} 🧱 Tiling :link: core_concepts/tiling :link-type: doc Understanding tiling strategies for partitioning logic. ::: -:::{grid-item-card} ⚡ JAX/XLA/Pallas +:::\{grid-item-card} ⚡ JAX/XLA/Pallas :link: core_concepts/jax_xla_and_pallas :link-type: doc How MaxText leverages JAX, XLA, and Pallas for efficiency. ::: -:::{grid-item-card} 🧠 MoE Configuration +:::\{grid-item-card} 🧠 MoE Configuration :link: core_concepts/moe_configuration :link-type: doc @@ -63,9 +63,10 @@ Configuring Mixture of Experts (MoE) models. :::: ```{toctree} -:hidden: -:maxdepth: 1 - +--- +hidden: +maxdepth: 1 +--- core_concepts/checkpoints.md core_concepts/alternatives.md core_concepts/quantization.md diff --git a/docs/reference/core_concepts/alternatives.md b/docs/reference/core_concepts/alternatives.md index e4dba355d7..3b16dc1a3d 100644 --- a/docs/reference/core_concepts/alternatives.md +++ b/docs/reference/core_concepts/alternatives.md @@ -19,5 +19,3 @@ MaxText is similar to [Nvidia/Megatron-LM](https://github.com/NVIDIA/Megatron-LM), a very well tuned LLM implementation targeting Nvidia GPUs. The two implementations achieve comparable MFUs. The difference in the codebases highlights the different programming strategies. MaxText is pure Python, relying heavily on the XLA compiler to achieve high performance. By contrast, Megatron-LM is a mix of Python and CUDA, relying on well-optimized CUDA kernels to achieve high performance. MaxText is heavily inspired by [MinGPT](https://github.com/karpathy/minGPT)/[NanoGPT](https://github.com/karpathy/nanoGPT), elegant standalone GPT implementations written in PyTorch and targeting Nvidia GPUs. MaxText is more complex, supporting more industry standard models and scaling to tens of thousands of chips. Ultimately MaxText has an MFU more than three times the [17%](https://twitter.com/karpathy/status/1613250489097027584?cxt=HHwWgIDUhbixteMsAAAA) reported most recently with that codebase, is massively scalable and implements a key-value cache for efficient auto-regressive decoding. - - diff --git a/docs/reference/core_concepts/checkpoints.md b/docs/reference/core_concepts/checkpoints.md index 08bffa9861..f918cf6dbf 100644 --- a/docs/reference/core_concepts/checkpoints.md +++ b/docs/reference/core_concepts/checkpoints.md @@ -98,6 +98,6 @@ Furthermore, MaxText supports emergency checkpointing, which saves a local copy - `local_checkpoint_directory`: The local path for storing emergency checkpoints. - `local_checkpoint_period`: The interval, in training steps, for saving local checkpoints. -More configs about checkpoints can be found in [here](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/maxtext/configs/base.yml#L23-L65). +More configs about checkpoints can be found in [here](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/configs/base.yml#L23-L65). For practical guides on checkpointing, please refer to [](checkpointing_solutions). diff --git a/docs/reference/core_concepts/jax_xla_and_pallas.md b/docs/reference/core_concepts/jax_xla_and_pallas.md index 8d080618de..bae02d343d 100644 --- a/docs/reference/core_concepts/jax_xla_and_pallas.md +++ b/docs/reference/core_concepts/jax_xla_and_pallas.md @@ -13,12 +13,12 @@ MaxText builds on the following core technologies: The following table provides a high-level overview to scaffold understanding before a more detailed exploration. -| Technology | Role in MaxText | Key Benefit for LLM Training | -| :---- | :---- | :---- | -| JAX | Programming Model & Transformations | Enables scalable, composable, and differentiable model definitions in pure Python. | -| JAX Pallas | Custom Kernel Language | Allows for hand-tuned, hardware-specific kernels for peak performance on novel operations (e.g., MoE, custom attention). | -| XLA | JAX Compiler | Automatically fuses operations and compiles HLO code, emitted by JAX into optimized LLO machine code for TPUs/GPUs. | -| Mosaic | Pallas Compiler | Compiles the Mosaic IR code emitted by JAX Pallas into LLO | +| Technology | Role in MaxText | Key Benefit for LLM Training | +| :--------- | :---------------------------------- | :----------------------------------------------------------------------------------------------------------------------- | +| JAX | Programming Model & Transformations | Enables scalable, composable, and differentiable model definitions in pure Python. | +| JAX Pallas | Custom Kernel Language | Allows for hand-tuned, hardware-specific kernels for peak performance on novel operations (e.g., MoE, custom attention). | +| XLA | JAX Compiler | Automatically fuses operations and compiles HLO code, emitted by JAX into optimized LLO machine code for TPUs/GPUs. | +| Mosaic | Pallas Compiler | Compiles the Mosaic IR code emitted by JAX Pallas into LLO | ## 1. JAX: the high-performance engine of MaxText @@ -58,7 +58,7 @@ JAX provides powerful, high-level abstractions for SPMD programming. [`jax.vmap` This unified scalability model is a key advantage of the JAX ecosystem. Training large LLMs requires different types of parallelism—data parallelism (splitting the batch), tensor parallelism (splitting a single matrix multiplication), and pipeline parallelism (splitting layers across devices). In many frameworks, implementing these requires different APIs, libraries, or coding patterns, which adds significant complexity. In JAX, all these forms of parallelism can be expressed through the single, unified concept of sharding tensors over a logical device mesh. For data parallelism, one shards the batch dimension of the input data. For tensor parallelism, one shards the weight matrices along their feature or output dimensions. -MaxText leverages this unification to great effect. The core model code remains largely agnostic to the parallelism strategy. Scalability is controlled primarily by changing the level of each kind of parallelism in configuration files. This abstraction is a primary reason MaxText can be described as both "simple" and "massively scalable," as the immense complexity of distributed execution is handled by JAX and the XLA compiler, rather than the user. +MaxText leverages this unification to great effect. The core model code remains largely agnostic to the parallelism strategy. Scalability is controlled primarily by changing the level of each kind of parallelism in configuration files. This abstraction is a primary reason MaxText can be described as both "simple" and "massively scalable," as the immense complexity of distributed execution is handled by JAX and the XLA compiler, rather than the user. ### 1.4. Composability: the JAX superpower diff --git a/docs/reference/core_concepts/moe_configuration.md b/docs/reference/core_concepts/moe_configuration.md index c5e42b7153..7ce7d63110 100644 --- a/docs/reference/core_concepts/moe_configuration.md +++ b/docs/reference/core_concepts/moe_configuration.md @@ -16,7 +16,7 @@ # Mixture of Experts (MoE) Configuration -This document provides a detailed explanation of the configuration parameters related to Mixture of Experts (MoE) models in MaxText. These settings control the model architecture, routing mechanisms, and performance optimizations. Default values and parameter definitions are located in `src/maxtext/configs/base.yml` and are primarily used in `src/MaxText/layers/moe.py`. +This document provides a detailed explanation of the configuration parameters related to Mixture of Experts (MoE) models in MaxText. These settings control the model architecture, routing mechanisms, and performance optimizations. Default values and parameter definitions are located in `src/maxtext/configs/base.yml` and are primarily used in `src/maxtext/layers/moe.py`. ## 1. Architecture @@ -30,7 +30,7 @@ MaxText supports both Dropless and Dropping strategies. Please refer to the deci Dropless: - [Tokamax Ragged Dot](https://github.com/openxla/tokamax/tree/main/tokamax/_src/ops/ragged_dot): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=True`. -- [Megablox](https://github.com/google/maxtext/tree/main/src/MaxText/kernels/megablox): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=True`. +- [Megablox](https://github.com/google/maxtext/tree/main/src/maxtext/kernels/megablox): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=True`. - [JAX Ragged Dot](https://docs.jax.dev/en/latest/_autosummary/jax.lax.ragged_dot.html): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=False`. - Dense Matmul: Enabled by setting `sparse_matmul=False, capacity_factor=-1`. diff --git a/docs/reference/core_concepts/tiling.md b/docs/reference/core_concepts/tiling.md index f477cbcb5c..5d203e31aa 100644 --- a/docs/reference/core_concepts/tiling.md +++ b/docs/reference/core_concepts/tiling.md @@ -16,38 +16,35 @@ # Tiling - Often in high-performance computing, there's a trade-off between memory usage and computation time. **Tiling** (also known as chunking) is an optimization technique that prioritizes reducing memory usage, but sometimes at a cost to running time. The core idea is to partition a large tensor into smaller, more manageable blocks called **tiles**. Instead of loading the entire tensor into memory, a program processes these tiles sequentially. This significantly lowers the peak memory required to perform an operation. While processing data in chunks can introduce minor computational overhead, an efficient tiling strategy minimizes this cost. By preventing out-of-memory errors, tiling enables programs to handle larger problems than would otherwise be possible. This often leads to better hardware utilization and improved end-to-end performance metrics, such as MFU. - ## The Concept of Tiling The effectiveness of tiling stems from the **linearity** of many operations in LLMs. Operations like matrix multiplication and gradient calculation can be broken down into smaller, independent sub-problems. For instance, consider the matrix multiplication operation: -$$A[M, N] \times B[N, 1] = C[M, 1]$$ +$$A[M, N] \\times B[N, 1] = C[M, 1]$$ -If the matrices `A` is too large to fit into memory, you can tile the operation. By splitting matrix `A` into $K$ smaller chunks along its `M` dimension ($A_0[M/K, N], \dots, A_{K-1}[M/K, N]$), you can load each chunk separately and compute a corresponding portion of the output matrix `C` ($C_0[M/K, 1], \dots, C_{K-1}[M/K, 1]$) follows +If the matrices `A` is too large to fit into memory, you can tile the operation. By splitting matrix `A` into $K$ smaller chunks along its `M` dimension ($A_0[M/K, N], \\dots, A\_{K-1}[M/K, N]$), you can load each chunk separately and compute a corresponding portion of the output matrix `C` ($C_0[M/K, 1], \\dots, C\_{K-1}[M/K, 1]$) follows -$$A_i[M/K, N]\times B[N, 1] = C_i[M/K, 1] \quad \forall i=0, \dots, K-1.$$ +$$A_i[M/K, N]\\times B[N, 1] = C_i[M/K, 1] \\quad \\forall i=0, \\dots, K-1.$$ Finally, you can concatenate the smaller `C` matrices to form the complete result. This principle extends to the backward pass as well. Instead of computing the full gradient `dA[M, N]`, which also exceeds memory capacity, you can compute the gradient for each tile individually: -$$dC_i[M/K, 1] \times B^\intercal[1, N] = dA_i[M/K, N] \quad \forall i=0,\dots,K-1.$$ +$$dC_i[M/K, 1] \\times B^\\intercal[1, N] = dA_i[M/K, N] \\quad \\forall i=0,\\dots,K-1.$$ Similarly, the gradient on `B` is the accumulation -$$\sum_{i=0}^{K-1}dC_i^\intercal[1, M/K] \times A_i[M/K, N] = dB[N, 1] \quad \forall i=0,\dots,K-1.$$ - -This tiling approach reduces the peak memory usage from $\mathcal{O}(MN)$ to $\mathcal{O}(MN/K)$, which facilitates model training with limited memory resources. +$$\\sum\_{i=0}^{K-1}dC_i^\\intercal[1, M/K] \\times A_i[M/K, N] = dB[N, 1] \\quad \\forall i=0,\\dots,K-1.$$ +This tiling approach reduces the peak memory usage from $\\mathcal{O}(MN)$ to $\\mathcal{O}(MN/K)$, which facilitates model training with limited memory resources. ## Tiling in MaxText @@ -62,7 +59,6 @@ GA reduces the size of activations in memory at any given moment, which is cruci ![Illustration of gradient accumulation.](../../_static/gradient_accum.png) *Figure 1: Gradient accumulation tiles a global batch into smaller micro-batches.* - ### Vocabulary Tiling Vocabulary tiling is another memory-saving technique designed to handle the large vocabulary sizes in modern language models. @@ -82,6 +78,6 @@ Tiling is also crucial for managing data movement across the memory hierarchy (H ## Tiling vs. Sharding -**Tiling** and **sharding** are independent concepts that do not conflict; in fact, they are often used together. Sharding distributes a tensor across multiple devices, while tiling processes a tensor in chunks on the same device. +**Tiling** and **sharding** are independent concepts that do not conflict; in fact, they are often used together. Sharding distributes a tensor across multiple devices, while tiling processes a tensor in chunks on the same device. -To learn more about sharding in MaxText, please refer to the [sharding documentation](https://maxtext.readthedocs.io/en/latest/explanations/sharding.html). \ No newline at end of file +To learn more about sharding in MaxText, please refer to the [sharding documentation](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/optimization/sharding.html). diff --git a/docs/reference/models.md b/docs/reference/models.md index 060d40986b..6c1c65954d 100644 --- a/docs/reference/models.md +++ b/docs/reference/models.md @@ -1,16 +1,16 @@ # Models -::::{grid} 1 2 2 2 +::::\{grid} 1 2 2 2 :gutter: 2 -:::{grid-item-card} 🥇 Tiering +:::\{grid-item-card} 🥇 Tiering :link: models/tiering :link-type: doc Optimized model tiers (Gold, Silver) for various TPU generations. ::: -:::{grid-item-card} 🏗️ Supported Models +:::\{grid-item-card} 🏗️ Supported Models :link: models/supported_models_and_architectures :link-type: doc @@ -19,9 +19,10 @@ List of all supported models and architectures. :::: ```{toctree} -:hidden: -:maxdepth: 1 - +--- +hidden: +maxdepth: 1 +--- models/tiering.md models/supported_models_and_architectures.md ``` diff --git a/docs/reference/models/supported_models_and_architectures.md b/docs/reference/models/supported_models_and_architectures.md index 5001b26f56..f329c9ba2b 100644 --- a/docs/reference/models/supported_models_and_architectures.md +++ b/docs/reference/models/supported_models_and_architectures.md @@ -80,12 +80,12 @@ The following summarizes observed runtime efficiency and scaling behaviors of Ma - **Model Implementation Guides & Source Code:** - - **Llama**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/llama2/run_llama2.md) | [Llama2 and Llama3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/llama2.py) | [Llama4 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/llama4.py) - - **Gemma**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma/Run_Gemma.md) | [Gemma Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/gemma.py) | [Gemma2 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/gemma2.py) | [Gemma3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/gemma3.py) - - **Mixtral**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/mixtral/Run_Mixtral.md) | [Mixtral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/mixtral.py) | [Mistral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/mistral.py) - - **DeepSeek**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md) | [DeepSeek Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/deepseek.py) - - **Qwen3**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/moe/run_qwen_moe.md) | [Qwen3-Next Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md) | [Qwen3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/qwen3.py) | [Qwen3-Next Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/qwen3.py) - - **GPT-OSS**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md) | [GPT-OSS Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/gpt_oss.py) + - **Llama**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/llama2/run_llama2.md) | [Llama2 and Llama3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/llama2.py) | [Llama4 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/llama4.py) + - **Gemma**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma/Run_Gemma.md) | [Gemma Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma.py) | [Gemma2 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma2.py) | [Gemma3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma3.py) + - **Mixtral**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/mixtral/Run_Mixtral.md) | [Mixtral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/mixtral.py) | [Mistral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/mistral.py) + - **DeepSeek**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md) | [DeepSeek Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/deepseek.py) + - **Qwen3**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/moe/run_qwen_moe.md) | [Qwen3-Next Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md) | [Qwen3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/qwen3.py) | [Qwen3-Next Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/qwen3.py) + - **GPT-OSS**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md) | [GPT-OSS Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gpt_oss.py) - **Technical Explanations:** diff --git a/docs/reference/models/tiering.md b/docs/reference/models/tiering.md index 7ee0373b40..b757c07e77 100644 --- a/docs/reference/models/tiering.md +++ b/docs/reference/models/tiering.md @@ -1,4 +1,3 @@ - # Optimized models tiering For each of the TPU platforms listed below, we present a list of optimized models[^1] [^2] for pre-training. If you’re getting started with MaxText, or want to push performance, we recommend choosing a Gold model, with an accompanying pre-training recipe. @@ -11,33 +10,33 @@ For each of the TPU platforms listed below, we present a list of optimized model ### Gold -| Model | Recipe | Benchmark Configuration | MFU | Approx tokens/sec/device | -| :--- | :--- | :--- | :--- | :--- | -| Llama 2 70B | [Link](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/training/trillium/Llama2-70B-MaxText) | 256, BF16, SL=4096 | 43.8% | 900 | -| Llama 3.1 8B | [Link](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/training/trillium/Llama3.1-8B-MaxText/v6e-256) | 256 Chips, BF16, SL=8192 | 45.46% | 7,207 | -| Llama 3.1 70B | [Link](https://github.com/AI-Hypercomputer/maxtext/blob/92e59fdf547421f647590087f50fea5729da42d8/benchmarks/maxtext_trillium_model_configs.py#L959) | 256 Chips, BF16, SL=8192 | 50.33% | 960 | +| Model | Recipe | Benchmark Configuration | MFU | Approx tokens/sec/device | +| :------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------- | :----------------------- | :----- | :----------------------- | +| Llama 2 70B | [Link](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/training/trillium/Llama2-70B-MaxText) | 256, BF16, SL=4096 | 43.8% | 900 | +| Llama 3.1 8B | [Link](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/training/trillium/Llama3.1-8B-MaxText/v6e-256) | 256 Chips, BF16, SL=8192 | 45.46% | 7,207 | +| Llama 3.1 70B | [Link](https://github.com/AI-Hypercomputer/maxtext/blob/92e59fdf547421f647590087f50fea5729da42d8/benchmarks/maxtext_trillium_model_configs.py#L959) | 256 Chips, BF16, SL=8192 | 50.33% | 960 | ### Silver -| Model | Recipe | Benchmark Configuration | MFU | Approx tokens/sec/device | -| :--- | :--- | :--- | :--- | :--- | -| Llama 3.1 405B | [Link](https://github.com/AI-Hypercomputer/maxtext/blob/5e6a7caff904f67fa654fc0ae983a16156bc21f8/benchmarks/maxtext_trillium_model_configs.py#L723) | 256 Chips, BF16, SL=8192 | 38.55% | 123 | -| Mixtral 8X7B | [Link](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/training/trillium/Mixtral-8x7B-MaxText) | 256 Chips, BF16, SL=4096 | 35.23% | 3,899 | -| Mixtral 8X22B | [Link](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/training/trillium/Mixtral-8x22B-MaxText) | 256 Chips, BF16, SL=4096 | 36.2% | 1,326 | +| Model | Recipe | Benchmark Configuration | MFU | Approx tokens/sec/device | +| :------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------- | :----------------------- | :----- | :----------------------- | +| Llama 3.1 405B | [Link](https://github.com/AI-Hypercomputer/maxtext/blob/5e6a7caff904f67fa654fc0ae983a16156bc21f8/benchmarks/maxtext_trillium_model_configs.py#L723) | 256 Chips, BF16, SL=8192 | 38.55% | 123 | +| Mixtral 8X7B | [Link](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/training/trillium/Mixtral-8x7B-MaxText) | 256 Chips, BF16, SL=4096 | 35.23% | 3,899 | +| Mixtral 8X22B | [Link](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/training/trillium/Mixtral-8x22B-MaxText) | 256 Chips, BF16, SL=4096 | 36.2% | 1,326 | ## v5p ### Gold -| Model | Recipe | Benchmark Configuration | MFU | Approx tokens/sec/device | -| :--- | :--- | :--- | :--- | :--- | -| Llama 2 70B | [Link](https://github.com/AI-Hypercomputer/maxtext/blob/92e59fdf547421f647590087f50fea5729da42d8/benchmarks/maxtext_v5p_model_configs.py#L156) | 512 Chips, BF16, SL=4096 | 65.4% | 692 | +| Model | Recipe | Benchmark Configuration | MFU | Approx tokens/sec/device | +| :---------- | :--------------------------------------------------------------------------------------------------------------------------------------------- | :----------------------- | :---- | :----------------------- | +| Llama 2 70B | [Link](https://github.com/AI-Hypercomputer/maxtext/blob/92e59fdf547421f647590087f50fea5729da42d8/benchmarks/maxtext_v5p_model_configs.py#L156) | 512 Chips, BF16, SL=4096 | 65.4% | 692 | ### Silver -| Model | Recipe | Benchmark Configuration | MFU | Approx tokens/sec/device | -| :--- | :--- | :--- | :--- | :--- | -| Mixtral 8X7B | [Link](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/training/v5p/Mixtral-8X7B-Maxtext) | 256 Chips(8x4x4), bf16, SL=4096 | 52.56% | 2,909 | +| Model | Recipe | Benchmark Configuration | MFU | Approx tokens/sec/device | +| :----------- | :-------------------------------------------------------------------------------------------------- | :------------------------------ | :----- | :----------------------- | +| Mixtral 8X7B | [Link](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/training/v5p/Mixtral-8X7B-Maxtext) | 256 Chips(8x4x4), bf16, SL=4096 | 52.56% | 2,909 | -[^1]: Performance results are subject to variations based on system configuration, software versions, and other factors. These benchmarks represent point-in-time measurements under specific conditions. -[^2]: Some older TFLOPS/s results are impacted by an updated calculation for causal attention ([PR #1988](https://github.com/AI-Hypercomputer/maxtext/pull/1988)), which halves the attention FLOPs. This change particularly affects configurations with large sequence lengths. For more details, please refer to the [performance metrics guide](https://maxtext.readthedocs.io/en/latest/reference/performance_metrics.html). +\[^1\]: Performance results are subject to variations based on system configuration, software versions, and other factors. These benchmarks represent point-in-time measurements under specific conditions. +\[^2\]: Some older TFLOPS/s results are impacted by an updated calculation for causal attention ([PR #1988](https://github.com/AI-Hypercomputer/maxtext/pull/1988)), which halves the attention FLOPs. This change particularly affects configurations with large sequence lengths. For more details, please refer to the [performance metrics guide](https://maxtext.readthedocs.io/en/latest/reference/performance_metrics.html). diff --git a/docs/reference/performance_metrics.md b/docs/reference/performance_metrics.md index abd10ec996..c2319a5cc7 100644 --- a/docs/reference/performance_metrics.md +++ b/docs/reference/performance_metrics.md @@ -15,6 +15,7 @@ --> (performance-metrics)= + # Performance metrics ## MFU @@ -23,46 +24,48 @@ Model Flops Utilization (MFU) is one of the most commonly used metrics to summar ### Definition -Model FLOPs are the floating point operations required to perform model computations regardless of implementation or hardware limitations. +Model FLOPs are the floating point operations required to perform model computations regardless of implementation or hardware limitations. For training, this corresponds to the operations in a single forward and backward pass (one model step). -$$ MFU:= \frac{\text{model flops/s}}{\text{peak hardware flops/s}} $$ +$$ MFU:= \\frac{\\text{model flops/s}}{\\text{peak hardware flops/s}} $$ -Model flops are generally easy to calculate/estimate theoretically since the model is mostly performing matmuls, and so we can just sum up the flops of each matmul. For example a $[A,B] \times [B,C] = [A,C]$ matmul has $2ABC$ flops. Hence, to calculate the observed model flops/s we can sum up the theoretical flops required in a training step of the model and then divide by the measured step time (in seconds). +Model flops are generally easy to calculate/estimate theoretically since the model is mostly performing matmuls, and so we can just sum up the flops of each matmul. For example a $[A,B] \\times [B,C] = [A,C]$ matmul has $2ABC$ flops. Hence, to calculate the observed model flops/s we can sum up the theoretical flops required in a training step of the model and then divide by the measured step time (in seconds). - $$ MFU = \frac{\text{model flops/s}}{\text{peak hardware flops/s}} = \frac{\text{theoretical model flops per step}}{\text{measured step time} \times \text{peak hardware flops/s}}$$ +$$ MFU = \\frac{\\text{model flops/s}}{\\text{peak hardware flops/s}} = \\frac{\\text{theoretical model flops per step}}{\\text{measured step time} \\times \\text{peak hardware flops/s}}$$ Furthermore, since $$ -\frac{\text{theoretical model flops per step}} - {\text{peak hardware flops/s}} -= \text{theoretically optimal step time} +\\frac{\\text{theoretical model flops per step}} +{\\text{peak hardware flops/s}} += \\text{theoretically optimal step time} $$ we also get that: $$ -MFU = \frac{\text{theoretically optimal step time}} - {\text{measured step time}} +MFU = \\frac{\\text{theoretically optimal step time}} +{\\text{measured step time}} $$ Finally, we can also look at throughput utilization. In each training step, the model processes $(batch_size x seq_length)$ tokens. Since the (optimal or measured) number of tokens per second is just the number of tokens per step divided by step time (optimal or measured, respectively), we get that: $$ -MFU = \frac{\text{theoretically optimal step time}} - {\text{measured step time}} = \frac{\text{number of tokens per step / optimal tokens/s}} - {\text{number of tokens per step / measured tokens/s}} = \frac{\text{measured tokens/s}} - {\text{optimal tokens/s}} +MFU = \\frac{\\text{theoretically optimal step time}} +{\\text{measured step time}} = \\frac{\\text{number of tokens per step / optimal tokens/s}} +{\\text{number of tokens per step / measured tokens/s}} = \\frac{\\text{measured tokens/s}} +{\\text{optimal tokens/s}} $$ Hence, MFU is the fraction of peak hardware performance actually utilized by the model, and can be expressed in different units — step time, throughput, or raw flops/s. ### MaxText calculating + reporting + In MaxText, we sum all of the matmuls performed in one step, see [calculate_tflops_per_device](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/maxtext_utils.py#L454) and divide it by the measured (via python `time.time()`) step time. In each step we print the resulting Model Flops per second [`per_device_tflops_per_sec`](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/metric_logger.py#L211-L213). One can calculate the MFU by dividing this number by the peak tflops of the hardware (e.g., $918e^{12}$ FLOPS/s for Trillium). ### Causal attention + Due to causality only half of the (query, key) pairs need to be computed, those with query_idx >= key_idx. This accounts for the fact only prior tokens can be used to predict future ones. Prior to https://github.com/AI-Hypercomputer/maxtext/pull/1988 MaxText did not account for sparsity for theoretical flops, and used Attention Flops ~= 4 * sequence^2 * batch * heads * head_dim @@ -75,27 +78,30 @@ Which maxtext now uses since this [PR/1988](https://github.com/AI-Hypercomputer/ Note that -$$ \text{Total Flops} = \text{Attention (quadratic in sequence) + Non-attention (linear)}$$ +$$ \\text{Total Flops} = \\text{Attention (quadratic in sequence) + Non-attention (linear)}$$ Thus the distinction between causal vs non causal flops is particularly important for long sequence when the attention flops dominate / are a significant fraction of the total flops. For 8k sequence length, the attention flops are generally around 10% of total flops (depending on exact model dims), whereas for 128k seq, the attention flops may be around 90%. Note however the attention flops also vary by attention type, e.g. sliding window flops are not quadratic in sequence, but are only linear in both sequence length and window length. We updated our model flops calculation to account for sliding window attention and chunked attention in [PR 2009](https://github.com/AI-Hypercomputer/maxtext/pull/2009) and [PR 2030](https://github.com/AI-Hypercomputer/maxtext/pull/2030). ### Why MFU + MFU is a very useful metric to understand your systems performance, but like step time or tokens/s, there are pros and cons of summarizing the system’s performance to a single number. **Pros** -* Clearly shows room left to improve, how much more the hardware is capable of. (e.g. 25% MFU means it's possible to get 4x more performance and 4x smaller step times). Note that achieving 100% is not practical due to many factors, but MFU score effectively shows how much room is left for optimization. -* Generalizable across hardwares, model, configs (e.g. batch sizes) + +- Clearly shows room left to improve, how much more the hardware is capable of. (e.g. 25% MFU means it's possible to get 4x more performance and 4x smaller step times). Note that achieving 100% is not practical due to many factors, but MFU score effectively shows how much room is left for optimization. +- Generalizable across hardwares, model, configs (e.g. batch sizes) **Cons** -* Care needs to be token to compare MFU across codebases that the model flops calculation are identcail (e.g. was causality taken into account in both code bases?) + +- Care needs to be token to compare MFU across codebases that the model flops calculation are identcail (e.g. was causality taken into account in both code bases?) Step time, tokens/s, and MFU all can be used to calculate how long training will take (e.g. how long will it take to train my model on $T$ tokens given $C$ chips?) -$$\begin{align*} -\text{training time} &= \text{step time} \times \text{num steps} \\ - &= \frac{T tokens}{\text{measured tokens per second per chip} \times C} \\ - &= \frac{\text{theoretical flops to train T tokens}}{\text{MFU} \times C \times \text{chip peak speed}} -\end{align*}$$ +$$\\begin{align\*} +\\text{training time} &= \\text{step time} \\times \\text{num steps} \\ +&= \\frac{T tokens}{\\text{measured tokens per second per chip} \\times C} \\ +&= \\frac{\\text{theoretical flops to train T tokens}}{\\text{MFU} \\times C \\times \\text{chip peak speed}} +\\end{align\*}$$ This shows any of step time, tokens/s or MFU can be used to determine how long training will take and are proportionally (or inversely proportionally) related. MFU is most useful to compare across different models/hardwares and while optimizing performance, whereas step time or tokens/second may be more useful when these are fixed. diff --git a/docs/run_maxtext.md b/docs/run_maxtext.md index a52a1351e3..e840892b6d 100644 --- a/docs/run_maxtext.md +++ b/docs/run_maxtext.md @@ -2,38 +2,38 @@ Choose your environment and orchestration method to run MaxText. -::::{grid} 1 2 2 2 +::::\{grid} 1 2 2 2 :gutter: 2 -:::{grid-item-card} 💻 Localhost / Single VM +:::\{grid-item-card} 💻 Localhost / Single VM :link: run_maxtext/run_maxtext_localhost :link-type: doc Get started quickly on a single machine. Clone the repo, install dependencies, and run your first training job on a single TPU or GPU VM. ::: -:::{grid-item-card} 🎮 Single-host GPU +:::\{grid-item-card} 🎮 Single-host GPU :link: run_maxtext/run_maxtext_single_host_gpu :link-type: doc Run MaxText on single-host NVIDIA GPUs (e.g., A3 High/Mega). Includes Docker setup, NVIDIA Container Toolkit installation, and 1B/7B model training examples. ::: -:::{grid-item-card} 🏗️ At scale with XPK (GKE) +:::\{grid-item-card} 🏗️ At scale with XPK (GKE) :link: run_maxtext/run_maxtext_via_xpk :link-type: doc Deploy to Google Kubernetes Engine (GKE) using XPK. Orchestrate large-scale training jobs on TPU or GPU clusters with simple CLI commands. ::: -:::{grid-item-card} 🌐 Multi-host via Pathways +:::\{grid-item-card} 🌐 Multi-host via Pathways :link: run_maxtext/run_maxtext_via_pathways :link-type: doc Run large-scale JAX jobs on TPUs using Pathways. Supports batch and headless (interactive) workloads on GKE. ::: -:::{grid-item-card} 🔌 Decoupled Mode +:::\{grid-item-card} 🔌 Decoupled Mode :link: run_maxtext/decoupled_mode :link-type: doc @@ -42,9 +42,10 @@ Run tests and local development without Google Cloud dependencies (no `gcloud`, :::: ```{toctree} -:hidden: -:maxdepth: 1 - +--- +hidden: +maxdepth: 1 +--- run_maxtext/run_maxtext_localhost.md run_maxtext/run_maxtext_single_host_gpu.md run_maxtext/run_maxtext_via_xpk.md diff --git a/docs/tutorials.md b/docs/tutorials.md index f171edae24..ffafa83865 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -18,31 +18,31 @@ Explore our tutorials to learn how to use MaxText, from your first run to advanced post-training techniques. -::::{grid} 1 2 2 2 +::::\{grid} 1 2 2 2 :gutter: 2 -:::{grid-item-card} 🚀 Getting Started +:::\{grid-item-card} 🚀 Getting Started :link: tutorials/first_run :link-type: doc Installation, prerequisites, verification, and your first training run. ::: -:::{grid-item-card} 📚 Pre-training +:::\{grid-item-card} 📚 Pre-training :link: tutorials/pretraining :link-type: doc Step-by-step guides for pre-training with real datasets like C4 using HuggingFace, Grain, or TFDS. ::: -:::{grid-item-card} 🧩 Post-training +:::\{grid-item-card} 🧩 Post-training :link: tutorials/post_training_index :link-type: doc Techniques for SFT, RL, and other post-training workflows on TPU. ::: -:::{grid-item-card} 📊 Inference +:::\{grid-item-card} 📊 Inference :link: tutorials/inference :link-type: doc @@ -51,9 +51,10 @@ Step-by-step guides for running inference of MaxText models on vLLM. :::: ```{toctree} -:hidden: -:maxdepth: 1 - +--- +hidden: +maxdepth: 1 +--- tutorials/first_run.md tutorials/pretraining.md tutorials/post_training_index.md diff --git a/docs/tutorials/first_run.md b/docs/tutorials/first_run.md index 3b7468129b..58b71f80ac 100644 --- a/docs/tutorials/first_run.md +++ b/docs/tutorials/first_run.md @@ -66,7 +66,7 @@ In the same TPU VM where you just installed all the dependencies of MaxText, You #### Decoding in MaxText via notebook -You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint. +You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint. ### Run MaxText on NVIDIA GPUs diff --git a/docs/tutorials/post_training_index.md b/docs/tutorials/post_training_index.md index 685d5d9563..46a68f830b 100644 --- a/docs/tutorials/post_training_index.md +++ b/docs/tutorials/post_training_index.md @@ -9,6 +9,7 @@ We’re investing in performance, scale, algorithms, models, reliability, and ea ## The MaxText stack MaxText was co-designed with key Google led innovations to provide a unified post training experience: + - [MaxText model library](https://maxtext.readthedocs.io/en/latest/index.html#model-library) for JAX LLMs highly optimized for TPUs - [Tunix](https://github.com/google/tunix) for the latest algorithms and post-training techniques - [vLLM on TPU](https://github.com/vllm-project/tpu-inference) for high performance sampling (inference) for Reinforcement Learning (RL) @@ -19,13 +20,13 @@ MaxText was co-designed with key Google led innovations to provide a unified pos ## Supported techniques & models - **SFT (Supervised Fine-Tuning)** - * [SFT on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html) - * [SFT on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft_on_multi_host.html) + - [SFT on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html) + - [SFT on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft_on_multi_host.html) - **Multimodal SFT** - * [Multimodal Support](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/multimodal.html) + - [Multimodal Support](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/multimodal.html) - **Reinforcement Learning (RL)** - * [RL on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/rl.html) - * [RL on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/rl_on_multi_host.html) + - [RL on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/rl.html) + - [RL on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/rl_on_multi_host.html) ## Step by step RL @@ -42,10 +43,11 @@ Pathways is a single controller JAX runtime that was [designed and pressure test Pathways allows for fine grained resource allocation (subslice of a physical slice) and scheduling. This allows JAX developers to explore novel model architectures in an easy to develop single controller programming environment. Pathways supercharges RL with: + 1. **Multi-host Model Support:** Easily manages models that span multiple hosts. -1. **Unified Orchestration:** Controls both trainers and samplers from a single Python process. -1. **Efficient Data Transfer:** Optimally moves data between training and inference devices, utilizing ICI or DCN as needed. JAX Reshard primitives simplify integration. -1. **Flexible Resource Allocation:** Enables dedicating different numbers of accelerators to inference and training within the same job, adapting to workload bottlenecks (disaggregated setup). +2. **Unified Orchestration:** Controls both trainers and samplers from a single Python process. +3. **Efficient Data Transfer:** Optimally moves data between training and inference devices, utilizing ICI or DCN as needed. JAX Reshard primitives simplify integration. +4. **Flexible Resource Allocation:** Enables dedicating different numbers of accelerators to inference and training within the same job, adapting to workload bottlenecks (disaggregated setup). ## Getting started @@ -54,8 +56,9 @@ Start your Post-Training journey through quick experimentation with [Python Note ## More tutorials ```{toctree} -:maxdepth: 1 - +--- +maxdepth: 1 +--- posttraining/sft.md posttraining/sft_on_multi_host.md posttraining/rl.md diff --git a/docs/tutorials/posttraining/full_finetuning.md b/docs/tutorials/posttraining/full_finetuning.md index 45f6e9eb5b..8ebf2f525b 100644 --- a/docs/tutorials/posttraining/full_finetuning.md +++ b/docs/tutorials/posttraining/full_finetuning.md @@ -24,29 +24,19 @@ In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get st ## Install dependencies -```sh -# 1. Clone the repository -git clone https://github.com/AI-Hypercomputer/maxtext.git -cd maxtext - -# 2. Create virtual environment -export VENV_NAME= # e.g., maxtext_venv -pip install uv -uv venv --python 3.12 --seed ${VENV_NAME?} -source ${VENV_NAME?}/bin/activate - -# 3. Install dependencies in editable mode -uv pip install -e .[tpu] --resolution=lowest -install_maxtext_github_deps -``` +For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/install_maxtext.html) and use the `maxtext[tpu]` installation path to include all necessary dependencies. ## Setup environment variables +Follow the instructions [here](https://huggingface.co/docs/huggingface_hub/v0.21.2/guides/cli) to login to Hugging Face using your access token using + +```bash +huggingface-cli login +``` + ```sh # -- Model configuration -- -export MODEL_NAME= # e.g., 'llama3.1-8b' -export MODEL_TOKENIZER= # e.g., 'meta-llama/Llama-3.1-8B-Instruct' -export HF_TOKEN= +export MODEL= # e.g., 'llama3.1-8b-Instruct' # -- MaxText configuration -- export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-output-directory @@ -62,15 +52,15 @@ This section explains how to prepare your model checkpoint for use with MaxText. If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section. ```sh -export MODEL_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items +export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items ``` ### Option 2: Converting a Hugging Face checkpoint -Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. +Refer the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. ```bash -export MODEL_CKPT_PATH= # gs://my-bucket/my-checkpoint-directory/0/items +export MAXTEXT_CKPT_PATH= # gs://my-bucket/my-checkpoint-directory/0/items ``` ## Dataset @@ -103,12 +93,10 @@ Below is a sample training script. python3 -m maxtext.trainers.pre_train.train \ run_name=${RUN_NAME?} \ base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ - load_parameters_path=${MODEL_CKPT_PATH?} \ - model_name=${MODEL_NAME?} \ + load_parameters_path=${MAXTEXT_CKPT_PATH?} \ + model_name=${MODEL?} \ dataset_path=${DATASET_GCS_BUCKET?} \ async_checkpointing=False \ - tokenizer_path=${MODEL_TOKENIZER?} \ - hf_access_token=${HF_TOKEN?} \ steps=10 per_device_batch_size=1 ``` diff --git a/docs/tutorials/posttraining/multimodal.md b/docs/tutorials/posttraining/multimodal.md index 65bbc1a78d..bf1428a9ca 100644 --- a/docs/tutorials/posttraining/multimodal.md +++ b/docs/tutorials/posttraining/multimodal.md @@ -25,7 +25,7 @@ Multimodal Large Language Models (LLMs) extend traditional text-only models by i ## Checkpoint Conversion -Recently we have onboarded a new centralized tool for bidirectional checkpoint conversion between MaxText and HuggingFace ([README](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/README.md)). +Recently we have onboarded a new centralized tool for bidirectional checkpoint conversion between MaxText and HuggingFace ([README](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/README.md)). Install pytorch: @@ -153,12 +153,12 @@ python -m maxtext.trainers.post_train.sft.train_sft_deprecated \ - **Setting appropriate prefill length**: To prevent truncation and ensure your full input (text + image) is processed, the prefill length should be set longer than the total combined length of your text tokens and image tokens. This combined length makes up the final sequence fed to the decoder. We recommend to estimate the combined sequence length from your full input and then add a buffer when setting your `max_prefill_predict_length` for decoding. Token estimation rules: - For text tokens, a good estimate is: - $\text{Text Tokens} \approx 1.3 \times \text{Number of Words in Prompt}$. + $\\text{Text Tokens} \\approx 1.3 \\times \\text{Number of Words in Prompt}$. - For Gemma3, each image is resized to 896\*896 and contributes 256 tokens: - $\text{Total Tokens} \approx \text{Text Tokens} + \text{Number of Images} * 256$. + $\\text{Total Tokens} \\approx \\text{Text Tokens} + \\text{Number of Images} * 256$. - For Llama4 models, each image is dynamically tiled based on its size, with each resulting tile contributing 144 tokens: - $\text{Total Tokens} \approx \text{Text Tokens} + 144 \times \sum_{i=1}^{N} \text{Number of Tiles of Image}_i$. + $\\text{Total Tokens} \\approx \\text{Text Tokens} + 144 \\times \\sum\_{i=1}^{N} \\text{Number of Tiles of Image}\_i$. diff --git a/docs/tutorials/posttraining/rl.md b/docs/tutorials/posttraining/rl.md index d8cfc17a44..ec486684e2 100644 --- a/docs/tutorials/posttraining/rl.md +++ b/docs/tutorials/posttraining/rl.md @@ -87,7 +87,7 @@ export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucke ### Option 2: Converting from a Hugging Face checkpoint -Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. +Refer the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. ```bash export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items diff --git a/docs/tutorials/posttraining/rl_on_multi_host.md b/docs/tutorials/posttraining/rl_on_multi_host.md index d1c20a68b2..033400a172 100644 --- a/docs/tutorials/posttraining/rl_on_multi_host.md +++ b/docs/tutorials/posttraining/rl_on_multi_host.md @@ -101,7 +101,7 @@ export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucke ### Option 2: Converting from a Hugging Face checkpoint -Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. +Refer the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. ```bash export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items diff --git a/docs/tutorials/posttraining/sft.md b/docs/tutorials/posttraining/sft.md index c7ed9f45c5..ad7613df95 100644 --- a/docs/tutorials/posttraining/sft.md +++ b/docs/tutorials/posttraining/sft.md @@ -68,7 +68,7 @@ export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucke ### Option 2: Converting a Hugging Face checkpoint -Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. +Refer the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. ```sh export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items diff --git a/docs/tutorials/posttraining/sft_on_multi_host.md b/docs/tutorials/posttraining/sft_on_multi_host.md index b54819cee0..766dd30e63 100644 --- a/docs/tutorials/posttraining/sft_on_multi_host.md +++ b/docs/tutorials/posttraining/sft_on_multi_host.md @@ -92,7 +92,7 @@ checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) ### Option 2: Converting a Hugging Face checkpoint -Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. +Refer the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. ```bash export MAXTEXT_CKPT_PATH= # gs://my-bucket/my-checkpoint-directory/0/items diff --git a/src/MaxText/README.md b/src/MaxText/README.md index c720fed8b8..a6b1ff3f55 100644 --- a/src/MaxText/README.md +++ b/src/MaxText/README.md @@ -14,9 +14,9 @@ # limitations under the License. --> -# src/MaxText +# src/maxtext -The contents of `src/MaxText` have moved to `src/MaxText` as part of a larger +The contents of `src/MaxText` have moved to `src/maxtext` as part of a larger [restructuring effort in MaxText](https://github.com/AI-Hypercomputer/maxtext/blob/2790ed289c0c4cb704645d5d2ab91da26711b891/RESTRUCTURE.md). This directory only contains shim files to temporarily support legacy commands like `python3 -m MaxText.train ...`. These legacy commands are now deprecated and will be removed soon. Please migrate your existing commands and avoid using diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/llama_mistral_mixtral_orbax_to_hf.py b/src/maxtext/checkpoint_conversion/standalone_scripts/llama_mistral_mixtral_orbax_to_hf.py index 29ffb3ba28..3e0f1f4574 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/llama_mistral_mixtral_orbax_to_hf.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/llama_mistral_mixtral_orbax_to_hf.py @@ -25,7 +25,7 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.llama_mistral_mixtral_orbax_to_hf src/maxtext/configs/base.yml base_output_directory=path/to/saving/intermediate_MaxText_files - load_parameters_path=/path/to/src/MaxText/checkpoint run_name= model_name= + load_parameters_path=/path/to/src/maxtext/checkpoint run_name= model_name= hardware=gpu hf_model_path=/local/path/to/save/HF/model/to diff --git a/src/maxtext/checkpoint_conversion/to_huggingface.py b/src/maxtext/checkpoint_conversion/to_huggingface.py index 4aa5429deb..94c70860c7 100644 --- a/src/maxtext/checkpoint_conversion/to_huggingface.py +++ b/src/maxtext/checkpoint_conversion/to_huggingface.py @@ -44,7 +44,7 @@ To convert a gemma2-2b MaxText checkpoint and save it to a local directory: export HF_AUTH_TOKEN="hf_YOUR_TOKEN" - python src/MaxText/checkpoint_conversion/to_huggingface.py \ + python src/maxtext/checkpoint_conversion/to_huggingface.py \ src/maxtext/configs/base.yml \ model_name="gemma2-2b" \ load_parameters_path="/path/to/your/maxtext/checkpoint/" \ diff --git a/src/maxtext/checkpoint_conversion/to_maxtext.py b/src/maxtext/checkpoint_conversion/to_maxtext.py index 26dbeb214b..a77893df2f 100644 --- a/src/maxtext/checkpoint_conversion/to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/to_maxtext.py @@ -40,7 +40,7 @@ Example Usage: To convert a gemma2-2b model and save it to a specific directory: - /usr/bin/time -v python src/MaxText/checkpoint_conversion/to_maxtext.py \ + /usr/bin/time -v python src/maxtext/checkpoint_conversion/to_maxtext.py \ maxtext/configs/base.yml model_name="gemma2-2b" \ base_output_directory="/path/to/your/output/directory" \ hf_access_token=${HF_TOKEN?} hardware=cpu skip_jax_distributed_system=True \ @@ -51,7 +51,7 @@ To convert a 70B model with minimal RAM usage: - /usr/bin/time -v python src/MaxText/checkpoint_conversion/to_maxtext.py \ + /usr/bin/time -v python src/maxtext/checkpoint_conversion/to_maxtext.py \ maxtext/configs/base.yml model_name="llama3.1-70b" \ base_output_directory="gs://my-bucket/maxtext-checkpoints" \ hf_access_token=${HF_TOKEN?} hardware=cpu skip_jax_distributed_system=True \ diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 56c8ba4462..29ff6e7120 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -219,7 +219,7 @@ def _is_profiler_boundary_step(self, step): return step in boundary_steps def _maybe_abort_after_write_metrics(self, metrics): - """ This function checks whether we have nan or inf values in training""" + """This function checks whether we have nan or inf values in training""" loss = metrics["scalar"].get("learning/loss") if self.config.abort_on_nan_loss and np.isnan(loss): max_logging.log("Aborting training due to NaN loss.") @@ -227,7 +227,7 @@ def _maybe_abort_after_write_metrics(self, metrics): if self.config.abort_on_inf_loss and np.isinf(loss): max_logging.log("Aborting training due to Inf loss.") sys.exit(1) - + def write_metrics_locally(self, metrics, step): """Writes metrics locally for testing.""" with open(self.config.metrics_file, "a", encoding="utf8") as local_metrics_file: diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 814eefe0b5..3ff1c33153 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1141,7 +1141,7 @@ use_tokamax_splash: false use_jax_splash: false # vLLM Adapter Configurations -# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter) +# Path to the HuggingFace-style config directory for the adapter (e.g. src/maxtext/integration/vllm/maxtext_vllm_adapter) vllm_hf_config_path: "" # A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter. # This can be used to override specific settings without modifying the original config file. diff --git a/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_12288.sh b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_12288.sh index 926dff0abc..6f9cb366a8 100644 --- a/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_12288.sh +++ b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_12288.sh @@ -11,5 +11,5 @@ set -euox pipefail RUNNAME=${1:-${RUNNAME:-some-run}} BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} -chmod +x "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/v5p/gpt3_175b/gpt3_175b_base.sh +chmod +x "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}"/v5p/gpt3_175b/gpt3_175b_base.sh ./maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 16 48 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" \ No newline at end of file diff --git a/src/maxtext/examples/sft_llama3_demo_gpu.ipynb b/src/maxtext/examples/sft_llama3_demo_gpu.ipynb index aaa0fe5fcf..467f8b8917 100644 --- a/src/maxtext/examples/sft_llama3_demo_gpu.ipynb +++ b/src/maxtext/examples/sft_llama3_demo_gpu.ipynb @@ -1,603 +1,603 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "7687de92-1dfb-4237-b663-30cda55dc8e1", - "metadata": {}, - "source": [ - "# Supervised Fine-Tuning of Llama 3.1-8B on NVIDIA GPUs with JAX and MaxText\n", - "\n", - "## Overview\n", - "\n", - "This tutorial walks you through supervised fine-tuning (SFT) of Llama 3.1-8B on NVIDIA GPUs using JAX and MaxText. You'll learn how to take a pretrained Llama checkpoint, convert it into MaxText's native format, configure an SFT training run, and verify the result with a quick inference test.\n", - "\n", - "**What you'll do:**\n", - "1. Set up the environment and authenticate with Hugging Face\n", - "2. Download and convert the Llama 3.1-8B checkpoint to MaxText format\n", - "3. Configure and launch supervised fine-tuning on the UltraChat 200k dataset\n", - "4. Visualize training metrics with TensorBoard\n", - "5. Run a quick inference sanity check\n", - "\n", - "## Prerequisites\n", - "\n", - "### Make sure you have supported hardware\n", - "\n", - "**Hardware requirements.** Full-parameter SFT of Llama 3.1-8B is memory-intensive due to optimizer state, activations, and sharded model parameters. We recommend a system with **8 NVIDIA GPUs with at least 80 GB of memory each** (e.g., A100-80GB, H100-80GB, or H200). This allows the model, optimizer state, and activations to be cleanly sharded across devices without aggressive memory tuning.\n", - "\n", - "When running `nvidia-smi`, you should see eight or more visible GPUs, each reporting at least 80 GB of total memory, with recent drivers, CUDA 12.x+ support, and minimal memory usage before training starts." - ] + "cells": [ + { + "cell_type": "markdown", + "id": "7687de92-1dfb-4237-b663-30cda55dc8e1", + "metadata": {}, + "source": [ + "# Supervised Fine-Tuning of Llama 3.1-8B on NVIDIA GPUs with JAX and MaxText\n", + "\n", + "## Overview\n", + "\n", + "This tutorial walks you through supervised fine-tuning (SFT) of Llama 3.1-8B on NVIDIA GPUs using JAX and MaxText. You'll learn how to take a pretrained Llama checkpoint, convert it into MaxText's native format, configure an SFT training run, and verify the result with a quick inference test.\n", + "\n", + "**What you'll do:**\n", + "1. Set up the environment and authenticate with Hugging Face\n", + "2. Download and convert the Llama 3.1-8B checkpoint to MaxText format\n", + "3. Configure and launch supervised fine-tuning on the UltraChat 200k dataset\n", + "4. Visualize training metrics with TensorBoard\n", + "5. Run a quick inference sanity check\n", + "\n", + "## Prerequisites\n", + "\n", + "### Make sure you have supported hardware\n", + "\n", + "**Hardware requirements.** Full-parameter SFT of Llama 3.1-8B is memory-intensive due to optimizer state, activations, and sharded model parameters. We recommend a system with **8 NVIDIA GPUs with at least 80 GB of memory each** (e.g., A100-80GB, H100-80GB, or H200). This allows the model, optimizer state, and activations to be cleanly sharded across devices without aggressive memory tuning.\n", + "\n", + "When running `nvidia-smi`, you should see eight or more visible GPUs, each reporting at least 80 GB of total memory, with recent drivers, CUDA 12.x+ support, and minimal memory usage before training starts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dhc3l20703b", + "metadata": {}, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "id": "c7671875-fea3-4abe-9f01-57c854f50f92", + "metadata": {}, + "source": [ + "### Get your Hugging Face token\n", + "\n", + "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", + "\n", + "**Follow these steps to get your token:**\n", + "\n", + "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", + "\n", + "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", + "\n", + "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", + "\n", + "4. **Copy the generated token**. You will need this in the later steps.\n", + "\n", + "**Follow these steps to store your token (only if running on Google Colab):**\n", + "\n", + "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", + "\n", + "2. Click **\"+ Add new secret\"**.\n", + "\n", + "3. Set the Name as **HF_TOKEN**.\n", + "\n", + "4. Paste your token into the Value field.\n", + "\n", + "5. Ensure the Notebook access toggle is turned On." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7999f002-65a1-4764-ba41-922f2fec43df", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from google.colab import userdata\n", + " print(\"Running the notebook on Google Colab\")\n", + " IN_COLAB = True\n", + "except ImportError:\n", + " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", + " IN_COLAB = False" + ] + }, + { + "cell_type": "markdown", + "id": "6b691306-88ee-47b3-b841-cd6d072f51eb", + "metadata": {}, + "source": [ + "### Authenticate with Hugging Face\n", + "\n", + "Verify that your Hugging Face token is set and valid by calling the Hub's `whoami` endpoint. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93a0030b-ab25-4d76-a173-1a464c19087a", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from huggingface_hub import HfApi\n", + "\n", + "if IN_COLAB:\n", + " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", + "else:\n", + " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", + "\n", + "if not HF_TOKEN:\n", + " from getpass import getpass\n", + " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", + "\n", + "if not HF_TOKEN:\n", + " raise RuntimeError(\"Authentication failed: Hugging Face token is not set.\")\n", + "\n", + "# Ensure token is set in this process\n", + "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n", + "\n", + "# Verify identity\n", + "api = HfApi()\n", + "user_info = api.whoami(token=HF_TOKEN)\n", + "username = user_info.get(\"name\") or \"Unknown user\"\n", + "\n", + "print(f\"Authenticated with Hugging Face successfully as: {username}\")" + ] + }, + { + "cell_type": "markdown", + "id": "887cd139-a776-43ad-aeea-8547fcd8d744", + "metadata": {}, + "source": [ + "### Acquire permission to use the gated model\n", + "\n", + "Llama 3.1-8B is a gated model, so you must explicitly request access before it can be downloaded. Visit the [model page](https://huggingface.co/meta-llama/Llama-3.1-8B) on Hugging Face, log in with the same account linked to your access token, and click **Request access**. You'll need to agree to Meta's license terms; approval is usually granted quickly but is not automatic. Once approved, your Hugging Face token will authorize downloads transparently. If you skip this step, model downloads will fail even with a valid token." + ] + }, + { + "cell_type": "markdown", + "id": "66b25ee3-1072-4c0f-965d-72e911873d1c", + "metadata": {}, + "source": [ + "## Get the model and convert it into MaxText format\n", + "\n", + "### Import dependencies\n", + "\n", + "#### Core libraries and installation\n", + "\n", + "Import the core libraries needed for this tutorial:\n", + "\n", + "- **JAX**: High-performance ML framework with automatic differentiation and XLA compilation\n", + "- **MaxText**: Google's production-grade training stack for JAX, providing model architectures, checkpoint management, and the SFT training loop\n", + "\n", + "The easiest way to get a working environment is the [NVIDIA NGC JAX container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax), which ships with JAX, CUDA, and MaxText preinstalled. To install the dependencies manually:\n", + "\n", + "```bash\n", + "pip install 'jax[cuda13]' maxtext\n", + "```\n", + "\n", + "On top of it, for the model conversion step you will also need **Torch**, the CPU version would be enough:\n", + "\n", + "```bash\n", + "pip install torch --index-url https://download.pytorch.org/whl/cpu\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "531951be-2f24-455a-b9ff-b07aeeb2de1d", + "metadata": {}, + "outputs": [], + "source": [ + "# Imports\n", + "from datetime import datetime\n", + "from pathlib import Path\n", + "import sys\n", + "import subprocess\n", + "import logging\n", + "\n", + "import transformers\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import MaxText\n", + "from maxtext.configs import pyconfig\n", + "from maxtext.trainers.post_train.sft import train_sft\n", + "\n", + "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n", + "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")\n", + "\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX devices: {jax.devices()}\")\n", + "print(f\"Number of available devices: {jax.local_device_count()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "483873c9-b3eb-4a95-be34-2b52aeb222e7", + "metadata": {}, + "source": [ + "#### Setting up the right parallel setup on GPU\n", + "\n", + "JAX supports two different parallel setups:\n", + "\n", + "1. *Single-host* (one machine)\n", + "\n", + "* Can be 1 GPU or multiple GPUs\n", + "* JAX will discover and use all local GPUs automatically\n", + "\n", + "Does not require `jax.distributed.initialize()`\n", + "\n", + "2. *Multi-host* (multiple machines / nodes)\n", + "\n", + "* Requires coordination across processes/hosts\n", + "* Requires `jax.distributed.initialize()` (or a launcher that does it)\n", + "\n", + "Needs coordinator and process metadata (address, process count, process index):\n", + "\n", + "`JAX_COORDINATOR_ADDRESS` (reachable host:port on process 0)\n", + "\n", + "`JAX_PROCESS_COUNT` (total number of processes/hosts)\n", + "\n", + "`JAX_PROCESS_INDEX` (0..count-1)\n", + "\n", + "**Example (2 hosts)**\n", + "\n", + "On host 0:\n", + "```bash\n", + "export JAX_COORDINATOR_ADDRESS=\"10.0.0.1:1234\"\n", + "export JAX_PROCESS_COUNT=2\n", + "export JAX_PROCESS_INDEX=0\n", + "```\n", + "\n", + "On host 1:\n", + "```bash\n", + "export JAX_COORDINATOR_ADDRESS=\"10.0.0.1:1234\"\n", + "export JAX_PROCESS_COUNT=2\n", + "export JAX_PROCESS_INDEX=1\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a86465ed-69d1-4d77-bb47-b29c203e5a60", + "metadata": {}, + "outputs": [], + "source": [ + "if not jax.distributed.is_initialized() and \"JAX_COORDINATOR_ADDRESS\" in os.environ:\n", + " jax.distributed.initialize()" + ] + }, + { + "cell_type": "markdown", + "id": "1687aa03-1549-429a-8156-571c7493ca3d", + "metadata": {}, + "source": [ + "### Define model paths and run configuration\n", + "\n", + "This block defines the core paths and identifiers used throughout the tutorial: the model name, tokenizer source, checkpoint location, and output directory. You can override `MODEL_CHECKPOINT_PATH` via an environment variable to point to an existing converted checkpoint and skip the conversion step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1aa133cf-1168-4e87-8c4a-6fc34f1cf5cc", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_NAME = \"llama3.1-8b\"\n", + "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n", + "\n", + "WORKSPACE_DIR = Path(\n", + " os.environ.get(\"WORKSPACE_DIR\", os.getcwd())\n", + ")\n", + "\n", + "# If set, use it; otherwise default to llama_checkpoint\n", + "MODEL_CHECKPOINT_PATH = os.environ.get(\"MODEL_CHECKPOINT_PATH\")\n", + "MODEL_CHECKPOINT_PATH = Path(MODEL_CHECKPOINT_PATH) if MODEL_CHECKPOINT_PATH else (WORKSPACE_DIR / \"llama_checkpoint\")\n", + "\n", + "print(f\"Model checkpoint directory: {MODEL_CHECKPOINT_PATH}\")\n", + "print(\"Tip: set MODEL_CHECKPOINT_PATH to a local directory to reuse an existing converted checkpoint.\")\n", + "\n", + "BASE_OUTPUT_DIRECTORY = Path(os.environ.get(\"BASE_OUTPUT_DIRECTORY\", str(WORKSPACE_DIR / \"sft_llama3_output\")))" + ] + }, + { + "cell_type": "markdown", + "id": "6b762437-1edb-4123-8257-90cb98028e97", + "metadata": {}, + "source": [ + "### Download and convert the Llama 3.1-8B checkpoint from Hugging Face\n", + "\n", + "This block downloads the pretrained Llama 3.1-8B weights from Hugging Face and converts them into MaxText's native checkpoint format. If a converted checkpoint already exists at the target path, this step is skipped entirely.\n", + "\n", + "The conversion runs in a CPU-only JAX context (`JAX_PLATFORMS=cpu`) to avoid unnecessary GPU memory allocation. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f028b9c-89ed-4301-9a73-2967694891d3", + "metadata": {}, + "outputs": [], + "source": [ + "ckpt_dir = Path(MODEL_CHECKPOINT_PATH)\n", + "\n", + "def run_ckpt_conversion(\n", + " *,\n", + " maxtext_repo_root: str,\n", + " model_name: str,\n", + " output_dir: Path,\n", + " hf_token: str,\n", + " quiet: bool = True,\n", + ") -> None:\n", + " env = os.environ.copy()\n", + "\n", + " # Conversion should not touch GPUs\n", + " env[\"JAX_PLATFORMS\"] = \"cpu\"\n", + "\n", + " # Reduce verbosity (JAX/XLA/TensorFlow C++ logging)\n", + " env.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\") # 0=all, 1=INFO off, 2=INFO+WARNING off, 3=only FATAL\n", + "\n", + " cmd = [\n", + " sys.executable, \"-m\", \"MaxText.utils.ckpt_conversion.to_maxtext\",\n", + " f\"{maxtext_repo_root}/configs/base.yml\",\n", + " f\"model_name={model_name}\",\n", + " f\"base_output_directory={str(output_dir)}\",\n", + " f\"hf_access_token={hf_token}\",\n", + " \"use_multimodal=false\",\n", + " \"scan_layers=true\",\n", + " \"skip_jax_distributed_system=true\",\n", + " ]\n", + "\n", + " output_dir.parent.mkdir(parents=True, exist_ok=True)\n", + "\n", + " if quiet:\n", + " # Capture logs; show only if something goes wrong\n", + " result = subprocess.run(\n", + " cmd,\n", + " env=env,\n", + " stdout=subprocess.PIPE,\n", + " stderr=subprocess.PIPE,\n", + " text=True,\n", + " )\n", + " if result.returncode != 0:\n", + " print(\"Checkpoint conversion failed. Logs:\\n\")\n", + " if result.stdout:\n", + " print(\"----- stdout -----\")\n", + " print(result.stdout)\n", + " if result.stderr:\n", + " print(\"----- stderr -----\")\n", + " print(result.stderr)\n", + " raise RuntimeError(\"Checkpoint conversion failed. See logs above.\")\n", + " else:\n", + " # Verbose mode (streams logs)\n", + " subprocess.run(cmd, env=env, check=True)\n", + "\n", + " print(f\"Checkpoint successfully converted to MaxText format at: {output_dir}\")\n", + "\n", + "if ckpt_dir.exists():\n", + " print(f\"Converted checkpoint already exists at: {ckpt_dir}\")\n", + "else:\n", + " print(f\"Converting checkpoint to MaxText format → {ckpt_dir}\")\n", + " run_ckpt_conversion(\n", + " maxtext_repo_root=MAXTEXT_REPO_ROOT,\n", + " model_name=MODEL_NAME,\n", + " output_dir=ckpt_dir,\n", + " hf_token=HF_TOKEN,\n", + " quiet=True, \n", + " )\n", + "\n", + "if not ckpt_dir.exists():\n", + " raise RuntimeError(\"Model checkpoint conversion failed. See logs above.\")" + ] + }, + { + "cell_type": "markdown", + "id": "265e9555-f026-4bbc-a6fb-7b0cac6bd9da", + "metadata": {}, + "source": [ + "## Provide the training configuration\n", + "\n", + "This block builds the full MaxText SFT training configuration by loading the base `sft.yml` config and applying runtime overrides for the model, dataset, hyperparameters, and output paths. Each run is tagged with a timestamp-based name to keep outputs isolated across experiments. Key settings:\n", + "\n", + "- **Dataset:** [UltraChat 200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k), a large instruction-style conversational dataset commonly used for SFT of chat models.\n", + "- **Training:** 100 steps, learning rate 2e-5, sequence length 1024, bfloat16 precision.\n", + "- **Checkpoint source:** The converted MaxText checkpoint from the previous step.\n", + "\n", + "To use your own dataset, ensure it follows a compatible schema and is accessible via the Hugging Face Hub or a local path. MaxText handles dataset loading, sharding, and batching automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d22f54a-efe1-425d-abd4-ceb34561a9a1", + "metadata": {}, + "outputs": [], + "source": [ + "ckpt_items_path = Path(MODEL_CHECKPOINT_PATH) / \"0\" / \"items\"\n", + "\n", + "if not os.environ.get(\"HF_TOKEN\"):\n", + " raise RuntimeError(\"HF_TOKEN is not set. Export it before loading the SFT config.\")\n", + "\n", + "RUN_NAME = datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n", + "\n", + "# Load configuration for SFT training\n", + "config_argv = [\n", + " \"\",\n", + " f\"{MAXTEXT_REPO_ROOT}/configs/sft.yml\",\n", + " f\"load_parameters_path={ckpt_items_path}\",\n", + " f\"model_name={MODEL_NAME}\",\n", + " \"steps=100\",\n", + " \"per_device_batch_size=1\",\n", + " \"max_target_length=1024\",\n", + " \"learning_rate=2.0e-5\",\n", + " \"weight_dtype=bfloat16\",\n", + " \"dtype=bfloat16\",\n", + " \"hf_path=HuggingFaceH4/ultrachat_200k\",\n", + " f\"hf_access_token={HF_TOKEN}\",\n", + " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", + " f\"run_name={RUN_NAME}\",\n", + " f\"tokenizer_path={TOKENIZER_PATH}\",\n", + " \"hardware=gpu\",\n", + "]\n", + "\n", + "# Suppress the verbose per-parameter config dump (hundreds of INFO lines)\n", + "_pyconfig_logger = logging.getLogger(\"MaxText.pyconfig\")\n", + "_prev_level = _pyconfig_logger.level\n", + "_pyconfig_logger.setLevel(logging.WARNING)\n", + "\n", + "config = pyconfig.initialize(config_argv)\n", + "\n", + "_pyconfig_logger.setLevel(_prev_level)\n", + "\n", + "print(\"SFT configuration loaded:\")\n", + "print(f\" Model: {config.model_name}\")\n", + "print(f\" Training Steps: {config.steps}\")\n", + "print(f\" Max sequence length: {config.max_target_length}\")\n", + "print(f\" Output Directory: {config.base_output_directory}\")" + ] + }, + { + "cell_type": "markdown", + "id": "408c6100-20e9-4dcb-b9b6-42d68bb03ae7", + "metadata": {}, + "source": [ + "## Run the SFT training\n", + "\n", + "This section launches the SFT training loop. It runs MaxText's `sft_train` with the configuration defined above, reports progress, and saves checkpoints to the output directory. On completion, it prints the checkpoint and TensorBoard log paths." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb13ddd1-57a8-469e-940c-11fe6ae2a90d", + "metadata": {}, + "outputs": [], + "source": [ + "os.environ.setdefault(\"LIBTPU_INIT_ARGS\", \"\")\n", + "\n", + "print(\"=\" * 60)\n", + "print(f\"Starting SFT training (run_name={RUN_NAME})\")\n", + "print(\"=\" * 60)\n", + "\n", + "try:\n", + " result = train_sft.train(config)\n", + "\n", + " print(\"\\n\" + \"=\" * 60)\n", + " print(\"Training completed successfully\")\n", + " print(\"=\" * 60)\n", + " print(f\"Checkpoints written to: {config.checkpoint_dir}\")\n", + " if hasattr(config, \"tensorboard_dir\"):\n", + " print(f\"TensorBoard logs: {config.tensorboard_dir}\")\n", + "\n", + " if isinstance(result, tuple) and len(result) == 2:\n", + " trainer, mesh = result\n", + "except Exception as e:\n", + " print(\"\\n\" + \"=\" * 60)\n", + " print(\"Training failed\")\n", + " print(\"=\" * 60)\n", + " print(f\"Error details: {e}\")\n", + " raise" + ] + }, + { + "cell_type": "markdown", + "id": "24e3a3e2-027a-4bb7-bb7f-163605010d03", + "metadata": {}, + "source": [ + "## Visualize training metrics with TensorBoard\n", + "\n", + "To monitor training loss and other metrics, launch TensorBoard in a separate terminal replacing `` with the TensorBoard logs path from the training log:\n", + "\n", + "```bash\n", + "export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python\n", + "tensorboard --logdir= --host 0.0.0.0 --port 6006 --load_fast=false\n", + "```\n", + "\n", + "Then open [http://127.0.0.1:6006/](http://127.0.0.1:6006/) in your browser. " + ] + }, + { + "cell_type": "markdown", + "id": "5acaf26b-72fe-404b-827a-297045547f5b", + "metadata": {}, + "source": [ + "## Test inference\n", + "\n", + "A quick sanity check to verify the fine-tuned model produces coherent output. The code below tokenizes a prompt using the Llama 3.1 chat template, then runs greedy autoregressive generation for up to 10 tokens, stopping early if the model produces an EOS token. This confirms the model loaded correctly and produces reasonable predictions.\n", + "\n", + "**Note:** this is naive autoregressive generation without KV-caching, so each step recomputes attention over the full sequence. For production use, consider a dedicated serving framework with KV-cache support." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f80390d8-a267-4d63-84c4-9c1c7a5a5944", + "metadata": {}, + "outputs": [], + "source": [ + "# Load tokenizer\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH, token=HF_TOKEN)\n", + "\n", + "# Get model from trainer\n", + "model = trainer.model\n", + "\n", + "# Format prompt using Llama chat template\n", + "prompt = \"What is the capital of France?\"\n", + "messages = [{\"role\": \"user\", \"content\": prompt}]\n", + "text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", + "\n", + "# Tokenize\n", + "tokens = jnp.array(tokenizer(text)[\"input_ids\"])[None, :]\n", + "\n", + "# Greedy autoregressive generation\n", + "max_new_tokens = 10\n", + "generated_ids = []\n", + "eos_token_id = tokenizer.eos_token_id\n", + "\n", + "for _ in range(max_new_tokens):\n", + " seq_len = tokens.shape[1]\n", + " positions = jnp.arange(seq_len)[None, :]\n", + " attention_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))[None, :]\n", + "\n", + " with mesh:\n", + " output = model(tokens, positions, None, attention_mask)\n", + " logits = output[0] if isinstance(output, tuple) else output\n", + "\n", + " next_token_id = int(jnp.argmax(logits[0, -1]))\n", + " generated_ids.append(next_token_id)\n", + "\n", + " if next_token_id == eos_token_id:\n", + " break\n", + "\n", + " tokens = jnp.concatenate([tokens, jnp.array([[next_token_id]])], axis=1)\n", + "\n", + "# Decode all generated tokens\n", + "generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"Generated ({len(generated_ids)} tokens): '{generated_text}'\")" + ] + }, + { + "cell_type": "markdown", + "id": "8a261a8c-55af-47ff-b68f-179abff5b623", + "metadata": {}, + "source": [ + "## Learn more\n", + "\n", + "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html\n", + "- **Configuration**: See `src/maxtext/configs/post_train/sft.yml` for all available options\n", + "- **Documentation**: Check `src/maxtext/trainers/post_train/sft/train_sft.py` for the `train` function implementation" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } }, - { - "cell_type": "code", - "execution_count": null, - "id": "dhc3l20703b", - "metadata": {}, - "outputs": [], - "source": [ - "!nvidia-smi" - ] - }, - { - "cell_type": "markdown", - "id": "c7671875-fea3-4abe-9f01-57c854f50f92", - "metadata": {}, - "source": [ - "### Get your Hugging Face token\n", - "\n", - "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", - "\n", - "**Follow these steps to get your token:**\n", - "\n", - "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", - "\n", - "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", - "\n", - "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", - "\n", - "4. **Copy the generated token**. You will need this in the later steps.\n", - "\n", - "**Follow these steps to store your token (only if running on Google Colab):**\n", - "\n", - "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", - "\n", - "2. Click **\"+ Add new secret\"**.\n", - "\n", - "3. Set the Name as **HF_TOKEN**.\n", - "\n", - "4. Paste your token into the Value field.\n", - "\n", - "5. Ensure the Notebook access toggle is turned On." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7999f002-65a1-4764-ba41-922f2fec43df", - "metadata": {}, - "outputs": [], - "source": [ - "try:\n", - " from google.colab import userdata\n", - " print(\"Running the notebook on Google Colab\")\n", - " IN_COLAB = True\n", - "except ImportError:\n", - " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", - " IN_COLAB = False" - ] - }, - { - "cell_type": "markdown", - "id": "6b691306-88ee-47b3-b841-cd6d072f51eb", - "metadata": {}, - "source": [ - "### Authenticate with Hugging Face\n", - "\n", - "Verify that your Hugging Face token is set and valid by calling the Hub's `whoami` endpoint. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "93a0030b-ab25-4d76-a173-1a464c19087a", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from huggingface_hub import HfApi\n", - "\n", - "if IN_COLAB:\n", - " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", - "else:\n", - " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", - "\n", - "if not HF_TOKEN:\n", - " from getpass import getpass\n", - " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", - "\n", - "if not HF_TOKEN:\n", - " raise RuntimeError(\"Authentication failed: Hugging Face token is not set.\")\n", - "\n", - "# Ensure token is set in this process\n", - "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n", - "\n", - "# Verify identity\n", - "api = HfApi()\n", - "user_info = api.whoami(token=HF_TOKEN)\n", - "username = user_info.get(\"name\") or \"Unknown user\"\n", - "\n", - "print(f\"Authenticated with Hugging Face successfully as: {username}\")" - ] - }, - { - "cell_type": "markdown", - "id": "887cd139-a776-43ad-aeea-8547fcd8d744", - "metadata": {}, - "source": [ - "### Acquire permission to use the gated model\n", - "\n", - "Llama 3.1-8B is a gated model, so you must explicitly request access before it can be downloaded. Visit the [model page](https://huggingface.co/meta-llama/Llama-3.1-8B) on Hugging Face, log in with the same account linked to your access token, and click **Request access**. You'll need to agree to Meta's license terms; approval is usually granted quickly but is not automatic. Once approved, your Hugging Face token will authorize downloads transparently. If you skip this step, model downloads will fail even with a valid token." - ] - }, - { - "cell_type": "markdown", - "id": "66b25ee3-1072-4c0f-965d-72e911873d1c", - "metadata": {}, - "source": [ - "## Get the model and convert it into MaxText format\n", - "\n", - "### Import dependencies\n", - "\n", - "#### Core libraries and installation\n", - "\n", - "Import the core libraries needed for this tutorial:\n", - "\n", - "- **JAX**: High-performance ML framework with automatic differentiation and XLA compilation\n", - "- **MaxText**: Google's production-grade training stack for JAX, providing model architectures, checkpoint management, and the SFT training loop\n", - "\n", - "The easiest way to get a working environment is the [NVIDIA NGC JAX container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax), which ships with JAX, CUDA, and MaxText preinstalled. To install the dependencies manually:\n", - "\n", - "```bash\n", - "pip install 'jax[cuda13]' maxtext\n", - "```\n", - "\n", - "On top of it, for the model conversion step you will also need **Torch**, the CPU version would be enough:\n", - "\n", - "```bash\n", - "pip install torch --index-url https://download.pytorch.org/whl/cpu\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "531951be-2f24-455a-b9ff-b07aeeb2de1d", - "metadata": {}, - "outputs": [], - "source": [ - "# Imports\n", - "from datetime import datetime\n", - "from pathlib import Path\n", - "import sys\n", - "import subprocess\n", - "import logging\n", - "\n", - "import transformers\n", - "\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import MaxText\n", - "from MaxText import pyconfig\n", - "from MaxText.sft.sft_trainer import train as sft_train, setup_trainer_state\n", - "\n", - "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n", - "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")\n", - "\n", - "print(f\"JAX version: {jax.__version__}\")\n", - "print(f\"JAX devices: {jax.devices()}\")\n", - "print(f\"Number of available devices: {jax.local_device_count()}\")" - ] - }, - { - "cell_type": "markdown", - "id": "483873c9-b3eb-4a95-be34-2b52aeb222e7", - "metadata": {}, - "source": [ - "#### Setting up the right parallel setup on GPU\n", - "\n", - "JAX supports two different parallel setups:\n", - "\n", - "1. *Single-host* (one machine)\n", - "\n", - "* Can be 1 GPU or multiple GPUs\n", - "* JAX will discover and use all local GPUs automatically\n", - "\n", - "Does not require `jax.distributed.initialize()`\n", - "\n", - "2. *Multi-host* (multiple machines / nodes)\n", - "\n", - "* Requires coordination across processes/hosts\n", - "* Requires `jax.distributed.initialize()` (or a launcher that does it)\n", - "\n", - "Needs coordinator and process metadata (address, process count, process index):\n", - "\n", - "`JAX_COORDINATOR_ADDRESS` (reachable host:port on process 0)\n", - "\n", - "`JAX_PROCESS_COUNT` (total number of processes/hosts)\n", - "\n", - "`JAX_PROCESS_INDEX` (0..count-1)\n", - "\n", - "**Example (2 hosts)**\n", - "\n", - "On host 0:\n", - "```bash\n", - "export JAX_COORDINATOR_ADDRESS=\"10.0.0.1:1234\"\n", - "export JAX_PROCESS_COUNT=2\n", - "export JAX_PROCESS_INDEX=0\n", - "```\n", - "\n", - "On host 1:\n", - "```bash\n", - "export JAX_COORDINATOR_ADDRESS=\"10.0.0.1:1234\"\n", - "export JAX_PROCESS_COUNT=2\n", - "export JAX_PROCESS_INDEX=1\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a86465ed-69d1-4d77-bb47-b29c203e5a60", - "metadata": {}, - "outputs": [], - "source": [ - "if not jax.distributed.is_initialized() and \"JAX_COORDINATOR_ADDRESS\" in os.environ:\n", - " jax.distributed.initialize()" - ] - }, - { - "cell_type": "markdown", - "id": "1687aa03-1549-429a-8156-571c7493ca3d", - "metadata": {}, - "source": [ - "### Define model paths and run configuration\n", - "\n", - "This block defines the core paths and identifiers used throughout the tutorial: the model name, tokenizer source, checkpoint location, and output directory. You can override `MODEL_CHECKPOINT_PATH` via an environment variable to point to an existing converted checkpoint and skip the conversion step." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1aa133cf-1168-4e87-8c4a-6fc34f1cf5cc", - "metadata": {}, - "outputs": [], - "source": [ - "MODEL_NAME = \"llama3.1-8b\"\n", - "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n", - "\n", - "WORKSPACE_DIR = Path(\n", - " os.environ.get(\"WORKSPACE_DIR\", os.getcwd())\n", - ")\n", - "\n", - "# If set, use it; otherwise default to llama_checkpoint\n", - "MODEL_CHECKPOINT_PATH = os.environ.get(\"MODEL_CHECKPOINT_PATH\")\n", - "MODEL_CHECKPOINT_PATH = Path(MODEL_CHECKPOINT_PATH) if MODEL_CHECKPOINT_PATH else (WORKSPACE_DIR / \"llama_checkpoint\")\n", - "\n", - "print(f\"Model checkpoint directory: {MODEL_CHECKPOINT_PATH}\")\n", - "print(\"Tip: set MODEL_CHECKPOINT_PATH to a local directory to reuse an existing converted checkpoint.\")\n", - "\n", - "BASE_OUTPUT_DIRECTORY = Path(os.environ.get(\"BASE_OUTPUT_DIRECTORY\", str(WORKSPACE_DIR / \"sft_llama3_output\")))" - ] - }, - { - "cell_type": "markdown", - "id": "6b762437-1edb-4123-8257-90cb98028e97", - "metadata": {}, - "source": [ - "### Download and convert the Llama 3.1-8B checkpoint from Hugging Face\n", - "\n", - "This block downloads the pretrained Llama 3.1-8B weights from Hugging Face and converts them into MaxText's native checkpoint format. If a converted checkpoint already exists at the target path, this step is skipped entirely.\n", - "\n", - "The conversion runs in a CPU-only JAX context (`JAX_PLATFORMS=cpu`) to avoid unnecessary GPU memory allocation. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3f028b9c-89ed-4301-9a73-2967694891d3", - "metadata": {}, - "outputs": [], - "source": [ - "ckpt_dir = Path(MODEL_CHECKPOINT_PATH)\n", - "\n", - "def run_ckpt_conversion(\n", - " *,\n", - " maxtext_repo_root: str,\n", - " model_name: str,\n", - " output_dir: Path,\n", - " hf_token: str,\n", - " quiet: bool = True,\n", - ") -> None:\n", - " env = os.environ.copy()\n", - "\n", - " # Conversion should not touch GPUs\n", - " env[\"JAX_PLATFORMS\"] = \"cpu\"\n", - "\n", - " # Reduce verbosity (JAX/XLA/TensorFlow C++ logging)\n", - " env.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\") # 0=all, 1=INFO off, 2=INFO+WARNING off, 3=only FATAL\n", - "\n", - " cmd = [\n", - " sys.executable, \"-m\", \"MaxText.utils.ckpt_conversion.to_maxtext\",\n", - " f\"{maxtext_repo_root}/configs/base.yml\",\n", - " f\"model_name={model_name}\",\n", - " f\"base_output_directory={str(output_dir)}\",\n", - " f\"hf_access_token={hf_token}\",\n", - " \"use_multimodal=false\",\n", - " \"scan_layers=true\",\n", - " \"skip_jax_distributed_system=true\",\n", - " ]\n", - "\n", - " output_dir.parent.mkdir(parents=True, exist_ok=True)\n", - "\n", - " if quiet:\n", - " # Capture logs; show only if something goes wrong\n", - " result = subprocess.run(\n", - " cmd,\n", - " env=env,\n", - " stdout=subprocess.PIPE,\n", - " stderr=subprocess.PIPE,\n", - " text=True,\n", - " )\n", - " if result.returncode != 0:\n", - " print(\"Checkpoint conversion failed. Logs:\\n\")\n", - " if result.stdout:\n", - " print(\"----- stdout -----\")\n", - " print(result.stdout)\n", - " if result.stderr:\n", - " print(\"----- stderr -----\")\n", - " print(result.stderr)\n", - " raise RuntimeError(\"Checkpoint conversion failed. See logs above.\")\n", - " else:\n", - " # Verbose mode (streams logs)\n", - " subprocess.run(cmd, env=env, check=True)\n", - "\n", - " print(f\"Checkpoint successfully converted to MaxText format at: {output_dir}\")\n", - "\n", - "if ckpt_dir.exists():\n", - " print(f\"Converted checkpoint already exists at: {ckpt_dir}\")\n", - "else:\n", - " print(f\"Converting checkpoint to MaxText format → {ckpt_dir}\")\n", - " run_ckpt_conversion(\n", - " maxtext_repo_root=MAXTEXT_REPO_ROOT,\n", - " model_name=MODEL_NAME,\n", - " output_dir=ckpt_dir,\n", - " hf_token=HF_TOKEN,\n", - " quiet=True, \n", - " )\n", - "\n", - "if not ckpt_dir.exists():\n", - " raise RuntimeError(\"Model checkpoint conversion failed. See logs above.\")" - ] - }, - { - "cell_type": "markdown", - "id": "265e9555-f026-4bbc-a6fb-7b0cac6bd9da", - "metadata": {}, - "source": [ - "## Provide the training configuration\n", - "\n", - "This block builds the full MaxText SFT training configuration by loading the base `sft.yml` config and applying runtime overrides for the model, dataset, hyperparameters, and output paths. Each run is tagged with a timestamp-based name to keep outputs isolated across experiments. Key settings:\n", - "\n", - "- **Dataset:** [UltraChat 200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k), a large instruction-style conversational dataset commonly used for SFT of chat models.\n", - "- **Training:** 100 steps, learning rate 2e-5, sequence length 1024, bfloat16 precision.\n", - "- **Checkpoint source:** The converted MaxText checkpoint from the previous step.\n", - "\n", - "To use your own dataset, ensure it follows a compatible schema and is accessible via the Hugging Face Hub or a local path. MaxText handles dataset loading, sharding, and batching automatically." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4d22f54a-efe1-425d-abd4-ceb34561a9a1", - "metadata": {}, - "outputs": [], - "source": [ - "ckpt_items_path = Path(MODEL_CHECKPOINT_PATH) / \"0\" / \"items\"\n", - "\n", - "if not os.environ.get(\"HF_TOKEN\"):\n", - " raise RuntimeError(\"HF_TOKEN is not set. Export it before loading the SFT config.\")\n", - "\n", - "RUN_NAME = datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n", - "\n", - "# Load configuration for SFT training\n", - "config_argv = [\n", - " \"\",\n", - " f\"{MAXTEXT_REPO_ROOT}/configs/sft.yml\",\n", - " f\"load_parameters_path={ckpt_items_path}\",\n", - " f\"model_name={MODEL_NAME}\",\n", - " \"steps=100\",\n", - " \"per_device_batch_size=1\",\n", - " \"max_target_length=1024\",\n", - " \"learning_rate=2.0e-5\",\n", - " \"weight_dtype=bfloat16\",\n", - " \"dtype=bfloat16\",\n", - " \"hf_path=HuggingFaceH4/ultrachat_200k\",\n", - " f\"hf_access_token={HF_TOKEN}\",\n", - " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", - " f\"run_name={RUN_NAME}\",\n", - " f\"tokenizer_path={TOKENIZER_PATH}\",\n", - " \"hardware=gpu\",\n", - "]\n", - "\n", - "# Suppress the verbose per-parameter config dump (hundreds of INFO lines)\n", - "_pyconfig_logger = logging.getLogger(\"MaxText.pyconfig\")\n", - "_prev_level = _pyconfig_logger.level\n", - "_pyconfig_logger.setLevel(logging.WARNING)\n", - "\n", - "config = pyconfig.initialize(config_argv)\n", - "\n", - "_pyconfig_logger.setLevel(_prev_level)\n", - "\n", - "print(\"SFT configuration loaded:\")\n", - "print(f\" Model: {config.model_name}\")\n", - "print(f\" Training Steps: {config.steps}\")\n", - "print(f\" Max sequence length: {config.max_target_length}\")\n", - "print(f\" Output Directory: {config.base_output_directory}\")" - ] - }, - { - "cell_type": "markdown", - "id": "408c6100-20e9-4dcb-b9b6-42d68bb03ae7", - "metadata": {}, - "source": [ - "## Run the SFT training\n", - "\n", - "This section launches the SFT training loop. It runs MaxText's `sft_train` with the configuration defined above, reports progress, and saves checkpoints to the output directory. On completion, it prints the checkpoint and TensorBoard log paths." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bb13ddd1-57a8-469e-940c-11fe6ae2a90d", - "metadata": {}, - "outputs": [], - "source": [ - "os.environ.setdefault(\"LIBTPU_INIT_ARGS\", \"\")\n", - "\n", - "print(\"=\" * 60)\n", - "print(f\"Starting SFT training (run_name={RUN_NAME})\")\n", - "print(\"=\" * 60)\n", - "\n", - "try:\n", - " result = sft_train(config)\n", - "\n", - " print(\"\\n\" + \"=\" * 60)\n", - " print(\"Training completed successfully\")\n", - " print(\"=\" * 60)\n", - " print(f\"Checkpoints written to: {config.checkpoint_dir}\")\n", - " if hasattr(config, \"tensorboard_dir\"):\n", - " print(f\"TensorBoard logs: {config.tensorboard_dir}\")\n", - "\n", - " if isinstance(result, tuple) and len(result) == 2:\n", - " trainer, mesh = result\n", - "except Exception as e:\n", - " print(\"\\n\" + \"=\" * 60)\n", - " print(\"Training failed\")\n", - " print(\"=\" * 60)\n", - " print(f\"Error details: {e}\")\n", - " raise" - ] - }, - { - "cell_type": "markdown", - "id": "24e3a3e2-027a-4bb7-bb7f-163605010d03", - "metadata": {}, - "source": [ - "## Visualize training metrics with TensorBoard\n", - "\n", - "To monitor training loss and other metrics, launch TensorBoard in a separate terminal replacing `` with the TensorBoard logs path from the training log:\n", - "\n", - "```bash\n", - "export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python\n", - "tensorboard --logdir= --host 0.0.0.0 --port 6006 --load_fast=false\n", - "```\n", - "\n", - "Then open [http://127.0.0.1:6006/](http://127.0.0.1:6006/) in your browser. " - ] - }, - { - "cell_type": "markdown", - "id": "5acaf26b-72fe-404b-827a-297045547f5b", - "metadata": {}, - "source": [ - "## Test inference\n", - "\n", - "A quick sanity check to verify the fine-tuned model produces coherent output. The code below tokenizes a prompt using the Llama 3.1 chat template, then runs greedy autoregressive generation for up to 10 tokens, stopping early if the model produces an EOS token. This confirms the model loaded correctly and produces reasonable predictions.\n", - "\n", - "**Note:** this is naive autoregressive generation without KV-caching, so each step recomputes attention over the full sequence. For production use, consider a dedicated serving framework with KV-cache support." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f80390d8-a267-4d63-84c4-9c1c7a5a5944", - "metadata": {}, - "outputs": [], - "source": [ - "# Load tokenizer\n", - "tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH, token=HF_TOKEN)\n", - "\n", - "# Get model from trainer\n", - "model = trainer.model\n", - "\n", - "# Format prompt using Llama chat template\n", - "prompt = \"What is the capital of France?\"\n", - "messages = [{\"role\": \"user\", \"content\": prompt}]\n", - "text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", - "\n", - "# Tokenize\n", - "tokens = jnp.array(tokenizer(text)[\"input_ids\"])[None, :]\n", - "\n", - "# Greedy autoregressive generation\n", - "max_new_tokens = 10\n", - "generated_ids = []\n", - "eos_token_id = tokenizer.eos_token_id\n", - "\n", - "for _ in range(max_new_tokens):\n", - " seq_len = tokens.shape[1]\n", - " positions = jnp.arange(seq_len)[None, :]\n", - " attention_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))[None, :]\n", - "\n", - " with mesh:\n", - " output = model(tokens, positions, None, attention_mask)\n", - " logits = output[0] if isinstance(output, tuple) else output\n", - "\n", - " next_token_id = int(jnp.argmax(logits[0, -1]))\n", - " generated_ids.append(next_token_id)\n", - "\n", - " if next_token_id == eos_token_id:\n", - " break\n", - "\n", - " tokens = jnp.concatenate([tokens, jnp.array([[next_token_id]])], axis=1)\n", - "\n", - "# Decode all generated tokens\n", - "generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)\n", - "\n", - "print(f\"Prompt: {prompt}\")\n", - "print(f\"Generated ({len(generated_ids)} tokens): '{generated_text}'\")" - ] - }, - { - "cell_type": "markdown", - "id": "8a261a8c-55af-47ff-b68f-179abff5b623", - "metadata": {}, - "source": [ - "## Learn more\n", - "\n", - "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html\n", - "- **Configuration**: See `src/maxtext/configs/post_train/sft.yml` for all available options\n", - "- **Documentation**: Check `src/MaxText/sft/sft_trainer.py` for the `sft_train` function implementation" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/src/maxtext/examples/sft_qwen3_demo.ipynb b/src/maxtext/examples/sft_qwen3_demo.ipynb index 0f6f2e8f59..c4fbf53529 100644 --- a/src/maxtext/examples/sft_qwen3_demo.ipynb +++ b/src/maxtext/examples/sft_qwen3_demo.ipynb @@ -1,584 +1,592 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "1nb_Ppf2ZUQL" - }, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_qwen3_demo.ipynb)\n", - "\n", - "# Qwen3-0.6B Supervised Fine-Tuning (SFT) Demo\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FGbe4_YQZUQL" - }, - "source": [ - "## Overview\n", - "\n", - "This notebook performs SFT training and evaluation workflow on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k).\n", - "The primary goal is to demonstrate the end-to-end process of:\n", - "1. Pre-SFT Evaluation: Calcuating baseline accuracy for the model before training.\n", - "2. SFT Training: Fine-tune the model using MaxText & Tunix SFT trainer.\n", - "3. Post-SFT Evaluation: Re-running the evaluation loop after training to measure the performance gain achieved by SFT.\n", - "\n", - "This notebook can run on the **public TPU v5e-1**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zolxPWhQZUQL" - }, - "source": [ - "## Prerequisites\n", - "\n", - "### Change Runtime Type (only if running on Google Colab)\n", - "\n", - "**Instructions:**\n", - "1. Navigate to the menu at the top of the screen.\n", - "2. Click on **Runtime**.\n", - "3. Select **Change runtime type** from the dropdown menu.\n", - "4. Select **v5e-1 TPU** as the **Hardware accelerator**.\n", - "5. Click on **Save**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Rk_QpVVuZUQL" - }, - "source": [ - "### Get Your Hugging Face Token\n", - "\n", - "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", - "\n", - "**Follow these steps to get your token:**\n", - "\n", - "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", - " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", - "\n", - "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", - "\n", - "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", - "\n", - "4. **Copy the generated token**. You will need this in the later steps.\n", - "\n", - "**Follow these steps to store your token (only if running on Google Colab):**\n", - "\n", - "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", - "\n", - "2. Click **\"+ Add new secret\"**.\n", - "\n", - "3. Set the Name as **HF_TOKEN**.\n", - "\n", - "4. Paste your token into the Value field.\n", - "\n", - "5. Ensure the Notebook access toggle is turned On." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "try:\n", - " from google.colab import userdata\n", - " print(\"Running the notebook on Google Colab\")\n", - " IN_COLAB = True\n", - "except ImportError:\n", - " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", - " IN_COLAB = False" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "D9ms-jTSZUQL" - }, - "source": [ - "## Installation: MaxText and Post training Dependencies\n", - "\n", - "**Running the notebook on Visual Studio or JupyterLab:** Before proceeding, create a virtual environment and install the required post-training dependencies by following `Option 3: Installing [tpu-post-train]` in the [MaxText installation guide](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-source). Once the environment is set up, ensure the notebook is running within it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if IN_COLAB:\n", - " # Clone the MaxText repository\n", - " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n", - " %cd maxtext\n", - "\n", - " # Install uv, a fast Python package installer\n", - " !pip install uv\n", - " \n", - " # Install MaxText and post-training dependencies\n", - " !uv pip install -e .[tpu-post-train] --resolution=lowest\n", - " !install_maxtext_tpu_post_train_extra_deps" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Session restart Instructions for Colab:**\n", - "1. Navigate to the menu at the top of the screen.\n", - "2. Click on **Runtime**.\n", - "3. Select **Restart session** from the dropdown menu.\n", - "\n", - "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Clexf-j7ZUQM" - }, - "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "PkBI9A3JZUQM" - }, - "outputs": [], - "source": [ - "import jax\n", - "import os\n", - "import sys\n", - "import transformers\n", - "\n", - "from maxtext.configs import pyconfig\n", - "from maxtext.examples.sft_train_and_evaluate import evaluate_model, get_test_dataset\n", - "from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter\n", - "from maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n", - "from maxtext.trainers.post_train.sft import train_sft\n", - "\n", - "# Suppress vLLM logging with a severity level below ERROR\n", - "os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n", - "from tunix.rl.rollout import base_rollout\n", - "from tunix.rl.rollout.vllm_rollout import VllmRollout\n", - "\n", - "from datetime import datetime\n", - "from flax import nnx\n", - "from huggingface_hub import login\n", - "\n", - "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if not jax.distributed.is_initialized():\n", - " jax.distributed.initialize()\n", - "print(f\"JAX version: {jax.__version__}\")\n", - "print(f\"JAX devices: {jax.devices()}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JBbPN-uVZUQM" - }, - "outputs": [], - "source": [ - "try:\n", - " from google.colab import userdata\n", - " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", - "except ImportError:\n", - " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", - "\n", - "# If not found in the environment, prompt the user for input securely\n", - "# getpass function ensures the token is hidden while you type\n", - "if not HF_TOKEN:\n", - " from getpass import getpass\n", - " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", - "\n", - "if HF_TOKEN:\n", - " login(token=HF_TOKEN)\n", - " print(\"Authenticated with Hugging Face successfully!\")\n", - "else:\n", - " print(\"Authentication failed: Hugging Face token is not set.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aENuzm9iZUQM" - }, - "source": [ - "## Model Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "RjPYYl3zZUQM" - }, - "outputs": [], - "source": [ - "MODEL_NAME = \"qwen3-0.6b\"\n", - "TOKENIZER_PATH = \"Qwen/Qwen3-0.6B\"\n", - "tokenizer = transformers.AutoTokenizer.from_pretrained(\n", - " TOKENIZER_PATH,\n", - " token=HF_TOKEN,\n", - ")\n", - "\n", - "# set the path to the model checkpoint (excluding `/0/items`) or leave empty to download from HuggingFace\n", - "MODEL_CHECKPOINT_PATH = \"\"\n", - "if not MODEL_CHECKPOINT_PATH:\n", - " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_PKG_DIR}/qwen_checkpoint\"\n", - " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n", - " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n", - "\n", - "\n", - "RUN_NAME = datetime.now().strftime(\"%Y-%m-%d-%H-%m-%S\")\n", - "\n", - "# This is the directory where the fine-tuned model checkpoint will be saved\n", - "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_PKG_DIR}/maxtext_qwen06_output\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4L37Ij4NZUQM" - }, - "source": [ - "## Download Qwen3-0.6B Model Checkpoint from Hugging Face" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kJanDAc0ZUQM" - }, - "outputs": [], - "source": [ - "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", - " # install torch for the conversion script\n", - " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n", - "\n", - " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_PKG_DIR} {sys.executable} -m maxtext.checkpoint_conversion.to_maxtext \\\n", - " {MAXTEXT_PKG_DIR}/configs/base.yml \\\n", - " model_name={MODEL_NAME} \\\n", - " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n", - " hf_access_token={HF_TOKEN} \\\n", - " use_multimodal=false \\\n", - " scan_layers=true \\\n", - " skip_jax_distributed_system=True\n", - "\n", - "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", - " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PC-hILG0ZUQM" - }, - "source": [ - "## Dataset Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "O3MLdr9kZUQM" - }, - "outputs": [], - "source": [ - "DATASET_NAME = \"openai/gsm8k\"\n", - "TRAIN_DATA_SPLIT = \"train\"\n", - "TEST_DATA_SPLIT = \"test\"\n", - "HF_DATA_DIR = \"main\"\n", - "TRAIN_DATA_COLUMNS = [\"question\", \"answer\"]\n", - "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/src/maxtext/examples/chat_templates/math_qa.json\"\n", - "if not os.path.exists(CHAT_TEMPLATE_PATH):\n", - " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", - "NUM_TEST_SAMPLES = 20 # Total number of samples to test\n", - "BATCH_SIZE = 1 # Number of test samples to process in a batch" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yeAHmxSYZUQM" - }, - "source": [ - "## MaxText Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "In-jdp1AAwrL" - }, - "outputs": [], - "source": [ - "%%capture\n", - "config = pyconfig.initialize(\n", - " [\n", - " \"\",\n", - " f\"{MAXTEXT_PKG_DIR}/configs/post_train/sft.yml\",\n", - " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n", - " f\"model_name={MODEL_NAME}\",\n", - " f\"hf_access_token={HF_TOKEN}\",\n", - " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", - " f\"run_name={RUN_NAME}\",\n", - " f\"tokenizer_path={TOKENIZER_PATH}\",\n", - " f\"hf_path={DATASET_NAME}\",\n", - " f\"train_split={TRAIN_DATA_SPLIT}\",\n", - " f\"hf_data_dir={HF_DATA_DIR}\",\n", - " f\"train_data_columns={TRAIN_DATA_COLUMNS}\",\n", - " \"steps=500\",\n", - " \"per_device_batch_size=1\",\n", - " \"max_target_length=1024\",\n", - " \"learning_rate=3e-6\",\n", - " \"weight_dtype=bfloat16\",\n", - " \"dtype=bfloat16\",\n", - " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", - " ]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O9b0GWo-ZUQM" - }, - "source": [ - "## Initial Setup & Data Preparation" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TDqFmvUCZUQM" - }, - "source": [ - "### Create Test Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wscWYxrtZUQM" - }, - "outputs": [], - "source": [ - "test_dataset = get_test_dataset(config, tokenizer)\n", - "test_dataset = test_dataset[:NUM_TEST_SAMPLES]\n", - "test_dataset = test_dataset.to_iter_dataset().batch(BATCH_SIZE, drop_remainder=True)\n", - "TOTAL_BATCHES = NUM_TEST_SAMPLES // BATCH_SIZE\n", - "print(\n", - " f\"Processing {NUM_TEST_SAMPLES} examples with a batch size of {BATCH_SIZE}. This will result in {TOTAL_BATCHES} total batches for the test run.\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bLSvOOEUZUQM" - }, - "source": [ - "### Create SFT Trainer State" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2IHsC0m6ZUQM" - }, - "outputs": [], - "source": [ - "trainer, mesh = train_sft.setup_trainer_state(config)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PpKtEqzFZUQM" - }, - "source": [ - "### Create vLLM Rollout" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3-pf_rbqZUQM" - }, - "outputs": [], - "source": [ - "tunix_model = TunixMaxTextAdapter(trainer.model)\n", - "vllm_rollout = VllmRollout(\n", - " model=tunix_model,\n", - " tokenizer=tokenizer,\n", - " cache_config_or_size=1280,\n", - " mesh=mesh,\n", - " rollout_config=base_rollout.RolloutConfig(\n", - " rollout_vllm_model_version=TOKENIZER_PATH,\n", - " rollout_vllm_hbm_utilization=0.8,\n", - " rollout_vllm_init_with_random_weights=True,\n", - " rollout_vllm_tpu_backend_type=\"jax\",\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "567gTxsEZUQM" - }, - "source": [ - "## Evaluation before SFT Training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "OnACa3zCZUQM" - }, - "outputs": [], - "source": [ - "print(\"Running Pre-SFT Evaluation...\")\n", - "score = evaluate_model(test_dataset, vllm_rollout, debug=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "u5-M4iYkZUQN" - }, - "outputs": [], - "source": [ - "print(\"========================= Score for PRE-SFT Evaluation =========================\")\n", - "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n", - "print(\n", - " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n", - ")\n", - "print(\n", - " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EJE1ookSAzz-" - }, - "source": [ - "## SFT Training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "mgwpNgQYCJEd", - "tags": [] - }, - "outputs": [], - "source": [ - "print(\"Starting SFT Training...\")\n", - "trainer = train_sft.train_model(config, trainer, mesh)\n", - "print(\"SFT Training Complete!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WEdNYRhwZUQN" - }, - "source": [ - "## Evaluation after SFT Training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XcsZacZdZUQN" - }, - "outputs": [], - "source": [ - "print(\"Running Post-SFT Evaluation...\")\n", - "model = TunixMaxTextAdapter(trainer.model)\n", - "state = nnx.state(model)\n", - "vllm_rollout.update_params(state)\n", - "score = evaluate_model(test_dataset, vllm_rollout, debug=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "-JtYTPvJZUQN", - "tags": [] - }, - "outputs": [], - "source": [ - "print(\"========================= Score for POST-SFT Evaluation =========================\")\n", - "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n", - "print(\n", - " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n", - ")\n", - "print(\n", - " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n", - ")" - ] - } - ], - "metadata": { - "accelerator": "TPU", - "colab": { - "gpuType": "V5E1", - "provenance": [] - }, - "kernelspec": { - "display_name": "maxtext_venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "1nb_Ppf2ZUQL" + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_qwen3_demo.ipynb)\n", + "\n", + "# Qwen3-0.6B Supervised Fine-Tuning (SFT) Demo\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FGbe4_YQZUQL" + }, + "source": [ + "## Overview\n", + "\n", + "This notebook performs SFT training and evaluation workflow on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k).\n", + "The primary goal is to demonstrate the end-to-end process of:\n", + "1. Pre-SFT Evaluation: Calcuating baseline accuracy for the model before training.\n", + "2. SFT Training: Fine-tune the model using MaxText & Tunix SFT trainer.\n", + "3. Post-SFT Evaluation: Re-running the evaluation loop after training to measure the performance gain achieved by SFT.\n", + "\n", + "This notebook can run on the **public TPU v5e-1**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zolxPWhQZUQL" + }, + "source": [ + "## Prerequisites\n", + "\n", + "### Change Runtime Type (only if running on Google Colab)\n", + "\n", + "**Instructions:**\n", + "1. Navigate to the menu at the top of the screen.\n", + "2. Click on **Runtime**.\n", + "3. Select **Change runtime type** from the dropdown menu.\n", + "4. Select **v5e-1 TPU** as the **Hardware accelerator**.\n", + "5. Click on **Save**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Rk_QpVVuZUQL" + }, + "source": [ + "### Get Your Hugging Face Token\n", + "\n", + "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", + "\n", + "**Follow these steps to get your token:**\n", + "\n", + "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", + " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", + "\n", + "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", + "\n", + "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", + "\n", + "4. **Copy the generated token**. You will need this in the later steps.\n", + "\n", + "**Follow these steps to store your token (only if running on Google Colab):**\n", + "\n", + "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", + "\n", + "2. Click **\"+ Add new secret\"**.\n", + "\n", + "3. Set the Name as **HF_TOKEN**.\n", + "\n", + "4. Paste your token into the Value field.\n", + "\n", + "5. Ensure the Notebook access toggle is turned On." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "o0gz1E8VtpsI" + }, + "outputs": [], + "source": [ + "try:\n", + " from google.colab import userdata\n", + " print(\"Running the notebook on Google Colab\")\n", + " IN_COLAB = True\n", + "except ImportError:\n", + " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", + " IN_COLAB = False" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D9ms-jTSZUQL" + }, + "source": [ + "## Installation: MaxText and Post training Dependencies\n", + "\n", + "**Running the notebook on Visual Studio or JupyterLab:** Before proceeding, create a virtual environment and install the required post-training dependencies by following `Option 3: Installing [tpu-post-train]` in the [MaxText installation guide](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-source). Once the environment is set up, ensure the notebook is running within it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bjnwIv1YtpsI" + }, + "outputs": [], + "source": [ + "if IN_COLAB:\n", + " # Clone the MaxText repository\n", + " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n", + " %cd maxtext\n", + "\n", + " # Install uv, a fast Python package installer\n", + " !pip install uv\n", + " \n", + " # Install MaxText and post-training dependencies\n", + " !uv pip install -e .[tpu-post-train] --resolution=lowest\n", + " !install_maxtext_tpu_post_train_extra_deps" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OKWBCMrstpsI" + }, + "source": [ + "**Session restart Instructions for Colab:**\n", + "1. Navigate to the menu at the top of the screen.\n", + "2. Click on **Runtime**.\n", + "3. Select **Restart session** from the dropdown menu.\n", + "\n", + "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Clexf-j7ZUQM" + }, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PkBI9A3JZUQM" + }, + "outputs": [], + "source": [ + "import jax\n", + "import os\n", + "import sys\n", + "import transformers\n", + "\n", + "from maxtext.configs import pyconfig\n", + "from maxtext.examples.sft_train_and_evaluate import evaluate_model, get_test_dataset\n", + "from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter\n", + "from maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n", + "from maxtext.trainers.post_train.sft import train_sft\n", + "\n", + "# Suppress vLLM logging with a severity level below ERROR\n", + "os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n", + "from tunix.rl.rollout import base_rollout\n", + "from tunix.rl.rollout.vllm_rollout import VllmRollout\n", + "\n", + "from datetime import datetime\n", + "from flax import nnx\n", + "from huggingface_hub import login\n", + "\n", + "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NIiA2OletpsI" + }, + "outputs": [], + "source": [ + "if not jax.distributed.is_initialized():\n", + " jax.distributed.initialize()\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX devices: {jax.devices()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JBbPN-uVZUQM" + }, + "outputs": [], + "source": [ + "try:\n", + " from google.colab import userdata\n", + " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", + "except ImportError:\n", + " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", + "\n", + "# If not found in the environment, prompt the user for input securely\n", + "# getpass function ensures the token is hidden while you type\n", + "if not HF_TOKEN:\n", + " from getpass import getpass\n", + " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", + "\n", + "if HF_TOKEN:\n", + " login(token=HF_TOKEN)\n", + " print(\"Authenticated with Hugging Face successfully!\")\n", + "else:\n", + " print(\"Authentication failed: Hugging Face token is not set.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aENuzm9iZUQM" + }, + "source": [ + "## Model Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RjPYYl3zZUQM" + }, + "outputs": [], + "source": [ + "MODEL_NAME = \"qwen3-0.6b\"\n", + "TOKENIZER_PATH = \"Qwen/Qwen3-0.6B\"\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained(\n", + " TOKENIZER_PATH,\n", + " token=HF_TOKEN,\n", + ")\n", + "\n", + "# set the path to the model checkpoint (excluding `/0/items`) or leave empty to download from HuggingFace\n", + "MODEL_CHECKPOINT_PATH = \"\"\n", + "if not MODEL_CHECKPOINT_PATH:\n", + " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_PKG_DIR}/qwen_checkpoint\"\n", + " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n", + " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n", + "\n", + "\n", + "RUN_NAME = datetime.now().strftime(\"%Y-%m-%d-%H-%m-%S\")\n", + "\n", + "# This is the directory where the fine-tuned model checkpoint will be saved\n", + "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_PKG_DIR}/maxtext_qwen06_output\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4L37Ij4NZUQM" + }, + "source": [ + "## Download Qwen3-0.6B Model Checkpoint from Hugging Face" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kJanDAc0ZUQM" + }, + "outputs": [], + "source": [ + "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", + " # install torch for the conversion script\n", + " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n", + "\n", + " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_PKG_DIR} {sys.executable} -m maxtext.checkpoint_conversion.to_maxtext \\\n", + " {MAXTEXT_PKG_DIR}/configs/base.yml \\\n", + " model_name={MODEL_NAME} \\\n", + " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n", + " hf_access_token={HF_TOKEN} \\\n", + " use_multimodal=false \\\n", + " scan_layers=true \\\n", + " skip_jax_distributed_system=True\n", + "\n", + "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", + " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PC-hILG0ZUQM" + }, + "source": [ + "## Dataset Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "O3MLdr9kZUQM" + }, + "outputs": [], + "source": [ + "DATASET_NAME = \"openai/gsm8k\"\n", + "TRAIN_DATA_SPLIT = \"train\"\n", + "TEST_DATA_SPLIT = \"test\"\n", + "HF_DATA_DIR = \"main\"\n", + "TRAIN_DATA_COLUMNS = [\"question\", \"answer\"]\n", + "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/src/maxtext/examples/chat_templates/math_qa.json\"\n", + "if not os.path.exists(CHAT_TEMPLATE_PATH):\n", + " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", + "NUM_TEST_SAMPLES = 20 # Total number of samples to test\n", + "BATCH_SIZE = 1 # Number of test samples to process in a batch" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yeAHmxSYZUQM" + }, + "source": [ + "## MaxText Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "In-jdp1AAwrL" + }, + "outputs": [], + "source": [ + "%%capture\n", + "config = pyconfig.initialize(\n", + " [\n", + " \"\",\n", + " f\"{MAXTEXT_PKG_DIR}/configs/post_train/sft.yml\",\n", + " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n", + " f\"model_name={MODEL_NAME}\",\n", + " f\"hf_access_token={HF_TOKEN}\",\n", + " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", + " f\"run_name={RUN_NAME}\",\n", + " f\"tokenizer_path={TOKENIZER_PATH}\",\n", + " f\"hf_path={DATASET_NAME}\",\n", + " f\"train_split={TRAIN_DATA_SPLIT}\",\n", + " f\"hf_data_dir={HF_DATA_DIR}\",\n", + " f\"train_data_columns={TRAIN_DATA_COLUMNS}\",\n", + " \"steps=500\",\n", + " \"per_device_batch_size=1\",\n", + " \"max_target_length=1024\",\n", + " \"learning_rate=3e-6\",\n", + " \"weight_dtype=bfloat16\",\n", + " \"dtype=bfloat16\",\n", + " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O9b0GWo-ZUQM" + }, + "source": [ + "## Initial Setup & Data Preparation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TDqFmvUCZUQM" + }, + "source": [ + "### Create Test Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wscWYxrtZUQM" + }, + "outputs": [], + "source": [ + "test_dataset = get_test_dataset(config, tokenizer)\n", + "test_dataset = test_dataset[:NUM_TEST_SAMPLES]\n", + "test_dataset = test_dataset.to_iter_dataset().batch(BATCH_SIZE, drop_remainder=True)\n", + "TOTAL_BATCHES = NUM_TEST_SAMPLES // BATCH_SIZE\n", + "print(\n", + " f\"Processing {NUM_TEST_SAMPLES} examples with a batch size of {BATCH_SIZE}. This will result in {TOTAL_BATCHES} total batches for the test run.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bLSvOOEUZUQM" + }, + "source": [ + "### Create SFT Trainer State" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2IHsC0m6ZUQM" + }, + "outputs": [], + "source": [ + "trainer, mesh = train_sft.setup_trainer_state(config)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PpKtEqzFZUQM" + }, + "source": [ + "### Create vLLM Rollout" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3-pf_rbqZUQM" + }, + "outputs": [], + "source": [ + "tunix_model = TunixMaxTextAdapter(trainer.model)\n", + "vllm_rollout = VllmRollout(\n", + " model=tunix_model,\n", + " tokenizer=tokenizer,\n", + " cache_config_or_size=1280,\n", + " mesh=mesh,\n", + " rollout_config=base_rollout.RolloutConfig(\n", + " rollout_vllm_model_version=TOKENIZER_PATH,\n", + " rollout_vllm_hbm_utilization=0.8,\n", + " rollout_vllm_init_with_random_weights=True,\n", + " rollout_vllm_tpu_backend_type=\"jax\",\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "567gTxsEZUQM" + }, + "source": [ + "## Evaluation before SFT Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OnACa3zCZUQM" + }, + "outputs": [], + "source": [ + "print(\"Running Pre-SFT Evaluation...\")\n", + "score = evaluate_model(test_dataset, vllm_rollout, debug=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "u5-M4iYkZUQN" + }, + "outputs": [], + "source": [ + "print(\"========================= Score for PRE-SFT Evaluation =========================\")\n", + "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n", + "print(\n", + " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n", + ")\n", + "print(\n", + " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EJE1ookSAzz-" + }, + "source": [ + "## SFT Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "id": "mgwpNgQYCJEd", + "tags": [] + }, + "outputs": [], + "source": [ + "print(\"Starting SFT Training...\")\n", + "trainer = train_sft.train_model(config, trainer, mesh)\n", + "print(\"SFT Training Complete!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WEdNYRhwZUQN" + }, + "source": [ + "## Evaluation after SFT Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XcsZacZdZUQN" + }, + "outputs": [], + "source": [ + "print(\"Running Post-SFT Evaluation...\")\n", + "model = TunixMaxTextAdapter(trainer.model)\n", + "state = nnx.state(model)\n", + "vllm_rollout.update_params(state)\n", + "score = evaluate_model(test_dataset, vllm_rollout, debug=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "id": "-JtYTPvJZUQN", + "tags": [] + }, + "outputs": [], + "source": [ + "print(\"========================= Score for POST-SFT Evaluation =========================\")\n", + "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n", + "print(\n", + " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n", + ")\n", + "print(\n", + " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n", + ")" + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "gpuType": "V5E1", + "provenance": [] + }, + "kernelspec": { + "display_name": "maxtext_venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/src/maxtext/experimental/agent/ckpt_conversion_agent/README.md b/src/maxtext/experimental/agent/ckpt_conversion_agent/README.md index 44cd63656a..4730c22a9c 100644 --- a/src/maxtext/experimental/agent/ckpt_conversion_agent/README.md +++ b/src/maxtext/experimental/agent/ckpt_conversion_agent/README.md @@ -1,5 +1,5 @@ # Checkpoint conversion agent -The agent is used to automate the model-specific mappings of checkpoint conversion. It is designed to cooperate with the new checkpoint conversion [framework](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/checkpoint_conversion). +The agent is used to automate the model-specific mappings of checkpoint conversion. It is designed to cooperate with the new checkpoint conversion [framework](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/checkpoint_conversion). ## Quick starts To begin, you'll need: @@ -16,7 +16,7 @@ pip install -q -U "google-genai>=1.0.0" ## 1. Prepare the context file -The agent requires context files about the target and source model's parameter names and tensor shapes. You can generate them using the [`save_param.py`](../ckpt_conversion_agent/utils/save_param.py) script. The output directory defined by `config.base_output_directory`. The default is `src/MaxText/experimental/agent/ckpt_conversion_agent/context/` folder. +The agent requires context files about the target and source model's parameter names and tensor shapes. You can generate them using the [`save_param.py`](../ckpt_conversion_agent/utils/save_param.py) script. The output directory defined by `config.base_output_directory`. The default is `src/maxtext/experimental/agent/ckpt_conversion_agent/context/` folder. ```bash python3 -m maxtext.experimental.agent.ckpt_conversion_agent.utils.save_param src/maxtext/configs/base.yml \ per_device_batch_size=1 run_name=param_ model_name= scan_layers=false \ @@ -30,16 +30,16 @@ After it, you can get two `*.json` files in `config.base_output_directory` folde ```bash python3 -m maxtext.experimental.agent.ckpt_conversion_agent.step1 --target_model= \ - --dir_path=src/MaxText/experimental/agent/ckpt_conversion_agent --api_key= + --dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent --api_key= ``` -Our engineer should check the `src/MaxText/experimental/agent/ckpt_conversion_agent/outputs/proposed_dsl.txt` for potential new DSL and assess if it's needed. Then we need to add this ops into `src/MaxText/experimental/agent/ckpt_conversion_agent/context/dsl.txt`. +Our engineer should check the `src/maxtext/experimental/agent/ckpt_conversion_agent/outputs/proposed_dsl.txt` for potential new DSL and assess if it's needed. Then we need to add this ops into `src/maxtext/experimental/agent/ckpt_conversion_agent/context/dsl.txt`. ### 2.2 Step 2: Generate mappings ```bash python3 -m maxtext.experimental.agent.ckpt_conversion_agent.step2 --target_model= \ - --dir_path=src/MaxText/experimental/agent/ckpt_conversion_agent --api_key= + --dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent --api_key= ``` ## Evaluation and Debugging @@ -53,14 +53,14 @@ You can automatically verify the output by comparing the generated code against ```bash python3 -m maxtext.experimental.agent.ckpt_conversion_agent.evaluation --files ground_truth/.py \ - outputs/hook_fn.py --api_key= --dir_path=src/MaxText/experimental/agent/ckpt_conversion_agent + outputs/hook_fn.py --api_key= --dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent ``` ### Manual Debugging (No Ground-Truth Code) If a ground-truth version isn't available, you'll need to debug the conversion manually. The recommended process is to: -1. Add the model mappings into [checkpoint conversion framework](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/README.md#adding-support-for-new-models). +1. Add the model mappings into [checkpoint conversion framework](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/README.md#adding-support-for-new-models). -2. Execute the conversion process layer-by-layer, using [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/README.md#hugging-face-to-maxtext) or [to_huggingface.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/README.md#maxtext-to-hugging-face). +2. Execute the conversion process layer-by-layer, using [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/README.md#hugging-face-to-maxtext) or [to_huggingface.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/README.md#maxtext-to-hugging-face). - If the tensor shape are not matched after conversion, error message will print out the parameter name that caused error. 3. After the conversion is done, run a decode to check the correctness of the generated code. @@ -73,7 +73,7 @@ python3 -m maxtext.inference.decode model_name=gemma3-4b tokenizer_path=src/maxt ``` If outputs are wrong, you can use jax.debug.print() to print the layer-wise mean/max/min values for debugging. -4. To further validate the converted checkpoint, we recommend to use the [forward_pass_logit_checker.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/README.md#verifying-conversion-correctness) to compare the original ckpt with the converted ckpt: +4. To further validate the converted checkpoint, we recommend to use the [forward_pass_logit_checker.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/README.md#verifying-conversion-correctness) to compare the original ckpt with the converted ckpt: ```bash python3 -m tests.utils.forward_pass_logit_checker src/maxtext/configs/base.yml \ tokenizer_path=assets/tokenizers/ \ @@ -121,5 +121,5 @@ Run the [One-shot agent Jyputer notebook](./baselines/one-shot-agent.ipynb) ### Prompt-chain Agent: ```bash python3 -m maxtext.experimental.agent.ckpt_conversion_agent.prompt_chain --target_model= \ - --dir_path=src/MaxText/experimental/agent/ckpt_conversion_agent --api_key= + --dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent --api_key= ``` \ No newline at end of file diff --git a/src/maxtext/experimental/agent/ckpt_conversion_agent/baselines/one-shot-agent.ipynb b/src/maxtext/experimental/agent/ckpt_conversion_agent/baselines/one-shot-agent.ipynb index 0f3a149474..6c82fd4489 100644 --- a/src/maxtext/experimental/agent/ckpt_conversion_agent/baselines/one-shot-agent.ipynb +++ b/src/maxtext/experimental/agent/ckpt_conversion_agent/baselines/one-shot-agent.ipynb @@ -1,458 +1,458 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "bc539d4f", - "metadata": {}, - "outputs": [], - "source": [ - "# Copyright 2025 Google LLC\n", - "#\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# http://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "1f25b113", - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], - "source": [ - "%pip install -U -q 'google-genai>=1.0.0'" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "09a81a73", - "metadata": {}, - "outputs": [], - "source": [ - "from google import genai\n", - "from IPython.display import Markdown" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bf2eab8b", - "metadata": {}, - "outputs": [], - "source": [ - "GOOGLE_API_KEY = \"\"\n", - "\n", - "client = genai.Client(api_key=GOOGLE_API_KEY)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f51eb3cd", - "metadata": {}, - "outputs": [], - "source": [ - "MODEL_ID = \"gemini-2.0-pro\"\n", - "target_model = \"Gemma3\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c7908d62", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": null, + "id": "bc539d4f", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2025 Google LLC\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Uploaded file 'files/96zbvn0jl8v9' as: https://generativelanguage.googleapis.com/v1beta/files/96zbvn0jl8v9\n", - "Uploaded file 'files/i81ci5tcyjwa' as: https://generativelanguage.googleapis.com/v1beta/files/i81ci5tcyjwa\n", - "Uploaded file 'files/poo15cv7l54e' as: https://generativelanguage.googleapis.com/v1beta/files/poo15cv7l54e\n", - "Uploaded file 'files/84qrwp42q92e' as: https://generativelanguage.googleapis.com/v1beta/files/84qrwp42q92e\n" - ] - } - ], - "source": [ - "param_file = client.files.upload(file=\"context/param_mapping.py\")\n", - "shape_file = client.files.upload(file=\"context/hf_shape.py\")\n", - "\n", - "print(f\"Uploaded file '{param_file.name}' as: {param_file.uri}\")\n", - "print(f\"Uploaded file '{shape_file.name}' as: {shape_file.uri}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "a8b3dcf0", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 2, + "id": "1f25b113", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install -U -q 'google-genai>=1.0.0'" + ] + }, { - "data": { - "text/markdown": [ - "```python\n", - "\"\"\"\n", - " Copyright 2025 Google LLC\n", - "\n", - " Licensed under the Apache License, Version 2.0 (the \"License\");\n", - " you may not use this file except in compliance with the License.\n", - " You may obtain a copy of the License at\n", - "\n", - " https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - " Unless required by applicable law or agreed to in writing, software\n", - " distributed under the License is distributed on an \"AS IS\" BASIS,\n", - " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - " See the License for the specific language governing permissions and\n", - " limitations under the License.\n", - " \"\"\"\n", - "\n", - "import numpy as np\n", - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "\n", - "def Gemma3_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False):\n", - " \"\"\"Returns mapping between MaxText and HuggingFace Gemma3 weight paths.\n", - "\n", - " Args:\n", - " config (dict): Model configuration dictionary containing at least 'num_hidden_layers'.\n", - " scan_layers (bool, optional): Whether the MaxText model uses layer scanning optimization.\n", - " When True, decoder layers are stacked into a single tensor [dim1, #layers, dim2].\n", - " Defaults to False.\n", - "\n", - " Returns:\n", - " dict: A mapping where:\n", - " - Keys are MaxText parameter paths\n", - " - Values are either:\n", - " - Single strings (HF parameter path) for unscanned parameters\n", - " - Lists of strings (HF parameter paths) for stacked layers when scan_layers=True\n", - " \"\"\"\n", - "\n", - " nlayers = config[\"num_hidden_layers\"]\n", - " mapping = {\n", - " \"params-token_embedder-embedding\": \"model.embed_tokens.weight\",\n", - " \"params-decoder-decoder_norm-scale\": \"model.norm.weight\",\n", - " }\n", - " if scan_layers:\n", - " mapping = {\n", - " **mapping,\n", - " \"params-decoder-layers-attention-key-kernel\": [\n", - " f\"model.layers.{i}.self_attn.k_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-attention-value-kernel\": [\n", - " f\"model.layers.{i}.self_attn.v_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-attention-query-kernel\": [\n", - " f\"model.layers.{i}.self_attn.q_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-attention-out-kernel\": [\n", - " f\"model.layers.{i}.self_attn.o_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-mlp-wi_0-kernel\": [\n", - " f\"model.layers.{i}.mlp.gate_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-mlp-wi_1-kernel\": [\n", - " f\"model.layers.{i}.mlp.up_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-mlp-wo-kernel\": [\n", - " f\"model.layers.{i}.mlp.down_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-rms_norm-scale\": [\n", - " f\"model.layers.{i}.input_layernorm.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-ffn_rms_norm-scale\": [\n", - " f\"model.layers.{i}.post_attention_layernorm.weight\" for i in range(nlayers)\n", - " ],\n", - " }\n", - " else:\n", - " for layer_idx in range(nlayers):\n", - " layer_mapping = {\n", - " f\"params-decoder-layers_{layer_idx}-attention-key-kernel\": f\"model.layers.{layer_idx}.self_attn.k_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-attention-value-kernel\": f\"model.layers.{layer_idx}.self_attn.v_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-attention-query-kernel\": f\"model.layers.{layer_idx}.self_attn.q_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-attention-out-kernel\": f\"model.layers.{layer_idx}.self_attn.o_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel\": f\"model.layers.{layer_idx}.mlp.gate_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel\": f\"model.layers.{layer_idx}.mlp.up_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-mlp-wo-kernel\": f\"model.layers.{layer_idx}.mlp.down_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-rms_norm-scale\": f\"model.layers.{layer_idx}.input_layernorm.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-ffn_rms_norm-scale\": f\"model.layers.{layer_idx}.post_attention_layernorm.weight\",\n", - " }\n", - " mapping = {**mapping, **layer_mapping}\n", - " return mapping\n", - "\n", - "\n", - "def Gemma3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, saving_to_hf=False):\n", - " \"\"\"Creates parameter transformation functions for converting between MaxText and\n", - " HuggingFace formats.\n", - "\n", - " This function generates a mapping of transformation functions that handle the necessary\n", - " conversions between MaxText and HuggingFace parameter formats, including operations like\n", - " padding, reshaping, and scaling.\n", - "\n", - " Args:\n", - " config (dict): Model configuration dictionary that must contain:\n", - " - num_hidden_layers (int): Number of layers in the model\n", - " - head_dim (int): Dimension of attention heads\n", - " - hidden_size (int): Model's hidden dimension size\n", - "\n", - " scan_layers (bool, optional): Controls the output format for layer parameters:\n", - " - True: Returns transformation functions for batched layer parameters\n", - " - False: Returns transformation functions for individual layer parameters\n", - " Defaults to False.\n", - "\n", - " saving_to_hf (bool, optional): Determines the direction of transformation:\n", - " - True: MaxText → HuggingFace conversion\n", - " - False: HuggingFace → MaxText conversion\n", - " Defaults to False.\n", - "\n", - " Returns:\n", - " dict: Parameter transformation mapping where:\n", - " - Keys: MaxText parameter names (str)\n", - " - Values: Either:\n", - " - callable: Single transformation function\n", - " - list[callable]: List of transformation functions to be applied in sequence\n", - "\n", - " Transformation Details:\n", - " The function handles several types of parameter transformations:\n", - " 1. Embedding layer padding:\n", - " - HF shape: [vocab_size, d_model]\n", - " - MaxText shape: [padded_vocab_size, d_model] (padded for performance)\n", - " 2. Layer normalization scaling:\n", - " - Adds/subtracts 1.0 depending on direction\n", - " 3. Attention query scaling:\n", - " - Scales by sqrt(head_dim) or its inverse\n", - "\n", - " 4. Kernel reshaping:\n", - " - Handles dimension transposition and reshaping between formats\n", - " \"\"\"\n", - " nlayers = config[\"num_hidden_layers\"]\n", - "\n", - " def pad_hf_embedding_layer(input_tensor, target_shape):\n", - " \"\"\"Pads the HF embedding layer to match the MaxText embedding layer's shape.\n", - "\n", - " Note:\n", - " HF embedding weights shape = [vocab_size,d_model]\n", - " MaxText embedding weights shape = [padded_vocab_size,d_model]\n", - " MaxText pad Gemma3 embedding to padded_vocab_size for better performance.\n", - " \"\"\"\n", - " # TODO(wenxindongwork), Perhaps, this dtype should be the activation dtype\n", - " normalizer = np.dtype(\"float32\").type(config[\"hidden_size\"] ** 0.5)\n", - "\n", - " def to_hf():\n", - " target_tensor = input_tensor[: target_shape[0], : target_shape[1]]\n", - " # target_tensor = target_tensor / normalizer # no scale factor for embedding\n", - " target_tensor = target_tensor.astype(input_tensor.dtype)\n", - " return target_tensor\n", - "\n", - " def from_hf():\n", - " target_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)\n", - " target_tensor[: input_tensor.shape[0], : input_tensor.shape[1]] = input_tensor\n", - " # target_tensor = target_tensor * normalizer # no scale factor for embedding\n", - " target_tensor = target_tensor.astype(input_tensor.dtype)\n", - " return target_tensor\n", - "\n", - " if saving_to_hf:\n", - " return to_hf()\n", - " else:\n", - " return from_hf()\n", - "\n", - " def reshape_kernel(input_tensor, target_shape):\n", - " def to_hf():\n", - " flipped_target_shape = np.flip(np.array(target_shape))\n", - " return input_tensor.reshape(flipped_target_shape).T\n", - "\n", - " def from_hf():\n", - " return input_tensor.T.reshape(target_shape)\n", - "\n", - " if saving_to_hf:\n", - " return to_hf()\n", - " else:\n", - " return from_hf()\n", - "\n", - " def scale_rmsnorm_layer(input_tensor, target_shape):\n", - " def to_hf():\n", - " return (input_tensor - 1.0).reshape(target_shape)\n", - "\n", - " def from_hf():\n", - " return (input_tensor + 1.0).reshape(target_shape)\n", - "\n", - " if saving_to_hf:\n", - " return to_hf()\n", - " else:\n", - " return from_hf()\n", - "\n", - " def scale_query_layer(input_tensor, target_shape):\n", - " def to_hf():\n", - " depth_scale = np.dtype(\"float32\").type(np.sqrt(config[\"head_dim\"]))\n", - " return (input_tensor * depth_scale).astype(input_tensor.dtype)\n", - "\n", - " def from_hf():\n", - " depth_scale = np.dtype(\"float32\").type(1 / np.sqrt(config[\"head_dim\"]))\n", - " return (input_tensor * depth_scale).astype(input_tensor.dtype)\n", - "\n", - " if saving_to_hf:\n", - " return to_hf()\n", - " else:\n", - " return from_hf()\n", - "\n", - " mapping = {\n", - " \"params-token_embedder-embedding\": pad_hf_embedding_layer,\n", - " \"params-decoder-decoder_norm-scale\": scale_rmsnorm_layer,\n", - " }\n", - " if scan_layers:\n", - " mapping = {\n", - " **mapping,\n", - " \"params-decoder-layers-attention-query-kernel\": [\n", - " reshape_kernel,\n", - " scale_query_layer,\n", - " ],\n", - " \"params-decoder-layers-attention-key-kernel\": reshape_kernel,\n", - " \"params-decoder-layers-attention-value-kernel\": reshape_kernel,\n", - " \"params-decoder-layers-mlp-wo-kernel\": reshape_kernel,\n", - " \"params-decoder-layers-mlp-wi_1-kernel\": reshape_kernel,\n", - " \"params-decoder-layers-mlp-wi_0-kernel\": reshape_kernel,\n", - " \"params-decoder-layers-attention-out-kernel\": reshape_kernel,\n", - " \"params-decoder-layers-rms_norm-scale\": scale_rmsnorm_layer,\n", - " \"params-decoder-layers-ffn_rms_norm-scale\": scale_rmsnorm_layer,\n", - " }\n", - " else:\n", - " for layer_idx in range(nlayers):\n", - " mapping = {\n", - " **mapping,\n", - " f\"params-decoder-layers_{layer_idx}-attention-query-kernel\": [\n", - " reshape_kernel,\n", - " scale_query_layer,\n", - " ],\n", - " f\"params-decoder-layers_{layer_idx}-attention-key-kernel\": reshape_kernel,\n", - " f\"params-decoder-layers_{layer_idx}-attention-value-kernel\": reshape_kernel,\n", - " f\"params-decoder-layers_{layer_idx}-mlp-wo-kernel\": reshape_kernel,\n", - " f\"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel\": reshape_kernel,\n", - " f\"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel\": reshape_kernel,\n", - " f\"params-decoder-layers_{layer_idx}-attention-out-kernel\": reshape_kernel,\n", - " f\"params-decoder-layers_{layer_idx}-rms_norm-scale\": scale_rmsnorm_layer,\n", - " f\"params-decoder-layers_{layer_idx}-ffn_rms_norm-scale\": scale_rmsnorm_layer,\n", - " }\n", - " return mapping\n", - "\n", - "\n", - "def Gemma3_HF_WEIGHTS_TO_SHAPE_MAPPING(config):\n", - " \"\"\"Returns mapping between HuggingFace weights path and weights shape.\n", - "\n", - " Args:\n", - " config (dict): Model configuration dictionary, defined in `model_configs.py`\n", - "\n", - " Returns:\n", - " dict: A mapping where:\n", - " - Keys are HuggingFace model parameter paths\n", - " - Values are parameter shape as a List\n", - " \"\"\"\n", - "\n", - " mapping = {\n", - " \"model.embed_tokens.weight\": [config[\"vocab_size\"], config[\"hidden_size\"]],\n", - " \"model.norm.weight\": [config[\"hidden_size\"]],\n", - " }\n", - " for layer_idx in range(config[\"num_hidden_layers\"]):\n", - " layer_mapping = {\n", - " f\"model.layers.{layer_idx}.input_layernorm.weight\": [config[\"hidden_size\"]],\n", - " f\"model.layers.{layer_idx}.post_attention_layernorm.weight\": [config[\"hidden_size\"]],\n", - " f\"model.layers.{layer_idx}.self_attn.q_proj.weight\": [\n", - " config[\"num_attention_heads\"] * config[\"head_dim\"],\n", - " config[\"hidden_size\"],\n", - " ],\n", - " f\"model.layers.{layer_idx}.self_attn.k_proj.weight\": [\n", - " config[\"num_key_value_heads\"] * config[\"head_dim\"],\n", - " config[\"hidden_size\"],\n", - " ],\n", - " f\"model.layers.{layer_idx}.self_attn.v_proj.weight\": [\n", - " config[\"num_key_value_heads\"] * config[\"head_dim\"],\n", - " config[\"hidden_size\"],\n", - " ],\n", - " f\"model.layers.{layer_idx}.self_attn.o_proj.weight\": [\n", - " config[\"hidden_size\"],\n", - " config[\"num_attention_heads\"] * config[\"head_dim\"],\n", - " ],\n", - " f\"model.layers.{layer_idx}.mlp.gate_proj.weight\": [\n", - " config[\"intermediate_size\"],\n", - " config[\"hidden_size\"],\n", - " ],\n", - " f\"model.layers.{layer_idx}.mlp.up_proj.weight\": [\n", - " config[\"intermediate_size\"],\n", - " config[\"hidden_size\"],\n", - " ],\n", - " f\"model.layers.{layer_idx}.mlp.down_proj.weight\": [\n", - " config[\"hidden_size\"],\n", - " config[\"intermediate_size\"],\n", - " ],\n", - " }\n", - " mapping = {**mapping, **layer_mapping}\n", - " return mapping\n", - "\n", - "```" + "cell_type": "code", + "execution_count": 3, + "id": "09a81a73", + "metadata": {}, + "outputs": [], + "source": [ + "from google import genai\n", + "from IPython.display import Markdown" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf2eab8b", + "metadata": {}, + "outputs": [], + "source": [ + "GOOGLE_API_KEY = \"\"\n", + "\n", + "client = genai.Client(api_key=GOOGLE_API_KEY)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f51eb3cd", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_ID = \"gemini-2.0-pro\"\n", + "target_model = \"Gemma3\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7908d62", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Uploaded file 'files/96zbvn0jl8v9' as: https://generativelanguage.googleapis.com/v1beta/files/96zbvn0jl8v9\n", + "Uploaded file 'files/i81ci5tcyjwa' as: https://generativelanguage.googleapis.com/v1beta/files/i81ci5tcyjwa\n", + "Uploaded file 'files/poo15cv7l54e' as: https://generativelanguage.googleapis.com/v1beta/files/poo15cv7l54e\n", + "Uploaded file 'files/84qrwp42q92e' as: https://generativelanguage.googleapis.com/v1beta/files/84qrwp42q92e\n" + ] + } ], - "text/plain": [ - "" + "source": [ + "param_file = client.files.upload(file=\"context/param_mapping.py\")\n", + "shape_file = client.files.upload(file=\"context/hf_shape.py\")\n", + "\n", + "print(f\"Uploaded file '{param_file.name}' as: {param_file.uri}\")\n", + "print(f\"Uploaded file '{shape_file.name}' as: {shape_file.uri}\")" ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a8b3dcf0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "```python\n", + "\"\"\"\n", + " Copyright 2025 Google LLC\n", + "\n", + " Licensed under the Apache License, Version 2.0 (the \"License\");\n", + " you may not use this file except in compliance with the License.\n", + " You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + " Unless required by applicable law or agreed to in writing, software\n", + " distributed under the License is distributed on an \"AS IS\" BASIS,\n", + " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + " See the License for the specific language governing permissions and\n", + " limitations under the License.\n", + " \"\"\"\n", + "\n", + "import numpy as np\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "\n", + "def Gemma3_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False):\n", + " \"\"\"Returns mapping between MaxText and HuggingFace Gemma3 weight paths.\n", + "\n", + " Args:\n", + " config (dict): Model configuration dictionary containing at least 'num_hidden_layers'.\n", + " scan_layers (bool, optional): Whether the MaxText model uses layer scanning optimization.\n", + " When True, decoder layers are stacked into a single tensor [dim1, #layers, dim2].\n", + " Defaults to False.\n", + "\n", + " Returns:\n", + " dict: A mapping where:\n", + " - Keys are MaxText parameter paths\n", + " - Values are either:\n", + " - Single strings (HF parameter path) for unscanned parameters\n", + " - Lists of strings (HF parameter paths) for stacked layers when scan_layers=True\n", + " \"\"\"\n", + "\n", + " nlayers = config[\"num_hidden_layers\"]\n", + " mapping = {\n", + " \"params-token_embedder-embedding\": \"model.embed_tokens.weight\",\n", + " \"params-decoder-decoder_norm-scale\": \"model.norm.weight\",\n", + " }\n", + " if scan_layers:\n", + " mapping = {\n", + " **mapping,\n", + " \"params-decoder-layers-attention-key-kernel\": [\n", + " f\"model.layers.{i}.self_attn.k_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-attention-value-kernel\": [\n", + " f\"model.layers.{i}.self_attn.v_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-attention-query-kernel\": [\n", + " f\"model.layers.{i}.self_attn.q_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-attention-out-kernel\": [\n", + " f\"model.layers.{i}.self_attn.o_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-mlp-wi_0-kernel\": [\n", + " f\"model.layers.{i}.mlp.gate_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-mlp-wi_1-kernel\": [\n", + " f\"model.layers.{i}.mlp.up_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-mlp-wo-kernel\": [\n", + " f\"model.layers.{i}.mlp.down_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-rms_norm-scale\": [\n", + " f\"model.layers.{i}.input_layernorm.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-ffn_rms_norm-scale\": [\n", + " f\"model.layers.{i}.post_attention_layernorm.weight\" for i in range(nlayers)\n", + " ],\n", + " }\n", + " else:\n", + " for layer_idx in range(nlayers):\n", + " layer_mapping = {\n", + " f\"params-decoder-layers_{layer_idx}-attention-key-kernel\": f\"model.layers.{layer_idx}.self_attn.k_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-attention-value-kernel\": f\"model.layers.{layer_idx}.self_attn.v_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-attention-query-kernel\": f\"model.layers.{layer_idx}.self_attn.q_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-attention-out-kernel\": f\"model.layers.{layer_idx}.self_attn.o_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel\": f\"model.layers.{layer_idx}.mlp.gate_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel\": f\"model.layers.{layer_idx}.mlp.up_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-mlp-wo-kernel\": f\"model.layers.{layer_idx}.mlp.down_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-rms_norm-scale\": f\"model.layers.{layer_idx}.input_layernorm.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-ffn_rms_norm-scale\": f\"model.layers.{layer_idx}.post_attention_layernorm.weight\",\n", + " }\n", + " mapping = {**mapping, **layer_mapping}\n", + " return mapping\n", + "\n", + "\n", + "def Gemma3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, saving_to_hf=False):\n", + " \"\"\"Creates parameter transformation functions for converting between MaxText and\n", + " HuggingFace formats.\n", + "\n", + " This function generates a mapping of transformation functions that handle the necessary\n", + " conversions between MaxText and HuggingFace parameter formats, including operations like\n", + " padding, reshaping, and scaling.\n", + "\n", + " Args:\n", + " config (dict): Model configuration dictionary that must contain:\n", + " - num_hidden_layers (int): Number of layers in the model\n", + " - head_dim (int): Dimension of attention heads\n", + " - hidden_size (int): Model's hidden dimension size\n", + "\n", + " scan_layers (bool, optional): Controls the output format for layer parameters:\n", + " - True: Returns transformation functions for batched layer parameters\n", + " - False: Returns transformation functions for individual layer parameters\n", + " Defaults to False.\n", + "\n", + " saving_to_hf (bool, optional): Determines the direction of transformation:\n", + " - True: MaxText → HuggingFace conversion\n", + " - False: HuggingFace → MaxText conversion\n", + " Defaults to False.\n", + "\n", + " Returns:\n", + " dict: Parameter transformation mapping where:\n", + " - Keys: MaxText parameter names (str)\n", + " - Values: Either:\n", + " - callable: Single transformation function\n", + " - list[callable]: List of transformation functions to be applied in sequence\n", + "\n", + " Transformation Details:\n", + " The function handles several types of parameter transformations:\n", + " 1. Embedding layer padding:\n", + " - HF shape: [vocab_size, d_model]\n", + " - MaxText shape: [padded_vocab_size, d_model] (padded for performance)\n", + " 2. Layer normalization scaling:\n", + " - Adds/subtracts 1.0 depending on direction\n", + " 3. Attention query scaling:\n", + " - Scales by sqrt(head_dim) or its inverse\n", + "\n", + " 4. Kernel reshaping:\n", + " - Handles dimension transposition and reshaping between formats\n", + " \"\"\"\n", + " nlayers = config[\"num_hidden_layers\"]\n", + "\n", + " def pad_hf_embedding_layer(input_tensor, target_shape):\n", + " \"\"\"Pads the HF embedding layer to match the MaxText embedding layer's shape.\n", + "\n", + " Note:\n", + " HF embedding weights shape = [vocab_size,d_model]\n", + " MaxText embedding weights shape = [padded_vocab_size,d_model]\n", + " MaxText pad Gemma3 embedding to padded_vocab_size for better performance.\n", + " \"\"\"\n", + " # TODO(wenxindongwork), Perhaps, this dtype should be the activation dtype\n", + " normalizer = np.dtype(\"float32\").type(config[\"hidden_size\"] ** 0.5)\n", + "\n", + " def to_hf():\n", + " target_tensor = input_tensor[: target_shape[0], : target_shape[1]]\n", + " # target_tensor = target_tensor / normalizer # no scale factor for embedding\n", + " target_tensor = target_tensor.astype(input_tensor.dtype)\n", + " return target_tensor\n", + "\n", + " def from_hf():\n", + " target_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)\n", + " target_tensor[: input_tensor.shape[0], : input_tensor.shape[1]] = input_tensor\n", + " # target_tensor = target_tensor * normalizer # no scale factor for embedding\n", + " target_tensor = target_tensor.astype(input_tensor.dtype)\n", + " return target_tensor\n", + "\n", + " if saving_to_hf:\n", + " return to_hf()\n", + " else:\n", + " return from_hf()\n", + "\n", + " def reshape_kernel(input_tensor, target_shape):\n", + " def to_hf():\n", + " flipped_target_shape = np.flip(np.array(target_shape))\n", + " return input_tensor.reshape(flipped_target_shape).T\n", + "\n", + " def from_hf():\n", + " return input_tensor.T.reshape(target_shape)\n", + "\n", + " if saving_to_hf:\n", + " return to_hf()\n", + " else:\n", + " return from_hf()\n", + "\n", + " def scale_rmsnorm_layer(input_tensor, target_shape):\n", + " def to_hf():\n", + " return (input_tensor - 1.0).reshape(target_shape)\n", + "\n", + " def from_hf():\n", + " return (input_tensor + 1.0).reshape(target_shape)\n", + "\n", + " if saving_to_hf:\n", + " return to_hf()\n", + " else:\n", + " return from_hf()\n", + "\n", + " def scale_query_layer(input_tensor, target_shape):\n", + " def to_hf():\n", + " depth_scale = np.dtype(\"float32\").type(np.sqrt(config[\"head_dim\"]))\n", + " return (input_tensor * depth_scale).astype(input_tensor.dtype)\n", + "\n", + " def from_hf():\n", + " depth_scale = np.dtype(\"float32\").type(1 / np.sqrt(config[\"head_dim\"]))\n", + " return (input_tensor * depth_scale).astype(input_tensor.dtype)\n", + "\n", + " if saving_to_hf:\n", + " return to_hf()\n", + " else:\n", + " return from_hf()\n", + "\n", + " mapping = {\n", + " \"params-token_embedder-embedding\": pad_hf_embedding_layer,\n", + " \"params-decoder-decoder_norm-scale\": scale_rmsnorm_layer,\n", + " }\n", + " if scan_layers:\n", + " mapping = {\n", + " **mapping,\n", + " \"params-decoder-layers-attention-query-kernel\": [\n", + " reshape_kernel,\n", + " scale_query_layer,\n", + " ],\n", + " \"params-decoder-layers-attention-key-kernel\": reshape_kernel,\n", + " \"params-decoder-layers-attention-value-kernel\": reshape_kernel,\n", + " \"params-decoder-layers-mlp-wo-kernel\": reshape_kernel,\n", + " \"params-decoder-layers-mlp-wi_1-kernel\": reshape_kernel,\n", + " \"params-decoder-layers-mlp-wi_0-kernel\": reshape_kernel,\n", + " \"params-decoder-layers-attention-out-kernel\": reshape_kernel,\n", + " \"params-decoder-layers-rms_norm-scale\": scale_rmsnorm_layer,\n", + " \"params-decoder-layers-ffn_rms_norm-scale\": scale_rmsnorm_layer,\n", + " }\n", + " else:\n", + " for layer_idx in range(nlayers):\n", + " mapping = {\n", + " **mapping,\n", + " f\"params-decoder-layers_{layer_idx}-attention-query-kernel\": [\n", + " reshape_kernel,\n", + " scale_query_layer,\n", + " ],\n", + " f\"params-decoder-layers_{layer_idx}-attention-key-kernel\": reshape_kernel,\n", + " f\"params-decoder-layers_{layer_idx}-attention-value-kernel\": reshape_kernel,\n", + " f\"params-decoder-layers_{layer_idx}-mlp-wo-kernel\": reshape_kernel,\n", + " f\"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel\": reshape_kernel,\n", + " f\"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel\": reshape_kernel,\n", + " f\"params-decoder-layers_{layer_idx}-attention-out-kernel\": reshape_kernel,\n", + " f\"params-decoder-layers_{layer_idx}-rms_norm-scale\": scale_rmsnorm_layer,\n", + " f\"params-decoder-layers_{layer_idx}-ffn_rms_norm-scale\": scale_rmsnorm_layer,\n", + " }\n", + " return mapping\n", + "\n", + "\n", + "def Gemma3_HF_WEIGHTS_TO_SHAPE_MAPPING(config):\n", + " \"\"\"Returns mapping between HuggingFace weights path and weights shape.\n", + "\n", + " Args:\n", + " config (dict): Model configuration dictionary, defined in `model_configs.py`\n", + "\n", + " Returns:\n", + " dict: A mapping where:\n", + " - Keys are HuggingFace model parameter paths\n", + " - Values are parameter shape as a List\n", + " \"\"\"\n", + "\n", + " mapping = {\n", + " \"model.embed_tokens.weight\": [config[\"vocab_size\"], config[\"hidden_size\"]],\n", + " \"model.norm.weight\": [config[\"hidden_size\"]],\n", + " }\n", + " for layer_idx in range(config[\"num_hidden_layers\"]):\n", + " layer_mapping = {\n", + " f\"model.layers.{layer_idx}.input_layernorm.weight\": [config[\"hidden_size\"]],\n", + " f\"model.layers.{layer_idx}.post_attention_layernorm.weight\": [config[\"hidden_size\"]],\n", + " f\"model.layers.{layer_idx}.self_attn.q_proj.weight\": [\n", + " config[\"num_attention_heads\"] * config[\"head_dim\"],\n", + " config[\"hidden_size\"],\n", + " ],\n", + " f\"model.layers.{layer_idx}.self_attn.k_proj.weight\": [\n", + " config[\"num_key_value_heads\"] * config[\"head_dim\"],\n", + " config[\"hidden_size\"],\n", + " ],\n", + " f\"model.layers.{layer_idx}.self_attn.v_proj.weight\": [\n", + " config[\"num_key_value_heads\"] * config[\"head_dim\"],\n", + " config[\"hidden_size\"],\n", + " ],\n", + " f\"model.layers.{layer_idx}.self_attn.o_proj.weight\": [\n", + " config[\"hidden_size\"],\n", + " config[\"num_attention_heads\"] * config[\"head_dim\"],\n", + " ],\n", + " f\"model.layers.{layer_idx}.mlp.gate_proj.weight\": [\n", + " config[\"intermediate_size\"],\n", + " config[\"hidden_size\"],\n", + " ],\n", + " f\"model.layers.{layer_idx}.mlp.up_proj.weight\": [\n", + " config[\"intermediate_size\"],\n", + " config[\"hidden_size\"],\n", + " ],\n", + " f\"model.layers.{layer_idx}.mlp.down_proj.weight\": [\n", + " config[\"hidden_size\"],\n", + " config[\"intermediate_size\"],\n", + " ],\n", + " }\n", + " mapping = {**mapping, **layer_mapping}\n", + " return mapping\n", + "\n", + "```" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prompt = f\"\"\"\n", + " You are a code assist to help me find the checkpoint conversion from MaxText to HuggingFace. \n", + " The checkpoint does not fuse QKV vectors. \n", + " The transformer configs should be completely aligned with given model config for {target_model}\n", + " You need to generate the following code functions of {target_model} Model:\n", + " {target_model}_MAXTEXT_TO_HF_PARAM_MAPPING(); \n", + " {target_model}_MAXTEXT_TO_HF_PARAM_HOOK_FN();\n", + " {target_model}_HF_WEIGHTS_TO_SHAPE_MAPPING();\n", + "\"\"\"\n", + "\n", + "response = client.models.generate_content(model=MODEL_ID, contents=[prompt, param_file, shape_file])\n", + "\n", + "Markdown(response.text)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "agent_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" } - ], - "source": [ - "prompt = f\"\"\"\n", - " You are a code assist to help me find the checkpoint conversion from maxtext to huggingface. \n", - " The checkpoint does not fuse QKV vectors. \n", - " The transformer configs should be completely aligned with given model config for {target_model}\n", - " You need to generate the following code functions of {target_model} Model:\n", - " {target_model}_MAXTEXT_TO_HF_PARAM_MAPPING(); \n", - " {target_model}_MAXTEXT_TO_HF_PARAM_HOOK_FN();\n", - " {target_model}_HF_WEIGHTS_TO_SHAPE_MAPPING();\n", - "\"\"\"\n", - "\n", - "response = client.models.generate_content(model=MODEL_ID, contents=[prompt, param_file, shape_file])\n", - "\n", - "Markdown(response.text)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "agent_env", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index 418fd970db..ff8cdde8a8 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -478,7 +478,7 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F if self._iterator is not None: # Follow MaxText's logic to handle multi-process saving - # Logic extracted from src/MaxText/common/checkpointing.py:save_checkpoint + # Logic extracted from src/maxtext/common/checkpointing.py:save_checkpoint data_iterator = self._iterator if not isinstance(data_iterator, list): data_iterator = [data_iterator] diff --git a/src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py b/src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py index 05c567e924..6e37e2df42 100644 --- a/src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py +++ b/src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py @@ -36,7 +36,7 @@ from itertools import islice from absl import app -from MaxText import pyconfig +from maxtext.configs import pyconfig from maxtext.utils import model_creation_utils from maxtext.input_pipeline import input_pipeline_interface from maxtext.utils import maxtext_utils diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index d77270f62e..5083a2194f 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -85,10 +85,10 @@ def get_maxtext_model(config, devices=None): """ Load MaxText model with Tunix adapter. # Note: pass the path to your scanned checkpoint for 'load_parameters_path'. - # To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if + # To create a scanned checkpoint, you can use /maxtext/src/maxtext/checkpoint_conversion/to_maxtext.py and if # using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags: # export USE_PATHWAYS=1 - # python src/MaxText/checkpoint_conversion/to_maxtext.py \ + # python src/maxtext/checkpoint_conversion/to_maxtext.py \ # --model_name="gemma2-2b" \ # --base_output_directory="/path/to/your/output/directory" \ # --scan_layers=True \ diff --git a/tests/end_to_end/gpu/a3/test_llama2_7b.sh b/tests/end_to_end/gpu/a3/test_llama2_7b.sh index a832c66500..49d6915f54 100644 --- a/tests/end_to_end/gpu/a3/test_llama2_7b.sh +++ b/tests/end_to_end/gpu/a3/test_llama2_7b.sh @@ -16,7 +16,7 @@ idx=$(date +%Y-%m-%d-%H-%M) export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs export ASYNC_CHECKPOINTING=false -# We install torch CPU because the checkpoint conversion script "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/llama_or_mistral_ckpt.py does not need a TPU/GPU +# We install torch CPU because the checkpoint conversion script "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}"/llama_or_mistral_ckpt.py does not need a TPU/GPU python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu # We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint @@ -29,7 +29,7 @@ gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/ # `CONVERTED_CHECKPOINT_PATH` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `CONVERTED_CHECKPOINT_PATH` to a GCS bucket that you own export CONVERTED_CHECKPOINT_PATH=gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext-gpu -#Next, run the conversion script `src/MaxText/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path` +#Next, run the conversion script `src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path` python3 -m maxtext.checkpoint_conversion.standalone_scripts.llama_or_mistral_ckpt --base-model-path /tmp/meta-ckpt --model-size llama2-7b --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH} # We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory exactly inside `CONVERTED_CHECKPOINT_PATH`. This way it is easier to use this path in the `train.py` and `decode.py` commands diff --git a/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md b/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md index 0efdf5729c..10651f7186 100644 --- a/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md +++ b/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md @@ -57,8 +57,8 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ ## Checkpoint conversion To get started, follow the instructions at HuggingFace ([V3](https://huggingface.co/deepseek-ai/DeepSeek-V3), [V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)) to download the model. Currently for V3, V3.1, and R1, it uses mixed precision fp8 & bf16 weights. To convert all FP8 weights to BF16, use the script [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/ckpt_scripts/deepseek_fp8_to_bf16.py). Once downloaded and converted to BF16: -* run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly. -* run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding. +* run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly. +* run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding. ## Fine-tuning diff --git a/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh b/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh index 2e22f344b8..9f5d4c320b 100644 --- a/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh +++ b/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh @@ -20,8 +20,8 @@ export TOKENIZER_PATH='deepseek-ai/DeepSeek-V2-Lite' # Installing torch for checkpoint conversion and forward_pass_logit_checker.py python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -# e.g., $HOME/maxtext/src/MaxText -export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}" +# e.g., $HOME/maxtext/src/maxtext +export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}" if [ -z "${BASE_OUTPUT_PATH}" ]; then # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. diff --git a/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh b/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh index 016d435133..f2b0d62dad 100644 --- a/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh +++ b/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh @@ -18,8 +18,8 @@ export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3' # Installing torch for checkpoint conversion and forward_pass_logit_checker.py python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -# e.g., $HOME/maxtext/src/MaxText -export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}" +# e.g., $HOME/maxtext/src/maxtext +export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}" if [ -z "${BASE_OUTPUT_PATH}" ]; then # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. diff --git a/tests/end_to_end/tpu/gemma/Run_Gemma.md b/tests/end_to_end/tpu/gemma/Run_Gemma.md index 33149c7d21..2fe8141156 100644 --- a/tests/end_to_end/tpu/gemma/Run_Gemma.md +++ b/tests/end_to_end/tpu/gemma/Run_Gemma.md @@ -19,7 +19,7 @@ Following the instructions at [kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText) will let you download Gemma model weights. You will have to consent to license for Gemma using your kaggle account's [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials). -After downloading the weights run [convert_gemma_chkpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_gemma_chkpt.py), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [tests/end_to_end/tpu/gemma](https://github.com/AI-Hypercomputer/maxtext/tree/main/tests/end_to_end/tpu/gemma). +After downloading the weights run [convert_gemma_chkpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gemma_chkpt.py), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [tests/end_to_end/tpu/gemma](https://github.com/AI-Hypercomputer/maxtext/tree/main/tests/end_to_end/tpu/gemma). ## MaxText supports pretraining and finetuning with high performance diff --git a/tests/end_to_end/tpu/gemma3/Run_Gemma3.md b/tests/end_to_end/tpu/gemma3/Run_Gemma3.md index 962ae4a803..f95f39b54e 100644 --- a/tests/end_to_end/tpu/gemma3/Run_Gemma3.md +++ b/tests/end_to_end/tpu/gemma3/Run_Gemma3.md @@ -29,7 +29,7 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml model_n ``` ## Checkpoint Conversion -To obtain the Gemma3 model weights, follow the instructions provided on [Kaggle](https://www.kaggle.com/models/google/gemma-3/flax/). You will need to accept the Gemma3 license through your Kaggle account and utilize your Kaggle [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials) for authentication. Once the weights are downloaded to your GCS bucket, use the [checkpoint conversion utils](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/checkpoint_conversion#usage) to transform the checkpoint into a format compatible with MaxText. This script will also upload the converted checkpoints to a Google Cloud Storage (GCS) bucket. +To obtain the Gemma3 model weights, follow the instructions provided on [Kaggle](https://www.kaggle.com/models/google/gemma-3/flax/). You will need to accept the Gemma3 license through your Kaggle account and utilize your Kaggle [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials) for authentication. Once the weights are downloaded to your GCS bucket, use the [checkpoint conversion utils](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/checkpoint_conversion#usage) to transform the checkpoint into a format compatible with MaxText. This script will also upload the converted checkpoints to a Google Cloud Storage (GCS) bucket. ## Fine-tuning After the conversion, you will have a MaxText compatible checkpoint which allows you to fine-tune it with different datasets. One example command to fine-tune a Gemma3-4B model is as follows: diff --git a/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md b/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md index 0c70557b4f..de505cf439 100644 --- a/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md +++ b/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md @@ -31,7 +31,7 @@ hf download [openai/gpt-oss-20b|openai/gpt-oss-120b] --local-dir --output-path= --dtype-str=bf16 @@ -39,14 +39,14 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.dequantize_mxfp4 --i 3. Once downloaded and converted to BF16: -* run [convert_gpt_oss_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_gpt_oss_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) scanned format for training and fine-tuning. +* run [convert_gpt_oss_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt_oss_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) scanned format for training and fine-tuning. ``` python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gpt_oss_ckpt --base-model-path \ --maxtext-model-path --model-size [gpt-oss-20b|gpt-oss-120b] ``` -* run [convert_gpt_oss_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_gpt_oss_unscanned_ckpt.py) to convert the checkpoint to unscanned format in Orbax for decoding. +* run [convert_gpt_oss_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt_oss_unscanned_ckpt.py) to convert the checkpoint to unscanned format in Orbax for decoding. ``` python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gpt_oss_unscanned_ckpt --base-model-path \ diff --git a/tests/end_to_end/tpu/mixtral/Run_Mixtral.md b/tests/end_to_end/tpu/mixtral/Run_Mixtral.md index c0f93de5e5..243041ff27 100644 --- a/tests/end_to_end/tpu/mixtral/Run_Mixtral.md +++ b/tests/end_to_end/tpu/mixtral/Run_Mixtral.md @@ -19,7 +19,7 @@ [Mixtral](https://mistral.ai/news/mixtral-of-experts/) is a state-of-the-art AI model developed by Mistral AI, utilizing a sparse mixture-of-experts (MoE) architecture. -To get started, follow the instructions at [mistral-inference](https://github.com/mistralai/mistral-inference) to download the model. Once downloaded, run [llama_or_mistral_ckpt.py](../../../src/MaxText/llama_or_mistral_ckpt.py) to convert the checkpoint for MaxText compatibility. You can then proceed with decoding, pretraining, and finetuning. You could find Mixtral 8x7B example in the [tests/end_to_end/tpu/mixtral/8x7b](../mixtral/8x7b) test scripts. +To get started, follow the instructions at [mistral-inference](https://github.com/mistralai/mistral-inference) to download the model. Once downloaded, run [llama_or_mistral_ckpt.py](../../../src/maxtext/llama_or_mistral_ckpt.py) to convert the checkpoint for MaxText compatibility. You can then proceed with decoding, pretraining, and finetuning. You could find Mixtral 8x7B example in the [tests/end_to_end/tpu/mixtral/8x7b](../mixtral/8x7b) test scripts. Additionally, Mixtral integrates with [MegaBlocks](https://arxiv.org/abs/2211.15841), an efficient dropless MoE strategy, which can be activated by setting both sparse_matmul and megablox flags to True (default). diff --git a/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh index 04a7596688..fd11eed7fa 100644 --- a/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh +++ b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh @@ -20,8 +20,8 @@ export TOKENIZER_PATH='Qwen/Qwen3-Next-80B-A3B-Instruct' # Installing torch for checkpoint conversion and forward_pass_logit_checker.py python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -# e.g., $HOME/maxtext/src/MaxText -export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}" +# e.g., $HOME/maxtext/src/maxtext +export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}" if [ -z "${BASE_OUTPUT_PATH}" ]; then # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. diff --git a/tests/end_to_end/tpu/test_grpo.sh b/tests/end_to_end/tpu/test_grpo.sh index 21bf5a6174..e69dbca005 100644 --- a/tests/end_to_end/tpu/test_grpo.sh +++ b/tests/end_to_end/tpu/test_grpo.sh @@ -54,6 +54,6 @@ ici_data_parallelism=${NUM_SAMPLERS} ici_tensor_parallelism=${DEVICES_PER_SAMPLE profiler=xplane skip_first_n_steps_for_profiler=10 profiler_steps=2" JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \ - python3 src/MaxText/experimental/rl/grpo_trainer.py src/MaxText/experimental/rl/grpo.yml \ - ${COMMON_ARGS} ${TRAINING_ARGS} src/MaxText/experimental/rl/grpo_inference.yml \ + python3 src/maxtext/experimental/rl/grpo_trainer.py src/maxtext/experimental/rl/grpo.yml \ + ${COMMON_ARGS} ${TRAINING_ARGS} src/maxtext/experimental/rl/grpo_inference.yml \ ${COMMON_ARGS} ${INFERENCE_ARGS} diff --git a/tests/inference/test_llama2_7b_bf16.sh b/tests/inference/test_llama2_7b_bf16.sh index 82f360b7ec..b02fa32d9a 100755 --- a/tests/inference/test_llama2_7b_bf16.sh +++ b/tests/inference/test_llama2_7b_bf16.sh @@ -1,8 +1,8 @@ #!/bin/bash -CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml" +CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}/configs/base.yml" if [ "${DECOUPLE_GCLOUD^^}" = "TRUE" ]; then - CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/decoupled_base_test.yml" + CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}/configs/decoupled_base_test.yml" fi # Define the arguments in an array diff --git a/tests/inference/test_llama2_7b_int8.sh b/tests/inference/test_llama2_7b_int8.sh index 4056467bc6..8e11e6ab48 100755 --- a/tests/inference/test_llama2_7b_int8.sh +++ b/tests/inference/test_llama2_7b_int8.sh @@ -1,8 +1,8 @@ #!/bin/bash -CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml" +CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}/configs/base.yml" if [ "${DECOUPLE_GCLOUD^^}" = "TRUE" ]; then - CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/decoupled_base_test.yml" + CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}/configs/decoupled_base_test.yml" fi # Define the arguments in an array diff --git a/tests/integration/smoke/train_gpu_smoke_test.py b/tests/integration/smoke/train_gpu_smoke_test.py index de5ca3951a..1df3ef630b 100644 --- a/tests/integration/smoke/train_gpu_smoke_test.py +++ b/tests/integration/smoke/train_gpu_smoke_test.py @@ -53,6 +53,8 @@ def test_tiny_config(self): "enable_goodput_recording=False", "enable_checkpoint_cloud_logger=False", "monitor_goodput=False", + "abort_on_nan_loss=False", + "abort_on_inf_loss=False", ] ) diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index c176e53883..b7a31acae6 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -34,7 +34,7 @@ # For example: # tests/assets/logits_generation/golden_llama2-7b_export.ipynb -"""Check if the logits generated by a model's src/MaxText/HF implementation matches golden logits for the same inputs""" +"""Check if the logits generated by a model's src/maxtext/HF implementation matches golden logits for the same inputs""" import argparse import os diff --git a/tools/dev/code_style.sh b/tools/dev/code_style.sh index c540715158..44af4d74ec 100755 --- a/tools/dev/code_style.sh +++ b/tools/dev/code_style.sh @@ -18,7 +18,7 @@ set -e # Exit immediately if any command fails REPO_ROOT="${MAXTEXT_REPO_ROOT:-$PWD}" -FOLDERS_TO_FORMAT=("${MAXTEXT_PKG_DIR:-${REPO_ROOT}/src/MaxText}" "${REPO_ROOT}/pedagogical_examples") +FOLDERS_TO_FORMAT=("${MAXTEXT_PKG_DIR:-${REPO_ROOT}/src/maxtext}") LINE_LENGTH=$(grep -E "^max-line-length=" pylintrc | cut -d '=' -f 2) # Check for --check flag