Skip to content

Commit 4b1265a

Browse files
committed
alpha scaling parameter to GEMM-based linear operations
1 parent 73fb6a8 commit 4b1265a

4 files changed

Lines changed: 11 additions & 7 deletions

File tree

include/infinicore/nn/linear.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class BaseLinear : public Module {
2828
size_t out_features() const { return out_features_; }
2929
bool has_bias() const { return has_bias_; }
3030
DataType dtype() const { return dtype_; }
31+
float alpha() const { return alpha_; }
32+
void set_alpha(float alpha) { alpha_ = alpha; }
3133

3234
// Accessors for parameters
3335
Tensor weight() const { return weight_; }
@@ -56,6 +58,7 @@ class BaseLinear : public Module {
5658
size_t out_features_;
5759
bool has_bias_;
5860
DataType dtype_;
61+
float alpha_ = 1.0f;
5962
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization_ = std::make_shared<infinicore::quantization::NoneQuantization>(nullptr);
6063
};
6164

include/infinicore/ops/linear.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
namespace infinicore::op {
77

8-
Tensor linear(Tensor input, Tensor weight, std::optional<Tensor> bias);
8+
Tensor linear(Tensor input, Tensor weight, std::optional<Tensor> bias, float alpha = 1.0f);
99

10-
void linear_(Tensor out, Tensor input, Tensor weight, std::optional<Tensor> bias);
10+
void linear_(Tensor out, Tensor input, Tensor weight, std::optional<Tensor> bias, float alpha = 1.0f);
1111

1212
} // namespace infinicore::op

src/infinicore/nn/linear.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ Tensor BaseLinear::compute_linear(Tensor &input) const {
7878
Tensor weight_tensor = static_cast<const Tensor &>(weight_);
7979
std::optional<Tensor> bias_opt = has_bias_ ? std::make_optional<Tensor>(static_cast<const Tensor &>(bias_)) : std::nullopt;
8080

81-
auto output = infinicore::op::linear(input_contiguous->contiguous(), weight_tensor->contiguous(), bias_opt);
81+
auto output = infinicore::op::linear(input_contiguous->contiguous(), weight_tensor->contiguous(), bias_opt, alpha_);
8282
return output;
8383
}
8484
}

src/infinicore/ops/linear/linear.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ namespace infinicore::op {
66

77
Tensor linear(Tensor input,
88
Tensor weight,
9-
std::optional<Tensor> bias) {
9+
std::optional<Tensor> bias,
10+
float alpha) {
1011

1112
Size ndim = input->ndim();
1213
Size out_features = weight->shape()[0];
@@ -17,14 +18,15 @@ Tensor linear(Tensor input,
1718
auto out = Tensor::empty(output_shape, input->dtype(), input->device());
1819

1920
// Inplace Calculate
20-
linear_(out, input, weight, bias);
21+
linear_(out, input, weight, bias, alpha);
2122
return out;
2223
}
2324

2425
void linear_(Tensor out,
2526
Tensor input,
2627
Tensor weight,
27-
std::optional<Tensor> bias) {
28+
std::optional<Tensor> bias,
29+
float alpha) {
2830

2931
auto weight_shape = weight->shape();
3032
Size out_features = weight_shape[0];
@@ -43,7 +45,6 @@ void linear_(Tensor out,
4345
// linear transformation
4446
Tensor out_view = out->view({N, out_features});
4547
// Add bias
46-
float alpha = 1.0f;
4748
float beta = 0.0f;
4849
if (bias.has_value()) {
4950
rearrange_(out_view,

0 commit comments

Comments
 (0)