Skip to content

Commit 1088ec5

Browse files
Updates for device agnosticism (#1601)
* Include device support tags for transformers multi-backend compatability; add xpu() and cpu() to Params4bit * Make test suite more device-agnostic * Additional device agnostic tests * Additional device agnosticism for tests * Add BNB_TEST_DEVICE env var to manually select device for unit tests * Include device support tags for transformers multi-backend compatability; add xpu() and cpu() to Params4bit * Make test suite more device-agnostic * Additional device agnostic tests * Additional device agnosticism for tests * Add BNB_TEST_DEVICE env var to manually select device for unit tests * Small bugfix for int8 test * Exclude backward() from code coverage reports * Params4bit: don't try to quantize when moving to meta device
1 parent 97073cd commit 1088ec5

File tree

13 files changed

+400
-444
lines changed

13 files changed

+400
-444
lines changed

bitsandbytes/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
from .optim import adam
2121

2222
# This is a signal for integrations with transformers/diffusers.
23-
# Eventually, we will remove this and check based on release version.
23+
# Eventually we may remove this but it is currently required for compatibility.
2424
features = {"multi-backend"}
2525
supported_torch_devices = {
26-
"cuda",
2726
"cpu",
28-
# "mps",
29-
# "xpu",
30-
# "hpu",
31-
# "npu",
27+
"cuda", # NVIDIA/AMD GPU
28+
"xpu", # Intel GPU
29+
"hpu", # Gaudi
30+
"npu", # Ascend NPU
31+
"mps", # Apple Silicon
3232
}
3333

3434
if torch.cuda.is_available():

bitsandbytes/autograd/_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
284284
dtype=torch.float16,
285285
)
286286

287-
if state.threshold > 0.0 and subA is not None:
287+
if state.threshold > 0.0 and subA is not None and subA.numel() > 0:
288288
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
289289

290290
if req_gradA:

bitsandbytes/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
341341
for i in range(gap):
342342
values.append(0)
343343
values.sort()
344-
code = torch.Tensor(values)
344+
code = torch.tensor(values)
345345
code /= code.max()
346346

347347
return code

bitsandbytes/nn/modules.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,15 @@ def _quantize(self, device):
306306
self.bnb_quantized = True
307307
return self
308308

309+
def cpu(self):
310+
return self.to(device="cpu")
311+
309312
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
310313
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
311314

315+
def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
316+
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
317+
312318
@overload
313319
def to(
314320
self: T,
@@ -326,7 +332,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
326332
def to(self, *args, **kwargs):
327333
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
328334

329-
if device is not None and device.type == "cuda" and not self.bnb_quantized:
335+
if device is not None and device.type != "meta" and not self.bnb_quantized:
330336
return self._quantize(device)
331337
else:
332338
if self.quant_state is not None:

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ include = ["bitsandbytes*"]
7979
[tool.setuptools.dynamic]
8080
version = {attr = "bitsandbytes.__version__"}
8181

82+
[tool.coverage.report]
83+
exclude_also = [
84+
# exclude backward() functions from coverage, as they are invoked from C++
85+
'def backward\(ctx'
86+
]
87+
8288
[tool.pytest.ini_options]
8389
addopts = "-rP -m 'not slow and not benchmark and not deprecated'"
8490
# ; --cov=bitsandbytes

tests/helpers.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import functools
12
from io import BytesIO
23
from itertools import product
4+
import os
35
import random
46
from typing import Any
57

@@ -13,6 +15,38 @@
1315
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
1416

1517

18+
@functools.cache
19+
def get_available_devices():
20+
if "BNB_TEST_DEVICE" in os.environ:
21+
# If the environment variable is set, use it directly.
22+
return [os.environ["BNB_TEST_DEVICE"]]
23+
24+
devices = ["cpu"]
25+
26+
if hasattr(torch, "accelerator"):
27+
# PyTorch 2.6+ - determine accelerator using agnostic API.
28+
if torch.accelerator.is_available():
29+
devices += [str(torch.accelerator.current_accelerator())]
30+
else:
31+
if torch.cuda.is_available():
32+
devices += ["cuda"]
33+
34+
if torch.backends.mps.is_available():
35+
devices += ["mps"]
36+
37+
if hasattr(torch, "xpu") and torch.xpu.is_available():
38+
devices += ["xpu"]
39+
40+
custom_backend_name = torch._C._get_privateuse1_backend_name()
41+
custom_backend_module = getattr(torch, custom_backend_name, None)
42+
custom_backend_is_available_fn = getattr(custom_backend_module, "is_available", None)
43+
44+
if custom_backend_is_available_fn and custom_backend_module.is_available():
45+
devices += [custom_backend_name]
46+
47+
return devices
48+
49+
1650
def torch_save_to_buffer(obj):
1751
buffer = BytesIO()
1852
torch.save(obj, buffer)

tests/test_autograd.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
BOOLEAN_TRIPLES,
77
TRUE_FALSE,
88
describe_dtype,
9+
get_available_devices,
910
id_formatter,
1011
)
1112

