Skip to content

Commit e60985d

Browse files
authored
Merge pull request #1040 from InfiniTensor/Issue/1030
Issue/1030: Nvidia 支持w4a16推理
2 parents 5877121 + 63233f9 commit e60985d

4 files changed

Lines changed: 144 additions & 10 deletions

File tree

include/infinicore/quantization/awq.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,23 @@ class AWQ : public BaseQuantization {
88
// information and support multiple quantization schemes.
99
public:
1010
explicit AWQ(const nlohmann::json &quant_config)
11-
: BaseQuantization(quant_config) {};
11+
: BaseQuantization(quant_config){};
1212

1313
infinicore::quantization::QuantScheme
1414
get_quant_scheme() const override {
1515
return infinicore::quantization::QuantScheme::AWQ_W4A16;
1616
};
17+
18+
int get_packing_num() const {
19+
// For AWQ, we pack 8 int4 weights into a single int32 value.
20+
return 32 / this->get_or<int>("bits", 4); // Default to 8 if not specified in config
21+
}
22+
23+
int get_group_size() const {
24+
// For simplicity, we return a fixed group size here. In a more complete implementation,
25+
// this could be extracted from quant_config_ to support different group sizes.
26+
return this->get_or<int>("group_size", 128); // Standard AWQ group size
27+
}
1728
};
1829

1930
} // namespace infinicore::quantization

include/infinicore/quantization/base_quantization.hpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,34 @@ namespace infinicore::quantization {
66
class BaseQuantization {
77
// Base class for quantization schemes. Intended to be extended to support various quantization methods.
88
public:
9-
explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config) {};
9+
explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config){};
1010
virtual ~BaseQuantization() = default;
1111

1212
virtual infinicore::quantization::QuantScheme get_quant_scheme() const = 0;
13+
template <typename T>
14+
T get(const std::string &key) const {
15+
if (!quant_config_.contains(key)) {
16+
throw std::out_of_range("Key '" + key + "' not found in config.");
17+
}
18+
try {
19+
return quant_config_.at(key).get<T>();
20+
} catch (const nlohmann::json::type_error &e) {
21+
throw std::runtime_error("Type conversion failed for key '" + key + "': " + std::string(e.what()));
22+
}
23+
}
24+
25+
template <typename T>
26+
T get_or(const std::string &key, const T &default_value) const {
27+
if (!quant_config_.contains(key) || quant_config_.at(key).is_null()) {
28+
return default_value;
29+
}
30+
try {
31+
return quant_config_.at(key).get<T>();
32+
} catch (const nlohmann::json::type_error &) {
33+
// If type conversion fails, return default value
34+
return default_value;
35+
}
36+
}
1337

1438
protected:
1539
nlohmann::json quant_config_;

