You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/reference/architecture/architecture_overview.md
+3-3Lines changed: 3 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -63,7 +63,7 @@ The table below summarizes some of the most critical parameters in base.yml and
63
63
| dataset_type | input_pipeline.py | Specifies the data loader backend ('tfds', 'grain', 'hf'). |
64
64
| enable_checkpointing | checkpointing.py, train.py | Enables or disables saving model state. |
65
65
| async_checkpointing | checkpointing.py, train.py | If True, saves checkpoints without blocking the training loop. |
66
-
| quantization | layers.py, optimizers.py | Enables quantization, e.g., 'int8' for AQT or Qwix. |
66
+
| quantization | layers.py, optimizers.py | Enables quantization, e.g., 'int8' for Qwix or legacy AQT (deprecated).|
67
67
| compile_topology | train_compile.py | Specifies the target hardware topology for AOT compilation. |
68
68
69
69
## Core architectural components
@@ -82,7 +82,7 @@ While the base model implementations are typically simple, MaxText is equipped t
82
82
83
83
- Advanced attention mechanisms: The architecture is not limited to standard self-attention. It supports variants like Grouped-Query Attention (GQA), Multi-Query Attention (MQA) and Multi-headed Latent Attention (MLA). Since, like MoE, attention can be a performance hot-spot in transformers, attention is typically implemented in [Pallas](https://docs.jax.dev/en/latest/pallas/index.html) kernels, with Splash (Sparse, Flash) Attention being the default for training.
84
84
85
-
- Quantization: The framework seamlessly integrates with Google's Accurate Quantized Training (AQT) and Qwix libraries. Quantization logic is applied at the layer level.
85
+
- Quantization: The framework seamlessly integrates with the Qwix and Google's Accurate Quantized Training (AQT, deprecated) libraries. Quantization logic is applied at the layer level.
86
86
87
87
The modularity of this design is clearly demonstrated by third-party extensions. For instance, the NVIDIA maxtext-jaxpp fork was able to add support for pipeline parallelism by inserting jaxpp.pipeline_enter_stage hooks directly into the \_\_call\_\_ method of the Decoder class, a testament to the codebase's modularity and extensibility.
88
88
@@ -158,7 +158,7 @@ Performance can be further tuned by setting specific XLA flags in the configurat
158
158
159
159
### Quantization for throughput boost
160
160
161
-
One of the most significant performance levers available in MaxText is the integration of Google's Accurate Quantized Training (AQT) and Qwix libraries. These enable training with reduced numerical precision, reducing memory requirements and often increasing FLOPS, while maintaining model quality and convergence characteristics that are very close to the full-precision baseline.
161
+
One of the most significant performance levers available in MaxText is the integration of the Qwix and Google's Accurate Quantized Training (AQT, deprecated) libraries. These enable training with reduced numerical precision, reducing memory requirements and often increasing FLOPS, while maintaining model quality and convergence characteristics that are very close to the full-precision baseline.
162
162
163
163
Integration into MaxText is seamless for the user. Quantization can be enabled by simply setting, for example, `quantization: 'int8'` in the configuration file. This flag activates quantization-aware layers (defined in
164
164
[`src/maxtext/layers/quantizations.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/quantizations.py)) that are applied to the relevant dense layers within the model's Flax definition. The quantization library handles the complexities of simulating quantization during the forward and backward passes, allowing the model to learn weights that are robust to the reduced precision.
Copy file name to clipboardExpand all lines: docs/reference/core_concepts/quantization.md
+8-5Lines changed: 8 additions & 5 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -20,7 +20,7 @@
20
20
21
21
Quantization in deep learning is the process of reducing the precision of numbers used to represent a model's weights and/or activations. Instead of using higher-precision floating-point formats like 32-bit floats (`float32`) or 16-bit brain floats (`bfloat16`), quantization maps these values to lower-precision numerical formats, most commonly 8-bit integers (`int8`) or floats (`fp8`).
22
22
23
-
MaxText supports quantization via both the [AQT](https://github.com/google/aqt) and [Qwix](https://github.com/google/qwix) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply Quantized Training (QT).
23
+
MaxText supports quantization via the [Qwix](https://github.com/google/qwix) library. Accurate Quantized Training (AQT) is deprecated and will be removed in a future release. Qwix is the recommended approach, providing a non-intrusive way to apply Quantized Training (QT).
24
24
25
25
## Why use quantization?
26
26
@@ -40,7 +40,7 @@ The primary trade-off with quantization is between the model accuracy and comput
40
40
- Impact on Gradients: Gradients during backpropagation can have very different, often wider, distributions than weights or activations, making them more sensitive to quantization errors.
41
41
- Convergence Issues: The approximations introduced by quantization can sometimes hinder the model's ability to converge during training.
42
42
43
-
To overcome the challenges of quantization, libraries like Google's Accurate Quantized Training (AQT) and its successor Qwix (used in MaxText) employ a suite of advanced techniques. These methods ensure that models can be trained with low-precision arithmetic without significant loss in accuracy and with stable convergence.
43
+
To overcome the challenges of quantization, libraries like Google's Accurate Quantized Training (AQT, deprecated) and its successor Qwix (used in MaxText) employ a suite of advanced techniques. These methods ensure that models can be trained with low-precision arithmetic without significant loss in accuracy and with stable convergence.
44
44
45
45
## How Quantized Training (QT) works with Qwix
46
46
@@ -56,16 +56,16 @@ By integrating the quantization simulation directly into the training, the model
56
56
57
57
## Using Quantization in MaxText
58
58
59
-
You can enable quantization in MaxText by setting flags in your configuration file (e.g., `base.yml`) or via the command line. MaxText supports two quantization libraries: Qwix (recommended) and AQT.
59
+
You can enable quantization in MaxText by setting flags in your configuration file (e.g., `base.yml`) or via the command line. MaxText supports Qwix (recommended) and the legacy AQT library (deprecated).
60
60
61
61
### Configuration Flags
62
62
63
63
The primary flags to control quantization are:
64
64
65
65
-`use_qwix_quantization`: A boolean flag.
66
66
- Set to `True` to enable quantization using the Qwix library.
67
-
- Set to `False` (or omit) to use the AQT library if `quantization` is set.
68
-
-`quantization`: A string that specifies the type of quantization to apply. The accepted values depend on whether you are using Qwix or AQT.
67
+
- Set to `False` (or omit) to use the AQT library (deprecated) if `quantization` is set.
68
+
-`quantization`: A string that specifies the type of quantization to apply. The accepted values depend on whether you are using Qwix or legacy AQT.
69
69
-`quantization_calibration_method`: The calibration method for weights and activations (e.g., `"absmax"`). This is mainly for Qwix.
70
70
71
71
### Qwix Quantization (Recommended)
@@ -127,6 +127,9 @@ model = qwix.quantize_model(model, qwix.QtProvider(rule))
127
127
128
128
### AQT Quantization
129
129
130
+
> [!WARNING]
131
+
> **DEPRECATION NOTICE**: AQT quantization is deprecated and will be removed in a future release. Please migrate to Qwix by setting `use_qwix_quantization=True`.
132
+
130
133
If `use_qwix_quantization` is `False` or not set, you can still apply quantization using the AQT library by setting the `quantization` flag. You can read more about AQT on this [Google Cloud blog](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e).
Copy file name to clipboardExpand all lines: docs/reference/models/supported_models_and_architectures.md
+2-2Lines changed: 2 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -10,7 +10,7 @@ MaxText is an open-source, high-performance LLM framework written in Python/JAX.
10
10
11
11
-**Supported Precisions**: FP32, BF16, INT8, and FP8.
12
12
-**Ahead-of-Time Compilation (AOT)**: For faster model development/prototyping and earlier OOM detection.
13
-
-**Quantization**: Via **Qwix** (recommended) and AQT. See Quantization [Guide](../reference/core_concepts/quantization.md).
13
+
-**Quantization**: Via **Qwix** (recommended) and AQT (deprecated). See Quantization [Guide](../reference/core_concepts/quantization.md).
14
14
-**Diagnostics**: Simple logging via `max_logging`, profiling in **XProf**, and visualization in **TensorBoard**.
15
15
-**Multi-Token Prediction (MTP)**: Enables token efficient training with multi-token prediction.
16
16
-**Elastic Training**: Fault-tolerant and dynamic scale-up/scale-down on Cloud TPUs with Pathways.
@@ -74,7 +74,7 @@ MaxText supports a wide range of parallelism strategies for scaling training and
74
74
The following summarizes observed runtime efficiency and scaling behaviors of MaxText across different hardware and model types, based on published benchmarks and large-scale runs.
75
75
76
76
-**High MFU**: MaxText targets high Model FLOPs Utilization across scales; exact numbers vary by model, hardware and config. See [**Performance Metrics → MFU**](../performance_metrics.md#performance-metrics) for the definition and how we calculate it.
77
-
-**Quantization**: MaxText supports quantization via both the AQT and Qwix libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
77
+
-**Quantization**: MaxText supports quantization via both the Qwix and AQT (deprecated) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
78
78
-**MoE**: The Mixture-of-Experts implementation features dropless routing with efficient kernels including Megablox, `jax.lax.ragged_dot`, and Tokamax Ragged Dot.
79
79
-**Multi-Token Prediction (MTP)**: This feature improves training efficiency on DeepSeek-style models by adding an auxiliary loss based on predicting multiple future tokens.
80
80
-**Long-Context Optimizations**: Implements various efficient attention mechanisms, including: Grouped-Query Attention (GQA), Sliding-Window Attention (SWA), Local–Global interleaved attention, Multi-Head Latent Attention (MLA). They reduce the KV-cache size, making it possible to handle long contexts efficiently.
use_qwix_quantization: false # whether to use qwix for quantization. if set to true, the model will be quantized using qwix.
146
+
use_qwix_quantization: false #[DEPRECATED: AQT will be removed in a future release. It is strongly recommended to set use_qwix_quantization to true] whether to use qwix for quantization. if set to true, the model will be quantized using qwix.
147
147
use_manual_quantization: false # a flag if to use manual quantization for batch split. Only used if use_batch_split_schedule is true.
148
148
# quantization calibration method used for weights and activations. supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#l70-l80
0 commit comments