Skip to content

Commit c735f58

Browse files
feat: Update QuantLLM to v2.1 with new quantization methods and enhanced kernel functionality
1 parent 5ab4986 commit c735f58

4 files changed

Lines changed: 92 additions & 6 deletions

File tree

quantllm/__init__.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
QuantLLM v2.0 - Ultra-fast LLM Quantization & GGUF Export
2+
QuantLLM v2.1 - Ultra-fast LLM Quantization & GGUF Export
33
44
The simplest way to load, quantize, fine-tune, and export LLMs.
55
@@ -13,16 +13,19 @@
1313
>>> from quantllm import turbo
1414
>>>
1515
>>> # Load any model (auto-quantizes to 4-bit)
16-
>>> model = turbo("meta-llama/Llama-3.2-3B")
16+
>>> model = turbo(
17+
... "meta-llama/Llama-3.2-3B",
18+
... config={"format": "gguf", "quantization": "Q4_K_M", "push_format": "gguf"},
19+
... )
1720
>>>
1821
>>> # Generate text
1922
>>> model.generate("Hello, world!")
2023
>>>
2124
>>> # Export to GGUF with Q4_K_M quantization
22-
>>> model.export("gguf", "model.Q4_K_M.gguf", quantization="Q4_K_M")
25+
>>> model.export()
2326
>>>
2427
>>> # Push to HuggingFace Hub
25-
>>> model.push("username/my-model", format="gguf", quantization="Q4_K_M")
28+
>>> model.push("username/my-model")
2629
"""
2730

2831
import os
@@ -32,6 +35,7 @@
3235
from .core import (
3336
turbo,
3437
TurboModel,
38+
register_architecture,
3539
SmartConfig,
3640
HardwareProfiler,
3741
ModelAnalyzer,
@@ -73,7 +77,7 @@
7377
# Configure logging (minimal by default)
7478
configure_logging("WARNING")
7579

76-
__version__ = "2.0.0"
80+
__version__ = "2.1.0rc1"
7781
__title__ = "QuantLLM"
7882
__description__ = "Ultra-fast LLM Quantization & Export (GGUF, ONNX, MLX)"
7983
__author__ = "Dark Coder"
@@ -114,6 +118,7 @@ def show_banner(force: bool = False):
114118
# Main API
115119
"turbo",
116120
"TurboModel",
121+
"register_architecture",
117122
"SmartConfig",
118123
"HardwareProfiler",
119124
"ModelAnalyzer",

quantllm/kernels/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@
77
from .triton import (
88
TritonQuantizedLinear,
99
fused_dequant_matmul,
10+
int4_matmul,
1011
is_triton_available,
12+
triton_q4_0_quantize,
13+
triton_q8_0_quantize,
1114
)
1215

1316
__all__ = [
1417
"TritonQuantizedLinear",
1518
"fused_dequant_matmul",
19+
"int4_matmul",
1620
"is_triton_available",
21+
"triton_q4_0_quantize",
22+
"triton_q8_0_quantize",
1723
]

quantllm/kernels/triton/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@
77
from .quantized_linear import (
88
TritonQuantizedLinear,
99
fused_dequant_matmul,
10+
int4_matmul,
1011
is_triton_available,
12+
triton_q4_0_quantize,
13+
triton_q8_0_quantize,
1114
)
1215

1316
__all__ = [
1417
"TritonQuantizedLinear",
1518
"fused_dequant_matmul",
19+
"int4_matmul",
1620
"is_triton_available",
21+
"triton_q4_0_quantize",
22+
"triton_q8_0_quantize",
1723
]

quantllm/kernels/triton/quantized_linear.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Performance: ~2-3x faster than separate dequant + matmul
99
"""
1010

11-
from typing import Optional, Tuple
11+
from typing import Callable, Dict, Optional, Tuple
1212
import torch
1313
import torch.nn as nn
1414

@@ -27,6 +27,69 @@ def is_triton_available() -> bool:
2727
return _TRITON_AVAILABLE
2828

2929

30+
def triton_q8_0_quantize(weight: torch.Tensor, eps: float = 1e-8) -> Tuple[torch.Tensor, torch.Tensor]:
31+
"""
32+
Quantize a weight matrix to Q8_0 format (per-column symmetric int8).
33+
34+
Returns:
35+
qweight: int8 tensor [in_features, out_features]
36+
scales: fp tensor [1, out_features]
37+
"""
38+
if weight.dim() != 2:
39+
raise ValueError(f"Q8_0 quantization expects a 2D tensor, got shape={tuple(weight.shape)}")
40+
41+
max_abs = weight.abs().amax(dim=0, keepdim=True).clamp(min=eps)
42+
scale = max_abs / 127.0
43+
qweight = torch.clamp(torch.round(weight / scale), -128, 127).to(torch.int8)
44+
return qweight, scale.to(weight.dtype)
45+
46+
47+
def triton_q4_0_quantize(weight: torch.Tensor, eps: float = 1e-8) -> Tuple[torch.Tensor, torch.Tensor]:
48+
"""
49+
Quantize a weight matrix to Q4_0 format (per-column symmetric 4-bit stored in int8).
50+
51+
Returns:
52+
qweight: int8 tensor [in_features, out_features] with values in [-8, 7]
53+
scales: fp tensor [1, out_features]
54+
"""
55+
if weight.dim() != 2:
56+
raise ValueError(f"Q4_0 quantization expects a 2D tensor, got shape={tuple(weight.shape)}")
57+
58+
max_abs = weight.abs().amax(dim=0, keepdim=True).clamp(min=eps)
59+
scale = max_abs / 7.0
60+
qweight = torch.clamp(torch.round(weight / scale), -8, 7).to(torch.int8)
61+
return qweight, scale.to(weight.dtype)
62+
63+
64+
def int4_matmul(
65+
x: torch.Tensor,
66+
qweight: torch.Tensor,
67+
scales: torch.Tensor,
68+
bias: Optional[torch.Tensor] = None,
69+
) -> torch.Tensor:
70+
"""
71+
INT4 matmul path backed by fused dequant+matmul on CUDA/Triton when available.
72+
73+
Args:
74+
x: Input [..., in_features]
75+
qweight: Quantized int4 values stored in int8, shape [in_features, out_features]
76+
scales: Per-column scales, shape [1, out_features] or [in_features/group, out_features]
77+
bias: Optional bias [out_features]
78+
"""
79+
# Per-column case uses [1, N] zeros; grouped quantization uses zeros shaped like scales.
80+
is_per_column = scales.shape[0] == 1
81+
zeros = scales.new_zeros((1, scales.shape[1])) if is_per_column else scales.new_zeros(scales.shape)
82+
group_size = qweight.shape[0] if is_per_column else max(qweight.shape[0] // scales.shape[0], 1)
83+
return fused_dequant_matmul(
84+
x=x,
85+
qweight=qweight,
86+
scales=scales,
87+
zeros=zeros,
88+
bias=bias,
89+
group_size=group_size,
90+
)
91+
92+
3093
if _TRITON_AVAILABLE:
3194
@triton.jit
3295
def _fused_dequant_matmul_kernel(
@@ -462,3 +525,9 @@ def extra_repr(self) -> str:
462525
f'group_size={self.group_size}, '
463526
f'triton={self._use_triton}'
464527
)
528+
529+
530+
triton_quantizers: Dict[str, Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]] = {
531+
"q4_0": triton_q4_0_quantize,
532+
"q8_0": triton_q8_0_quantize,
533+
}

0 commit comments

Comments
 (0)