diff --git a/angelslim/compressor/quant/core/save.py b/angelslim/compressor/quant/core/save.py index 71b6025a..21d2695c 100644 --- a/angelslim/compressor/quant/core/save.py +++ b/angelslim/compressor/quant/core/save.py @@ -401,7 +401,6 @@ def save(self, save_path): if fused_act_scale_dict: for k, v in fused_act_scale_dict.items(): - torch.distributed.all_reduce(v, op=torch.distributed.ReduceOp.MAX) _save_path = os.path.join( save_path, "{}.input_scale.{}.pt".format(k, _index) ) @@ -412,6 +411,7 @@ def save(self, save_path): ) torch.save(v, _save_path) else: + torch.distributed.all_reduce(v, op=torch.distributed.ReduceOp.MAX) if self.rank == 0: torch.save(v, _save_path) print_info("save act scales done.") @@ -419,6 +419,9 @@ def save(self, save_path): if self.quant_model.weight_scales_dict: for k, v in self.quant_model.weight_scales_dict.items(): max_value_group_wise = v + # fp8 pertensor scale + fused_max_value = fused_weight_fp8_scale_dict[k] + # if weight quant is int4 and act quant is fp8, extra save int4 absmax if ( self.quant_model.quant_algo_dict["w_quant_algo"] == "int4" @@ -442,14 +445,11 @@ def save(self, save_path): _save_path, self.quant_model.quant_algo_dict["all_reduce"], ) + scale = (fused_max_value.max() / 448.0).to(fused_max_value.dtype) + elif self.quant_model.quant_algo_dict["w_quant_algo"] == "fp8": + scale = fused_max_value.max().to(fused_max_value.dtype) - # fp8 pertensor scale - fused_max_value = fused_weight_fp8_scale_dict[k] - scale = (fused_max_value.max() / 448.0).to(fused_max_value.dtype) assert scale.numel() == 1 - print_info(f"before all reduce scale = {scale}") - torch.distributed.all_reduce(scale, op=torch.distributed.ReduceOp.MAX) - print_info(f"after all reduce scale = {scale}") if "experts" in k and "shared_experts" not in k: _save_path = os.path.join( @@ -457,6 +457,11 @@ def save(self, save_path): ) torch.save(scale, _save_path) else: + print_info(f"before all reduce scale = {scale}") + torch.distributed.all_reduce( + scale, op=torch.distributed.ReduceOp.MAX + ) + print_info(f"after all reduce scale = {scale}") _save_path = os.path.join( save_path, "{}.weight_scale.{}.pt".format(k, _index) ) @@ -492,6 +497,8 @@ def save(self, save_path): if os.path.exists(tmp_path): shutil.rmtree(tmp_path) + if os.path.exists(save_path): + shutil.rmtree(save_path) parent_dir = os.path.dirname( self.quant_model.model.ori_model_path.rstrip("/") ) diff --git a/angelslim/compressor/quant/modules/awq/auto_scale.py b/angelslim/compressor/quant/modules/awq/auto_scale.py index d8036b40..c24fdb59 100644 --- a/angelslim/compressor/quant/modules/awq/auto_scale.py +++ b/angelslim/compressor/quant/modules/awq/auto_scale.py @@ -15,7 +15,7 @@ import torch from .....utils import get_op_by_name, get_op_name, print_info, set_op_by_name -from ...core import mse_loss +from ...core import mse_loss, per_block_weight_quant, weight_dequant from ...modules.helper_layer import SmoothHelpModule from .search import AWQSearch @@ -75,7 +75,13 @@ def apply_scale(self, module, scales_list, input_feat_dict=None): assert torch.isnan(p).sum() == 0, f"nan in {prev_op_name} weight" for layer in layers: - layer.weight.mul_(scales.view(1, -1)) + if layer.weight.dtype == torch.float8_e4m3fn: + weight = weight_dequant(layer.weight, layer.weight_scale_inv) + weight.mul_(scales.view(1, -1)) + weight, _ = per_block_weight_quant(weight) + layer.weight.data.copy_(weight) + else: + layer.weight.mul_(scales.view(1, -1)) for p in layer.parameters(): assert torch.isnan(p).sum() == 0, f"nan in {layer_names} weight" @@ -144,31 +150,59 @@ def _auto_get_scale( scales_list = [] print_info(input_feat.keys()) - scales_list.append( - _auto_get_scale( - layer_name="attn.qkv", - prev_op=module.input_layernorm, - layers=[ - module.self_attn.q_proj, - module.self_attn.k_proj, - module.self_attn.v_proj, - ], - inp=input_feat["self_attn.q_proj"], - module2inspect=module.self_attn, - cache=cache, + if self.model_type == "deepseek_v3": + scales_list.append( + _auto_get_scale( + layer_name="attn.qkv", + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_a_proj, + module.self_attn.kv_a_proj_with_mqa, + ], + inp=input_feat["self_attn.q_a_proj"], + module2inspect=module.self_attn, + cache=cache, + ) ) - ) - # attention output - if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + # attention output scales_list.append( _auto_get_scale( layer_name="attn.o", - prev_op=module.self_attn.v_proj, + prev_op=module.self_attn.kv_b_proj, layers=[module.self_attn.o_proj], inp=input_feat["self_attn.o_proj"], ) ) + else: + scales_list.append( + _auto_get_scale( + layer_name="attn.qkv", + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + cache=cache, + ) + ) + + # attention output + if ( + module.self_attn.v_proj.weight.shape + == module.self_attn.o_proj.weight.shape + ): + scales_list.append( + _auto_get_scale( + layer_name="attn.o", + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) if hasattr(module.mlp, "gate"): print_info("auto scale -> MoeAWQ") @@ -221,6 +255,58 @@ def _auto_get_scale( inp=input_feat[f"mlp.experts.{i}.down_proj"], ) ) + elif self.model_type == "deepseek_v3": + # share_mlp fc1 + scales_list.append( + _auto_get_scale( + layer_name="shared_experts.gate_proj", + prev_op=module.post_attention_layernorm, + layers=[ + module.mlp.shared_experts.gate_proj, + module.mlp.shared_experts.up_proj, + ], + inp=input_feat["mlp"], + module2inspect=module.mlp, + cache=cache, + ) + ) + # share_mlp fc2 + scales_list.append( + _auto_get_scale( + layer_name="shared_experts.down_proj", + prev_op=module.mlp.shared_experts.up_proj, + layers=[module.mlp.shared_experts.down_proj], + inp=input_feat["mlp.shared_experts.down_proj"].view( + input_feat["mlp"].shape[0], input_feat["mlp"].shape[1], -1 + ), + ) + ) + # fc1 + scales_list.append( + _auto_get_scale( + layer_name="expert.gate_proj", + prev_op=module.post_attention_layernorm, + layers=[ + w + for expert in module.mlp.experts + for w in [expert.gate_proj, expert.up_proj] + ] + + [module.mlp.gate], + inp=input_feat["mlp"], + module2inspect=module.mlp, + cache=cache, + ) + ) + # fc2 + for i, expert in enumerate(module.mlp.experts): + scales_list.append( + _auto_get_scale( + layer_name=f"expert.{i}.down_proj", + prev_op=expert.up_proj, + layers=[expert.down_proj], + inp=input_feat[f"mlp.experts.{i}.down_proj"].unsqueeze(0), + ) + ) else: # fc1 scales_list.append( diff --git a/angelslim/compressor/quant/modules/awq/awq.py b/angelslim/compressor/quant/modules/awq/awq.py index b7edec92..297eabe3 100644 --- a/angelslim/compressor/quant/modules/awq/awq.py +++ b/angelslim/compressor/quant/modules/awq/awq.py @@ -158,9 +158,9 @@ def run(self, dataloader): if not self.low_memory: outs = outs.to(dev) self.inps = self.inps.to(dev) - subset = find_layers(layer) + subset = find_layers(layer, layers=self.observer_layer_classes) - if self.model_arch_type in ["qwen3_moe", "hunyuan_v1_moe"]: + if self.model_arch_type in ["qwen3_moe", "hunyuan_v1_moe", "deepseek_v3"]: subset = { **subset, "mlp": layer.mlp, @@ -334,7 +334,7 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]): def _convert_llm(self): for i in tqdm(range(len(self.layers)), desc="AWQ"): - subset = find_layers(self.layers[i]) + subset = find_layers(self.layers[i], layers=self.observer_layer_classes) self._apply_quant(self.layers[i], subset) def convert(self): diff --git a/angelslim/compressor/quant/modules/awq/search.py b/angelslim/compressor/quant/modules/awq/search.py index 37fb47a1..87880097 100644 --- a/angelslim/compressor/quant/modules/awq/search.py +++ b/angelslim/compressor/quant/modules/awq/search.py @@ -16,7 +16,12 @@ from torch.nn import Linear from .....utils import get_best_device, print_info -from ...core import mse_loss, pseudo_quantize_tensor +from ...core import ( + mse_loss, + per_block_weight_quant, + pseudo_quantize_tensor, + weight_dequant, +) print_func = print_info @@ -95,7 +100,13 @@ def search_by_block( scales = act_abs_max_tmp.pow(ratio).clamp(min=1e-4).view(-1) scales = scales / (scales.max() * scales.min()).sqrt() for layer in layers: - layer.weight.mul_(scales.view(1, -1)) + if layer.weight.dtype == torch.float8_e4m3fn: + weight = weight_dequant(layer.weight, layer.weight_scale_inv) + weight.mul_(scales.view(1, -1)) + weight, _ = per_block_weight_quant(weight) + layer.weight.data.copy_(weight) + else: + layer.weight.mul_(scales.view(1, -1)) if type(layer) in [Linear]: quant_dequant_weight = pseudo_quantize_tensor( layer.weight, diff --git a/angelslim/compressor/quant/ptq.py b/angelslim/compressor/quant/ptq.py index 5d8ba2e3..1b917b3b 100644 --- a/angelslim/compressor/quant/ptq.py +++ b/angelslim/compressor/quant/ptq.py @@ -13,7 +13,6 @@ # limitations under the License. import torch -import torch.nn as nn from ...utils import find_parent_layer_and_sub_name, print_info from ..compressor_factory import CompressorFactory @@ -63,7 +62,7 @@ def __init__(self, model, slim_config=None): hidden_size=hidden_size, model_arch_type=model_arch_type, mse_range=self.quant_model.quant_config.quant_algo_info["mse_range"], - observer_layer_classes=[nn.Linear], + observer_layer_classes=self.quant_model.observer_layer_classes, low_memory=self.quant_model.quant_config.low_memory, ) elif "fp8" in self.quant_algo: diff --git a/angelslim/models/base_model.py b/angelslim/models/base_model.py index a28a9dd7..bd8ecec9 100644 --- a/angelslim/models/base_model.py +++ b/angelslim/models/base_model.py @@ -50,6 +50,7 @@ def __init__( self.tokenizer = None self.modal_type = "LLM" self.pre_transformer_module_names = ["model.embed_tokens"] + self.observer_layer_classes = [torch.nn.Linear] def from_pretrained( self, diff --git a/angelslim/models/diffusion/flux.py b/angelslim/models/diffusion/flux.py index 84d663cc..8e63007c 100644 --- a/angelslim/models/diffusion/flux.py +++ b/angelslim/models/diffusion/flux.py @@ -143,9 +143,8 @@ def get_observer_layers(self): "norm1_context.linear", ] self.quant_module = self.model.transformer - obs_layers = [nn.Linear] observer_layers_dict = {} - layers_dict = find_layers(self.quant_module, layers=obs_layers) + layers_dict = find_layers(self.quant_module, layers=self.observer_layer_classes) ignore_layers = self.skip_layer_names() for name, module in layers_dict.items(): diff --git a/angelslim/models/llm/deepseek.py b/angelslim/models/llm/deepseek.py index 1871c02f..08c0b33a 100644 --- a/angelslim/models/llm/deepseek.py +++ b/angelslim/models/llm/deepseek.py @@ -51,6 +51,7 @@ def __init__( self.block_name = "model.layers" self.column_parallel_linear_class = ColumnParallelLinear self.row_parallel_linear_class = RowParallelLinear + self.observer_layer_classes = [nn.Linear, Linear] torch.set_default_dtype(torch.bfloat16) def from_pretrained( @@ -96,8 +97,9 @@ def from_pretrained( def get_observer_layers(self): names = self.quant_config.quant_algo_info["ignore_layers"] - obs_layers = [nn.Linear, Linear] - observer_layers_dict = find_layers(self.model, layers=obs_layers) + observer_layers_dict = find_layers( + self.model, layers=self.observer_layer_classes + ) observer_layers_dict = { k: v for k, v in observer_layers_dict.items() diff --git a/angelslim/models/llm/hunyuan_dense.py b/angelslim/models/llm/hunyuan_dense.py index f8c808fb..96848896 100644 --- a/angelslim/models/llm/hunyuan_dense.py +++ b/angelslim/models/llm/hunyuan_dense.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch.nn as nn - from ...compressor.quant.core import PTQSaveVllmHF from ...utils.utils import find_layers from ..base_model import BaseLLMModel @@ -45,8 +43,9 @@ def get_observer_layers(self): "mlp.down_proj", "mlp.gate_and_up_proj", ] - obs_layers = [nn.Linear] - observer_layers_dict = find_layers(self.model, layers=obs_layers) + observer_layers_dict = find_layers( + self.model, layers=self.observer_layer_classes + ) observer_layers_dict = { k: v diff --git a/angelslim/models/llm/hunyuan_moe.py b/angelslim/models/llm/hunyuan_moe.py index 82ae5ff7..5ea210c2 100644 --- a/angelslim/models/llm/hunyuan_moe.py +++ b/angelslim/models/llm/hunyuan_moe.py @@ -14,8 +14,6 @@ import re -import torch.nn as nn - from ...compressor.quant.core import PTQSaveVllmHF from ...utils.utils import find_layers from ..base_model import BaseLLMModel @@ -51,8 +49,9 @@ def get_observer_layers(self): r"model\.layers\.\d+\.mlp\.experts\.\d+\.down_proj", ] - obs_layers = [nn.Linear] - observer_layers_dict = find_layers(self.model, layers=obs_layers) + observer_layers_dict = find_layers( + self.model, layers=self.observer_layer_classes + ) compiled_patterns = [re.compile(pattern) for pattern in expert_pattern] diff --git a/angelslim/models/llm/llama.py b/angelslim/models/llm/llama.py index b3f52241..2ec5b400 100644 --- a/angelslim/models/llm/llama.py +++ b/angelslim/models/llm/llama.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch.nn as nn - from ...compressor.quant.core import PTQSaveVllmHF from ...utils.utils import find_layers from ..base_model import BaseLLMModel @@ -43,8 +41,9 @@ def get_observer_layers(self): "mlp.gate_proj", "mlp.down_proj", ] - obs_layers = [nn.Linear] - observer_layers_dict = find_layers(self.model, layers=obs_layers) + observer_layers_dict = find_layers( + self.model, layers=self.observer_layer_classes + ) observer_layers_dict = { k: v for k, v in observer_layers_dict.items() diff --git a/angelslim/models/llm/modeling_deepseek.py b/angelslim/models/llm/modeling_deepseek.py index 878db915..3b32c9ce 100755 --- a/angelslim/models/llm/modeling_deepseek.py +++ b/angelslim/models/llm/modeling_deepseek.py @@ -574,9 +574,11 @@ def forward( q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) + freqs_cis = freqs_cis.to(q_pe.device) q_pe = apply_rotary_emb(q_pe, freqs_cis) kv = self.kv_a_proj_with_mqa(x) kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + freqs_cis = freqs_cis.to(k_pe.device) k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) if attn_impl == "naive": q = torch.cat([q_nope, q_pe], dim=-1) diff --git a/angelslim/models/llm/qwen.py b/angelslim/models/llm/qwen.py index 60b22779..f06fc010 100644 --- a/angelslim/models/llm/qwen.py +++ b/angelslim/models/llm/qwen.py @@ -14,8 +14,6 @@ import re -import torch.nn as nn - from ...compressor.quant.core import PTQSaveVllmHF from ...utils.utils import find_layers from ..base_model import BaseLLMModel @@ -45,9 +43,8 @@ def get_observer_layers(self): "gate_proj", "down_proj", ] - obs_layers = [nn.Linear] observer_layers_dict = {} - layers_dict = find_layers(self.model, layers=obs_layers) + layers_dict = find_layers(self.model, layers=self.observer_layer_classes) ignore_layers = self.skip_layer_names() for name, module in layers_dict.items(): diff --git a/angelslim/models/llm/seed_oss.py b/angelslim/models/llm/seed_oss.py index 7e1ca3df..c7c2dcfb 100644 --- a/angelslim/models/llm/seed_oss.py +++ b/angelslim/models/llm/seed_oss.py @@ -14,8 +14,6 @@ import re -import torch.nn as nn - from ...compressor.quant.core import PTQSaveVllmHF from ..base_model import BaseLLMModel from ..model_factory import SlimModelFactory @@ -44,9 +42,8 @@ def get_observer_layers(self): "gate_proj", "down_proj", ] - obs_layers = [nn.Linear] observer_layers_dict = {} - layers_dict = self.find_layers(self.model, layers=obs_layers) + layers_dict = self.find_layers(self.model, layers=self.observer_layer_classes) ignore_layers = self.skip_layer_names() for name, module in layers_dict.items(): diff --git a/angelslim/models/vlm/qwen_vl.py b/angelslim/models/vlm/qwen_vl.py index 6617a10c..d9cec6e6 100644 --- a/angelslim/models/vlm/qwen_vl.py +++ b/angelslim/models/vlm/qwen_vl.py @@ -13,7 +13,6 @@ # limitations under the License. import torch -import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm from transformers import ( @@ -93,9 +92,8 @@ def get_observer_layers(self): vit_names = ["qkv", "proj"] names.extend(vit_names) - obs_layers = [nn.Linear] observer_layers_dict = {} - layers_dict = find_layers(self.model, layers=obs_layers) + layers_dict = find_layers(self.model, layers=self.observer_layer_classes) ignore_layers = self.skip_layer_names() for name, module in layers_dict.items(): diff --git a/configs/deepseek_r1/int4_awq/deepseek_r1_int4_awq.yaml b/configs/deepseek_r1/int4_awq/deepseek_r1_int4_awq.yaml new file mode 100644 index 00000000..09aed313 --- /dev/null +++ b/configs/deepseek_r1/int4_awq/deepseek_r1_int4_awq.yaml @@ -0,0 +1,37 @@ +# Global configuration of pipeline +global: + save_path: ./output + +# Simplified Configuration for LLM compression +model: + name: DeepSeek + model_path: deepseek-ai/DeepSeek-R1 + trust_remote_code: true + low_cpu_mem_usage: true + use_cache: false + torch_dtype: fp8 + device_map: cpu + +# Compression configuration +compression: + name: PTQ + quantization: + name: int4_awq # Supported: fp8_static, w4a8_fp8, int4_awq + bits: 4 # Quantization bits (4/8) + quant_method: + weight: "per-group" + group_size: 128 + zero_point: true + mse_range: false + ignore_layers: # Skip quantization for these layers + - "lm_head" + - "model.embed_tokens" + - "model.layers.61." + +# Dataset for calibration +dataset: + name: TextDataset + data_path: your/data/path + max_seq_length: 4096 + num_samples: 128 + batch_size: 1 diff --git a/docs/source/models/deepseek/deepseek_quant.md b/docs/source/models/deepseek/deepseek_quant.md index fd82a7c7..7ee519be 100644 --- a/docs/source/models/deepseek/deepseek_quant.md +++ b/docs/source/models/deepseek/deepseek_quant.md @@ -117,3 +117,37 @@ torchrun \ python3 tools/run.py -c configs/deepseek_r1/w4a8_fp8/deepseek_r1_w4a8_fp8_low_memmory.yaml ``` +## INT4-AWQ量化 + +DeepSeekR1的INT4-AWQ量化,可使用vllm部署。其中权重为per-group的粒度,group-size可选64/128;激活为动态per-token量化,可支持激活数据类型为int8/fp8。具体可见对应PR: +```shell +https://github.com/vllm-project/vllm/pull/24722 +``` + + +您可以量化`AngelSlim/configs/deepseek_r1`下面带有`int4_awq`字段的模型类型。 + +### 配置 + +INT4-AWQ `confg.yaml`文件参数配置,您可以参考`config/deepseek_r1/int4_awq`路径下的文件,下面是参数信息介绍。 + +#### model配置 +- `name`:填写`DeepSeek`。 +- `torch_dtype`:权重加载时使用的数据类型。设置为`fp8`时可直接加载HF上的deepseek-ai/DeepSeek-R1-0528模型。设置为`bf16`时需要将fp8权重进行转换后再进行量化。 +- `device_map`:设置为`cpu`。 + +#### Compression配置 +- `name`:压缩策略,填写`PTQ`。 +- `quantization.name`:压缩算法选填`int4_awq`。 +- `quantization.bits`:量化比特数,设置为4。 +- `quantization.quant_method`:主要指定权重的量化粒度,设置为`per-group`。 +- `quantization.group_size`:权重量化分组数。 +- `quantization.zero_point`:权重量化偏置,设置为`True`。 +- `quantization.ignore_layers`:指定模型中不需要量化的层。 + +### INT4-AWQ量化 + +您可以通过下面代码启动INT4-AWQ量化流程: +```shell +python3 tools/run.py --config configs/deepseek_r1/int4_awq/deepseek_r1_int4_awq.yaml +``` \ No newline at end of file