|
7 | 7 | import torch |
8 | 8 |
|
9 | 9 | import bitsandbytes as bnb |
10 | | -from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer |
| 10 | +from tests.helpers import ( |
| 11 | + TRUE_FALSE, |
| 12 | + describe_dtype, |
| 13 | + get_available_devices, |
| 14 | + id_formatter, |
| 15 | + torch_load_from_buffer, |
| 16 | + torch_save_to_buffer, |
| 17 | +) |
11 | 18 |
|
12 | 19 | storage = { |
13 | 20 | "uint8": torch.uint8, |
@@ -275,3 +282,72 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s |
275 | 282 | # there was a bug where deepcopy would modify the original object |
276 | 283 | assert dict_keys_before == dict_keys_after |
277 | 284 | assert dict_keys_before == dict_keys_deserialized |
| 285 | + |
| 286 | + |
| 287 | +@pytest.mark.parametrize("device", get_available_devices()) |
| 288 | +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) |
| 289 | +@pytest.mark.parametrize("compute_dtype", [torch.bfloat16, torch.float32], ids=describe_dtype) |
| 290 | +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) |
| 291 | +@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) |
| 292 | +@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) |
| 293 | +@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) |
| 294 | +@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") |
| 295 | +def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode): |
| 296 | + if device == "cpu" and quant_type == "fp4": |
| 297 | + pytest.skip("FP4 is not supported for CPU") |
| 298 | + |
| 299 | + if fullgraph and torch.__version__ < (2, 8): |
| 300 | + pytest.skip("fullgraph mode requires torch 2.8 or higher") |
| 301 | + |
| 302 | + dim = 256 |
| 303 | + batch_size = 16 |
| 304 | + |
| 305 | + torch.compiler.reset() |
| 306 | + |
| 307 | + # Create a small network with Linear4bit layers |
| 308 | + net = torch.nn.Sequential( |
| 309 | + *[ |
| 310 | + bnb.nn.Linear4bit( |
| 311 | + dim, |
| 312 | + dim, |
| 313 | + bias=bias, |
| 314 | + compute_dtype=compute_dtype, |
| 315 | + compress_statistics=compress_statistics, |
| 316 | + quant_type=quant_type, |
| 317 | + ) |
| 318 | + for _ in range(4) |
| 319 | + ] |
| 320 | + ).to(device) |
| 321 | + |
| 322 | + # Create input tensor |
| 323 | + x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device) |
| 324 | + |
| 325 | + # Get reference output before compilation |
| 326 | + with torch.no_grad(): |
| 327 | + ref_output = net(x) |
| 328 | + |
| 329 | + # Compile the model |
| 330 | + compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode) |
| 331 | + |
| 332 | + # Get output from compiled model |
| 333 | + with torch.no_grad(): |
| 334 | + compiled_output = compiled_net(x) |
| 335 | + |
| 336 | + # Check outputs match |
| 337 | + assert compiled_output.shape == ref_output.shape |
| 338 | + assert compiled_output.device == ref_output.device |
| 339 | + assert compiled_output.dtype == ref_output.dtype |
| 340 | + torch.testing.assert_close(compiled_output, ref_output) |
| 341 | + |
| 342 | + # Test with gradients |
| 343 | + x.requires_grad_(True) |
| 344 | + y1 = net(x).sum() |
| 345 | + y1.backward() |
| 346 | + grad_ref = x.grad.clone() |
| 347 | + |
| 348 | + x.grad = None |
| 349 | + y2 = compiled_net(x).sum() |
| 350 | + y2.backward() |
| 351 | + grad_compiled = x.grad.clone() |
| 352 | + |
| 353 | + torch.testing.assert_close(grad_compiled, grad_ref) |
0 commit comments