diff --git a/.github/scripts/build-rocm.sh b/.github/scripts/build-rocm.sh new file mode 100644 index 000000000..b508fac69 --- /dev/null +++ b/.github/scripts/build-rocm.sh @@ -0,0 +1,21 @@ +#!/bin/bash +declare build_arch +declare build_os +declare rocm_version + +set -xeuo pipefail +bnb_rocm_arch="gfx90a;gfx942;gfx1100" +if [ "${build_os:0:6}" == ubuntu ]; then + image=rocm/dev-ubuntu-22.04:${rocm_version}-complete + echo "Using image $image" + docker run --rm --platform "linux/$build_arch" -i \ + -w /src -v "$PWD:/src" "$image" sh -c \ + "apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ + && cmake -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=\"${bnb_rocm_arch}\" . \ + && cmake --build ." +fi + +output_dir="output/${build_os}/${build_arch}" +mkdir -p "${output_dir}" +(shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}") diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index fbaa27d56..3673ac608 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -102,10 +102,55 @@ jobs: path: output/* retention-days: 7 + build-shared-libs-rocm: + strategy: + matrix: + os: [ubuntu-22.04] + arch: [x86_64] + rocm_version: + ["6.1.2", "6.2.4", "6.3.2"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Docker multiarch + uses: docker/setup-qemu-action@v3 + - name: Clean up disk space + run: | + sudo rm -rf \ + /usr/share/dotnet \ + /opt/ghc \ + "/usr/local/share/boost" \ + "$AGENT_TOOLSDIRECTORY" \ + /opt/hostedtoolcache \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /opt/pipx \ + /usr/lib/mono \ + /usr/local/julia* \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/swift + - name: Build C++ + run: bash .github/scripts/build-rocm.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + rocm_version: ${{ matrix.rocm_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} + path: output/* + retention-days: 7 + build-wheels: needs: - build-shared-libs - build-shared-libs-cuda + - build-shared-libs-rocm strategy: matrix: os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] @@ -171,10 +216,10 @@ jobs: path: tmp/ pattern: "bdist_wheel_*" merge-multiple: true - + - name: Inspect tmp directory after downloading artifacts run: ls -alFR tmp/ - + - name: Move and rename wheel files with pattern replacement run: | mkdir -p wheels/ @@ -199,7 +244,7 @@ jobs: - name: Inspect wheels directory after renaming files run: ls -alFR wheels/ - + - name: Delete old pre-release (if exists) run: | gh release delete continuous-release_main --cleanup-tag -y || true @@ -213,7 +258,7 @@ jobs: This pre-release contains the latest development wheels for all supported platforms, rebuilt automatically on every commit to the `main` branch. - **How to install:** + **How to install:** Pick the correct command for your platform and run it in your terminal: ENDOFMARKDOWN @@ -228,7 +273,7 @@ jobs: done cat >> body.md << 'ENDOFMARKDOWN' - > **Note:** + > **Note:** > These wheels are updated automatically with every commit to `main` and become available as soon as the [python-package.yml](.github/workflows/python-package.yml) workflow finishes. ENDOFMARKDOWN diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index fd7b7b9a2..9089d6fc2 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -8,7 +8,7 @@ from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel -from ...cextension import lib, HIP_ENVIRONMENT +from ...cextension import HIP_ENVIRONMENT, lib @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @@ -210,12 +210,12 @@ def _get_col_absmax( @register_kernel("bitsandbytes::quantize_blockwise", "cuda") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) - - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: + + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") n = A.numel() @@ -269,11 +269,11 @@ def _( def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check( dtype in [torch.float16, torch.bfloat16, torch.float32], @@ -303,11 +303,11 @@ def _dequantize_blockwise_impl( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(quant_type in ["fp4", "nf4"]) torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], @@ -385,11 +385,11 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 108aa0c9a..5283df93e 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -81,7 +81,7 @@ def get_available_cuda_binary_versions() -> list[str]: lib_pattern = f"libbitsandbytes_{BNB_BACKEND.lower()}*{DYNAMIC_LIBRARY_SUFFIX}" versions = [] for lib in Path(__file__).parent.glob(lib_pattern): - pattern = r"{}(\d+)".format(BNB_BACKEND.lower()) + pattern = rf"{BNB_BACKEND.lower()}(\d+)" match = re.search(pattern, lib.name) if match: ver_code = int(match.group(1)) @@ -199,18 +199,16 @@ def _format_lib_error_message( ) compile_instructions = ( - ( - "COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n" - ) if not no_cuda_lib_found - else - ( + ("COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n") + if not no_cuda_lib_found + else ( "You have two options:\n" "1. COMPILE FROM SOURCE (required if no binary exists):\n" " https://huggingface.co/docs/bitsandbytes/main/en/installation#cuda-compile\n" "2. Use BNB_CUDA_VERSION to specify a DIFFERENT CUDA version from the detected one, which is installed on your machine and matching an available pre-compiled version listed above\n\n" - ) if not HIP_ENVIRONMENT - else - ( + ) + if not HIP_ENVIRONMENT + else ( "You can COMPILE FROM SOURCE as mentioned here:\n" " https://huggingface.co/docs/bitsandbytes/main/en/installation?backend=AMD+ROCm#amd-gpu\n" ) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 61d03083c..32563a159 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,8 +1,8 @@ import dataclasses -import logging -import re -import subprocess from functools import lru_cache +import logging +import re +import subprocess from typing import Optional import torch @@ -78,25 +78,25 @@ def get_cuda_specs() -> Optional[CUDASpecs]: return None -def get_rocm_gpu_arch() -> str: - """Get ROCm GPU architecture.""" - logger = logging.getLogger(__name__) - try: - if torch.version.hip: - result = subprocess.run(["rocminfo"], capture_output=True, text=True) - match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) - if match: - return "gfx" + match.group(1) - else: - return "unknown" - else: - return "unknown" - except Exception as e: - logger.error(f"Could not detect ROCm GPU architecture: {e}") - if torch.cuda.is_available(): - logger.warning( - """ -ROCm GPU architecture detection failed despite ROCm being available. - """, - ) - return "unknown" +def get_rocm_gpu_arch() -> str: + """Get ROCm GPU architecture.""" + logger = logging.getLogger(__name__) + try: + if torch.version.hip: + result = subprocess.run(["rocminfo"], capture_output=True, text=True) + match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) + if match: + return "gfx" + match.group(1) + else: + return "unknown" + else: + return "unknown" + except Exception as e: + logger.error(f"Could not detect ROCm GPU architecture: {e}") + if torch.cuda.is_available(): + logger.warning( + """ +ROCm GPU architecture detection failed despite ROCm being available. + """, + ) + return "unknown" diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index b9de27fd7..b9db101ab 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -33,11 +33,13 @@ } CUDA_RUNTIME_LIB_PATTERNS = ( - "libamdhip64.so*", -) if HIP_ENVIRONMENT else ( - "cudart64*.dll", # Windows - "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. - "nvcuda*.dll", # Windows + ("libamdhip64.so*",) + if HIP_ENVIRONMENT + else ( + "cudart64*.dll", # Windows + "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. + "nvcuda*.dll", # Windows + ) ) logger = logging.getLogger(__name__) diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 8e2bc2a7b..bf31d7978 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -43,7 +43,8 @@ def main(): print(f"{BNB_BACKEND} specs:{cuda_specs}") if not torch.cuda.is_available(): print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") - if not HIP_ENVIRONMENT: print(f"- {BNB_BACKEND} driver not installed") + if not HIP_ENVIRONMENT: + print(f"- {BNB_BACKEND} driver not installed") print(f"- {BNB_BACKEND} not installed") print(f"- You have multiple conflicting {BNB_BACKEND} libraries") if cuda_specs: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 03f6c323d..9b7ce2da9 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import lib, HIP_ENVIRONMENT +from .cextension import HIP_ENVIRONMENT, lib name2qmap = {} @@ -1007,10 +1007,10 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ - + if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - + input_shape = A.shape _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( @@ -1114,10 +1114,10 @@ def dequantize_4bit( Returns: `torch.Tensor`: The dequantized tensor. """ - + if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - + if quant_state is None: assert absmax is not None and out is not None diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 2383f2c10..a2facac28 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -222,10 +222,10 @@ def __new__( ) -> "Params4bit": if data is None: data = torch.empty(0) - + if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - + self = torch.Tensor._make_subclass(cls, data, requires_grad) self.blocksize = blocksize self.compress_statistics = compress_statistics diff --git a/csrc/common_hip.cuh b/csrc/common_hip.cuh index e7fc4eb81..105179535 100644 --- a/csrc/common_hip.cuh +++ b/csrc/common_hip.cuh @@ -1,6 +1,6 @@ #pragma once -#define BNB_WARP_SIZE warpSize +#define BNB_WARP_SIZE warpSize // These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs #define BNB_MAX_THREADS_PER_SM 2048 diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 368788f39..56e1d54db 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -532,7 +532,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float absmax[i / BLOCK_SIZE] = local_abs_max; } __syncthreads(); - + local_abs_max = smem_absmax_value[0]; if(STOCHASTIC) @@ -610,7 +610,7 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs valid_items_load = min(TILE_SIZE, n - i); valid_items_store = valid_items_load; } - + // Since blocksize will always be a power-of-2, we avoid more expensive // division by the blocksize and instead use a shift operation. // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. @@ -811,7 +811,7 @@ __global__ void kOptimizer32bit2State(T* g, T* p, LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); __syncthreads(); Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); - + // Load additional state1 data for AdEMAMix // TODO: Make constexpr after updating min compiler if (OPTIMIZER == ADEMAMIX) { @@ -1607,7 +1607,7 @@ kOptimizerStatic8bit2StateBlockwise( unsigned char c1s[N_PER_TH]; unsigned char c2s[N_PER_TH]; unsigned char c3s[N_PER_TH]; - + T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; typedef hipcub::BlockLoad LoadT; @@ -1712,7 +1712,7 @@ kOptimizerStatic8bit2StateBlockwise( new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); - + if (OPTIMIZER == ADEMAMIX) { new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j])); } @@ -1776,7 +1776,7 @@ kOptimizerStatic8bit2StateBlockwise( } else { p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); } - + if(weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } @@ -2148,27 +2148,27 @@ __global__ void kdequant_mm_int32_fp16( int local_values[ITEMS_PER_THREAD]; half local_output[ITEMS_PER_THREAD]; - + float local_rowStats[ITEMS_PER_THREAD]; float local_colStats[ITEMS_PER_THREAD]; float local_biasValue[ITEMS_PER_THREAD]; typedef hipcub::BlockLoad LoadInt32; __shared__ typename LoadInt32::TempStorage loadint32; - + int row_idx, col_idx; - + #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) { row_idx = (block_offset + thread_offset + j) / numCols; col_idx = (block_offset + thread_offset + j) % numCols; - + local_colStats[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; - local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; + local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; local_biasValue[j] = ((bias == nullptr) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]); } - + // Each block loads THREADS * ITEMS_PER_THREAD values from A int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD @@ -2188,7 +2188,7 @@ __global__ void kdequant_mm_int32_fp16( if (outIdx < n_out) { out[outIdx] = local_output[j]; } - } + } } #define DENORM 1.0f/127.0f diff --git a/csrc/ops.hip b/csrc/ops.hip index 4d077d19a..eef616d48 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -199,10 +199,10 @@ template void optimizerStatic8bit(T* p, T* g, } } -#define BLOCKSIZE_2STATE 256 -#define NUM_2STATE 1 -#define BLOCKSIZE_1STATE 256 -#define NUM_1STATE 1 +#define BLOCKSIZE_2STATE 256 +#define NUM_2STATE 1 +#define BLOCKSIZE_1STATE 256 +#define NUM_1STATE 1 template void optimizerStatic8bitBlockwise( T* p, @@ -443,7 +443,7 @@ static std::string hipError_to_string(const hipError_t ret) } template int igemmlt( - hipblasLtHandle_t ltHandle, + hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 1b2ea85db..3d8b688ee 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -12,11 +12,13 @@ def cuda120_spec() -> CUDASpecs: cuda_version_tuple=(12, 0), ) + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): monkeypatch.setenv("BNB_CUDA_VERSION", "110") diff --git a/tests/test_functional.py b/tests/test_functional.py index 5f5ee488c..a2964c733 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -8,8 +8,8 @@ import torch import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH from bitsandbytes import functional as F +from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, @@ -92,7 +92,10 @@ class Test8BitBlockwiseQuantizeFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) - @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128] ) + @pytest.mark.parametrize( + "blocksize", + [4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128], + ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): iters = 100 @@ -796,6 +799,7 @@ def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): A[:, outlier_cols] = 0 torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") class TestSpMMFunctional: @@ -1106,7 +1110,10 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096]) + @pytest.mark.parametrize( + "blocksize", + [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096], + ) def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1205,7 +1212,7 @@ def test_bench_4bit_dequant(self, quant_type): # torch.matmul(b, a.t()) # torch.cuda.synchronize() # print((time.time()-t0)/iters*1e6) - + @pytest.mark.skipif( HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" ) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 1b7a7722c..60c163477 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -17,6 +17,7 @@ "float32": torch.float32, } + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) diff --git a/tests/test_ops.py b/tests/test_ops.py index a99d080b3..a433a0c4b 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -4,8 +4,8 @@ import torch import bitsandbytes -from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter from bitsandbytes.cextension import HIP_ENVIRONMENT +from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter class TestLLMInt8Ops: