Skip to content

Commit f176eab

Browse files
Add torch.compile tests
1 parent d475533 commit f176eab

File tree

4 files changed

+135
-4
lines changed

4 files changed

+135
-4
lines changed

bitsandbytes/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -771,14 +771,14 @@ def quantize_blockwise(
771771
qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False)
772772
quant_state = QuantState(
773773
absmax=qabsmax,
774-
code=code,
774+
code=code.to(A.device, copy=True),
775775
blocksize=blocksize,
776776
dtype=A.dtype,
777777
offset=offset,
778778
state2=state2,
779779
)
780780
else:
781-
quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype)
781+
quant_state = QuantState(absmax=_absmax, code=code.to(A.device, copy=True), blocksize=blocksize, dtype=A.dtype)
782782

783783
# TODO(matthewdouglas): Deprecate out kwarg
784784
out = out.copy_(_out) if out is not None else _out

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def forward(self, x: torch.Tensor):
493493

494494
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
495495

496-
return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
496+
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
497497

498498

499499
class LinearFP4(Linear4bit):

tests/test_linear4bit.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
import torch
88

99
import bitsandbytes as bnb
10-
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer
10+
from tests.helpers import (
11+
TRUE_FALSE,
12+
describe_dtype,
13+
get_available_devices,
14+
id_formatter,
15+
torch_load_from_buffer,
16+
torch_save_to_buffer,
17+
)
1118

1219
storage = {
1320
"uint8": torch.uint8,
@@ -275,3 +282,72 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
275282
# there was a bug where deepcopy would modify the original object
276283
assert dict_keys_before == dict_keys_after
277284
assert dict_keys_before == dict_keys_deserialized
285+
286+
287+
@pytest.mark.parametrize("device", get_available_devices())
288+
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
289+
@pytest.mark.parametrize("compute_dtype", [torch.bfloat16, torch.float32], ids=describe_dtype)
290+
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
291+
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
292+
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
293+
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
294+
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
295+
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
296+
if device == "cpu" and quant_type == "fp4":
297+
pytest.skip("FP4 is not supported for CPU")
298+
299+
if fullgraph and torch.__version__ < (2, 8):
300+
pytest.skip("fullgraph mode requires torch 2.8 or higher")
301+
302+
dim = 256
303+
batch_size = 16
304+
305+
torch.compiler.reset()
306+
307+
# Create a small network with Linear4bit layers
308+
net = torch.nn.Sequential(
309+
*[
310+
bnb.nn.Linear4bit(
311+
dim,
312+
dim,
313+
bias=bias,
314+
compute_dtype=compute_dtype,
315+
compress_statistics=compress_statistics,
316+
quant_type=quant_type,
317+
)
318+
for _ in range(4)
319+
]
320+
).to(device)
321+
322+
# Create input tensor
323+
x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device)
324+
325+
# Get reference output before compilation
326+
with torch.no_grad():
327+
ref_output = net(x)
328+
329+
# Compile the model
330+
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)
331+
332+
# Get output from compiled model
333+
with torch.no_grad():
334+
compiled_output = compiled_net(x)
335+
336+
# Check outputs match
337+
assert compiled_output.shape == ref_output.shape
338+
assert compiled_output.device == ref_output.device
339+
assert compiled_output.dtype == ref_output.dtype
340+
torch.testing.assert_close(compiled_output, ref_output)
341+
342+
# Test with gradients
343+
x.requires_grad_(True)
344+
y1 = net(x).sum()
345+
y1.backward()
346+
grad_ref = x.grad.clone()
347+
348+
x.grad = None
349+
y2 = compiled_net(x).sum()
350+
y2.backward()
351+
grad_compiled = x.grad.clone()
352+
353+
torch.testing.assert_close(grad_compiled, grad_ref)

tests/test_linear8bitlt.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,58 @@ def test_linear8bit_serialization(linear8bit):
224224
# check for a bug where SCB and CB were not copied
225225
assert (linear8bit.weight.SCB == deserialized.weight.SCB).all()
226226
assert (linear8bit.weight.CB == deserialized.weight.CB).all()
227+
228+
229+
@pytest.mark.parametrize("device", get_available_devices())
230+
@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold"))
231+
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
232+
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
233+
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
234+
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
235+
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
236+
dim = 256
237+
batch_size = 16
238+
239+
torch.compiler.reset()
240+
241+
torch._dynamo.config.patch()
242+
# Create a small network with Linear8bitLt layers
243+
net = torch.nn.Sequential(
244+
*[bnb.nn.Linear8bitLt(dim, dim, bias=bias, has_fp16_weights=False, threshold=threshold) for _ in range(4)]
245+
).to(device)
246+
247+
dynamic_output_shapes = fullgraph and threshold > 0
248+
with torch._dynamo.config.patch("capture_dynamic_output_shape_ops", dynamic_output_shapes):
249+
# Create input tensor
250+
x = torch.randn(batch_size, dim, dtype=torch.float16, device=device)
251+
252+
# Get reference output before compilation
253+
with torch.no_grad():
254+
ref_output = net(x)
255+
256+
# Compile the model
257+
compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode)
258+
259+
# Get output from compiled model
260+
with torch.no_grad():
261+
compiled_output = compiled_net(x)
262+
263+
# Check outputs match
264+
assert compiled_output.shape == ref_output.shape
265+
assert compiled_output.device == ref_output.device
266+
assert compiled_output.dtype == ref_output.dtype
267+
torch.testing.assert_close(compiled_output, ref_output)
268+
269+
# Test with gradients. Currently only works with threshold=0.
270+
if threshold == 0:
271+
x.requires_grad_(True)
272+
y1 = net(x).sum()
273+
y1.backward()
274+
grad_ref = x.grad.clone()
275+
276+
x.grad = None
277+
y2 = compiled_net(x).sum()
278+
y2.backward()
279+
grad_compiled = x.grad.clone()
280+
281+
torch.testing.assert_close(grad_compiled, grad_ref)

0 commit comments

Comments
 (0)