Skip to content

Commit 4579891

Browse files
Add parametrize util for targeting parameters outside of nn.Linear modules
1 parent e54dc12 commit 4579891

File tree

1 file changed

+149
-0
lines changed

1 file changed

+149
-0
lines changed

bitsandbytes/nn/parametrize.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from functools import partial
2+
from typing import Any, Literal, Optional
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.utils.parametrize as P
7+
8+
from .. import functional as F
9+
10+
11+
class Bnb4bitParametrization(nn.Module):
12+
"""
13+
A parametrization module that handles dequantization of a 4-bit quantized parameter.
14+
15+
The parameter data is expected to be already quantized when this parametrization is applied.
16+
This module will dequantize the parameter data to its original floating-point representation
17+
when the forward method is called (i.e. when the parameter is accessed).
18+
19+
Args:
20+
quant_state (`F.QuantState`):
21+
The quantization state containing the necessary information for dequantization.
22+
"""
23+
24+
def __init__(self, quant_state: F.QuantState):
25+
super().__init__()
26+
self.quant_state = quant_state
27+
28+
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
29+
"""
30+
Forward pass to dequantize the parameter.
31+
32+
Args:
33+
quantized_param (`torch.Tensor`): The quantized parameter tensor (from .original)
34+
35+
Returns:
36+
`torch.Tensor`: The dequantized parameter tensor in the original shape and dtype.
37+
"""
38+
return F.dequantize_4bit(quantized_param, self.quant_state)
39+
40+
41+
def replace_parameter_4bit(
42+
module: nn.Module,
43+
param_name: str,
44+
compress_statistics: bool = False,
45+
quant_type: Literal["nf4", "fp4"] = "nf4",
46+
blocksize: Optional[int] = None,
47+
):
48+
"""
49+
Replace a module parameter with a 4-bit quantized version using parametrization.
50+
51+
This function quantizes an existing parameter in a PyTorch module to 4-bit precision
52+
and sets up parametrization to handle automatic dequantization during forward passes.
53+
The original parameter is replaced with quantized data, and a parametrization layer
54+
is registered to manage the quantization state and dequantization process.
55+
56+
Additional, it registers a state dict post-hook to ensure that the quantization state
57+
is saved correctly when the model's state dict is saved.
58+
59+
It is useful for MoE models or other scenarios where you want to quantize parameters
60+
outside of nn.Linear layers without changing the model's architecture.
61+
62+
<Tip warning={true}>This feature is experimental and may change in future releases.</Tip>
63+
64+
Args:
65+
module (`nn.Module`):
66+
The PyTorch module containing the parameter to be quantized.
67+
param_name (`str`):
68+
The name of the parameter within the module to quantize.
69+
compress_statistics (`bool`, *optional*, defaults to `False`):
70+
Whether to compress quantization statistics to reduce memory usage.
71+
quant_type (`Literal["nf4", "fp4"]`, *optional*, defaults to `"nf4"`):
72+
The quantization format to use.
73+
blocksize (`int`, *optional*, defaults to `None`):
74+
The block size for quantization. If None, uses the default block size.
75+
76+
Raises:
77+
AttributeError: If the module does not have the specified parameter.
78+
TypeError: If the specified attribute is not an instance of nn.Parameter.
79+
"""
80+
81+
if not hasattr(module, param_name):
82+
raise AttributeError(f"Module does not have parameter '{param_name}'")
83+
84+
original_param = getattr(module, param_name)
85+
86+
if not isinstance(original_param, nn.Parameter):
87+
raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter")
88+
89+
# Quantize the original parameter.
90+
quantized_data, quant_state = F.quantize_4bit(
91+
original_param.data,
92+
blocksize=blocksize,
93+
compress_statistics=compress_statistics,
94+
quant_type=quant_type,
95+
)
96+
97+
# Replace the parameter with the quantized data.
98+
setattr(module, param_name, nn.Parameter(quantized_data, requires_grad=False))
99+
del original_param
100+
101+
# Apply a parametrization to the module to handle dequantization.
102+
P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)
103+
104+
# Next, register state dict hook for saving.
105+
module.register_state_dict_post_hook(
106+
partial(
107+
_parametrized_state_dict_post_hook,
108+
param_name=param_name,
109+
)
110+
)
111+
112+
113+
def _parametrized_state_dict_post_hook(
114+
module: nn.Module,
115+
state_dict: dict[str, Any],
116+
prefix: str,
117+
local_metadata: Any,
118+
*,
119+
param_name: str = "weight",
120+
**kwargs: dict[str, Any],
121+
) -> None:
122+
"""
123+
Hook to modify the state dict to include the quantization state.
124+
"""
125+
126+
original_key = f"{prefix}parametrizations.{param_name}.original"
127+
128+
if original_key in state_dict:
129+
# Create a clean entry.
130+
# The `parametrizations.{param_name}.original` key will have the quantized data,
131+
# but we would like it to keep it in the state_dict as `{param_name}`.
132+
clean_key = f"{prefix}{param_name}"
133+
state_dict[clean_key] = state_dict.pop(original_key)
134+
135+
assert P.is_parametrized(module, param_name)
136+
137+
# Find the parametrization, which should have the quantization state.
138+
parametrization: Bnb4bitParametrization = next(
139+
filter(lambda x: isinstance(x, Bnb4bitParametrization), module.parametrizations[param_name]), None
140+
)
141+
142+
assert parametrization is not None, "Parametrization not found for the parameter."
143+
144+
quant_state = parametrization.quant_state
145+
146+
# Next, we need to store the quantization state.
147+
if quant_state is not None:
148+
for k, v in quant_state.as_dict(packed=True).items():
149+
state_dict[f"{prefix}{param_name}.{k}"] = v

0 commit comments

Comments
 (0)