2424
2525name2optimizer_id = {
2626 "momentum" : MOMENTUM ,
27+ "lars" : MOMENTUM ,
2728 "rmsprop" : RMSPROP ,
2829 "adagrad" : ADAGRAD ,
2930 "adam" : ADAM ,
31+ "lamb" : ADAM ,
3032 "lion" : LION ,
3133 "ademamix" : ADEMAMIX ,
3234}
@@ -313,6 +315,10 @@ def _optimizer_update_1state_32bit_triton_kernel(
313315 "preprocess" : _optimizer_precondition_2state_32bit ,
314316 "update" : _optimizer_update_2state_32bit_triton_kernel ,
315317 },
318+ "lamb" : {
319+ "preprocess" : _optimizer_precondition_2state_32bit ,
320+ "update" : _optimizer_update_2state_32bit_triton_kernel ,
321+ },
316322 "ademamix" : {
317323 "preprocess" : _optimizer_precondition_2state_32bit ,
318324 "update" : _optimizer_update_2state_32bit_triton_kernel ,
@@ -321,6 +327,10 @@ def _optimizer_update_1state_32bit_triton_kernel(
321327 "preprocess" : _optimizer_precondition_1state_32bit ,
322328 "update" : _optimizer_update_1state_32bit_triton_kernel ,
323329 },
330+ "lars" : {
331+ "preprocess" : _optimizer_precondition_1state_32bit ,
332+ "update" : _optimizer_update_1state_32bit_triton_kernel ,
333+ },
324334 "rmsprop" : {
325335 "preprocess" : _optimizer_precondition_1state_32bit ,
326336 "update" : _optimizer_update_1state_32bit_triton_kernel ,
@@ -1065,9 +1075,11 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel(
10651075
10661076name2optimizer_fn = {
10671077 "momentum" : _optimizer_update_1state_8bit_blockwise_triton_kernel ,
1078+ "lars" : _optimizer_update_1state_8bit_blockwise_triton_kernel ,
10681079 "rmsprop" : _optimizer_update_1state_8bit_blockwise_triton_kernel ,
10691080 "adagrad" : _optimizer_update_1state_8bit_blockwise_triton_kernel ,
10701081 "adam" : _optimizer_update_2state_8bit_blockwise_triton_kernel ,
1082+ "lamb" : _optimizer_update_2state_8bit_blockwise_triton_kernel ,
10711083 "lion" : _optimizer_update_1state_8bit_blockwise_triton_kernel ,
10721084 "ademamix" : _optimizer_update_2state_8bit_blockwise_triton_kernel ,
10731085}
0 commit comments