Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions bitsandbytes/backends/xpu/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Sequence
import ctypes as ct
import warnings
import logging

import torch

Expand All @@ -10,6 +10,8 @@
from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
from ..utils import triton_available

logger = logging.getLogger(__name__)


def _dequantize_4bit_impl(
A: torch.Tensor,
Expand Down Expand Up @@ -135,6 +137,7 @@ def _gemv_4bit_impl(

# SYCL should be faster for xpu, so at first checking if it is available.
if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
logger.info("Loading sycl bitsandbytes kernels for XPU")

@register_kernel("bitsandbytes::dequantize_4bit", "xpu")
def _(
Expand Down Expand Up @@ -201,6 +204,7 @@ def _(
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
elif triton_available:
logger.info("Loading triton bitsandbytes kernels for XPU")
from ..triton import ops as triton_ops

register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
Expand All @@ -211,6 +215,4 @@ def _(
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
else:
warnings.warn(
"XPU available but no native library or triton packages found. Please follow the installation instructions in the documentation."
)
logger.warning("Loading pytorch bitsandbytes kernels for XPU because no native library or triton packages found.")
32 changes: 20 additions & 12 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,19 +303,27 @@ def get_native_library() -> BNBNativeLibrary:

ROCM_GPU_ARCH = get_rocm_gpu_arch()

try:
if torch.version.hip:
HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm"
else:
HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA"
HIP_ENVIRONMENT = False
BNB_BACKEND = "CPU"
if torch.version.hip:
HIP_ENVIRONMENT = True
BNB_BACKEND = "ROCm"
elif torch.cuda.is_available():
BNB_BACKEND = "CUDA"
elif torch._C._has_xpu:
BNB_BACKEND = "XPU"

try:
lib = get_native_library()
except Exception as e:
error_msg = str(e)
logger.error(
f"bitsandbytes library load error: {error_msg}",
exc_info=True,
)
if BNB_BACKEND in ("CPU", "XPU"):
lib = ErrorHandlerMockBNBNativeLibrary("XPU/CPU can run without native library.")
else:
error_msg = str(e)
logger.error(
f"bitsandbytes library load error: {error_msg}",
exc_info=True,
)

# create a mock with error messaging as fallback
lib = ErrorHandlerMockBNBNativeLibrary(error_msg)
# create a mock with error messaging as fallback
lib = ErrorHandlerMockBNBNativeLibrary(error_msg)
Loading