|
34 | 34 | reduce_block_padding, |
35 | 35 | tensor_quant_dequant_fp8, |
36 | 36 | tensor_quant_dequant_int, |
| 37 | + unpack_weight_omni, |
37 | 38 | ) |
38 | 39 |
|
39 | 40 |
|
@@ -547,19 +548,21 @@ def __init__( |
547 | 548 | super().__init__() |
548 | 549 | self.quant_algo = quant_algo |
549 | 550 | weight_scale = weight_scale.to(weight.device) |
| 551 | + self.group_size = group_size |
550 | 552 | if "fp8" in quant_algo: |
551 | 553 | if "w4a8" in self.quant_algo: |
552 | 554 | max_value_group_wise = weight_scale.clone() |
| 555 | + # weight(bf16) -> fp8 -> int4 |
| 556 | + # dweight(int4) -> fp8 -> weight(bf16) |
553 | 557 | tensor_wise_scale = max_value_group_wise.max() / 448.0 |
554 | | - quant_weight, _ = quantize_weight_per_tensor_fp8(weight, tensor_wise_scale) |
555 | | - new_weight_bf16 = quant_weight.to(torch.bfloat16) * tensor_wise_scale |
556 | | - |
| 558 | + new_weight_bf16 = weight |
557 | 559 | new_weight_bf16_qdq = fake_quant_dequant( |
558 | 560 | new_weight_bf16, method="groupwise", bits=4, group_size=group_size |
559 | 561 | ) |
560 | 562 | quant_weight, _ = quantize_weight_int( |
561 | 563 | new_weight_bf16_qdq, max_value_group_wise, bits=4 |
562 | 564 | ) |
| 565 | + |
563 | 566 | quant_weight = pack_weight_to_int8(quant_weight) |
564 | 567 | del new_weight_bf16_qdq, new_weight_bf16 |
565 | 568 | self.weight_scale_int4 = torch.nn.Parameter( |
@@ -600,11 +603,30 @@ def forward(self, x): |
600 | 603 | raise ValueError(f"Unsupported quantization algorithm: {self.quant_algo}") |
601 | 604 |
|
602 | 605 | if "fp8" in self.quant_algo: |
| 606 | + if "w4a8" in self.quant_algo: |
| 607 | + # unpack, save as int32 |
| 608 | + weight = self.qweight.to(qinput.device) |
| 609 | + weight = unpack_weight_omni(weight, save_bit=4, pack_bit=8) |
| 610 | + weight_scale = self.weight_scale.to(qinput.device) |
| 611 | + |
| 612 | + scale = ( |
| 613 | + self.weight_scale_int4.float() |
| 614 | + .repeat_interleave(self.group_size, dim=-1) |
| 615 | + .to(qinput.device) |
| 616 | + ) # (out,in) |
| 617 | + # dequant to bf16 |
| 618 | + weight = weight * scale |
| 619 | + # quant to fp8 |
| 620 | + weight, _ = quantize_weight_per_tensor_fp8(weight, weight_scale) |
| 621 | + # to fp8 |
| 622 | + else: |
| 623 | + weight = self.weight.to(qinput.device) |
| 624 | + weight_scale = self.weight_scale.to(qinput.device) |
603 | 625 | output = gemm_fp8( |
604 | 626 | act=qinput, |
605 | 627 | act_scale=self.input_scale, |
606 | | - weight=self.weight, |
607 | | - weight_scale=self.weight_scale, |
| 628 | + weight=weight, |
| 629 | + weight_scale=weight_scale, |
608 | 630 | bias=self.bias, |
609 | 631 | out_dtype=x.dtype, |
610 | 632 | ) |
|
0 commit comments