Skip to content

Commit 45bd314

Browse files
authored
Enable CPU Optimizer Support for bitsandbytes (#1901)
* fix kernelk Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable cpu optimizer Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix ademamix Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update tests and example Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * optimize Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix 8bit custom op Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update example Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update example Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix lint Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix storage Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix shape Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix dispatch Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix macos CI Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix GNUC flag check Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix c++ compile Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent dd599c2 commit 45bd314

File tree

8 files changed

+890
-81
lines changed

8 files changed

+890
-81
lines changed

bitsandbytes/backends/cpu/ops.py

100644100755
Lines changed: 295 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from collections.abc import Sequence
22
import ctypes as ct
33
import logging
4+
import math
45
from math import prod
6+
from typing import Optional
57

68
import torch
79

@@ -36,14 +38,12 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
3638
torch._check_is_size(blocksize)
3739

3840
n = A.numel()
41+
blocks = -(n // -blocksize)
3942

40-
# Only FP32 has c++ kernrl
41-
if A.dtype == torch.float32:
42-
blocks = -(n // -blocksize)
43-
44-
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
45-
out = torch.empty_like(A, dtype=torch.uint8)
43+
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
44+
out = torch.empty(A.shape, device=A.device, dtype=torch.uint8)
4645

46+
if A.dtype == torch.float32:
4747
lib.cquantize_blockwise_cpu_fp32(
4848
get_ptr(code),
4949
get_ptr(A),
@@ -52,20 +52,37 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
5252
ct.c_longlong(blocksize),
5353
ct.c_longlong(n),
5454
)
55+
elif A.dtype == torch.bfloat16:
56+
lib.cquantize_blockwise_cpu_bf16(
57+
get_ptr(code),
58+
get_ptr(A),
59+
get_ptr(absmax),
60+
get_ptr(out),
61+
ct.c_longlong(blocksize),
62+
ct.c_longlong(n),
63+
)
64+
elif A.dtype == torch.float16:
65+
lib.cquantize_blockwise_cpu_fp16(
66+
get_ptr(code),
67+
get_ptr(A),
68+
get_ptr(absmax),
69+
get_ptr(out),
70+
ct.c_longlong(blocksize),
71+
ct.c_longlong(n),
72+
)
5573
else:
74+
# Generic fallback for other dtypes
75+
A_flat = A.reshape(n).float()
5676
rem = n % blocksize
5777
has_rem = rem > 0
58-
blocks = n // blocksize + has_rem
59-
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
60-
A_reshaped = A.reshape(n)
61-
A_com = A_reshaped[: n - rem]
78+
A_com = A_flat[: n - rem]
6279
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
6380
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
6481
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
6582
scaled_A = scaled_A.reshape(-1)
6683
if has_rem:
67-
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
68-
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
84+
absmax[-1] = torch.abs(A_flat[n - rem :]).max()
85+
scaled_A_rem = torch.clamp(A_flat[n - rem :] * (1 / absmax[-1]), -1, 1)
6986
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
7087

7188
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
@@ -248,19 +265,24 @@ def _(
248265
code: torch.Tensor,
249266
blocksize: int,
250267
) -> torch.Tensor:
251-
assert B.dtype == torch.uint8, "Only support uint8 qweight"
268+
if B.dtype != torch.uint8:
269+
B = B.contiguous().view(torch.uint8)
252270
dtype = A.dtype
253271
quant_type = "fp4" if code[1] > 0 else "nf4"
254272
# cpu fused op only support bf16 for now.
255273
if dtype != torch.bfloat16:
256274
A = A.to(torch.bfloat16)
275+
if absmax.dtype != torch.bfloat16:
276+
absmax = absmax.to(torch.bfloat16)
257277

258278
final_out_shape = (*A.shape[:-1], shapeB[0])
259279
A = A.reshape(-1, A.shape[-1])
260280
out_shape = (*A.shape[:-1], shapeB[0])
261281
if gemm_4bit_forward_kernel is not None:
262282
quant_type_num = 1 if quant_type == "fp4" else 0
263-
out = gemm_4bit_forward_kernel(A, B, absmax, blocksize, quant_type_num)
283+
# C++ kernel expects weight shape (N, K_packed), ensure 2D contiguous
284+
B_2d = B.reshape(shapeB[0], -1).contiguous()
285+
out = gemm_4bit_forward_kernel(A, B_2d, absmax, blocksize, quant_type_num)
264286
else:
265287
out = torch.empty(out_shape, dtype=A.dtype, device=A.device)
266288
M = A.shape[0]
@@ -299,3 +321,262 @@ def _(
299321
out = out.to(dtype)
300322

301323
return out.reshape(final_out_shape)
324+
325+
326+
# ==================== CPU Optimizer Kernels ====================
327+
328+
329+
def _compute_update_norm_and_scale(
330+
update: torch.Tensor,
331+
unorm_vec: Optional[torch.Tensor],
332+
max_unorm: float,
333+
param_norm: float,
334+
) -> float:
335+
"""Compute trust-ratio scaling factor for LAMB/LARS and store update norm."""
336+
if max_unorm <= 0.0:
337+
return 1.0
338+
unorm = torch.norm(update).item()
339+
if unorm_vec is not None:
340+
unorm_vec.fill_(unorm)
341+
if unorm > max_unorm * param_norm:
342+
return (max_unorm * param_norm) / unorm
343+
return 1.0
344+
345+
346+
@torch.no_grad()
347+
def _optimizer_update_32bit_cpu(
348+
optimizer_name: str,
349+
g: torch.Tensor,
350+
p: torch.Tensor,
351+
state1: torch.Tensor,
352+
state2: Optional[torch.Tensor],
353+
unorm_vec: Optional[torch.Tensor],
354+
max_unorm: float,
355+
param_norm: float,
356+
beta1: float,
357+
beta2: float,
358+
beta3: float,
359+
alpha: float,
360+
eps: float,
361+
weight_decay: float,
362+
step: int,
363+
lr: float,
364+
gnorm_scale: float,
365+
skip_zeros: bool = False,
366+
) -> None:
367+
g_float = g.float() * gnorm_scale
368+
p_float = p.data.float()
369+
370+
if optimizer_name in ("adam", "lamb"):
371+
# Adam / LAMB (2-state): m and v
372+
state1.mul_(beta1).add_(g_float, alpha=1.0 - beta1)
373+
state2.mul_(beta2).addcmul_(g_float, g_float, value=1.0 - beta2)
374+
375+
correction1 = 1.0 - beta1**step
376+
correction2 = math.sqrt(1.0 - beta2**step)
377+
step_size = -lr * correction2 / correction1
378+
379+
if weight_decay > 0.0:
380+
p_float.mul_(1.0 - lr * weight_decay)
381+
382+
update = state1 / (state2.sqrt() + eps * correction2)
383+
384+
update_scale = _compute_update_norm_and_scale(update, unorm_vec, max_unorm, param_norm)
385+
p_float.add_(update, alpha=step_size * update_scale)
386+
387+
elif optimizer_name == "ademamix":
388+
# AdEMAMix (2-state): state1 shape is (2, *p.shape), state1[0]=m1, state1[1]=m2
389+
m1 = state1[0]
390+
m2 = state1[1]
391+
nu = state2
392+
393+
m1.mul_(beta1).add_(g_float, alpha=1.0 - beta1)
394+
m2.mul_(beta3).add_(g_float, alpha=1.0 - beta3)
395+
nu.mul_(beta2).addcmul_(g_float, g_float, value=1.0 - beta2)
396+
397+
correction1 = 1.0 - beta1**step
398+
correction2 = math.sqrt(1.0 - beta2**step)
399+
400+
if weight_decay > 0.0:
401+
p_float.mul_(1.0 - lr * weight_decay)
402+
403+
mixed_momentum = (m1 / correction1) + (alpha * m2)
404+
adaptive_term = (nu.sqrt() / correction2) + eps
405+
p_float.add_(mixed_momentum / adaptive_term, alpha=-lr)
406+
407+
elif optimizer_name in ("momentum", "lars"):
408+
# SGD with momentum / LARS (1-state)
409+
g_wd = g_float.add(p_float, alpha=weight_decay) if weight_decay > 0.0 else g_float
410+
411+
if step == 1:
412+
state1.copy_(g_wd)
413+
else:
414+
state1.mul_(beta1).add_(g_wd)
415+
416+
update_scale = _compute_update_norm_and_scale(state1, unorm_vec, max_unorm, param_norm)
417+
p_float.add_(state1, alpha=-lr * update_scale)
418+
419+
elif optimizer_name == "lion":
420+
# Lion (2-state sign update)
421+
if weight_decay > 0.0:
422+
p_float.mul_(1.0 - lr * weight_decay)
423+
424+
update = state1.mul(beta1).add(g_float, alpha=1.0 - beta1)
425+
p_float.add_(update.sign(), alpha=-lr)
426+
427+
state1.mul_(beta2).add_(g_float, alpha=1.0 - beta2)
428+
429+
elif optimizer_name == "rmsprop":
430+
# RMSprop (1-state)
431+
g_wd = g_float.add(p_float, alpha=weight_decay) if weight_decay > 0.0 else g_float
432+
state1.mul_(beta1).addcmul_(g_wd, g_wd, value=1.0 - beta1)
433+
434+
update = g_wd / (state1.sqrt() + eps)
435+
update_scale = _compute_update_norm_and_scale(update, unorm_vec, max_unorm, param_norm)
436+
p_float.add_(update, alpha=-lr * update_scale)
437+
438+
elif optimizer_name == "adagrad":
439+
# Adagrad (1-state)
440+
g_wd = g_float.add(p_float, alpha=weight_decay) if weight_decay > 0.0 else g_float
441+
state1.addcmul_(g_wd, g_wd, value=1.0)
442+
443+
update = g_wd / (state1.sqrt() + eps)
444+
p_float.add_(update, alpha=-lr)
445+
446+
else:
447+
raise ValueError(f"Unsupported optimizer for CPU: {optimizer_name}")
448+
449+
# Write back to original precision
450+
p.data.copy_(p_float)
451+
452+
453+
register_kernel("bitsandbytes::optimizer_update_32bit", "cpu")(_optimizer_update_32bit_cpu)
454+
455+
456+
@torch.no_grad()
457+
def _dequant_blockwise_fp32_direct(
458+
A_uint8: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int
459+
) -> torch.Tensor:
460+
return torch.ops.bitsandbytes.dequantize_blockwise(A_uint8, absmax, code, blocksize, torch.float32)
461+
462+
463+
def _quant_blockwise_fp32_direct(
464+
A_fp32: torch.Tensor, code: torch.Tensor, absmax_out: torch.Tensor, out_uint8: torch.Tensor, blocksize: int
465+
) -> None:
466+
out, absmax = torch.ops.bitsandbytes.quantize_blockwise(A_fp32, code, blocksize)
467+
out_uint8.copy_(out)
468+
absmax_out.copy_(absmax)
469+
470+
471+
def _optimizer_update_8bit_blockwise_cpu(
472+
optimizer_name: str,
473+
g: torch.Tensor,
474+
p: torch.Tensor,
475+
state1: torch.Tensor,
476+
state2: Optional[torch.Tensor],
477+
beta1: float,
478+
beta2: float,
479+
beta3: float,
480+
alpha: float,
481+
eps: float,
482+
step: int,
483+
lr: float,
484+
qmap1: torch.Tensor,
485+
qmap2: Optional[torch.Tensor],
486+
absmax1: torch.Tensor,
487+
absmax2: Optional[torch.Tensor],
488+
weight_decay: float,
489+
gnorm_scale: float,
490+
skip_zeros: bool = False,
491+
) -> None:
492+
blocksize = 256
493+
494+
# Dequantize states
495+
if optimizer_name == "ademamix" and absmax1.ndim == 2:
496+
s1_1 = _dequant_blockwise_fp32_direct(state1[0], absmax1[0], qmap1, blocksize)
497+
s1_2 = _dequant_blockwise_fp32_direct(state1[1], absmax1[1], qmap1, blocksize)
498+
state1_fp32 = torch.stack([s1_1, s1_2])
499+
else:
500+
state1_fp32 = _dequant_blockwise_fp32_direct(state1, absmax1, qmap1, blocksize)
501+
502+
state2_fp32 = None
503+
if state2 is not None and qmap2 is not None and absmax2 is not None:
504+
state2_fp32 = _dequant_blockwise_fp32_direct(state2, absmax2, qmap2, blocksize)
505+
506+
grad = g.float() * gnorm_scale
507+
p_fp32 = p.data.float()
508+
509+
if optimizer_name in ("adam", "lamb"):
510+
state1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
511+
state2_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
512+
513+
correction1 = 1.0 - beta1**step
514+
correction2 = math.sqrt(1.0 - beta2**step)
515+
516+
denom = (state2_fp32.sqrt() / correction2).add_(eps)
517+
if weight_decay > 0.0:
518+
p_fp32.mul_(1.0 - lr * weight_decay)
519+
p_fp32.addcdiv_(state1_fp32, denom, value=-lr / correction1)
520+
521+
elif optimizer_name == "ademamix":
522+
m1_fp32, m2_fp32 = state1_fp32[0], state1_fp32[1]
523+
nu_fp32 = state2_fp32
524+
525+
m1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
526+
m2_fp32.mul_(beta3).add_(grad, alpha=1.0 - beta3)
527+
nu_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
528+
529+
correction1 = 1.0 - beta1**step
530+
correction2 = math.sqrt(1.0 - beta2**step)
531+
532+
update = (m1_fp32 / correction1 + alpha * m2_fp32) / (nu_fp32.sqrt() / correction2 + eps)
533+
if weight_decay > 0.0:
534+
p_fp32.mul_(1.0 - lr * weight_decay)
535+
p_fp32.add_(update, alpha=-lr)
536+
537+
state1_fp32 = torch.stack([m1_fp32, m2_fp32])
538+
539+
elif optimizer_name in ("momentum", "lars"):
540+
grad.add_(p_fp32, alpha=weight_decay)
541+
if step == 1:
542+
state1_fp32.copy_(grad)
543+
else:
544+
state1_fp32.mul_(beta1).add_(grad)
545+
p_fp32.add_(state1_fp32, alpha=-lr)
546+
547+
elif optimizer_name == "lion":
548+
if weight_decay > 0.0:
549+
p_fp32.mul_(1.0 - lr * weight_decay)
550+
551+
update_dir = torch.sign(state1_fp32.mul(beta1) + grad.mul(1.0 - beta1))
552+
p_fp32.add_(update_dir, alpha=-lr)
553+
554+
state1_fp32.mul_(beta2).add_(grad, alpha=1.0 - beta2)
555+
556+
elif optimizer_name == "rmsprop":
557+
grad.add_(p_fp32, alpha=weight_decay)
558+
state1_fp32.mul_(beta1).addcmul_(grad, grad, value=1.0 - beta1)
559+
p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
560+
561+
elif optimizer_name == "adagrad":
562+
grad.add_(p_fp32, alpha=weight_decay)
563+
state1_fp32.addcmul_(grad, grad, value=1.0)
564+
p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
565+
566+
else:
567+
raise ValueError(f"Unsupported optimizer for CPU 8-bit: {optimizer_name}")
568+
569+
p.data.copy_(p_fp32)
570+
571+
# Re-quantize states
572+
if optimizer_name == "ademamix":
573+
_quant_blockwise_fp32_direct(state1_fp32[0], qmap1, absmax1[0], state1[0], blocksize)
574+
_quant_blockwise_fp32_direct(state1_fp32[1], qmap1, absmax1[1], state1[1], blocksize)
575+
_quant_blockwise_fp32_direct(state2_fp32, qmap2, absmax2, state2, blocksize)
576+
else:
577+
_quant_blockwise_fp32_direct(state1_fp32, qmap1, absmax1, state1, blocksize)
578+
if state2_fp32 is not None:
579+
_quant_blockwise_fp32_direct(state2_fp32, qmap2, absmax2, state2, blocksize)
580+
581+
582+
register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cpu")(_optimizer_update_8bit_blockwise_cpu)

0 commit comments

Comments
 (0)