Skip to content

Commit 1fe2125

Browse files
howardzhang-cvgithub-actions[bot]sayakpaul
authored
remove str option for quantization config in torchao (#13291)
* remove str option for quantization config in torchao * Apply style fixes * minor fixes * Added AOBaseConfig docs to torchao.md * minor fixes for removing str option torchao * minor change to add back int and uint check * minor fixes * minor fixes to tests * Update tests/quantization/torchao/test_torchao.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update docs/source/en/quantization/torchao.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update tests/quantization/torchao/test_torchao.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * version=2 update to test_torchao.py --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 7298f5b commit 1fe2125

File tree

6 files changed

+174
-580
lines changed

6 files changed

+174
-580
lines changed

docs/source/en/api/pipelines/cogvideox.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,15 @@ The quantized CogVideoX 5B model below requires ~16GB of VRAM.
4141

4242
```py
4343
import torch
44-
from diffusers import CogVideoXPipeline, AutoModel
44+
from diffusers import CogVideoXPipeline, AutoModel, TorchAoConfig
4545
from diffusers.quantizers import PipelineQuantizationConfig
4646
from diffusers.hooks import apply_group_offloading
4747
from diffusers.utils import export_to_video
48+
from torchao.quantization import Int8WeightOnlyConfig
4849

4950
# quantize weights to int8 with torchao
5051
pipeline_quant_config = PipelineQuantizationConfig(
51-
quant_backend="torchao",
52-
quant_kwargs={"quant_type": "int8wo"},
53-
components_to_quantize="transformer"
52+
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig())}
5453
)
5554

5655
# fp8 layerwise weight-casting

docs/source/en/quantization/torchao.md

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
2929
from torchao.quantization import Int8WeightOnlyConfig
3030

3131
pipeline_quant_config = PipelineQuantizationConfig(
32-
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}
33-
)
34-
pipeline = DiffusionPipeline.from_pretrained(
35-
"black-forest-labs/FLUX.1-dev",
36-
quantization_config=pipeline_quant_config,
37-
torch_dtype=torch.bfloat16,
38-
device_map="cuda"
39-
)
40-
```
41-
42-
For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.
43-
44-
```py
45-
import torch
46-
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
47-
48-
pipeline_quant_config = PipelineQuantizationConfig(
49-
quant_mapping={"transformer": TorchAoConfig("int8wo")}
32+
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128, version=2))}
5033
)
5134
pipeline = DiffusionPipeline.from_pretrained(
5235
"black-forest-labs/FLUX.1-dev",
@@ -91,18 +74,15 @@ Weight-only quantization stores the model weights in a specific low-bit data typ
9174

9275
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
9376

94-
The quantization methods supported are as follows:
77+
Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods. An exhaustive list of configuration options are available [here](https://docs.pytorch.org/ao/main/workflows/inference.html#inference-workflows).
9578

96-
| **Category** | **Full Function Names** | **Shorthands** |
97-
|--------------|-------------------------|----------------|
98-
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
99-
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8dq_e4m3_tensor`, `float8dq_e4m3_row` |
100-
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
101-
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
79+
Some example popular quantization configurations are as follows:
10280

103-
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
104-
105-
Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
81+
| **Category** | **Configuration Classes** |
82+
|---|---|
83+
| **Integer quantization** | [`Int4WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int4WeightOnlyConfig.html), [`Int8WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int8WeightOnlyConfig.html), [`Int8DynamicActivationInt8WeightConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int8DynamicActivationInt8WeightConfig.html) |
84+
| **Floating point 8-bit quantization** | [`Float8WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Float8WeightOnlyConfig.html), [`Float8DynamicActivationFloat8WeightConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Float8DynamicActivationFloat8WeightConfig.html) |
85+
| **Unsigned integer quantization** | [`IntxWeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.IntxWeightOnlyConfig.html) |
10686

10787
## Serializing and Deserializing quantized models
10888

@@ -111,8 +91,9 @@ To serialize a quantized model in a given dtype, first load the model with the d
11191
```python
11292
import torch
11393
from diffusers import AutoModel, TorchAoConfig
94+
from torchao.quantization import Int8WeightOnlyConfig
11495

115-
quantization_config = TorchAoConfig("int8wo")
96+
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
11697
transformer = AutoModel.from_pretrained(
11798
"black-forest-labs/Flux.1-Dev",
11899
subfolder="transformer",
@@ -137,18 +118,19 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
137118
image.save("output.png")
138119
```
139120

140-
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
121+
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4` weight-only, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
141122

142123
```python
143124
import torch
144125
from accelerate import init_empty_weights
145126
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
127+
from torchao.quantization import IntxWeightOnlyConfig
146128

147129
# Serialize the model
148130
transformer = AutoModel.from_pretrained(
149131
"black-forest-labs/Flux.1-Dev",
150132
subfolder="transformer",
151-
quantization_config=TorchAoConfig("uint4wo"),
133+
quantization_config=TorchAoConfig(IntxWeightOnlyConfig(dtype=torch.uint4)),
152134
torch_dtype=torch.bfloat16,
153135
)
154136
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")

0 commit comments

Comments
 (0)