Skip to content

Commit bad4c86

Browse files
Fix torch.compile graph breaks from Params4bit __getattr__ (#1904)
Replace __getattr__ + _QUANT_STATE_ATTR_MAP on Params4bit with @Property descriptors. Dynamo cannot trace __getattr__ on torch.Tensor subclasses, causing graph breaks that multiply under activation checkpointing. Properties use the descriptor protocol which Dynamo handles correctly. Add regression test that compiles Linear4bit with fullgraph=True and torch.utils.checkpoint to catch this class of issue. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4986b43 commit bad4c86

File tree

2 files changed

+150
-37
lines changed

2 files changed

+150
-37
lines changed

bitsandbytes/nn/modules.py

Lines changed: 79 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -258,42 +258,85 @@ def __setstate__(self, state):
258258
self.bnb_quantized = state["bnb_quantized"]
259259
self.module = state["module"]
260260

261-
# Map from state_dict key names (as produced by QuantState.as_dict) to
262-
# the actual QuantState attribute/access path. FSDP's _get_fqns() resolves
263-
# dotted FQN keys via getattr, so "weight.quant_map" becomes
264-
# getattr(weight, "quant_map") — we must map that to quant_state.code.
265-
_QUANT_STATE_ATTR_MAP = {
266-
# Direct QuantState attributes
267-
"absmax": lambda qs: qs.absmax,
268-
"code": lambda qs: qs.code,
269-
"blocksize": lambda qs: qs.blocksize,
270-
"dtype": lambda qs: qs.dtype,
271-
"shape": lambda qs: qs.shape,
272-
"offset": lambda qs: qs.offset,
273-
"state2": lambda qs: qs.state2,
274-
# as_dict serializes code → "quant_map"
275-
"quant_map": lambda qs: qs.code,
276-
"quant_type": lambda qs: qs.quant_type,
277-
# as_dict serializes nested state2 attributes under "nested_*" keys
278-
"nested_absmax": lambda qs: qs.state2.absmax,
279-
"nested_blocksize": lambda qs: qs.state2.blocksize,
280-
"nested_quant_map": lambda qs: qs.state2.code,
281-
"nested_dtype": lambda qs: qs.state2.dtype,
282-
"nested_offset": lambda qs: qs.offset,
283-
}
284-
285-
def __getattr__(self, name):
286-
# Proxy known QuantState attributes so that PyTorch's FSDP state_dict
287-
# machinery (which traverses FQN paths via getattr) can find them.
288-
accessor = self._QUANT_STATE_ATTR_MAP.get(name)
289-
if accessor is not None:
290-
quant_state = self.__dict__.get("quant_state")
291-
if quant_state is not None:
292-
try:
293-
return accessor(quant_state)
294-
except AttributeError:
295-
pass
296-
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
261+
# Properties that proxy QuantState attributes for FSDP state_dict traversal.
262+
# FSDP's _get_fqns() resolves dotted FQN keys via getattr, e.g. "weight.absmax"
263+
# becomes getattr(weight, "absmax"). Using @property instead of __getattr__
264+
# avoids torch.compile graph breaks (see #1904), since Dynamo can trace
265+
# descriptor protocol access but not __getattr__ on Tensor subclasses.
266+
#
267+
# Note: attributes that collide with Params4bit instance attrs (blocksize,
268+
# quant_type) or Tensor attrs (dtype, shape) are intentionally omitted —
269+
# they are packed into the bitsandbytes__* blob and not traversed by FSDP.
270+
271+
@property
272+
def absmax(self):
273+
qs = self.__dict__.get("quant_state")
274+
if qs is not None:
275+
return qs.absmax
276+
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'absmax'")
277+
278+
@property
279+
def code(self):
280+
qs = self.__dict__.get("quant_state")
281+
if qs is not None:
282+
return qs.code
283+
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'code'")
284+
285+
@property
286+
def quant_map(self):
287+
qs = self.__dict__.get("quant_state")
288+
if qs is not None:
289+
return qs.code
290+
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'quant_map'")
291+
292+
@property
293+
def offset(self):
294+
qs = self.__dict__.get("quant_state")
295+
if qs is not None:
296+
return qs.offset
297+
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'offset'")
298+
299+
@property
300+
def state2(self):
301+
qs = self.__dict__.get("quant_state")
302+
if qs is not None:
303+
return qs.state2
304+
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'state2'")
305+
306+
@property
307+
def nested_absmax(self):
308+
qs = self.__dict__.get("quant_state")
309+
if qs is not None and qs.state2 is not None:
310+
return qs.state2.absmax
311+
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_absmax'")
312+
313+
@property
314+
def nested_blocksize(self):
315+
qs = self.__dict__.get("quant_state")
316+
if qs is not None and qs.state2 is not None:
317+
return qs.state2.blocksize
318+
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_blocksize'")
319+
320+
@property
321+
def nested_quant_map(self):
322+
qs = self.__dict__.get("quant_state")
323+
if qs is not None and qs.state2 is not None:
324+
return qs.state2.code
325+
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_quant_map'")
326+
327+
@property
328+
def nested_dtype(self):
329+
qs = self.__dict__.get("quant_state")
330+
if qs is not None and qs.state2 is not None:
331+
return qs.state2.dtype
332+
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_dtype'")
333+
334+
@property
335+
def nested_offset(self):
336+
qs = self.__dict__.get("quant_state")
337+
if qs is not None:
338+
return qs.offset
339+
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_offset'")
297340

298341
def __deepcopy__(self, memo):
299342
new_instance = type(self).__new__(type(self))

tests/test_linear4bit.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,76 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st
434434
torch.testing.assert_close(grad_compiled, grad_ref)
435435

436436

437+
@pytest.mark.parametrize("device", get_available_devices())
438+
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
439+
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
440+
@pytest.mark.skipif(torch.__version__ < (2, 8, 0, "dev"), reason="fullgraph requires torch 2.8+")
441+
@pytest.mark.skipif(
442+
torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10"
443+
)
444+
def test_linear4bit_torch_compile_activation_checkpointing(device, quant_type, compress_statistics):
445+
"""Regression test for #1904: __getattr__ on Params4bit causes graph breaks under torch.compile.
446+
447+
Activation checkpointing replays the forward pass during backward, which multiplies
448+
attribute accesses on Params4bit. If __getattr__ is defined (instead of @property),
449+
Dynamo cannot trace through it and creates graph breaks. With fullgraph=True, this
450+
causes torch.compile to raise an error rather than silently degrading performance.
451+
"""
452+
if device == "hpu" and not is_supported_on_hpu(quant_type):
453+
pytest.skip("This configuration is not supported on HPU.")
454+
if device == "cuda" and platform.system() == "Windows":
455+
pytest.skip("Triton is not officially supported on Windows")
456+
457+
dim = 256
458+
batch_size = 16
459+
compute_dtype = torch.bfloat16
460+
461+
torch.compiler.reset()
462+
463+
class CheckpointedNet(torch.nn.Module):
464+
def __init__(self):
465+
super().__init__()
466+
self.layers = torch.nn.ModuleList(
467+
[
468+
bnb.nn.Linear4bit(
469+
dim,
470+
dim,
471+
bias=False,
472+
compute_dtype=compute_dtype,
473+
compress_statistics=compress_statistics,
474+
quant_type=quant_type,
475+
)
476+
for _ in range(4)
477+
]
478+
)
479+
480+
def forward(self, x):
481+
for layer in self.layers:
482+
x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
483+
return x
484+
485+
net = CheckpointedNet().to(device)
486+
487+
x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device, requires_grad=True)
488+
489+
# Reference output (eager)
490+
ref_output = net(x)
491+
ref_output.sum().backward()
492+
grad_ref = x.grad.clone()
493+
x.grad = None
494+
495+
# Compiled with fullgraph=True — will raise if there are graph breaks
496+
compile_backend = "hpu_backend" if device == "hpu" else "inductor"
497+
compiled_net = torch.compile(net, fullgraph=True, backend=compile_backend)
498+
499+
compiled_output = compiled_net(x)
500+
compiled_output.sum().backward()
501+
grad_compiled = x.grad.clone()
502+
503+
torch.testing.assert_close(compiled_output, ref_output)
504+
torch.testing.assert_close(grad_compiled, grad_ref)
505+
506+
437507
@pytest.mark.parametrize("device", get_available_devices())
438508
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
439509
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@@ -494,7 +564,7 @@ def test_params4bit_quant_state_attr_access(device, quant_type, compress_statist
494564
with pytest.raises(AttributeError, match="nonexistent_attribute"):
495565
_ = w.nonexistent_attribute
496566

497-
# Verify that normal Params4bit attributes are unaffected by __getattr__
567+
# Verify that normal Params4bit instance attributes are unaffected
498568
assert isinstance(w.quant_state, bnb.functional.QuantState)
499569
assert isinstance(w.bnb_quantized, bool)
500570
assert w.bnb_quantized is True

0 commit comments

Comments
 (0)