diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py index ed59ed2f2..ddcff8c8c 100644 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -1,6 +1,6 @@ from collections.abc import Sequence import ctypes as ct -import warnings +import logging import torch @@ -10,6 +10,8 @@ from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib from ..utils import triton_available +logger = logging.getLogger(__name__) + def _dequantize_4bit_impl( A: torch.Tensor, @@ -135,6 +137,7 @@ def _gemv_4bit_impl( # SYCL should be faster for xpu, so at first checking if it is available. if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): + logger.info("Loading sycl bitsandbytes kernels for XPU") @register_kernel("bitsandbytes::dequantize_4bit", "xpu") def _( @@ -201,6 +204,7 @@ def _( torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) elif triton_available: + logger.info("Loading triton bitsandbytes kernels for XPU") from ..triton import ops as triton_ops register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise) @@ -211,6 +215,4 @@ def _( register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit) register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) else: - warnings.warn( - "XPU available but no native library or triton packages found. Please follow the installation instructions in the documentation." - ) + logger.warning("Loading pytorch bitsandbytes kernels for XPU because no native library or triton packages found.") diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 29101c76c..c7e407efd 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -303,19 +303,27 @@ def get_native_library() -> BNBNativeLibrary: ROCM_GPU_ARCH = get_rocm_gpu_arch() -try: - if torch.version.hip: - HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" - else: - HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA" +HIP_ENVIRONMENT = False +BNB_BACKEND = "CPU" +if torch.version.hip: + HIP_ENVIRONMENT = True + BNB_BACKEND = "ROCm" +elif torch.cuda.is_available(): + BNB_BACKEND = "CUDA" +elif torch._C._has_xpu: + BNB_BACKEND = "XPU" +try: lib = get_native_library() except Exception as e: - error_msg = str(e) - logger.error( - f"bitsandbytes library load error: {error_msg}", - exc_info=True, - ) + if BNB_BACKEND in ("CPU", "XPU"): + lib = ErrorHandlerMockBNBNativeLibrary("XPU/CPU can run without native library.") + else: + error_msg = str(e) + logger.error( + f"bitsandbytes library load error: {error_msg}", + exc_info=True, + ) - # create a mock with error messaging as fallback - lib = ErrorHandlerMockBNBNativeLibrary(error_msg) + # create a mock with error messaging as fallback + lib = ErrorHandlerMockBNBNativeLibrary(error_msg)