You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/guides/checkpointing_solutions/convert_checkpoint.md
+9-5Lines changed: 9 additions & 5 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,3 +1,5 @@
1
+
(checkpoint-conversion)=
2
+
1
3
# Checkpoint Conversion Utilities
2
4
3
5
This guide provides instructions to use [checkpoint conversion scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/checkpoint_conversion) to convert model checkpoints bidirectionally between Hugging Face and MaxText formats.
@@ -23,10 +25,12 @@ The following models are supported:
23
25
24
26
## Prerequisites
25
27
26
-
- MaxText must be installed in a Python virtual environment using the `maxtext[tpu]` option. For instructions on installing MaxText on your VM, please refer to the official [installation documentation](../../install_maxtext.md).
28
+
- MaxText must be installed in a Python virtual environment using the `maxtext[tpu]` option. For instructions on installing MaxText on your VM, please refer to the official [installation documentation](install-from-source).
27
29
- Hugging Face model checkpoints are cached locally at `$HOME/.cache/huggingface/hub` before conversion. Ensure you have sufficient disk space.
28
30
- Authenticate via the [Hugging Face CLI](https://huggingface.co/docs/huggingface_hub/v0.21.2/guides/cli) if using private or gated models.
29
31
32
+
(hf-to-maxtext)=
33
+
30
34
## Hugging Face to MaxText
31
35
32
36
Use the `to_maxtext.py` script to convert a Hugging Face model checkpoint into a MaxText checkpoint. The script will automatically download the specified model from the Hugging Face Hub, perform conversion, and save converted checkpoints to the given output directory.
@@ -74,7 +78,7 @@ You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/i
74
78
### Key Parameters
75
79
76
80
-`model_name`: The specific model identifier. It must match a supported entry in the MaxText [globals.py](https://github.com/AI-Hypercomputer/maxtext/blob/16b684840db9b96b19e24e84ac49f06af7204ae3/src/maxtext/utils/globals.py#L46C1-L46C7).
77
-
-`scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [here](../../reference/core_concepts/checkpoints.md) for more information. **IMPORTANT:** This setting *must* match the `scan_layers` value used during model training or loading. A mismatch will cause PyTree loading errors (though MaxText will intercept these and raise a descriptive `ValueError` explaining the mismatch).
81
+
-`scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [to the Checkpoints guide](checkpoints) for more information. **IMPORTANT:** This setting *must* match the `scan_layers` value used during model training or loading. A mismatch will cause PyTree loading errors (though MaxText will intercept these and raise a descriptive `ValueError` explaining the mismatch).
78
82
-`use_multimodal`: Indicates if multimodality is used, important for Gemma3.
79
83
-`base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Google Cloud Storage (GCS) or local.
80
84
-`hardware=cpu`: The conversion script runs on a CPU machine.
-`model_name`: The specific model identifier. It must match a supported entry in the MaxText [globals.py](https://github.com/AI-Hypercomputer/maxtext/blob/16b684840db9b96b19e24e84ac49f06af7204ae3/src/maxtext/utils/globals.py#L46C1-L46C7).
128
132
-`load_parameters_path`: The path to the MaxText Orbax checkpoint.
129
-
-`scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [here](../../reference/core_concepts/checkpoints.md) for more information.
133
+
-`scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [to the Checkpoints guide](checkpoints) for more information.
130
134
-`use_multimodal`: Indicates if multimodality is used, important for Gemma3.
131
135
-`hardware=cpu`: The conversion script runs on a CPU machine.
132
136
-`base_output_directory`: The path where the converted checkpoint will be stored; it can be Google Cloud Storage (GCS), Hugging Face Hub or local.
To ensure the conversion was successful, you can use the [test script](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/utils/forward_pass_logit_checker.py). It runs a forward pass on both the original and converted models and compares the output logits to verify conversion. It is used to verify the bidirectional conversion.
138
142
139
-
> **Note:** This correctness test will only work when MaxText is installed from source by following the installation instructions [here](../../install_maxtext.md#from-source).
143
+
> **Note:** This correctness test will only work when MaxText is installed from source by following the installation instructions [here](install-from-source).
-`load_parameters_path`: The path to the MaxText Orbax checkpoint (e.g., `gs://your-bucket/maxtext-checkpoint/0/items`).
179
183
-`model_name`: The corresponding model name in the MaxText configuration (e.g., `qwen3-4b`).
180
-
-`scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [here](../../reference/core_concepts/checkpoints.md) for more information.
184
+
-`scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [to the Checkpoints guide](checkpoints) for more information.
181
185
-`use_multimodal`: Indicates if multimodality is used.
182
186
-`--run_hf_model` (Optional): Indicates if loading Hugging Face model from the hf_model_path. If not set, it will compare the maxtext logits with pre-saved golden logits.
183
187
-`--hf_model_path` (Optional): The path to the Hugging Face checkpoint (if `--run_hf_model=True`).
Copy file name to clipboardExpand all lines: docs/guides/data_input_pipeline.md
+2Lines changed: 2 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -26,6 +26,8 @@ Currently MaxText has three data input pipelines:
26
26
|**[Hugging Face](data_input_pipeline/data_input_hf.md)**| datasets in [Hugging Face Hub](https://huggingface.co/datasets)<br>local/Cloud Storage datasets in json, parquet, arrow, csv, txt (sequential access) | no download needed, convenience; <br>multiple formats | limit scalability using the Hugging Face Hub (no limit using Cloud Storage); <br>non-deterministic with preemption<br>(deterministic without preemption)<br> |
27
27
|**[TFDS](data_input_pipeline/data_input_tfds.md)**| TFRecord (sequential access), available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview)| performant | only supports TFRecords; <br>non-deterministic with preemption<br>(deterministic without preemption) |
28
28
29
+
(multihost-dataloading-best-practice)=
30
+
29
31
## Multihost dataloading best practice
30
32
31
33
Training in a multi-host environment presents unique challenges for data input pipelines. An effective data loading strategy must address three key issues:
Copy file name to clipboardExpand all lines: docs/guides/data_input_pipeline/data_input_grain.md
+4Lines changed: 4 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,3 +1,5 @@
1
+
(grain-pipeline)=
2
+
1
3
# Grain pipeline
2
4
3
5
## The recommended input pipeline for determinism and resilience!
@@ -30,6 +32,8 @@ Grain ensures determinism in data input pipelines by saving the pipeline's state
30
32
-**Global shuffle**: This feature is only available when using Grain with [ArrayRecord](https://github.com/google/array_record) (random access) format, achieved by shuffling indices globally at the beginning of each epoch and then reading the elements according to the random order. This shuffle method effectively prevents local overfitting, leading to better training results.
31
33
-**Hierarchical shuffle**: For sequential access format [Parquet](https://arrow.apache.org/docs/python/parquet.html), shuffle is performed by these steps: file shuffling, interleave from files, and window shuffle using a fixed size buffer.
32
34
35
+
(using-grain)=
36
+
33
37
## Using Grain
34
38
35
39
1. Grain currently supports three data formats: [ArrayRecord](https://github.com/google/array_record) (random access), [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups) and [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources/protocol.html) class.
Copy file name to clipboardExpand all lines: docs/guides/model_bringup.md
+6-6Lines changed: 6 additions & 6 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -20,15 +20,15 @@ This documentation acts as the primary resource for efficiently integrating new
20
20
21
21
## 1. Architecture Analysis
22
22
23
-
The first phase involves determining how the new model's architecture aligns with MaxText's existing capabilities. To facilitate this assessment, refer to the [MaxText architecture overview](../reference/architecture/architecture_overview.md) and [list of supported models](../reference/models/supported_models_and_architectures.md).
23
+
The first phase involves determining how the new model's architecture aligns with MaxText's existing capabilities. To facilitate this assessment, refer to the [MaxText architecture overview](architecture-overview) and [list of supported models](supported-models).
24
24
25
-
**Input Data Pipeline**: MaxText supports HuggingFace, Grain, and TFDS pipelines ([details](data_input_pipeline.md)). While synthetic data is typically used for initial performance benchmarks, the framework supports multiple modalities including text and image (audio and video - work in progress).
25
+
**Input Data Pipeline**: MaxText supports HuggingFace, Grain, and TFDS pipelines ([details](data-input-pipeline)). While synthetic data is typically used for initial performance benchmarks, the framework supports multiple modalities including text and image (audio and video - work in progress).
26
26
27
27
**Tokenizer**: Supported [tokenizer options](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/input_pipeline/tokenizer.py) include `TikTokenTokenizer`, `SentencePieceTokenizer`, and `HFTokenizer`.
28
28
29
29
**Self-Attention & RoPE**: Available mechanisms include optimized [Flash Attention](https://github.com/AI-Hypercomputer/maxtext/blob/62ee818144eb037ad3fe85ab8e789cd074776f46/src/maxtext/layers/attention_op.py#L1184) (supporting MHA, GQA, and MQA), Multi-head Latent Attention ([MLA](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/attention_mla.py)), and [Gated Delta Network](https://github.com/AI-Hypercomputer/maxtext/blob/62ee818144eb037ad3fe85ab8e789cd074776f46/src/maxtext/models/qwen3.py#L358). MaxText also supports [Regular](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L108), [Llama](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L178), and [YaRN](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L282) variations of Rotary Positional Embeddings (RoPE).
30
30
31
-
**Multi-Layer Perceptron (MLP)**: The framework supports both traditional dense models and Mixture of Experts (MoE) architectures, including [configurations](../reference/core_concepts/moe_configuration.md) for routed and shared experts.
31
+
**Multi-Layer Perceptron (MLP)**: The framework supports both traditional dense models and Mixture of Experts (MoE) architectures, including [configurations](moe-configuration) for routed and shared experts.
32
32
33
33
**Normalization**: We support different [normalization strategies](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/normalizations.py), including RMSNorm and Gated RMSNorm. These can be configured before or after attention/MLP layers.
34
34
@@ -44,7 +44,7 @@ This step can be bypassed if the current MaxText codebase already supports all c
44
44
45
45
While most open-source models are distributed in Safetensors or PyTorch formats, MaxText requires conversion to the [Orbax](https://orbax.readthedocs.io/en/latest/) format.
46
46
47
-
There are [two primary formats](../reference/core_concepts/checkpoints.md) for Orbax checkpoints within MaxText, and while both are technically compatible with training and inference, we recommend following these performance-optimized guidelines:
47
+
There are [two primary formats](checkpoints) for Orbax checkpoints within MaxText, and while both are technically compatible with training and inference, we recommend following these performance-optimized guidelines:
48
48
49
49
-**Scanned Format**: Recommended for **training** as it stacks layers for efficient processing via `jax.lax.scan`. To enable this, set `scan_layers=True`.
50
50
-**Unscanned Format**: Recommended for **inference** to simplify loading individual layer parameters. To enable this, set `scan_layers=False`.
@@ -58,7 +58,7 @@ Success starts with a clear map. You must align the parameter names from your so
58
58
59
59
### 3.2 Write Script
60
60
61
-
Use existing model scripts within the repository as templates to tailor the conversion logic for your specific architecture. We strongly recommended to use the [checkpoint conversion utility](checkpointing_solutions/convert_checkpoint.md) rather than [standalone scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/checkpoint_conversion/standalone_scripts).
61
+
Use existing model scripts within the repository as templates to tailor the conversion logic for your specific architecture. We strongly recommend using the [checkpoint conversion utility](checkpoint-conversion) rather than [standalone scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/checkpoint_conversion/standalone_scripts).
62
62
63
63
### 3.3 Verify Compatibility
64
64
@@ -132,7 +132,7 @@ If you run the `forward_pass_logit_checker.py` to compare reference logits with
132
132
133
133
**Q: How to compile models for a target hardware without physical access?**
134
134
135
-
**A:** If you need to compile your training run ahead of time, use the train_compile.py tool. This utility allows you to compile the primary train_step for specific target hardware without needing the actual devices on hand. It’s particularly useful for verifying your implementation's functionality on a local Cloud VM or a standard CPU. Please refer [here](monitoring_and_debugging/features_and_diagnostics.md#ahead-of-time-compilation-aot) for more examples.
135
+
**A:** If you need to compile your training run ahead of time, use the `train_compile.py` tool. This utility allows you to compile the primary `train_step` for specific target hardware without needing the actual devices on hand. It’s particularly useful for verifying your implementation's functionality on a local Cloud VM or a standard CPU. Please refer [here](aot-compilation) for more examples.
136
136
137
137
**Q: My model is too large for my development machine. What should I do?**
Copy file name to clipboardExpand all lines: docs/guides/optimization/custom_model.md
+1-1Lines changed: 1 addition & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -254,7 +254,7 @@ Ironwood over ICI:
254
254
-`3 * M * 8 / 2 > 12800`
255
255
-`M > 1100`
256
256
257
-
It is important to emphasize that this is a theoretical roofline analysis. Real-world performance will depend on the efficiency of the implementation and XLA compilation on the TPU. Refer to the [link](../optimization/sharding.md) for specific challenges regarding PP + FSDP/DP.
257
+
It is important to emphasize that this is a theoretical roofline analysis. Real-world performance will depend on the efficiency of the implementation and XLA compilation on the TPU. Refer to the [link](sharding_on_TPUs) for specific challenges regarding PP + FSDP/DP.
Copy file name to clipboardExpand all lines: docs/guides/run_python_notebook.md
+2-2Lines changed: 2 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -86,7 +86,7 @@ To install, click the `Extensions` icon on the left sidebar (or press `Ctrl+Shif
86
86
87
87
### Step 3: Install MaxText and Dependencies
88
88
89
-
To execute post-training notebooks on your TPU-VM, follow the official [MaxText installation guides](../install_maxtext.md#from-source) and specifically follow `Option 3: Installing [tpu-post-train]`. This will ensure all post-training dependencies are installed inside your virtual environment.
89
+
To execute post-training notebooks on your TPU-VM, follow the official [MaxText installation guides](install-from-source) and specifically follow `Option 3: Installing [tpu-post-train]`. This will ensure all post-training dependencies are installed inside your virtual environment.
90
90
91
91
> **Note:** If you have previously installed MaxText with a different option (e.g., `maxtext[tpu]`), we strongly recommend using a fresh virtual environment for `maxtext[tpu-post-train]` to avoid potential library version conflicts.
92
92
@@ -139,7 +139,7 @@ pip3 install jupyterlab
139
139
140
140
### Step 3: Install MaxText and Dependencies
141
141
142
-
To execute post-training notebooks on your TPU-VM, follow the official [MaxText installation guides](../install_maxtext.md#from-source) and specifically follow `Option 3: Installing [tpu-post-train]`. This will ensure all post-training dependencies are installed inside your virtual environment.
142
+
To execute post-training notebooks on your TPU-VM, follow the official [MaxText installation guides](install-from-source) and specifically follow `Option 3: Installing [tpu-post-train]`. This will ensure all post-training dependencies are installed inside your virtual environment.
143
143
144
144
> **Note:** If you have previously installed MaxText with a different option (e.g., `maxtext[tpu]`), we strongly recommend using a fresh virtual environment for `maxtext[tpu-post-train]` to avoid potential library version conflicts.
Copy file name to clipboardExpand all lines: docs/reference/architecture/jax_ai_libraries_chosen.md
+2-2Lines changed: 2 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -56,11 +56,11 @@ For more information on using Orbax, please refer to https://github.com/google/o
56
56
57
57
1.**Deterministic by Design**: Grain allows storing data loader states, provides strong guarantees about data ordering and sharding even with preemptions, which is critical for reproducibility.
58
58
2.**Global Shuffle**: Prevents local overfitting.
59
-
3.**Built for Multi-Host Training**: The using random access file format streamlines [data loading in the multi-host environments](../../guides/data_input_pipeline.md#multihost-dataloading-best-practice).
59
+
3.**Built for Multi-Host Training**: The using random access file format streamlines [data loading in the multi-host environments](multihost-dataloading-best-practice).
60
60
61
61
Its APIs are explicitly designed for the multi-host paradigm, simplifying the process of ensuring that each host loads a unique shard of the global batch.
62
62
63
-
For more information on using Grain, please refer to https://github.com/google/grain and the grain guide in maxtext located [here](../../guides/data_input_pipeline/data_input_grain.md).
63
+
For more information on using Grain, please refer to https://github.com/google/grain and the [Grain guide in MaxText](grain-pipeline).
0 commit comments