Skip to content

Commit 1a7e74f

Browse files
committed
issue/1118: qyblas_gptq_w4a16_gemm
1 parent f2418ae commit 1a7e74f

File tree

20 files changed

+464
-21
lines changed

20 files changed

+464
-21
lines changed

include/infinicore/nn/linear.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class BaseLinear : public Module {
3434
Tensor bias() const { return bias_; }
3535
Tensor weight_scale() const { return weight_scale_; }
3636
Tensor weight_zeros() const { return weight_zeros_; }
37+
Tensor gidx() const { return gidx_; }
3738

3839
std::shared_ptr<infinicore::quantization::BaseQuantization> get_quantization() const { return quantization_; }
3940

@@ -45,6 +46,8 @@ class BaseLinear : public Module {
4546
INFINICORE_NN_PARAMETER(weight_scale);
4647
INFINICORE_NN_PARAMETER(weight_zeros);
4748

49+
INFINICORE_NN_PARAMETER(gidx);
50+
4851
protected:
4952
// Helper method for common forward computation
5053
Tensor compute_linear(Tensor &input) const;
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include "common/op.hpp"
4+
#include <optional>
5+
6+
namespace infinicore::op {
7+
8+
Tensor linear_w4a16_gptq_qy(Tensor in, Tensor qweight, Tensor qzeros, Tensor scales, int64_t quant_type, int64_t bit);
9+
10+
void linear_w4a16_gptq_qy_(Tensor out, Tensor in, Tensor qweights, Tensor scales, Tensor qzeros, int64_t quant_type, int64_t bit);
11+
12+
} // namespace infinicore::op
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
5+
#include "common/op.hpp"
6+
#include <optional>
7+
8+
namespace infinicore::op {
9+
10+
INFINICORE_GRAPH_OP_CLASS(GptqQyblasGemm, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t);
11+
12+
void scaled_mm_w4a16_gptq_qy_(Tensor out, const Tensor &in, const Tensor &qweight, const Tensor &scales, const Tensor &qzeros, int64_t quant_type, int64_t bit);
13+
} // namespace infinicore::op

include/infinicore/quantization.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
#include "quantization/awq.hpp"
44
#include "quantization/base_quantization.hpp"
55
#include "quantization/compressed_tensors.hpp"
6+
#include "quantization/gptq_qy.hpp"
67
#include "quantization/none_quantizaiton.hpp"
78
#include "quantization/quantization_scheme.hpp"

include/infinicore/quantization/base_quantization.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace infinicore::quantization {
66
class BaseQuantization {
77
// Base class for quantization schemes. Intended to be extended to support various quantization methods.
88
public:
9-
explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config){};
9+
explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config) {};
1010
virtual ~BaseQuantization() = default;
1111

1212
virtual infinicore::quantization::QuantScheme get_quant_scheme() const = 0;
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
#include "base_quantization.hpp"
3+
namespace infinicore::quantization {
4+
5+
class GPTQ_QY : public BaseQuantization {
6+
// This is a temporary class that currently only returns GPTQ W4A16.
7+
// Future enhancements should parse quant_config to extract detailed quantization
8+
// information and support multiple quantization schemes.
9+
public:
10+
explicit GPTQ_QY(const nlohmann::json &quant_config)
11+
: BaseQuantization(quant_config) {};
12+
13+
infinicore::quantization::QuantScheme
14+
get_quant_scheme() const override {
15+
return infinicore::quantization::QuantScheme::GPTQ_W4A16_QY;
16+
};
17+
18+
int get_packing_num() const {
19+
// For GPTQ, we pack 8 int4 weights into a single int32 value.
20+
return 32 / this->get_or<int>("bits", 4); // Default to 8 if not specified in config
21+
}
22+
23+
int get_group_size() const {
24+
// For simplicity, we return a fixed group size here. In a more complete implementation,
25+
// this could be extracted from quant_config_ to support different group sizes.
26+
return this->get_or<int>("group_size", 128); // Standard GPTQ group size
27+
}
28+
};
29+
30+
} // namespace infinicore::quantization

include/infinicore/quantization/quantization_scheme.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ enum class QuantScheme {
77
NONE,
88
COMPRESSED_TENSOR_W8A8I8,
99
AWQ_W4A16,
10+
GPTQ_W4A16_QY,
1011
};
1112

1213
enum class KVQuantAlgo {

include/infiniop.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "infiniop/ops/fmod.h"
4949
#include "infiniop/ops/gelu.h"
5050
#include "infiniop/ops/gemm.h"
51+
#include "infiniop/ops/gptq_qyblas_gemm.h"
5152
#include "infiniop/ops/hardswish.h"
5253
#include "infiniop/ops/hardtanh.h"
5354
#include "infiniop/ops/hinge_embedding_loss.h"

include/infiniop/ops/gemm.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,22 @@
66
typedef struct InfiniopDescriptor *infiniopGemmDescriptor_t;
77

88
__INFINI_C __export infiniStatus_t infiniopCreateGemmDescriptor(infiniopHandle_t handle,
9-
infiniopGemmDescriptor_t *desc_ptr,
10-
infiniopTensorDescriptor_t c_desc,
11-
infiniopTensorDescriptor_t a_desc,
12-
infiniopTensorDescriptor_t b_desc);
9+
infiniopGemmDescriptor_t *desc_ptr,
10+
infiniopTensorDescriptor_t c_desc,
11+
infiniopTensorDescriptor_t a_desc,
12+
infiniopTensorDescriptor_t b_desc);
1313

1414
__INFINI_C __export infiniStatus_t infiniopGetGemmWorkspaceSize(infiniopGemmDescriptor_t desc, size_t *size);
1515

1616
__INFINI_C __export infiniStatus_t infiniopGemm(infiniopGemmDescriptor_t desc,
17-
void *workspace,
18-
size_t workspace_size,
19-
void *c,
20-
void const *a,
21-
void const *b,
22-
float alpha,
23-
float beta,
24-
void *stream);
17+
void *workspace,
18+
size_t workspace_size,
19+
void *c,
20+
void const *a,
21+
void const *b,
22+
float alpha,
23+
float beta,
24+
void *stream);
2525

2626
__INFINI_C __export infiniStatus_t infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc);
2727