1213
TRANSPOSE_VALS = [(False, True), (False, False)]
1314

1415

16+
@pytest.mark.parametrize("device", get_available_devices())
1517
@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
1618
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
1719
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
@@ -27,32 +29,38 @@
2729
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
2830
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
2931
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
30-
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias):
32+
def test_matmullt(
33+
device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
34+
):
35+
if device != "cuda" and funcs[1] == bnb.research.switchback_bnb:
36+
# TODO: Deprecate/remove?
37+
pytest.skip("switchback_bnb only works on CUDA.")
38+
3139
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
3240
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
33-
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
41+
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device=device)
3442
if has_bias == False:
3543
req_grad = list(req_grad)
3644
req_grad[2] = False
3745

3846
for i in range(3):
3947
# normal multiply
4048
if funcs[0] in [torch.mm, torch.matmul]:
41-
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
49+
A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
4250
if decomp == 6.0:
4351
with torch.no_grad():
4452
A[:, outlier_dim] = 6.0
45-
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
53+
B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
4654
target = torch.randn(
4755
size=(dim2, dim4),
48-
device="cuda",
56+
device=device,
4957
requires_grad=req_grad[1],
5058
dtype=dtype,
5159
)
5260
bias = None
5361
bias2 = None
5462
if has_bias:
55-
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
63+
bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
5664
bias2 = bias.clone()
5765
torch.nn.init.xavier_uniform_(B)
5866
B2 = B.clone()
@@ -91,7 +99,8 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
9199
if has_fp16_weights:
92100
if any(req_grad):
93101
out_bnb.data.copy_(out_torch)
94-
torch.cuda.synchronize()
102+
if device == "cuda":
103+
torch.cuda.synchronize()
95104
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
96105
loss_bnb.backward()
97106
gradA1 = A.grad
@@ -135,6 +144,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
135144
torch.testing.assert_close(gradBias1, gradBias2)
136145

137146

147+
@pytest.mark.parametrize("device", get_available_devices())
138148
@pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1"))
139149
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
140150
@pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3"))
@@ -147,6 +157,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
147157
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
148158
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type"))
149159
def test_matmul_4bit(
160+
device,
150161
dim1,
151162
dim2,
152163
dim3,
@@ -159,6 +170,9 @@ def test_matmul_4bit(
159170
compress_statistics,
160171
quant_type,
161172
):
173+
if device == "cpu" and quant_type == "fp4":
174+
pytest.skip("Only nf4 is supported on CPU")
175+
162176
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
163177
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
164178
if has_bias == False:
@@ -168,13 +182,13 @@ def test_matmul_4bit(
168182
for i in range(3):
169183
# normal multiply
170184
if funcs[0] in [torch.mm, torch.matmul]:
171-
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
172-
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
173-
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
185+
A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype)
186+
B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype)
187+
target = torch.randn(size=(dim2, dim4), device=device, requires_grad=req_grad[1], dtype=dtype)
174188
bias = None
175189
bias2 = None
176190
if has_bias:
177-
bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2])
191+
bias = torch.randn(dim4, device=device, dtype=dtype, requires_grad=req_grad[2])
178192
bias2 = bias.clone()
179193
torch.nn.init.xavier_uniform_(B)
180194

@@ -204,7 +218,8 @@ def test_matmul_4bit(
204218
# assert err < 0.20
205219
if any(req_grad):
206220
out_bnb.data.copy_(out_torch)
207-
torch.cuda.synchronize()
221+
if device == "cuda":
222+
torch.cuda.synchronize()
208223
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
209224
loss_bnb.backward()
210225
gradA1 = A.grad

0 commit comments

Comments
 (0)