Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion angelslim/compressor/quant/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.

from .config import * # noqa: F401 F403
from .hook import DiTHook, PTQHook # noqa: F401
from .hook import PTQHook # noqa: F401
from .metrics import mse_loss, snr_loss # noqa: F401
from .packing_utils import dequantize_gemm, pack_weight_to_int8 # noqa: F401
from .quant_func import * # noqa: F401 F403
from .sample_func import EMASampler, MultiStepSampler # noqa: F401
from .save import DeepseekV3HfPTQSave # noqa: F401
from .save import DeepseekV3PTQSaveTRTLLM # noqa: F401
from .save import PTQDiffusionSave # noqa: F401
from .save import PTQOnlyScaleSave # noqa: F401
from .save import PTQPTMSave # noqa: F401
from .save import PTQSaveVllmHF # noqa: F401
from .save import PTQTorchSave # noqa: F401
Expand Down
12 changes: 4 additions & 8 deletions angelslim/compressor/quant/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,13 @@ def __init__(self, config, global_config=None):
quantization_args = config.quantization
self.quant_algo = quantization_args.name
self.quant_bit = quantization_args.bits
self.max_seq_length = global_config.max_seq_length
self.quant_helpers = quantization_args.quant_helpers
act_quant_method = quantization_args.quant_method.get("activation", None)
weight_quant_method = quantization_args.quant_method["weight"]
if global_config:
self.max_seq_length = global_config.max_seq_length
self.hidden_size = global_config.hidden_size
self.model_arch_type = global_config.model_arch_type

if "fp8" in self.quant_algo:
is_dynamic = "dynamic" if "dynamic" in self.quant_algo else "static"
Expand Down Expand Up @@ -94,8 +97,6 @@ def __init__(self, config, global_config=None):

if act_quant_method is not None:
self.quant_algo_info["a"] = f"fp8_{act_quant_method}-{is_dynamic}"
self.hidden_size = global_config.hidden_size
self.model_arch_type = global_config.model_arch_type
self.low_memory = config.quantization.low_memory
self.quant_analyse = config.quantization.quant_analyse
self.quant_vit = config.quantization.quant_vit
Expand All @@ -117,8 +118,6 @@ def __init__(self, config, global_config=None):
}
if act_quant_method is not None:
self.quant_algo_info["a"] = f"int8_{act_quant_method}-{is_dynamic}"
self.hidden_size = global_config.hidden_size
self.model_arch_type = global_config.model_arch_type
self.low_memory = config.quantization.low_memory
self.quant_analyse = config.quantization.quant_analyse
elif "int4_awq" in self.quant_algo:
Expand All @@ -135,8 +134,6 @@ def __init__(self, config, global_config=None):
"group_size": int(group_size),
"mse_range": quantization_args.quant_method["mse_range"],
}
self.hidden_size = global_config.hidden_size
self.model_arch_type = global_config.model_arch_type
self.low_memory = config.quantization.low_memory
elif "int4_gptq" in self.quant_algo or "int4_gptaq" in self.quant_algo:
self.act_observer = None
Expand All @@ -151,7 +148,6 @@ def __init__(self, config, global_config=None):
"group_size": group_size,
"ignore_layers": quantization_args.ignore_layers,
}
self.hidden_size = global_config.hidden_size

if "smooth" in self.quant_helpers:
self.smooth_alpha = quantization_args.smooth_alpha
Expand Down
69 changes: 1 addition & 68 deletions angelslim/compressor/quant/core/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re

import torch

from ..observers import ParentObserver, PTQObserver
from .quant_func import get_fp_maxval, get_fp_search_maxval

__all__ = ["PTQHook", "DiTHook"]
__all__ = ["PTQHook"]


class PTQHook:
Expand Down Expand Up @@ -119,66 +115,3 @@ def post_process(self):
if self.quant_model.quant_algo_dict["c_quant_algo"] == "fp8":
for k, v in self.quant_model.kv_cache_scales_dict.items():
self.quant_model.kv_cache_scales_dict[k] = v / maxval.type(v.dtype)


