Summary
matmul_4bit() mutates quant_state.dtype in-place on CPU (_functions.py:387):
if A.device.type == "cpu":
quant_state.dtype = A.dtype
Under torch.compile(fullgraph=True) with activation checkpointing, Dynamo flags this as:
torch._dynamo.exc.Unsupported: HigherOrderOperator: Mutating a variable
not in the current scope (SideEffects)
This only affects CPU — on CUDA the branch is not taken, so GPU compilation works fine.
How discovered
The regression test added in #1916 (test_linear4bit_torch_compile_activation_checkpointing) exercises fullgraph=True + torch.utils.checkpoint. It passes on all GPU backends but fails on CPU with torch 2.10.0 due to this pre-existing mutation. The test now skips on CPU to keep #1916 focused on the __getattr__ → @property fix.
Suggested fix
Refactor matmul_4bit() to avoid mutating quant_state.dtype in-place. For example, pass the input dtype as a parameter to the downstream functions instead of writing it onto the shared QuantState object. This would make matmul_4bit safe for Dynamo tracing on CPU.
Reproducer
import torch
import bitsandbytes as bnb
class CheckpointedNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ModuleList([
bnb.nn.Linear4bit(256, 256, bias=False, compute_dtype=torch.bfloat16, quant_type="nf4")
for _ in range(4)
])
def forward(self, x):
for layer in self.layers:
x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
return x
net = CheckpointedNet().to("cpu")
compiled = torch.compile(net, fullgraph=True, backend="inductor")
x = torch.randn(16, 256, dtype=torch.bfloat16, requires_grad=True)
compiled(x).sum().backward() # raises Unsupported: Mutating a variable not in the current scope
Environment
- torch 2.10.0
- CPU only (CUDA is unaffected)
- Observed on linux-x64, linux-aarch64, linux-x64-icelake, macOS
Summary
matmul_4bit()mutatesquant_state.dtypein-place on CPU (_functions.py:387):Under
torch.compile(fullgraph=True)with activation checkpointing, Dynamo flags this as:This only affects CPU — on CUDA the branch is not taken, so GPU compilation works fine.
How discovered
The regression test added in #1916 (
test_linear4bit_torch_compile_activation_checkpointing) exercisesfullgraph=True+torch.utils.checkpoint. It passes on all GPU backends but fails on CPU with torch 2.10.0 due to this pre-existing mutation. The test now skips on CPU to keep #1916 focused on the__getattr__→@propertyfix.Suggested fix
Refactor
matmul_4bit()to avoid mutatingquant_state.dtypein-place. For example, pass the input dtype as a parameter to the downstream functions instead of writing it onto the sharedQuantStateobject. This would makematmul_4bitsafe for Dynamo tracing on CPU.Reproducer
Environment