diff --git a/angelslim/compressor/quant/core/__init__.py b/angelslim/compressor/quant/core/__init__.py index 66cf1493..5dd61769 100644 --- a/angelslim/compressor/quant/core/__init__.py +++ b/angelslim/compressor/quant/core/__init__.py @@ -13,13 +13,15 @@ # limitations under the License. from .config import * # noqa: F401 F403 -from .hook import DiTHook, PTQHook # noqa: F401 +from .hook import PTQHook # noqa: F401 from .metrics import mse_loss, snr_loss # noqa: F401 from .packing_utils import dequantize_gemm, pack_weight_to_int8 # noqa: F401 from .quant_func import * # noqa: F401 F403 from .sample_func import EMASampler, MultiStepSampler # noqa: F401 from .save import DeepseekV3HfPTQSave # noqa: F401 from .save import DeepseekV3PTQSaveTRTLLM # noqa: F401 +from .save import PTQDiffusionSave # noqa: F401 +from .save import PTQOnlyScaleSave # noqa: F401 from .save import PTQPTMSave # noqa: F401 from .save import PTQSaveVllmHF # noqa: F401 from .save import PTQTorchSave # noqa: F401 diff --git a/angelslim/compressor/quant/core/config.py b/angelslim/compressor/quant/core/config.py index a562859b..5ea979e0 100644 --- a/angelslim/compressor/quant/core/config.py +++ b/angelslim/compressor/quant/core/config.py @@ -57,10 +57,13 @@ def __init__(self, config, global_config=None): quantization_args = config.quantization self.quant_algo = quantization_args.name self.quant_bit = quantization_args.bits - self.max_seq_length = global_config.max_seq_length self.quant_helpers = quantization_args.quant_helpers act_quant_method = quantization_args.quant_method.get("activation", None) weight_quant_method = quantization_args.quant_method["weight"] + if global_config: + self.max_seq_length = global_config.max_seq_length + self.hidden_size = global_config.hidden_size + self.model_arch_type = global_config.model_arch_type if "fp8" in self.quant_algo: is_dynamic = "dynamic" if "dynamic" in self.quant_algo else "static" @@ -94,8 +97,6 @@ def __init__(self, config, global_config=None): if act_quant_method is not None: self.quant_algo_info["a"] = f"fp8_{act_quant_method}-{is_dynamic}" - self.hidden_size = global_config.hidden_size - self.model_arch_type = global_config.model_arch_type self.low_memory = config.quantization.low_memory self.quant_analyse = config.quantization.quant_analyse self.quant_vit = config.quantization.quant_vit @@ -117,8 +118,6 @@ def __init__(self, config, global_config=None): } if act_quant_method is not None: self.quant_algo_info["a"] = f"int8_{act_quant_method}-{is_dynamic}" - self.hidden_size = global_config.hidden_size - self.model_arch_type = global_config.model_arch_type self.low_memory = config.quantization.low_memory self.quant_analyse = config.quantization.quant_analyse elif "int4_awq" in self.quant_algo: @@ -135,8 +134,6 @@ def __init__(self, config, global_config=None): "group_size": int(group_size), "mse_range": quantization_args.quant_method["mse_range"], } - self.hidden_size = global_config.hidden_size - self.model_arch_type = global_config.model_arch_type self.low_memory = config.quantization.low_memory elif "int4_gptq" in self.quant_algo or "int4_gptaq" in self.quant_algo: self.act_observer = None @@ -151,7 +148,6 @@ def __init__(self, config, global_config=None): "group_size": group_size, "ignore_layers": quantization_args.ignore_layers, } - self.hidden_size = global_config.hidden_size if "smooth" in self.quant_helpers: self.smooth_alpha = quantization_args.smooth_alpha diff --git a/angelslim/compressor/quant/core/hook.py b/angelslim/compressor/quant/core/hook.py index 1f77fb32..a9c6821a 100644 --- a/angelslim/compressor/quant/core/hook.py +++ b/angelslim/compressor/quant/core/hook.py @@ -12,14 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re - -import torch - from ..observers import ParentObserver, PTQObserver from .quant_func import get_fp_maxval, get_fp_search_maxval -__all__ = ["PTQHook", "DiTHook"] +__all__ = ["PTQHook"] class PTQHook: @@ -119,66 +115,3 @@ def post_process(self): if self.quant_model.quant_algo_dict["c_quant_algo"] == "fp8": for k, v in self.quant_model.kv_cache_scales_dict.items(): self.quant_model.kv_cache_scales_dict[k] = v / maxval.type(v.dtype) - - -def _filter_func(name): - pattern = re.compile( - r".*(mlp_t5|pooler|style_embedder|x_embedder|t_embedder|extra_embedder).*" - ) - return pattern.match(name) is not None - - -class DiTHook: - def __init__(self, model): - """ - Args: - model(nn.Moudle, required): the model to be quant - """ - self.model = model - self.input_activation = [] - self.output_activation = [] - - self._apply_hook() - - def _apply_hook(self): - self._forward_hook_list = [] - for name, sub_layer in self.model.named_modules(): - if _filter_func(name): - continue - if isinstance(sub_layer, (torch.nn.Conv2d, torch.nn.Linear)): - if "blocks" in name: - # handle - forward_pre_hook_handle = sub_layer.register_forward_hook( - self._forward_pre_hook - ) - self._forward_hook_list.append(forward_pre_hook_handle) - - def _forward_pre_hook(self, layer, input, output): - layer_name = "" - for name, module in self.model.named_modules(): - if _filter_func(name): - continue - if module == layer: - layer_name = name - break - x = ( - output[0].detach().cpu() - if isinstance(output, tuple) - else output.detach().cpu() - ) - self.output_activation.append((layer_name, x)) - y = ( - input[0].detach().cpu() - if isinstance(input, tuple) - else input.detach().cpu() - ) - self.input_activation.append((layer_name, y)) - - def remove_hook(self): - for hook in self._forward_hook_list: - hook.remove() - self._forward_hook_list = [] - - def clean_acitvation_list(self): - self.input_activation = [] - self.output_activation = [] diff --git a/angelslim/compressor/quant/core/quant_func.py b/angelslim/compressor/quant/core/quant_func.py index 8fae3de8..5e79cc5a 100644 --- a/angelslim/compressor/quant/core/quant_func.py +++ b/angelslim/compressor/quant/core/quant_func.py @@ -380,3 +380,52 @@ def weight_dequant( ) weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y + + +# This function is copied from DeepSeek-V3 (MIT License): +# Copyright (c) 2023 DeepSeek-AI +# Original source: https://github.com/deepseek-ai/DeepSeek-V3 +@triton.jit +def weight_quant(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """Quantizes FP32 weights to FP8 format using block-wise quantization.""" + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + max_val = tl.max(tl.abs(x)) + scale = max_val / 448.0 + scale = tl.where(max_val == 0.0, 1.0, scale) + y = x / scale + y = y.to(y_ptr.dtype.element_ty) + + tl.store(y_ptr + offs, y, mask=mask) + tl.store(s_ptr + pid_m * n + pid_n, scale) + + +def per_block_weight_quant( + x: torch.Tensor, block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantizes FP32 weight tensor to FP8 format using block-wise quantization.""" + assert x.is_contiguous() + assert x.dim() == 2 + + M, N = x.size() + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + m_blocks = triton.cdiv(M, block_size) + n_blocks = triton.cdiv(N, block_size) + s = torch.empty((m_blocks, n_blocks), dtype=torch.float32, device=x.device) + + grid = lambda meta: ( # noqa: E731 + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + + weight_quant[grid](x, y, s, M, N, BLOCK_SIZE=block_size) + + return y, s diff --git a/angelslim/compressor/quant/core/save.py b/angelslim/compressor/quant/core/save.py index 220f6e02..1c846de8 100644 --- a/angelslim/compressor/quant/core/save.py +++ b/angelslim/compressor/quant/core/save.py @@ -28,8 +28,8 @@ from tqdm import tqdm from transformers.models.deepseek_v3 import DeepseekV3Config -from ....utils import print_info -from ..modules import QDQModule, QDQSingleModule +from ....utils import find_layers, find_parent_layer_and_sub_name, print_info +from ..modules import QDQModule, QDQSingleModule, QLinear from .packing_utils import pack_weight_to_int8 from .quant_func import fake_quant_dequant, tensor_quant, weight_dequant @@ -188,6 +188,96 @@ def save(self, save_path): self.quant_model.tokenizer.save_pretrained(save_path) +class PTQDiffusionSave(PTQSaveBase): + def __init__(self, quant_model): + super().__init__(quant_model=quant_model) + + def save(self, save_path): + a_quant_algo = self.quant_model.quant_config.quant_algo_info["a"] + ignored_layers = self.quant_model.skip_layer_names() + + static_q_dict = { + "quantization_config": { + "quant_method": "fp8", + "activation_scheme": ( + "dynamic" if "dynamic" in a_quant_algo else "static" + ), + "ignored_layers": ignored_layers, + } + } + + os.makedirs(save_path, exist_ok=True) + with open(os.path.join(save_path, "hf_quant_config.json"), "w") as f: + json.dump(static_q_dict, f, indent=4) + + save_scales = {} + layers_dict = find_layers( + self.quant_model.get_model().transformer, layers=[QDQModule] + ) + for name, sub_layer in layers_dict.items(): + parent_layer, sub_name = find_parent_layer_and_sub_name( + self.quant_model.get_model().transformer, name + ) + q_module = QLinear( + quant_algo=sub_layer.quant_algo, + weight=sub_layer.weight, + bias=sub_layer.bias, + weight_scale=sub_layer.weight_scale.data.clone().detach(), + input_scale=sub_layer.input_scale.data.clone().detach(), + ) + setattr(parent_layer, sub_name, q_module) + save_scales[name + ".input_scale"] = sub_layer.input_scale + save_scales[name + ".weight_scale"] = sub_layer.weight_scale + + self.quant_model.get_model().save_pretrained(save_path) + safetensor_file = os.path.join(save_path, "model-scales.safetensors") + safe_save(save_scales, safetensor_file) + + +class PTQOnlyScaleSave(PTQSaveBase): + def __init__(self, quant_model): + super().__init__(quant_model=quant_model) + + def save(self, save_path): + a_quant_algo = self.quant_model.quant_config.quant_algo_info["a"] + ignored_layers = self.quant_model.skip_layer_names() + + static_q_dict = { + "quantization_config": { + "quant_method": "fp8", + "activation_scheme": ( + "dynamic" if "dynamic" in a_quant_algo else "static" + ), + "ignored_layers": ignored_layers, + } + } + + os.makedirs(save_path, exist_ok=True) + with open(os.path.join(save_path, "hf_quant_config.json"), "w") as f: + json.dump(static_q_dict, f, indent=4) + + save_scales = {} + new_model_index = { + "metadata": {}, + "weight_map": {}, + } + safetensor_name = "model-scales.safetensors" + for name, value in self.quant_model.act_scales_dict.items(): + save_scales[name + ".input_scale"] = value + new_model_index["weight_map"][name + ".input_scale"] = safetensor_name + for name, value in self.quant_model.weight_scales_dict.items(): + save_scales[name + ".weight_scale"] = value + new_model_index["weight_map"][name + ".weight_scale"] = safetensor_name + + safetensor_file = os.path.join(save_path, safetensor_name) + safe_save(save_scales, safetensor_file) + + # update model index json + new_model_index_file = os.path.join(save_path, "model.safetensors.index.json") + with open(new_model_index_file, "w") as f: + json.dump(new_model_index, f, indent=2) + + class PTQTorchSave(PTQSaveBase): def __init__(self, quant_model): super(PTQTorchSave, self).__init__(quant_model=quant_model) @@ -594,7 +684,7 @@ def merge_model(self, input_path, save_model_path, mp=16): param_list.append(param) newparam = torch.cat(param_list, dim=0) new_save_dict[k] = newparam - print(f"shape of {k}: {new_save_dict[k].shape}") + print_info(f"shape of {k}: {new_save_dict[k].shape}") index_dict["weight_map"][k] = str(filename) safe_save(new_save_dict, os.path.join(save_model_path, filename)) # process others @@ -625,7 +715,7 @@ def merge_model(self, input_path, save_model_path, mp=16): index_dict, filename, ) - print(f"shape of {k}: {new_save_dict[k].shape}") + print_info(f"shape of {k}: {new_save_dict[k].shape}") safe_save(new_save_dict, os.path.join(save_model_path, filename)) # update scales map diff --git a/angelslim/compressor/quant/modules/__init__.py b/angelslim/compressor/quant/modules/__init__.py index 2468a4e4..07962bdf 100644 --- a/angelslim/compressor/quant/modules/__init__.py +++ b/angelslim/compressor/quant/modules/__init__.py @@ -21,6 +21,7 @@ from .helper_layer import GPTQQuantLinear # noqa: F401 from .helper_layer import QDQModule # noqa: F401 from .helper_layer import QDQSingleModule # noqa: F401 +from .helper_layer import QLinear # noqa: F401 from .helper_layer import SmoothHelpModule # noqa: F401 from .helper_layer import WQLinearGEMM # noqa: F401 from .int8.int8 import INT8 # noqa: F401 diff --git a/angelslim/compressor/quant/modules/awq/awq.py b/angelslim/compressor/quant/modules/awq/awq.py index e0c6012f..b7edec92 100644 --- a/angelslim/compressor/quant/modules/awq/awq.py +++ b/angelslim/compressor/quant/modules/awq/awq.py @@ -22,7 +22,7 @@ from huggingface_hub import save_torch_state_dict from tqdm import tqdm -from .....utils import get_best_device, print_info, set_op_by_name +from .....utils import find_layers, get_best_device, print_info, set_op_by_name from ...core import pseudo_quantize_tensor from ...modules.catcher import Catcher from ...modules.helper_layer import WQLinearGEMM @@ -60,10 +60,7 @@ def __init__( super(AWQ, self).__init__() self.model = model self.modal_type = self.model.modal_type - if self.modal_type == "VLM": - self.layers = self.model.model.model.language_model.layers - else: - self.layers = self.model.model.model.layers + self.layers = self.model.get_quant_module() self.quant_bits = self.model.quant_config.quant_bit self.group_size = self.model.quant_config.quant_algo_info["group_size"] self.zero_point = self.model.quant_config.quant_algo_info["zero_point"] @@ -161,7 +158,7 @@ def run(self, dataloader): if not self.low_memory: outs = outs.to(dev) self.inps = self.inps.to(dev) - subset = self._find_layers(layer) + subset = find_layers(layer) if self.model_arch_type in ["qwen3_moe", "hunyuan_v1_moe"]: subset = { @@ -302,22 +299,6 @@ def forward(self, x): if self.modal_type in ["LLM", "VLM"]: self.model.tokenizer.save_pretrained(save_dir) - def _find_layers(self, module, layers=None, name=""): - if not layers: - layers = self.isinstance_list - if type(module) in layers: - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update( - self._find_layers( - child, - layers=layers, - name=name + "." + name1 if name != "" else name1, - ) - ) - return res - def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]): for name, linear_layer in named_linears.items(): if "mlp.gate." in name: @@ -353,7 +334,7 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]): def _convert_llm(self): for i in tqdm(range(len(self.layers)), desc="AWQ"): - subset = self._find_layers(self.layers[i]) + subset = find_layers(self.layers[i]) self._apply_quant(self.layers[i], subset) def convert(self): @@ -361,11 +342,5 @@ def convert(self): Saves scales and inserts QDQ modules. """ print_info("Start convert model...") - if self.modal_type in ["LLM", "VLM"]: - self._convert_llm() - elif self.modal_type == "AIGC": - pass - else: - print_info("current {} modal type not support".format(self.modal_type)) - raise NotImplementedError + self._convert_llm() print_info("convert model done.") diff --git a/angelslim/compressor/quant/modules/fp8/fp8.py b/angelslim/compressor/quant/modules/fp8/fp8.py index f76f35a1..2aca7928 100644 --- a/angelslim/compressor/quant/modules/fp8/fp8.py +++ b/angelslim/compressor/quant/modules/fp8/fp8.py @@ -43,18 +43,13 @@ def __init__( """ super(FP8, self).__init__() self.model = model - self.modal_type = self.model.modal_type - if self.modal_type == "VLM": - self.layers = self.model.model.model.language_model.layers - else: - self.layers = self.model.model.model.layers + self.layers = self.model.get_quant_module() self.quant_bits = self.model.quant_config.quant_bit self.seq_length = seq_length self.hidden_size = hidden_size self.model_arch_type = model_arch_type self.low_memory = low_memory self.dtype = torch.bfloat16 - torch.set_default_dtype(self.dtype) self.scales_dict = {} self.inps = None diff --git a/angelslim/compressor/quant/modules/gptq/gptq.py b/angelslim/compressor/quant/modules/gptq/gptq.py index 6fb05f4e..71ee33b3 100644 --- a/angelslim/compressor/quant/modules/gptq/gptq.py +++ b/angelslim/compressor/quant/modules/gptq/gptq.py @@ -20,7 +20,7 @@ from huggingface_hub import save_torch_state_dict from tqdm import tqdm -from .....utils import print_info +from .....utils import find_layers, print_info from ...modules.catcher import Catcher from ...modules.helper_layer import GPTQQuantLinear from .gptaq_module import GPTAQModule @@ -36,10 +36,7 @@ def __init__( super(GPTQ, self).__init__() self.model = model self.modal_type = self.model.modal_type - if self.modal_type == "VLM": - self.layers = self.model.model.model.language_model.layers - else: - self.layers = self.model.model.model.layers + self.layers = self.model.get_quant_module() self.layers_block_name = self.model.block_name self.quant_bits = self.model.quant_config.quant_bit self.group_size = self.model.quant_config.quant_algo_info["group_size"] @@ -98,7 +95,7 @@ def run(self, dataloader): for i in range(len(layers)): layer = layers[i].to(inps.device) - subset = self._find_layers(layer) + subset = find_layers(layer) print_info("subset:{}".format(subset)) self.gptq = {} if "gptaq" in self.quant_algo: @@ -246,12 +243,12 @@ def _pack_model( model.cpu() print_info("Packing model...") - layers = self._find_layers(model) + layers = find_layers(model) layers = {n: layers[n] for n in quantizers} self._make_quant(model, quantizers, bits, group_size) - qlayers = self._find_layers(model, [GPTQQuantLinear]) + qlayers = find_layers(model, [GPTQQuantLinear]) with tctl.threadpool_limits(limits=1): pbar = tqdm(qlayers.keys(), leave=True) @@ -286,13 +283,7 @@ def convert(self): Saves scales and inserts QDQ modules. """ print_info("Start convert model...") - if self.modal_type in ["LLM", "VLM"]: - self._convert_llm() - elif self.modal_type == "AIGC": - pass - else: - print_info("current {} modal type not support".format(self.modal_type)) - raise NotImplementedError + self._convert_llm() print_info("convert model done.") def save(self, save_dir: str, shard_size="5GB", safetensors=True): @@ -357,19 +348,3 @@ def _recurse_setattr(self, module, name, value): else: name, rest = name.split(".", 1) self._recurse_setattr(getattr(module, name), rest, value) - - def _find_layers(self, module, layers=None, name=""): - if not layers: - layers = [torch.nn.Linear] - if type(module) in layers: - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update( - self._find_layers( - child, - layers=layers, - name=name + "." + name1 if name != "" else name1, - ) - ) - return res diff --git a/angelslim/compressor/quant/modules/helper_layer.py b/angelslim/compressor/quant/modules/helper_layer.py index d99488c9..e05f6dfe 100644 --- a/angelslim/compressor/quant/modules/helper_layer.py +++ b/angelslim/compressor/quant/modules/helper_layer.py @@ -652,3 +652,57 @@ def forward(self, x): qoutput = quantize_activation_per_tensor_fp8(output, self.output_scale) output = qoutput.to(output.dtype) * self.output_scale return output + + +class QLinear(torch.nn.Module): + def __init__( + self, + quant_algo: QuantConfig, + weight: torch.nn.Parameter, + weight_scale: torch.nn.Parameter, + bias: torch.nn.Parameter, + input_scale: Optional[torch.nn.Parameter] = None, + ): + super().__init__() + self.quant_algo = quant_algo + self.weight = weight + + self.weight_scale = ( + weight_scale.view(-1) if weight_scale.ndim == 0 else weight_scale + ) + self.bias = bias + if input_scale is not None: + self.input_scale = ( + input_scale.view(-1) if input_scale.ndim == 0 else input_scale + ) + else: + self.input_scale = None + + def forward(self, x): + if self.input_scale: + if "fp8" in self.quant_algo: + qinput = quantize_activation_per_tensor_fp8(x, self.input_scale) + elif "int8" in self.quant_algo: + qinput = tensor_quant_dequant_int(x, self.input_scale, bits=8) + else: + raise ValueError( + f"Unsupported quantization algorithm: {self.quant_algo}" + ) + + if "fp8" in self.quant_algo: + output = gemm_fp8( + act=qinput, + act_scale=self.input_scale, + weight=self.weight, + weight_scale=self.weight_scale, + bias=self.bias, + out_dtype=x.dtype, + ) + elif "int8" in self.quant_algo: + output = torch.nn.functional.linear( + x, self.weight * self.weight_scale, bias=self.bias + ) + else: + raise ValueError(f"Unsupported quantization algorithm: {self.quant_algo}") + + return output diff --git a/angelslim/compressor/quant/modules/int8/int8.py b/angelslim/compressor/quant/modules/int8/int8.py index 87746373..798a927e 100644 --- a/angelslim/compressor/quant/modules/int8/int8.py +++ b/angelslim/compressor/quant/modules/int8/int8.py @@ -42,7 +42,7 @@ def __init__( super(INT8, self).__init__() self.model = model self.modal_type = self.model.modal_type - self.layers = self.model.model.model.layers + self.layers = self.model.get_quant_module() self.quant_bits = self.model.quant_config.quant_bit self.seq_length = seq_length self.hidden_size = hidden_size diff --git a/angelslim/compressor/quant/ptq.py b/angelslim/compressor/quant/ptq.py index b3d0ee2b..138ac5f8 100644 --- a/angelslim/compressor/quant/ptq.py +++ b/angelslim/compressor/quant/ptq.py @@ -36,11 +36,10 @@ def __init__(self, model, slim_config=None): self.quant_model = model # init ptq config of model self.quant_model.init_ptq(slim_config) - self.modal_type = self.quant_model.modal_type - self.layers = self.quant_model.get_model() + self.layers = self.quant_model.get_quant_module() self.quant_algo = self.quant_model.quant_config.quant_algo self.quant_helpers = self.quant_model.quant_config.quant_helpers - if self.modal_type in ["LLM", "VLM"]: + if "fp8" in self.quant_algo or "int8" in self.quant_algo: # Add ptq observer hook self.ptq_hook = PTQHook(self.quant_model) self.ptq_hook.apply_hook() @@ -51,8 +50,7 @@ def __init__(self, model, slim_config=None): self.gptq = GPTQ( self.quant_model, seq_length=max_seq_length, hidden_size=hidden_size ) - - if "awq" in self.quant_algo: + elif "awq" in self.quant_algo: max_seq_length = self.quant_model.quant_config.max_seq_length hidden_size = self.quant_model.quant_config.hidden_size model_arch_type = self.quant_model.quant_config.model_arch_type @@ -65,7 +63,7 @@ def __init__(self, model, slim_config=None): observer_layer_classes=[nn.Linear], low_memory=self.quant_model.quant_config.low_memory, ) - if "fp8" in self.quant_algo: + elif "fp8" in self.quant_algo: max_seq_length = self.quant_model.quant_config.max_seq_length hidden_size = self.quant_model.quant_config.hidden_size model_arch_type = self.quant_model.quant_config.model_arch_type @@ -86,7 +84,7 @@ def __init__(self, model, slim_config=None): model_arch_type=model_arch_type, low_memory=self.quant_model.quant_config.low_memory, ) - if "int8" in self.quant_algo: + elif "int8" in self.quant_algo: max_seq_length = self.quant_model.quant_config.max_seq_length hidden_size = self.quant_model.quant_config.hidden_size model_arch_type = self.quant_model.quant_config.model_arch_type @@ -97,6 +95,11 @@ def __init__(self, model, slim_config=None): model_arch_type=model_arch_type, low_memory=self.quant_model.quant_config.low_memory, ) + else: + raise NotImplementedError( + f"[AngelSlim Error] algo {self.quant_algo} is not support" + ) + if "smooth" in self.quant_helpers: self.smooth = SmoothQuant( self.quant_model, @@ -130,13 +133,9 @@ def convert(self): elif "lepto" in self.quant_algo: self.fp8.convert() else: - if self.modal_type in ["LLM", "VLM"]: - if "smooth" in self.quant_helpers: - self.smooth.convert() - self._convert_llm() - else: - print_info("current {} modal type not support".format(self.modal_type)) - raise NotImplementedError + if "smooth" in self.quant_helpers: + self.smooth.convert() + self._convert() print_info("convert model done.") def save(self, save_path: str): @@ -174,7 +173,7 @@ def save(self, save_path: str): save_func = self.quant_model.get_save_func()(self.quant_model) save_func.save(save_path) - def _convert_llm(self): + def _convert(self): # 1. get act, weight and kv-cache scale for name, sub_layer in self.ptq_hook.quant_layers_dict.items(): if ( diff --git a/angelslim/data/__init__.py b/angelslim/data/__init__.py index 2eb27d0f..97d05f98 100644 --- a/angelslim/data/__init__.py +++ b/angelslim/data/__init__.py @@ -6,4 +6,5 @@ from .dataloader import DataLoaderFactory # noqa: F401 from .multimodal_dataset import MultiModalDataset # noqa: F401 +from .text2image_dataset import Text2ImageDataset # noqa: F401 from .text_dataset import TextDataset # noqa: F401 diff --git a/angelslim/data/base_dataset.py b/angelslim/data/base_dataset.py index 941e0d8c..3604e0b5 100644 --- a/angelslim/data/base_dataset.py +++ b/angelslim/data/base_dataset.py @@ -38,17 +38,24 @@ def __getitem__(self, idx: int) -> Dict: @staticmethod def collate_fn(batch: List[Dict]) -> Dict: - """Custom collate function to batch dictionary items""" + """Custom collate function for batching""" collated = {} + + # Get all keys from the first sample for key in batch[0].keys(): - # Skip non-tensor items - if not isinstance(batch[0][key], torch.Tensor): - continue - - # Stack tensors of the same type - tensors = [item[key] for item in batch] - if all(t.dim() == 0 for t in tensors): # Handle scalar tensors - collated[key] = torch.stack(tensors) + # Collect the values for this key from all samples + values = [item[key] for item in batch] + + # If the value is numeric, convert directly to tensor + if isinstance(values[0], torch.Tensor): + if all(t.dim() == 0 for t in values): # Handle scalar tensors + collated[key] = torch.stack(values) + else: + collated[key] = torch.cat(values, dim=0) + elif isinstance(values[0], (int, float)): + collated[key] = torch.tensor(values) else: - collated[key] = torch.cat(tensors, dim=0) + # For text and other non-numeric types, keep as a list + collated[key] = values + return collated diff --git a/angelslim/data/dataloader.py b/angelslim/data/dataloader.py index 1073f54a..89052182 100644 --- a/angelslim/data/dataloader.py +++ b/angelslim/data/dataloader.py @@ -20,6 +20,7 @@ from .base_dataset import BaseDataset from .multimodal_dataset import MultiModalDataset +from .text2image_dataset import Text2ImageDataset from .text_dataset import TextDataset @@ -37,6 +38,7 @@ def create_data_loader( data_source: Union[str, Dict] = None, data_type: str = "auto", num_workers: int = 0, + inference_settings: Dict = None, ) -> DataLoader: """ Create appropriate DataLoader based on data source @@ -51,6 +53,12 @@ def create_data_loader( data_source: File path or HF dataset dict data_type: "text", "multimodal" or "auto" num_workers: Number of workers for DataLoader + inference_settings: Settings for text-to-image inference + - height: Image height + - width: Image width + - guidance_scale: Guidance scale for inference + - num_inference_steps: Number of inference steps + - max_sequence_length: Maximum sequence length for text inputs Returns: PyTorch DataLoader ready for use @@ -82,6 +90,12 @@ def create_data_loader( data_source=data_source, is_hf_dataset=not os.path.isfile(data_source), ) + elif data_type == "Text2ImageDataset": + dataset = Text2ImageDataset( + data_path=data_source, + num_samples=num_samples, + inference_settings=inference_settings, + ) else: raise ValueError(f"Unsupported data type: {data_type}") diff --git a/angelslim/data/text2image_dataset.py b/angelslim/data/text2image_dataset.py new file mode 100644 index 00000000..a41146e0 --- /dev/null +++ b/angelslim/data/text2image_dataset.py @@ -0,0 +1,68 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from .base_dataset import BaseDataset + + +class Text2ImageDataset(BaseDataset): + """Dataset for text-only data in Parquet or JSONL formats""" + + def __init__( + self, + data_path: str, + num_samples: int = -1, + inference_settings: dict = None, + ): + self.data = [] + self.seed = inference_settings["seed"] + self.height = inference_settings["height"] + self.width = inference_settings["width"] + self.guidance_scale = inference_settings["guidance_scale"] + self.num_inference_steps = inference_settings["num_inference_steps"] + self.max_sequence_length = inference_settings["max_sequence_length"] + self._load_data(data_path, num_samples) + + def _load_data(self, data_path: str, num_samples: int): + if data_path.endswith(".jsonl"): + self._load_jsonl_data(data_path, num_samples) + else: + raise ValueError("Unsupported file format. Only JSONL is supported.") + + def _load_jsonl_data(self, data_path: str, num_samples: int): + line_count = 0 + with open(data_path, "r") as f: + for line in f: + if num_samples > 0 and line_count >= num_samples: + break + + data = json.loads(line) + + # Validate format + assert "input" in data, "JSON format error" + + self.data.append( + { + "input": data["input"], + "seed": self.seed, + "height": self.height, + "width": self.width, + "guidance_scale": self.guidance_scale, + "num_inference_steps": self.num_inference_steps, + "max_sequence_length": self.max_sequence_length, + } + ) + + line_count += 1 diff --git a/angelslim/engine.py b/angelslim/engine.py index adb06825..75d712df 100644 --- a/angelslim/engine.py +++ b/angelslim/engine.py @@ -23,7 +23,12 @@ from .compressor import CompressorFactory from .data.dataloader import DataLoaderFactory from .models import SlimModelFactory -from .utils import default_compress_config, get_package_info, print_info +from .utils import ( + default_compress_config, + get_package_info, + parse_json_full_config, + print_info, +) DEFAULT_COMPRESSION_CONFIG = { "fp8_static": default_compress_config.default_fp8_static_config(), @@ -32,7 +37,6 @@ "int4_awq": default_compress_config.default_int4_awq_config(), "int4_gptq": default_compress_config.default_int4_gptq_config(), "w4a8_fp8": default_compress_config.default_w4a8_fp8_static_config(), - "int4_gptaq": default_compress_config.default_int4_gptaq_config(), } @@ -131,6 +135,7 @@ def prepare_data( batch_size=1, num_samples=128, shuffle=True, + inference_settings=None, ) -> Optional[Any]: """Prepare compression dataset""" if custom_dataloader is not None: @@ -153,6 +158,7 @@ def prepare_data( shuffle=shuffle, num_samples=num_samples, data_source=data_path, + inference_settings=inference_settings, ) self.max_seq_length = max_length @@ -247,13 +253,85 @@ def save( print_info(f"Compressed model saved to {save_path}") - def infer(self, input_prompt: str, **kwargs) -> Any: + +class InferEngine(Engine): + def __init__(self): + """ + Initialize engine configuration + """ + super().__init__() + self.slim_model = None + self.tokenizer = None + self.dataloader = None + self.compressor = None + self.compress_type = None + self.model_path = None + self.max_seq_length = None + + def from_pretrained( + self, + model_path, + torch_dtype=None, + device_map=None, + trust_remote_code=None, + low_cpu_mem_usage=None, + use_cache=None, + ) -> Any: + """Load pretrained model and tokenizer + Args: + model_path (str): Path to the pretrained model. + torch_dtype (str): Data type for the model weights. + device_map (str): Device map for the model. + trust_remote_code (bool): Whether to trust remote code. + low_cpu_mem_usage (bool): Whether to use low CPU memory usage mode. + use_cache (bool): Whether to use cache during loading. + cache_dir (str, optional): Directory to cache the model. + """ + assert model_path, "model_path must be specified." + # load slim config + slim_config_path = os.path.join(model_path, "angelslim_config.json") + if not os.path.exists(slim_config_path): + raise FileNotFoundError( + f"angelslim_config.json not found in {model_path}. " + "Please ensure the model is compressed with Angelslim." + ) + slim_config = parse_json_full_config(slim_config_path) + if torch_dtype: + slim_config.model_config.torch_dtype = torch_dtype + if device_map: + slim_config.model_config.device_map = device_map + if trust_remote_code is not None: + slim_config.model_config.trust_remote_code = trust_remote_code + if low_cpu_mem_usage is not None: + slim_config.model_config.low_cpu_mem_usage = low_cpu_mem_usage + if use_cache is not None: + slim_config.model_config.use_cache = use_cache + + self.slim_model = SlimModelFactory.create( + slim_config.model_config.name, deploy_backend="huggingface" + ) + + self.slim_model.from_pretrained( + model_path=model_path, + torch_dtype=slim_config.model_config.torch_dtype, + device_map=slim_config.model_config.device_map, + trust_remote_code=slim_config.model_config.trust_remote_code, + low_cpu_mem_usage=slim_config.model_config.low_cpu_mem_usage, + use_cache=slim_config.model_config.use_cache, + compress_config=slim_config.compression_config, + ) + + self.series = SlimModelFactory.get_series_by_models( + slim_config.model_config.name + ) + + def generate(self, input_prompt: str, **kwargs) -> Any: """Run inference with the compressed model Args: input_prompt (str): Input prompt for the model. """ if not self.slim_model or not self.slim_model.model: - raise RuntimeError("Model not initialized. Call prepare_model() first") + raise RuntimeError("Model not initialized. Call from_pretrained() first") if self.series in ["LLM", "VLM"]: return self.slim_model.generate( diff --git a/angelslim/models/base_model.py b/angelslim/models/base_model.py index 36d6524f..93be5f5e 100644 --- a/angelslim/models/base_model.py +++ b/angelslim/models/base_model.py @@ -44,11 +44,6 @@ def __init__( model: Optional[torch.nn.Module] = None, deploy_backend: Optional[str] = "vllm", ): - assert deploy_backend in [ - "vllm", - "huggingface", - "trtllm", - ], f"Unsupported deploy backend {deploy_backend}" self.deploy_backend = deploy_backend self.model = model self.tokenizer = None @@ -114,13 +109,20 @@ def skip_layer_names(self): def get_model(self): return self.model + def get_quant_module(self): + """ + Returns the module that will be quantized. + This is typically the main transformer module of the model. + """ + return self.model.model.layers + def get_qdq_module(self, sub_layer, name): act_scale, weight_scale = None, None if name in self.act_scales_dict: act_scale = self.act_scales_dict[name] if name in self.weight_scales_dict: weight_scale = self.weight_scales_dict[name] - if self.deploy_backend in ["vllm", "huggingface"]: + if self.deploy_backend in ["vllm", "huggingface", "trtllm", "tensorrt"]: q_linear = QDQModule( quant_algo=self.quant_config.quant_algo, weight=sub_layer.weight, @@ -130,9 +132,7 @@ def get_qdq_module(self, sub_layer, name): ) else: print_info( - "[Slim] current {} deploy_backend not support".format( - self.deploy_backend - ) + "current {} deploy_backend not support".format(self.deploy_backend) ) raise NotImplementedError return q_linear @@ -211,30 +211,6 @@ def get_quant_config(self): def is_all_reduce(self): return False - def build_hf_model(self, model_path): - model = AutoModelForCausalLM.from_pretrained( - model_path, - torch_dtype="auto", - device_map="auto", - use_flash_attention_2=True, - trust_remote_code=True, - ) - return model - - def find_layers(self, module, layers=None, name=""): - if type(module) in layers and name not in self.skip_layer_names(): - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update( - self.find_layers( - child, - layers=layers, - name=name + "." + name1 if name != "" else name1, - ) - ) - return res - def get_pre_transformer_modules(self): pre_transformer_modules_dict = {} for full_name in self.pre_transformer_module_names: @@ -344,7 +320,7 @@ def __getattr__(self, item): return super().__getattr__(item) -class BaseDiffusionModel(metaclass=ABCMeta): +class BaseDiffusionModel(BaseLLMModel): """ Base class for diffusion model compression, providing common functionalities such as initialization, quantization configuration, and model handling. @@ -357,10 +333,14 @@ class BaseDiffusionModel(metaclass=ABCMeta): def __init__( self, model: Optional[torch.nn.Module] = None, - deploy_backend: Optional[str] = "torch", + deploy_backend: Optional[str] = "huggingface", ): + super().__init__( + model=model, + deploy_backend=deploy_backend, + ) assert deploy_backend in [ - "torch", + "huggingface", "tensorrt", ], f"Unsupported deploy backend {deploy_backend}" self.deploy_backend = deploy_backend @@ -383,3 +363,9 @@ def get_observer_layers(self): @abstractmethod def get_save_func(self): pass + + def skip_layer_names(self): + return self.quant_config.quant_algo_info.get("ignore_layers", []) + + def get_model(self): + return self.model diff --git a/angelslim/models/diffusion/flux.py b/angelslim/models/diffusion/flux.py index efc09057..11ead538 100644 --- a/angelslim/models/diffusion/flux.py +++ b/angelslim/models/diffusion/flux.py @@ -12,9 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import torch +import torch.nn as nn from diffusers import FluxPipeline +from safetensors.torch import load_file +from tqdm import tqdm +from ...compressor.quant.core import PTQDiffusionSave, PTQOnlyScaleSave, QuantConfig +from ...compressor.quant.modules import QLinear +from ...utils.utils import find_layers, find_parent_layer_and_sub_name from ..base_model import BaseDiffusionModel from ..model_factory import SlimModelFactory @@ -24,13 +32,14 @@ class FLUX(BaseDiffusionModel): def __init__( self, model=None, - deploy_backend="torch", + deploy_backend="huggingface", ): super().__init__( model=model, deploy_backend=deploy_backend, ) self.model_type = "flux" + self.block_name = "transformer_blocks" self.cache_helper = None def from_pretrained( @@ -39,6 +48,8 @@ def from_pretrained( torch_dtype="auto", cache_dir=None, use_cache_helper=False, + compress_config=None, + **kwargs, ): """ Load a pretrained FLUX model. @@ -46,12 +57,26 @@ def from_pretrained( model_path (str): Path to the pretrained model. torch_dtype (str): Data type for the model weights. cache_dir (str): Directory to cache the model. + use_cache_helper (bool): Whether to use cache helper for optimization. + compress_config (dict): Compression configuration. """ - self.model = FluxPipeline.from_pretrained( - model_path, - torch_dtype=torch_dtype, - cache_dir=cache_dir, - ) + # load the model from the specified path + if compress_config and compress_config.name == "PTQ": + self.model = FluxQuantPipeline.from_pretrained( + model_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + ) + scales_dicts = load_file( + os.path.join(model_path, "model-scales.safetensors") + ) + self.model.quantize(compress_config, scales_dicts) + else: + self.model = FluxPipeline.from_pretrained( + model_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + ) if use_cache_helper: self.model.cache_helper = self.cache_helper @@ -90,7 +115,124 @@ def generate( ).images[0] def get_observer_layers(self): - pass + names = [ + "attn.to_q", + "attn.to_k", + "attn.to_v", + "norm.linear", + "proj_mlp", + "proj_out", + "attn.add_k_proj", + "attn.add_q_proj", + "attn.add_v_proj", + "attn.to_add_out", + "attn.to_out", + "to_out.0", + "0.proj", + "net.0", + "net.2", + "norm1.linear", + "norm1_context.linear", + ] + self.quant_module = self.model.transformer + obs_layers = [nn.Linear] + observer_layers_dict = {} + layers_dict = find_layers(self.quant_module, layers=obs_layers) + + ignore_layers = self.skip_layer_names() + for name, module in layers_dict.items(): + if self.block_name in name and ( + name.split(".")[-1] in names + or name.split(".")[-2] + "." + name.split(".")[-1] in names + ): + observer_layers_dict[name] = module + else: + ignore_layers.append(name) + self.quant_config.quant_algo_info["ignore_layers"] = ignore_layers + + return observer_layers_dict + + def get_quant_module(self): + """ + Returns the module that will be quantized. + This is typically the main transformer module of the model. + """ + return self.model.transformer def get_save_func(self): - pass + if self.deploy_backend in ["huggingface"]: + return PTQDiffusionSave + elif self.deploy_backend in ["tensorrt"]: + return PTQOnlyScaleSave + else: + raise NotImplementedError( + f"deploy_backend {self.deploy_backend} is not supported for saving." + ) + + def model_forward(self, dataloader, **kwargs): + assert dataloader is not None, "Dataloader must be provided for model forward." + with torch.no_grad(): + for batch in tqdm(dataloader, desc="calibrating...", total=len(dataloader)): + generator = torch.Generator().manual_seed(batch["seed"].item()) + self.model( + prompt=batch["input"], + height=batch["height"].item(), + width=batch["width"].item(), + guidance_scale=batch["guidance_scale"].item(), + num_inference_steps=batch["num_inference_steps"].item(), + max_sequence_length=batch["max_sequence_length"].item(), + generator=generator, + ).images[0] + + +class FluxQuantPipeline(FluxPipeline): + def __init__( + self, + scheduler, + vae, + text_encoder, + tokenizer, + text_encoder_2, + tokenizer_2, + transformer, + image_encoder=None, + feature_extractor=None, + ): + super().__init__( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + + def quantize(self, compress_config, scales_dicts): + """ + Quantize the transformer layers of the FLUX model. + This method replaces the linear in the transformer with quant. + """ + quant_config = QuantConfig(compress_config) + layers_dict = find_layers(self.transformer, layers=[nn.Linear]) + + for name, sub_layer in layers_dict.items(): + if name in quant_config.quant_algo_info["ignore_layers"]: + continue + parent_layer, sub_name = find_parent_layer_and_sub_name( + self.transformer, name + ) + act_scale = scales_dicts[name + ".input_scale"] + weight_scale = scales_dicts[name + ".weight_scale"] + + qdq_module = QLinear( + quant_algo=quant_config.quant_algo, + weight=sub_layer.weight, + weight_scale=weight_scale, + bias=sub_layer.bias, + input_scale=act_scale, + ) + + setattr(parent_layer, sub_name, qdq_module) diff --git a/angelslim/models/llm/deepseek.py b/angelslim/models/llm/deepseek.py index e01ce125..9aef0ca2 100644 --- a/angelslim/models/llm/deepseek.py +++ b/angelslim/models/llm/deepseek.py @@ -27,7 +27,7 @@ weight_dequant, ) from ...compressor.quant.modules import QDQModule -from ...utils import print_info +from ...utils import find_layers, print_info from ..base_model import BaseLLMModel from ..model_factory import SlimModelFactory from .modeling_deepseek import ( @@ -52,6 +52,7 @@ def __init__( self.block_name = "model.layers" self.column_parallel_linear_class = ColumnParallelLinear self.row_parallel_linear_class = RowParallelLinear + torch.set_default_dtype(torch.bfloat16) def from_pretrained( self, @@ -97,7 +98,7 @@ def from_pretrained( def get_observer_layers(self): names = self.quant_config.quant_algo_info["ignore_layers"] obs_layers = [nn.Linear, Linear] - observer_layers_dict = self.find_layers(self.model, layers=obs_layers) + observer_layers_dict = find_layers(self.model, layers=obs_layers) observer_layers_dict = { k: v for k, v in observer_layers_dict.items() diff --git a/angelslim/models/llm/hunyuan_dense.py b/angelslim/models/llm/hunyuan_dense.py index 963c4447..f8c808fb 100644 --- a/angelslim/models/llm/hunyuan_dense.py +++ b/angelslim/models/llm/hunyuan_dense.py @@ -15,6 +15,7 @@ import torch.nn as nn from ...compressor.quant.core import PTQSaveVllmHF +from ...utils.utils import find_layers from ..base_model import BaseLLMModel from ..model_factory import SlimModelFactory @@ -45,7 +46,7 @@ def get_observer_layers(self): "mlp.gate_and_up_proj", ] obs_layers = [nn.Linear] - observer_layers_dict = self.find_layers(self.model, layers=obs_layers) + observer_layers_dict = find_layers(self.model, layers=obs_layers) observer_layers_dict = { k: v diff --git a/angelslim/models/llm/hunyuan_moe.py b/angelslim/models/llm/hunyuan_moe.py index 391cbf80..82ae5ff7 100644 --- a/angelslim/models/llm/hunyuan_moe.py +++ b/angelslim/models/llm/hunyuan_moe.py @@ -17,6 +17,7 @@ import torch.nn as nn from ...compressor.quant.core import PTQSaveVllmHF +from ...utils.utils import find_layers from ..base_model import BaseLLMModel from ..model_factory import SlimModelFactory @@ -51,7 +52,7 @@ def get_observer_layers(self): ] obs_layers = [nn.Linear] - observer_layers_dict = self.find_layers(self.model, layers=obs_layers) + observer_layers_dict = find_layers(self.model, layers=obs_layers) compiled_patterns = [re.compile(pattern) for pattern in expert_pattern] diff --git a/angelslim/models/llm/llama.py b/angelslim/models/llm/llama.py index a5b4e13e..b3f52241 100644 --- a/angelslim/models/llm/llama.py +++ b/angelslim/models/llm/llama.py @@ -15,6 +15,7 @@ import torch.nn as nn from ...compressor.quant.core import PTQSaveVllmHF +from ...utils.utils import find_layers from ..base_model import BaseLLMModel from ..model_factory import SlimModelFactory @@ -43,7 +44,7 @@ def get_observer_layers(self): "mlp.down_proj", ] obs_layers = [nn.Linear] - observer_layers_dict = self.find_layers(self.model, layers=obs_layers) + observer_layers_dict = find_layers(self.model, layers=obs_layers) observer_layers_dict = { k: v for k, v in observer_layers_dict.items() diff --git a/angelslim/models/llm/qwen.py b/angelslim/models/llm/qwen.py index 6033bd5d..ff7eeb29 100644 --- a/angelslim/models/llm/qwen.py +++ b/angelslim/models/llm/qwen.py @@ -17,6 +17,7 @@ import torch.nn as nn from ...compressor.quant.core import PTQSaveVllmHF +from ...utils.utils import find_layers from ..base_model import BaseLLMModel from ..model_factory import SlimModelFactory @@ -46,7 +47,7 @@ def get_observer_layers(self): ] obs_layers = [nn.Linear] observer_layers_dict = {} - layers_dict = self.find_layers(self.model, layers=obs_layers) + layers_dict = find_layers(self.model, layers=obs_layers) ignore_layers = self.skip_layer_names() for name, module in layers_dict.items(): diff --git a/angelslim/models/vlm/qwen_vl.py b/angelslim/models/vlm/qwen_vl.py index f72558a8..6617a10c 100644 --- a/angelslim/models/vlm/qwen_vl.py +++ b/angelslim/models/vlm/qwen_vl.py @@ -23,7 +23,7 @@ ) from ...compressor.quant.core import PTQVLMSaveVllmHF -from ...utils import print_info +from ...utils import find_layers, print_info from ..base_model import BaseLLMModel from ..model_factory import SlimModelFactory @@ -95,7 +95,7 @@ def get_observer_layers(self): obs_layers = [nn.Linear] observer_layers_dict = {} - layers_dict = self.find_layers(self.model, layers=obs_layers) + layers_dict = find_layers(self.model, layers=obs_layers) ignore_layers = self.skip_layer_names() for name, module in layers_dict.items(): @@ -169,6 +169,13 @@ def model_forward(self, dataloader, **kwargs): calibrated_cnt += 1 pass + def get_quant_module(self): + """ + Returns the module that will be quantized. + This is typically the main transformer module of the model. + """ + return self.model.model.language_model.layers + def get_save_func(self): if self.deploy_backend in ["vllm", "huggingface"]: return PTQVLMSaveVllmHF diff --git a/angelslim/utils/__init__.py b/angelslim/utils/__init__.py index 8d2405d9..e8f87898 100644 --- a/angelslim/utils/__init__.py +++ b/angelslim/utils/__init__.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_parser import SlimConfigParser # noqa: F401 +from .config_parser import SlimConfigParser, parse_json_full_config # noqa: F401 from .default_compress_config import * # noqa: F401 F403 from .utils import common_prefix # noqa: F401 +from .utils import find_layers # noqa: F401 from .utils import find_parent_layer_and_sub_name # noqa: F401 from .utils import get_best_device # noqa: F401 from .utils import get_op_by_name # noqa: F401 diff --git a/angelslim/utils/config_parser.py b/angelslim/utils/config_parser.py index fe1bec67..9f266a39 100644 --- a/angelslim/utils/config_parser.py +++ b/angelslim/utils/config_parser.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from dataclasses import dataclass, field from typing import Any, Dict, List, Optional @@ -117,6 +118,7 @@ class DatasetConfig: num_samples: int = field(default=256) batch_size: int = field(default=1) shuffle: bool = field(default=False) + inference_settings: Optional[Dict[str, Any]] = field(default=None) @dataclass @@ -401,6 +403,76 @@ def get_default_config() -> FullConfig: ) +def parse_json_compression_config_section(compress_config: dict) -> CompressionConfig: + """ + Parses the compression_config field from a JSON configuration file + + Args: + compress_config: Dictionary containing compression configuration data + + Returns: + CompressionConfig instance initialized with the parsed data + """ + # Extract compression method name (required field) + name = compress_config["name"] + + # Parse quantization configuration + quant_data = compress_config.get("quantization") + quantization = None + # Create QuantizationConfig if quantization data exists + if quant_data: + quantization = QuantizationConfig(**quant_data) + + # Parse cache configuration + cache_data = compress_config.get("cache") + cache = None + # Create CacheConfig if cache data exists + if cache_data: + cache = CacheConfig(**cache_data) + + # Create and return the CompressionConfig instance + return CompressionConfig(name=name, quantization=quantization, cache=cache) + + +def parse_json_full_config(json_file_path: str) -> FullConfig: + """ + Parses a JSON configuration file into a FullConfig instance + + Args: + json_file_path: Path to JSON configuration file + + Returns: + Fully populated FullConfig instance containing all configuration sections + """ + with open(json_file_path, "r") as f: + config_data = json.load(f) + + # Parse model configuration section + model_config = ModelConfig(**config_data["model_config"]) + + # Parse compression configuration section + comp_config = parse_json_compression_config_section( + config_data["compression_config"] + ) + + # Parse other configuration sections with default fallbacks + dataset_config, global_config, infer_config = None, None, None + if config_data.get("dataset_config", {}): + dataset_config = DatasetConfig(**config_data["dataset_config"]) + if config_data.get("global_config", {}): + global_config = GlobalConfig(**config_data["global_config"]) + if config_data.get("infer_config", {}): + infer_config = InferenceConfig(**config_data["infer_config"]) + + return FullConfig( + model_config=model_config, + compression_config=comp_config, + dataset_config=dataset_config, + global_config=global_config, + infer_config=infer_config, + ) + + def print_config(config, indent=0): """ Print the configuration in a structured YAML-like format diff --git a/angelslim/utils/utils.py b/angelslim/utils/utils.py index a3757506..83729c78 100644 --- a/angelslim/utils/utils.py +++ b/angelslim/utils/utils.py @@ -70,6 +70,23 @@ def find_parent_layer_and_sub_name(model, name): return parent_layer, sub_name +def find_layers(module, layers=None, name=""): + if not layers: + layers = [torch.nn.Linear] + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update( + find_layers( + child, + layers=layers, + name=name + "." + name1 if name != "" else name1, + ) + ) + return res + + def get_tensor_item(x): return x.item() diff --git a/configs/flux/flux-1-schnell_deepcache.yaml b/configs/flux/flux-1-schnell_deepcache.yaml index 54067c18..564a5a82 100644 --- a/configs/flux/flux-1-schnell_deepcache.yaml +++ b/configs/flux/flux-1-schnell_deepcache.yaml @@ -1,7 +1,7 @@ # Global configuration of pipeline global: save_path: ./output - deploy_backend: torch + deploy_backend: huggingface # Simplified Configuration for LLM compression model: diff --git a/configs/flux/flux-1-schnell_fp8_static.yaml b/configs/flux/flux-1-schnell_fp8_static.yaml new file mode 100644 index 00000000..62243b93 --- /dev/null +++ b/configs/flux/flux-1-schnell_fp8_static.yaml @@ -0,0 +1,37 @@ +# Global configuration of pipeline +global: + save_path: ./output + deploy_backend: huggingface + +# Simplified Configuration for LLM compression +model: + name: FLUX + model_path: black-forest-labs/FLUX.1-schnell + cache_dir: NULL + torch_dtype: bfloat16 + +# Compression configuration +compression: + name: PTQ + quantization: + name: fp8_static + bits: 8 + quant_method: + weight: "per-tensor" + activation: "per-tensor" + ignore_layers: # Skip quantization for these layers + - "time_text_embed" + +# Dataset for calibration +dataset: + name: Text2ImageDataset + data_path: ./dataset/text2image_data/text2image_example_data.jsonl + num_samples: 2 + batch_size: 1 + inference_settings: + height: 1024 + width: 1024 + guidance_scale: 3.5 + num_inference_steps: 50 + max_sequence_length: 512 + seed: 42 diff --git a/configs/flux/flux-1-schnell_fp8_static_trt.yaml b/configs/flux/flux-1-schnell_fp8_static_trt.yaml new file mode 100644 index 00000000..89d61b80 --- /dev/null +++ b/configs/flux/flux-1-schnell_fp8_static_trt.yaml @@ -0,0 +1,37 @@ +# Global configuration of pipeline +global: + save_path: ./output + deploy_backend: tensorrt + +# Simplified Configuration for LLM compression +model: + name: FLUX + model_path: black-forest-labs/FLUX.1-schnell + cache_dir: NULL + torch_dtype: bfloat16 + +# Compression configuration +compression: + name: PTQ + quantization: + name: fp8_static + bits: 8 + quant_method: + weight: "per-tensor" + activation: "per-tensor" + ignore_layers: # Skip quantization for these layers + - "time_text_embed" + +# Dataset for calibration +dataset: + name: Text2ImageDataset + data_path: ./dataset/text2image_data/text2image_example_data.jsonl + num_samples: 2 + batch_size: 1 + inference_settings: + height: 1024 + width: 1024 + guidance_scale: 3.5 + num_inference_steps: 50 + max_sequence_length: 512 + seed: 42 diff --git a/configs/flux/flux-1-schnell_teacache.yaml b/configs/flux/flux-1-schnell_teacache.yaml index bb75126f..7ecc4288 100644 --- a/configs/flux/flux-1-schnell_teacache.yaml +++ b/configs/flux/flux-1-schnell_teacache.yaml @@ -1,7 +1,7 @@ # Global configuration of pipeline global: save_path: ./output - deploy_backend: torch + deploy_backend: huggingface # Simplified Configuration for LLM compression model: diff --git a/dataset/text2image_data/text2image_example_data.jsonl b/dataset/text2image_data/text2image_example_data.jsonl new file mode 100755 index 00000000..61e7b7ee --- /dev/null +++ b/dataset/text2image_data/text2image_example_data.jsonl @@ -0,0 +1,2 @@ +{"input": "A cat holding a sign that says hello world.", "type": "text2image"} +{"input": "A beautiful landscape with mountains and a river.", "type": "text2image"} \ No newline at end of file diff --git a/docs/source/models/flux/flux_cache.md b/docs/source/models/flux/flux_cache.md index 9c20b90b..e89c3644 100644 --- a/docs/source/models/flux/flux_cache.md +++ b/docs/source/models/flux/flux_cache.md @@ -3,13 +3,13 @@ ## DeepCache ```shell -python tools/run.py -c configs/flux/flux-1-schnell_deepcache.yaml \ +python tools/infer.py -c configs/flux/flux-1-schnell_deepcache.yaml \ --input-prompt "A beautiful landscape with mountains and a river." ``` ## DeepCache ```shell -python tools/run.py -c configs/flux/flux-1-schnell_teacache.yaml \ +python tools/infer.py -c configs/flux/flux-1-schnell_teacache.yaml \ --input-prompt "A beautiful landscape with mountains and a river." ``` \ No newline at end of file diff --git a/tools/fp8_quant_analyse.py b/tools/fp8_quant_analyse.py index 4a5446a7..acb9fb19 100644 --- a/tools/fp8_quant_analyse.py +++ b/tools/fp8_quant_analyse.py @@ -45,7 +45,9 @@ def quant_analyse(args): if __name__ == "__main__": - global_parser = argparse.ArgumentParser(description="全局参数", add_help=True) + global_parser = argparse.ArgumentParser( + description="Quantization analyse", add_help=True + ) global_parser.add_argument( "--analyse-type", type=str, diff --git a/tools/infer.py b/tools/infer.py new file mode 100644 index 00000000..e7852d70 --- /dev/null +++ b/tools/infer.py @@ -0,0 +1,118 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from angelslim.engine import InferEngine +from angelslim.utils import get_yaml_prefix_simple +from angelslim.utils.config_parser import SlimConfigParser, print_config + + +def get_args(): + parser = argparse.ArgumentParser(description="AngelSlim") + parser.add_argument("--model-path", type=str, default=None) + parser.add_argument("--input-prompt", type=str, default=None) + parser.add_argument("-c", "--config", type=str, default=None) + parser.add_argument("--save-path", type=str, default="./output/") + + args = parser.parse_args() + return args + + +def merge_config(config, args): + """ + Merge command line arguments into the configuration dictionary. + + Args: + config (dict): Configuration dictionary to be updated. + args (argparse.Namespace): Parsed command line arguments. + """ + if args.save_path is not None: + config.global_config.save_path = args.save_path + if args.model_path is not None: + config.model_config.model_path = args.model_path + config.global_config.save_path = os.path.join( + config.global_config.save_path, + get_yaml_prefix_simple(args.config), + ) + + +def infer(config, args): + """ + Evaluate the compression process. + This function is a placeholder for future evaluation logic. + """ + assert ( + config or args.model_path + ), "Please provide a model path or a configuration file." + slim_engine = InferEngine() + + if config: + # Step 1: Initialize configurations + model_config = config.model_config + compress_config = config.compression_config + global_config = config.global_config + infer_config = config.infer_config + + # Step 2: Prepare model + slim_engine.prepare_model( + model_name=model_config.name, + model_path=model_config.model_path, + torch_dtype=model_config.torch_dtype, + device_map=model_config.device_map, + trust_remote_code=model_config.trust_remote_code, + low_cpu_mem_usage=model_config.low_cpu_mem_usage, + use_cache=model_config.use_cache, + cache_dir=model_config.cache_dir, + deploy_backend=global_config.deploy_backend, + ) + + # Step 4: Initialize compressor + slim_engine.prepare_compressor( + compress_name=compress_config.name, + compress_config=compress_config, + global_config=global_config, + ) + else: + slim_engine.from_pretrained(model_path=args.model_path) + + if config and infer_config: + output = slim_engine.generate(args.input_prompt, **infer_config.__dict__) + else: + output = slim_engine.generate(args.input_prompt) + if slim_engine.series == "Diffusion": + # Save the generated image + if config and global_config: + save_path = os.path.join(global_config.save_path, "output_image.png") + else: + save_path = os.path.join(args.save_path, "output_image.png") + + # Ensure the directory exists + if save_path: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + output.save(save_path) + + +if __name__ == "__main__": + args = get_args() + config = None + if args.config: + parser = SlimConfigParser() + config = parser.parse(args.config) + merge_config(config, args) + print_config(config) + assert args.input_prompt, "Please provide an input prompt for inference." + infer(config, args) diff --git a/tools/run.py b/tools/run.py index d0315b0a..39685525 100644 --- a/tools/run.py +++ b/tools/run.py @@ -29,7 +29,6 @@ def get_args(): parser.add_argument("-c", "--config", type=str, required=True) parser.add_argument("--model-path", type=str, default=None) parser.add_argument("--save-path", type=str, default=None) - parser.add_argument("--input-prompt", type=str, default=None) parser.add_argument("--multi-nodes", action="store_true") args = parser.parse_args() return args @@ -90,6 +89,7 @@ def multi_nodes_run(config): trust_remote_code=model_config.trust_remote_code, low_cpu_mem_usage=model_config.low_cpu_mem_usage, use_cache=model_config.use_cache, + cache_dir=model_config.cache_dir, deploy_backend=global_config.deploy_backend, using_multi_nodes=True, ) @@ -104,6 +104,7 @@ def multi_nodes_run(config): batch_size=dataset_config.batch_size, num_samples=dataset_config.num_samples, shuffle=dataset_config.shuffle, + inference_settings=dataset_config.inference_settings, ) # Step 6: Initialize compressor @@ -146,6 +147,7 @@ def run(config): trust_remote_code=model_config.trust_remote_code, low_cpu_mem_usage=model_config.low_cpu_mem_usage, use_cache=model_config.use_cache, + cache_dir=model_config.cache_dir, deploy_backend=global_config.deploy_backend, ) @@ -159,6 +161,7 @@ def run(config): batch_size=dataset_config.batch_size, num_samples=dataset_config.num_samples, shuffle=dataset_config.shuffle, + inference_settings=dataset_config.inference_settings, ) # Step 5: Initialize compressor @@ -175,64 +178,13 @@ def run(config): slim_engine.save(global_config.save_path, config) -def infer(config, input_prompt): - """ - Evaluate the compression process. - This function is a placeholder for future evaluation logic. - """ - # Step 1: Initialize configurations - model_config = config.model_config - compress_config = config.compression_config - global_config = config.global_config - infer_config = config.infer_config - - # Step 2: Execute complete pipeline - slim_engine = Engine() - - # Step 3: Prepare model - slim_engine.prepare_model( - model_name=model_config.name, - model_path=model_config.model_path, - torch_dtype=model_config.torch_dtype, - device_map=model_config.device_map, - trust_remote_code=model_config.trust_remote_code, - low_cpu_mem_usage=model_config.low_cpu_mem_usage, - use_cache=model_config.use_cache, - cache_dir=model_config.cache_dir, - deploy_backend=global_config.deploy_backend, - ) - - # Step 4: Initialize compressor - slim_engine.prepare_compressor( - compress_name=compress_config.name, - compress_config=compress_config, - global_config=global_config, - ) - - # Step 5: Run inference - output = slim_engine.infer(input_prompt, **infer_config.__dict__) - if slim_engine.series == "Diffusion": - # Save the generated image - if global_config.save_path: - save_path = os.path.join(global_config.save_path, "output_image.png") - else: - save_path = "output_image.png" - - # Ensure the directory exists - os.makedirs(os.path.dirname(save_path), exist_ok=True) - - output.save(save_path) - - if __name__ == "__main__": args = get_args() parser = SlimConfigParser() config = parser.parse(args.config) merge_config(config, args) print_config(config) - if args.input_prompt: - infer(config, args.input_prompt) - elif args.multi_nodes: + if args.multi_nodes: multi_nodes_run(config) else: run(config)