Skip to content

Commit 8349915

Browse files
committed
issue/1118: success w4 kernel
1 parent 38ced1a commit 8349915

File tree

3 files changed

+227
-154
lines changed

3 files changed

+227
-154
lines changed

src/infiniop/ops/gptq_qyblas_gemm/info.h

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,37 @@
66
#include <optional>
77
#include <vector>
88

9+
inline void prepare_matrix_for_cublas(
10+
infiniopTensorDescriptor_t tensor,
11+
bool &transpose_tensor) {
12+
13+
auto strides = tensor->strides();
14+
auto sizes = tensor->shape();
15+
16+
if ((strides[0] == 1) && (strides[1] >= std::max<int64_t>(1, sizes[0]))) {
17+
18+
transpose_tensor = false;
19+
return;
20+
}
21+
if ((strides[1] == 1) && (strides[0] >= std::max<int64_t>(1, sizes[1]))) {
22+
23+
transpose_tensor = true;
24+
return;
25+
}
26+
transpose_tensor = true;
27+
}
28+
929
namespace op::gptq_qyblas_gemm {
1030

1131
class GptqQyblasGemmInfo {
1232
GptqQyblasGemmInfo() = default;
1333

1434
public:
15-
infiniDtype_t dtype, weight_dtype, scales_dtype, zeros_dtype;
35+
infiniDtype_t dtype, weight_dtype, scales_dtype, zeros_dtype, out_dtype;
1636
size_t M, K, N, scales_size_0, scales_size_1;
1737
ptrdiff_t lda, ldb, result_ld;
18-
bool transpose_mat_1, transpose_mat_2, transpose_result;
38+
bool transpose_result;
39+
char transa, transb;
1940

2041
static utils::Result<GptqQyblasGemmInfo> createGptqQyblasGemmInfo(
2142
infiniopTensorDescriptor_t out_desc,
@@ -27,17 +48,38 @@ class GptqQyblasGemmInfo {
2748
auto dtype = a_desc->dtype();
2849

2950
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
30-
CHECK_DTYPE(dtype, out_desc->dtype());
51+
auto out_dtype = out_desc->dtype();
52+
CHECK_DTYPE(dtype, out_dtype);
3153

3254
const infiniDtype_t weight_dtype = b_desc->dtype();
33-
CHECK_DTYPE(weight_dtype, INFINI_DTYPE_F8, INFINI_DTYPE_U8, INFINI_DTYPE_I8);
55+
// CHECK_DTYPE(weight_dtype, INFINI_DTYPE_F8, INFINI_DTYPE_U8, INFINI_DTYPE_I8);
3456

3557
const infiniDtype_t scales_dtype = b_scales_desc->dtype();
3658
const infiniDtype_t zeros_dtype = b_zeros_desc->dtype();
3759

38-
size_t M = out_desc->shape()[0];
39-
size_t N = out_desc->shape()[1];
40-
size_t K = a_desc->shape()[1];
60+
bool transpose_result = false;
61+
bool transpose_mat_1 = false;
62+
bool transpose_mat_2 = false;
63+
64+
prepare_matrix_for_cublas(out_desc, transpose_result);
65+
66+
auto mata = (transpose_result ? b_desc : a_desc);
67+
prepare_matrix_for_cublas(transpose_result ? b_desc : a_desc, transpose_mat_1);
68+
auto matb = (transpose_result ? a_desc : b_desc);
69+
prepare_matrix_for_cublas(transpose_result ? a_desc : b_desc, transpose_mat_2);
70+
71+
auto mat1_sizes = a_desc->shape();
72+
auto mat2_sizes = b_desc->shape();
73+
if (transpose_result) {
74+
transpose_mat_1 = !transpose_mat_1;
75+
transpose_mat_2 = !transpose_mat_2;
76+
mat1_sizes = mata->shape();
77+
mat2_sizes = matb->shape();
78+
}
79+
80+
size_t M = mat1_sizes[transpose_result ? 1 : 0];
81+
size_t K = mat1_sizes[transpose_result ? 0 : 1];
82+
size_t N = mat2_sizes[transpose_result ? 0 : 1];
4183

4284
size_t scales_size_0 = b_scales_desc->shape()[0];
4385
size_t scales_size_1 = b_scales_desc->shape()[1];
@@ -50,40 +92,23 @@ class GptqQyblasGemmInfo {
5092
&& b_zeros_desc->ndim() == ndim,
5193
INFINI_STATUS_BAD_TENSOR_SHAPE);
5294

53-
bool transpose_result = false;
54-
if (out_desc->strides()[0] == 1 && out_desc->strides()[1] >= std::max<int64_t>(1, out_desc->shape()[0])) {
55-
transpose_result = true;
56-
} else if (out_desc->strides()[1] == 1 && out_desc->strides()[0] >= std::max<int64_t>(1, out_desc->shape()[1])) {
57-
transpose_result = false;
58-
} else {
59-
transpose_result = false;
60-
}
61-
bool transpose_mat_1 = false;
62-
if (a_desc->strides()[0] == 1 && a_desc->strides()[1] >= std::max<int64_t>(1, a_desc->shape()[0])) {
63-
transpose_mat_1 = true;
64-
} else if (a_desc->strides()[1] == 1 && a_desc->strides()[0] >= std::max<int64_t>(1, a_desc->shape()[1])) {
65-
transpose_mat_1 = false;
66-
} else {
67-
transpose_mat_1 = false;
68-
}
69-
bool transpose_mat_2 = false;
70-
if (b_desc->strides()[0] == 1 && b_desc->strides()[1] >= std::max<int64_t>(1, b_desc->shape()[0])) {
71-
transpose_mat_2 = true;
72-
} else if (b_desc->strides()[1] == 1 && b_desc->strides()[0] >= std::max<int64_t>(1, b_desc->shape()[1])) {
73-
transpose_mat_2 = false;
74-
} else {
75-
transpose_mat_2 = false;
76-
}
95+
ptrdiff_t lda = mata->strides()[(transpose_mat_1 == transpose_result)
96+
? 1
97+
: 0];
98+
ptrdiff_t ldb = matb->strides()[(transpose_mat_2 == transpose_result)
99+
? 1
100+
: 0];
101+
ptrdiff_t result_ld = out_desc->strides()[transpose_result ? 0 : 1];
77102

78-
ptrdiff_t lda = a_desc->strides()[transpose_mat_1 ? 1 : 0];
79-
ptrdiff_t ldb = b_desc->strides()[transpose_mat_2 ? 1 : 0];
80-
ptrdiff_t result_ld = out_desc->strides()[transpose_result ? 1 : 0];
103+
char transa = transpose_mat_1 ? 't' : 'n';
104+
char transb = transpose_mat_2 ? 't' : 'n';
81105

82106
return utils::Result<GptqQyblasGemmInfo>(GptqQyblasGemmInfo{
83-
dtype, weight_dtype, scales_dtype, zeros_dtype,
107+
dtype, weight_dtype, scales_dtype, zeros_dtype, out_dtype,
84108
M, K, N, scales_size_0, scales_size_1,
85109
lda, ldb, result_ld,
86-
transpose_mat_1, transpose_mat_2, transpose_result});
110+
transpose_result,
111+
transa, transb});
87112
}
88113
};
89114

