Skip to content

Commit 3f0a98c

Browse files
xgqdut2016qinyiqun
andauthored
issue/1118: QY机器添加gptq_qyblas_gemm算子 (#1123)
* issue/1118: qyblas error * issue/1118: success qy int8 test * issue/1118: bit=4 error * issue/1118: debug w4 * issue/1118: success w4 kernel * issue/1118: success quant_type and bit test * issue/1118: qyblas_gptq_w4a16_gemm --------- Co-authored-by: qinyiqun <qinyiqun@outlook.com>
1 parent fa2a580 commit 3f0a98c

28 files changed

Lines changed: 1944 additions & 21 deletions

include/infinicore/nn/linear.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class BaseLinear : public Module {
3434
Tensor bias() const { return bias_; }
3535
Tensor weight_scale() const { return weight_scale_; }
3636
Tensor weight_zeros() const { return weight_zeros_; }
37+
Tensor gidx() const { return gidx_; }
3738

3839
std::shared_ptr<infinicore::quantization::BaseQuantization> get_quantization() const { return quantization_; }
3940

@@ -45,6 +46,8 @@ class BaseLinear : public Module {
4546
INFINICORE_NN_PARAMETER(weight_scale);
4647
INFINICORE_NN_PARAMETER(weight_zeros);
4748

49+
INFINICORE_NN_PARAMETER(gidx);
50+
4851
protected:
4952
// Helper method for common forward computation
5053
Tensor compute_linear(Tensor &input) const;
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include "common/op.hpp"
4+
#include <optional>
5+
6+
namespace infinicore::op {
7+
8+
Tensor linear_w4a16_gptq_qy(Tensor in, Tensor qweight, Tensor qzeros, Tensor scales, int64_t quant_type, int64_t bit);
9+
10+
void linear_w4a16_gptq_qy_(Tensor out, Tensor in, Tensor qweights, Tensor scales, Tensor qzeros, int64_t quant_type, int64_t bit);
11+
12+
} // namespace infinicore::op
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
5+
#include "common/op.hpp"
6+
#include <optional>
7+
8+
namespace infinicore::op {
9+
10+
INFINICORE_GRAPH_OP_CLASS(GptqQyblasGemm, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t);
11+
12+
void scaled_mm_w4a16_gptq_qy_(Tensor out, const Tensor &in, const Tensor &qweight, const Tensor &scales, const Tensor &qzeros, int64_t quant_type, int64_t bit);
13+
} // namespace infinicore::op

include/infinicore/quantization.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
#include "quantization/awq.hpp"
44
#include "quantization/base_quantization.hpp"
55
#include "quantization/compressed_tensors.hpp"
6+
#include "quantization/gptq_qy.hpp"
67
#include "quantization/none_quantizaiton.hpp"
78
#include "quantization/quantization_scheme.hpp"

include/infinicore/quantization/base_quantization.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace infinicore::quantization {
66
class BaseQuantization {
77
// Base class for quantization schemes. Intended to be extended to support various quantization methods.
88
public:
9-
explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config){};
9+
explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config) {};
1010
virtual ~BaseQuantization() = default;
1111

1212
virtual infinicore::quantization::QuantScheme get_quant_scheme() const = 0;
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
#include "base_quantization.hpp"
3+
namespace infinicore::quantization {
4+
5+
class GPTQ_QY : public BaseQuantization {
6+
// This is a temporary class that currently only returns GPTQ W4A16.
7+
// Future enhancements should parse quant_config to extract detailed quantization
8+
// information and support multiple quantization schemes.
9+
public:
10+
explicit GPTQ_QY(const nlohmann::json &quant_config)
11+
: BaseQuantization(quant_config) {};
12+
13+
infinicore::quantization::QuantScheme
14+
get_quant_scheme() const override {
15+
return infinicore::quantization::QuantScheme::GPTQ_W4A16_QY;
16+
};
17+
18+
int get_packing_num() const {
19+
// For GPTQ, we pack 8 int4 weights into a single int32 value.
20+
return 32 / this->get_or<int>("bits", 4); // Default to 8 if not specified in config
21+
}
22+
23+
int get_group_size() const {
24+
// For simplicity, we return a fixed group size here. In a more complete implementation,
25+
// this could be extracted from quant_config_ to support different group sizes.
26+
return this->get_or<int>("group_size", 128); // Standard GPTQ group size
27+
}
28+
};
29+
30+
} // namespace infinicore::quantization

include/infinicore/quantization/quantization_scheme.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ enum class QuantScheme {
77
NONE,
88
COMPRESSED_TENSOR_W8A8I8,
99
AWQ_W4A16,
10+
GPTQ_W4A16_QY,
1011
};
1112

1213
enum class KVQuantAlgo {

include/infiniop.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "infiniop/ops/fmod.h"
4949
#include "infiniop/ops/gelu.h"
5050
#include "infiniop/ops/gemm.h"
51+
#include "infiniop/ops/gptq_qyblas_gemm.h"
5152
#include "infiniop/ops/hardswish.h"
5253
#include "infiniop/ops/hardtanh.h"
5354
#include "infiniop/ops/hinge_embedding_loss.h"

include/infiniop/ops/gemm.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,22 @@
66
typedef struct InfiniopDescriptor *infiniopGemmDescriptor_t;
77

88
__INFINI_C __export infiniStatus_t infiniopCreateGemmDescriptor(infiniopHandle_t handle,
9-
infiniopGemmDescriptor_t *desc_ptr,
10-
infiniopTensorDescriptor_t c_desc,
11-
infiniopTensorDescriptor_t a_desc,
12-
infiniopTensorDescriptor_t b_desc);
9+
infiniopGemmDescriptor_t *desc_ptr,
10+
infiniopTensorDescriptor_t c_desc,
11+
infiniopTensorDescriptor_t a_desc,
12+
infiniopTensorDescriptor_t b_desc);
1313

1414
__INFINI_C __export infiniStatus_t infiniopGetGemmWorkspaceSize(infiniopGemmDescriptor_t desc, size_t *size);
1515

1616
__INFINI_C __export infiniStatus_t infiniopGemm(infiniopGemmDescriptor_t desc,
17-
void *workspace,
18-
size_t workspace_size,
19-
void *c,
20-
void const *a,
21-
void const *b,
22-
float alpha,
23-
float beta,
24-
void *stream);
17+
void *workspace,
18+
size_t workspace_size,
19+
void *c,
20+
void const *a,
21+
void const *b,
22+
float alpha,
23+
float beta,
24+
void *stream);
2525

2626
__INFINI_C __export infiniStatus_t infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc);
2727

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#ifndef __INFINIOP_GPTQ_QYBLAS_GEMM_API_H__
2+
#define __INFINIOP_GPTQ_QYBLAS_GEMM_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
#include <cstdint>
6+
7+
typedef struct InfiniopDescriptor *infiniopGptqQyblasGemmDescriptor_t;
8+
9+
__INFINI_C __export infiniStatus_t infiniopCreateGptqQyblasGemmDescriptor(
10+
infiniopHandle_t handle,
11+
infiniopGptqQyblasGemmDescriptor_t *desc_ptr,
12+
infiniopTensorDescriptor_t out_desc,
13+
infiniopTensorDescriptor_t a_desc,
14+
infiniopTensorDescriptor_t b_desc,
15+
infiniopTensorDescriptor_t b_scales_desc,
16+
infiniopTensorDescriptor_t b_zeros_desc);
17+
18+
__INFINI_C __export infiniStatus_t infiniopGetGptqQyblasGemmWorkspaceSize(
19+
infiniopGptqQyblasGemmDescriptor_t desc,
20+
size_t *size);
21+
22+
__INFINI_C __export infiniStatus_t infiniopGptqQyblasGemm(
23+
infiniopGptqQyblasGemmDescriptor_t desc,
24+
void *workspace,
25+
size_t workspace_size,
26+
void *out,
27+
const void *a,
28+
const void *b,
29+
void *b_scale,
30+
void *b_zero,
31+
int64_t quant_type,
32+
int64_t bit,
33+
void *stream);
34+
35+
__INFINI_C __export infiniStatus_t infiniopDestroyGptqQyblasGemmDescriptor(
36+
infiniopGptqQyblasGemmDescriptor_t desc);
37+
#endif

0 commit comments

Comments
 (0)