Skip to content
Merged
84 changes: 46 additions & 38 deletions angelslim/compressor/quant/modules/awq/auto_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch

from .....utils import get_op_by_name, get_op_name, print_info, set_op_by_name
from ...core import mse_loss, per_block_weight_quant, weight_dequant
from ...core import mse_loss, weight_dequant
from ...modules.helper_layer import SmoothHelpModule
from .search import AWQSearch

Expand Down Expand Up @@ -76,10 +76,9 @@ def apply_scale(self, module, scales_list, input_feat_dict=None):

for layer in layers:
if layer.weight.dtype == torch.float8_e4m3fn:
weight = weight_dequant(layer.weight, layer.weight_scale_inv)
weight.mul_(scales.view(1, -1))
weight, _ = per_block_weight_quant(weight)
layer.weight.data.copy_(weight)
w = weight_dequant(layer.weight, layer.weight_scale_inv)
w.mul_(scales.view(1, -1))
layer.weight.data = w
else:
layer.weight.mul_(scales.view(1, -1))
for p in layer.parameters():
Expand All @@ -91,6 +90,27 @@ def apply_scale(self, module, scales_list, input_feat_dict=None):
new_module.convert_weight(scales)
set_op_by_name(module, prev_op_name, new_module)

elif type(prev_op) in self.observer_layer_classes:
if prev_op.weight.dtype == torch.float8_e4m3fn:
w = weight_dequant(prev_op.weight, prev_op.weight_scale_inv)
prev_op.weight.data = w
scales = scales.to(prev_op.weight.device)
prev_op.weight[-scales.size(0) :].div_(scales.view(-1, 1))
if prev_op.bias is not None:
prev_op.bias.div_(scales.view(-1))

for layer in layers:
if layer.weight.dtype == torch.float8_e4m3fn:
w = weight_dequant(layer.weight, layer.weight_scale_inv)
layer.weight.data = w
layer.weight.mul_(scales.view(1, -1))

for p in prev_op.parameters():
assert torch.isnan(p).sum() == 0
for layer in layers:
for p in layer.parameters():
assert torch.isnan(p).sum() == 0

if input_feat_dict is not None:
for layer_name in layer_names:
if layer_name not in input_feat_dict.keys():
Expand Down Expand Up @@ -164,14 +184,14 @@ def _auto_get_scale(
cache=cache,
)
)

