|
8 | 8 | from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr |
9 | 9 |
|
10 | 10 | from ..._ops import register_kernel |
11 | | -from ...cextension import lib, HIP_ENVIRONMENT |
| 11 | +from ...cextension import HIP_ENVIRONMENT, lib |
12 | 12 |
|
13 | 13 |
|
14 | 14 | @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") |
@@ -210,12 +210,12 @@ def _get_col_absmax( |
210 | 210 | @register_kernel("bitsandbytes::quantize_blockwise", "cuda") |
211 | 211 | def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: |
212 | 212 | torch._check_is_size(blocksize) |
213 | | - |
214 | | - if HIP_ENVIRONMENT: |
215 | | - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) |
216 | | - else: |
| 213 | + |
| 214 | + if HIP_ENVIRONMENT: |
| 215 | + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) |
| 216 | + else: |
217 | 217 | torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) |
218 | | - |
| 218 | + |
219 | 219 | torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") |
220 | 220 |
|
221 | 221 | n = A.numel() |
@@ -269,11 +269,11 @@ def _( |
269 | 269 | def _dequantize_blockwise_impl( |
270 | 270 | A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor |
271 | 271 | ) -> None: |
272 | | - if HIP_ENVIRONMENT: |
273 | | - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) |
274 | | - else: |
| 272 | + if HIP_ENVIRONMENT: |
| 273 | + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) |
| 274 | + else: |
275 | 275 | torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) |
276 | | - |
| 276 | + |
277 | 277 | torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") |
278 | 278 | torch._check( |
279 | 279 | dtype in [torch.float16, torch.bfloat16, torch.float32], |
@@ -303,11 +303,11 @@ def _dequantize_blockwise_impl( |
303 | 303 | def _( |
304 | 304 | A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype |
305 | 305 | ) -> tuple[torch.Tensor, torch.Tensor]: |
306 | | - if HIP_ENVIRONMENT: |
307 | | - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) |
308 | | - else: |
| 306 | + if HIP_ENVIRONMENT: |
| 307 | + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) |
| 308 | + else: |
309 | 309 | torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) |
310 | | - |
| 310 | + |
311 | 311 | torch._check(quant_type in ["fp4", "nf4"]) |
312 | 312 | torch._check( |
313 | 313 | A.dtype in [torch.bfloat16, torch.float16, torch.float32], |
@@ -385,11 +385,11 @@ def _dequantize_4bit_impl( |
385 | 385 | dtype: torch.dtype, |
386 | 386 | out: torch.Tensor, |
387 | 387 | ) -> None: |
388 | | - if HIP_ENVIRONMENT: |
389 | | - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) |
390 | | - else: |
| 388 | + if HIP_ENVIRONMENT: |
| 389 | + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) |
| 390 | + else: |
391 | 391 | torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) |
392 | | - |
| 392 | + |
393 | 393 | torch._check(quant_type in ["fp4", "nf4"]) |
394 | 394 | torch._check( |
395 | 395 | dtype in [torch.bfloat16, torch.float16, torch.float32], |
|
0 commit comments