def _filter_func(name):
pattern = re.compile(
r".*(mlp_t5|pooler|style_embedder|x_embedder|t_embedder|extra_embedder).*"
)
return pattern.match(name) is not None


class DiTHook:
def __init__(self, model):
"""
Args:
model(nn.Moudle, required): the model to be quant
"""
self.model = model
self.input_activation = []
self.output_activation = []

self._apply_hook()

def _apply_hook(self):
self._forward_hook_list = []
for name, sub_layer in self.model.named_modules():
if _filter_func(name):
continue
if isinstance(sub_layer, (torch.nn.Conv2d, torch.nn.Linear)):
if "blocks" in name:
# handle
forward_pre_hook_handle = sub_layer.register_forward_hook(
self._forward_pre_hook
)
self._forward_hook_list.append(forward_pre_hook_handle)

def _forward_pre_hook(self, layer, input, output):
layer_name = ""
for name, module in self.model.named_modules():
if _filter_func(name):
continue
if module == layer:
layer_name = name
break
x = (
output[0].detach().cpu()
if isinstance(output, tuple)
else output.detach().cpu()
)
self.output_activation.append((layer_name, x))
y = (
input[0].detach().cpu()
if isinstance(input, tuple)
else input.detach().cpu()
)
self.input_activation.append((layer_name, y))

def remove_hook(self):
for hook in self._forward_hook_list:
hook.remove()
self._forward_hook_list = []

def clean_acitvation_list(self):
self.input_activation = []
self.output_activation = []
49 changes: 49 additions & 0 deletions angelslim/compressor/quant/core/quant_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,52 @@ def weight_dequant(
)
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y


