forked from NVIDIA/Model-Optimizer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquant_module.py
More file actions
274 lines (222 loc) · 11.3 KB
/
quant_module.py
File metadata and controls
274 lines (222 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for quantization modules."""
import contextlib
import warnings
from typing import Any
import torch
import torch.nn as nn
from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls
from modelopt.torch.utils.distributed import ParallelState
from ...tensor_quant import QUANT_DESC_8BIT_PER_TENSOR
from ...utils import is_torch_export_mode
from .tensor_quantizer import SequentialQuantizer, TensorQuantizer
__all__ = [
"QuantInputBase",
"QuantLinearConvBase",
"QuantModule",
"QuantModuleRegistry",
]
class QuantModule(DynamicModule):
"""A base class for quantized modules.
In addition, the class also provides ``parallel_state`` attribute that can be used to access
the parallel state of the module.
"""
_parallel_state: ParallelState
@classmethod
@torch.no_grad()
def convert(cls, module: nn.Module, **setup_kwargs: Any) -> "QuantModule":
"""Convert the module to a dynamic module."""
module = super().convert(module, **setup_kwargs)
# setup parallel state now that the module is converted
if module.parallel_state is None:
module._initialize_parallel_state()
return module
@property
def parallel_state(self) -> ParallelState | None:
"""Return the parallel state of the quant module."""
return getattr(self, "_parallel_state", None)
@parallel_state.setter
def parallel_state(self, parallel_state: ParallelState):
"""Set the parallel state of the dynamic module."""
assert isinstance(parallel_state, ParallelState), (
"parallel_state must be a ParallelState object!"
)
self._parallel_state = parallel_state
def _initialize_parallel_state(self):
"""Initialize the parallel state of the dynamic module.
This method is called only if the `QuantModule` does not have a `parallel_state` attribute
after `_setup` is called.
"""
if torch.distributed.is_initialized():
warnings.warn(
f"Distributed training is initialized but no parallel_state is set for {type(self)}. "
"Using default parallel_state which has data_parallel_group set to the default process group and "
"tensor_parallel_group is unspecified. "
"If you are using tensor parallelism for this module, you should set the parallel_state "
"in its `_setup` method."
)
self.parallel_state = ParallelState(data_parallel_group=None)
def modelopt_post_restore(self, prefix: str = ""):
"""Post-restore to correctly configure the TensorQuantizer states.
TensorQuantizer states are restored to their shape before saving. Now we need to further configure them.
1. For non-sharded modules this simply involves moving the TensorQuantizer states to the right device.
This applies for regular Pytorch models and HuggingFace models.
2. For sharded modules the restored states of TensorQuantizer could be incorrect. This is because
parallelism such as TP might have been changed between saving and resoring. So we need to re-calculate
the state shapes. Hence such modules should override this and implement their own logic.
"""
# Get a parameter or buffer that does not belong to a TensorQuantizer
non_tq_param_or_buffer = None
for name, param_or_buffer in self.state_dict().items():
parent = self.get_submodule(name.rsplit(".", 1)[0]) if "." in name else self
if not isinstance(parent, TensorQuantizer):
non_tq_param_or_buffer = param_or_buffer
break
if non_tq_param_or_buffer is None:
warnings.warn(
f"Could not identify the device for TensorQuantizer states of {prefix}. "
"Please move the model to the right device now. This can be done by calling "
"`model.to(device)`."
)
return
# Move the TensorQuantizer states to the right device (dtype should have been restored).
for module in self.modules():
if isinstance(module, TensorQuantizer):
module.to(non_tq_param_or_buffer.device)
def iter_weights_for_calibration(self):
"""Yield ``(weight, weight_quantizer)`` pairs for weight-only calibration."""
from modelopt.torch.quantization.utils import quantizer_attr_names, weight_attr_names
for weight_name in weight_attr_names(self):
weight_quantizer = getattr(self, quantizer_attr_names(weight_name).weight_quantizer)
yield getattr(self, weight_name), weight_quantizer
def fold_weight(self, keep_attrs: bool = False):
"""Fold the weight for faster eval."""
# Handle all attributes that end with _weight_quantizer
for name in dir(self):
attr = getattr(self, name)
if (
name.endswith("weight_quantizer")
and isinstance(attr, TensorQuantizer)
and attr.fake_quant
):
# Get the corresponding weight name by removing _weight_quantizer suffix
weight_name = name[:-10]
assert hasattr(self, weight_name), (
f"{name} doesn't have a corresponding {weight_name} in {self.__class__.__name__}"
)
weight = getattr(self, weight_name)
weight.data.copy_(attr(weight.float()).to(weight.dtype))
attr.disable()
if not keep_attrs:
_attrs = [
"_pre_quant_scale",
"_amax",
]
for attr_name in _attrs:
if hasattr(attr, attr_name):
delattr(attr, attr_name)
QuantModuleRegistry = _DMRegistryCls("Quant", QuantModule)
class QuantInputBase(QuantModule):
"""Base class for modules where the input is quantized."""
input_quantizer: TensorQuantizer
output_quantizer: TensorQuantizer
default_quant_desc_input = QUANT_DESC_8BIT_PER_TENSOR
default_quant_desc_output = QUANT_DESC_8BIT_PER_TENSOR
def forward(self, input, *args, **kwargs):
"""Quantize the input before calling the original forward method."""
input = self.input_quantizer(input)
# Check MR: https://github.com/NVIDIA/Model-Optimizer/pull/824
if hasattr(self, "_forward_pre_dm"):
pre_fwd = getattr(self, "_forward_pre_dm")
def _is_forward_in_mro(bound_or_func) -> bool:
# If this is a bound method, compare its underlying function to any `forward`
# implementation in the current MRO. If it matches, it's not an external monkey-patch.
if hasattr(bound_or_func, "__func__"):
fn = bound_or_func.__func__
for cls in type(self).mro():
if cls.__dict__.get("forward") is fn:
return True
return False
if pre_fwd is getattr(self, "forward") or _is_forward_in_mro(pre_fwd):
output = super().forward(input, *args, **kwargs)
else:
output = pre_fwd(input, *args, **kwargs)
else:
output = super().forward(input, *args, **kwargs)
if isinstance(output, tuple):
return (self.output_quantizer(output[0]), *output[1:])
return self.output_quantizer(output)
def _setup(self):
"""Patch the module's forward method to quantize the input."""
self._register_temp_attribute(
"input_quantizer", TensorQuantizer(self.default_quant_desc_input)
)
self._register_temp_attribute(
"output_quantizer", TensorQuantizer(self.default_quant_desc_output)
)
self.output_quantizer.disable()
class QuantLinearConvBase(QuantInputBase):
"""Base class for quantized linear modules.
Quantized linear modules are modules where both the input and the weight are quantized.
"""
weight_quantizer: TensorQuantizer | SequentialQuantizer
_enable_weight_quantization: bool
default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR
@contextlib.contextmanager
def quantize_weight(self):
"""Context in which `self.weight` is quantized."""
self._enable_weight_quantization = True
try:
yield
finally:
self._enable_weight_quantization = False
@staticmethod
def _get_quantized_weight(module: "QuantLinearConvBase", weight: torch.Tensor) -> torch.Tensor:
if module._enable_weight_quantization or is_torch_export_mode():
return module.weight_quantizer(weight)
return weight
def forward(self, input, *args, **kwargs):
"""Quantize the input and the weight before calling the original forward method."""
# self.quntize_weight() setting attributes is not allowed for torch.export.
if is_torch_export_mode():
return super().forward(input, *args, **kwargs)
with self.quantize_weight():
return super().forward(input, *args, **kwargs)
def _setup(self):
super()._setup()
self._register_temp_attribute(
"weight_quantizer", TensorQuantizer(self.default_quant_desc_weight)
)
self._register_temp_attribute("_enable_weight_quantization", False)
self._register_dynamic_attribute("weight", self._get_quantized_weight)
class _LegacyQuantInputBaseMixin:
"""A mixin to support legacy quantized modules which needs to have an __init__ method."""
_quantized_cls = QuantInputBase
default_quant_desc_input = QUANT_DESC_8BIT_PER_TENSOR
default_quant_desc_output = QUANT_DESC_8BIT_PER_TENSOR
def __init__(self, *args, quant_desc_input=None, **kwargs):
"""Initialize the module with its original __init__ and patch its forward."""
self.default_quant_desc_input = quant_desc_input or self.default_quant_desc_input
super().__init__(*args, **kwargs)
QuantModuleRegistry.convert(self)
class _LegacyQuantLinearConvBaseMixin(_LegacyQuantInputBaseMixin):
"""A mixin to support legacy quantized modules which needs to have an __init__ method."""
_quantized_cls = QuantLinearConvBase
default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR
def __init__(self, *args, quant_desc_input=None, quant_desc_weight=None, **kwargs):
"""Initialize the module with its original __init__ and patch its forward."""
self.default_quant_desc_weight = quant_desc_weight or self.default_quant_desc_weight
super().__init__(*args, quant_desc_input=quant_desc_input, **kwargs)