diff --git a/lmdeploy/turbomind/linear.py b/lmdeploy/turbomind/linear.py index edbdc94fe1..fc436c8cef 100644 --- a/lmdeploy/turbomind/linear.py +++ b/lmdeploy/turbomind/linear.py @@ -73,8 +73,8 @@ def _dequant_linear(linear: Linear, *, data_type) -> Linear: """Dequantize a quantized Linear to trivial. ``TrivialFormat.dequant`` is identity, so already-trivial inputs round-trip - safely. ``AWQFormat.dequant`` and ``FP8Format.dequant`` do real work. - GPTQ / CompressedTensor / MXFP4 inherit the base-class + safely. ``AWQFormat.dequant``, ``CompressedTensorFormat.dequant`` and + ``FP8Format.dequant`` do real work. GPTQ / MXFP4 inherit the base-class ``NotImplementedError`` — calling ``_dequant_linear`` on one of those is a broken-fusion-group configuration, and the raise names it at the call site. """ diff --git a/lmdeploy/turbomind/weight_format.py b/lmdeploy/turbomind/weight_format.py index 1a370c36ee..27eb35d586 100644 --- a/lmdeploy/turbomind/weight_format.py +++ b/lmdeploy/turbomind/weight_format.py @@ -75,9 +75,9 @@ def pack_u4_row(x: torch.Tensor) -> torch.Tensor: def _zeros_int4_symmetric(scales: Tensor) -> Tensor: - """Synthesize symmetric int4 zero-points (value = 8) matching *scales* - shape.""" - return torch.full(scales.shape, 8, dtype=torch.uint8, device=scales.device) + """Synthesize normalized symmetric int4 zero-points (value = 8) matching + *scales* shape.""" + return torch.full(scales.shape, 8, dtype=scales.dtype, device=scales.device) # --------------------------------------------------------------------------- @@ -329,6 +329,22 @@ def pack(self, tensor: Tensor, kind: str) -> PackedTensor: def synthesize_zeros(self, scales: Tensor) -> Tensor: return _zeros_int4_symmetric(scales) + def dequant(self, tensors, data_type): + weight = tensors['weight'] + scales = tensors['scales'] + zeros = tensors['zeros'] + + out_size = weight.shape[-1] + zeros = zeros[..., :out_size] + + scales = scales.repeat_interleave(self.block_in, dim=0)[:weight.shape[0]] + zeros = zeros.repeat_interleave(self.block_in, dim=0)[:weight.shape[0]] + w = (weight.to(scales.dtype) - zeros.to(scales.dtype)) * scales + result: dict[str, Tensor] = {'weight': w} + if 'bias' in tensors: + result['bias'] = tensors['bias'] + return result + class FP8Format(WeightFormat): name = 'fp8'