Skip to content

Commit dd599c2

Browse files
authored
Fix LARS/LAMB optimizer support and non-contiguous tensor handling on XPU (#1902)
* enable lars on XPU Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix optimizer Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix contiguous Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 88e802c commit dd599c2

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

bitsandbytes/backends/default/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,11 @@ def _(
357357

358358
name2optimizer_id = {
359359
"momentum": MOMENTUM,
360+
"lars": MOMENTUM,
360361
"rmsprop": RMSPROP,
361362
"adagrad": ADAGRAD,
362363
"adam": ADAM,
364+
"lamb": ADAM,
363365
"lion": LION,
364366
"ademamix": ADEMAMIX,
365367
}

bitsandbytes/backends/triton/kernels_optim.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424

2525
name2optimizer_id = {
2626
"momentum": MOMENTUM,
27+
"lars": MOMENTUM,
2728
"rmsprop": RMSPROP,
2829
"adagrad": ADAGRAD,
2930
"adam": ADAM,
31+
"lamb": ADAM,
3032
"lion": LION,
3133
"ademamix": ADEMAMIX,
3234
}
@@ -121,7 +123,8 @@ def _optimizer_precondition_1state_32bit(
121123

122124
if OPTIMIZER_ID == 0: # MOMENTUM
123125
if step == 1:
124-
s1_vals = g_vals
126+
# Cast to fp32 to avoid type mismatch: s1_vals is fp32 but g_vals may be fp16.
127+
s1_vals = g_vals.to(tl.float32)
125128
else:
126129
s1_vals = s1_vals * beta1 + g_vals
127130
update_norm = s1_vals * s1_vals
@@ -313,6 +316,10 @@ def _optimizer_update_1state_32bit_triton_kernel(
313316
"preprocess": _optimizer_precondition_2state_32bit,
314317
"update": _optimizer_update_2state_32bit_triton_kernel,
315318
},
319+
"lamb": {
320+
"preprocess": _optimizer_precondition_2state_32bit,
321+
"update": _optimizer_update_2state_32bit_triton_kernel,
322+
},
316323
"ademamix": {
317324
"preprocess": _optimizer_precondition_2state_32bit,
318325
"update": _optimizer_update_2state_32bit_triton_kernel,
@@ -321,6 +328,10 @@ def _optimizer_update_1state_32bit_triton_kernel(
321328
"preprocess": _optimizer_precondition_1state_32bit,
322329
"update": _optimizer_update_1state_32bit_triton_kernel,
323330
},
331+
"lars": {
332+
"preprocess": _optimizer_precondition_1state_32bit,
333+
"update": _optimizer_update_1state_32bit_triton_kernel,
334+
},
324335
"rmsprop": {
325336
"preprocess": _optimizer_precondition_1state_32bit,
326337
"update": _optimizer_update_1state_32bit_triton_kernel,
@@ -1065,9 +1076,11 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel(
10651076

10661077
name2optimizer_fn = {
10671078
"momentum": _optimizer_update_1state_8bit_blockwise_triton_kernel,
1079+
"lars": _optimizer_update_1state_8bit_blockwise_triton_kernel,
10681080
"rmsprop": _optimizer_update_1state_8bit_blockwise_triton_kernel,
10691081
"adagrad": _optimizer_update_1state_8bit_blockwise_triton_kernel,
10701082
"adam": _optimizer_update_2state_8bit_blockwise_triton_kernel,
1083+
"lamb": _optimizer_update_2state_8bit_blockwise_triton_kernel,
10711084
"lion": _optimizer_update_1state_8bit_blockwise_triton_kernel,
10721085
"ademamix": _optimizer_update_2state_8bit_blockwise_triton_kernel,
10731086
}

bitsandbytes/backends/triton/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t
1818
torch._check_is_size(blocksize)
1919
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
2020
with torch_accelerator_module.device(A.device):
21-
out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A, code, blocksize)
21+
out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A.contiguous(), code, blocksize)
2222
return out, absmax.float()
2323

2424

@@ -30,7 +30,7 @@ def dequantize_blockwise(
3030
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
3131
with torch_accelerator_module.device(A.device):
3232
out = kernels_8bit_quant.dequant_8bit_blockwise(
33-
A,
33+
A.contiguous(),
3434
absmax,
3535
code,
3636
blocksize,

0 commit comments

Comments
 (0)