src/infinicore/nn/linear.cc

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "infinicore/ops/distributed/allreduce.hpp"
55
#include "infinicore/ops/linear.hpp"
66
#include "infinicore/ops/linear_w4a16_awq.hpp"
7+
#include "infinicore/ops/linear_w4a16_gptq_qy.hpp"
78
#include "infinicore/ops/linear_w8a8i8.hpp"
89
#include <optional>
910
#include <spdlog/spdlog.h>
@@ -53,6 +54,19 @@ Tensor BaseLinear::compute_linear(Tensor &input) const {
5354
auto output = infinicore::op::linear_w4a16_awq(input_contiguous->contiguous(), qweight, scales, qzeros, bias_opt);
5455
return output;
5556
}
57+
case infinicore::quantization::QuantScheme::GPTQ_W4A16_QY: {
58+
Tensor input_contiguous = input->is_contiguous() ? input : input->contiguous();
59+
Tensor qweight = static_cast<const Tensor &>(weight_);
60+
Tensor qzeros = static_cast<const Tensor &>(weight_zeros_);
61+
Tensor scales = static_cast<const Tensor &>(weight_scale_);
62+
Tensor g_idx = static_cast<const Tensor &>(gidx_);
63+
std::optional<Tensor> bias_opt = has_bias_ ? std::make_optional<Tensor>(static_cast<const Tensor &>(bias_)) : std::nullopt;
64+
auto output = infinicore::op::linear_w4a16_gptq_qy(input_contiguous->contiguous(), qweight, qzeros, scales, 0, 4);
65+
if (bias_opt.has_value()) {
66+
infinicore::op::add_(output, output, bias_opt.value()->as_strided(output->shape(), {0, 0, 1}));
67+
}
68+
return output;
69+
}
5670
default: {
5771
// Ensure input is contiguous before creating views (required for matmul)
5872
// This prevents hanging when input tensor has non-contiguous memory layout
@@ -140,6 +154,23 @@ Linear::Linear(size_t in_features, size_t out_features,
140154
}
141155
break;
142156
}
157+
case infinicore::quantization::QuantScheme::GPTQ_W4A16_QY: {
158+
weight_ = infinicore::nn::Parameter({in_features / 2, out_features}, infinicore::DataType::U8, device);
159+
this->register_parameter("qweight", weight_);
160+
weight_zeros_ = infinicore::nn::Parameter({in_features / 128, out_features}, dtype_, device);
161+
this->register_parameter("qzeros", weight_zeros_);
162+
weight_scale_ = infinicore::nn::Parameter({in_features / 128, out_features}, dtype_, device);
163+
this->register_parameter("scales", weight_scale_);
164+
165+
gidx_ = infinicore::nn::Parameter({in_features}, infinicore::DataType::I32, device);
166+
this->register_parameter("g_idx", gidx_);
167+
if (bias) {
168+
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device));
169+
} else {
170+
bias_ = Parameter();
171+
}
172+
break;
173+
}
143174
default: {
144175
// Initialize parameters using macro
145176
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device));
@@ -247,6 +278,27 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur
247278
}
248279
break;
249280
}
281+
case infinicore::quantization::QuantScheme::GPTQ_W4A16_QY: {
282+
auto gptq_ptr = std::static_pointer_cast<infinicore::quantization::GPTQ_QY>(this->quantization_);
283+
int group_size = gptq_ptr->get_group_size();
284+
int packing_num = gptq_ptr->get_packing_num();
285+
weight_ = infinicore::nn::Parameter({in_features / 2, out_features}, infinicore::DataType::U8, device, 1, tp_rank_, tp_size_);
286+
this->register_parameter("qweight", weight_);
287+
weight_zeros_ = infinicore::nn::Parameter({in_features / group_size, out_features}, dtype_, device, 1, tp_rank_, tp_size_);
288+
this->register_parameter("qzeros", weight_zeros_);
289+
weight_scale_ = infinicore::nn::Parameter({in_features / group_size, out_features}, dtype_, device, 1, tp_rank_, tp_size_);
290+
this->register_parameter("scales", weight_scale_);
291+
gidx_ = infinicore::nn::Parameter({in_features},
292+
infinicore::DataType::I32,
293+
device, 0, tp_rank_, tp_size_);
294+
this->register_parameter("g_idx", gidx_);
295+
if (bias) {
296+
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, tp_rank_, tp_size_));
297+
} else {
298+
bias_ = Parameter();
299+
}
300+
break;
301+
}
250302
default: {
251303
// Initialize parameters using macro
252304
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device,
@@ -356,6 +408,34 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, st
356408
}
357409
break;
358410
}
411+
case infinicore::quantization::QuantScheme::GPTQ_W4A16_QY: {
412+
// GPTQ W4A16 QY for RowParallelLinear:切分维度为 in_features(权重矩阵的第1维)
413+
// - Weight: packed int4 in U8 containers (8 int4 per U8)
414+
// - Group-wise quantization with group_size=128
415+
// - Scale and zero points stored per group along in_features dimension
416+
417+
auto gptq_ptr = std::static_pointer_cast<infinicore::quantization::GPTQ_QY>(this->quantization_);
418+
int group_size = gptq_ptr->get_group_size();
419+
int packing_num = gptq_ptr->get_packing_num();
420+
421+
weight_ = infinicore::nn::Parameter({in_features / 2, out_features}, infinicore::DataType::U8, device, 0, tp_rank_, tp_size_);
422+
this->register_parameter("qweight", weight_);
423+
weight_zeros_ = infinicore::nn::Parameter({in_features / group_size, out_features}, dtype_, device, 0, tp_rank_, tp_size_);
424+
this->register_parameter("qzeros", weight_zeros_);
425+
weight_scale_ = infinicore::nn::Parameter({in_features / group_size, out_features}, dtype_, device, 0, tp_rank_, tp_size_);
426+
this->register_parameter("scales", weight_scale_);
427+
428+
gidx_ = infinicore::nn::Parameter({in_features},
429+
infinicore::DataType::I32,
430+
device, 0, tp_rank_, tp_size_);
431+
this->register_parameter("g_idx", gidx_);
432+
if (bias && (0 == tp_rank_)) {
433+
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1));
434+
} else {
435+
bias_ = Parameter();
436+
}
437+
break;
438+
}
359439
default: {
360440
// Initialize parameters using macro
361441
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device,

0 commit comments

Comments
 (0)