@@ -31,7 +31,7 @@ def _(A: torch.Tensor, B: torch.Tensor):
3131 ).reshape (* A .shape [:- 1 ], B .shape [0 ])
3232
3333
34- if not isinstance (lib , ErrorHandlerMockBNBNativeLibrary ):
34+ if not isinstance (lib , ErrorHandlerMockBNBNativeLibrary ) and _has_avx512 :
3535
3636 @register_kernel ("bitsandbytes::quantize_blockwise" , "cpu" )
3737 def _ (A : torch .Tensor , code : torch .Tensor , blocksize : int ) -> tuple [torch .Tensor , torch .Tensor ]:
@@ -457,33 +457,15 @@ def _optimizer_update_32bit_cpu(
457457def _dequant_blockwise_fp32_direct (
458458 A_uint8 : torch .Tensor , absmax : torch .Tensor , code : torch .Tensor , blocksize : int
459459) -> torch .Tensor :
460- """Dequantize blockwise via direct C lib call, avoiding torch.ops dispatch overhead."""
461- n = A_uint8 .numel ()
462- out = torch .empty (n , dtype = torch .float32 , device = A_uint8 .device )
463- lib .cdequantize_blockwise_cpu_fp32 (
464- get_ptr (code ),
465- get_ptr (A_uint8 .reshape (- 1 )),
466- get_ptr (absmax ),
467- get_ptr (out ),
468- ct .c_longlong (blocksize ),
469- ct .c_longlong (n ),
470- )
471- return out .reshape (A_uint8 .shape )
460+ return torch .ops .bitsandbytes .dequantize_blockwise (A_uint8 , absmax , code , blocksize , torch .float32 )
472461
473462
474463def _quant_blockwise_fp32_direct (
475464 A_fp32 : torch .Tensor , code : torch .Tensor , absmax_out : torch .Tensor , out_uint8 : torch .Tensor , blocksize : int
476465) -> None :
477- """Quantize blockwise via direct C lib call, writing into existing buffers (zero-alloc)."""
478- n = A_fp32 .numel ()
479- lib .cquantize_blockwise_cpu_fp32 (
480- get_ptr (code ),
481- get_ptr (A_fp32 .reshape (- 1 )),
482- get_ptr (absmax_out ),
483- get_ptr (out_uint8 .reshape (- 1 )),
484- ct .c_longlong (blocksize ),
485- ct .c_longlong (n ),
486- )
466+ out , absmax = torch .ops .bitsandbytes .quantize_blockwise (A_fp32 , code , blocksize )
467+ out_uint8 .copy_ (out )
468+ absmax_out .copy_ (absmax )
487469
488470
489471def _optimizer_update_8bit_blockwise_cpu (
@@ -509,7 +491,7 @@ def _optimizer_update_8bit_blockwise_cpu(
509491) -> None :
510492 blocksize = 256
511493
512- # Dequantize states — direct C lib calls (no torch.ops dispatch overhead)
494+ # Dequantize states
513495 if optimizer_name == "ademamix" and absmax1 .ndim == 2 :
514496 s1_1 = _dequant_blockwise_fp32_direct (state1 [0 ], absmax1 [0 ], qmap1 , blocksize )
515497 s1_2 = _dequant_blockwise_fp32_direct (state1 [1 ], absmax1 [1 ], qmap1 , blocksize )
@@ -586,7 +568,7 @@ def _optimizer_update_8bit_blockwise_cpu(
586568
587569 p .data .copy_ (p_fp32 )
588570
589- # Re-quantize states — direct C lib calls, zero-alloc (write into existing buffers)
571+ # Re-quantize states
590572 if optimizer_name == "ademamix" :
591573 _quant_blockwise_fp32_direct (state1_fp32 [0 ], qmap1 , absmax1 [0 ], state1 [0 ], blocksize )
592574 _quant_blockwise_fp32_direct (state1_fp32 [1 ], qmap1 , absmax1 [1 ], state1 [1 ], blocksize )
0 commit comments