Skip to content

Commit 12b3d0b

Browse files
yghstillwoodchenwu
authored andcommitted
Add diffusion FLUX fp8_static quantization (Tencent#37)
1 parent 54ad709 commit 12b3d0b

37 files changed

Lines changed: 904 additions & 289 deletions

angelslim/compressor/quant/core/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
# limitations under the License.
1414

1515
from .config import * # noqa: F401 F403
16-
from .hook import DiTHook, PTQHook # noqa: F401
16+
from .hook import PTQHook # noqa: F401
1717
from .metrics import mse_loss, snr_loss # noqa: F401
1818
from .packing_utils import dequantize_gemm, pack_weight_to_int8 # noqa: F401
1919
from .quant_func import * # noqa: F401 F403
2020
from .sample_func import EMASampler, MultiStepSampler # noqa: F401
2121
from .save import DeepseekV3HfPTQSave # noqa: F401
2222
from .save import DeepseekV3PTQSaveTRTLLM # noqa: F401
23+
from .save import PTQDiffusionSave # noqa: F401
24+
from .save import PTQOnlyScaleSave # noqa: F401
2325
from .save import PTQPTMSave # noqa: F401
2426
from .save import PTQSaveVllmHF # noqa: F401
2527
from .save import PTQTorchSave # noqa: F401

angelslim/compressor/quant/core/config.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,13 @@ def __init__(self, config, global_config=None):
5757
quantization_args = config.quantization
5858
self.quant_algo = quantization_args.name
5959
self.quant_bit = quantization_args.bits
60-
self.max_seq_length = global_config.max_seq_length
6160
self.quant_helpers = quantization_args.quant_helpers
6261
act_quant_method = quantization_args.quant_method.get("activation", None)
6362
weight_quant_method = quantization_args.quant_method["weight"]
63+
if global_config:
64+
self.max_seq_length = global_config.max_seq_length
65+
self.hidden_size = global_config.hidden_size
66+
self.model_arch_type = global_config.model_arch_type
6467

6568
if "fp8" in self.quant_algo:
6669
is_dynamic = "dynamic" if "dynamic" in self.quant_algo else "static"
@@ -94,8 +97,6 @@ def __init__(self, config, global_config=None):
9497

9598
if act_quant_method is not None:
9699
self.quant_algo_info["a"] = f"fp8_{act_quant_method}-{is_dynamic}"
97-
self.hidden_size = global_config.hidden_size
98-
self.model_arch_type = global_config.model_arch_type
99100
self.low_memory = config.quantization.low_memory
100101
self.quant_analyse = config.quantization.quant_analyse
101102
self.quant_vit = config.quantization.quant_vit
@@ -117,8 +118,6 @@ def __init__(self, config, global_config=None):
117118
}
118119
if act_quant_method is not None:
119120
self.quant_algo_info["a"] = f"int8_{act_quant_method}-{is_dynamic}"
120-
self.hidden_size = global_config.hidden_size
121-
self.model_arch_type = global_config.model_arch_type
122121
self.low_memory = config.quantization.low_memory
123122
self.quant_analyse = config.quantization.quant_analyse
124123
elif "int4_awq" in self.quant_algo:
@@ -135,8 +134,6 @@ def __init__(self, config, global_config=None):
135134
"group_size": int(group_size),
136135
"mse_range": quantization_args.quant_method["mse_range"],
137136
}
138-
self.hidden_size = global_config.hidden_size
139-
self.model_arch_type = global_config.model_arch_type
140137
self.low_memory = config.quantization.low_memory
141138
elif "int4_gptq" in self.quant_algo or "int4_gptaq" in self.quant_algo:
142139
self.act_observer = None
@@ -151,7 +148,6 @@ def __init__(self, config, global_config=None):
151148
"group_size": group_size,
152149
"ignore_layers": quantization_args.ignore_layers,
153150
}
154-
self.hidden_size = global_config.hidden_size
155151

156152
if "smooth" in self.quant_helpers:
157153
self.smooth_alpha = quantization_args.smooth_alpha

angelslim/compressor/quant/core/hook.py

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import re
16-
17-
import torch
18-
1915
from ..observers import ParentObserver, PTQObserver
2016
from .quant_func import get_fp_maxval, get_fp_search_maxval
2117

22-
__all__ = ["PTQHook", "DiTHook"]
18+
__all__ = ["PTQHook"]
2319

2420

