Skip to content

Commit 2ad0744

Browse files
committed
fix dispatch
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent af7410d commit 2ad0744

File tree

1 file changed

+7
-25
lines changed
  • bitsandbytes/backends/cpu

1 file changed

+7
-25
lines changed

bitsandbytes/backends/cpu/ops.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _(A: torch.Tensor, B: torch.Tensor):
3131
).reshape(*A.shape[:-1], B.shape[0])
3232

3333

34-
if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
34+
if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary) and _has_avx512:
3535

3636
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
3737
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
@@ -457,33 +457,15 @@ def _optimizer_update_32bit_cpu(
457457
def _dequant_blockwise_fp32_direct(
458458
A_uint8: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int
459459
) -> torch.Tensor:
460-
"""Dequantize blockwise via direct C lib call, avoiding torch.ops dispatch overhead."""
461-
n = A_uint8.numel()
462-
out = torch.empty(n, dtype=torch.float32, device=A_uint8.device)
463-
lib.cdequantize_blockwise_cpu_fp32(
464-
get_ptr(code),
465-
get_ptr(A_uint8.reshape(-1)),
466-
get_ptr(absmax),
467-
get_ptr(out),
468-
ct.c_longlong(blocksize),
469-
ct.c_longlong(n),
470-
)
471-
return out.reshape(A_uint8.shape)
460+
return torch.ops.bitsandbytes.dequantize_blockwise(A_uint8, absmax, code, blocksize, torch.float32)
472461

473462

474463
def _quant_blockwise_fp32_direct(
475464
A_fp32: torch.Tensor, code: torch.Tensor, absmax_out: torch.Tensor, out_uint8: torch.Tensor, blocksize: int
476465
) -> None:
477-
"""Quantize blockwise via direct C lib call, writing into existing buffers (zero-alloc)."""
478-
n = A_fp32.numel()
479-
lib.cquantize_blockwise_cpu_fp32(
480-
get_ptr(code),
481-
get_ptr(A_fp32.reshape(-1)),
482-
get_ptr(absmax_out),
483-
get_ptr(out_uint8.reshape(-1)),
484-
ct.c_longlong(blocksize),
485-
ct.c_longlong(n),
486-
)
466+
out, absmax = torch.ops.bitsandbytes.quantize_blockwise(A_fp32, code, blocksize)
467+
out_uint8.copy_(out)
468+
absmax_out.copy_(absmax)
487469

488470

489471
def _optimizer_update_8bit_blockwise_cpu(
@@ -509,7 +491,7 @@ def _optimizer_update_8bit_blockwise_cpu(
509491
) -> None:
510492
blocksize = 256
511493

512-
# Dequantize states — direct C lib calls (no torch.ops dispatch overhead)
494+
# Dequantize states
513495
if optimizer_name == "ademamix" and absmax1.ndim == 2:
514496
s1_1 = _dequant_blockwise_fp32_direct(state1[0], absmax1[0], qmap1, blocksize)
515497
s1_2 = _dequant_blockwise_fp32_direct(state1[1], absmax1[1], qmap1, blocksize)
@@ -586,7 +568,7 @@ def _optimizer_update_8bit_blockwise_cpu(
586568

587569
p.data.copy_(p_fp32)
588570

589-
# Re-quantize states — direct C lib calls, zero-alloc (write into existing buffers)
571+
# Re-quantize states
590572
if optimizer_name == "ademamix":
591573
_quant_blockwise_fp32_direct(state1_fp32[0], qmap1, absmax1[0], state1[0], blocksize)
592574
_quant_blockwise_fp32_direct(state1_fp32[1], qmap1, absmax1[1], state1[1], blocksize)

0 commit comments

Comments
 (0)