# This function is copied from DeepSeek-V3 (MIT License):
# Copyright (c) 2023 DeepSeek-AI
# Original source: https://github.com/deepseek-ai/DeepSeek-V3
@triton.jit
def weight_quant(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""Quantizes FP32 weights to FP8 format using block-wise quantization."""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)

offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]

mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
max_val = tl.max(tl.abs(x))
scale = max_val / 448.0
scale = tl.where(max_val == 0.0, 1.0, scale)
y = x / scale
y = y.to(y_ptr.dtype.element_ty)

tl.store(y_ptr + offs, y, mask=mask)
tl.store(s_ptr + pid_m * n + pid_n, scale)


def per_block_weight_quant(
x: torch.Tensor, block_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantizes FP32 weight tensor to FP8 format using block-wise quantization."""
assert x.is_contiguous()
assert x.dim() == 2

M, N = x.size()
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
m_blocks = triton.cdiv(M, block_size)
n_blocks = triton.cdiv(N, block_size)
s = torch.empty((m_blocks, n_blocks), dtype=torch.float32, device=x.device)

grid = lambda meta: ( # noqa: E731
triton.cdiv(M, meta["BLOCK_SIZE"]),
triton.cdiv(N, meta["BLOCK_SIZE"]),
)

weight_quant[grid](x, y, s, M, N, BLOCK_SIZE=block_size)

return y, s
98 changes: 94 additions & 4 deletions angelslim/compressor/quant/core/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from tqdm import tqdm
from transformers.models.deepseek_v3 import DeepseekV3Config

from ....utils import print_info
from ..modules import QDQModule, QDQSingleModule
from ....utils import find_layers, find_parent_layer_and_sub_name, print_info
from ..modules import QDQModule, QDQSingleModule, QLinear
from .packing_utils import pack_weight_to_int8
from .quant_func import fake_quant_dequant, tensor_quant, weight_dequant

Expand Down Expand Up @@ -188,6 +188,96 @@ def save(self, save_path):
self.quant_model.tokenizer.save_pretrained(save_path)


class PTQDiffusionSave(PTQSaveBase):
def __init__(self, quant_model):
super().__init__(quant_model=quant_model)

def save(self, save_path):
a_quant_algo = self.quant_model.quant_config.quant_algo_info["a"]
ignored_layers = self.quant_model.skip_layer_names()

static_q_dict = {
"quantization_config": {
"quant_method": "fp8",
"activation_scheme": (
"dynamic" if "dynamic" in a_quant_algo else "static"
),
"ignored_layers": ignored_layers,
}
}

os.makedirs(save_path, exist_ok=True)
with open(os.path.join(save_path, "hf_quant_config.json"), "w") as f:
json.dump(static_q_dict, f, indent=4)

save_scales = {}
layers_dict = find_layers(
self.quant_model.get_model().transformer, layers=[QDQModule]
)
for name, sub_layer in layers_dict.items():
parent_layer, sub_name = find_parent_layer_and_sub_name(
self.quant_model.get_model().transformer, name
)
q_module = QLinear(
quant_algo=sub_layer.quant_algo,
weight=sub_layer.weight,
bias=sub_layer.bias,
weight_scale=sub_layer.weight_scale.data.clone().detach(),
input_scale=sub_layer.input_scale.data.clone().detach(),
)
setattr(parent_layer, sub_name, q_module)
save_scales[name + ".input_scale"] = sub_layer.input_scale
save_scales[name + ".weight_scale"] = sub_layer.weight_scale

self.quant_model.get_model().save_pretrained(save_path)
safetensor_file = os.path.join(save_path, "model-scales.safetensors")
safe_save(save_scales, safetensor_file)


class PTQOnlyScaleSave(PTQSaveBase):
def __init__(self, quant_model):
super().__init__(quant_model=quant_model)

def save(self, save_path):
a_quant_algo = self.quant_model.quant_config.quant_algo_info["a"]
ignored_layers = self.quant_model.skip_layer_names()

static_q_dict = {
"quantization_config": {
"quant_method": "fp8",
"activation_scheme": (
"dynamic" if "dynamic" in a_quant_algo else "static"
),
"ignored_layers": ignored_layers,
}
}

os.makedirs(save_path, exist_ok=True)
with open(os.path.join(save_path, "hf_quant_config.json"), "w") as f:
json.dump(static_q_dict, f, indent=4)

save_scales = {}
new_model_index = {
"metadata": {},
"weight_map": {},
}
safetensor_name = "model-scales.safetensors"
for name, value in self.quant_model.act_scales_dict.items():
save_scales[name + ".input_scale"] = value
new_model_index["weight_map"][name + ".input_scale"] = safetensor_name
for name, value in self.quant_model.weight_scales_dict.items():
save_scales[name + ".weight_scale"] = value
new_model_index["weight_map"][name + ".weight_scale"] = safetensor_name

safetensor_file = os.path.join(save_path, safetensor_name)
safe_save(save_scales, safetensor_file)

# update model index json
new_model_index_file = os.path.join(save_path, "model.safetensors.index.json")
with open(new_model_index_file, "w") as f:
json.dump(new_model_index, f, indent=2)


class PTQTorchSave(PTQSaveBase):
def __init__(self, quant_model):
super(PTQTorchSave, self).__init__(quant_model=quant_model)
Expand Down Expand Up @@ -594,7 +684,7 @@ def merge_model(self, input_path, save_model_path, mp=16):
param_list.append(param)
newparam = torch.cat(param_list, dim=0)
new_save_dict[k] = newparam
print(f"shape of {k}: {new_save_dict[k].shape}")
print_info(f"shape of {k}: {new_save_dict[k].shape}")
index_dict["weight_map"][k] = str(filename)
safe_save(new_save_dict, os.path.join(save_model_path, filename))
# process others
Expand Down Expand Up @@ -625,7 +715,7 @@ def merge_model(self, input_path, save_model_path, mp=16):
index_dict,
filename,
)
print(f"shape of {k}: {new_save_dict[k].shape}")
print_info(f"shape of {k}: {new_save_dict[k].shape}")
safe_save(new_save_dict, os.path.join(save_model_path, filename))

# update scales map
Expand Down
1 change: 1 addition & 0 deletions angelslim/compressor/quant/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .helper_layer import GPTQQuantLinear # noqa: F401
from .helper_layer import QDQModule # noqa: F401
from .helper_layer import QDQSingleModule # noqa: F401
from .helper_layer import QLinear # noqa: F401
from .helper_layer import SmoothHelpModule # noqa: F401
from .helper_layer import WQLinearGEMM # noqa: F401
from .int8.int8 import INT8 # noqa: F401
Expand Down
Loading