Skip to content

Commit 4019407

Browse files
committed
enable xpu lars
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 2ad0744 commit 4019407

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

bitsandbytes/backends/default/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,11 @@ def _(
357357

358358
name2optimizer_id = {
359359
"momentum": MOMENTUM,
360+
"lars": MOMENTUM,
360361
"rmsprop": RMSPROP,
361362
"adagrad": ADAGRAD,
362363
"adam": ADAM,
364+
"lamb": ADAM,
363365
"lion": LION,
364366
"ademamix": ADEMAMIX,
365367
}

bitsandbytes/backends/triton/kernels_optim.py

100644100755
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424

2525
name2optimizer_id = {
2626
"momentum": MOMENTUM,
27+
"lars": MOMENTUM,
2728
"rmsprop": RMSPROP,
2829
"adagrad": ADAGRAD,
2930
"adam": ADAM,
31+
"lamb": ADAM,
3032
"lion": LION,
3133
"ademamix": ADEMAMIX,
3234
}
@@ -313,6 +315,10 @@ def _optimizer_update_1state_32bit_triton_kernel(
313315
"preprocess": _optimizer_precondition_2state_32bit,
314316
"update": _optimizer_update_2state_32bit_triton_kernel,
315317
},
318+
"lamb": {
319+
"preprocess": _optimizer_precondition_2state_32bit,
320+
"update": _optimizer_update_2state_32bit_triton_kernel,
321+
},
316322
"ademamix": {
317323
"preprocess": _optimizer_precondition_2state_32bit,
318324
"update": _optimizer_update_2state_32bit_triton_kernel,
@@ -321,6 +327,10 @@ def _optimizer_update_1state_32bit_triton_kernel(
321327
"preprocess": _optimizer_precondition_1state_32bit,
322328
"update": _optimizer_update_1state_32bit_triton_kernel,
323329
},
330+
"lars": {
331+
"preprocess": _optimizer_precondition_1state_32bit,
332+
"update": _optimizer_update_1state_32bit_triton_kernel,
333+
},
324334
"rmsprop": {
325335
"preprocess": _optimizer_precondition_1state_32bit,
326336
"update": _optimizer_update_1state_32bit_triton_kernel,
@@ -1065,9 +1075,11 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel(
10651075

10661076
name2optimizer_fn = {
10671077
"momentum": _optimizer_update_1state_8bit_blockwise_triton_kernel,
1078+
"lars": _optimizer_update_1state_8bit_blockwise_triton_kernel,
10681079
"rmsprop": _optimizer_update_1state_8bit_blockwise_triton_kernel,
10691080
"adagrad": _optimizer_update_1state_8bit_blockwise_triton_kernel,
10701081
"adam": _optimizer_update_2state_8bit_blockwise_triton_kernel,
1082+
"lamb": _optimizer_update_2state_8bit_blockwise_triton_kernel,
10711083
"lion": _optimizer_update_1state_8bit_blockwise_triton_kernel,
10721084
"ademamix": _optimizer_update_2state_8bit_blockwise_triton_kernel,
10731085
}

0 commit comments

Comments
 (0)