Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ def _(
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
Expand Down Expand Up @@ -312,7 +311,6 @@ def _(
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
Expand Down
7 changes: 6 additions & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state)


# Above this limit, inference falls back to the dequantize + GEMM path.
FUSED_4BIT_DEQUANT_LIMIT = 8


def matmul_4bit(
A: torch.Tensor,
B: torch.Tensor,
Expand All @@ -391,7 +395,8 @@ def matmul_4bit(
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)

if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
num_a_rows = A.numel() // A.shape[-1]
if num_a_rows <= FUSED_4BIT_DEQUANT_LIMIT and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
Expand Down
5 changes: 3 additions & 2 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,11 @@ def _gemv_4bit_impl(
# torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")

m = ct.c_int32(shapeB[0])
n = ct.c_int32(1)
num_a_rows = A.numel() // A.shape[-1]
n = ct.c_int32(num_a_rows)
k = ct.c_int32(shapeB[1])

lda = m
lda = ct.c_int32(A.shape[-1])
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
ldc = m

Expand Down
221 changes: 136 additions & 85 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1446,116 +1446,155 @@ __global__ void kgemm_4bit_inference_naive(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out,
int lda, int ldb, int ldc, int blocksize
) {

// per threadblock:
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
// THREADS/BNB_WARP_SIZE warps -> that many loads per iter
// 1xwarp_size * warp_size x warps -> 1 x warps outputs per thread block
typedef bnb_cub::WarpReduce<float> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS / BNB_WARP_SIZE];

const int warp_idx = threadIdx.x / BNB_WARP_SIZE;
const int warp_lane = threadIdx.x % BNB_WARP_SIZE;
const int row_B = (THREADS / BNB_WARP_SIZE) * blockIdx.x + warp_idx;
const int offset_B = ldb * row_B;
const int num_values_8bit = num_values_4bit / 2;
float local_C = 0.0f;
constexpr int num_values_8bit = num_values_4bit / 2;

float local_C0 = 0.0f;
float local_C1 = 0.0f;
float local_C2 = 0.0f;
float local_C3 = 0.0f;

unsigned char local_B_4bit[num_values_8bit];
T local_B[num_values_4bit / 4];
T local_A[num_values_4bit / 4];
__shared__ T quant_map[16];
T local_absmax = T(0.0f);

if (threadIdx.x < 16)
quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x]));
// for(int i = threadIdx.x; i < 16; i++)
// quant_map[i] = T(__ldg(&datatype[i]));
__shared__ float quant_map[32];
float local_absmax = 0.0f;

if (threadIdx.x < 16) {
float val = __ldg(&datatype[threadIdx.x]);
quant_map[threadIdx.x] = val;
quant_map[threadIdx.x + 16] = val;
}
__syncthreads();

// A: [1, K]
// B: [N, K]
for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE * num_values_4bit) {
const int inner_idx_halved = inner_idx / 2;
if (row_B >= M) return;

// Since blocksize will always be a power-of-2, we avoid more expensive
// division by the blocksize and instead use a shift operation.
// This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize.
const int absidx = ((2 * offset_B) + inner_idx) >> (31 - __clz(blocksize));
const int stride = BNB_WARP_SIZE * num_values_4bit;
const int clz_blocksize = 31 - __clz(blocksize);
const int base_absidx = 2 * offset_B;
const int qm_offset = (warp_lane & 1) << 4;

local_absmax = __ldg(&(absmax[absidx]));
for (int n_idx = 0; n_idx < N; n_idx++) {
const T* __restrict__ A_row = A + n_idx * lda;

if (row_B < M) {
if ((inner_idx_halved + num_values_8bit) < (K / 2)) {
// this is the most important for performance considerations
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] =
reinterpret_cast<int4*>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)];
} else {
#pragma unroll
for (int j = 0; j < (num_values_8bit); j++)
if ((inner_idx_halved) + j < (K / 2))
local_B_4bit[j] = B[offset_B + inner_idx_halved + j];
else
local_B_4bit[j] = 0b01110111;
}
} else {
#pragma unroll
for (int j = 0; j < (num_values_8bit); j++)
local_B_4bit[j] = 0b01110111;
local_C0 = 0.0f;
local_C1 = 0.0f;
local_C2 = 0.0f;
local_C3 = 0.0f;

int inner_idx = warp_lane * num_values_4bit;
int inner_idx_halved = inner_idx >> 1;
int4 prefetch_B;
float prefetch_absmax;

if (inner_idx < K) {
prefetch_absmax = __ldg(&absmax[(base_absidx + inner_idx) >> clz_blocksize]);
if ((inner_idx_halved + num_values_8bit) < (K >> 1))
prefetch_B = reinterpret_cast<int4*>(B)[(offset_B + inner_idx_halved) / num_values_8bit];
}

for (int i = 0; i < 4; i++) {
#pragma unroll
for (int k = 0; k < num_values_8bit / 4; k++) {
#if BNB_BF16_AVAILABLE
local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax;
local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax;
#else
// bf16 multipliation not supported
local_B[k * 2] =
T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * (float)local_absmax);
local_B[k * 2 + 1] =
T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * (float)local_absmax);
#endif
}
for (; inner_idx < K; inner_idx += stride) {
inner_idx_halved = inner_idx >> 1;

if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) {
// this is also relatively important for performance
if (BITS == 16) {
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] =
reinterpret_cast<int4*>(A)[inner_idx / (num_values_4bit / 4) + i];
} else {
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] =
reinterpret_cast<int4*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0];
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] =
reinterpret_cast<int4*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1];
}
local_absmax = prefetch_absmax;

} else
if (__builtin_expect((inner_idx_halved + num_values_8bit) < (K >> 1), 1)) {
reinterpret_cast<int4&>(local_B_4bit[0]) = prefetch_B;
} else {
#pragma unroll
for (int k = 0; k < num_values_4bit / 4; k++)
if (inner_idx + (i * num_values_4bit / 4) + k < K)
local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)];
else
local_A[k] = T(0.0f);
for (int j = 0; j < num_values_8bit; j++)
local_B_4bit[j] = ((inner_idx_halved + j) < (K >> 1)) ? B[offset_B + inner_idx_halved + j] : 0x77;
}

