diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index fb44141e9..88456d345 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -86,7 +86,7 @@ jobs: cuda: ${{ matrix.cuda_version }} method: "network" # The "crt" "nvvm" and "nvptxcompiler" components are added for CUDA 13. - sub-packages: ${{ format('["nvcc"{0},"cudart","cusparse","cublas","thrust","cublas_dev","cusparse_dev"]', startsWith(matrix.cuda_version, '13.') && ',"crt","nvvm","nvptxcompiler"' || '') }} + sub-packages: ${{ format('["nvcc"{0},"cudart","cublas","thrust","cublas_dev"]', startsWith(matrix.cuda_version, '13.') && ',"crt","nvvm","nvptxcompiler"' || '') }} use-github-cache: false use-local-cache: false log-file-suffix: ${{matrix.os}}-${{matrix.cuda_version}}.txt diff --git a/.github/workflows/test-runner.yml b/.github/workflows/test-runner.yml index ad5415a8c..9a9c42fda 100644 --- a/.github/workflows/test-runner.yml +++ b/.github/workflows/test-runner.yml @@ -148,7 +148,7 @@ jobs: with: cuda: ${{ inputs.cuda_version }} method: "network" - sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]' + sub-packages: '["nvcc","cudart","cublas","thrust","nvrtc_dev","cublas_dev"]' use-github-cache: false # Windows: Setup MSVC (needed for both CPU and CUDA builds) diff --git a/CMakeLists.txt b/CMakeLists.txt index da592203c..19bd7b623 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -349,7 +349,7 @@ endif() if(BUILD_CUDA) target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse) + target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt) set_target_properties(bitsandbytes PROPERTIES CUDA_SEPARABLE_COMPILATION ON @@ -369,7 +369,6 @@ if(BUILD_HIP) endmacro() find_package_and_print_version(hipblas REQUIRED) find_package_and_print_version(hiprand REQUIRED) - find_package_and_print_version(hipsparse REQUIRED) ## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies) ## On Windows, we need to link amdhip64 explicitly @@ -381,7 +380,7 @@ if(BUILD_HIP) target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include) target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib) - target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse) + target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand) # On Windows, rocblas is not pulled in transitively by roc::hipblas # and is needed because ops_hip.cuh uses rocblas_handle directly. diff --git a/agents/api_surface.md b/agents/api_surface.md index 32470baff..0e2e2552e 100644 --- a/agents/api_surface.md +++ b/agents/api_surface.md @@ -860,57 +860,7 @@ F.batched_igemm( Batched int8 matrix multiplication. **Stability:** Stable (internal). -### 4.9 Sparse Operations - -#### `COOSparseTensor` - -```python -class F.COOSparseTensor: - def __init__(self, rows, cols, nnz, rowidx, colidx, values): ... -``` - -**Stability:** Legacy — used internally for sparse decomposition. - -#### `CSRSparseTensor` / `CSCSparseTensor` - -Similar sparse tensor containers. -**Stability:** Legacy. - -#### `coo_zeros` - -```python -F.coo_zeros(rows, cols, nnz, device, dtype=torch.half) -> COOSparseTensor -``` - -#### `coo2csr` / `coo2csc` - -```python -F.coo2csr(cooA: COOSparseTensor) -> CSRSparseTensor -F.coo2csc(cooA: COOSparseTensor) -> CSCSparseTensor -``` - -#### `spmm_coo` - -```python -F.spmm_coo( - cooA: COOSparseTensor, B: torch.Tensor, - out: Optional[torch.Tensor] = None, -) -> torch.Tensor -``` - -Sparse matrix-dense matrix multiply using cusparse. -**Stability:** Legacy. - -#### `spmm_coo_very_sparse` - -```python -F.spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None) -> torch.Tensor -``` - -Optimized for very sparse matrices with custom kernel. -**Stability:** Legacy. - -### 4.10 Paged Memory +### 4.9 Paged Memory #### `get_paged` @@ -930,7 +880,7 @@ F.prefetch_tensor(A: torch.Tensor, to_cpu: bool = False) -> None Prefetch a paged tensor to GPU or CPU. **Stability:** Stable (internal). -### 4.11 CPU-Specific Functions +### 4.10 CPU-Specific Functions #### `_convert_weight_packed_for_cpu` @@ -963,7 +913,7 @@ F.has_avx512bf16() -> bool Detects AVX512BF16 CPU support. **Stability:** Internal but may be useful externally. -### 4.12 Utility Functions +### 4.11 Utility Functions #### `is_on_gpu` @@ -983,7 +933,7 @@ F.get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p] Gets the data pointer of a tensor for ctypes calls. **Stability:** Internal. -### 4.13 Singleton Managers +### 4.12 Singleton Managers #### `GlobalPageManager` @@ -1003,15 +953,6 @@ F.CUBLAS_Context.get_instance() -> CUBLAS_Context Manages cuBLAS context handles per device. **Stability:** Internal. -#### `Cusparse_Context` - -```python -F.Cusparse_Context.get_instance() -> Cusparse_Context -``` - -Manages cusparse context handle. -**Stability:** Internal. - --- ## 5. Autograd Functions @@ -1234,7 +1175,7 @@ bitsandbytes.utils.replace_linear( | Class | Description | |-------|-------------| | `BNBNativeLibrary` | Base wrapper for the ctypes-loaded native library | -| `CudaBNBNativeLibrary` | CUDA-specific subclass (sets up context/cusparse/managed ptr) | +| `CudaBNBNativeLibrary` | CUDA-specific subclass (sets up context/managed ptr) | | `ErrorHandlerMockBNBNativeLibrary` | Fallback mock that defers error messages to call time | ### Module-level symbols @@ -1396,11 +1337,9 @@ A PR that changes any of these symbols MUST consider downstream impact: - `bitsandbytes.cextension.*` (native library loading) - `bitsandbytes.functional.get_ptr`, `is_on_gpu`, `_get_tensor_stream` -- `bitsandbytes.functional.GlobalPageManager`, `CUBLAS_Context`, `Cusparse_Context` +- `bitsandbytes.functional.GlobalPageManager`, `CUBLAS_Context` - `bitsandbytes.functional._convert_weight_packed_for_cpu*` - `bitsandbytes.functional.check_matmul`, `elementwise_func`, `fill`, `_mul` -- `bitsandbytes.functional.spmm_coo`, `spmm_coo_very_sparse` -- `bitsandbytes.functional.COOSparseTensor`, `CSRSparseTensor`, `CSCSparseTensor` - `bitsandbytes.utils.pack_dict_to_tensor`, `unpack_tensor_to_dict` - `bitsandbytes.utils.execute_and_return`, `sync_gpu` - `bitsandbytes.optim.optimizer.MockArgs` diff --git a/agents/architecture_guide.md b/agents/architecture_guide.md index d351d03be..f67885266 100644 --- a/agents/architecture_guide.md +++ b/agents/architecture_guide.md @@ -962,8 +962,8 @@ The `COMPUTE_BACKEND` CMake variable selects the target: | Backend | Library name | Languages | Dependencies | |---|---|---|---| | `cpu` | `libbitsandbytes_cpu.so` | C++17 | OpenMP (optional) | -| `cuda` | `libbitsandbytes_cuda{VER}.so` | C++17 + CUDA | cudart, cublas, cublasLt, cusparse | -| `hip` | `libbitsandbytes_rocm{VER}.so` | C++17 + HIP | hipblas, hiprand, hipsparse | +| `cuda` | `libbitsandbytes_cuda{VER}.so` | C++17 + CUDA | cudart, cublas, cublasLt | +| `hip` | `libbitsandbytes_rocm{VER}.so` | C++17 + HIP | hipblas, hiprand | | `mps` | `libbitsandbytes_mps.dylib` | C++17 + ObjC++ | Metal framework | | `xpu` | `libbitsandbytes_xpu.so` | C++20 + SYCL | Intel oneAPI | diff --git a/agents/code_standards.md b/agents/code_standards.md index 34420183d..27c6e2b3c 100644 --- a/agents/code_standards.md +++ b/agents/code_standards.md @@ -152,7 +152,7 @@ class GlobalOptimManager: ``` This pattern is used by: `GlobalOptimManager`, `GlobalPageManager`, `CUBLAS_Context`, -`Cusparse_Context`, `GlobalOutlierPooler`, `OutlierTracer`. +`GlobalOutlierPooler`, `OutlierTracer`. --- @@ -867,7 +867,6 @@ Use the project's error checking macros: ```cpp CUDA_CHECK_RETURN(cudaMemcpy(...)); -CHECK_CUSPARSE(cusparseCreate(...)); ``` The `checkCublasStatus` function returns an error code rather than throwing — the Python diff --git a/agents/issue_patterns.md b/agents/issue_patterns.md index 41447de89..92ccfadaf 100644 --- a/agents/issue_patterns.md +++ b/agents/issue_patterns.md @@ -34,9 +34,9 @@ These are the single largest category of issues. Most are environment problems o > > If you're still hitting problems on the **latest** bitsandbytes (v0.45+), please open a new issue with the output of `python -m bitsandbytes` and your environment details. -### Missing `libcusparse.so.11` / shared library mismatch +### Missing shared CUDA library / shared library mismatch -**How to identify:** `OSError: libcusparse.so.11: cannot open shared object file: No such file or directory`. Or similar errors for `libcusparse.so.12`, `libcublasLt.so.11`, etc. +**How to identify:** `OSError: libcublasLt.so.11: cannot open shared object file: No such file or directory`. Or similar errors for `libcudart`, `libcublas`, etc. **What happened:** The bnb binary was compiled against one CUDA version (e.g., 11.x) but the system only has another (e.g., 12.x). The shared library dependencies don't exist. Modern releases ship platform-specific wheels with better CUDA version detection and multiple binary variants. diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 7e1f59276..ab0ffc309 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -171,7 +171,7 @@ def _( A: torch.Tensor, threshold=0.0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Use CUDA kernel for rowwise and COO tensor + # Use CUDA kernel for rowwise quant and outlier column detection quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( A, threshold=threshold, diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 11a5cffb7..c3cec7281 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -90,7 +90,6 @@ class CudaBNBNativeLibrary(BNBNativeLibrary): def __init__(self, lib: ct.CDLL): super().__init__(lib) lib.get_context.restype = ct.c_void_p - lib.get_cusparse.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 65fa0a442..6ce846277 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -70,23 +70,6 @@ def get_context(self, device): return self.context[device.index] -class Cusparse_Context: - _instance = None - - def __init__(self): - raise RuntimeError("Call get_instance() instead") - - def initialize(self): - self.context = ct.c_void_p(lib.get_cusparse()) - - @classmethod - def get_instance(cls): - if cls._instance is None: - cls._instance = cls.__new__(cls) - cls._instance.initialize() - return cls._instance - - FIRST_CUDA_DEVICE = torch.device("cuda", index=0) # When multiple GPUs are present, we use a context manager to @@ -1557,87 +1540,6 @@ def int8_mm_dequant( return result -class COOSparseTensor: - def __init__( - self, rows: int, cols: int, nnz: int, rowidx: torch.Tensor, colidx: torch.Tensor, values: torch.Tensor - ): - assert rowidx.dtype == torch.int32 - assert colidx.dtype == torch.int32 - assert values.dtype == torch.float16 - assert values.numel() == nnz - assert rowidx.numel() == nnz - assert colidx.numel() == nnz - - self.rows = rows - self.cols = cols - self.nnz = nnz - self.rowidx = rowidx - self.colidx = colidx - self.values = values - - -class CSRSparseTensor: - def __init__(self, rows, cols, nnz, rowptr, colidx, values): - assert rowptr.dtype == torch.int32 - assert colidx.dtype == torch.int32 - assert values.dtype == torch.float16 - assert values.numel() == nnz - assert colidx.numel() == nnz - assert rowptr.numel() == rows + 1 - - self.rows = rows - self.cols = cols - self.nnz = nnz - self.rowptr = rowptr - self.colidx = colidx - self.values = values - - -class CSCSparseTensor: - def __init__(self, rows, cols, nnz, colptr, rowidx, values): - assert colptr.dtype == torch.int32 - assert rowidx.dtype == torch.int32 - assert values.dtype == torch.float16 - assert values.numel() == nnz - assert rowidx.numel() == nnz - assert colptr.numel() == cols + 1 - - self.rows = rows - self.cols = cols - self.nnz = nnz - self.colptr = colptr - self.rowidx = rowidx - self.values = values - - -def coo2csr(cooA): - values, counts = torch.unique(cooA.rowidx, return_counts=True) - values.add_(1) - rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) - rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) - rowptr.cumsum_(0) - return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values) - - -def coo2csc(cooA): - val, col2rowidx = torch.sort(cooA.colidx) - rowidx = cooA.rowidx[col2rowidx] - values = cooA.values[col2rowidx] - colvalues, counts = torch.unique(val, return_counts=True) - colvalues.add_(1) - colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) - colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) - colptr.cumsum_(0) - return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) - - -def coo_zeros(rows, cols, nnz, device, dtype=torch.half): - rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device) - colidx = torch.zeros((nnz,), dtype=torch.int32, device=device) - values = torch.zeros((nnz,), dtype=dtype, device=device) - return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) - - def int8_double_quant( A: torch.Tensor, col_stats: Optional[torch.Tensor] = None, @@ -1724,147 +1626,6 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): return torch.ops.bitsandbytes.int8_vectorwise_quant.default(A, threshold) -def spmm_coo( - cooA: COOSparseTensor | torch.Tensor, - B: torch.Tensor, - out: Optional[torch.Tensor] = None, -): - if not isinstance(cooA, COOSparseTensor): - assert cooA.is_sparse and cooA.layout == torch.sparse_coo, ( - "Tensor must be `COOSparseTensor or a PyTorch COO tensor." - ) - - # Convert to custom COOSparseTensor - cooA = COOSparseTensor( - rows=cooA.shape[0], - cols=cooA.shape[1], - nnz=cooA._nnz(), - rowidx=cooA.indices()[0].int(), - colidx=cooA.indices()[1].int(), - values=cooA.values(), - ) - - if out is None: - out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) - nnz = cooA.nnz - assert cooA.rowidx.numel() == nnz - assert cooA.colidx.numel() == nnz - assert cooA.values.numel() == nnz - assert cooA.cols == B.shape[0] - - transposed_B = not B.is_contiguous() - - ldb = B.stride()[(1 if transposed_B else 0)] - ldc = B.shape[1] - - ptr = Cusparse_Context.get_instance().context - - ptrRowidx = get_ptr(cooA.rowidx) - ptrColidx = get_ptr(cooA.colidx) - ptrValues = get_ptr(cooA.values) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - cnnz = ct.c_int32(cooA.nnz) - crowsA = ct.c_int32(cooA.rows) - ccolsA = ct.c_int32(cooA.cols) - ccolsB = ct.c_int32(B.shape[1]) - cldb = ct.c_int32(ldb) - cldc = ct.c_int32(ldc) - - is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) - lib.cspmm_coo( - ptr, - ptrRowidx, - ptrColidx, - ptrValues, - cnnz, - crowsA, - ccolsA, - ccolsB, - cldb, - ptrB, - cldc, - ptrC, - ct.c_bool(transposed_B), - ) - - return out - - -def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): - if out is None: - out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype) - nnz = cooA.nnz - - assert cooA.rowidx.numel() == nnz - assert cooA.colidx.numel() == nnz - assert cooA.values.numel() == nnz - assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}" - - _, counts = torch.unique(cooA.rowidx, return_counts=True) - offset = counts.cumsum(0).int() - max_count, max_idx = torch.sort(counts, descending=True) - max_idx = max_idx.int() - max_count = max_count.int() - assert max_count[0] <= 32, f"Current max count per row is 8 but found {max_count[0]}." - assert B.dtype in [torch.float16, torch.int8] - ptrOffset = get_ptr(offset) - ptrMaxCount = get_ptr(max_count) - ptrMaxIdx = get_ptr(max_idx) - - ptrRowidx = get_ptr(cooA.rowidx) - ptrColidx = get_ptr(cooA.colidx) - ptrValues = get_ptr(cooA.values) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrDequantStats = get_ptr(dequant_stats) - cnnz_rows = ct.c_int32(counts.numel()) - cnnz = ct.c_int32(cooA.nnz) - crowsA = ct.c_int32(cooA.rows) - crowsB = ct.c_int32(B.shape[1]) - ccolsB = ct.c_int32(B.shape[1]) - - with _cuda_device_of(B): - is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) - if B.dtype == torch.float16: - lib.cspmm_coo_very_sparse_naive_fp16( - ptrMaxCount, - ptrMaxIdx, - ptrOffset, - ptrRowidx, - ptrColidx, - ptrValues, - ptrB, - ptrC, - ptrDequantStats, - cnnz_rows, - cnnz, - crowsA, - crowsB, - ccolsB, - ) - elif B.dtype == torch.int8: - lib.cspmm_coo_very_sparse_naive_int8( - ptrMaxCount, - ptrMaxIdx, - ptrOffset, - ptrRowidx, - ptrColidx, - ptrValues, - ptrB, - ptrC, - ptrDequantStats, - cnnz_rows, - cnnz, - crowsA, - crowsB, - ccolsB, - ) - # else: assertion error - - return out - - def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState, block_n: int = 32): """ qweight: (K * N / 2) uint8 diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 1595a036c..dac6a2dc4 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1437,149 +1437,6 @@ __global__ void kdequant_mm_int32_fp16( } } -#define DENORM 1.0f / 127.0f -#define MAX_SPARSE_COUNT 32 -#define SMEM_SIZE 8 * 256 - -template -__global__ void kspmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, - float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB -) { - - // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block - // If a block finishes, the next one is scheduled. Since the last blocks like have fewer - // elements they finish faster "fillin up" the gaps left by larger blocks - - // without tensor cores - // 1. use rowidx_length to find what to load (as many blocks as there are rows) - // 2. Load A into registers - // 3. each warp loads all required rows of B but each warp is offset by k - // 4. Do mma operations that accumulate into registers - // 5. Each warp stores its output row into matrix C - - const int count = max_count[blockIdx.x]; - const int local_max_idx = max_idx[blockIdx.x]; - const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx - 1]; - const int local_row_idx = rowidx[offset]; - - const int warp_id = threadIdx.x / 32; - const int warp_idx = threadIdx.x % 32; - const int warp_offset = (warp_id * 32) * SPMM_ITEMS; - const int num_items = BITS == 8 ? 8 : 8; - int idx_col_B = warp_offset; - int local_idx_col_B_offset = 0; - - half local_valA[MAX_SPARSE_COUNT]; - int local_colidxA[MAX_SPARSE_COUNT]; - half local_valC[SPMM_ITEMS]; - T local_valsB[num_items]; - half local_valOut[num_items]; - // 128 byte loads per warp == 4 bytes per thread - - // 2. Load A into registers - for (int j = 0; j < MAX_SPARSE_COUNT; j++) { - local_valA[j] = j < count ? values[offset + j] : __float2half(0.0f); - local_colidxA[j] = j < count ? colidx[offset + j] : 0; - } - - // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 - // we expect each warp to be SPMM_ITEMS*32 apart - // we have a total of 128 bytes for the bank with a bank size of 4 bytes - // added 3 bytes = 6 values between warps should reduce bank conflicts - __shared__ half smem_dequant_stats[SMEM_SIZE]; - - while (idx_col_B < colsB) { - - if (dequant_stats != NULL) { - for (int i = threadIdx.x; i < SMEM_SIZE; i += blockDim.x) - if ((idx_col_B + i - local_idx_col_B_offset) < colsB) - smem_dequant_stats[i] = dequant_stats[idx_col_B + i - local_idx_col_B_offset]; - - __syncthreads(); - } - -#pragma unroll SPMM_ITEMS - for (int j = 0; j < SPMM_ITEMS; j++) - local_valC[j] = 0.0f; - -#pragma unroll - for (int i = 0; i < count; i++) { - // 3. each warp loads all required rows of B but each warp is offset by k - int row_offset = colsB * local_colidxA[i]; - -#pragma unroll SPMM_ITEMS - for (int j = 0; j < SPMM_ITEMS; j += num_items) { - // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached - int idx = idx_col_B + (warp_idx * SPMM_ITEMS) + j; - if (idx >= colsB) { - break; - } - if ((idx + num_items < colsB)) { - if (BITS == 8) - reinterpret_cast(local_valsB)[0] = - reinterpret_cast(B)[(row_offset + idx) / num_items]; - else - reinterpret_cast(local_valsB)[0] = - reinterpret_cast(B)[(row_offset + idx) / num_items]; - } else { -#pragma unroll num_items - for (int k = 0; k < num_items; k++) - if (idx + k < colsB) - local_valsB[k] = B[row_offset + idx + k]; - else - local_valsB[k] = 0.0f; - } -#pragma unroll num_items - for (int k = 0; k < num_items; k++) { - if (BITS == 8 && dequant_stats != NULL) - // we do texture cache reads (__ldg) on dequant_stats which should be super fast - { - float valB = local_valsB[k]; - float valA = local_valA[i]; - if (valB != 0.0 && valA != 0.0) - local_valC[j + k] = - (float)local_valC[j + k] + - ((float)smem_dequant_stats[idx + k - local_idx_col_B_offset]) * DENORM * valB * valA; - } else - local_valC[j + k] = (float)local_valC[j + k] + (float)local_valsB[k] * (float)local_valA[i]; - } - } - } - - int idx_row_C = (colsB * local_row_idx); - -#pragma unroll SPMM_ITEMS - for (int j = 0; j < SPMM_ITEMS; j += num_items) { - // int idx_col_C = idx_col_B + (32*j) + warp_idx; - int idx_col_C = idx_col_B + warp_idx * SPMM_ITEMS + j; - int idx_val = idx_col_C + idx_row_C; - - if (idx_col_C + num_items < colsB) { - - // load outputs to do inplace addition - reinterpret_cast(local_valOut)[0] = - reinterpret_cast(out)[idx_val / num_items]; - -#pragma unroll num_items - for (int k = 0; k < num_items; k++) - local_valC[(j / num_items) + k] = (float)local_valC[(j / num_items) + k] + (float)local_valOut[k]; - - reinterpret_cast(out)[idx_val / num_items] = - reinterpret_cast(local_valC)[j / num_items]; - } else { -#pragma unroll num_items - for (int k = 0; k < num_items; k++) - if (idx_col_C + k < colsB) - out[idx_val + k] = (float)out[idx_val + k] + (float)local_valC[j + k]; - } - } - - idx_col_B += blockDim.x * SPMM_ITEMS; - local_idx_col_B_offset += blockDim.x * SPMM_ITEMS; - } -} - #define num_values_4bit 32 template @@ -1737,31 +1594,6 @@ template __global__ void kgemm_4bit_inference_naive( float* out, int lda, int ldb, int ldc, int blocksize ); -template __global__ void kspmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, - float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB -); -template __global__ void kspmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, - float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB -); -template __global__ void kspmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, - float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB -); -template __global__ void kspmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, - float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB -); -template __global__ void kspmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, - float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB -); -template __global__ void kspmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, - float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB -); - template __global__ void kdequant_mm_int32_fp16<4, 512>( int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, half* __restrict__ const bias, const int numRows, const int numCols, const int n diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 3d5237a46..6de55f2e8 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -65,12 +65,6 @@ __global__ void kOptimizerStatic8bit1StateBlockwise( const float gnorm_scale, const bool skip_zeros, const int n ); -template -__global__ void kspmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, - float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB -); - template __global__ void kdequant_mm_int32_fp16( int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 59cff0028..691f6e07c 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -1532,152 +1532,6 @@ __global__ void kdequant_mm_int32_fp16( } } -#define DENORM 1.0f/127.0f -#define MAX_SPARSE_COUNT 32 -#define SMEM_SIZE 8*256 -template -__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) -{ - - // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block - // If a block finishes, the next one is scheduled. Since the last blocks like have fewer - // elements they finish faster "fillin up" the gaps left by larger blocks - - // without tensor cores - // 1. use rowidx_length to find what to load (as many blocks as there are rows) - // 2. Load A into registers - // 3. each warp loads all required rows of B but each warp is offset by k - // 4. Do mma operations that accumulate into registers - // 5. Each warp stores its output row into matrix C - - const int count = max_count[blockIdx.x]; - const int local_max_idx = max_idx[blockIdx.x]; - const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; - const int local_row_idx = rowidx[offset]; - - const int warp_id = threadIdx.x / BNB_WARP_SIZE; - const int warp_idx = threadIdx.x % BNB_WARP_SIZE; - const int warp_offset = (warp_id*BNB_WARP_SIZE)*SPMM_ITEMS; - const int num_items = BITS == 8 ? 8 : 8; - int idx_col_B = warp_offset; - int local_idx_col_B_offset = 0; - - half local_valA[MAX_SPARSE_COUNT]; - int local_colidxA[MAX_SPARSE_COUNT]; - half local_valC[SPMM_ITEMS]; - T local_valsB[num_items]; - half local_valOut[num_items]; - // 128 byte loads per warp == 4 bytes per thread - - // 2. Load A into registers - for(int j = 0; j < MAX_SPARSE_COUNT; j++) - { - local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f); - local_colidxA[j] = j < count ? colidx[offset+j] : 0; - } - - // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 - // we expect each warp to be SPMM_ITEMS*BNB_WARP_SIZE apart - // we have a total of 128 bytes for the bank with a bank size of 4 bytes - // added 3 bytes = 6 values between warps should reduce bank conflicts - __shared__ half smem_dequant_stats[SMEM_SIZE]; - - - while(idx_col_B < colsB) - { - - if(dequant_stats != NULL) - { - for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) - if((idx_col_B+i-local_idx_col_B_offset) < colsB) - smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; - - __syncthreads(); - } - - #pragma unroll SPMM_ITEMS - for(int j = 0; j < SPMM_ITEMS; j++) - local_valC[j] = 0.0f; - - #pragma unroll - for(int i = 0; i < count; i++) - { - // 3. each warp loads all required rows of B but each warp is offset by k - int row_offset = colsB*local_colidxA[i]; - - #pragma unroll SPMM_ITEMS - for(int j = 0; j < SPMM_ITEMS; j+=num_items) - { - // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached - int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; - if(idx >= colsB){ break; } - if((idx+num_items < colsB)) - { - if(BITS == 8) - reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; - else - reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; - } - else - { - #pragma unroll num_items - for(int k = 0; k < num_items; k++) - if(idx+k < colsB) - local_valsB[k] = B[row_offset+idx+k]; - else - local_valsB[k] = 0.0f; - } - #pragma unroll num_items - for(int k = 0; k < num_items; k++) - { - if(BITS == 8 && dequant_stats != NULL) - // we do texture cache reads (__ldg) on dequant_stats which should be super fast - { - float valB = local_valsB[k]; - float valA = local_valA[i]; - if(valB != 0.0 && valA != 0.0) - local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; - } - else - local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; - } - } - } - - int idx_row_C = (colsB*local_row_idx); - - #pragma unroll SPMM_ITEMS - for(int j = 0; j < SPMM_ITEMS; j+=num_items) - { - //int idx_col_C = idx_col_B + (32*j) + warp_idx; - int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; - int idx_val = idx_col_C + idx_row_C; - - if(idx_col_C +num_items < colsB) - { - - // load outputs to do inplace addition - reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; - - #pragma unroll num_items - for(int k = 0; k < num_items; k++) - local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; - - reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; - } - else - { - #pragma unroll num_items - for(int k = 0; k < num_items; k++) - if(idx_col_C + k < colsB) - out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; - } - } - - idx_col_B += blockDim.x*SPMM_ITEMS; - local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; - } -} // No of 4bit values processed by each thread #define num_values_4bit 32 @@ -1840,12 +1694,6 @@ template __global__ void kgemm_4bit_inference_naive(int M, int N, template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, hip_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh index 725e9fabe..0e2885693 100644 --- a/csrc/kernels_hip.cuh +++ b/csrc/kernels_hip.cuh @@ -67,12 +67,6 @@ __global__ void kOptimizerStatic8bit1StateBlockwise( const float gnorm_scale, const bool skip_zeros, const int n ); -template -__global__ void kspmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, - float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB -); - template __global__ void kdequant_mm_int32_fp16( int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, diff --git a/csrc/ops.cu b/csrc/ops.cu index 9009a24c9..88bb675a3 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -316,66 +316,6 @@ void int8VectorQuant( CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void spmm_coo( - cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, - int ldb, half* B, int ldc, half* C, bool transposed_B -) { - cusparseSpMatDescr_t descA; - cusparseDnMatDescr_t descB, descC; - - float alpha = 1.0f; - float beta = 0.0f; - void* dBuffer = NULL; - size_t bufferSize = 0; - - CHECK_CUSPARSE(cusparseCreateCoo( - &descA, A_rows, A_cols, A_nnz, A_rowidx, A_colidx, A_vals, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, - CUDA_R_16F - )); - // Create dense matrix C - CHECK_CUSPARSE(cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, CUDA_R_16F, CUSPARSE_ORDER_ROW)); - // Create dense matrix B - if (transposed_B) { - int tmp = A_cols; - A_cols = B_cols; - B_cols = tmp; - } - - CHECK_CUSPARSE(cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, CUDA_R_16F, CUSPARSE_ORDER_ROW)); - // allocate an external buffer if needed - CHECK_CUSPARSE(cusparseSpMM_bufferSize( - handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, descA, descB, &beta, - descC, CUDA_R_32F, CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize - )); - CUDA_CHECK_RETURN(cudaMalloc(&dBuffer, bufferSize)); - - // execute SpMM - CHECK_CUSPARSE(cusparseSpMM( - handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, descA, descB, &beta, - descC, CUDA_R_32F, CUSPARSE_SPMM_ALG_DEFAULT, dBuffer - )); - - // destroy matrix/vector descriptors - CHECK_CUSPARSE(cusparseDestroySpMat(descA)); - CHECK_CUSPARSE(cusparseDestroyDnMat(descB)); - CHECK_CUSPARSE(cusparseDestroyDnMat(descC)); - CUDA_CHECK_RETURN(cudaFree(dBuffer)); -} - -template -void spmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, - float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB -) { - - kspmm_coo_very_sparse_naive<<>>( - max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB - ); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); -} - template void gemm_4bit_inference_naive( int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, @@ -419,15 +359,6 @@ template void gemm_4bit_inference_naive( int ldc, int blocksize, cudaStream_t stream ); -template void spmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, - float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB -); -template void spmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, - float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB -); - template int igemmlt<32, 0>( cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, int ldb, int ldc, cudaStream_t stream diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 490459c90..4d3af547f 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -16,7 +16,6 @@ #include #include #include -#include #include #include @@ -29,17 +28,6 @@ } \ } -#define CHECK_CUSPARSE(value) \ - { \ - cusparseStatus_t _m_cudaStat = value; \ - if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \ - fprintf( \ - stderr, "Error %s at line %d in file %s\n", cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__ \ - ); \ - exit(1); \ - } \ - } - inline void checkCudaStatus(cudaError_t status) { if (status != cudaSuccess) { printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status)); @@ -98,17 +86,6 @@ class ContextLt { } }; -class ContextCusparse { - public: - cusparseHandle_t m_handle; - - ContextCusparse() { - cusparseHandle_t handle; - cusparseCreate(&handle); - m_handle = handle; - } -}; - template void quantizeBlockwise( float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n @@ -157,17 +134,6 @@ void int8VectorQuant( half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream ); -void spmm_coo( - cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, - int ldb, half* B, int ldc, half* C, bool transposed_B -); - -template -void spmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, - float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB -); - template void gemm_4bit_inference_naive( int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, diff --git a/csrc/ops.hip b/csrc/ops.hip index e547b5e7a..937f8f249 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -10,7 +10,6 @@ #include #include #include -#include #ifndef NO_HIPBLASLT #include #endif @@ -454,67 +453,6 @@ void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float CUDA_CHECK_RETURN(hipPeekAtLastError()); } -void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) -{ - -#ifdef NO_HIPBLASLT -#else - - hipsparseSpMatDescr_t descA; - hipsparseDnMatDescr_t descB, descC; - - float alpha = 1.0f; - float beta = 0.0f; - void *dBuffer = NULL; - size_t bufferSize = 0; - - CHECK_HIPSPARSE( hipsparseCreateCoo(&descA, A_rows, A_cols, A_nnz, - A_rowidx, A_colidx, A_vals, - HIPSPARSE_INDEX_32I, - HIPSPARSE_INDEX_BASE_ZERO, HIP_R_16F) ); - // Create dense matrix C - CHECK_HIPSPARSE( hipsparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, - HIP_R_16F, HIPSPARSE_ORDER_ROW) ); - // Create dense matrix B - if(transposed_B) - { - int tmp = A_cols; - A_cols = B_cols; - B_cols = tmp; - } - - CHECK_HIPSPARSE( hipsparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, - HIP_R_16F, HIPSPARSE_ORDER_ROW) ); - // allocate an external buffer if needed - CHECK_HIPSPARSE( hipsparseSpMM_bufferSize( - handle, - HIPSPARSE_OPERATION_NON_TRANSPOSE, - transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, - &alpha, descA, descB, &beta, descC, HIP_R_32F, - HIPSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); - CUDA_CHECK_RETURN( hipMalloc(&dBuffer, bufferSize) ); - - // execute SpMM - CHECK_HIPSPARSE( hipsparseSpMM(handle, - HIPSPARSE_OPERATION_NON_TRANSPOSE, - transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, - &alpha, descA, descB, &beta, descC, HIP_R_32F, - HIPSPARSE_SPMM_ALG_DEFAULT, dBuffer)); - - // destroy matrix/vector descriptors - CHECK_HIPSPARSE( hipsparseDestroySpMat(descA) ); - CHECK_HIPSPARSE( hipsparseDestroyDnMat(descB) ); - CHECK_HIPSPARSE( hipsparseDestroyDnMat(descC) ); - CUDA_CHECK_RETURN( hipFree(dBuffer) ); -#endif -} - -template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) -{ - - hipLaunchKernelGGL(( kspmm_coo_very_sparse_naive), dim3(nnz_rows), dim3(256), 0, 0, max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream) { @@ -553,8 +491,6 @@ template void gemm_4bit_inference_naive(int m, int n, int k, half * A, template void gemm_4bit_inference_naive(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); -template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); -template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template int igemmlt<32, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); template int igemmlt<8, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index 5472463ed..6e884df00 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -25,7 +25,6 @@ #include #include #include -#include #include #include @@ -38,17 +37,6 @@ } \ } -#define CHECK_HIPSPARSE(value) \ - { \ - hipsparseStatus_t _m_hipStat = value; \ - if (_m_hipStat != HIPSPARSE_STATUS_SUCCESS) { \ - fprintf( \ - stderr, "Error %s at line %d in file %s\n", hipsparseGetErrorString(_m_hipStat), __LINE__, __FILE__ \ - ); \ - exit(1); \ - } \ - } - inline void checkHipStatus(hipError_t status) { if (status != hipSuccess) { printf("hip API failed with status %d: %s\n", status, hipGetErrorString(status)); @@ -107,17 +95,6 @@ class ContextLt { } }; -class ContextHipsparse { - public: - hipsparseHandle_t m_handle; - - ContextHipsparse() { - hipsparseHandle_t handle; - hipsparseCreate(&handle); - m_handle = handle; - } -}; - template void quantizeBlockwise( float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n @@ -166,17 +143,6 @@ void int8VectorQuant( half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, hipStream_t stream ); -void spmm_coo( - hipsparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, - int ldb, half* B, int ldc, half* C, bool transposed_B -); - -template -void spmm_coo_very_sparse_naive( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, - float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB -); - template void gemm_4bit_inference_naive( int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 6a36d1962..aee7a4d25 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -23,8 +23,6 @@ #define cudaStream_t hipStream_t #define __nv_bfloat16 hip_bfloat16 #define cublasLtHandle_t hipblasLtHandle_t -#define ContextCusparse ContextHipsparse -#define cusparseHandle_t hipsparseHandle_t #define cudaMallocManaged hipMallocManaged #define cudaMemAttachHost hipMemAttachHost #define cudaPeekAtLastError hipPeekAtLastError @@ -249,25 +247,6 @@ int igemmlt_8_rowscale( return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } -void spmm_coo_very_sparse_naive_fp16( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, - float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB -) { - spmm_coo_very_sparse_naive( - max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, - colsB - ); -} - -void spmm_coo_very_sparse_naive_int8( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, - float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB -) { - spmm_coo_very_sparse_naive( - max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, - colsB - ); -} #endif #if BUILD_XPU @@ -535,8 +514,6 @@ void cbatched_igemm( Context* get_context() { return new Context(); } -ContextCusparse* get_cusparse() { return new ContextCusparse(); } - int cigemmlt_32( Context* context, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, int lda, int ldb, int ldc, cudaStream_t stream @@ -570,36 +547,6 @@ void cint8_vector_quant( int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream); } -void cspmm_coo( - ContextCusparse* context, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, - int ldb, half* B, int ldc, half* C, bool transposed_B -) { - spmm_coo( - (cusparseHandle_t)context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, - transposed_B - ); -} - -void cspmm_coo_very_sparse_naive_fp16( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out, - float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB -) { - spmm_coo_very_sparse_naive_fp16( - max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, - colsB - ); -} - -void cspmm_coo_very_sparse_naive_int8( - int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, signed char* B, half* out, - float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB -) { - spmm_coo_very_sparse_naive_int8( - max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, - colsB - ); -} - void* cget_managed_ptr(size_t bytes) { void* ptr; CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); diff --git a/tests/test_functional.py b/tests/test_functional.py index 9febe8212..d2f5152bb 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -10,7 +10,7 @@ import bitsandbytes as bnb from bitsandbytes import functional as F -from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_WARP_SIZE_64 +from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, @@ -838,280 +838,6 @@ def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") -class TestSpMMFunctional: - @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) - @pytest.mark.parametrize("dim2", [128, 512], ids=id_formatter("dim2")) - @pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B")) - def test_spmm_coo(self, dim1, dim2, transposed_B): - threshold = 1.5 - dim3 = torch.randint(32, 128, size=(1,)).item() - # dim3 = 17 - for i in range(k): - A = torch.randn(dim1, dim2).cuda().half() - if transposed_B: - B = torch.randn(dim3, dim2).cuda().half() - else: - B = torch.randn(dim2, dim3).cuda().half() - - idx = torch.abs(A) >= threshold - nnz = (idx == 1).sum().item() - rows, cols = torch.where(idx) - values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A * idx - - if transposed_B: - out2 = F.spmm_coo(cooA, B.t()) - out1 = torch.matmul(A2, B.t()) - else: - out2 = F.spmm_coo(cooA, B) - out1 = torch.matmul(A2, B) - - assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30) - - @pytest.mark.benchmark - def test_spmm_bench(self): - batch = 2 - model = 1024 * 1 - hidden = model * 4 - seq = 1024 - dim1 = batch * seq - dim2 = model - dim3 = hidden - threshold = 4 - A = torch.randn(dim1, dim2, device="cuda").half() - B = torch.randn(dim2, dim3, device="cuda").half() - for i in range(10): - C1 = bnb.matmul(A, B.t()) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(k): - C1 = bnb.matmul(A, B.t()) - torch.cuda.synchronize() - t8 = time.time() - t0 - - idx = torch.abs(A) >= threshold - nnz = (idx == 1).sum().item() - print(nnz / idx.numel()) - rows, cols = torch.where(idx) - values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - - for i in range(10): - out2 = F.spmm_coo(cooA, B) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(k): - out2 = F.spmm_coo(cooA, B) - torch.cuda.synchronize() - tsp = time.time() - t0 - print(tsp, t8) - print(tsp / t8) - - @pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1")) - @pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2")) - @pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype) - @pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func")) - def test_spmm_coo_very_sparse(self, dim1, dim2, dtype, out_func): - out_func = getattr(torch, out_func) - - threshold = 3.3 - # threshold = 2.8 - # threshold = 0.0 - A = torch.randn(dim1, dim2, device="cuda").half() - if dtype == torch.float16: - B = torch.randn(dim2, dim2 * 4, device="cuda").half() - torch.nn.init.xavier_uniform_(B) - else: - B = torch.randn(dim2, dim2 * 4, device="cuda").half() - torch.nn.init.xavier_uniform_(B) - - SB = torch.abs(B).max().float() - B = torch.round(B / SB * 127).to(torch.int8) - - print("") - idx = torch.abs(A) >= threshold - nnz = (idx == 1).sum().item() - rows, cols = torch.where(idx) - values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A * idx - out1 = torch.matmul(A2.half(), B.half()) - out = out_func(out1.shape, dtype=torch.float16, device=out1.device) - out1 += out.clone() - out2 = F.spmm_coo_very_sparse(cooA, B, out=out) - # print(B) - # print(out1) - # print(out2) - p = 200 / (2048 * 12288 * 4) - n = out1.numel() - count = math.ceil(p * n) - std = out1.std() - out1 /= std - out2 /= std - assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) - # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) - - idx_col = torch.randint(0, A2.shape[-1], size=(15,)) - - # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001) - - # Bt = torch.randn(dim2*4, dim2, device='cuda').half() - # torch.cuda.synchronize() - # t0 = time.time() - # print(A2.shape, B.shape) - # for i in range(100): - # #out3 = F.spmm_coo(cooA, Bt.t()) - # #out2 = F.spmm_coo(cooA, B) - # #out2 = F.spmm_coo_very_sparse(cooA, B) - # #out1 = torch.matmul(A, Bt.t()) - - # torch.cuda.synchronize() - # print(time.time() - t0) - - @pytest.mark.parametrize("dim1", [1 * 2048]) - @pytest.mark.parametrize("dim2", [2048]) - @pytest.mark.parametrize("dtype", [torch.int8]) - def test_spmm_coo_dequant(self, dim1, dim2, dtype): - threshold = 6.0 - # threshold = 2.8 - # threshold = 0.0 - A = torch.randn(dim1, dim2, device="cuda").half() - B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16) - torch.nn.init.xavier_uniform_(B) - Bt = B.t().contiguous() - - _CB, CBt, _statsB, statsBt, _coo_tensor = F.int8_double_quant(B) - - rowidx = torch.randint(0, A.shape[-1], size=(15,)) - - A[:, rowidx] = 8.0 - - idx = torch.abs(A) >= threshold - nnz = (idx == 1).sum().item() - rows, cols = torch.where(idx) - values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A * idx - out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) - out1 = torch.matmul(A2, B.half()) - out3 = F.spmm_coo_very_sparse(cooA, CBt.half()) - out3 = out3 * statsBt.half() / 127 - - values, counts = torch.unique(cooA.rowidx, return_counts=True) - offset = counts.cumsum(0).int() - max_count, _ = torch.sort(counts, descending=True) - print(torch.median(max_count.float())) - - torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001) - - p = 200 / (2048 * 12288 * 4) - n = out1.numel() - count = math.ceil(p * n) - assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count) - - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(100): - # out2 = F.spmm_coo_very_sparse(cooA, B) - # torch.cuda.synchronize() - # print('fp16', time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - out2 = F.spmm_coo(cooA, B) - torch.cuda.synchronize() - print("cusparse fp16", time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - out2 = F.spmm_coo_very_sparse(cooA, CBt) - torch.cuda.synchronize() - print("int8", time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) - torch.cuda.synchronize() - print("int8+dequant", time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - out2 = torch.matmul(A, B) - torch.cuda.synchronize() - print("matmul", time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - out1 = bnb.matmul(A, Bt) - out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) - out = out1 + out2 - torch.cuda.synchronize() - print("sparse+ matmul", time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - out1 = bnb.matmul(A, Bt) - torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1) - torch.cuda.synchronize() - print("partial matmul", time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - out1 = bnb.matmul(A, Bt) - torch.cuda.synchronize() - print("partial matmul", time.time() - t0) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") -class TestSparseTensorFunctional: - def test_coo2csr(self): - threshold = 1 - A = torch.randn(128, 128).half().cuda() - idx = torch.abs(A) >= threshold - nnz = (idx == 1).sum().item() - rows, cols = torch.where(idx) - values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A * idx - csrA = F.coo2csr(cooA) - counts = csrA.rowptr[1:] - csrA.rowptr[:-1] - assert counts.numel() == A.shape[0] - - torch.testing.assert_close(counts.long(), (A2 != 0).sum(1)) - idx = A2 != 0 - torch.testing.assert_close(A2[idx], csrA.values) - - def test_coo2csc(self): - threshold = 1 - A = torch.randn(128, 128).half().cuda() - idx = torch.abs(A) >= threshold - nnz = (idx == 1).sum().item() - rows, cols = torch.where(idx) - values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A * idx - cscA = F.coo2csc(cooA) - counts = cscA.colptr[1:] - cscA.colptr[:-1] - assert counts.numel() == A.shape[1] - - torch.testing.assert_close(counts.long(), (A2 != 0).sum(0)) - # torch uses row-major -> use transpose to transfer to col-major - idx = A2.t() != 0 - torch.testing.assert_close(A2.t()[idx], cscA.values) - - class TestQuantize4BitFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)