Skip to content

Commit 3db2828

Browse files
committed
refine code
1 parent 9528f07 commit 3db2828

5 files changed

Lines changed: 91 additions & 33 deletions

File tree

bitsandbytes/backends/xpu/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _(
183183
) -> torch.Tensor:
184184
shape = (*A.shape[:-1], shapeB[0])
185185
out = torch.zeros(shape, device=A.device, dtype=torch.float32)
186-
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
186+
_gemv_4bit_impl(A, B, shapeB, absmax.bfloat16(), code, blocksize, out=out)
187187
return out
188188

189189
@register_kernel("bitsandbytes::gemv_4bit.out", "xpu")

csrc/pythonInterface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ void gemv_4bit_inference_fp16(
381381

382382
#if 1
383383
void gemm_4bit_inference_bf16(
384-
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype, float * out,
384+
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, sycl::ext::oneapi::bfloat16 *absmax, float *datatype, float * out,
385385
int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
386386
) {
387387
gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
@@ -826,7 +826,7 @@ void cgemv_4bit_inference_fp16(
826826

827827
#if 1
828828
void cgemv_4bit_inference_bf16(
829-
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype,
829+
int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, sycl::ext::oneapi::bfloat16 *absmax, float *datatype,
830830
float * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
831831
) {
832832
gemm_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);

csrc/xpu_cutlass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void gemv_4bit_inference_cutlass_cute(int m, int n, int k, T *A, T *B,
109109

110110
template <typename T, int BITS>
111111
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned char *B,
112-
float *absmax, float *datatype, float *out, int lda,
112+
T *absmax, float *datatype, float *out, int lda,
113113
int ldb, int ldc, int blocksize, sycl::queue *stream);
114114

115115
template <typename T, int BITS>

csrc/xpu_cutlass_fusion.cpp

Lines changed: 83 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include "cutlass/detail/layout.hpp"
1515
#include "cutlass/detail/mma.hpp"
1616
#include "cutlass/cuda_host_adapter.hpp"
17+
#include <cutlass/numeric_types.h>
18+
#include <cutlass/bfloat16.h>
1719

1820
#include "cutlass/kernel_launch.h"
1921
#if !defined(__CUDACC_RTC__)
@@ -46,7 +48,7 @@ using ElementB = QuantType; //cutlass::gemm::collective::detail::deduce_mixed_wi
4648

4749
using ElementMMA = ElementA;
4850
using ElementQuant = QuantType;
49-
using ElementScale = MmaType;
51+
using ElementScale = sycl::ext::oneapi::bfloat16; //MmaType;
5052

5153
using ElementC = float;
5254
using ElementD = float;
@@ -63,11 +65,11 @@ using ProblemShape = Shape<int, int, int, int>;
6365
// inner_loop_number (Atom numbers per thread): (256/8) * (256/4) * (32/1)
6466
// XE_8x16x16_F32BF16BF16F32_TT -> hardware 指令
6567
// Stride<_4, _1, _0> could be optional?
66-
using TileShape = Shape<_256, _256, _32>;
68+
using TileShape = Shape<_256, _256, _32>; /*TODO: maybe need to adjust me for the small tile shapes*/
6769
//using TileShape = Shape<_32, _32, _32>;
6870
using TiledMma =
6971
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
70-
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
72+
Layout<Shape<_8, _4, _1> /*TODO: maybe need to adjust me for the small tile shapes*/, Stride<_4, _1, _0>>>::TiledMMA;
7173

7274
// Define Mainloop dispatch policy
7375
constexpr int PipelineStages = 1;
@@ -228,7 +230,12 @@ class kgemm_4bit_inference_cutlass_dequant {
228230
return Int<cute::gcd(Cosize, 32 / cute::min(sizeof_bits_v<SrcType>, sizeof_bits_v<DstType>))>{};
229231
}
230232
};
231-
233+
234+
float bfloat16_to_float(uint16_t bf16_bits) {
235+
uint32_t float_bits = (bf16_bits << 16); // 将 bfloat16 左移16位转为 float
236+
return reinterpret_cast<float&>(float_bits);
237+
}
238+
232239
/// Utilities to transform A.
233240
template <class EngineIn,
234241
class EngineOut,
@@ -251,20 +258,39 @@ class kgemm_4bit_inference_cutlass_dequant {
251258
static_assert(size_v<LayoutIn> == cosize_v<LayoutIn>);
252259
static_assert(size_v<LayoutOut> == cosize_v<LayoutOut>);
253260

254-
for (int i = 0; i < size(tCrB_src); ++i) {
255-
// uint8_t src_value = tCrB_src(i);
256-
// tCrB_dst(2*i) = static_cast<ElementMMA>(quant_map[(src_value >> 4) & 0x0F]);// * tCrS(i/4) ;
257-
// tCrB_dst(2*i+1) = static_cast<ElementMMA>(quant_map[src_value & 0x0F]);// * tCrS(i/4);
258-
uint8_t packed = tCrB_src(i);
259-
uint8_t high = (packed >> 4) & 0x0F;
260-
uint8_t low = packed & 0x0F;
261-
262-
// 应用缩放因子
263-
float val_high = quant_map[high];// * tCrS(i/4); // 假设每32个元素共享一个scale
264-
float val_low = quant_map[low];// * tCrS(i/4);
265-
266-
tCrB_dst(2*i) = static_cast<ElementMMA>(val_high);
267-
tCrB_dst(2*i+1) = static_cast<ElementMMA>(val_low);
261+
int scale_number = decltype(size(tCrS))::value;
262+
int src_number = decltype(size(tCrB_src))::value;
263+
int src_sub_number = src_number / scale_number;
264+
265+
//float scale_value = 1.0;
266+
267+
if(cute::thread0()) printf("scale_number = %d, src_number= %d, src_sub_number = %d\n", scale_number, src_number, src_sub_number);
268+
for(int i=0; i < scale_number; ++i) {
269+
auto scale_value = tCrS(i);
270+
float scale_value_float = static_cast<float>(scale_value);
271+
//uint16_t scale_bits = reinterpret_cast<uint16_t&>(scale_value);
272+
//uint16_t scale_bits = reinterpret_cast<uint16_t&>(scale_value);
273+
//printf("scale_value = %f, tCrS(%d) raw bits: %d\n",scale_value, i, static_cast<int>(scale_value));
274+
//float scale_value_float = bfloat16_to_float(scale_bits);
275+
printf("scale_value_float = %f\n", scale_value_float);
276+
for (int j = 0; j < src_sub_number; ++j) {
277+
int offset = i * src_sub_number;
278+
uint8_t packed = tCrB_src(offset + j);
279+
uint8_t high = (packed >> 4) & 0x0F;
280+
uint8_t low = packed & 0x0F;
281+
282+
float val_high = quant_map[high];
283+
float val_low = quant_map[low];
284+
285+
float val_high_scaled = val_high * scale_value_float;
286+
float val_low_scaled = val_high * scale_value_float;
287+
288+
//printf("scale value = %f, val_high_scaled = %f, val_low_scaled = %f\n", scale_value, val_high_scaled, val_low_scaled);
289+
tCrB_dst(offset + 2 * j) = static_cast<ElementMMA>(val_high_scaled);// * scale_value;
290+
tCrB_dst(offset + 2 * j + 1) = static_cast<ElementMMA>(val_low_scaled);// * scale_value;
291+
292+
//printf("scale value = %f, val_high = %f, val_low = %f\n", scale_value, val_high, val_low);
293+
}
268294
}
269295
}
270296

@@ -387,7 +413,7 @@ class kgemm_4bit_inference_cutlass_dequant {
387413

388414
Tensor tSgS = [&](){
389415
return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)), // 初始坐标:(n_coord, 0, l_coord),表示从 N 维的 n_coord 开始,K 维从 0 开始
390-
make_layout(make_shape(_2{}, _2{}, _1{}, k_tile_count), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
416+
make_layout(make_shape(_2{}, _2{}, _1{}, k_tile_count / 2), // 迭代器的逻辑形状:[2, 2, 1, k_tile_count],表示每次迭代生成 2x2x1 的坐标块,共 k_tile_count 次
391417
make_stride(E<0>{} * _16{}, E<0>{} * _32{}, _0{}, E<1>{} * _1{}))); // 步长 [16, 32, 0, 1]:
392418
}();
393419

@@ -524,12 +550,20 @@ for(int i=0; i<num_Acc; i++) {
524550
}
525551
};
526552

553+
void convert_float_to_bfloat16_host(float* src, cutlass::bfloat16_t* dst, int size) {
554+
for (int i = 0; i < size; ++i) {
555+
dst[i] = static_cast<cutlass::bfloat16_t>(src[i]); // 直接转换
556+
}
557+
}
558+
527559
template <typename T, int BITS>
528560
void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned char *B,
529-
float *absmax_, float *datatype, float *out, int lda,
561+
T *absmax_, float *datatype, float *out, int lda,
530562
int ldb, int ldc, int blocksize, sycl::queue *stream) {
531563
std::cout<<"this is gemm_4bit_inference_cutlass_dequant ......................!!!!!!\n";
532564

565+
566+
533567
sycl::queue q = *stream;
534568
using GemmKernel = kgemm_4bit_inference_cutlass_dequant<T, BITS>;
535569

@@ -539,8 +573,26 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
539573
//TODO(Xiaoli): FIX ME?? auto problem_size = ProblemShape{m, n, k};
540574
auto problem_size = ProblemShape{m, n, k, l};
541575
//TODO(Xiaoli): FIX ME
542-
T* absmax = (T*)absmax_;
543-
576+
// T* absmax = (T*)absmax_;
577+
//T* absmax = static_cast<T*>(absmax_);
578+
//auto absmax = reinterpret_cast<ElementScale*>((T*)absmax_);
579+
//int scale_size = 2048;
580+
//ElementScale* absmax = new ElementScale[scale_size];
581+
//for (int i = 0; i < scale_size; ++i) {
582+
// absmax[i] = static_cast<ElementScale>(absmax_[i]); // 逐元素转换
583+
// std::cout<<"absmax_[i] = "<<absmax_[i]<<" absmax[i] = "<< absmax[i]<<std::endl;
584+
//}
585+
586+
//int N = 2048;
587+
//printf("A[0] = %f\n", A[0]);
588+
//printf("absmax_[0] = %f\n", absmax_[0]);
589+
//cutlass::bfloat16_t* absmax_test = new cutlass::bfloat16_t[N];
590+
//std::cout<<"absmax_[0] ==== "<<absmax_[0]<<std::endl;
591+
//convert_float_to_bfloat16_host(absmax_, absmax, N);
592+
//absmax_test[0] = static_cast<cutlass::bfloat16_t>(absmax[0]);
593+
//printf("absmax_test[0] = %f\n", absmax_test[0]);
594+
//std::cout<<"absmax_test[0] ==== "<<absmax_test[0]<<std::endl;
595+
#if 1
544596
// Init Params
545597
using Params = GemmKernel::Params;
546598
Params params;
@@ -570,12 +622,16 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
570622
params.tiled_copy_a = tiled_copy_a;
571623
params.tiled_copy_b = tiled_copy_b;
572624
params.tiled_copy_b_4bit = tiled_copy_b_4bit;
573-
574-
const int scale_k = cute::ceil_div(k, blocksize);
625+
626+
//float f = 1.0f;
627+
//cutlass::bfloat16_t bf16 = static_cast<cutlass::bfloat16_t>(f);
628+
//std::cout<<"k = "<<k<<" blocksize = "<<blocksize<<std::endl;
629+
const int scale_k = cute::ceil_div(k / 2, blocksize);
630+
std::cout<<"n = "<<n<<" k = "<<k<<" blocksize = "<<blocksize<<" scale_k = "<<scale_k<<std::endl;
575631
const int dq_mn_size = n; //A(m) or B(n)
576632
StrideScale stride_S = cutlass::make_cute_packed_stride(StrideScale{}, cute::make_shape(dq_mn_size, scale_k, l));
577633
auto mScale = make_tensor(
578-
make_gmem_ptr(absmax), //static_cast<NonVoidElementScale *>(absmax)),
634+
make_gmem_ptr(absmax_), //static_cast<ElementScale *>(absmax)),
579635
make_layout(make_shape(dq_mn_size, scale_k, l), stride_S));
580636
Copy_Scale tiled_copy_scale{Copy_Scale{}.with(mScale)};
581637

@@ -627,10 +683,11 @@ void gemm_4bit_inference_cutlass_dequant(int m, int n, int k, T *A, unsigned cha
627683
auto event = syclcompat::experimental::launch<device_kernel<GemmKernel>>(policy, q, params);
628684
EventManager::getInstance().addEvent(event);
629685
//syclcompat::wait();
686+
#endif
630687
}
631688

632689
template void gemm_4bit_inference_cutlass_dequant<sycl::ext::oneapi::bfloat16, 16>(
633690
int m, int n, int k, sycl::ext::oneapi::bfloat16 *A, unsigned char *B,
634-
float *absmax, float *datatype, float *out, int lda,
691+
sycl::ext::oneapi::bfloat16 *absmax, float *datatype, float *out, int lda,
635692
int ldb, int ldc, int blocksize, sycl::queue *stream);
636693

tests/test_xpu.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class TestXPU:
4747
[torch.uint8],
4848
ids=describe_dtype,
4949
)
50-
@pytest.mark.parametrize("dim", [32], ids=id_formatter("dim"))
50+
@pytest.mark.parametrize("dim", [256], ids=id_formatter("dim"))
5151
def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
5252
errs1 = []
5353
errs2 = []
@@ -66,7 +66,7 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
6666
#for i in range(iters):
6767
#pdb.set_trace()
6868
if kind == "fc1":
69-
A = torch.ones(dim, dim, dtype=dtype, device=device)
69+
A = torch.ones(32, dim, dtype=dtype, device=device)
7070
B = torch.ones(dim, dim, dtype=dtype, device=device) # / math.sqrt(dim)
7171
elif kind == "fc2":
7272
A = torch.randn(1, 4 * dim, dtype=dtype, device=device)
@@ -83,9 +83,10 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
8383
quant_type=storage_type,
8484
compress_statistics=double_quant,
8585
quant_storage=quant_storage,
86+
blocksize=32,
8687
)
8788
##pdb.set_trace()
88-
C3 = torch.matmul(A.t(), B)
89+
C3 = torch.matmul(A, B.t())
8990
pdb.set_trace()
9091
C2 = F.gemv_4bit(A, qB.t(), state=state)
9192
pdb.set_trace()

0 commit comments

Comments
 (0)