diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 847c7ef7a..7597ed9f9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -162,7 +162,7 @@ jobs: - name: Run tests run: pytest --durations=100 - test-cpu-ipex: + test-cpu-intel: if: github.repository == 'bitsandbytes-foundation/bitsandbytes' needs: build-cpu runs-on: banb-aws-general-8-plus-use1-public-80 @@ -186,7 +186,6 @@ jobs: - name: Install dependencies run: | pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu - pip install intel_extension_for_pytorch==2.7.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ pip install -e ".[test]" pip install pytest-cov @@ -196,9 +195,6 @@ jobs: - name: Show environment information run: python -m torch.utils.collect_env - - name: IPEX smoke test - run: python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__);" - - name: Run tests run: pytest --durations=100 @@ -286,15 +282,6 @@ jobs: fail-fast: false matrix: torch_version: ["2.7.1"] #["2.6.0", "2.7.1"] - ipex: [false] - # ipex: [true, false] - # include: - # - torch_version: "2.6.0" - # ipex: true - # ipex_version: "2.6.10+xpu" - # - torch_version: "2.7.1" - # ipex: true - # ipex_version: "2.7.10+xpu" runs-on: group: bandb-itac-bmsprpvc1550-8-1gpu env: @@ -330,10 +317,6 @@ jobs: - name: Install PyTorch run: pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/xpu - - name: Install IPEX - if: matrix.ipex == true - run: pip install intel_extension_for_pytorch==${{ matrix.ipex_version }} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ - - name: Install dependencies run: | pip install -e ".[test]" diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index 56bfaa357..9a3ac46ac 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -4,8 +4,6 @@ import torch -from .utils import ipex_cpu - _IS_TORCH_GTE_24 = False if hasattr(torch.library, "register_fake"): @@ -329,22 +327,3 @@ def _( ) torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - - -if ipex_cpu: - # Register the dequantize_nf4_ipex implementation - torch.library.define( - "bitsandbytes::dequantize_nf4_ipex", - "(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor", - ) - - @register_fake("bitsandbytes::dequantize_nf4_ipex") - def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - shape: Sequence[int], - dtype: torch.dtype, - ) -> torch.Tensor: - torch._check_is_size(blocksize) - return torch.empty(shape, dtype=dtype, device=A.device) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index c28b301b9..cb761fe24 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -422,9 +422,9 @@ def matmul( if threshold > 0.0: state.threshold = threshold # MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU - if state.is_training: - if A.device.type in ("cpu", "xpu"): - return MatMul8bitFp.apply(A, B, out, bias, state) + if state.is_training and A.device.type in ("cpu", "xpu"): + return MatMul8bitFp.apply(A, B, out, bias, state) + return MatMul8bitLt.apply(A, B, out, bias, state) @@ -437,16 +437,6 @@ def matmul_4bit( ): assert quant_state is not None - if A.device.type == "cpu" and A.requires_grad == False: - if getattr(quant_state, "ipex", False): - # IPEX CPU will change weight to 4D so don't need transpose - B = B.t() if B.dim() == 2 else B - out = F.gemv_4bit(A, B, out, state=quant_state) - if bias is not None: - out += bias - return out - else: - return MatMul4Bit.apply(A, B, out, bias, quant_state) if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: warn( diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index b715b1d00..78f9fef47 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -1,13 +1,14 @@ -from collections.abc import Sequence import ctypes as ct +import logging import torch from bitsandbytes.functional import get_ptr from ..._ops import register_kernel -from ...cextension import lib -from ...utils import ipex_cpu +from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib + +logger = logging.getLogger(__name__) # torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+. # However, we can overflow if we use this without AVX512_VNNI support. @@ -24,97 +25,80 @@ def _(A: torch.Tensor, B: torch.Tensor): ).reshape(*A.shape[:-1], B.shape[0]) -@register_kernel("bitsandbytes::quantize_blockwise", "cpu") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - - n = A.numel() - - # Only FP32 has c++ kernrl - if A.dtype == torch.float32: - blocks = -(n // -blocksize) - - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - lib.cquantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(n), - ) - else: - rem = n % blocksize - has_rem = rem > 0 - blocks = n // blocksize + has_rem - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - A_reshaped = A.reshape(n) - A_com = A_reshaped[: n - rem] - A_com_reshaped = A_com.reshape(n // blocksize, blocksize) - absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] - scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) - scaled_A = scaled_A.reshape(-1) - if has_rem: - absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() - scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) - scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) - - diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) - out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape) - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cpu") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - torch._check_is_size(blocksize) - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - - # Only FP32 has c++ kernrl - if dtype == torch.float32: - out = torch.empty_like(A, dtype=dtype) - - lib.cdequantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), - ) - else: - out = code[A.reshape(-1).int()] - blocks = out.shape[-1] // blocksize - res = out.shape[-1] % blocksize - if res != 0: - out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0) - out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) - out = out[: blocks * blocksize + res] - out = out.reshape(A.shape) - - return out - - -if ipex_cpu: - from bitsandbytes.utils import _reverse_4bit_compress_format - - @register_kernel("bitsandbytes::dequantize_nf4_ipex", "cpu") +if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): + logger.info("Loading C++ bitsandbytes kernels for CPU") + + @register_kernel("bitsandbytes::quantize_blockwise", "cpu") + def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + + n = A.numel() + + # Only FP32 has c++ kernrl + if A.dtype == torch.float32: + blocks = -(n // -blocksize) + + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(n), + ) + else: + rem = n % blocksize + has_rem = rem > 0 + blocks = n // blocksize + has_rem + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + A_reshaped = A.reshape(n) + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + + diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) + out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape) + + return out, absmax + + @register_kernel("bitsandbytes::dequantize_blockwise", "cpu") def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - shape: Sequence[int], - dtype: torch.dtype, + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype ) -> torch.Tensor: - ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2) - A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1) - return torch.ops.bitsandbytes.dequantize_4bit.default( - A, - absmax, - blocksize, - "nf4", - shape, - dtype, - ) + torch._check_is_size(blocksize) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + + # Only FP32 has c++ kernrl + if dtype == torch.float32: + out = torch.empty_like(A, dtype=dtype) + + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + else: + out = code[A.reshape(-1).int()] + blocks = out.shape[-1] // blocksize + res = out.shape[-1] % blocksize + if res != 0: + out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0) + out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) + out = out[: blocks * blocksize + res] + out = out.reshape(A.shape) + + return out +else: + logger.warning("Loading pytorch bitsandbytes kernels for CPU because no native library found.") diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 372632d17..5cd9eac67 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -13,7 +13,7 @@ from torch import Tensor from typing_extensions import deprecated -from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict +from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict from .cextension import HIP_ENVIRONMENT, lib @@ -1055,16 +1055,6 @@ def dequantize_4bit( if absmax.dtype != torch.float32: absmax = absmax.float() - # IPEX format is different, we need extra process. - if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4": - return torch.ops.bitsandbytes.dequantize_nf4_ipex( - A, - absmax, - quant_state.blocksize, - quant_state.shape, - quant_state.dtype, - ) - if out is not None: torch.ops.bitsandbytes.dequantize_4bit.out( A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out @@ -1633,25 +1623,6 @@ def gemv_4bit( if state.nested: absmax = dequantize_blockwise(absmax, state.state2) + state.offset - if getattr(state, "ipex", False) and state.quant_type == "nf4": - # compute_dtype: 1 indicates fp16, 2 indicates bf16 - compute_dtype = 2 if A.dtype == torch.bfloat16 else 1 - out = torch.ops.torch_ipex.woq_linear( - A, - B, - "nf4", - state.shape, - state.new_scales, - state.new_zeros, - None, - None, - state.blocksize, - compute_dtype, - 1, - state.compensation, - ) - return out - if out is not None: torch.ops.bitsandbytes.gemv_4bit.out( A, @@ -2338,37 +2309,3 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): C = 127.0 - - -def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor): - quant_state = linear.weight.quant_state - - if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - - quant_state.absmax = absmax - quant_state.nested = False - delattr(quant_state, "state2") - - assert x.device.type == "cpu" - converted_weight = _reverse_4bit_compress_format(linear.weight.data) - new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight( - converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), - "nf4", - quant_state.shape, # weight shape - quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales - None, # zero_points - None, # bias - None, # batch_size - quant_state.blocksize, - 2, - ) - - linear.weight.data = new_weight.data - linear.weight.quant_state.ipex = True - linear.weight.quant_state.new_scales = new_scales - linear.weight.quant_state.new_zeros = new_zeros - linear.weight.quant_state.compensation = compensation diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 9015665ee..464205fa5 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -12,14 +12,9 @@ import bitsandbytes as bnb from bitsandbytes.cextension import HIP_ENVIRONMENT -from bitsandbytes.functional import QuantState, _enable_ipex_fusion +from bitsandbytes.functional import QuantState from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import ( - INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, - OutlierTracer, - _reverse_4bit_compress_format, - ipex_cpu, -) +from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer T = TypeVar("T", bound="torch.nn.Module") @@ -444,7 +439,6 @@ def __init__( self.compute_type_is_set = False if compute_dtype is None else True self.quant_state = None self.quant_storage = quant_storage - self.ipex_linear_is_set = False def set_compute_type(self, x): if x.dtype in [torch.float32, torch.bfloat16]: @@ -471,40 +465,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): save weight and bias, then fill state_dict with components of quant_state """ - if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "ipex", False): - if self.weight.device.type == "cpu": - original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight( - self.weight, "nf4", self.weight.quant_state.shape, 2 - ) - self.weight.data = _reverse_4bit_compress_format(original_weight.data) - - self.weight.quant_state.ipex = False - self.ipex_linear_is_set = False - super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias if getattr(self.weight, "quant_state", None) is not None: for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() - def set_ipex_linear(self, x: torch.Tensor): - if ( - not getattr(self.weight.quant_state, "ipex", False) - and self.weight.data.dtype == torch.uint8 - and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 - and self.weight.quant_state.quant_type == "nf4" - and x.device.type == "cpu" - and not self.training - and not x.requires_grad - ): - _enable_ipex_fusion(self, x) - def forward(self, x: torch.Tensor): - # Check if ipex fusion can be used - if not self.ipex_linear_is_set and ipex_cpu: - self.set_ipex_linear(x) - self.ipex_linear_is_set = True - fix_4bit_weight_quant_state_from_module(self) # weights are cast automatically as Int8Params, but the bias has to be cast manually @@ -520,8 +487,7 @@ def forward(self, x: torch.Tensor): x = x.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype) - # IPEX CPU will change weight to 4D so don't need transpose - weight = self.weight.t() if self.weight.dim() == 2 else self.weight + weight = self.weight.t() return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) @@ -676,7 +642,7 @@ def to(self, *args, **kwargs): if device is not None and device.type != "meta" and self.data.device.type == "cpu": if device.type != "cpu" or self.data.dtype != torch.int8: return self._quantize(device) - elif self.data.dtype == torch.int8 and device.type == "cpu" and ipex_cpu: + elif self.data.dtype == torch.int8 and device.type == "cpu": self.CB = self.data new_param = Int8Params( diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 4328a241c..0828dd295 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -4,14 +4,6 @@ import torch -try: - # to support Intel CPU backend - import intel_extension_for_pytorch as ipex - - ipex_cpu = ipex if ipex._C._has_cpu() else None -except BaseException: - ipex_cpu = None - def outlier_hook(module, input): assert isinstance(module, torch.nn.Linear) @@ -46,14 +38,6 @@ def outlier_hook(module, input): hook.remove() -# convert btw standard 4-bit compression format and ipex compression format -def _reverse_4bit_compress_format(weight: torch.Tensor): - out_1 = (weight & 0xF0) >> 4 - out_2 = (weight & 0xF) << 4 - out = out_1 | out_2 - return out - - class OutlierTracer: _instance = None diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 9b3449870..7396c7dcf 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -138,8 +138,8 @@ We provide an early preview of support for AMD and Intel hardware as part of a d | **Backend** | **Supported Versions** | **Python versions** | **Architecture Support** | **Status** | |-------------|------------------------|---------------------------|-------------------------|------------| | **AMD ROCm** | 6.1+ | 3.10+ | minimum CDNA - `gfx90a`, RDNA - `gfx1100` | Alpha | -| **Intel CPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel CPU | Alpha | -| **Intel GPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel GPU | Experimental | +| **Intel CPU** | v2.4.0+ | 3.10+ | Intel CPU | Alpha | +| **Intel GPU** | v2.7.0+ | 3.10+ | Intel GPU | Experimental | | **Ascend NPU** | 2.1.0+ (`torch_npu`) | 3.10+ | Ascend NPU | Experimental | For each supported backend, follow the respective instructions below: @@ -179,7 +179,6 @@ pip install torch --index-url https://download.pytorch.org/whl/rocm6.3/ * A compatible PyTorch version with Intel XPU support is required. It is recommended to use the latest stable release. See [Getting Started on Intel GPU](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html) for guidance. -* The [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/) is recommended for performance improvements. @@ -235,9 +234,9 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise -#### Intel CPU + XPU +#### Intel CPU + GPU(XPU) -CPU needs to build CPU C++ codes, while xpu needs to build sycl codes. +CPU needs to build CPU C++ codes, while XPU needs to build sycl codes. Run `export bnb_device=xpu` if you are using xpu, run `export bnb_device=cpu` if you are using cpu. ``` git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ @@ -245,7 +244,6 @@ cmake -DCOMPUTE_BACKEND=$bnb_device -S . make pip install -e . ``` -Note: You can run `pip install intel_extension_for_pytorch to get better performance on CPU` diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 86726bd44..0e5f7bc18 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -272,14 +272,11 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): # Test with gradients. Currently only works with threshold=0. # Has a strange regression on Linux aarch64 CPU in torch==2.6.0. - # There is also an issue with torch==2.7.0 on x86-64 with IPEX. is_broken_platform = ( device == "cpu" and platform.system() == "Linux" - and ( - (platform.machine() == "aarch64" and (2, 6) <= torch.__version__ < (2, 7)) - or (platform.machine() == "x86_64" and bnb.functional.ipex_cpu) - ) + and platform.machine() == "aarch64" + and (2, 6) <= torch.__version__ < (2, 7) ) if threshold == 0 and not is_broken_platform: