Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 1 addition & 18 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ jobs:
- name: Run tests
run: pytest --durations=100

test-cpu-ipex:
test-cpu-intel:
if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
needs: build-cpu
runs-on: banb-aws-general-8-plus-use1-public-80
Expand All @@ -186,7 +186,6 @@ jobs:
- name: Install dependencies
run: |
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu
pip install intel_extension_for_pytorch==2.7.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
pip install -e ".[test]"
pip install pytest-cov

Expand All @@ -196,9 +195,6 @@ jobs:
- name: Show environment information
run: python -m torch.utils.collect_env

- name: IPEX smoke test
run: python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__);"

- name: Run tests
run: pytest --durations=100

Expand Down Expand Up @@ -286,15 +282,6 @@ jobs:
fail-fast: false
matrix:
torch_version: ["2.7.1"] #["2.6.0", "2.7.1"]
ipex: [false]
# ipex: [true, false]
# include:
# - torch_version: "2.6.0"
# ipex: true
# ipex_version: "2.6.10+xpu"
# - torch_version: "2.7.1"
# ipex: true
# ipex_version: "2.7.10+xpu"
runs-on:
group: bandb-itac-bmsprpvc1550-8-1gpu
env:
Expand Down Expand Up @@ -330,10 +317,6 @@ jobs:
- name: Install PyTorch
run: pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/xpu

- name: Install IPEX
if: matrix.ipex == true
run: pip install intel_extension_for_pytorch==${{ matrix.ipex_version }} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/

- name: Install dependencies
run: |
pip install -e ".[test]"
Expand Down
21 changes: 0 additions & 21 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import torch

from .utils import ipex_cpu

_IS_TORCH_GTE_24 = False

if hasattr(torch.library, "register_fake"):
Expand Down Expand Up @@ -329,22 +327,3 @@ def _(
)
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")


if ipex_cpu:
# Register the dequantize_nf4_ipex implementation
torch.library.define(
"bitsandbytes::dequantize_nf4_ipex",
"(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor",
)

@register_fake("bitsandbytes::dequantize_nf4_ipex")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
return torch.empty(shape, dtype=dtype, device=A.device)
16 changes: 3 additions & 13 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,9 @@ def matmul(
if threshold > 0.0:
state.threshold = threshold
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
if state.is_training:
if A.device.type in ("cpu", "xpu"):
return MatMul8bitFp.apply(A, B, out, bias, state)
if state.is_training and A.device.type in ("cpu", "xpu"):
return MatMul8bitFp.apply(A, B, out, bias, state)

return MatMul8bitLt.apply(A, B, out, bias, state)


Expand All @@ -437,16 +437,6 @@ def matmul_4bit(
):
assert quant_state is not None

if A.device.type == "cpu" and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
# IPEX CPU will change weight to 4D so don't need transpose
B = B.t() if B.dim() == 2 else B
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
Expand Down
174 changes: 79 additions & 95 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from collections.abc import Sequence
import ctypes as ct
import logging

import torch

from bitsandbytes.functional import get_ptr

from ..._ops import register_kernel
from ...cextension import lib
from ...utils import ipex_cpu
from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib

logger = logging.getLogger(__name__)

# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
# However, we can overflow if we use this without AVX512_VNNI support.
Expand All @@ -24,97 +25,80 @@ def _(A: torch.Tensor, B: torch.Tensor):
).reshape(*A.shape[:-1], B.shape[0])


@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)

n = A.numel()

# Only FP32 has c++ kernrl
if A.dtype == torch.float32:
blocks = -(n // -blocksize)

absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)

lib.cquantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(n),
)
else:
rem = n % blocksize
has_rem = rem > 0
blocks = n // blocksize + has_rem
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
A_reshaped = A.reshape(n)
A_com = A_reshaped[: n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
scaled_A = scaled_A.reshape(-1)
if has_rem:
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)

diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)

return out, absmax


@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")