# attention output
scales_list.append(
_auto_get_scale(
layer_name="attn.o",
prev_op=module.self_attn.kv_b_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
layer_name="attn.q_b_proj",
prev_op=module.self_attn.q_a_layernorm,
layers=[
module.self_attn.q_b_proj,
],
inp=input_feat["self_attn.q_b_proj"],
)
)
else:
Expand Down Expand Up @@ -256,42 +276,20 @@ def _auto_get_scale(
)
)
elif self.model_type == "deepseek_v3":
# share_mlp fc1
scales_list.append(
_auto_get_scale(
layer_name="shared_experts.gate_proj",
prev_op=module.post_attention_layernorm,
layers=[
module.mlp.shared_experts.gate_proj,
module.mlp.shared_experts.up_proj,
],
inp=input_feat["mlp"],
module2inspect=module.mlp,
cache=cache,
)
)
# share_mlp fc2
scales_list.append(
_auto_get_scale(
layer_name="shared_experts.down_proj",
prev_op=module.mlp.shared_experts.up_proj,
layers=[module.mlp.shared_experts.down_proj],
inp=input_feat["mlp.shared_experts.down_proj"].view(
input_feat["mlp"].shape[0], input_feat["mlp"].shape[1], -1
),
)
)
# fc1
scales_list.append(
_auto_get_scale(
layer_name="expert.gate_proj",
layer_name="moe",
prev_op=module.post_attention_layernorm,
layers=[
w
for expert in module.mlp.experts
for w in [expert.gate_proj, expert.up_proj]
]
+ [module.mlp.gate],
+ [
module.mlp.shared_experts.gate_proj,
module.mlp.shared_experts.up_proj,
],
inp=input_feat["mlp"],
module2inspect=module.mlp,
cache=cache,
Expand All @@ -307,6 +305,16 @@ def _auto_get_scale(
inp=input_feat[f"mlp.experts.{i}.down_proj"].unsqueeze(0),
)
)
scales_list.append(
_auto_get_scale(
layer_name="shared_experts.down_proj",
prev_op=module.mlp.shared_experts.up_proj,
layers=[module.mlp.shared_experts.down_proj],
inp=input_feat["mlp.shared_experts.down_proj"].view(
input_feat["mlp"].shape[0], input_feat["mlp"].shape[1], -1
),
)
)
else:
# fc1
scales_list.append(
Expand Down
18 changes: 14 additions & 4 deletions angelslim/compressor/quant/modules/awq/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tqdm import tqdm

from .....utils import find_layers, get_best_device, print_info, set_op_by_name
from ...core import pseudo_quantize_tensor
from ...core import pseudo_quantize_tensor, weight_dequant
from ...modules.catcher import Catcher
from ...modules.helper_layer import WQLinearGEMM
from .auto_clip import AutoLayerClip
Expand Down Expand Up @@ -290,7 +290,10 @@ def forward(self, x):
force_contiguous=True,
shared_tensors_to_discard=self.model.model._tied_weights_keys,
)
self.model.model.config.torch_dtype = "float16"
if self.model_arch_type == "deepseek_v3":
self.model.model.config.torch_dtype = "bfloat16"
else:
self.model.model.config.torch_dtype = "float16"
self.model.model.config.to_json_file(os.path.join(save_dir, "config.json"))

# save processor and tokenizer
Expand All @@ -303,8 +306,15 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
for name, linear_layer in named_linears.items():
if "mlp.gate." in name:
continue
# NOTE: small regression in perplexity if linear layer uses .cpu().float()
linear_layer = linear_layer.to(get_best_device()).half()
if linear_layer.weight.dtype == torch.float8_e4m3fn:
linear_layer = linear_layer.to("cuda")
w = weight_dequant(linear_layer.weight, linear_layer.weight_scale_inv)
linear_layer = linear_layer.to("cpu")
w = w.to("cpu")
linear_layer.weight.data = w
else:
# NOTE: small regression in perplexity if linear uses .cpu().float()
linear_layer = linear_layer.to(get_best_device()).half()

linear_layer.weight.data, scales, zeros = pseudo_quantize_tensor(
linear_layer.weight.data,
Expand Down
32 changes: 16 additions & 16 deletions angelslim/compressor/quant/modules/awq/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,9 @@
# limitations under the License.

import torch
from torch.nn import Linear

from .....utils import get_best_device, print_info
from ...core import (
mse_loss,
per_block_weight_quant,
pseudo_quantize_tensor,
weight_dequant,
)
from ...core import mse_loss, pseudo_quantize_tensor, weight_dequant

print_func = print_info

Expand Down Expand Up @@ -101,13 +95,10 @@ def search_by_block(
scales = scales / (scales.max() * scales.min()).sqrt()
for layer in layers:
if layer.weight.dtype == torch.float8_e4m3fn:
weight = weight_dequant(layer.weight, layer.weight_scale_inv)
weight.mul_(scales.view(1, -1))
weight, _ = per_block_weight_quant(weight)
layer.weight.data.copy_(weight)
else:
layer.weight.mul_(scales.view(1, -1))
if type(layer) in [Linear]:
w = weight_dequant(layer.weight, layer.weight_scale_inv)
layer.weight.data = w
layer.weight.mul_(scales.view(1, -1))
if type(layer) in self.observer_layer_classes:
quant_dequant_weight = pseudo_quantize_tensor(
layer.weight,
w_bit=self.bits_length,
Expand All @@ -121,7 +112,16 @@ def search_by_block(
layer_name, new_act, block, cache
).to(act.device)

loss = self.loss_function(origin_out, new_out).to(torch.float32)
try:
loss = self.loss_function(origin_out, new_out).to(torch.float32)
except RuntimeError as e:
if "CUDA out of memory" in str(e):
print_func("switch cpu to compute loss...")
origin_out = origin_out.cpu()
new_out = new_out.cpu()
loss = self.loss_function(origin_out, new_out).to(torch.float32)
else:
raise

if loss < best_error:
print_func("find better ratio: {}, loss: {}".format(ratio, loss))
Expand All @@ -130,7 +130,7 @@ def search_by_block(
best_scales = scales

for layer, w in zip(layers, org_w):
layer.weight.data.copy_(w)
layer.weight.data = w.to(act.device)

origin_out = origin_out.detach().cpu()
new_out = w.detach().cpu()
Expand Down
7 changes: 6 additions & 1 deletion angelslim/models/llm/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def __init__(
self.block_name = "model.layers"
self.column_parallel_linear_class = ColumnParallelLinear
self.row_parallel_linear_class = RowParallelLinear
self.observer_layer_classes = [nn.Linear, Linear]
self.observer_layer_classes = [
nn.Linear,
Linear,
ColumnParallelLinear,
RowParallelLinear,
]
torch.set_default_dtype(torch.bfloat16)

def from_pretrained(
Expand Down
12 changes: 12 additions & 0 deletions docs/source/models/deepseek/deepseek_quant.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,18 @@ INT4-AWQ `confg.yaml`文件参数配置,您可以参考`config/deepseek_r1/int
### INT4-AWQ量化

您可以通过下面代码启动INT4-AWQ量化流程:
#### AWQ算法
```shell
python3 tools/run.py --config configs/deepseek_r1/int4_awq/deepseek_r1_int4_awq.yaml
```
#### 快速转换AWQ格式
```shell
python3 tools/convert_int4_awq_offline.py \
--bit 4 \
--group-size xx \ # 128或者64
--zero-point xx \ # 设置为True时在vllm中对应AWQ,设置为False对应GPTQ
--num-workers xx \ # 线程数
--input_path 权重路径 \
--output_path 保存路径 \
--exclude-patterns None \ # 设置不量化模块,默认为None
```
Loading