Skip to content

Commit 084d294

Browse files
Ailuntzailuntz
andauthored
Guard SCB access in Linear8bitLt (#1897)
* Guard SCB access in Linear8bitLt * fix(nn): keep tied-weight SCB guards out of forward --------- Co-authored-by: ailuntz <ailuntz@ailuntzdeMac-mini.local>
1 parent e1dc75a commit 084d294

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

bitsandbytes/nn/modules.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,7 +1056,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
10561056
scb_name = "SCB"
10571057

10581058
# case 1: .cuda was called, SCB is in self.weight
1059-
param_from_weight = getattr(self.weight, scb_name)
1059+
param_from_weight = getattr(self.weight, scb_name, None)
10601060
# case 2: self.init_8bit_state was called, SCB is in self.state
10611061
param_from_state = getattr(self.state, scb_name)
10621062

@@ -1097,15 +1097,16 @@ def _load_from_state_dict(
10971097
for key in unexpected_copy:
10981098
input_name = key[len(prefix) :]
10991099
if input_name == "SCB":
1100-
if self.weight.SCB is None:
1100+
weight_scb = getattr(self.weight, "SCB", None)
1101+
if weight_scb is None:
11011102
# buffers not yet initialized, can't access them directly without quantizing first
11021103
raise RuntimeError(
11031104
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
11041105
"not supported. Please call module.cuda() before module.load_state_dict()",
11051106
)
11061107

11071108
input_param = state_dict[key]
1108-
self.weight.SCB.copy_(input_param)
1109+
weight_scb.copy_(input_param)
11091110

11101111
if self.state.SCB is not None:
11111112
self.state.SCB = self.weight.SCB

tests/test_linear8bitlt.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,26 @@ def test_linear8bit_serialization(linear8bit):
228228
assert (linear8bit.weight.CB == deserialized.weight.CB).all()
229229

230230

231+
def test_linear8bit_state_dict_skips_scb_for_tied_weight():
232+
linear = Linear8bitLt(8, 8, bias=False, has_fp16_weights=False)
233+
linear.weight = torch.nn.Parameter(torch.randn_like(linear.weight))
234+
235+
state_dict = linear.state_dict()
236+
237+
assert "SCB" not in state_dict
238+
assert "weight_format" not in state_dict
239+
240+
241+
def test_linear8bit_load_state_dict_raises_runtime_for_tied_weight():
242+
linear = Linear8bitLt(8, 8, bias=False, has_fp16_weights=False)
243+
linear.weight = torch.nn.Parameter(torch.randn_like(linear.weight))
244+
state_dict = linear.state_dict()
245+
state_dict["SCB"] = torch.ones(linear.out_features)
246+
247+
with pytest.raises(RuntimeError, match="Loading a quantized checkpoint into non-quantized Linear8bitLt"):
248+
linear.load_state_dict(state_dict, strict=False)
249+
250+
231251
@pytest.mark.parametrize("device", get_available_devices())
232252
@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold"))
233253
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))

0 commit comments

Comments
 (0)