int next_inner_idx = inner_idx + stride;
int next_inner_idx_halved = next_inner_idx >> 1;
if (next_inner_idx < K) {
prefetch_absmax = __ldg(&absmax[(base_absidx + next_inner_idx) >> clz_blocksize]);
if ((next_inner_idx_halved + num_values_8bit) < (K >> 1))
prefetch_B = reinterpret_cast<int4*>(B)[(offset_B + next_inner_idx_halved) / num_values_8bit];
}

// accumulate in float; small performance hit for Ampere, but lower error for outputs
float b0 = quant_map[qm_offset + (local_B_4bit[0] >> 4)] * local_absmax;
float b1 = quant_map[qm_offset + (local_B_4bit[0] & 0xF)] * local_absmax;
float b2 = quant_map[qm_offset + (local_B_4bit[1] >> 4)] * local_absmax;
float b3 = quant_map[qm_offset + (local_B_4bit[1] & 0xF)] * local_absmax;
float b4 = quant_map[qm_offset + (local_B_4bit[2] >> 4)] * local_absmax;
float b5 = quant_map[qm_offset + (local_B_4bit[2] & 0xF)] * local_absmax;
float b6 = quant_map[qm_offset + (local_B_4bit[3] >> 4)] * local_absmax;
float b7 = quant_map[qm_offset + (local_B_4bit[3] & 0xF)] * local_absmax;
float b8 = quant_map[qm_offset + (local_B_4bit[4] >> 4)] * local_absmax;
float b9 = quant_map[qm_offset + (local_B_4bit[4] & 0xF)] * local_absmax;
float b10 = quant_map[qm_offset + (local_B_4bit[5] >> 4)] * local_absmax;
float b11 = quant_map[qm_offset + (local_B_4bit[5] & 0xF)] * local_absmax;
float b12 = quant_map[qm_offset + (local_B_4bit[6] >> 4)] * local_absmax;
float b13 = quant_map[qm_offset + (local_B_4bit[6] & 0xF)] * local_absmax;
float b14 = quant_map[qm_offset + (local_B_4bit[7] >> 4)] * local_absmax;
float b15 = quant_map[qm_offset + (local_B_4bit[7] & 0xF)] * local_absmax;
float b16 = quant_map[qm_offset + (local_B_4bit[8] >> 4)] * local_absmax;
float b17 = quant_map[qm_offset + (local_B_4bit[8] & 0xF)] * local_absmax;
float b18 = quant_map[qm_offset + (local_B_4bit[9] >> 4)] * local_absmax;
float b19 = quant_map[qm_offset + (local_B_4bit[9] & 0xF)] * local_absmax;
float b20 = quant_map[qm_offset + (local_B_4bit[10] >> 4)] * local_absmax;
float b21 = quant_map[qm_offset + (local_B_4bit[10] & 0xF)] * local_absmax;
float b22 = quant_map[qm_offset + (local_B_4bit[11] >> 4)] * local_absmax;
float b23 = quant_map[qm_offset + (local_B_4bit[11] & 0xF)] * local_absmax;
float b24 = quant_map[qm_offset + (local_B_4bit[12] >> 4)] * local_absmax;
float b25 = quant_map[qm_offset + (local_B_4bit[12] & 0xF)] * local_absmax;
float b26 = quant_map[qm_offset + (local_B_4bit[13] >> 4)] * local_absmax;
float b27 = quant_map[qm_offset + (local_B_4bit[13] & 0xF)] * local_absmax;
float b28 = quant_map[qm_offset + (local_B_4bit[14] >> 4)] * local_absmax;
float b29 = quant_map[qm_offset + (local_B_4bit[14] & 0xF)] * local_absmax;
float b30 = quant_map[qm_offset + (local_B_4bit[15] >> 4)] * local_absmax;
float b31 = quant_map[qm_offset + (local_B_4bit[15] & 0xF)] * local_absmax;

