Skip to content

Commit c31bb1c

Browse files
xin3hestevhliugithub-actions[bot]sayakpaul
authored
Integrate AutoRound into Diffusers (#13552)
* support auto_round Signed-off-by: Xin He <xin3.he@intel.com> * add document and unit tests Signed-off-by: Xin He <xin3.he@intel.com> * fix CI Signed-off-by: Xin He <xin3.he@intel.com> * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update document and overwrite the default quantization_config with specified backend. Signed-off-by: Xin He <xin3.he@intel.com> * add UT and fix bug Signed-off-by: Xin He <xin3.he@intel.com> * update per comments Signed-off-by: Xin He <xin3.he@intel.com> * update per comments Signed-off-by: Xin He <xin3.he@intel.com> * fix compile error in doc Signed-off-by: Xin He <xin3.he@intel.com> * Apply style fixes * small nits * Add auto_round dependency to the versions table Signed-off-by: Xin He <xin3.he@intel.com> * fix make deps_table_check_updated Signed-off-by: Xin He <xin3.he@intel.com> * fix CI Signed-off-by: Xin He <xin3.he@intel.com> --------- Signed-off-by: Xin He <xin3.he@intel.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 89512d2 commit c31bb1c

17 files changed

Lines changed: 747 additions & 3 deletions

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@
180180
title: quanto
181181
- local: quantization/modelopt
182182
title: NVIDIA ModelOpt
183+
- local: quantization/autoround
184+
title: AutoRound
183185
title: Quantization
184186
- isExpanded: false
185187
sections:
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
<!-- Copyright 2026 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# AutoRound
13+
14+
[AutoRound](https://github.com/intel/auto-round) is an advanced quantization toolkit. It achieves high accuracy at ultra-low bit widths (2-4 bits) with minimal tuning by leveraging sign-gradient descent and providing broad hardware compatibility. See our papers [SignRoundV1](https://arxiv.org/pdf/2309.05516) and [SignRoundV2](https://arxiv.org/abs/2512.04746) for more details.
15+
16+
17+
Install `auto-round`(version ≥ 0.13.0):
18+
19+
```bash
20+
pip install "auto-round>=0.13.0"
21+
```
22+
23+
To use the Marlin kernel for faster CUDA inference, install `gptqmodel`:
24+
25+
```bash
26+
pip install "gptqmodel>=5.8.0"
27+
```
28+
29+
## Load a quantized model
30+
31+
Load a pre-quantized AutoRound model by passing [`AutoRoundConfig`] to [`~ModelMixin.from_pretrained`]. The method works with any model that loads via [Accelerate](https://hf.co/docs/accelerate/index) and has `torch.nn.Linear` layers.
32+
33+
You can use [`PipelineQuantizationConfig`] to quantize specific components of a pipeline:
34+
35+
```python
36+
import torch
37+
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, AutoRoundConfig
38+
39+
pipeline_quant_config = PipelineQuantizationConfig(
40+
quant_mapping={"transformer": AutoRoundConfig(backend="auto")}
41+
)
42+
pipe = DiffusionPipeline.from_pretrained(
43+
"INCModel/Z-Image-W4A16-AutoRound",
44+
quantization_config=pipeline_quant_config,
45+
torch_dtype=torch.bfloat16,
46+
device_map="cuda",
47+
)
48+
49+
image = pipe("a cat holding a sign that says hello").images[0]
50+
image.save("output.png")
51+
```
52+
53+
Or load a quantized model component directly:
54+
55+
```python
56+
import torch
57+
from diffusers import ZImageTransformer2DModel, ZImagePipeline, AutoRoundConfig
58+
59+
model_id = "INCModel/Z-Image-W4A16-AutoRound"
60+
61+
quantization_config = AutoRoundConfig(backend="auto")
62+
transformer = ZImageTransformer2DModel.from_pretrained(
63+
model_id,
64+
subfolder="transformer",
65+
quantization_config=quantization_config,
66+
torch_dtype=torch.bfloat16,
67+
device_map="cuda",
68+
)
69+
70+
pipe = ZImagePipeline.from_pretrained(
71+
model_id,
72+
transformer=transformer,
73+
torch_dtype=torch.bfloat16,
74+
device_map="cuda",
75+
)
76+
77+
image = pipe("a cat holding a sign that says hello").images[0]
78+
image.save("output.png")
79+
```
80+
81+
> [!NOTE]
82+
> AutoRound in Diffusers only supports loading *pre-quantized* models. To quantize a model from scratch, use the [AutoRound CLI or Python API](https://github.com/intel/auto-round) directly, then load the result with Diffusers.
83+
84+
## torch.compile
85+
86+
AutoRound is compatible with [`torch.compile`](../optimization/fp16#torchcompile) for faster inference. You can compile the quantized transformer (DiT) for better performance:
87+
88+
```python
89+
import torch
90+
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, AutoRoundConfig
91+
92+
pipeline_quant_config = PipelineQuantizationConfig(
93+
quant_mapping={"transformer": AutoRoundConfig(backend="auto")}
94+
)
95+
pipe = DiffusionPipeline.from_pretrained(
96+
"INCModel/Z-Image-W4A16-AutoRound",
97+
quantization_config=pipeline_quant_config,
98+
torch_dtype=torch.bfloat16,
99+
device_map="cuda",
100+
)
101+
102+
pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False)
103+
```
104+
105+
## Backends
106+
107+
AutoRound supports multiple inference backends for Weight-only quantized model. The backend controls which kernel handles dequantization during the forward pass. Set the `backend` parameter in [`AutoRoundConfig`] to choose one:
108+
109+
| Backend | Value | Device | Requirements | Notes |
110+
|---------|-------|--------|--------------|-------|
111+
| **Auto** | `"auto"` | Any || Default. Automatically selects the best available backend. |
112+
| **PyTorch** | `"torch"` | CPU / CUDA || Pure PyTorch implementation. Broadest compatibility. |
113+
| **Triton** | `"tritonv2"` | CUDA | `triton` | Triton-based kernel for GPU inference. |
114+
| **ExllamaV2** | `"exllamav2"` | CUDA | `gptqmodel>=5.8.0` | Good CUDA performance via the ExllamaV2 kernel. |
115+
| **Marlin** | `"marlin"` | CUDA | `gptqmodel>=5.8.0` | Best CUDA performance via the Marlin kernel. |
116+
117+
118+
```python
119+
from diffusers import AutoRoundConfig
120+
121+
# Auto-select (default)
122+
config = AutoRoundConfig()
123+
124+
# Explicit Triton backend for CUDA
125+
config = AutoRoundConfig(backend="tritonv2")
126+
127+
# Marlin backend for best CUDA performance (requires gptqmodel>=5.8.0)
128+
config = AutoRoundConfig(backend="marlin")
129+
130+
# ExllamaV2 backend for good CUDA performance (requires gptqmodel>=5.8.0)
131+
config = AutoRoundConfig(backend="exllamav2")
132+
133+
# PyTorch backend for CPU/CUDA inference
134+
config = AutoRoundConfig(backend="torch")
135+
```
136+
137+
138+
## Save and load
139+
140+
<hfoptions id="save-and-load">
141+
<hfoption id="save">
142+
143+
AutoRound requires data calibration to quantize a model. This is done outside of Diffusers using the [AutoRound library](https://github.com/intel/auto-round) directly:
144+
145+
```python
146+
from auto_round import AutoRound
147+
148+
autoround = AutoRound(
149+
"Tongyi-MAI/Z-Image",
150+
scheme="W4A16", # W4G128 symmetric
151+
enable_torch_compile=True,
152+
num_inference_steps=3,
153+
guidance_scale=7.5,
154+
dataset="coco2014",
155+
)
156+
autoround.quantize_and_save("Z-Image-W4A16-AutoRound")
157+
```
158+
159+
For more details on calibration options, see the [AutoRound documentation](https://github.com/intel/auto-round).
160+
161+
</hfoption>
162+
<hfoption id="load">
163+
164+
165+
```python
166+
import torch
167+
from diffusers import ZImageTransformer2DModel, ZImagePipeline
168+
169+
model_id = "INCModel/Z-Image-W4A16-AutoRound"
170+
171+
# The inference backend will be automatically selected.
172+
pipe = ZImagePipeline.from_pretrained(
173+
model_id,
174+
torch_dtype=torch.bfloat16,
175+
device_map="cuda",
176+
)
177+
178+
image = pipe("a cat holding a sign that says hello").images[0]
179+
image.save("output.png")
180+
```
181+
</hfoption>
182+
</hfoptions>
183+
184+
185+
### Supported Quantization Schemes
186+
187+
AutoRound supports several Schemes:
188+
189+
- **W4A16**(bits:4,group_size:128,sym:True,act_bits:16)
190+
- **W8A16**(bits:8,group_size:128,sym:True,act_bits:16)
191+
- **W3A16**(bits:3,group_size:128,sym:True,act_bits:16)
192+
- **W2A16**(bits:2,group_size:128,sym:True,act_bits:16)
193+
- **GGUF:Q4_K_M**(all Q*_K,Q*_0,Q*_1 provided by llamacpp are supported)
194+
- **NVFP4**(Experimental feature, recommend exporting to `llm_compressor` format.data_type nvfp4,act_data_type nvfp4,static_global_scale,group_size 16)
195+
- **MXFP4**(**Research feature, no real kernel**, Standard MXFP4, data_type mxfp,act_data_type mxfp,bits 4, act_bits 4, group_size 32)
196+
- **MXINT4**(**Research feature, no real kernel**, Standard MXINT4, data_type mxint,act_data_type mxint,bits 4, act_bits 4, group_size 32)
197+
- **MXFP4_RCEIL**(**Research feature,no real kernel**, NVIDIA's variant, data_type mxfp,act_data_type mxfp_rceil,bits 4, act_bits 4, group_size 32)
198+
- **MXFP8**(**Research feature, no real kernel**, data_type mxfp,act_data_type mxfp_rceil,group_size 32)
199+
- **FPW8A16**(**Research feature, no real kernel**, data_type fp8,group_size 0->per tensor )
200+
- **FP8_STATIC**(**Research feature, no real kernel**, data_type:fp8,act_data_type:fp8,group_size -1 ->per channel, act_group_size=0->per tensor)
201+
202+
Besides, you could modify the `group_size`, `bits`, `sym` and many other configs you want, though there are maybe no real kernels.
203+
204+
## Resources
205+
206+
- [Pre-quantized AutoRound models on the Hub](https://huggingface.co/models?search=autoround)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
"onnx",
131131
"optimum_quanto>=0.2.6",
132132
"gguf>=0.10.0",
133+
"auto-round>=0.13.0",
133134
"torchao>=0.7.0",
134135
"bitsandbytes>=0.43.3",
135136
"nvidia_modelopt[hf]>=0.33.1",

src/diffusers/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
OptionalDependencyNotAvailable,
88
_LazyModule,
99
is_accelerate_available,
10+
is_auto_round_available,
1011
is_bitsandbytes_available,
1112
is_flax_available,
1213
is_gguf_available,
@@ -123,6 +124,18 @@
123124
else:
124125
_import_structure["quantizers.quantization_config"].append("NVIDIAModelOptConfig")
125126

127+
try:
128+
if not is_auto_round_available():
129+
raise OptionalDependencyNotAvailable()
130+
except OptionalDependencyNotAvailable:
131+
from .utils import dummy_auto_round_objects
132+
133+
_import_structure["utils.dummy_auto_round_objects"] = [
134+
name for name in dir(dummy_auto_round_objects) if not name.startswith("_")
135+
]
136+
else:
137+
_import_structure["quantizers.quantization_config"].append("AutoRoundConfig")
138+
126139
try:
127140
if not is_onnx_available():
128141
raise OptionalDependencyNotAvailable()
@@ -982,6 +995,14 @@
982995
else:
983996
from .quantizers.quantization_config import NVIDIAModelOptConfig
984997

998+
try:
999+
if not is_auto_round_available():
1000+
raise OptionalDependencyNotAvailable()
1001+
except OptionalDependencyNotAvailable:
1002+
from .utils.dummy_auto_round_objects import *
1003+
else:
1004+
from .quantizers.quantization_config import AutoRoundConfig
1005+
9851006
try:
9861007
if not is_onnx_available():
9871008
raise OptionalDependencyNotAvailable()

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"onnx": "onnx",
3838
"optimum_quanto": "optimum_quanto>=0.2.6",
3939
"gguf": "gguf>=0.10.0",
40+
"auto-round": "auto-round>=0.13.0",
4041
"torchao": "torchao>=0.7.0",
4142
"bitsandbytes": "bitsandbytes>=0.43.3",
4243
"nvidia_modelopt[hf]": "nvidia_modelopt[hf]>=0.33.1",

src/diffusers/quantizers/auto.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818

1919
import warnings
2020

21+
from .autoround import AutoRoundQuantizer
2122
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
2223
from .gguf import GGUFQuantizer
2324
from .modelopt import NVIDIAModelOptQuantizer
2425
from .quantization_config import (
26+
AutoRoundConfig,
2527
BitsAndBytesConfig,
2628
GGUFQuantizationConfig,
2729
NVIDIAModelOptConfig,
@@ -41,6 +43,7 @@
4143
"quanto": QuantoQuantizer,
4244
"torchao": TorchAoHfQuantizer,
4345
"modelopt": NVIDIAModelOptQuantizer,
46+
"auto-round": AutoRoundQuantizer,
4447
}
4548

4649
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -50,6 +53,7 @@
5053
"quanto": QuantoConfig,
5154
"torchao": TorchAoConfig,
5255
"modelopt": NVIDIAModelOptConfig,
56+
"auto-round": AutoRoundConfig,
5357
}
5458

5559

@@ -143,6 +147,19 @@ def merge_quantization_configs(
143147
if isinstance(quantization_config, NVIDIAModelOptConfig):
144148
quantization_config.check_model_patching()
145149

150+
if quantization_config_from_args is not None and isinstance(quantization_config, AutoRoundConfig):
151+
# For AutoRound, allow overriding fields like `backend` from user args,
152+
# since the model config may store a default value (e.g. backend="auto").
153+
for key, value in quantization_config_from_args.__dict__.items():
154+
if key in ("quant_method",):
155+
continue
156+
if hasattr(quantization_config, key) and getattr(quantization_config, key) != value:
157+
warnings.warn(
158+
f"Overriding `{key}` in the model's quantization_config with value {value!r} "
159+
f"from the user-provided `quantization_config`."
160+
)
161+
setattr(quantization_config, key, value)
162+
146163
if warning_msg != "":
147164
warnings.warn(warning_msg)
148165

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .autoround_quantizer import AutoRoundQuantizer

0 commit comments

Comments
 (0)