Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lmdeploy/turbomind/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def _dequant_linear(linear: Linear, *, data_type) -> Linear:
"""Dequantize a quantized Linear to trivial.

``TrivialFormat.dequant`` is identity, so already-trivial inputs round-trip
safely. ``AWQFormat.dequant`` and ``FP8Format.dequant`` do real work.
GPTQ / CompressedTensor / MXFP4 inherit the base-class
safely. ``AWQFormat.dequant``, ``CompressedTensorFormat.dequant`` and
``FP8Format.dequant`` do real work. GPTQ / MXFP4 inherit the base-class
``NotImplementedError`` — calling ``_dequant_linear`` on one of those is a
broken-fusion-group configuration, and the raise names it at the call site.
"""
Expand Down
22 changes: 19 additions & 3 deletions lmdeploy/turbomind/weight_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def pack_u4_row(x: torch.Tensor) -> torch.Tensor:


def _zeros_int4_symmetric(scales: Tensor) -> Tensor:
"""Synthesize symmetric int4 zero-points (value = 8) matching *scales*
shape."""
return torch.full(scales.shape, 8, dtype=torch.uint8, device=scales.device)
"""Synthesize normalized symmetric int4 zero-points (value = 8) matching
*scales* shape."""
return torch.full(scales.shape, 8, dtype=scales.dtype, device=scales.device)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -329,6 +329,22 @@ def pack(self, tensor: Tensor, kind: str) -> PackedTensor:
def synthesize_zeros(self, scales: Tensor) -> Tensor:
return _zeros_int4_symmetric(scales)

def dequant(self, tensors, data_type):
weight = tensors['weight']
scales = tensors['scales']
zeros = tensors['zeros']

out_size = weight.shape[-1]
zeros = zeros[..., :out_size]

scales = scales.repeat_interleave(self.block_in, dim=0)[:weight.shape[0]]
zeros = zeros.repeat_interleave(self.block_in, dim=0)[:weight.shape[0]]
w = (weight.to(scales.dtype) - zeros.to(scales.dtype)) * scales
result: dict[str, Tensor] = {'weight': w}
if 'bias' in tensors:
result['bias'] = tensors['bias']
return result


class FP8Format(WeightFormat):
name = 'fp8'
Expand Down
Loading