11from collections .abc import Sequence
22import ctypes as ct
33import logging
4+ import math
45from math import prod
6+ from typing import Optional
57
68import torch
79
@@ -36,14 +38,12 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
3638 torch ._check_is_size (blocksize )
3739
3840 n = A .numel ()
41+ blocks = - (n // - blocksize )
3942
40- # Only FP32 has c++ kernrl
41- if A .dtype == torch .float32 :
42- blocks = - (n // - blocksize )
43-
44- absmax = torch .empty ((blocks ,), device = A .device , dtype = torch .float32 )
45- out = torch .empty_like (A , dtype = torch .uint8 )
43+ absmax = torch .empty ((blocks ,), device = A .device , dtype = torch .float32 )
44+ out = torch .empty (A .shape , device = A .device , dtype = torch .uint8 )
4645
46+ if A .dtype == torch .float32 :
4747 lib .cquantize_blockwise_cpu_fp32 (
4848 get_ptr (code ),
4949 get_ptr (A ),
@@ -52,20 +52,37 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
5252 ct .c_longlong (blocksize ),
5353 ct .c_longlong (n ),
5454 )
55+ elif A .dtype == torch .bfloat16 :
56+ lib .cquantize_blockwise_cpu_bf16 (
57+ get_ptr (code ),
58+ get_ptr (A ),
59+ get_ptr (absmax ),
60+ get_ptr (out ),
61+ ct .c_longlong (blocksize ),
62+ ct .c_longlong (n ),
63+ )
64+ elif A .dtype == torch .float16 :
65+ lib .cquantize_blockwise_cpu_fp16 (
66+ get_ptr (code ),
67+ get_ptr (A ),
68+ get_ptr (absmax ),
69+ get_ptr (out ),
70+ ct .c_longlong (blocksize ),
71+ ct .c_longlong (n ),
72+ )
5573 else :
74+ # Generic fallback for other dtypes
75+ A_flat = A .reshape (n ).float ()
5676 rem = n % blocksize
5777 has_rem = rem > 0
58- blocks = n // blocksize + has_rem
59- absmax = torch .zeros ((blocks ,), device = A .device , dtype = torch .float32 )
60- A_reshaped = A .reshape (n )
61- A_com = A_reshaped [: n - rem ]
78+ A_com = A_flat [: n - rem ]
6279 A_com_reshaped = A_com .reshape (n // blocksize , blocksize )
6380 absmax [: blocks - has_rem ] = torch .abs (A_com_reshaped ).max (dim = - 1 )[0 ]
6481 scaled_A = torch .clamp (A_com_reshaped * (1 / absmax [: blocks - has_rem ].view (- 1 , 1 )), - 1 , 1 )
6582 scaled_A = scaled_A .reshape (- 1 )
6683 if has_rem :
67- absmax [- 1 ] = torch .abs (A_reshaped [n - rem :]).max ()
68- scaled_A_rem = torch .clamp (A_reshaped [n - rem :] * (1 / absmax [- 1 ]), - 1 , 1 )
84+ absmax [- 1 ] = torch .abs (A_flat [n - rem :]).max ()
85+ scaled_A_rem = torch .clamp (A_flat [n - rem :] * (1 / absmax [- 1 ]), - 1 , 1 )
6986 scaled_A = torch .cat ([scaled_A , scaled_A_rem ], dim = 0 )
7087
7188 diff = torch .abs (scaled_A .unsqueeze (- 1 ) - code .to (scaled_A .device ))
@@ -248,19 +265,24 @@ def _(
248265 code : torch .Tensor ,
249266 blocksize : int ,
250267 ) -> torch .Tensor :
251- assert B .dtype == torch .uint8 , "Only support uint8 qweight"
268+ if B .dtype != torch .uint8 :
269+ B = B .contiguous ().view (torch .uint8 )
252270 dtype = A .dtype
253271 quant_type = "fp4" if code [1 ] > 0 else "nf4"
254272 # cpu fused op only support bf16 for now.
255273 if dtype != torch .bfloat16 :
256274 A = A .to (torch .bfloat16 )
275+ if absmax .dtype != torch .bfloat16 :
276+ absmax = absmax .to (torch .bfloat16 )
257277
258278 final_out_shape = (* A .shape [:- 1 ], shapeB [0 ])
259279 A = A .reshape (- 1 , A .shape [- 1 ])
260280 out_shape = (* A .shape [:- 1 ], shapeB [0 ])
261281 if gemm_4bit_forward_kernel is not None :
262282 quant_type_num = 1 if quant_type == "fp4" else 0
263- out = gemm_4bit_forward_kernel (A , B , absmax , blocksize , quant_type_num )
283+ # C++ kernel expects weight shape (N, K_packed), ensure 2D contiguous
284+ B_2d = B .reshape (shapeB [0 ], - 1 ).contiguous ()
285+ out = gemm_4bit_forward_kernel (A , B_2d , absmax , blocksize , quant_type_num )
264286 else :
265287 out = torch .empty (out_shape , dtype = A .dtype , device = A .device )
266288 M = A .shape [0 ]
@@ -299,3 +321,262 @@ def _(
299321 out = out .to (dtype )
300322
301323 return out .reshape (final_out_shape )
324+
325+
326+ # ==================== CPU Optimizer Kernels ====================
327+
328+
329+ def _compute_update_norm_and_scale (
330+ update : torch .Tensor ,
331+ unorm_vec : Optional [torch .Tensor ],
332+ max_unorm : float ,
333+ param_norm : float ,
334+ ) -> float :
335+ """Compute trust-ratio scaling factor for LAMB/LARS and store update norm."""
336+ if max_unorm <= 0.0 :
337+ return 1.0
338+ unorm = torch .norm (update ).item ()
339+ if unorm_vec is not None :
340+ unorm_vec .fill_ (unorm )
341+ if unorm > max_unorm * param_norm :
342+ return (max_unorm * param_norm ) / unorm
343+ return 1.0
344+
345+
346+ @torch .no_grad ()
347+ def _optimizer_update_32bit_cpu (
348+ optimizer_name : str ,
349+ g : torch .Tensor ,
350+ p : torch .Tensor ,
351+ state1 : torch .Tensor ,
352+ state2 : Optional [torch .Tensor ],
353+ unorm_vec : Optional [torch .Tensor ],
354+ max_unorm : float ,
355+ param_norm : float ,
356+ beta1 : float ,
357+ beta2 : float ,
358+ beta3 : float ,
359+ alpha : float ,
360+ eps : float ,
361+ weight_decay : float ,
362+ step : int ,
363+ lr : float ,
364+ gnorm_scale : float ,
365+ skip_zeros : bool = False ,
366+ ) -> None :
367+ g_float = g .float () * gnorm_scale
368+ p_float = p .data .float ()
369+
370+ if optimizer_name in ("adam" , "lamb" ):
371+ # Adam / LAMB (2-state): m and v
372+ state1 .mul_ (beta1 ).add_ (g_float , alpha = 1.0 - beta1 )
373+ state2 .mul_ (beta2 ).addcmul_ (g_float , g_float , value = 1.0 - beta2 )
374+
375+ correction1 = 1.0 - beta1 ** step
376+ correction2 = math .sqrt (1.0 - beta2 ** step )
377+ step_size = - lr * correction2 / correction1
378+
379+ if weight_decay > 0.0 :
380+ p_float .mul_ (1.0 - lr * weight_decay )
381+
382+ update = state1 / (state2 .sqrt () + eps * correction2 )
383+
384+ update_scale = _compute_update_norm_and_scale (update , unorm_vec , max_unorm , param_norm )
385+ p_float .add_ (update , alpha = step_size * update_scale )
386+
387+ elif optimizer_name == "ademamix" :
388+ # AdEMAMix (2-state): state1 shape is (2, *p.shape), state1[0]=m1, state1[1]=m2
389+ m1 = state1 [0 ]
390+ m2 = state1 [1 ]
391+ nu = state2
392+
393+ m1 .mul_ (beta1 ).add_ (g_float , alpha = 1.0 - beta1 )
394+ m2 .mul_ (beta3 ).add_ (g_float , alpha = 1.0 - beta3 )
395+ nu .mul_ (beta2 ).addcmul_ (g_float , g_float , value = 1.0 - beta2 )
396+
397+ correction1 = 1.0 - beta1 ** step
398+ correction2 = math .sqrt (1.0 - beta2 ** step )
399+
400+ if weight_decay > 0.0 :
401+ p_float .mul_ (1.0 - lr * weight_decay )
402+
403+ mixed_momentum = (m1 / correction1 ) + (alpha * m2 )
404+ adaptive_term = (nu .sqrt () / correction2 ) + eps
405+ p_float .add_ (mixed_momentum / adaptive_term , alpha = - lr )
406+
407+ elif optimizer_name in ("momentum" , "lars" ):
408+ # SGD with momentum / LARS (1-state)
409+ g_wd = g_float .add (p_float , alpha = weight_decay ) if weight_decay > 0.0 else g_float
410+
411+ if step == 1 :
412+ state1 .copy_ (g_wd )
413+ else :
414+ state1 .mul_ (beta1 ).add_ (g_wd )
415+
416+ update_scale = _compute_update_norm_and_scale (state1 , unorm_vec , max_unorm , param_norm )
417+ p_float .add_ (state1 , alpha = - lr * update_scale )
418+
419+ elif optimizer_name == "lion" :
420+ # Lion (2-state sign update)
421+ if weight_decay > 0.0 :
422+ p_float .mul_ (1.0 - lr * weight_decay )
423+
424+ update = state1 .mul (beta1 ).add (g_float , alpha = 1.0 - beta1 )
425+ p_float .add_ (update .sign (), alpha = - lr )
426+
427+ state1 .mul_ (beta2 ).add_ (g_float , alpha = 1.0 - beta2 )
428+
429+ elif optimizer_name == "rmsprop" :
430+ # RMSprop (1-state)
431+ g_wd = g_float .add (p_float , alpha = weight_decay ) if weight_decay > 0.0 else g_float
432+ state1 .mul_ (beta1 ).addcmul_ (g_wd , g_wd , value = 1.0 - beta1 )
433+
434+ update = g_wd / (state1 .sqrt () + eps )
435+ update_scale = _compute_update_norm_and_scale (update , unorm_vec , max_unorm , param_norm )
436+ p_float .add_ (update , alpha = - lr * update_scale )
437+
438+ elif optimizer_name == "adagrad" :
439+ # Adagrad (1-state)
440+ g_wd = g_float .add (p_float , alpha = weight_decay ) if weight_decay > 0.0 else g_float
441+ state1 .addcmul_ (g_wd , g_wd , value = 1.0 )
442+
443+ update = g_wd / (state1 .sqrt () + eps )
444+ p_float .add_ (update , alpha = - lr )
445+
446+ else :
447+ raise ValueError (f"Unsupported optimizer for CPU: { optimizer_name } " )
448+
449+ # Write back to original precision
450+ p .data .copy_ (p_float )
451+
452+
453+ register_kernel ("bitsandbytes::optimizer_update_32bit" , "cpu" )(_optimizer_update_32bit_cpu )
454+
455+
456+ @torch .no_grad ()
457+ def _dequant_blockwise_fp32_direct (
458+ A_uint8 : torch .Tensor , absmax : torch .Tensor , code : torch .Tensor , blocksize : int
459+ ) -> torch .Tensor :
460+ return torch .ops .bitsandbytes .dequantize_blockwise (A_uint8 , absmax , code , blocksize , torch .float32 )
461+
462+
463+ def _quant_blockwise_fp32_direct (
464+ A_fp32 : torch .Tensor , code : torch .Tensor , absmax_out : torch .Tensor , out_uint8 : torch .Tensor , blocksize : int
465+ ) -> None :
466+ out , absmax = torch .ops .bitsandbytes .quantize_blockwise (A_fp32 , code , blocksize )
467+ out_uint8 .copy_ (out )
468+ absmax_out .copy_ (absmax )
469+
470+
471+ def _optimizer_update_8bit_blockwise_cpu (
472+ optimizer_name : str ,
473+ g : torch .Tensor ,
474+ p : torch .Tensor ,
475+ state1 : torch .Tensor ,
476+ state2 : Optional [torch .Tensor ],
477+ beta1 : float ,
478+ beta2 : float ,
479+ beta3 : float ,
480+ alpha : float ,
481+ eps : float ,
482+ step : int ,
483+ lr : float ,
484+ qmap1 : torch .Tensor ,
485+ qmap2 : Optional [torch .Tensor ],
486+ absmax1 : torch .Tensor ,
487+ absmax2 : Optional [torch .Tensor ],
488+ weight_decay : float ,
489+ gnorm_scale : float ,
490+ skip_zeros : bool = False ,
491+ ) -> None :
492+ blocksize = 256
493+
494+ # Dequantize states
495+ if optimizer_name == "ademamix" and absmax1 .ndim == 2 :
496+ s1_1 = _dequant_blockwise_fp32_direct (state1 [0 ], absmax1 [0 ], qmap1 , blocksize )
497+ s1_2 = _dequant_blockwise_fp32_direct (state1 [1 ], absmax1 [1 ], qmap1 , blocksize )
498+ state1_fp32 = torch .stack ([s1_1 , s1_2 ])
499+ else :
500+ state1_fp32 = _dequant_blockwise_fp32_direct (state1 , absmax1 , qmap1 , blocksize )
501+
502+ state2_fp32 = None
503+ if state2 is not None and qmap2 is not None and absmax2 is not None :
504+ state2_fp32 = _dequant_blockwise_fp32_direct (state2 , absmax2 , qmap2 , blocksize )
505+
506+ grad = g .float () * gnorm_scale
507+ p_fp32 = p .data .float ()
508+
509+ if optimizer_name in ("adam" , "lamb" ):
510+ state1_fp32 .mul_ (beta1 ).add_ (grad , alpha = 1.0 - beta1 )
511+ state2_fp32 .mul_ (beta2 ).addcmul_ (grad , grad , value = 1.0 - beta2 )
512+
513+ correction1 = 1.0 - beta1 ** step
514+ correction2 = math .sqrt (1.0 - beta2 ** step )
515+
516+ denom = (state2_fp32 .sqrt () / correction2 ).add_ (eps )
517+ if weight_decay > 0.0 :
518+ p_fp32 .mul_ (1.0 - lr * weight_decay )
519+ p_fp32 .addcdiv_ (state1_fp32 , denom , value = - lr / correction1 )
520+
521+ elif optimizer_name == "ademamix" :
522+ m1_fp32 , m2_fp32 = state1_fp32 [0 ], state1_fp32 [1 ]
523+ nu_fp32 = state2_fp32
524+
525+ m1_fp32 .mul_ (beta1 ).add_ (grad , alpha = 1.0 - beta1 )
526+ m2_fp32 .mul_ (beta3 ).add_ (grad , alpha = 1.0 - beta3 )
527+ nu_fp32 .mul_ (beta2 ).addcmul_ (grad , grad , value = 1.0 - beta2 )
528+
529+ correction1 = 1.0 - beta1 ** step
530+ correction2 = math .sqrt (1.0 - beta2 ** step )
531+
532+ update = (m1_fp32 / correction1 + alpha * m2_fp32 ) / (nu_fp32 .sqrt () / correction2 + eps )
533+ if weight_decay > 0.0 :
534+ p_fp32 .mul_ (1.0 - lr * weight_decay )
535+ p_fp32 .add_ (update , alpha = - lr )
536+
537+ state1_fp32 = torch .stack ([m1_fp32 , m2_fp32 ])
538+
539+ elif optimizer_name in ("momentum" , "lars" ):
540+ grad .add_ (p_fp32 , alpha = weight_decay )
541+ if step == 1 :
542+ state1_fp32 .copy_ (grad )
543+ else :
544+ state1_fp32 .mul_ (beta1 ).add_ (grad )
545+ p_fp32 .add_ (state1_fp32 , alpha = - lr )
546+
547+ elif optimizer_name == "lion" :
548+ if weight_decay > 0.0 :
549+ p_fp32 .mul_ (1.0 - lr * weight_decay )
550+
551+ update_dir = torch .sign (state1_fp32 .mul (beta1 ) + grad .mul (1.0 - beta1 ))
552+ p_fp32 .add_ (update_dir , alpha = - lr )
553+
554+ state1_fp32 .mul_ (beta2 ).add_ (grad , alpha = 1.0 - beta2 )
555+
556+ elif optimizer_name == "rmsprop" :
557+ grad .add_ (p_fp32 , alpha = weight_decay )
558+ state1_fp32 .mul_ (beta1 ).addcmul_ (grad , grad , value = 1.0 - beta1 )
559+ p_fp32 .addcdiv_ (grad , state1_fp32 .sqrt ().add_ (eps ), value = - lr )
560+
561+ elif optimizer_name == "adagrad" :
562+ grad .add_ (p_fp32 , alpha = weight_decay )
563+ state1_fp32 .addcmul_ (grad , grad , value = 1.0 )
564+ p_fp32 .addcdiv_ (grad , state1_fp32 .sqrt ().add_ (eps ), value = - lr )
565+
566+ else :
567+ raise ValueError (f"Unsupported optimizer for CPU 8-bit: { optimizer_name } " )
568+
569+ p .data .copy_ (p_fp32 )
570+
571+ # Re-quantize states
572+ if optimizer_name == "ademamix" :
573+ _quant_blockwise_fp32_direct (state1_fp32 [0 ], qmap1 , absmax1 [0 ], state1 [0 ], blocksize )
574+ _quant_blockwise_fp32_direct (state1_fp32 [1 ], qmap1 , absmax1 [1 ], state1 [1 ], blocksize )
575+ _quant_blockwise_fp32_direct (state2_fp32 , qmap2 , absmax2 , state2 , blocksize )
576+ else :
577+ _quant_blockwise_fp32_direct (state1_fp32 , qmap1 , absmax1 , state1 , blocksize )
578+ if state2_fp32 is not None :
579+ _quant_blockwise_fp32_direct (state2_fp32 , qmap2 , absmax2 , state2 , blocksize )
580+
581+
582+ register_kernel ("bitsandbytes::optimizer_update_8bit_blockwise" , "cpu" )(_optimizer_update_8bit_blockwise_cpu )
0 commit comments