2521
class PTQHook:
@@ -119,66 +115,3 @@ def post_process(self):
119115
if self.quant_model.quant_algo_dict["c_quant_algo"] == "fp8":
120116
for k, v in self.quant_model.kv_cache_scales_dict.items():
121117
self.quant_model.kv_cache_scales_dict[k] = v / maxval.type(v.dtype)
122-
123-
124-
def _filter_func(name):
125-
pattern = re.compile(
126-
r".*(mlp_t5|pooler|style_embedder|x_embedder|t_embedder|extra_embedder).*"
127-
)
128-
return pattern.match(name) is not None
129-
130-
131-
class DiTHook:
132-
def __init__(self, model):
133-
"""
134-
Args:
135-
model(nn.Moudle, required): the model to be quant
136-
"""
137-
self.model = model
138-
self.input_activation = []
139-
self.output_activation = []
140-
141-
self._apply_hook()
142-
143-
def _apply_hook(self):
144-
self._forward_hook_list = []
145-
for name, sub_layer in self.model.named_modules():
146-
if _filter_func(name):
147-
continue
148-
if isinstance(sub_layer, (torch.nn.Conv2d, torch.nn.Linear)):
149-
if "blocks" in name:
150-
# handle
151-
forward_pre_hook_handle = sub_layer.register_forward_hook(
152-
self._forward_pre_hook
153-
)
154-
self._forward_hook_list.append(forward_pre_hook_handle)
155-
156-
def _forward_pre_hook(self, layer, input, output):
157-
layer_name = ""
158-
for name, module in self.model.named_modules():
159-
if _filter_func(name):
160-
continue
161-
if module == layer:
162-
layer_name = name
163-
break
164-
x = (
165-
output[0].detach().cpu()
166-
if isinstance(output, tuple)
167-
else output.detach().cpu()
168-
)
169-
self.output_activation.append((layer_name, x))
170-
y = (
171-
input[0].detach().cpu()
172-
if isinstance(input, tuple)
173-
else input.detach().cpu()
174-
)
175-
self.input_activation.append((layer_name, y))
176-
177-
def remove_hook(self):
178-
for hook in self._forward_hook_list:
179-
hook.remove()
180-
self._forward_hook_list = []
181-
182-
def clean_acitvation_list(self):
183-
self.input_activation = []
184-
self.output_activation = []

