88Performance: ~2-3x faster than separate dequant + matmul
99"""
1010
11- from typing import Optional , Tuple
11+ from typing import Callable , Dict , Optional , Tuple
1212import torch
1313import torch .nn as nn
1414
@@ -27,6 +27,69 @@ def is_triton_available() -> bool:
2727 return _TRITON_AVAILABLE
2828
2929
30+ def triton_q8_0_quantize (weight : torch .Tensor , eps : float = 1e-8 ) -> Tuple [torch .Tensor , torch .Tensor ]:
31+ """
32+ Quantize a weight matrix to Q8_0 format (per-column symmetric int8).
33+
34+ Returns:
35+ qweight: int8 tensor [in_features, out_features]
36+ scales: fp tensor [1, out_features]
37+ """
38+ if weight .dim () != 2 :
39+ raise ValueError (f"Q8_0 quantization expects a 2D tensor, got shape={ tuple (weight .shape )} " )
40+
41+ max_abs = weight .abs ().amax (dim = 0 , keepdim = True ).clamp (min = eps )
42+ scale = max_abs / 127.0
43+ qweight = torch .clamp (torch .round (weight / scale ), - 128 , 127 ).to (torch .int8 )
44+ return qweight , scale .to (weight .dtype )
45+
46+
47+ def triton_q4_0_quantize (weight : torch .Tensor , eps : float = 1e-8 ) -> Tuple [torch .Tensor , torch .Tensor ]:
48+ """
49+ Quantize a weight matrix to Q4_0 format (per-column symmetric 4-bit stored in int8).
50+
51+ Returns:
52+ qweight: int8 tensor [in_features, out_features] with values in [-8, 7]
53+ scales: fp tensor [1, out_features]
54+ """
55+ if weight .dim () != 2 :
56+ raise ValueError (f"Q4_0 quantization expects a 2D tensor, got shape={ tuple (weight .shape )} " )
57+
58+ max_abs = weight .abs ().amax (dim = 0 , keepdim = True ).clamp (min = eps )
59+ scale = max_abs / 7.0
60+ qweight = torch .clamp (torch .round (weight / scale ), - 8 , 7 ).to (torch .int8 )
61+ return qweight , scale .to (weight .dtype )
62+
63+
64+ def int4_matmul (
65+ x : torch .Tensor ,
66+ qweight : torch .Tensor ,
67+ scales : torch .Tensor ,
68+ bias : Optional [torch .Tensor ] = None ,
69+ ) -> torch .Tensor :
70+ """
71+ INT4 matmul path backed by fused dequant+matmul on CUDA/Triton when available.
72+
73+ Args:
74+ x: Input [..., in_features]
75+ qweight: Quantized int4 values stored in int8, shape [in_features, out_features]
76+ scales: Per-column scales, shape [1, out_features] or [in_features/group, out_features]
77+ bias: Optional bias [out_features]
78+ """
79+ # Per-column case uses [1, N] zeros; grouped quantization uses zeros shaped like scales.
80+ is_per_column = scales .shape [0 ] == 1
81+ zeros = scales .new_zeros ((1 , scales .shape [1 ])) if is_per_column else scales .new_zeros (scales .shape )
82+ group_size = qweight .shape [0 ] if is_per_column else max (qweight .shape [0 ] // scales .shape [0 ], 1 )
83+ return fused_dequant_matmul (
84+ x = x ,
85+ qweight = qweight ,
86+ scales = scales ,
87+ zeros = zeros ,
88+ bias = bias ,
89+ group_size = group_size ,
90+ )
91+
92+
3093if _TRITON_AVAILABLE :
3194 @triton .jit
3295 def _fused_dequant_matmul_kernel (
@@ -462,3 +525,9 @@ def extra_repr(self) -> str:
462525 f'group_size={ self .group_size } , '
463526 f'triton={ self ._use_triton } '
464527 )
528+
529+
530+ triton_quantizers : Dict [str , Callable [[torch .Tensor ], Tuple [torch .Tensor , torch .Tensor ]]] = {
531+ "q4_0" : triton_q4_0_quantize ,
532+ "q8_0" : triton_q8_0_quantize ,
533+ }
0 commit comments