Skip to content

Commit 997510d

Browse files
TimDettmersclaude
andcommitted
Fix GlobalOptimManager.override_config not propagating to optimizer (#1269)
override_config only wrote to pid2config, but get_config only read from index2config. When override_config was called after register_parameters (the documented usage order), the override was never seen by the optimizer. Fix by having get_config also check pid2config as a fallback after index2config, so overrides work regardless of call order. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a2c92f7 commit 997510d

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

bitsandbytes/optim/optimizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,13 @@ def get_config(self, gindex, pindex, group):
352352

353353
if (gindex, pindex) in self.mng.index2config:
354354
config.update(self.mng.index2config[(gindex, pindex)])
355+
356+
# Also check pid2config as a fallback so that override_config works
357+
# regardless of whether it was called before or after register_parameters.
358+
p = self.param_groups[gindex]["params"][pindex]
359+
if id(p) in self.mng.pid2config:
360+
config.update(self.mng.pid2config[id(p)])
361+
355362
return config
356363

357364
def init_state(self, group, p, gindex, pindex):

tests/test_optim.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,39 @@ def test_global_config(dim1, dim2, gtype, device):
302302
assert adam2.state[p3]["state2"].dtype == torch.uint8
303303

304304

305+
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
306+
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
307+
def test_override_config_after_register(device):
308+
"""Test that override_config works when called after register_parameters (issue #1269)."""
309+
if device not in ["cuda", "xpu"]:
310+
pytest.skip("Optimizers are only supported on CUDA and XPU")
311+
312+
mng = bnb.optim.GlobalOptimManager.get_instance()
313+
mng.initialize()
314+
315+
p1 = torch.randn(64, 64, device="cpu") * 0.1
316+
p2 = torch.randn(64, 64, device="cpu") * 0.1
317+
318+
# Register first, override second (the documented order)
319+
mng.register_parameters([p1, p2])
320+
p1 = p1.to(device)
321+
p2 = p2.to(device)
322+
323+
# Override p2 to use 8-bit after register_parameters
324+
mng.override_config(p2, "optim_bits", 8)
325+
326+
adam = bnb.optim.Adam([p1, p2], lr=0.001, optim_bits=32)
327+
328+
# Run a step to trigger init_state
329+
p1.grad = torch.randn_like(p1) * 0.1
330+
p2.grad = torch.randn_like(p2) * 0.1
331+
adam.step()
332+
333+
# p1 should be 32-bit, p2 should be 8-bit
334+
assert adam.state[p1]["state1"].dtype == torch.float32
335+
assert adam.state[p2]["state1"].dtype == torch.uint8
336+
337+
305338
optimizer_names_8bit = [
306339
"adam8bit_blockwise",
307340
"lion8bit_blockwise",

0 commit comments

Comments
 (0)