angelslim/compressor/quant/core/quant_func.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,52 @@ def weight_dequant(
380380
)
381381
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
382382
return y
383+
384+
385+
# This function is copied from DeepSeek-V3 (MIT License):
386+
# Copyright (c) 2023 DeepSeek-AI
387+
# Original source: https://github.com/deepseek-ai/DeepSeek-V3
388+
@triton.jit
389+
def weight_quant(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
390+
"""Quantizes FP32 weights to FP8 format using block-wise quantization."""
391+
pid_m = tl.program_id(axis=0)
392+
pid_n = tl.program_id(axis=1)
393+
n = tl.cdiv(N, BLOCK_SIZE)
394+
395+
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
396+
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
397+
offs = offs_m[:, None] * N + offs_n[None, :]
398+
399+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
400+
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
401+
max_val = tl.max(tl.abs(x))
402+
scale = max_val / 448.0
403+
scale = tl.where(max_val == 0.0, 1.0, scale)
404+
y = x / scale
405+
y = y.to(y_ptr.dtype.element_ty)
406+
407+
tl.store(y_ptr + offs, y, mask=mask)
408+
tl.store(s_ptr + pid_m * n + pid_n, scale)
409+
410+
411+
def per_block_weight_quant(
412+
x: torch.Tensor, block_size: int = 128
413+
) -> Tuple[torch.Tensor, torch.Tensor]:
414+
"""Quantizes FP32 weight tensor to FP8 format using block-wise quantization."""
415+
assert x.is_contiguous()
416+
assert x.dim() == 2
417+
418+
M, N = x.size()
419+
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
420+
m_blocks = triton.cdiv(M, block_size)
421+
n_blocks = triton.cdiv(N, block_size)
422+
s = torch.empty((m_blocks, n_blocks), dtype=torch.float32, device=x.device)
423+
424+
grid = lambda meta: ( # noqa: E731
425+
triton.cdiv(M, meta["BLOCK_SIZE"]),
426+
triton.cdiv(N, meta["BLOCK_SIZE"]),
427+
)
428+
429+
weight_quant[grid](x, y, s, M, N, BLOCK_SIZE=block_size)
430+
431+
return y, s

angelslim/compressor/quant/core/save.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from tqdm import tqdm
2929
from transformers.models.deepseek_v3 import DeepseekV3Config
3030

31-
from ....utils import print_info
32-
from ..modules import QDQModule, QDQSingleModule
31+
from ....utils import find_layers, find_parent_layer_and_sub_name, print_info
32+
from ..modules import QDQModule, QDQSingleModule, QLinear
3333
from .packing_utils import pack_weight_to_int8
3434
from .quant_func import fake_quant_dequant, tensor_quant, weight_dequant
3535

@@ -188,6 +188,96 @@ def save(self, save_path):
188188
self.quant_model.tokenizer.save_pretrained(save_path)
189189

190190

191+
class PTQDiffusionSave(PTQSaveBase):
192+
def __init__(self, quant_model):
193+
super().__init__(quant_model=quant_model)
194+
195+
def save(self, save_path):
196+
a_quant_algo = self.quant_model.quant_config.quant_algo_info["a"]
197+
ignored_layers = self.quant_model.skip_layer_names()
198+
199+
static_q_dict = {
200+
"quantization_config": {
201+
"quant_method": "fp8",
202+
"activation_scheme": (
203+
"dynamic" if "dynamic" in a_quant_algo else "static"
204+
),
205+
"ignored_layers": ignored_layers,
206+
}
207+
}
208+
209+
os.makedirs(save_path, exist_ok=True)
210+
with open(os.path.join(save_path, "hf_quant_config.json"), "w") as f:
211+
json.dump(static_q_dict, f, indent=4)
212+
213+
save_scales = {}
214+
layers_dict = find_layers(
215+
self.quant_model.get_model().transformer, layers=[QDQModule]
216+
)
217+
for name, sub_layer in layers_dict.items():
218+
parent_layer, sub_name = find_parent_layer_and_sub_name(
219+
self.quant_model.get_model().transformer, name
220+
)
221+
q_module = QLinear(
222+
quant_algo=sub_layer.quant_algo,
223+
weight=sub_layer.weight,
224+
bias=sub_layer.bias,
225+
weight_scale=sub_layer.weight_scale.data.clone().detach(),
226+
input_scale=sub_layer.input_scale.data.clone().detach(),
227+
)
228+
setattr(parent_layer, sub_name, q_module)
229+
save_scales[name + ".input_scale"] = sub_layer.input_scale
230+
save_scales[name + ".weight_scale"] = sub_layer.weight_scale
231+
232+
self.quant_model.get_model().save_pretrained(save_path)
233+
safetensor_file = os.path.join(save_path, "model-scales.safetensors")
234+
safe_save(save_scales, safetensor_file)
235+
236+
237+
class PTQOnlyScaleSave(PTQSaveBase):
238+
def __init__(self, quant_model):
239+
super().__init__(quant_model=quant_model)
240+
241+
def save(self, save_path):
242+
a_quant_algo = self.quant_model.quant_config.quant_algo_info["a"]
243+
ignored_layers = self.quant_model.skip_layer_names()
244+
245+
static_q_dict = {
246+
"quantization_config": {
247+
"quant_method": "fp8",
248+
"activation_scheme": (
249+
"dynamic" if "dynamic" in a_quant_algo else "static"
250+
),
251+
"ignored_layers": ignored_layers,
252+
}
253+
}
254+
255+
os.makedirs(save_path, exist_ok=True)
256+
with open(os.path.join(save_path, "hf_quant_config.json"), "w") as f:
257+
json.dump(static_q_dict, f, indent=4)
258+
259+
save_scales = {}
260+
new_model_index = {
261+
"metadata": {},
262+
"weight_map": {},
263+
}
264+
safetensor_name = "model-scales.safetensors"
265+
for name, value in self.quant_model.act_scales_dict.items():
266+
save_scales[name + ".input_scale"] = value
267+
new_model_index["weight_map"][name + ".input_scale"] = safetensor_name
268+
for name, value in self.quant_model.weight_scales_dict.items():
269+
save_scales[name + ".weight_scale"] = value
270+
new_model_index["weight_map"][name + ".weight_scale"] = safetensor_name
271+
272+
safetensor_file = os.path.join(save_path, safetensor_name)
273+
safe_save(save_scales, safetensor_file)
274+
275+
# update model index json
276+
new_model_index_file = os.path.join(save_path, "model.safetensors.index.json")
277+
with open(new_model_index_file, "w") as f:
278+
json.dump(new_model_index, f, indent=2)
279+
280+
191281
class PTQTorchSave(PTQSaveBase):
192282
def __init__(self, quant_model):
193283
super(PTQTorchSave, self).__init__(quant_model=quant_model)
@@ -594,7 +684,7 @@ def merge_model(self, input_path, save_model_path, mp=16):
594684
param_list.append(param)
595685
newparam = torch.cat(param_list, dim=0)
596686
new_save_dict[k] = newparam
597-
print(f"shape of {k}: {new_save_dict[k].shape}")
687+
print_info(f"shape of {k}: {new_save_dict[k].shape}")
598688
index_dict["weight_map"][k] = str(filename)
599689
safe_save(new_save_dict, os.path.join(save_model_path, filename))
600690
# process others
@@ -625,7 +715,7 @@ def merge_model(self, input_path, save_model_path, mp=16):
625715
index_dict,
626716
filename,
627717
)
628-
print(f"shape of {k}: {new_save_dict[k].shape}")
718+
print_info(f"shape of {k}: {new_save_dict[k].shape}")
629719
safe_save(new_save_dict, os.path.join(save_model_path, filename))
630720

631721
# update scales map

angelslim/compressor/quant/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .helper_layer import GPTQQuantLinear # noqa: F401
2222
from .helper_layer import QDQModule # noqa: F401
2323
from .helper_layer import QDQSingleModule # noqa: F401
24+
from .helper_layer import QLinear # noqa: F401
2425
from .helper_layer import SmoothHelpModule # noqa: F401
2526
from .helper_layer import WQLinearGEMM # noqa: F401
2627
from .int8.int8 import INT8 # noqa: F401

0 commit comments

Comments
 (0)