if (__builtin_expect(inner_idx + 32 <= K, 1)) {
int4 a_vec0 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8];
int4 a_vec1 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8 + 1];
int4 a_vec2 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8 + 2];
int4 a_vec3 = reinterpret_cast<const int4*>(A_row)[inner_idx / 8 + 3];

const T* a0 = reinterpret_cast<const T*>(&a_vec0);
const T* a1 = reinterpret_cast<const T*>(&a_vec1);
const T* a2 = reinterpret_cast<const T*>(&a_vec2);
const T* a3 = reinterpret_cast<const T*>(&a_vec3);

local_C0 += (float)a0[0]*b0; local_C1 += (float)a0[1]*b1;
local_C2 += (float)a0[2]*b2; local_C3 += (float)a0[3]*b3;
local_C0 += (float)a0[4]*b4; local_C1 += (float)a0[5]*b5;
local_C2 += (float)a0[6]*b6; local_C3 += (float)a0[7]*b7;
local_C0 += (float)a1[0]*b8; local_C1 += (float)a1[1]*b9;
local_C2 += (float)a1[2]*b10; local_C3 += (float)a1[3]*b11;
local_C0 += (float)a1[4]*b12; local_C1 += (float)a1[5]*b13;
local_C2 += (float)a1[6]*b14; local_C3 += (float)a1[7]*b15;
local_C0 += (float)a2[0]*b16; local_C1 += (float)a2[1]*b17;
local_C2 += (float)a2[2]*b18; local_C3 += (float)a2[3]*b19;
local_C0 += (float)a2[4]*b20; local_C1 += (float)a2[5]*b21;
local_C2 += (float)a2[6]*b22; local_C3 += (float)a2[7]*b23;
local_C0 += (float)a3[0]*b24; local_C1 += (float)a3[1]*b25;
local_C2 += (float)a3[2]*b26; local_C3 += (float)a3[3]*b27;
local_C0 += (float)a3[4]*b28; local_C1 += (float)a3[5]*b29;
local_C2 += (float)a3[6]*b30; local_C3 += (float)a3[7]*b31;
} else {
float b_vals[32] = {b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15,
b16,b17,b18,b19,b20,b21,b22,b23,b24,b25,b26,b27,b28,b29,b30,b31};
#pragma unroll
for (int k = 0; k < num_values_4bit / 4; k++) {
#if BNB_BF16_AVAILABLE
local_C += (float)(local_A[k] * local_B[k]);
#else
// bf16 multipliation not supported
local_C += ((float)local_A[k] * (float)local_B[k]);
#endif
for (int k = 0; k < 32; k++) {
float a_val = (inner_idx + k < K) ? (float)A_row[inner_idx + k] : 0.0f;
local_C0 += a_val * b_vals[k];
}
}
}
}

local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
float local_C = local_C0 + local_C1 + local_C2 + local_C3;
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);

if (row_B < M && warp_lane == 0)
out[row_B] = T(local_C);
if (warp_lane == 0)
out[n_idx * ldc + row_B] = T(local_C);
}
}

template <typename T, int FUNC> __global__ void kfunc(T* A, T* B, T value, long n) {
Expand Down Expand Up @@ -1595,6 +1634,18 @@ template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(
int M, int N, int K, float* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype,
float* out, int lda, int ldb, int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference_naive<half, 64, 16>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, half* out,
int lda, int ldb, int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference_naive<bnb_bfloat16, 64, 16>(
int M, int N, int K, bnb_bfloat16* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype,
bnb_bfloat16* out, int lda, int ldb, int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference_naive<float, 64, 32>(
int M, int N, int K, float* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype,
float* out, int lda, int ldb, int ldc, int blocksize
);

template __global__ void kdequant_mm_int32_fp16<4, 512>(
int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out,
Expand Down
16 changes: 11 additions & 5 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -421,15 +421,21 @@ void gemm_4bit_inference_naive(
int blocksize, bnb_stream_t stream
) {

int num_blocks = (m + 3) / 4;
#if BNB_HIP
if (bnb_host_warp_size() == 64) {
num_blocks = (m + 1) / 2;
const int ws = bnb_host_warp_size();
int num_blocks = (m + 1) / 2;
if (ws == 32) {
kgemm_4bit_inference_naive<T, 64, BITS>
<<<num_blocks, 64, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
} else {
kgemm_4bit_inference_naive<T, 128, BITS>
<<<num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
}
#endif

#else
int num_blocks = (m + 3) / 4;
kgemm_4bit_inference_naive<T, 128, BITS>
<<<num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
#endif
BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR());
}

Expand Down