Skip to content

Commit 9d9c97c

Browse files
committed
issue/1118: qyblas error
1 parent 6e88052 commit 9d9c97c

File tree

8 files changed

+907
-0
lines changed

8 files changed

+907
-0
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#ifndef __INFINIOP_GPTQ_QYBLAS_GEMM_API_H__
2+
#define __INFINIOP_GPTQ_QYBLAS_GEMM_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
#include <cstdint>
6+
7+
typedef struct InfiniopDescriptor *infiniopGptqQyblasGemmDescriptor_t;
8+
9+
__INFINI_C __export infiniStatus_t infiniopCreateGptqQyblasGemmDescriptor(
10+
infiniopHandle_t handle,
11+
infiniopGptqQyblasGemmDescriptor_t *desc_ptr,
12+
infiniopTensorDescriptor_t out_desc,
13+
infiniopTensorDescriptor_t a_desc,
14+
infiniopTensorDescriptor_t b_desc,
15+
infiniopTensorDescriptor_t b_scales_desc,
16+
infiniopTensorDescriptor_t b_zeros_desc);
17+
18+
__INFINI_C __export infiniStatus_t infiniopGetGptqQyblasGemmWorkspaceSize(
19+
infiniopGptqQyblasGemmDescriptor_t desc,
20+
size_t *size);
21+
22+
__INFINI_C __export infiniStatus_t infiniopGptqQyblasGemm(
23+
infiniopGptqQyblasGemmDescriptor_t desc,
24+
void *workspace,
25+
size_t workspace_size,
26+
void *out,
27+
const void *a,
28+
const void *b,
29+
void *b_scale,
30+
void *b_zero,
31+
int64_t quant_type,
32+
int64_t bit,
33+
void *stream);
34+
35+
__INFINI_C __export infiniStatus_t infiniopDestroyGptqQyblasGemmDescriptor(
36+
infiniopGptqQyblasGemmDescriptor_t desc);
37+
#endif
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#ifndef GPTQ_QYBLAS_GEMM_H
2+
#define GPTQ_QYBLAS_GEMM_H
3+
4+
#include "../../operator.h"
5+
#include "info.h"
6+
7+
#define DESCRIPTOR(NAMESPACE) \
8+
\
9+
namespace op::gptq_qyblas_gemm::NAMESPACE { \
10+
class Descriptor final : public InfiniopDescriptor { \
11+
struct Opaque; \
12+
Opaque *_opaque; \
13+
GptqQyblasGemmInfo _info; \
14+
size_t _workspace_size; \
15+
\
16+
Descriptor( \
17+
Opaque *opaque, \
18+
GptqQyblasGemmInfo info, \
19+
size_t workspace_size, \
20+
infiniDevice_t device_type, \
21+
int device_id) \
22+
: InfiniopDescriptor{device_type, device_id}, \
23+
_opaque(opaque), \
24+
_info(info), \
25+
_workspace_size(workspace_size) {} \
26+
\
27+
public: \
28+
~Descriptor(); \
29+
\
30+
size_t workspaceSize() const { return _workspace_size; } \
31+
\
32+
static infiniStatus_t create( \
33+
infiniopHandle_t handle, \
34+
Descriptor **desc_ptr, \
35+
infiniopTensorDescriptor_t out_desc, \
36+
infiniopTensorDescriptor_t a_desc, \
37+
infiniopTensorDescriptor_t b_desc, \
38+
infiniopTensorDescriptor_t b_scales_desc, \
39+
infiniopTensorDescriptor_t b_zeros_desc); \
40+
\
41+
infiniStatus_t calculate( \
42+
void *workspace, size_t workspace_size, \
43+
void *out, \
44+
const void *a, const void *b, void *b_scale, void *b_zero, int64_t quant_type, int64_t bit, \
45+
void *stream) const; \
46+
}; \
47+
}
48+
49+
#endif // GPTQ_QYBLAS_GEMM_H
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#ifndef __GPTQ_QYBLAS_GEMM_INFO_H__
2+
#define __GPTQ_QYBLAS_GEMM_INFO_H__
3+
4+
#include "../../../utils.h"
5+
#include "../../tensor.h"
6+
#include <optional>
7+
#include <vector>
8+
9+
namespace op::gptq_qyblas_gemm {
10+
11+
class GptqQyblasGemmInfo {
12+
GptqQyblasGemmInfo() = default;
13+
14+
public:
15+
infiniDtype_t dtype, weight_dtype, scales_dtype, zeros_dtype, out_dtype;
16+
size_t M, K, N, scales_size_0, scales_size_1;
17+
ptrdiff_t lda, ldb, result_ld;
18+
bool transpose_mat_1, transpose_mat_2, transpose_result;
19+
20+
static utils::Result<GptqQyblasGemmInfo> createGptqQyblasGemmInfo(
21+
infiniopTensorDescriptor_t out_desc,
22+
infiniopTensorDescriptor_t a_desc,
23+
infiniopTensorDescriptor_t b_desc,
24+
infiniopTensorDescriptor_t b_scales_desc,
25+
infiniopTensorDescriptor_t b_zeros_desc) {
26+
27+
auto dtype = a_desc->dtype();
28+
29+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
30+
31+
const infiniDtype_t weight_dtype = b_desc->dtype();
32+
CHECK_DTYPE(weight_dtype, INFINI_DTYPE_F8, INFINI_DTYPE_U8, INFINI_DTYPE_I8);
33+
34+
const infiniDtype_t scales_dtype = b_scales_desc->dtype();
35+
const infiniDtype_t zeros_dtype = b_zeros_desc->dtype();
36+
const infiniDtype_t out_dtype = out_desc->dtype();
37+
38+
size_t M = out_desc->shape()[0];
39+
size_t N = out_desc->shape()[1];
40+
size_t K = a_desc->shape()[1];
41+
42+
size_t scales_size_0 = b_scales_desc->shape()[0];
43+
size_t scales_size_1 = b_scales_desc->shape()[1];
44+
45+
auto ndim = out_desc->ndim();
46+
CHECK_OR_RETURN(ndim == 2
47+
&& a_desc->ndim() == ndim
48+
&& b_desc->ndim() == ndim
49+
&& b_scales_desc->ndim() == ndim
50+
&& b_zeros_desc->ndim() == ndim,
51+
INFINI_STATUS_BAD_TENSOR_SHAPE);
52+
53+
bool transpose_result = false;
54+
if (out_desc->strides()[0] == 1 && out_desc->strides()[1] >= std::max<int64_t>(1, out_desc->shape()[0])) {
55+
transpose_result = true;
56+
} else if (out_desc->strides()[1] == 1 && out_desc->strides()[0] >= std::max<int64_t>(1, out_desc->shape()[1])) {
57+
transpose_result = false;
58+
} else {
59+
transpose_result = false;
60+
}
61+
bool transpose_mat_1 = false;
62+
if (a_desc->strides()[0] == 1 && a_desc->strides()[1] >= std::max<int64_t>(1, a_desc->shape()[0])) {
63+
transpose_mat_1 = true;
64+
} else if (a_desc->strides()[1] == 1 && a_desc->strides()[0] >= std::max<int64_t>(1, a_desc->shape()[1])) {
65+
transpose_mat_1 = false;
66+
} else {
67+
transpose_mat_1 = false;
68+
}
69+
bool transpose_mat_2 = false;
70+
if (b_desc->strides()[0] == 1 && b_desc->strides()[1] >= std::max<int64_t>(1, b_desc->shape()[0])) {
71+
transpose_mat_2 = true;
72+
} else if (b_desc->strides()[1] == 1 && b_desc->strides()[0] >= std::max<int64_t>(1, b_desc->shape()[1])) {
73+
transpose_mat_2 = false;
74+
} else {
75+
transpose_mat_2 = false;
76+
}
77+
78+
ptrdiff_t lda = a_desc->strides()[transpose_mat_1 ? 1 : 0];
79+
ptrdiff_t ldb = b_desc->strides()[transpose_mat_2 ? 1 : 0];
80+
ptrdiff_t result_ld = out_desc->strides()[transpose_result ? 1 : 0];
81+
82+
return utils::Result<GptqQyblasGemmInfo>(GptqQyblasGemmInfo{
83+
dtype, weight_dtype, scales_dtype, zeros_dtype, out_dtype,
84+
M, K, N, scales_size_0, scales_size_1,
85+
lda, ldb, result_ld,
86+
transpose_mat_1, transpose_mat_2, transpose_result});
87+
}
88+
};
89+
90+
} // namespace op::gptq_qyblas_gemm
91+
92+
#endif // __GPTQ_QYBLAS_GEMM_INFO_H__

0 commit comments

Comments
 (0)