Skip to content

Commit 857d562

Browse files
TimDettmersclaude
andcommitted
Fix Params4bit/QuantState attribute access for FSDP state_dict traversal (#1405)
PyTorch's FSDP state_dict machinery (_get_fqns) resolves dotted FQN paths via getattr. For 4-bit quantized models, state_dict keys like "weight.absmax" and "weight.quant_state.bitsandbytes__nf4" require attribute access on Params4bit and QuantState objects that previously didn't exist. Add __getattr__ to Params4bit that proxies known QuantState attributes (including the quant_map→code alias used by as_dict serialization), and add __getattr__ to QuantState that handles the packed "bitsandbytes__*" keys. This allows FSDP's get_model_state_dict with cpu_offload=True to traverse the full FQN namespace without AttributeError. Verified with single-GPU FSDP integration test: without fix, fails with "'Params4bit' object has no attribute 'absmax'"; with fix, successfully produces all 7 state_dict keys. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 88c6c71 commit 857d562

File tree

3 files changed

+115
-0
lines changed

3 files changed

+115
-0
lines changed

bitsandbytes/functional.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,18 @@ def __init__(
487487
self.state2 = state2
488488
self.nested = state2 is not None
489489

490+
def __getattr__(self, name):
491+
# Support attribute access for packed state_dict keys like "bitsandbytes__nf4".
492+
# PyTorch's FSDP state_dict traversal (_get_fqns) resolves dotted FQN paths via
493+
# getattr. The packed key "quant_state.bitsandbytes__nf4" causes it to call
494+
# getattr(quant_state_obj, "bitsandbytes__nf4"), which we handle here.
495+
if name.startswith("bitsandbytes__"):
496+
qs_dict = self.as_dict(packed=True)
497+
packed_key = "quant_state." + name
498+
if packed_key in qs_dict:
499+
return qs_dict[packed_key]
500+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
501+
490502
def __getitem__(self, idx):
491503
"""
492504
ensures compatibility with older quant state scheme with nested lists.

bitsandbytes/nn/modules.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,43 @@ def __setstate__(self, state):
256256
self.bnb_quantized = state["bnb_quantized"]
257257
self.module = state["module"]
258258

259+
# Map from state_dict key names (as produced by QuantState.as_dict) to
260+
# the actual QuantState attribute/access path. FSDP's _get_fqns() resolves
261+
# dotted FQN keys via getattr, so "weight.quant_map" becomes
262+
# getattr(weight, "quant_map") — we must map that to quant_state.code.
263+
_QUANT_STATE_ATTR_MAP = {
264+
# Direct QuantState attributes
265+
"absmax": lambda qs: qs.absmax,
266+
"code": lambda qs: qs.code,
267+
"blocksize": lambda qs: qs.blocksize,
268+
"dtype": lambda qs: qs.dtype,
269+
"shape": lambda qs: qs.shape,
270+
"offset": lambda qs: qs.offset,
271+
"state2": lambda qs: qs.state2,
272+
# as_dict serializes code → "quant_map"
273+
"quant_map": lambda qs: qs.code,
274+
"quant_type": lambda qs: qs.quant_type,
275+
# as_dict serializes nested state2 attributes under "nested_*" keys
276+
"nested_absmax": lambda qs: qs.state2.absmax,
277+
"nested_blocksize": lambda qs: qs.state2.blocksize,
278+
"nested_quant_map": lambda qs: qs.state2.code,
279+
"nested_dtype": lambda qs: qs.state2.dtype,
280+
"nested_offset": lambda qs: qs.offset,
281+
}
282+
283+
def __getattr__(self, name):
284+
# Proxy known QuantState attributes so that PyTorch's FSDP state_dict
285+
# machinery (which traverses FQN paths via getattr) can find them.
286+
accessor = self._QUANT_STATE_ATTR_MAP.get(name)
287+
if accessor is not None:
288+
quant_state = self.__dict__.get("quant_state")
289+
if quant_state is not None:
290+
try:
291+
return accessor(quant_state)
292+
except AttributeError:
293+
pass
294+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
295+
259296
def __deepcopy__(self, memo):
260297
new_instance = type(self).__new__(type(self))
261298
state = self.__getstate__()

tests/test_linear4bit.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,3 +431,69 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st
431431
grad_compiled = x.grad.clone()
432432

433433
torch.testing.assert_close(grad_compiled, grad_ref)
434+
435+
436+
@pytest.mark.parametrize("device", get_available_devices())
437+
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
438+
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
439+
def test_params4bit_quant_state_attr_access(device, quant_type, compress_statistics):
440+
"""Test that Params4bit proxies QuantState attributes for FSDP state_dict traversal (#1405).
441+
442+
PyTorch's FSDP state_dict machinery traverses FQN paths like
443+
'model.layers.0.weight.absmax' using getattr(). This test verifies
444+
that Params4bit and QuantState expose the attributes that appear as
445+
state_dict keys so that _get_fqns() traversal succeeds.
446+
"""
447+
if device == "hpu" and not is_supported_on_hpu(quant_type):
448+
pytest.skip("This configuration is not supported on HPU.")
449+
450+
layer = bnb.nn.Linear4bit(
451+
64,
452+
64,
453+
bias=False,
454+
compress_statistics=compress_statistics,
455+
quant_type=quant_type,
456+
)
457+
layer = layer.to(device)
458+
w = layer.weight
459+
460+
assert w.quant_state is not None, "quant_state should be set after quantization"
461+
462+
# Direct QuantState attributes proxied through Params4bit
463+
assert torch.equal(w.absmax, w.quant_state.absmax)
464+
assert torch.equal(w.code, w.quant_state.code)
465+
466+
# "quant_map" is how as_dict() serializes "code" — FSDP uses this key name
467+
assert torch.equal(w.quant_map, w.quant_state.code)
468+
469+
# QuantState packed key: as_dict(packed=True) produces "quant_state.bitsandbytes__<type>"
470+
# FSDP resolves this as getattr(quant_state_obj, "bitsandbytes__<type>")
471+
packed_attr = f"bitsandbytes__{quant_type}"
472+
assert hasattr(w.quant_state, packed_attr)
473+
packed_val = getattr(w.quant_state, packed_attr)
474+
assert isinstance(packed_val, torch.Tensor)
475+
476+
# Simulate the full FSDP _get_fqns traversal for all state_dict keys
477+
state_dict_keys = list(w.quant_state.as_dict(packed=True).keys())
478+
for key in state_dict_keys:
479+
# Each key is relative to "weight.", e.g. "absmax" or "quant_state.bitsandbytes__nf4"
480+
parts = key.split(".")
481+
obj = w
482+
for part in parts:
483+
obj = getattr(obj, part)
484+
assert obj is not None
485+
486+
# hasattr should return True for proxied attrs, False for unknown ones
487+
assert hasattr(w, "absmax")
488+
assert hasattr(w, "code")
489+
assert hasattr(w, "quant_map")
490+
assert not hasattr(w, "nonexistent_attribute")
491+
492+
# Unknown attributes must still raise AttributeError
493+
with pytest.raises(AttributeError, match="nonexistent_attribute"):
494+
_ = w.nonexistent_attribute
495+
496+
# Verify that normal Params4bit attributes are unaffected by __getattr__
497+
assert isinstance(w.quant_state, bnb.functional.QuantState)
498+
assert isinstance(w.bnb_quantized, bool)
499+
assert w.bnb_quantized is True

0 commit comments

Comments
 (0)