Skip to content

Commit 577e7b5

Browse files
TimDettmersclaude
andauthored
Fix Params4bit attribute access for FSDP state_dict traversal (#1866)
* Fix Params4bit/QuantState attribute access for FSDP state_dict traversal (#1405) PyTorch's FSDP state_dict machinery (_get_fqns) resolves dotted FQN paths via getattr. For 4-bit quantized models, state_dict keys like "weight.absmax" and "weight.quant_state.bitsandbytes__nf4" require attribute access on Params4bit and QuantState objects that previously didn't exist. Add __getattr__ to Params4bit that proxies known QuantState attributes (including the quant_map→code alias used by as_dict serialization), and add __getattr__ to QuantState that handles the packed "bitsandbytes__*" keys. This allows FSDP's get_model_state_dict with cpu_offload=True to traverse the full FQN namespace without AttributeError. Verified with single-GPU FSDP integration test: without fix, fails with "'Params4bit' object has no attribute 'absmax'"; with fix, successfully produces all 7 state_dict keys. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Add FSDP integration test for 4-bit state_dict save (#1405) Add a subprocess-based pytest test that launches a single-GPU FSDP process via torchrun to verify get_model_state_dict with cpu_offload=True works for QLoRA-style models with frozen Params4bit base weights. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 79ce819 commit 577e7b5

File tree

4 files changed

+224
-0
lines changed

4 files changed

+224
-0
lines changed

bitsandbytes/functional.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,18 @@ def __init__(
487487
self.state2 = state2
488488
self.nested = state2 is not None
489489

490+
def __getattr__(self, name):
491+
# Support attribute access for packed state_dict keys like "bitsandbytes__nf4".
492+
# PyTorch's FSDP state_dict traversal (_get_fqns) resolves dotted FQN paths via
493+
# getattr. The packed key "quant_state.bitsandbytes__nf4" causes it to call
494+
# getattr(quant_state_obj, "bitsandbytes__nf4"), which we handle here.
495+
if name.startswith("bitsandbytes__"):
496+
qs_dict = self.as_dict(packed=True)
497+
packed_key = "quant_state." + name
498+
if packed_key in qs_dict:
499+
return qs_dict[packed_key]
500+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
501+
490502
def __getitem__(self, idx):
491503
"""
492504
ensures compatibility with older quant state scheme with nested lists.

bitsandbytes/nn/modules.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,43 @@ def __setstate__(self, state):
255255
self.bnb_quantized = state["bnb_quantized"]
256256
self.module = state["module"]
257257

258+
# Map from state_dict key names (as produced by QuantState.as_dict) to
259+
# the actual QuantState attribute/access path. FSDP's _get_fqns() resolves
260+
# dotted FQN keys via getattr, so "weight.quant_map" becomes
261+
# getattr(weight, "quant_map") — we must map that to quant_state.code.
262+
_QUANT_STATE_ATTR_MAP = {
263+
# Direct QuantState attributes
264+
"absmax": lambda qs: qs.absmax,
265+
"code": lambda qs: qs.code,
266+
"blocksize": lambda qs: qs.blocksize,
267+
"dtype": lambda qs: qs.dtype,
268+
"shape": lambda qs: qs.shape,
269+
"offset": lambda qs: qs.offset,
270+
"state2": lambda qs: qs.state2,
271+
# as_dict serializes code → "quant_map"
272+
"quant_map": lambda qs: qs.code,
273+
"quant_type": lambda qs: qs.quant_type,
274+
# as_dict serializes nested state2 attributes under "nested_*" keys
275+
"nested_absmax": lambda qs: qs.state2.absmax,
276+
"nested_blocksize": lambda qs: qs.state2.blocksize,
277+
"nested_quant_map": lambda qs: qs.state2.code,
278+
"nested_dtype": lambda qs: qs.state2.dtype,
279+
"nested_offset": lambda qs: qs.offset,
280+
}
281+
282+
def __getattr__(self, name):
283+
# Proxy known QuantState attributes so that PyTorch's FSDP state_dict
284+
# machinery (which traverses FQN paths via getattr) can find them.
285+
accessor = self._QUANT_STATE_ATTR_MAP.get(name)
286+
if accessor is not None:
287+
quant_state = self.__dict__.get("quant_state")
288+
if quant_state is not None:
289+
try:
290+
return accessor(quant_state)
291+
except AttributeError:
292+
pass
293+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
294+
258295
def __deepcopy__(self, memo):
259296
new_instance = type(self).__new__(type(self))
260297
state = self.__getstate__()

tests/fsdp_state_dict_save.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""FSDP state_dict save integration test for 4-bit quantized models (#1405).
2+
3+
This script must be launched via torchrun (not directly):
4+
torchrun --nproc_per_node=1 tests/fsdp_state_dict_save.py
5+
6+
It wraps a QLoRA-style model (frozen 4-bit base + trainable adapter) in FSDP
7+
and calls get_model_state_dict with cpu_offload=True, which exercises the
8+
_get_fqns() getattr traversal that previously crashed with:
9+
AttributeError: 'Params4bit' object has no attribute 'absmax'
10+
"""
11+
12+
import sys
13+
14+
import torch
15+
import torch.distributed as dist
16+
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
17+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
18+
import torch.nn as nn
19+
20+
import bitsandbytes as bnb
21+
22+
23+
class SimpleQLoRAModel(nn.Module):
24+
"""Minimal model with a frozen 4-bit base layer and a trainable adapter."""
25+
26+
def __init__(self, quant_type="nf4"):
27+
super().__init__()
28+
self.base = bnb.nn.Linear4bit(64, 64, bias=False, quant_type=quant_type)
29+
self.adapter = nn.Linear(64, 64, bias=False)
30+
31+
def forward(self, x):
32+
return self.base(x) + self.adapter(x)
33+
34+
35+
def main():
36+
dist.init_process_group(backend="nccl")
37+
rank = dist.get_rank()
38+
torch.cuda.set_device(rank)
39+
40+
errors = []
41+
42+
for quant_type in ("nf4", "fp4"):
43+
model = SimpleQLoRAModel(quant_type=quant_type)
44+
model = model.to("cuda")
45+
46+
# Freeze quantized base weights (as in real QLoRA)
47+
for p in model.base.parameters():
48+
p.requires_grad = False
49+
50+
# Tell FSDP to ignore the frozen quantized params (can't flatten int dtypes)
51+
ignored = list(model.base.parameters())
52+
fsdp_model = FSDP(model, device_id=rank, ignored_states=ignored, use_orig_params=True)
53+
54+
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
55+
try:
56+
state_dict = get_model_state_dict(fsdp_model, options=options)
57+
58+
# Verify expected keys are present
59+
expected_substrings = ["base.weight", "absmax", "quant_map", "adapter.weight"]
60+
for substr in expected_substrings:
61+
if not any(substr in k for k in state_dict.keys()):
62+
errors.append(f"{quant_type}: missing key containing '{substr}' in {list(state_dict.keys())}")
63+
64+
print(f"{quant_type}: SUCCESS ({len(state_dict)} keys)", flush=True)
65+
except Exception as e:
66+
errors.append(f"{quant_type}: {type(e).__name__}: {e}")
67+
print(f"{quant_type}: FAILED: {e}", flush=True)
68+
69+
dist.destroy_process_group()
70+
71+
if errors:
72+
print("\nFAILURES:\n" + "\n".join(errors), file=sys.stderr, flush=True)
73+
sys.exit(1)
74+
else:
75+
print("\nAll FSDP state_dict tests passed.", flush=True)
76+
sys.exit(0)
77+
78+
79+
if __name__ == "__main__":
80+
main()

tests/test_linear4bit.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import copy
22
import os
3+
import pathlib
34
import pickle
45
import platform
6+
import subprocess
57
import sys
68
from tempfile import TemporaryDirectory
79

@@ -431,3 +433,96 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st
431433
grad_compiled = x.grad.clone()
432434

433435
torch.testing.assert_close(grad_compiled, grad_ref)
436+
437+
438+
@pytest.mark.parametrize("device", get_available_devices())
439+
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
440+
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
441+
def test_params4bit_quant_state_attr_access(device, quant_type, compress_statistics):
442+
"""Test that Params4bit proxies QuantState attributes for FSDP state_dict traversal (#1405).
443+
444+
PyTorch's FSDP state_dict machinery traverses FQN paths like
445+
'model.layers.0.weight.absmax' using getattr(). This test verifies
446+
that Params4bit and QuantState expose the attributes that appear as
447+
state_dict keys so that _get_fqns() traversal succeeds.
448+
"""
449+
if device == "hpu" and not is_supported_on_hpu(quant_type):
450+
pytest.skip("This configuration is not supported on HPU.")
451+
452+
layer = bnb.nn.Linear4bit(
453+
64,
454+
64,
455+
bias=False,
456+
compress_statistics=compress_statistics,
457+
quant_type=quant_type,
458+
)
459+
layer = layer.to(device)
460+
w = layer.weight
461+
462+
assert w.quant_state is not None, "quant_state should be set after quantization"
463+
464+
# Direct QuantState attributes proxied through Params4bit
465+
assert torch.equal(w.absmax, w.quant_state.absmax)
466+
assert torch.equal(w.code, w.quant_state.code)
467+
468+
# "quant_map" is how as_dict() serializes "code" — FSDP uses this key name
469+
assert torch.equal(w.quant_map, w.quant_state.code)
470+
471+
# QuantState packed key: as_dict(packed=True) produces "quant_state.bitsandbytes__<type>"
472+
# FSDP resolves this as getattr(quant_state_obj, "bitsandbytes__<type>")
473+
packed_attr = f"bitsandbytes__{quant_type}"
474+
assert hasattr(w.quant_state, packed_attr)
475+
packed_val = getattr(w.quant_state, packed_attr)
476+
assert isinstance(packed_val, torch.Tensor)
477+
478+
# Simulate the full FSDP _get_fqns traversal for all state_dict keys
479+
state_dict_keys = list(w.quant_state.as_dict(packed=True).keys())
480+
for key in state_dict_keys:
481+
# Each key is relative to "weight.", e.g. "absmax" or "quant_state.bitsandbytes__nf4"
482+
parts = key.split(".")
483+
obj = w
484+
for part in parts:
485+
obj = getattr(obj, part)
486+
assert obj is not None
487+
488+
# hasattr should return True for proxied attrs, False for unknown ones
489+
assert hasattr(w, "absmax")
490+
assert hasattr(w, "code")
491+
assert hasattr(w, "quant_map")
492+
assert not hasattr(w, "nonexistent_attribute")
493+
494+
# Unknown attributes must still raise AttributeError
495+
with pytest.raises(AttributeError, match="nonexistent_attribute"):
496+
_ = w.nonexistent_attribute
497+
498+
# Verify that normal Params4bit attributes are unaffected by __getattr__
499+
assert isinstance(w.quant_state, bnb.functional.QuantState)
500+
assert isinstance(w.bnb_quantized, bool)
501+
assert w.bnb_quantized is True
502+
503+
504+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="FSDP requires CUDA")
505+
@pytest.mark.skipif(
506+
not torch.distributed.is_nccl_available(),
507+
reason="FSDP test requires NCCL backend",
508+
)
509+
def test_fsdp_state_dict_save_4bit():
510+
"""Integration test: FSDP get_model_state_dict with cpu_offload on a 4-bit model (#1405).
511+
512+
Launches a single-GPU FSDP process via torchrun to exercise the real
513+
_get_fqns() code path that previously crashed with:
514+
AttributeError: 'Params4bit' object has no attribute 'absmax'
515+
"""
516+
script = pathlib.Path(__file__).with_name("fsdp_state_dict_save.py")
517+
result = subprocess.run(
518+
["torchrun", "--nproc_per_node=1", str(script)],
519+
capture_output=True,
520+
text=True,
521+
timeout=120,
522+
)
523+
if result.returncode != 0:
524+
pytest.fail(
525+
f"FSDP state_dict test failed (exit {result.returncode}):\n"
526+
f"stdout: {result.stdout}\n"
527+
f"stderr: {result.stderr}"
528+
)

0 commit comments

Comments
 (0)