Skip to content

Commit 72ddb3c

Browse files
Deprecate AQT quantization in MaxText
1 parent 4bcec6a commit 72ddb3c

7 files changed

Lines changed: 24 additions & 12 deletions

File tree

docs/reference/architecture/architecture_overview.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ The table below summarizes some of the most critical parameters in base.yml and
6363
| dataset_type | input_pipeline.py | Specifies the data loader backend ('tfds', 'grain', 'hf'). |
6464
| enable_checkpointing | checkpointing.py, train.py | Enables or disables saving model state. |
6565
| async_checkpointing | checkpointing.py, train.py | If True, saves checkpoints without blocking the training loop. |
66-
| quantization | layers.py, optimizers.py | Enables quantization, e.g., 'int8' for AQT or Qwix. |
66+
| quantization | layers.py, optimizers.py | Enables quantization, e.g., 'int8' for Qwix or legacy AQT (deprecated). |
6767
| compile_topology | train_compile.py | Specifies the target hardware topology for AOT compilation. |
6868

6969
## Core architectural components
@@ -82,7 +82,7 @@ While the base model implementations are typically simple, MaxText is equipped t
8282

8383
- Advanced attention mechanisms: The architecture is not limited to standard self-attention. It supports variants like Grouped-Query Attention (GQA), Multi-Query Attention (MQA) and Multi-headed Latent Attention (MLA). Since, like MoE, attention can be a performance hot-spot in transformers, attention is typically implemented in [Pallas](https://docs.jax.dev/en/latest/pallas/index.html) kernels, with Splash (Sparse, Flash) Attention being the default for training.
8484

85-
- Quantization: The framework seamlessly integrates with Google's Accurate Quantized Training (AQT) and Qwix libraries. Quantization logic is applied at the layer level.
85+
- Quantization: The framework seamlessly integrates with the Qwix and Google's Accurate Quantized Training (AQT, deprecated) libraries. Quantization logic is applied at the layer level.
8686

8787
The modularity of this design is clearly demonstrated by third-party extensions. For instance, the NVIDIA maxtext-jaxpp fork was able to add support for pipeline parallelism by inserting jaxpp.pipeline_enter_stage hooks directly into the \_\_call\_\_ method of the Decoder class, a testament to the codebase's modularity and extensibility.
8888

@@ -158,7 +158,7 @@ Performance can be further tuned by setting specific XLA flags in the configurat
158158

159159
### Quantization for throughput boost
160160

161-
One of the most significant performance levers available in MaxText is the integration of Google's Accurate Quantized Training (AQT) and Qwix libraries. These enable training with reduced numerical precision, reducing memory requirements and often increasing FLOPS, while maintaining model quality and convergence characteristics that are very close to the full-precision baseline.
161+
One of the most significant performance levers available in MaxText is the integration of the Qwix and Google's Accurate Quantized Training (AQT, deprecated) libraries. These enable training with reduced numerical precision, reducing memory requirements and often increasing FLOPS, while maintaining model quality and convergence characteristics that are very close to the full-precision baseline.
162162

163163
Integration into MaxText is seamless for the user. Quantization can be enabled by simply setting, for example, `quantization: 'int8'` in the configuration file. This flag activates quantization-aware layers (defined in
164164
[`src/maxtext/layers/quantizations.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/quantizations.py)) that are applied to the relevant dense layers within the model's Flax definition. The quantization library handles the complexities of simulating quantization during the forward and backward passes, allowing the model to learn weights that are robust to the reduced precision.

docs/reference/core_concepts/quantization.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
Quantization in deep learning is the process of reducing the precision of numbers used to represent a model's weights and/or activations. Instead of using higher-precision floating-point formats like 32-bit floats (`float32`) or 16-bit brain floats (`bfloat16`), quantization maps these values to lower-precision numerical formats, most commonly 8-bit integers (`int8`) or floats (`fp8`).
2222

23-
MaxText supports quantization via both the [AQT](https://github.com/google/aqt) and [Qwix](https://github.com/google/qwix) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply Quantized Training (QT).
23+
MaxText supports quantization via the [Qwix](https://github.com/google/qwix) library. Accurate Quantized Training (AQT) is deprecated and will be removed in a future release. Qwix is the recommended approach, providing a non-intrusive way to apply Quantized Training (QT).
2424

2525
## Why use quantization?
2626

@@ -40,7 +40,7 @@ The primary trade-off with quantization is between the model accuracy and comput
4040
- Impact on Gradients: Gradients during backpropagation can have very different, often wider, distributions than weights or activations, making them more sensitive to quantization errors.
4141
- Convergence Issues: The approximations introduced by quantization can sometimes hinder the model's ability to converge during training.
4242

43-
To overcome the challenges of quantization, libraries like Google's Accurate Quantized Training (AQT) and its successor Qwix (used in MaxText) employ a suite of advanced techniques. These methods ensure that models can be trained with low-precision arithmetic without significant loss in accuracy and with stable convergence.
43+
To overcome the challenges of quantization, libraries like Google's Accurate Quantized Training (AQT, deprecated) and its successor Qwix (used in MaxText) employ a suite of advanced techniques. These methods ensure that models can be trained with low-precision arithmetic without significant loss in accuracy and with stable convergence.
4444

4545
## How Quantized Training (QT) works with Qwix
4646

@@ -56,16 +56,16 @@ By integrating the quantization simulation directly into the training, the model
5656

5757
## Using Quantization in MaxText
5858

59-
You can enable quantization in MaxText by setting flags in your configuration file (e.g., `base.yml`) or via the command line. MaxText supports two quantization libraries: Qwix (recommended) and AQT.
59+
You can enable quantization in MaxText by setting flags in your configuration file (e.g., `base.yml`) or via the command line. MaxText supports Qwix (recommended) and the legacy AQT library (deprecated).
6060

6161
### Configuration Flags
6262

6363
The primary flags to control quantization are:
6464

6565
- `use_qwix_quantization`: A boolean flag.
6666
- Set to `True` to enable quantization using the Qwix library.
67-
- Set to `False` (or omit) to use the AQT library if `quantization` is set.
68-
- `quantization`: A string that specifies the type of quantization to apply. The accepted values depend on whether you are using Qwix or AQT.
67+
- Set to `False` (or omit) to use the AQT library (deprecated) if `quantization` is set.
68+
- `quantization`: A string that specifies the type of quantization to apply. The accepted values depend on whether you are using Qwix or legacy AQT.
6969
- `quantization_calibration_method`: The calibration method for weights and activations (e.g., `"absmax"`). This is mainly for Qwix.
7070

7171
### Qwix Quantization (Recommended)
@@ -127,6 +127,9 @@ model = qwix.quantize_model(model, qwix.QtProvider(rule))
127127

128128
### AQT Quantization
129129

130+
> [!WARNING]
131+
> **DEPRECATION NOTICE**: AQT quantization is deprecated and will be removed in a future release. Please migrate to Qwix by setting `use_qwix_quantization=True`.
132+
130133
If `use_qwix_quantization` is `False` or not set, you can still apply quantization using the AQT library by setting the `quantization` flag. You can read more about AQT on this [Google Cloud blog](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e).
131134

132135
#### `quantization` values for AQT

docs/reference/models/supported_models_and_architectures.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ MaxText is an open-source, high-performance LLM framework written in Python/JAX.
1010

1111
- **Supported Precisions**: FP32, BF16, INT8, and FP8.
1212
- **Ahead-of-Time Compilation (AOT)**: For faster model development/prototyping and earlier OOM detection.
13-
- **Quantization**: Via **Qwix** (recommended) and AQT. See Quantization [Guide](../reference/core_concepts/quantization.md).
13+
- **Quantization**: Via **Qwix** (recommended) and AQT (deprecated). See Quantization [Guide](../reference/core_concepts/quantization.md).
1414
- **Diagnostics**: Simple logging via `max_logging`, profiling in **XProf**, and visualization in **TensorBoard**.
1515
- **Multi-Token Prediction (MTP)**: Enables token efficient training with multi-token prediction.
1616
- **Elastic Training**: Fault-tolerant and dynamic scale-up/scale-down on Cloud TPUs with Pathways.
@@ -74,7 +74,7 @@ MaxText supports a wide range of parallelism strategies for scaling training and
7474
The following summarizes observed runtime efficiency and scaling behaviors of MaxText across different hardware and model types, based on published benchmarks and large-scale runs.
7575

7676
- **High MFU**: MaxText targets high Model FLOPs Utilization across scales; exact numbers vary by model, hardware and config. See [**Performance Metrics → MFU**](../performance_metrics.md#performance-metrics) for the definition and how we calculate it.
77-
- **Quantization**: MaxText supports quantization via both the AQT and Qwix libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
77+
- **Quantization**: MaxText supports quantization via both the Qwix and AQT (deprecated) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
7878
- **MoE**: The Mixture-of-Experts implementation features dropless routing with efficient kernels including Megablox, `jax.lax.ragged_dot`, and Tokamax Ragged Dot.
7979
- **Multi-Token Prediction (MTP)**: This feature improves training efficiency on DeepSeek-style models by adding an auxiliary loss based on predicting multiple future tokens.
8080
- **Long-Context Optimizations**: Implements various efficient attention mechanisms, including: Grouped-Query Attention (GQA), Sliding-Window Attention (SWA), Local–Global interleaved attention, Multi-Head Latent Attention (MLA). They reduce the KV-cache size, making it possible to handle long contexts efficiently.

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ save_quantized_params_path: ""
143143
# when left as is, corresponds to training
144144
# accepted values are "inference"
145145
model_call_mode: ""
146-
use_qwix_quantization: false # whether to use qwix for quantization. if set to true, the model will be quantized using qwix.
146+
use_qwix_quantization: false # [DEPRECATED: AQT will be removed in a future release. It is strongly recommended to set use_qwix_quantization to true] whether to use qwix for quantization. if set to true, the model will be quantized using qwix.
147147
use_manual_quantization: false # a flag if to use manual quantization for batch split. Only used if use_batch_split_schedule is true.
148148
# quantization calibration method used for weights and activations. supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#l70-l80
149149
weight_quantization_calibration_method: "absmax"

src/maxtext/configs/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2571,6 +2571,14 @@ def get_num_target_devices():
25712571
}
25722572
self.num_slices = max_utils.get_num_slices(raw_keys_for_num_slices)
25732573

2574+
# Check for AQT deprecation warning
2575+
if self.quantization and not self.use_qwix_quantization:
2576+
if self.quantization not in ("fp8", "nanoo_fp8") and not self.quantization.startswith("te_"):
2577+
logger.warning(
2578+
"WARNING: AQT quantization is deprecated and will be removed in a future release. "
2579+
"Please migrate to Qwix by setting use_qwix_quantization=True."
2580+
)
2581+
25742582
# Default quantization sharding count to number of local devices if not set.
25752583
if self.quantization_local_shard_count == -1:
25762584
try:

src/maxtext/layers/quantizations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,8 +759,8 @@ def get_fp8_full_qwix_rule_w_sparsity(config: Config):
759759

760760

761761
def get_quantization_rule(config: Config):
762-
763762
"""Returns a list of qwix.QtRule from `dtype`."""
763+
764764
def make_qt_rule(dtype) -> list[qwix.QtRule]:
765765
return [
766766
qwix.QtRule(

tests/unit/quantizations_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def test_configure_quantization_replicate_scale(self):
149149
quant = _configure_quantization(quant_str="int8", mode_str=quant_mode, replicate_scale=True)
150150
self.assertEqual(quant.replicate_scale, True)
151151

152+
@pytest.mark.cpu_only
152153
def test_configure_quantization_is_int8(self):
153154
for quant_mode in ["train", "serve", "convert"]:
154155
quant = _configure_quantization(quant_str="int8", mode_str=quant_mode)

0 commit comments

Comments
 (0)