Skip to content

Commit 3ef6032

Browse files
TimDettmersclaude
andcommitted
Fix 32-bit error message and add LARS test coverage
Fix the error message in _optimizer_update_32bit_impl that incorrectly displayed str2optimizer8bit_blockwise keys instead of str2optimizer32bit keys. Add LARS to the parametrized 32-bit optimizer tests using PytorchLARS as the reference implementation, with bf16 skip since momentum kernels lack a bf16 variant. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4ee7547 commit 3ef6032

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ def _optimizer_update_32bit_impl(
637637
optim_fns = str2optimizer32bit.get(optimizer_name, None)
638638
if optim_fns is None:
639639
raise ValueError(
640-
f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}"
640+
f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer32bit.keys())}"
641641
)
642642
if g.dtype == torch.float32:
643643
optim_func = optim_fns[0]

tests/test_optim.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ def rm_path(path):
102102
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
103103
)
104104

105+
str2optimizers["lars"] = (
106+
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
107+
lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
108+
)
109+
105110
str2optimizers["rmsprop"] = (
106111
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
107112
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
@@ -118,6 +123,7 @@ def rm_path(path):
118123
str2statenames["lion"] = [("exp_avg", "state1")]
119124
str2statenames["paged_lion"] = [("exp_avg", "state1")]
120125
str2statenames["momentum"] = [("momentum_buffer", "state1")]
126+
str2statenames["lars"] = [("momentum_buffer", "state1")]
121127
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
122128
str2statenames["rmsprop"] = [("square_avg", "state1")]
123129

@@ -155,6 +161,7 @@ def rm_path(path):
155161
"paged_adamw",
156162
"paged_adam",
157163
"momentum",
164+
"lars",
158165
"rmsprop",
159166
"lion",
160167
"paged_lion",
@@ -181,7 +188,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
181188
if optim_name.startswith("paged_") and device == "xpu":
182189
pytest.skip("Paged optimizers are not supported on XPU currently.")
183190

184-
if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
191+
if gtype == torch.bfloat16 and optim_name in ["momentum", "lars", "rmsprop"]:
185192
pytest.skip()
186193
if dim1 == 1 and dim2 == 1:
187194
return

0 commit comments

Comments
 (0)