Skip to content

Commit 6d489cc

Browse files
C++ restructure WIP
1 parent 4d19869 commit 6d489cc

File tree

13 files changed

+1662
-1757
lines changed

13 files changed

+1662
-1757
lines changed

.clang-format

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ StatementMacros:
2121
- 'MAKE_optimizerStatic8bit2State'
2222
- 'MAKE_OptimizerStatic8bit1StateBlockwise'
2323
- 'MAKE_OptimizerStatic8bit2StateBlockwise'
24+
- 'MAKE_optimizerStatic8bit'
25+
- 'MAKE_optimizerStatic8bitBlockwise'
26+
- 'MAKE_optimizer32bit'
2427
- 'MAKE_kQuantizeBlockwise'
2528
- 'MAKE_BLOCKWISE8'
2629
- 'MAKE_ELEMENTWISE_FUNC'

CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ endif()
2424

2525
# Define included source files
2626
set(CPP_FILES csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
27-
set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
27+
set(CUDA_FILES csrc/ops.cu csrc/kernels.cu csrc/cuda/blockwise_quantization.cu csrc/cuda/int8.cu csrc/cuda/optimizers.cu)
2828
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
2929
set(MPS_FILES csrc/mps_ops.mm)
3030
set(METAL_FILES csrc/mps_kernels.metal)
@@ -312,7 +312,9 @@ if(BUILD_CUDA)
312312
set_target_properties(bitsandbytes
313313
PROPERTIES
314314
CUDA_SEPARABLE_COMPILATION ON
315+
CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE
315316
)
317+
316318
endif()
317319
if(BUILD_HIP)
318320
if(NOT DEFINED ENV{ROCM_PATH})
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
#include "common.cuh"
2+
#include "ops.cuh" // For CUDA_CHECK_RETURN, some typedefs
3+
#include <cub/cub.cuh>
4+
#include <cuda_fp16.h>
5+
6+
// from kernels.cu
7+
// TODO move somewhere like common.cuh or cub_utils.cuh etc
8+
#if CCCL_VERSION >= 2008002
9+
#include <cuda/std/functional>
10+
#define CUB_REDUCTIONOP_MAX \
11+
cuda::maximum<> {}
12+
#else
13+
#define CUB_REDUCTIONOP_MAX cub::Max()
14+
#endif
15+
16+
// copied from kernels.cu, todo
17+
#define NUM 4
18+
#define NUM_BLOCK 4096
19+
20+
// helper. todo: maybe move elsewhere. copied from kernels.cu
21+
// it is needed in deprecated optimizers too
22+
template <int STOCHASTIC> __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) {
23+
int pivot = 127;
24+
int upper_pivot = 255;
25+
int lower_pivot = 0;
26+
27+
float lower = -1.0f;
28+
float upper = 1.0f;
29+
30+
float val = smem_code[pivot];
31+
// i>>=1 = {32, 16, 8, 4, 2, 1}
32+
for (int i = 64; i > 0; i >>= 1) {
33+
if (x > val) {
34+
lower_pivot = pivot;
35+
lower = val;
36+
pivot += i;
37+
} else {
38+
upper_pivot = pivot;
39+
upper = val;
40+
pivot -= i;
41+
}
42+
val = smem_code[pivot];
43+
}
44+
45+
if (upper_pivot == 255)
46+
upper = smem_code[upper_pivot];
47+
if (lower_pivot == 0)
48+
lower = smem_code[lower_pivot];
49+
50+
if (!STOCHASTIC) {
51+
if (x > val) {
52+
float midpoint = (upper + val) * 0.5f;
53+
if (x > midpoint) {
54+
return upper_pivot;
55+
} else
56+
return pivot;
57+
} else {
58+
float midpoint = (lower + val) * 0.5f;
59+
if (x < midpoint)
60+
return lower_pivot;
61+
else
62+
return pivot;
63+
}
64+
} else {
65+
if (x > val) {
66+
float dist_to_upper = fabsf(upper - x);
67+
float dist_full = upper - val;
68+
if (rand >= dist_to_upper / dist_full)
69+
return upper_pivot;
70+
else
71+
return pivot;
72+
} else {
73+
float dist_to_lower = fabsf(lower - x);
74+
float dist_full = val - lower;
75+
if (rand >= dist_to_lower / dist_full)
76+
return lower_pivot;
77+
else
78+
return pivot;
79+
}
80+
}
81+
}
82+
83+
// helper. maybe move elsewhere TODO
84+
__device__ unsigned char dQuantizeFP4(float x) {
85+
// FP4 with bias of 3
86+
// first bit is a sign
87+
// subnormals
88+
// 0b000 = 0
89+
// 0b001 = 0.0625
90+
// 0b110 = 2
91+
// 0b111 = 3
92+
// 0b100 = 4
93+
// 0b101 = 6
94+
// 0b010 = 8
95+
// 0b011 = 12
96+
97+
// we do a binary search
98+
// the pivots are divided by 12 (the FP4 absmax)
99+
// since we assume input data is in [-1.0, 1.0]
100+
101+
// !be careful here, its easy to make a mistake
102+
// that is difficult to notice if you add an extra
103+
// zero somewhere!
104+
105+
int sign = x < 0 ? 0b1000 : 0b0000;
106+
x = fabsf(x);
107+
if (x > 0.29166667f)
108+
if (x > 0.583333f)
109+
if (x > 0.8333333f)
110+
return 0b0011 + sign;
111+
else
112+
return 0b0010 + sign;
113+
else if (x > 0.4166667f)
114+
return 0b101 + sign;
115+
else
116+
return 0b100 + sign;
117+
else if (x > 0.0859375f)
118+
if (x > 0.20833333f)
119+
return 0b0111 + sign;
120+
else
121+
return 0b0110 + sign;
122+
else if (x > 0.00260417f)
123+
return 0b0001 + sign;
124+
else
125+
return 0b0000 + sign;
126+
}
127+
128+
// helper. maybe move elsewhere TODO
129+
__device__ unsigned char dQuantizeNF4(float x) {
130+
131+
// the values for this tree was generated by test_normal_map_tree
132+
// in the file tests/test_functional.py
133+
if (x > 0.03979014977812767f)
134+
if (x > 0.3893125355243683f) // 1
135+
if (x > 0.6427869200706482f) // 11
136+
if (x > 0.8614784181118011f) // 111
137+
return 0b1111;
138+
else
139+
return 0b1110;
140+
else if (x > 0.5016634166240692f) // 110
141+
return 0b1101;
142+
else
143+
return 0b1100;
144+
else if (x > 0.2035212516784668f) // 10
145+
if (x > 0.2920137718319893f) // 101
146+
return 0b1011;
147+
else
148+
return 0b1010;
149+
else if (x > 0.1202552504837513f) // 100
150+
return 0b1001;
151+
else
152+
return 0b1000;
153+
else if (x > -0.33967943489551544f) // 0
154+
if (x > -0.13791173323988914f) // 01
155+
if (x > -0.045525018125772476f) // 011
156+
return 0b0111;
157+
else
158+
return 0b0110;
159+
else if (x > -0.23460740596055984f) // 010
160+
return 0b0101;
161+
else
162+
return 0b0100;
163+
else if (x > -0.6106329262256622f) // 00
164+
if (x > -0.4599952697753906f) // 001
165+
return 0b0011;
166+
else
167+
return 0b0010;
168+
else if (x > -0.8480964004993439f) // 000
169+
return 0b0001;
170+
else
171+
return 0b0000;
172+
}
173+
174+
template <typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
175+
//__launch_bounds__(TH, 4)
176+
__global__ void kQuantizeBlockwise(
177+
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
178+
const int rand_offset, const int n
179+
) {
180+
// This can overflow, so we clamp to INT32_MAX. We won't have more elements than this.
181+
const int n_full = min(gridDim.x * BLOCK_SIZE, INT32_MAX);
182+
183+
const int base_idx = blockIdx.x * BLOCK_SIZE;
184+
int valid_items = 0;
185+
186+
T vals[NUM_PER_TH];
187+
float rand_vals[NUM_PER_TH];
188+
unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH];
189+
190+
float local_abs_max = 0.0f;
191+
int local_rand_idx = 0;
192+
193+
typedef cub::BlockLoad<T, BLOCK_SIZE / NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
194+
typedef cub::BlockStore<
195+
unsigned char, BLOCK_SIZE / NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH,
196+
cub::BLOCK_STORE_WARP_TRANSPOSE>
197+
StoreChar;
198+
typedef cub::BlockReduce<float, BLOCK_SIZE / NUM_PER_TH> BlockReduce;
199+
typedef cub::BlockLoad<float, BLOCK_SIZE / NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
200+
201+
__shared__ typename LoadT::TempStorage loadt;
202+
__shared__ typename LoadFloat::TempStorage loadf;
203+
__shared__ typename StoreChar::TempStorage storec;
204+
__shared__ typename BlockReduce::TempStorage reduce;
205+
__shared__ float smem_code[256];
206+
__shared__ float smem_absmax_value[1];
207+
208+
if (DATA_TYPE == General8bit)
209+
for (int i = threadIdx.x; i < 256; i += blockDim.x)
210+
smem_code[i] = code[i];
211+
212+
for (int64_t i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) {
213+
valid_items = min(BLOCK_SIZE, static_cast<int>(n - i));
214+
local_abs_max = -FLT_MAX;
215+
216+
__syncthreads();
217+
LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f);
218+
219+
// 1. compute local max
220+
// 2. broadcast local max
221+
// 3. normalize inputs and quantize
222+
223+
#pragma unroll NUM_PER_TH
224+
for (int j = 0; j < NUM_PER_TH; j++)
225+
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
226+
227+
local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, CUB_REDUCTIONOP_MAX, valid_items);
228+
229+
if (threadIdx.x == 0) {
230+
smem_absmax_value[0] = 1.0f / local_abs_max;
231+
absmax[i / BLOCK_SIZE] = local_abs_max;
232+
}
233+
__syncthreads();
234+
235+
local_abs_max = smem_absmax_value[0];
236+
237+
if (STOCHASTIC) {
238+
local_rand_idx = ((blockIdx.x * NUM_BLOCK) + (threadIdx.x * NUM) + rand_offset) % (1024 - 4);
239+
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
240+
}
241+
242+
switch (DATA_TYPE) {
243+
case General8bit:
244+
#pragma unroll NUM_PER_TH
245+
for (int j = 0; j < NUM_PER_TH; j++) {
246+
if (!STOCHASTIC)
247+
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j]) * local_abs_max);
248+
else
249+
qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j]) * local_abs_max);
250+
}
251+
break;
252+
case FP4:
253+
#pragma unroll NUM_PER_TH
254+
for (int j = 0; j < NUM_PER_TH / 2; j++) {
255+
qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
256+
qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
257+
}
258+
break;
259+
case NF4:
260+
#pragma unroll NUM_PER_TH
261+
for (int j = 0; j < NUM_PER_TH / 2; j++) {
262+
qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
263+
qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
264+
}
265+
break;
266+
}
267+
268+
__syncthreads();
269+
StoreChar(storec).Store(
270+
&(out[(DATA_TYPE > 0) ? i / 2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items + 1) / 2 : valid_items
271+
);
272+
}
273+
}
274+
275+
//// host code
276+
277+
template <typename T, int STOCHASTIC, int DATA_TYPE>
278+
void quantizeBlockwise(
279+
float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
280+
) {
281+
int num_blocks = n / blocksize;
282+
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
283+
284+
if (blocksize == 4096)
285+
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, DATA_TYPE>
286+
<<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
287+
else if (blocksize == 2048)
288+
kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
289+
else if (blocksize == 1024)
290+
kQuantizeBlockwise<T, 1024, 4, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
291+
else if (blocksize == 512)
292+
kQuantizeBlockwise<T, 512, 2, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
293+
else if (blocksize == 256)
294+
kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
295+
else if (blocksize == 128)
296+
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
297+
else if (blocksize == 64)
298+
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
299+
300+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
301+
}
302+
303+
// launch template instantiations needed for host code
304+
// todo: consider just exposing C API here instead
305+
306+
template void quantizeBlockwise<half, 1, General8bit>(
307+
float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
308+
);
309+
template void quantizeBlockwise<half, 0, General8bit>(
310+
float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
311+
);
312+
template void quantizeBlockwise<half, 0, FP4>(
313+
float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
314+
);
315+
template void quantizeBlockwise<half, 0, NF4>(
316+
float* code, half* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
317+
);
318+
template void quantizeBlockwise<float, 1, General8bit>(
319+
float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
320+
);
321+
template void quantizeBlockwise<float, 0, General8bit>(
322+
float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
323+
);
324+
template void quantizeBlockwise<float, 0, FP4>(
325+
float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
326+
);
327+
template void quantizeBlockwise<float, 0, NF4>(
328+
float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
329+
);
330+
template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(
331+
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
332+
const int n
333+
);
334+
template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(
335+
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
336+
const int n
337+
);
338+
template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(
339+
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
340+
const int n
341+
);
342+
template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(
343+
float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize,
344+
const int n
345+
);
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
template <typename T, int STOCHASTIC, int DATA_TYPE>
2+
void quantizeBlockwise(
3+
float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
4+
);

0 commit comments

Comments
 (0)