src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cu

Lines changed: 54 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,36 @@
33
#include "dlblas_ext.h"
44
#include "gptq_qyblas_gemm_nvidia.cuh"
55

6+
inline cudaDataType_t ScalarTypeToCudaDataType(
7+
infiniDtype_t scalar_type) {
8+
switch (scalar_type) {
9+
case INFINI_DTYPE_U8:
10+
return CUDA_R_8U;
11+
case INFINI_DTYPE_I8:
12+
return CUDA_R_8I;
13+
case INFINI_DTYPE_I32:
14+
return CUDA_R_32I;
15+
case INFINI_DTYPE_F16:
16+
return CUDA_R_16F;
17+
case INFINI_DTYPE_F32:
18+
return CUDA_R_32F;
19+
case INFINI_DTYPE_F64:
20+
return CUDA_R_64F;
21+
case INFINI_DTYPE_I16:
22+
return CUDA_R_16I;
23+
case INFINI_DTYPE_I64:
24+
return CUDA_R_64I;
25+
case INFINI_DTYPE_BF16:
26+
return CUDA_R_16BF;
27+
case INFINI_DTYPE_F8:
28+
return (cudaDataType_t)CUDA_R_8F_E4M3;
29+
default:
30+
fprintf(stderr,
31+
"Cannot convert ScalarType %d\n",
32+
(int)scalar_type);
33+
abort();
34+
}
35+
}
636
namespace op::gptq_qyblas_gemm::nvidia {
737

838
struct Descriptor::Opaque {
@@ -47,17 +77,14 @@ infiniStatus_t Descriptor::calculate(void *workspace,
4777

4878
cudaDataType_t computeType_ = (cudaDataType_t)CUDA_R_32F;
4979
cudaDataType_t kernel_Atype_, kernel_Btype_, kernel_Ctype_, kernel_Stype_, kernel_Ztype_;
50-
51-
switch (_info.dtype) {
52-
case INFINI_DTYPE_F16:
53-
kernel_Atype_ = CUDA_R_16F;
54-
break;
55-
case INFINI_DTYPE_BF16:
56-
kernel_Atype_ = CUDA_R_16BF;
57-
break;
58-
default:
59-
return INFINI_STATUS_BAD_TENSOR_DTYPE;
80+
auto dtype = _info.dtype;
81+
auto weight_dtype = _info.weight_dtype;
82+
if (_info.transpose_result) {
83+
std::swap(a, b);
84+
std::swap(dtype, weight_dtype);
6085
}
86+
kernel_Atype_ = ScalarTypeToCudaDataType(dtype);
87+
kernel_Btype_ = ScalarTypeToCudaDataType(weight_dtype);
6188

6289
if (quant_type == 0) {
6390
if (8 == bit) {
@@ -66,66 +93,21 @@ infiniStatus_t Descriptor::calculate(void *workspace,
6693

6794
if (4 == bit) {
6895
kernel_Atype_ = (cudaDataType_t)CUDA_R_4U;
96+
K = K * 2;
6997
}
7098
}
7199

72-
switch (_info.weight_dtype) {
73-
case INFINI_DTYPE_F8:
74-
kernel_Btype_ = (cudaDataType_t)CUDA_R_8F_E4M3;
75-
break;
76-
case INFINI_DTYPE_U8:
77-
kernel_Btype_ = CUDA_R_8U;
78-
break;
79-
case INFINI_DTYPE_I8:
80-
kernel_Btype_ = CUDA_R_8I;
81-
break;
82-
default:
83-
return INFINI_STATUS_BAD_TENSOR_DTYPE;
84-
}
85-
86-
kernel_Ctype_ = kernel_Atype_;
87-
88-
switch (_info.scales_dtype) {
89-
case INFINI_DTYPE_F32:
90-
kernel_Stype_ = CUDA_R_32F;
91-
break;
92-
case INFINI_DTYPE_F16:
93-
kernel_Stype_ = CUDA_R_16F;
94-
break;
95-
case INFINI_DTYPE_BF16:
96-
kernel_Stype_ = CUDA_R_16BF;
97-
break;
98-
default:
99-
return INFINI_STATUS_BAD_TENSOR_DTYPE;
100-
}
101-
102-
switch (_info.zeros_dtype) {
103-
case INFINI_DTYPE_F32:
104-
kernel_Ztype_ = CUDA_R_32F;
105-
break;
106-
case INFINI_DTYPE_F16:
107-
kernel_Ztype_ = CUDA_R_16F;
108-
break;
109-
case INFINI_DTYPE_BF16:
110-
kernel_Ztype_ = CUDA_R_16BF;
111-
break;
112-
default:
113-
return INFINI_STATUS_BAD_TENSOR_DTYPE;
114-
}
100+
kernel_Ctype_ = ScalarTypeToCudaDataType(_info.out_dtype);
101+
kernel_Stype_ = ScalarTypeToCudaDataType(_info.scales_dtype);
102+
kernel_Ztype_ = ScalarTypeToCudaDataType(_info.zeros_dtype);
115103

116104
float alpha = 1.0f;
117105
float beta = 0.0f;
118106

119-
bool transpose_mat_1 = _info.transpose_mat_1;
120-
bool transpose_mat_2 = _info.transpose_mat_2;
121-
122107
int64_t M = static_cast<int64_t>(_info.M);
123108
int64_t N = static_cast<int64_t>(_info.N);
124109
int64_t lda = static_cast<int64_t>(_info.lda);
125-
int64_t ldb = ((bit == 4 && transpose_mat_2) ? 2 * static_cast<int64_t>(_info.ldb) : static_cast<int64_t>(_info.ldb));
126-
127-
cublasOperation_t transa = transpose_mat_2 ? CUBLAS_OP_T : CUBLAS_OP_N;
128-
cublasOperation_t transb = transpose_mat_1 ? CUBLAS_OP_T : CUBLAS_OP_N;
110+
int64_t ldb = static_cast<int64_t>(_info.ldb);
129111

130112
int64_t scales_size_0 = static_cast<int64_t>(_info.scales_size_0);
131113
int64_t scales_size_1 = static_cast<int64_t>(_info.scales_size_1);
@@ -135,7 +117,7 @@ infiniStatus_t Descriptor::calculate(void *workspace,
135117
dlblasExtQuantParametersV2_t extParameters;
136118

137119
if (quant_type == 0) {
138-
extParameters.a_group_size_m = N / scales_size_1;
120+
extParameters.a_group_size_m = M / scales_size_1;
139121
extParameters.a_group_size_k = K / scales_size_0;
140122
extParameters.a_zeropoints_type = kernel_Ztype_;
141123
extParameters.a_zeropoints = b_zeros;
@@ -151,13 +133,13 @@ infiniStatus_t Descriptor::calculate(void *workspace,
151133
} else if (quant_type == 2 || quant_type == 3) {
152134
// calculate block_shape according weight/scales shape
153135
int block_shape = 128;
154-
while ((N + block_shape - 1) / block_shape < scales_size_0) {
136+
while ((M + block_shape - 1) / block_shape < scales_size_0) {
155137
block_shape /= 2;
156138
if (block_shape < 32) {
157139
fprintf(stderr,
158140
"INTERNAL ASSERT FAILED: block_shape >= 32\n"
159141
"Invalid fp blockwise linear arguments. Weight: [%d, %d]. Scales: [%d, %d].\n",
160-
(int)N, (int)K, (int)scales_size_0, (int)scales_size_1);
142+
(int)M, (int)K, (int)scales_size_0, (int)scales_size_1);
161143
abort();
162144
}
163145
}
@@ -172,12 +154,11 @@ infiniStatus_t Descriptor::calculate(void *workspace,
172154
extParameters.a_zeropoints = nullptr;
173155
extParameters.a_scales = b_scales;
174156
}
175-
printf("a=%s, b=%s, c=%s\n",
176-
_info.transpose_mat_1 ? "true" : "false",
177-
_info.transpose_mat_2 ? "true" : "false",
178-
_info.transpose_result ? "true" : "false");
179-
printf("M-K-N:[%ld, %ld, %ld], lda-ldb-ldc:[%ld, %ld, %ld]\n", M, K, N, lda, ldb, result_ld);
180-
printf("quant type:%ld, bit:%ld\n", quant_type, bit);
157+
bool transpose_mat_1 = _info.transa == 't';
158+
bool transpose_mat_2 = _info.transb == 't';
159+
cublasOperation_t transa = transpose_mat_1 ? CUBLAS_OP_T : CUBLAS_OP_N;
160+
cublasOperation_t transb = transpose_mat_2 ? CUBLAS_OP_T : CUBLAS_OP_N;
161+
181162
if (_info.dtype == INFINI_DTYPE_F16 || _info.dtype == INFINI_DTYPE_BF16) {
182163
CHECK_STATUS(_opaque->internal->useCublas(
183164
(cudaStream_t)stream,
@@ -186,16 +167,16 @@ infiniStatus_t Descriptor::calculate(void *workspace,
186167
dlblasGemmExV2(handle,
187168
transa,
188169
transb,
189-
N,
190170
M,
171+
N,
191172
K,
192173
&alpha,
193-
b,
194-
kernel_Btype_,
195-
ldb,
196174
a,
197175
kernel_Atype_,
198176
lda,
177+
b,
178+
kernel_Btype_,
179+
ldb,
199180
&beta,
200181
out,
201182
kernel_Ctype_,

0 commit comments

Comments
 (0)