diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 847c7ef7a..7597ed9f9 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -162,7 +162,7 @@ jobs:
- name: Run tests
run: pytest --durations=100
- test-cpu-ipex:
+ test-cpu-intel:
if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
needs: build-cpu
runs-on: banb-aws-general-8-plus-use1-public-80
@@ -186,7 +186,6 @@ jobs:
- name: Install dependencies
run: |
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu
- pip install intel_extension_for_pytorch==2.7.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
pip install -e ".[test]"
pip install pytest-cov
@@ -196,9 +195,6 @@ jobs:
- name: Show environment information
run: python -m torch.utils.collect_env
- - name: IPEX smoke test
- run: python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__);"
-
- name: Run tests
run: pytest --durations=100
@@ -286,15 +282,6 @@ jobs:
fail-fast: false
matrix:
torch_version: ["2.7.1"] #["2.6.0", "2.7.1"]
- ipex: [false]
- # ipex: [true, false]
- # include:
- # - torch_version: "2.6.0"
- # ipex: true
- # ipex_version: "2.6.10+xpu"
- # - torch_version: "2.7.1"
- # ipex: true
- # ipex_version: "2.7.10+xpu"
runs-on:
group: bandb-itac-bmsprpvc1550-8-1gpu
env:
@@ -330,10 +317,6 @@ jobs:
- name: Install PyTorch
run: pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/xpu
- - name: Install IPEX
- if: matrix.ipex == true
- run: pip install intel_extension_for_pytorch==${{ matrix.ipex_version }} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
-
- name: Install dependencies
run: |
pip install -e ".[test]"
diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py
index 56bfaa357..9a3ac46ac 100644
--- a/bitsandbytes/_ops.py
+++ b/bitsandbytes/_ops.py
@@ -4,8 +4,6 @@
import torch
-from .utils import ipex_cpu
-
_IS_TORCH_GTE_24 = False
if hasattr(torch.library, "register_fake"):
@@ -329,22 +327,3 @@ def _(
)
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
-
-
-if ipex_cpu:
- # Register the dequantize_nf4_ipex implementation
- torch.library.define(
- "bitsandbytes::dequantize_nf4_ipex",
- "(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor",
- )
-
- @register_fake("bitsandbytes::dequantize_nf4_ipex")
- def _(
- A: torch.Tensor,
- absmax: torch.Tensor,
- blocksize: int,
- shape: Sequence[int],
- dtype: torch.dtype,
- ) -> torch.Tensor:
- torch._check_is_size(blocksize)
- return torch.empty(shape, dtype=dtype, device=A.device)
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index c28b301b9..cb761fe24 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -422,9 +422,9 @@ def matmul(
if threshold > 0.0:
state.threshold = threshold
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
- if state.is_training:
- if A.device.type in ("cpu", "xpu"):
- return MatMul8bitFp.apply(A, B, out, bias, state)
+ if state.is_training and A.device.type in ("cpu", "xpu"):
+ return MatMul8bitFp.apply(A, B, out, bias, state)
+
return MatMul8bitLt.apply(A, B, out, bias, state)
@@ -437,16 +437,6 @@ def matmul_4bit(
):
assert quant_state is not None
- if A.device.type == "cpu" and A.requires_grad == False:
- if getattr(quant_state, "ipex", False):
- # IPEX CPU will change weight to 4D so don't need transpose
- B = B.t() if B.dim() == 2 else B
- out = F.gemv_4bit(A, B, out, state=quant_state)
- if bias is not None:
- out += bias
- return out
- else:
- return MatMul4Bit.apply(A, B, out, bias, quant_state)
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py
index b715b1d00..78f9fef47 100644
--- a/bitsandbytes/backends/cpu/ops.py
+++ b/bitsandbytes/backends/cpu/ops.py
@@ -1,13 +1,14 @@
-from collections.abc import Sequence
import ctypes as ct
+import logging
import torch
from bitsandbytes.functional import get_ptr
from ..._ops import register_kernel
-from ...cextension import lib
-from ...utils import ipex_cpu
+from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
+
+logger = logging.getLogger(__name__)
# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
# However, we can overflow if we use this without AVX512_VNNI support.
@@ -24,97 +25,80 @@ def _(A: torch.Tensor, B: torch.Tensor):
).reshape(*A.shape[:-1], B.shape[0])
-@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
-def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
- torch._check_is_size(blocksize)
-
- n = A.numel()
-
- # Only FP32 has c++ kernrl
- if A.dtype == torch.float32:
- blocks = -(n // -blocksize)
-
- absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
- out = torch.empty_like(A, dtype=torch.uint8)
-
- lib.cquantize_blockwise_cpu_fp32(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_longlong(blocksize),
- ct.c_longlong(n),
- )
- else:
- rem = n % blocksize
- has_rem = rem > 0
- blocks = n // blocksize + has_rem
- absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
- A_reshaped = A.reshape(n)
- A_com = A_reshaped[: n - rem]
- A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
- absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
- scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
- scaled_A = scaled_A.reshape(-1)
- if has_rem:
- absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
- scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
- scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
-
- diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
- out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
-
- return out, absmax
-
-
-@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
-def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
- torch._check_is_size(blocksize)
- torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
-
- # Only FP32 has c++ kernrl
- if dtype == torch.float32:
- out = torch.empty_like(A, dtype=dtype)
-
- lib.cdequantize_blockwise_cpu_fp32(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_longlong(blocksize),
- ct.c_longlong(A.numel()),
- )
- else:
- out = code[A.reshape(-1).int()]
- blocks = out.shape[-1] // blocksize
- res = out.shape[-1] % blocksize
- if res != 0:
- out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
- out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
- out = out[: blocks * blocksize + res]
- out = out.reshape(A.shape)
-
- return out
-
-
-if ipex_cpu:
- from bitsandbytes.utils import _reverse_4bit_compress_format
-
- @register_kernel("bitsandbytes::dequantize_nf4_ipex", "cpu")
+if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
+ logger.info("Loading C++ bitsandbytes kernels for CPU")
+
+ @register_kernel("bitsandbytes::quantize_blockwise", "cpu")
+ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
+ torch._check_is_size(blocksize)
+
+ n = A.numel()
+
+ # Only FP32 has c++ kernrl
+ if A.dtype == torch.float32:
+ blocks = -(n // -blocksize)
+
+ absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
+ out = torch.empty_like(A, dtype=torch.uint8)
+
+ lib.cquantize_blockwise_cpu_fp32(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_longlong(blocksize),
+ ct.c_longlong(n),
+ )
+ else:
+ rem = n % blocksize
+ has_rem = rem > 0
+ blocks = n // blocksize + has_rem
+ absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
+ A_reshaped = A.reshape(n)
+ A_com = A_reshaped[: n - rem]
+ A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
+ absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
+ scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
+ scaled_A = scaled_A.reshape(-1)
+ if has_rem:
+ absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
+ scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
+ scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
+
+ diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
+ out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
+
+ return out, absmax
+
+ @register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
def _(
- A: torch.Tensor,
- absmax: torch.Tensor,
- blocksize: int,
- shape: Sequence[int],
- dtype: torch.dtype,
+ A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
- ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2)
- A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1)
- return torch.ops.bitsandbytes.dequantize_4bit.default(
- A,
- absmax,
- blocksize,
- "nf4",
- shape,
- dtype,
- )
+ torch._check_is_size(blocksize)
+ torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
+
+ # Only FP32 has c++ kernrl
+ if dtype == torch.float32:
+ out = torch.empty_like(A, dtype=dtype)
+
+ lib.cdequantize_blockwise_cpu_fp32(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_longlong(blocksize),
+ ct.c_longlong(A.numel()),
+ )
+ else:
+ out = code[A.reshape(-1).int()]
+ blocks = out.shape[-1] // blocksize
+ res = out.shape[-1] % blocksize
+ if res != 0:
+ out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
+ out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
+ out = out[: blocks * blocksize + res]
+ out = out.reshape(A.shape)
+
+ return out
+else:
+ logger.warning("Loading pytorch bitsandbytes kernels for CPU because no native library found.")
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 372632d17..5cd9eac67 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -13,7 +13,7 @@
from torch import Tensor
from typing_extensions import deprecated
-from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict
+from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
from .cextension import HIP_ENVIRONMENT, lib
@@ -1055,16 +1055,6 @@ def dequantize_4bit(
if absmax.dtype != torch.float32:
absmax = absmax.float()
- # IPEX format is different, we need extra process.
- if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4":
- return torch.ops.bitsandbytes.dequantize_nf4_ipex(
- A,
- absmax,
- quant_state.blocksize,
- quant_state.shape,
- quant_state.dtype,
- )
-
if out is not None:
torch.ops.bitsandbytes.dequantize_4bit.out(
A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out
@@ -1633,25 +1623,6 @@ def gemv_4bit(
if state.nested:
absmax = dequantize_blockwise(absmax, state.state2) + state.offset
- if getattr(state, "ipex", False) and state.quant_type == "nf4":
- # compute_dtype: 1 indicates fp16, 2 indicates bf16
- compute_dtype = 2 if A.dtype == torch.bfloat16 else 1
- out = torch.ops.torch_ipex.woq_linear(
- A,
- B,
- "nf4",
- state.shape,
- state.new_scales,
- state.new_zeros,
- None,
- None,
- state.blocksize,
- compute_dtype,
- 1,
- state.compensation,
- )
- return out
-
if out is not None:
torch.ops.bitsandbytes.gemv_4bit.out(
A,
@@ -2338,37 +2309,3 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
C = 127.0
-
-
-def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor):
- quant_state = linear.weight.quant_state
-
- if quant_state.nested:
- absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
- absmax += quant_state.offset
- if absmax.dtype != torch.float32:
- absmax = absmax.float()
-
- quant_state.absmax = absmax
- quant_state.nested = False
- delattr(quant_state, "state2")
-
- assert x.device.type == "cpu"
- converted_weight = _reverse_4bit_compress_format(linear.weight.data)
- new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
- converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
- "nf4",
- quant_state.shape, # weight shape
- quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
- None, # zero_points
- None, # bias
- None, # batch_size
- quant_state.blocksize,
- 2,
- )
-
- linear.weight.data = new_weight.data
- linear.weight.quant_state.ipex = True
- linear.weight.quant_state.new_scales = new_scales
- linear.weight.quant_state.new_zeros = new_zeros
- linear.weight.quant_state.compensation = compensation
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 9015665ee..464205fa5 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -12,14 +12,9 @@
import bitsandbytes as bnb
from bitsandbytes.cextension import HIP_ENVIRONMENT
-from bitsandbytes.functional import QuantState, _enable_ipex_fusion
+from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager
-from bitsandbytes.utils import (
- INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
- OutlierTracer,
- _reverse_4bit_compress_format,
- ipex_cpu,
-)
+from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer
T = TypeVar("T", bound="torch.nn.Module")
@@ -444,7 +439,6 @@ def __init__(
self.compute_type_is_set = False if compute_dtype is None else True
self.quant_state = None
self.quant_storage = quant_storage
- self.ipex_linear_is_set = False
def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]:
@@ -471,40 +465,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
save weight and bias,
then fill state_dict with components of quant_state
"""
- if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "ipex", False):
- if self.weight.device.type == "cpu":
- original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(
- self.weight, "nf4", self.weight.quant_state.shape, 2
- )
- self.weight.data = _reverse_4bit_compress_format(original_weight.data)
-
- self.weight.quant_state.ipex = False
- self.ipex_linear_is_set = False
-
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
if getattr(self.weight, "quant_state", None) is not None:
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
- def set_ipex_linear(self, x: torch.Tensor):
- if (
- not getattr(self.weight.quant_state, "ipex", False)
- and self.weight.data.dtype == torch.uint8
- and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
- and self.weight.quant_state.quant_type == "nf4"
- and x.device.type == "cpu"
- and not self.training
- and not x.requires_grad
- ):
- _enable_ipex_fusion(self, x)
-
def forward(self, x: torch.Tensor):
- # Check if ipex fusion can be used
- if not self.ipex_linear_is_set and ipex_cpu:
- self.set_ipex_linear(x)
- self.ipex_linear_is_set = True
-
fix_4bit_weight_quant_state_from_module(self)
# weights are cast automatically as Int8Params, but the bias has to be cast manually
@@ -520,8 +487,7 @@ def forward(self, x: torch.Tensor):
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
- # IPEX CPU will change weight to 4D so don't need transpose
- weight = self.weight.t() if self.weight.dim() == 2 else self.weight
+ weight = self.weight.t()
return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
@@ -676,7 +642,7 @@ def to(self, *args, **kwargs):
if device is not None and device.type != "meta" and self.data.device.type == "cpu":
if device.type != "cpu" or self.data.dtype != torch.int8:
return self._quantize(device)
- elif self.data.dtype == torch.int8 and device.type == "cpu" and ipex_cpu:
+ elif self.data.dtype == torch.int8 and device.type == "cpu":
self.CB = self.data
new_param = Int8Params(
diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py
index 4328a241c..0828dd295 100644
--- a/bitsandbytes/utils.py
+++ b/bitsandbytes/utils.py
@@ -4,14 +4,6 @@
import torch
-try:
- # to support Intel CPU backend
- import intel_extension_for_pytorch as ipex
-
- ipex_cpu = ipex if ipex._C._has_cpu() else None
-except BaseException:
- ipex_cpu = None
-
def outlier_hook(module, input):
assert isinstance(module, torch.nn.Linear)
@@ -46,14 +38,6 @@ def outlier_hook(module, input):
hook.remove()
-# convert btw standard 4-bit compression format and ipex compression format
-def _reverse_4bit_compress_format(weight: torch.Tensor):
- out_1 = (weight & 0xF0) >> 4
- out_2 = (weight & 0xF) << 4
- out = out_1 | out_2
- return out
-
-
class OutlierTracer:
_instance = None
diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx
index 9b3449870..7396c7dcf 100644
--- a/docs/source/installation.mdx
+++ b/docs/source/installation.mdx
@@ -138,8 +138,8 @@ We provide an early preview of support for AMD and Intel hardware as part of a d
| **Backend** | **Supported Versions** | **Python versions** | **Architecture Support** | **Status** |
|-------------|------------------------|---------------------------|-------------------------|------------|
| **AMD ROCm** | 6.1+ | 3.10+ | minimum CDNA - `gfx90a`, RDNA - `gfx1100` | Alpha |
-| **Intel CPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel CPU | Alpha |
-| **Intel GPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel GPU | Experimental |
+| **Intel CPU** | v2.4.0+ | 3.10+ | Intel CPU | Alpha |
+| **Intel GPU** | v2.7.0+ | 3.10+ | Intel GPU | Experimental |
| **Ascend NPU** | 2.1.0+ (`torch_npu`) | 3.10+ | Ascend NPU | Experimental |
For each supported backend, follow the respective instructions below:
@@ -179,7 +179,6 @@ pip install torch --index-url https://download.pytorch.org/whl/rocm6.3/
* A compatible PyTorch version with Intel XPU support is required. It is recommended to use the latest stable release. See [Getting Started on Intel GPU](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html) for guidance.
-* The [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/) is recommended for performance improvements.
@@ -235,9 +234,9 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise
-#### Intel CPU + XPU
+#### Intel CPU + GPU(XPU)
-CPU needs to build CPU C++ codes, while xpu needs to build sycl codes.
+CPU needs to build CPU C++ codes, while XPU needs to build sycl codes.
Run `export bnb_device=xpu` if you are using xpu, run `export bnb_device=cpu` if you are using cpu.
```
git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/
@@ -245,7 +244,6 @@ cmake -DCOMPUTE_BACKEND=$bnb_device -S .
make
pip install -e .
```
-Note: You can run `pip install intel_extension_for_pytorch to get better performance on CPU`
diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py
index 86726bd44..0e5f7bc18 100644
--- a/tests/test_linear8bitlt.py
+++ b/tests/test_linear8bitlt.py
@@ -272,14 +272,11 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
# Test with gradients. Currently only works with threshold=0.
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0.
- # There is also an issue with torch==2.7.0 on x86-64 with IPEX.
is_broken_platform = (
device == "cpu"
and platform.system() == "Linux"
- and (
- (platform.machine() == "aarch64" and (2, 6) <= torch.__version__ < (2, 7))
- or (platform.machine() == "x86_64" and bnb.functional.ipex_cpu)
- )
+ and platform.machine() == "aarch64"
+ and (2, 6) <= torch.__version__ < (2, 7)
)
if threshold == 0 and not is_broken_platform: