Skip to content

Commit 3bff01d

Browse files
Fix: Python 3.14 compatibility with PyTorch 2.9 (#1831)
* Fix: Python 3.14 / torch.compile compatibility * Skip torch.compile test on Python 3.14 and torch < 2.10 (not supported) * Format
1 parent c664054 commit 3bff01d

File tree

3 files changed

+37
-2
lines changed

3 files changed

+37
-2
lines changed

bitsandbytes/backends/default/ops.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Sequence
2+
from functools import wraps
23
from math import prod, sqrt
34
from typing import Optional
45

@@ -8,6 +9,32 @@
89
from ..utils import CODE
910

1011

12+
def _try_torch_compile(func=None, **compile_kwargs):
13+
"""
14+
Wrapper around torch.compile that falls back to the original function if compilation fails.
15+
"""
16+
17+
def decorator(fn):
18+
try:
19+
compiled_fn = torch.compile(fn, **compile_kwargs)
20+
21+
@wraps(fn)
22+
def wrapper(*args, **kwargs):
23+
try:
24+
return compiled_fn(*args, **kwargs)
25+
except Exception:
26+
return fn(*args, **kwargs)
27+
28+
return wrapper
29+
except Exception:
30+
return fn
31+
32+
if func is None:
33+
return decorator
34+
else:
35+
return decorator(func)
36+
37+
1138
@register_kernel("bitsandbytes::int8_mm_dequant", "default")
1239
def _(
1340
A: torch.Tensor,
@@ -332,7 +359,7 @@ def _(
332359
}
333360

334361

335-
@torch.compile
362+
@_try_torch_compile
336363
def _optimizer_precondition_32bit(
337364
g: torch.Tensor,
338365
p: torch.Tensor,
@@ -393,7 +420,7 @@ def _optimizer_precondition_32bit(
393420
unorm_vec.add_(total_norm)
394421

395422

396-
@torch.compile
423+
@_try_torch_compile
397424
def _optimizer_update_32bit(
398425
g: torch.Tensor,
399426
p: torch.Tensor,

tests/test_linear4bit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import pickle
44
import platform
5+
import sys
56
from tempfile import TemporaryDirectory
67

78
import pytest
@@ -320,6 +321,9 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
320321
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
321322
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
322323
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
324+
@pytest.mark.skipif(
325+
torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10"
326+
)
323327
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
324328
if device == "hpu" and not is_supported_on_hpu(quant_type):
325329
pytest.skip("This configuration is not supported on HPU.")

tests/test_linear8bitlt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import pickle
55
import platform
6+
import sys
67
from tempfile import TemporaryDirectory
78

89
import pytest
@@ -234,6 +235,9 @@ def test_linear8bit_serialization(linear8bit):
234235
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
235236
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
236237
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
238+
@pytest.mark.skipif(
239+
torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10"
240+
)
237241
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
238242
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
239243
if device == "cuda" and platform.system() == "Windows":

0 commit comments

Comments
 (0)