Skip to content

Commit 91f0cef

Browse files
Mr-Neutr0nTimDettmersclaude
authored
Add LARS to str2optimizer32bit dictionary (#1855)
* Add LARS to str2optimizer32bit dictionary LARS optimizer was missing from str2optimizer32bit, causing KeyError when using LARS32bit optimizer. LARS uses momentum-based kernels since it's essentially SGD with momentum plus layerwise adaptive learning rates. Fixes #1810 * Fix 32-bit error message and add LARS test coverage Fix the error message in _optimizer_update_32bit_impl that incorrectly displayed str2optimizer8bit_blockwise keys instead of str2optimizer32bit keys. Add LARS to the parametrized 32-bit optimizer tests using PytorchLARS as the reference implementation, with bf16 skip since momentum kernels lack a bf16 variant. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Tim Dettmers <tim.dettmers@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 577e7b5 commit 91f0cef

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,10 @@ def _gemv_4bit_impl(
578578
lib.cademamix32bit_grad_fp16,
579579
lib.cademamix32bit_grad_bf16,
580580
),
581+
"lars": (
582+
lib.cmomentum32bit_grad_32,
583+
lib.cmomentum32bit_grad_16,
584+
),
581585
}
582586

583587
str2optimizer8bit_blockwise = {
@@ -637,7 +641,7 @@ def _optimizer_update_32bit_impl(
637641
optim_fns = str2optimizer32bit.get(optimizer_name, None)
638642
if optim_fns is None:
639643
raise ValueError(
640-
f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}"
644+
f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer32bit.keys())}"
641645
)
642646
if g.dtype == torch.float32:
643647
optim_func = optim_fns[0]

tests/test_optim.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ def rm_path(path):
102102
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
103103
)
104104

105+
str2optimizers["lars"] = (
106+
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
107+
lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
108+
)
109+
105110
str2optimizers["rmsprop"] = (
106111
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
107112
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
@@ -118,6 +123,7 @@ def rm_path(path):
118123
str2statenames["lion"] = [("exp_avg", "state1")]
119124
str2statenames["paged_lion"] = [("exp_avg", "state1")]
120125
str2statenames["momentum"] = [("momentum_buffer", "state1")]
126+
str2statenames["lars"] = [("momentum_buffer", "state1")]
121127
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
122128
str2statenames["rmsprop"] = [("square_avg", "state1")]
123129

@@ -155,6 +161,7 @@ def rm_path(path):
155161
"paged_adamw",
156162
"paged_adam",
157163
"momentum",
164+
"lars",
158165
"rmsprop",
159166
"lion",
160167
"paged_lion",
@@ -181,7 +188,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
181188
if optim_name.startswith("paged_") and device == "xpu":
182189
pytest.skip("Paged optimizers are not supported on XPU currently.")
183190

184-
if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
191+
if gtype == torch.bfloat16 and optim_name in ["momentum", "lars", "rmsprop"]:
185192
pytest.skip()
186193
if dim1 == 1 and dim2 == 1:
187194
return

0 commit comments

Comments
 (0)