Skip to content

Commit 404c40f

Browse files
authored
Simplest MPS backend: adding kernels using kernel builder and kernels-community (#1875)
* initial commit * fix * fix * fix linter * fix * add kernels-community kernel
1 parent eb3064f commit 404c40f

File tree

4 files changed

+150
-2
lines changed

4 files changed

+150
-2
lines changed

bitsandbytes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
if hasattr(torch, "xpu") and torch.xpu.is_available():
3939
from .backends.xpu import ops as xpu_ops
4040

41+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
42+
from .backends.mps import ops as mps_ops
43+
4144
if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"):
4245
# In case not automatically imported
4346
import habana_frameworks.torch

bitsandbytes/backends/mps/__init__.py

Whitespace-only changes.

bitsandbytes/backends/mps/ops.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""MPS backend for bitsandbytes 4-bit quantization ops.
2+
3+
Uses Metal kernels from kernels-community/bitsandbytes-mps via the
4+
HuggingFace Kernels Hub.
5+
"""
6+
7+
from collections.abc import Sequence
8+
from math import prod
9+
10+
import torch
11+
12+
from ..._ops import register_kernel
13+
14+
# ---------------------------------------------------------------------------
15+
# Quant-type mapping: BnB uses strings, our Metal kernel uses ints.
16+
# ---------------------------------------------------------------------------
17+
_QUANT_MAP = {"fp4": 1, "nf4": 2}
18+
_kernel = None
19+
20+
21+
def _get_kernel():
22+
"""Lazily load the bitsandbytes-mps kernel (local build or Hub)."""
23+
global _kernel
24+
if _kernel is None:
25+
from kernels import get_kernel
26+
27+
# TODO: use kernels-community/bitsandbytes-mps when it's available
28+
_kernel = get_kernel("kernels-community/bitsandbytes-mps")
29+
return _kernel
30+
31+
32+
# ============================= quantize_4bit =================================
33+
34+
35+
@register_kernel("bitsandbytes::quantize_4bit", "mps")
36+
def _(
37+
A: torch.Tensor,
38+
blocksize: int,
39+
quant_type: str,
40+
quant_storage: torch.dtype,
41+
) -> tuple[torch.Tensor, torch.Tensor]:
42+
torch._check(blocksize in [64, 128, 256, 512])
43+
torch._check(quant_type in ("fp4", "nf4"))
44+
45+
k = _get_kernel()
46+
packed, absmax = k.quantize_4bit(A.contiguous(), blocksize, _QUANT_MAP[quant_type])
47+
48+
packed = packed.view(quant_storage).unsqueeze(1)
49+
50+
return packed, absmax
51+
52+
53+
# ============================ dequantize_4bit ================================
54+
55+
56+
def _dequantize_4bit_impl(
57+
A: torch.Tensor,
58+
absmax: torch.Tensor,
59+
blocksize: int,
60+
quant_type: str,
61+
shape: Sequence[int],
62+
dtype: torch.dtype,
63+
) -> torch.Tensor:
64+
if A.dtype != torch.uint8:
65+
A = A.view(torch.uint8)
66+
67+
numel = prod(shape)
68+
k = _get_kernel()
69+
out = k.dequantize_4bit(A, absmax, blocksize, _QUANT_MAP[quant_type], numel, dtype)
70+
return out.reshape(shape)
71+
72+
73+
@register_kernel("bitsandbytes::dequantize_4bit", "mps")
74+
def _(
75+
A: torch.Tensor,
76+
absmax: torch.Tensor,
77+
blocksize: int,
78+
quant_type: str,
79+
shape: Sequence[int],
80+
dtype: torch.dtype,
81+
) -> torch.Tensor:
82+
torch._check(blocksize in [64, 128, 256, 512])
83+
torch._check(quant_type in ("fp4", "nf4"))
84+
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
85+
86+
87+
@register_kernel("bitsandbytes::dequantize_4bit.out", "mps")
88+
def _(
89+
A: torch.Tensor,
90+
absmax: torch.Tensor,
91+
blocksize: int,
92+
quant_type: str,
93+
shape: Sequence[int],
94+
dtype: torch.dtype,
95+
out: torch.Tensor,
96+
) -> None:
97+
result = _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
98+
out.copy_(result)
99+
100+
101+
# ================================ gemv_4bit ==================================
102+
103+
104+
def _gemv_4bit_impl(
105+
A: torch.Tensor,
106+
B: torch.Tensor,
107+
shapeB: Sequence[int],
108+
absmax: torch.Tensor,
109+
code: torch.Tensor,
110+
blocksize: int,
111+
) -> torch.Tensor:
112+
if B.dtype != torch.uint8:
113+
B = B.view(torch.uint8)
114+
115+
quant_type_int = _QUANT_MAP["fp4"] if code[1] > 0 else _QUANT_MAP["nf4"]
116+
output_features = shapeB[0]
117+
118+
k = _get_kernel()
119+
return k.gemv_4bit(A, B, absmax, output_features, blocksize, quant_type_int)
120+
121+
122+
@register_kernel("bitsandbytes::gemv_4bit", "mps")
123+
def _(
124+
A: torch.Tensor,
125+
B: torch.Tensor,
126+
shapeB: Sequence[int],
127+
absmax: torch.Tensor,
128+
code: torch.Tensor,
129+
blocksize: int,
130+
) -> torch.Tensor:
131+
return _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize)
132+
133+
134+
@register_kernel("bitsandbytes::gemv_4bit.out", "mps")
135+
def _(
136+
A: torch.Tensor,
137+
B: torch.Tensor,
138+
shapeB: Sequence[int],
139+
absmax: torch.Tensor,
140+
code: torch.Tensor,
141+
blocksize: int,
142+
out: torch.Tensor,
143+
) -> None:
144+
result = _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize)
145+
out.copy_(result)

tests/test_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
219219
out_features = 1024
220220
in_features = 256
221221

222-
if device == "cpu" and blocksize > in_features:
223-
pytest.skip("CPU implementation only suppoer blocksize <= in_features")
222+
if device in ("cpu", "mps") and blocksize > in_features:
223+
pytest.skip("CPU/MPS implementation only supports blocksize <= in_features")
224224

225225
A = torch.randn((1, 1, in_features), dtype=dtype, device=device)
226226
B = torch.randn((out_features, in_features), dtype=dtype, device=A.device)

0 commit comments

Comments
 (0)