Skip to content

Commit 75ef76b

Browse files
TimDettmersclaude
andauthored
Fix GlobalOptimManager.override_config not propagating to optimizer (#1269) (#1869)
override_config only wrote to pid2config, but get_config only read from index2config. When override_config was called after register_parameters (the documented usage order), the override was never seen by the optimizer. Fix by having get_config also check pid2config as a fallback after index2config, so overrides work regardless of call order. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3934632 commit 75ef76b

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

bitsandbytes/optim/optimizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,13 @@ def get_config(self, gindex, pindex, group):
350350

351351
if (gindex, pindex) in self.mng.index2config:
352352
config.update(self.mng.index2config[(gindex, pindex)])
353+
354+
# Also check pid2config as a fallback so that override_config works
355+
# regardless of whether it was called before or after register_parameters.
356+
p = self.param_groups[gindex]["params"][pindex]
357+
if id(p) in self.mng.pid2config:
358+
config.update(self.mng.pid2config[id(p)])
359+
353360
return config
354361

355362
def init_state(self, group, p, gindex, pindex):

tests/test_optim.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,39 @@ def test_global_config(dim1, dim2, gtype, device):
309309
assert adam2.state[p3]["state2"].dtype == torch.uint8
310310

311311

312+
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
313+
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
314+
def test_override_config_after_register(device):
315+
"""Test that override_config works when called after register_parameters (issue #1269)."""
316+
if device not in ["cuda", "xpu"]:
317+
pytest.skip("Optimizers are only supported on CUDA and XPU")
318+
319+
mng = bnb.optim.GlobalOptimManager.get_instance()
320+
mng.initialize()
321+
322+
p1 = torch.randn(64, 64, device="cpu") * 0.1
323+
p2 = torch.randn(64, 64, device="cpu") * 0.1
324+
325+
# Register first, override second (the documented order)
326+
mng.register_parameters([p1, p2])
327+
p1 = p1.to(device)
328+
p2 = p2.to(device)
329+
330+
# Override p2 to use 8-bit after register_parameters
331+
mng.override_config(p2, "optim_bits", 8)
332+
333+
adam = bnb.optim.Adam([p1, p2], lr=0.001, optim_bits=32)
334+
335+
# Run a step to trigger init_state
336+
p1.grad = torch.randn_like(p1) * 0.1
337+
p2.grad = torch.randn_like(p2) * 0.1
338+
adam.step()
339+
340+
# p1 should be 32-bit, p2 should be 8-bit
341+
assert adam.state[p1]["state1"].dtype == torch.float32
342+
assert adam.state[p2]["state1"].dtype == torch.uint8
343+
344+
312345
optimizer_names_8bit = [
313346
"adam8bit_blockwise",
314347
"lion8bit_blockwise",

0 commit comments

Comments
 (0)