# Only FP32 has c++ kernrl
if dtype == torch.float32:
out = torch.empty_like(A, dtype=dtype)

lib.cdequantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
else:
out = code[A.reshape(-1).int()]
blocks = out.shape[-1] // blocksize
res = out.shape[-1] % blocksize
if res != 0:
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
out = out[: blocks * blocksize + res]
out = out.reshape(A.shape)

return out


if ipex_cpu:
from bitsandbytes.utils import _reverse_4bit_compress_format

@register_kernel("bitsandbytes::dequantize_nf4_ipex", "cpu")
if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
logger.info("Loading C++ bitsandbytes kernels for CPU")

@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)

n = A.numel()

# Only FP32 has c++ kernrl
if A.dtype == torch.float32:
blocks = -(n // -blocksize)

absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)

lib.cquantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(n),
)
else:
rem = n % blocksize
has_rem = rem > 0
blocks = n // blocksize + has_rem
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
A_reshaped = A.reshape(n)
A_com = A_reshaped[: n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
scaled_A = scaled_A.reshape(-1)
if has_rem:
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)

diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)

return out, absmax

@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
shape: Sequence[int],
dtype: torch.dtype,
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2)
A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1)
return torch.ops.bitsandbytes.dequantize_4bit.default(
A,
absmax,
blocksize,
"nf4",
shape,
dtype,
)
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")

# Only FP32 has c++ kernrl
if dtype == torch.float32:
out = torch.empty_like(A, dtype=dtype)

lib.cdequantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
else:
out = code[A.reshape(-1).int()]
blocks = out.shape[-1] // blocksize
res = out.shape[-1] % blocksize
if res != 0:
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
out = out[: blocks * blocksize + res]
out = out.reshape(A.shape)

return out
else:
logger.warning("Loading pytorch bitsandbytes kernels for CPU because no native library found.")
65 changes: 1 addition & 64 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch import Tensor
from typing_extensions import deprecated

from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict

from .cextension import HIP_ENVIRONMENT, lib

Expand Down Expand Up @@ -1055,16 +1055,6 @@ def dequantize_4bit(
if absmax.dtype != torch.float32:
absmax = absmax.float()

# IPEX format is different, we need extra process.
if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4":
return torch.ops.bitsandbytes.dequantize_nf4_ipex(
A,
absmax,
quant_state.blocksize,
quant_state.shape,
quant_state.dtype,
)

if out is not None:
torch.ops.bitsandbytes.dequantize_4bit.out(
A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out
Expand Down Expand Up @@ -1633,25 +1623,6 @@ def gemv_4bit(
if state.nested:
absmax = dequantize_blockwise(absmax, state.state2) + state.offset

if getattr(state, "ipex", False) and state.quant_type == "nf4":
# compute_dtype: 1 indicates fp16, 2 indicates bf16
compute_dtype = 2 if A.dtype == torch.bfloat16 else 1
out = torch.ops.torch_ipex.woq_linear(
A,
B,
"nf4",
state.shape,
state.new_scales,
state.new_zeros,
None,
None,
state.blocksize,
compute_dtype,
1,
state.compensation,
)
return out

if out is not None:
torch.ops.bitsandbytes.gemv_4bit.out(
A,
Expand Down Expand Up @@ -2338,37 +2309,3 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):


C = 127.0


def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor):
quant_state = linear.weight.quant_state

if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()

quant_state.absmax = absmax
quant_state.nested = False
delattr(quant_state, "state2")

assert x.device.type == "cpu"
converted_weight = _reverse_4bit_compress_format(linear.weight.data)
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
"nf4",
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # batch_size
quant_state.blocksize,
2,
)

linear.weight.data = new_weight.data
linear.weight.quant_state.ipex = True
linear.weight.quant_state.new_scales = new_scales
linear.weight.quant_state.new_zeros = new_zeros
linear.weight.quant_state.compensation = compensation
Loading
Loading