src/infinicore/nn/linear.cc

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "infinicore/ops.hpp"
44
#include "infinicore/ops/distributed/allreduce.hpp"
55
#include "infinicore/ops/linear.hpp"
6+
#include "infinicore/ops/linear_w4a16_awq.hpp"
67
#include "infinicore/ops/linear_w8a8i8.hpp"
78
#include <optional>
89
#include <spdlog/spdlog.h>
@@ -43,6 +44,15 @@ Tensor BaseLinear::compute_linear(Tensor &input) const {
4344
auto output = infinicore::op::linear_w8a8i8(input_contiguous->contiguous(), weight_packed_tensor, weight_scale_tensor, bias_opt);
4445
return output;
4546
}
47+
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
48+
Tensor input_contiguous = input->is_contiguous() ? input : input->contiguous();
49+
Tensor qweight = static_cast<const Tensor &>(weight_);
50+
Tensor qzeros = static_cast<const Tensor &>(weight_zeros_);
51+
Tensor scales = static_cast<const Tensor &>(weight_scale_);
52+
std::optional<Tensor> bias_opt = has_bias_ ? std::make_optional<Tensor>(static_cast<const Tensor &>(bias_)) : std::nullopt;
53+
auto output = infinicore::op::linear_w4a16_awq(input_contiguous->contiguous(), qweight, scales, qzeros, bias_opt);
54+
return output;
55+
}
4656
default: {
4757
// Ensure input is contiguous before creating views (required for matmul)
4858
// This prevents hanging when input tensor has non-contiguous memory layout
@@ -116,6 +126,20 @@ Linear::Linear(size_t in_features, size_t out_features,
116126
}
117127
break;
118128
}
129+
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
130+
weight_ = infinicore::nn::Parameter({out_features, in_features}, infinicore::DataType::I32, device);
131+
this->register_parameter("qweight", weight_);
132+
weight_zeros_ = infinicore::nn::Parameter({out_features, in_features}, infinicore::DataType::I32, device);
133+
this->register_parameter("qzeros", weight_zeros_);
134+
weight_scale_ = infinicore::nn::Parameter({out_features, in_features}, dtype_, device);
135+
this->register_parameter("scales", weight_scale_);
136+
if (bias) {
137+
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device));
138+
} else {
139+
bias_ = Parameter();
140+
}
141+
break;
142+
}
119143
default: {
120144
// Initialize parameters using macro
121145
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device));
@@ -190,6 +214,39 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur
190214
}
191215
break;
192216
}
217+
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
218+
auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(this->quantization_);
219+
int group_size = awq_ptr->get_group_size();
220+
int packing_num = awq_ptr->get_packing_num();
221+
222+
weight_ = infinicore::nn::Parameter({in_features, out_features / packing_num},
223+
infinicore::DataType::I32,
224+
device, 1, tp_rank_, tp_size_);
225+
this->register_parameter("qweight", weight_);
226+
227+
// Weight scale: [out_features, in_features / group_size]
228+
// One FP32 scale per group of weights (group_size=128)
229+
230+
weight_scale_ = infinicore::nn::Parameter({in_features / group_size, out_features},
231+
dtype_,
232+
device, 1, tp_rank_, tp_size_);
233+
this->register_parameter("scales", weight_scale_);
234+
235+
// Weight zeros (zero points): [out_features, in_features / group_size]
236+
// AWQ implementations (e.g., AutoAWQ) typically store zero points as I32
237+
// for symmetric/asymmetric quantization support
238+
weight_zeros_ = infinicore::nn::Parameter({in_features / group_size, out_features / packing_num},
239+
infinicore::DataType::I32,
240+
device, 1, tp_rank_, tp_size_);
241+
242+
this->register_parameter("qzeros", weight_zeros_);
243+
if (bias) {
244+
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1));
245+
} else {
246+
bias_ = Parameter();
247+
}
248+
break;
249+
}
193250
default: {
194251
// Initialize parameters using macro
195252
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device,
@@ -261,6 +318,44 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, st
261318
}
262319
break;
263320
}
321+
case infinicore::quantization::QuantScheme::AWQ_W4A16: {
322+
// AWQ W4A16 for RowParallelLinear:切分维度为 in_features(权重矩阵的第1维)
323+
// - Weight: packed int4 in I32 containers (8 int4 per I32)
324+
// - Group-wise quantization with group_size=128
325+
// - Scale and zero points stored per group along in_features dimension
326+
327+
auto awq_ptr = std::static_pointer_cast<infinicore::quantization::AWQ>(this->quantization_);
328+
int group_size = awq_ptr->get_group_size();
329+
int packing_num = awq_ptr->get_packing_num();
330+
331+
// Packed weight: [out_features, in_features / 8]
332+
weight_ = infinicore::nn::Parameter({in_features, out_features / packing_num},
333+
infinicore::DataType::I32,
334+
device, 0, tp_rank_, tp_size_);
335+
this->register_parameter("qweight", weight_);
336+
337+
// Weight scale: [out_features, in_features / group_size]
338+
339+
weight_scale_ = infinicore::nn::Parameter({in_features / group_size, out_features},
340+
dtype_,
341+
device, 0, tp_rank_, tp_size_);
342+
this->register_parameter("scales", weight_scale_);
343+
// Weight zeros (zero points): [out_features, in_features / group_size]
344+
weight_zeros_ = infinicore::nn::Parameter({in_features / group_size, out_features / packing_num},
345+
infinicore::DataType::I32,
346+
device, 0, tp_rank_, tp_size_);
347+
this->register_parameter("qzeros", weight_zeros_);
348+
349+
// Bias handling in RowParallelLinear:
350+
// - Only rank 0 holds the full bias (after all-reduce on output)
351+
// - Other ranks have empty bias parameter
352+
if (bias && (0 == tp_rank_)) {
353+
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device, 0, 0, 1));
354+
} else {
355+
bias_ = Parameter();
356+
}
357+
break;
358+
}
264359
default: {
265360
// Initialize parameters using macro
266361
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device,

src/infinicore/ops/linear_w4a16_awq/linear_w4a16_awq.cc

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "infinicore/ops/linear_w4a16_awq.hpp"
22
#include "infinicore/ops/dequantize_awq.hpp"
33
#include "infinicore/ops/gemm.hpp"
4-
4+
#include "infinicore/ops/rearrange.hpp"
55
namespace infinicore::op {
66

77
Tensor linear_w4a16_awq(Tensor input,
@@ -12,7 +12,8 @@ Tensor linear_w4a16_awq(Tensor input,
1212

1313
// Input is of shape [M, K], Weight_packed is of shape [N, K],stirdes is [N, 1]
1414
Size ndim = input->ndim();
15-
Size out_features = weight_packed->shape()[0];
15+
Size element_size = weight_packed->element_size();
16+
Size out_features = weight_packed->shape()[1] * element_size * 2;
1617

1718
// Assign memory to out variables
1819
auto output_shape = input->shape();
@@ -33,7 +34,7 @@ void linear_w4a16_awq_(Tensor out,
3334

3435
auto weight_packed_shape = weight_packed->shape();
3536
Size out_features = weight_packed_shape[0];
36-
Size in_features = weight_packed_shape[1];
37+
Size in_features = weight_packed_shape[1] * 8;
3738

3839
Size ndim = input->ndim();
3940
assert(out->ndim() == ndim);
@@ -43,18 +44,21 @@ void linear_w4a16_awq_(Tensor out,
4344
for (size_t i = 0; i < ndim - 1; ++i) {
4445
N *= input_shape[i];
4546
}
46-
4747
auto weight = Tensor::empty(
4848
{out_features, in_features},
4949
out->dtype(),
5050
weight_packed->device());
5151
float alpha = 1.0f;
5252
float beta = 0.0f;
5353
op::dequantize_awq_(weight, weight_packed, weight_scale, weight_zeros);
54-
bias = std::make_optional(bias.value()->as_strided({N, out_features}, {0, 1}));
55-
gemm_(out->view({N, out_features}),
56-
input->view({N, in_features}),
57-
weight->permute({1, 0}), alpha, beta);
54+
if (bias.has_value()) {
55+
rearrange_(out,
56+
bias.value()->as_strided({N, in_features}, {0, 1}));
57+
beta = 1.0f;
58+
}
59+
gemm_(out,
60+
input->view({N, out_features}),
61+
weight, alpha, beta);
5862
}
5963

6064
} // namespace infinicore::op

0 commit comments

Comments
 (0)