|
| 1 | +<!-- Copyright 2026 The HuggingFace Team. All rights reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
| 4 | +the License. You may obtain a copy of the License at |
| 5 | +
|
| 6 | +http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +
|
| 8 | +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
| 9 | +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
| 10 | +specific language governing permissions and limitations under the License. --> |
| 11 | + |
| 12 | +# AutoRound |
| 13 | + |
| 14 | +[AutoRound](https://github.com/intel/auto-round) is an advanced quantization toolkit. It achieves high accuracy at ultra-low bit widths (2-4 bits) with minimal tuning by leveraging sign-gradient descent and providing broad hardware compatibility. See our papers [SignRoundV1](https://arxiv.org/pdf/2309.05516) and [SignRoundV2](https://arxiv.org/abs/2512.04746) for more details. |
| 15 | + |
| 16 | + |
| 17 | +Install `auto-round`(version ≥ 0.13.0): |
| 18 | + |
| 19 | +```bash |
| 20 | +pip install "auto-round>=0.13.0" |
| 21 | +``` |
| 22 | + |
| 23 | +To use the Marlin kernel for faster CUDA inference, install `gptqmodel`: |
| 24 | + |
| 25 | +```bash |
| 26 | +pip install "gptqmodel>=5.8.0" |
| 27 | +``` |
| 28 | + |
| 29 | +## Load a quantized model |
| 30 | + |
| 31 | +Load a pre-quantized AutoRound model by passing [`AutoRoundConfig`] to [`~ModelMixin.from_pretrained`]. The method works with any model that loads via [Accelerate](https://hf.co/docs/accelerate/index) and has `torch.nn.Linear` layers. |
| 32 | + |
| 33 | +You can use [`PipelineQuantizationConfig`] to quantize specific components of a pipeline: |
| 34 | + |
| 35 | +```python |
| 36 | +import torch |
| 37 | +from diffusers import DiffusionPipeline, PipelineQuantizationConfig, AutoRoundConfig |
| 38 | + |
| 39 | +pipeline_quant_config = PipelineQuantizationConfig( |
| 40 | + quant_mapping={"transformer": AutoRoundConfig(backend="auto")} |
| 41 | +) |
| 42 | +pipe = DiffusionPipeline.from_pretrained( |
| 43 | + "INCModel/Z-Image-W4A16-AutoRound", |
| 44 | + quantization_config=pipeline_quant_config, |
| 45 | + torch_dtype=torch.bfloat16, |
| 46 | + device_map="cuda", |
| 47 | +) |
| 48 | + |
| 49 | +image = pipe("a cat holding a sign that says hello").images[0] |
| 50 | +image.save("output.png") |
| 51 | +``` |
| 52 | + |
| 53 | +Or load a quantized model component directly: |
| 54 | + |
| 55 | +```python |
| 56 | +import torch |
| 57 | +from diffusers import ZImageTransformer2DModel, ZImagePipeline, AutoRoundConfig |
| 58 | + |
| 59 | +model_id = "INCModel/Z-Image-W4A16-AutoRound" |
| 60 | + |
| 61 | +quantization_config = AutoRoundConfig(backend="auto") |
| 62 | +transformer = ZImageTransformer2DModel.from_pretrained( |
| 63 | + model_id, |
| 64 | + subfolder="transformer", |
| 65 | + quantization_config=quantization_config, |
| 66 | + torch_dtype=torch.bfloat16, |
| 67 | + device_map="cuda", |
| 68 | +) |
| 69 | + |
| 70 | +pipe = ZImagePipeline.from_pretrained( |
| 71 | + model_id, |
| 72 | + transformer=transformer, |
| 73 | + torch_dtype=torch.bfloat16, |
| 74 | + device_map="cuda", |
| 75 | +) |
| 76 | + |
| 77 | +image = pipe("a cat holding a sign that says hello").images[0] |
| 78 | +image.save("output.png") |
| 79 | +``` |
| 80 | + |
| 81 | +> [!NOTE] |
| 82 | +> AutoRound in Diffusers only supports loading *pre-quantized* models. To quantize a model from scratch, use the [AutoRound CLI or Python API](https://github.com/intel/auto-round) directly, then load the result with Diffusers. |
| 83 | +
|
| 84 | +## torch.compile |
| 85 | + |
| 86 | +AutoRound is compatible with [`torch.compile`](../optimization/fp16#torchcompile) for faster inference. You can compile the quantized transformer (DiT) for better performance: |
| 87 | + |
| 88 | +```python |
| 89 | +import torch |
| 90 | +from diffusers import DiffusionPipeline, PipelineQuantizationConfig, AutoRoundConfig |
| 91 | + |
| 92 | +pipeline_quant_config = PipelineQuantizationConfig( |
| 93 | + quant_mapping={"transformer": AutoRoundConfig(backend="auto")} |
| 94 | +) |
| 95 | +pipe = DiffusionPipeline.from_pretrained( |
| 96 | + "INCModel/Z-Image-W4A16-AutoRound", |
| 97 | + quantization_config=pipeline_quant_config, |
| 98 | + torch_dtype=torch.bfloat16, |
| 99 | + device_map="cuda", |
| 100 | +) |
| 101 | + |
| 102 | +pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False) |
| 103 | +``` |
| 104 | + |
| 105 | +## Backends |
| 106 | + |
| 107 | +AutoRound supports multiple inference backends for Weight-only quantized model. The backend controls which kernel handles dequantization during the forward pass. Set the `backend` parameter in [`AutoRoundConfig`] to choose one: |
| 108 | + |
| 109 | +| Backend | Value | Device | Requirements | Notes | |
| 110 | +|---------|-------|--------|--------------|-------| |
| 111 | +| **Auto** | `"auto"` | Any | — | Default. Automatically selects the best available backend. | |
| 112 | +| **PyTorch** | `"torch"` | CPU / CUDA | — | Pure PyTorch implementation. Broadest compatibility. | |
| 113 | +| **Triton** | `"tritonv2"` | CUDA | `triton` | Triton-based kernel for GPU inference. | |
| 114 | +| **ExllamaV2** | `"exllamav2"` | CUDA | `gptqmodel>=5.8.0` | Good CUDA performance via the ExllamaV2 kernel. | |
| 115 | +| **Marlin** | `"marlin"` | CUDA | `gptqmodel>=5.8.0` | Best CUDA performance via the Marlin kernel. | |
| 116 | + |
| 117 | + |
| 118 | +```python |
| 119 | +from diffusers import AutoRoundConfig |
| 120 | + |
| 121 | +# Auto-select (default) |
| 122 | +config = AutoRoundConfig() |
| 123 | + |
| 124 | +# Explicit Triton backend for CUDA |
| 125 | +config = AutoRoundConfig(backend="tritonv2") |
| 126 | + |
| 127 | +# Marlin backend for best CUDA performance (requires gptqmodel>=5.8.0) |
| 128 | +config = AutoRoundConfig(backend="marlin") |
| 129 | + |
| 130 | +# ExllamaV2 backend for good CUDA performance (requires gptqmodel>=5.8.0) |
| 131 | +config = AutoRoundConfig(backend="exllamav2") |
| 132 | + |
| 133 | +# PyTorch backend for CPU/CUDA inference |
| 134 | +config = AutoRoundConfig(backend="torch") |
| 135 | +``` |
| 136 | + |
| 137 | + |
| 138 | +## Save and load |
| 139 | + |
| 140 | +<hfoptions id="save-and-load"> |
| 141 | +<hfoption id="save"> |
| 142 | + |
| 143 | +AutoRound requires data calibration to quantize a model. This is done outside of Diffusers using the [AutoRound library](https://github.com/intel/auto-round) directly: |
| 144 | + |
| 145 | +```python |
| 146 | +from auto_round import AutoRound |
| 147 | + |
| 148 | +autoround = AutoRound( |
| 149 | + "Tongyi-MAI/Z-Image", |
| 150 | + scheme="W4A16", # W4G128 symmetric |
| 151 | + enable_torch_compile=True, |
| 152 | + num_inference_steps=3, |
| 153 | + guidance_scale=7.5, |
| 154 | + dataset="coco2014", |
| 155 | +) |
| 156 | +autoround.quantize_and_save("Z-Image-W4A16-AutoRound") |
| 157 | +``` |
| 158 | + |
| 159 | +For more details on calibration options, see the [AutoRound documentation](https://github.com/intel/auto-round). |
| 160 | + |
| 161 | +</hfoption> |
| 162 | +<hfoption id="load"> |
| 163 | + |
| 164 | + |
| 165 | +```python |
| 166 | +import torch |
| 167 | +from diffusers import ZImageTransformer2DModel, ZImagePipeline |
| 168 | + |
| 169 | +model_id = "INCModel/Z-Image-W4A16-AutoRound" |
| 170 | + |
| 171 | +# The inference backend will be automatically selected. |
| 172 | +pipe = ZImagePipeline.from_pretrained( |
| 173 | + model_id, |
| 174 | + torch_dtype=torch.bfloat16, |
| 175 | + device_map="cuda", |
| 176 | +) |
| 177 | + |
| 178 | +image = pipe("a cat holding a sign that says hello").images[0] |
| 179 | +image.save("output.png") |
| 180 | +``` |
| 181 | +</hfoption> |
| 182 | +</hfoptions> |
| 183 | + |
| 184 | + |
| 185 | +### Supported Quantization Schemes |
| 186 | + |
| 187 | +AutoRound supports several Schemes: |
| 188 | + |
| 189 | +- **W4A16**(bits:4,group_size:128,sym:True,act_bits:16) |
| 190 | +- **W8A16**(bits:8,group_size:128,sym:True,act_bits:16) |
| 191 | +- **W3A16**(bits:3,group_size:128,sym:True,act_bits:16) |
| 192 | +- **W2A16**(bits:2,group_size:128,sym:True,act_bits:16) |
| 193 | +- **GGUF:Q4_K_M**(all Q*_K,Q*_0,Q*_1 provided by llamacpp are supported) |
| 194 | +- **NVFP4**(Experimental feature, recommend exporting to `llm_compressor` format.data_type nvfp4,act_data_type nvfp4,static_global_scale,group_size 16) |
| 195 | +- **MXFP4**(**Research feature, no real kernel**, Standard MXFP4, data_type mxfp,act_data_type mxfp,bits 4, act_bits 4, group_size 32) |
| 196 | +- **MXINT4**(**Research feature, no real kernel**, Standard MXINT4, data_type mxint,act_data_type mxint,bits 4, act_bits 4, group_size 32) |
| 197 | +- **MXFP4_RCEIL**(**Research feature,no real kernel**, NVIDIA's variant, data_type mxfp,act_data_type mxfp_rceil,bits 4, act_bits 4, group_size 32) |
| 198 | +- **MXFP8**(**Research feature, no real kernel**, data_type mxfp,act_data_type mxfp_rceil,group_size 32) |
| 199 | +- **FPW8A16**(**Research feature, no real kernel**, data_type fp8,group_size 0->per tensor ) |
| 200 | +- **FP8_STATIC**(**Research feature, no real kernel**, data_type:fp8,act_data_type:fp8,group_size -1 ->per channel, act_group_size=0->per tensor) |
| 201 | + |
| 202 | +Besides, you could modify the `group_size`, `bits`, `sym` and many other configs you want, though there are maybe no real kernels. |
| 203 | + |
| 204 | +## Resources |
| 205 | + |
| 206 | +- [Pre-quantized AutoRound models on the Hub](https://huggingface.co/models?search=autoround) |
0 commit comments