Skip to content

Commit 64c12fd

Browse files
authored
[docs] Improve contribution guidelines for Quantization (huggingface#42870)
* update * fix * nit * nit
1 parent f0d9cd1 commit 64c12fd

1 file changed

Lines changed: 95 additions & 16 deletions

File tree

docs/source/en/quantization/contribute.md

Lines changed: 95 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ rendered properly in your Markdown viewer.
1616

1717
# Contribute
1818

19-
Transformers supports many quantization methods such as QLoRA, GPTQ, LLM.int8, and AWQ. However, there are still many more quantization approaches that haven't been integrated yet. To make adding and using these quantization methods with Transformers easier, use the [`~quantizers.HfQuantizer`] class. [`~quantizers.HfQuantizer`] is designed to be an internal helper class for adding a quantization method instead of something applied to every PyTorch module.
19+
Transformers supports many quantization methods such as QLoRA, GPTQ, LLM.int8, and AWQ. However, there are still many more quantization approaches that haven't been integrated yet. To make adding and using these quantization methods with Transformers easier, use the [`~quantizers.HfQuantizer`] class. [`~quantizers.HfQuantizer`] is designed to be an internal helper class for adding a quantization method instead of something applied to every PyTorch module.
2020

2121
This guide will show you how to integrate a new quantization method with [`~quantizers.HfQuantizer`].
2222

@@ -28,16 +28,16 @@ Before integrating a new quantization method into Transformers, ensure the metho
2828
- The method can run on commonly-used hardware (CPU, GPU, etc.).
2929
- The method is wrapped in a [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) ([`~bitsandbytes.nn.Linear8bitLt`], [`~bitsandbytes.nn.Linear4bit`]), and the quantized linear layer should have the following definition.
3030

31-
```py
32-
class Linear4bit(nn.Module):
33-
def __init__(self, ...):
34-
...
35-
36-
def forward(self, x):
37-
return my_4bit_kernel(x, self.weight, self.bias)
38-
```
31+
```py
32+
class Linear4bit(nn.Module):
33+
def __init__(self, ...):
34+
...
3935

40-
This way, Transformers models are easily quantized by replacing instances of [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) with a target class.
36+
def forward(self, x):
37+
return my_4bit_kernel(x, self.weight, self.bias)
38+
```
39+
40+
This way, Transformers models are easily quantized by replacing instances of [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) with a target class.
4141

4242
- The quantization method should be serializable. You can save the quantized weights locally or push them to the Hub.
4343
- Make sure the package containing the quantization kernels/primitive is stable (no frequent breaking changes).
@@ -48,23 +48,23 @@ Some quantization methods may require "pre-quantizing" the model through data ca
4848

4949
0. The best starting point would be to have a look at another quantization method such as Finegrained Fp8. You will have to update or create three files in total: the [config file](https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py), the [integration file](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/finegrained_fp8.py) and the [quantizer file](https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_finegrained_fp8.py).
5050

51-
1. Create a new quantization config class inside [src/transformers/utils/quantization_config.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/utils/quantization_config.py). Add the new quantization config to the [_import_structure](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py#L1088) inside Transformers' [src/transformers/__init__.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py) file.
51+
1. Create a new quantization config class inside [src/transformers/utils/quantization_config.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/utils/quantization_config.py). Add the new quantization config to the [\_import_structure](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py#L1088) inside Transformers' [src/transformers/\_\_init\_\_.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py) file.
5252

5353
2. Create a new file inside [src/transformers/quantizers/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers) named `quantizer_your_method.py`, and make it inherit from [`~quantizers.HfQuantizer]. Make sure to add the new quantizer and quantization config in the quantization auto-mapping in [src/transformers/quantizers/auto.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers/auto.py).
5454

5555
3. Define the following class attributes and property methods for your quantization method:
5656

57-
- `requires_calibration`: Whether the quantization method requires a data calibration process. If set to `True`, you can only support inference (with quantized weights) and not inference and quantization.
58-
- `is_serializable`: A property method to determine whether the method is serializable or not.
59-
- `is_trainable`: A property method to determine whether you can fine-tune models on top of the quantization method (with or without PEFT approaches).
57+
- `requires_calibration`: Whether the quantization method requires a data calibration process. If set to `True`, you can only support inference (with quantized weights) and not inference and quantization.
58+
- `is_serializable`: A property method to determine whether the method is serializable or not.
59+
- `is_trainable`: A property method to determine whether you can fine-tune models on top of the quantization method (with or without PEFT approaches).
6060

6161
4. Write the `validate_environment` and `update_dtype` methods. These methods are called before creating the quantized model to ensure users use the right configuration. Refer to other quantizers for an example of it is implemented.
6262

6363
5. Write the `_process_model_before_weight_loading` method. In Transformers, the quantized models are initialized first on the `"meta"` device before loading the weights. This means the `_process_model_before_weight_loading` method takes care of manipulating the model skeleton to replace some modules ([nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)) with the target modules (quantization modules).
6464

65-
You can define module replacement logic or any other utility method by creating a new file in [transformers/src/integrations/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/integrations) and exposing the relevant methods in that folder's `__init__.py` file.
65+
You can define module replacement logic or any other utility method by creating a new file in [transformers/src/integrations/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/integrations) and exposing the relevant methods in that folder's `__init__.py` file.
6666

67-
6. Add the `get_quantize_ops` method to the quantizer class if the quantization supports quantizing on the fly. In transformers, we materialize each tensor and apply a sequence of different operations on it. In our case, the quantization operation happens at the end. You need to create a `XXXQuantize`, a subclass of `ConversionOps`, and add a `convert` method. In the `convert` method, you need to quantize the weights and return a dictionary of quantized params.
67+
6. Add the `get_quantize_ops` method to the quantizer class if the quantization supports quantizing on the fly. In transformers, we materialize each tensor and apply a sequence of different operations on it. In our case, the quantization operation happens at the end. You need to create a `XXXQuantize`, a subclass of `ConversionOps`, and add a `convert` method. In the `convert` method, you need to quantize the weights and return a dictionary of quantized params.
6868

6969
7. Add the `get_weight_conversions` method to the quantizer class if the quantization supports loading pre-quantized weights. In transformers, we can collect multiple tensors and apply operations on them. This is particularly useful when we have tensors in the checkpoint that require to be regrouped to re-create the quantized tensors.
7070

@@ -73,3 +73,82 @@ You can define module replacement logic or any other utility method by creating
7373
9. Document everything! Make sure your quantization method is documented by adding a new file under `docs/source/en/quantization`.
7474

7575
10. You should add tests by adding the package in our nightly Dockerfile inside `docker/transformers-quantization-latest-gpu` and then adding a new test file in `tests/quantization/xxx`. Feel free to check out existing quantization methods to see how it is implemented.
76+
77+
## Files overview
78+
79+
| File | Purpose |
80+
| -------------------------------------------- | ------------------------------------------------------------------------------------------------ |
81+
| `utils/quantization_config.py` | Define `YourMethodConfig` inheriting from `QuantizationConfigMixin` |
82+
| `quantizers/quantizer_your_method.py` | Implement `YourMethodHfQuantizer` inheriting from `HfQuantizer` |
83+
| `integrations/your_method.py` | Implement `ConversionOps` subclasses and helper functions |
84+
| `quantizers/auto.py` | Register quantizer and config in `AUTO_QUANTIZER_MAPPING` and `AUTO_QUANTIZATION_CONFIG_MAPPING` |
85+
| `docs/source/en/quantization/your_method.md` | Document usage for users |
86+
| `tests/quantization/your_method/` | Add integration tests |
87+
88+
## Understanding `get_quantize_ops` vs `get_weight_conversions`
89+
90+
These two methods handle different scenarios for loading weights. Understanding when to use each is essential.
91+
92+
### `get_quantize_ops` — Quantize on the fly
93+
94+
Use this when loading a **non-quantized checkpoint** (e.g., float16/bfloat16 weights) and quantizing during load.
95+
96+
```
97+
Checkpoint: model.safetensors (float16 weights for example)
98+
99+
get_quantize_ops → YourQuantize.convert()
100+
101+
Result: Quantized weights in memory
102+
```
103+
104+
The `convert` method receives one tensor at a time, quantizes it, and can return a dictionary of quantized params, for example:
105+
106+
```py
107+
class YourQuantize(ConversionOps):
108+
def convert(self, input_dict, model, full_layer_name, missing_keys, **kwargs):
109+
# input_dict = {"layer.weight": <float16 tensor>}
110+
value = list(input_dict.values())[0]
111+
module, tensor_name = get_module_from_name(model, full_layer_name)
112+
113+
# Quantize and assign
114+
quantized, scale, zero_point = your_quantize_fn(value)
115+
return {full_layer_name: quantized, full_layer_name + ".scale": scale, full_layer_name + ".zero_point": zero_point}
116+
```
117+
118+
### `get_weight_conversions` — Load pre-quantized checkpoints
119+
120+
Use this when loading a **pre-quantized checkpoint** where the quantized weights are saved as several separate components (such as data, scale, and zero point), and these need to be combined into one tensor during loading. Not all quantization methods require this reconstruction step: for example, some methods like FP8 simply load weights and scales as-is, without combining them. Others, such as torchao, do require reassembling the quantized tensor from its multiple saved components.
121+
122+
```
123+
Checkpoint: model.safetensors (quantized components)
124+
- layer._weight_qdata
125+
- layer._weight_scale
126+
- layer._weight_zero_point
127+
128+
get_weight_conversions → WeightConverter + YourDeserialize.convert()
129+
130+
Result: Reconstructed quantized tensor → layer.weight
131+
```
132+
133+
The `WeightConverter` collects related tensors based on `source_patterns`, then passes them to your `convert` method:
134+
135+
```py
136+
def get_weight_conversions(self):
137+
if self.pre_quantized:
138+
return [
139+
WeightConverter(
140+
source_patterns=["_weight_qdata", "_weight_scale", "_weight_zero_point"],
141+
target_patterns="weight",
142+
operations=[YourDeserialize(self)],
143+
),
144+
]
145+
return []
146+
147+
148+
class YourDeserialize(ConversionOps):
149+
def convert(self, input_dict, model, full_layer_name, **kwargs):
150+
# input_dict contains all collected tensors
151+
# Reconstruct the quantized tensor from components
152+
reconstructed_tensor = reconstruct_from_components(input_dict)
153+
return {full_layer_name: reconstructed_tensor}
154+
```

0 commit comments

Comments
 (0)