Skip to content

Commit 444ff10

Browse files
committed
issue/1143 - Add alpha scaling parameter in linear operation
1 parent 73fb6a8 commit 444ff10

9 files changed

Lines changed: 130 additions & 18 deletions

File tree

include/infinicore/nn/linear.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class BaseLinear : public Module {
2828
size_t out_features() const { return out_features_; }
2929
bool has_bias() const { return has_bias_; }
3030
DataType dtype() const { return dtype_; }
31+
float alpha() const { return alpha_; }
32+
void set_alpha(float alpha) { alpha_ = alpha; }
3133

3234
// Accessors for parameters
3335
Tensor weight() const { return weight_; }
@@ -56,6 +58,7 @@ class BaseLinear : public Module {
5658
size_t out_features_;
5759
bool has_bias_;
5860
DataType dtype_;
61+
float alpha_ = 1.0f;
5962
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization_ = std::make_shared<infinicore::quantization::NoneQuantization>(nullptr);
6063
};
6164

include/infinicore/ops/linear.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
namespace infinicore::op {
77

8-
Tensor linear(Tensor input, Tensor weight, std::optional<Tensor> bias);
8+
Tensor linear(Tensor input, Tensor weight, std::optional<Tensor> bias, float alpha = 1.0f);
99

10-
void linear_(Tensor out, Tensor input, Tensor weight, std::optional<Tensor> bias);
10+
void linear_(Tensor out, Tensor input, Tensor weight, std::optional<Tensor> bias, float alpha = 1.0f);
1111

1212
} // namespace infinicore::op

python/infinicore/nn/functional/linear.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,23 @@
44
__all__ = ["linear"]
55

66

7-
def linear(input: Tensor, weight: Tensor, bias=None, *, out=None) -> Tensor:
8-
r"""Applies a linear transformation to the incoming data: y=xA^T+b."""
7+
def linear(
8+
input: Tensor,
9+
weight: Tensor,
10+
bias=None,
11+
*,
12+
alpha: float = 1.0,
13+
out=None,
14+
) -> Tensor:
15+
r"""Applies a linear transformation to the incoming data: y=alpha*xA^T+b."""
916

1017
if out is None:
1118
return Tensor(
1219
_infinicore.linear(
1320
input._underlying,
1421
weight._underlying,
1522
None if bias is None else bias._underlying,
23+
alpha,
1624
)
1725
)
1826

@@ -21,5 +29,6 @@ def linear(input: Tensor, weight: Tensor, bias=None, *, out=None) -> Tensor:
2129
input._underlying,
2230
weight._underlying,
2331
None if bias is None else bias._underlying,
32+
alpha,
2433
)
2534
return out

python/infinicore/nn/modules/linear.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
super().__init__()
4646
self.in_features = in_features
4747
self.out_features = out_features
48+
self._alpha = 1.0
4849
self.weight = Parameter(
4950
infinicore.empty([out_features, in_features], **factory_kwargs)
5051
)
@@ -55,7 +56,15 @@ def __init__(
5556
self.register_parameter("bias", None)
5657

5758
def forward(self, input: Tensor) -> Tensor:
58-
return F.linear(input, self.weight, self.bias)
59+
return F.linear(input, self.weight, self.bias, alpha=self._alpha)
5960

6061
def extra_repr(self) -> str:
6162
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
63+
64+
@property
65+
def alpha(self) -> float:
66+
return self._alpha
67+
68+
@alpha.setter
69+
def alpha(self, value: float):
70+
self._alpha = value

src/infinicore/nn/linear.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ Tensor BaseLinear::compute_linear(Tensor &input) const {
7878
Tensor weight_tensor = static_cast<const Tensor &>(weight_);
7979
std::optional<Tensor> bias_opt = has_bias_ ? std::make_optional<Tensor>(static_cast<const Tensor &>(bias_)) : std::nullopt;
8080

81-
auto output = infinicore::op::linear(input_contiguous->contiguous(), weight_tensor->contiguous(), bias_opt);
81+
auto output = infinicore::op::linear(input_contiguous->contiguous(), weight_tensor->contiguous(), bias_opt, alpha_);
8282
return output;
8383
}
8484
}

src/infinicore/ops/linear/linear.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ namespace infinicore::op {
66

77
Tensor linear(Tensor input,
88
Tensor weight,
9-
std::optional<Tensor> bias) {
9+
std::optional<Tensor> bias,
10+
float alpha) {
1011

1112
Size ndim = input->ndim();
1213
Size out_features = weight->shape()[0];
@@ -17,14 +18,15 @@ Tensor linear(Tensor input,
1718
auto out = Tensor::empty(output_shape, input->dtype(), input->device());
1819

1920
// Inplace Calculate
20-
linear_(out, input, weight, bias);
21+
linear_(out, input, weight, bias, alpha);
2122
return out;
2223
}
2324

2425
void linear_(Tensor out,
2526
Tensor input,
2627
Tensor weight,
27-
std::optional<Tensor> bias) {
28+
std::optional<Tensor> bias,
29+
float alpha) {
2830

2931
auto weight_shape = weight->shape();
3032
Size out_features = weight_shape[0];
@@ -43,7 +45,6 @@ void linear_(Tensor out,
4345
// linear transformation
4446
Tensor out_view = out->view({N, out_features});
4547
// Add bias
46-
float alpha = 1.0f;
4748
float beta = 0.0f;
4849
if (bias.has_value()) {
4950
rearrange_(out_view,

src/infinicore/pybind11/ops/linear.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,27 @@ namespace infinicore::ops {
1010

1111
Tensor py_linear(Tensor input,
1212
Tensor weight,
13-
pybind11::object bias) {
13+
pybind11::object bias,
14+
float alpha = 1.0f) {
1415
std::optional<Tensor> bias_tensor = std::nullopt;
1516
if (!bias.is_none()) {
1617
bias_tensor = bias.cast<Tensor>();
1718
}
18-
return op::linear(input, weight, bias_tensor);
19+
return op::linear(input, weight, bias_tensor, alpha);
1920
}
2021

2122
void py_linear_(Tensor out,
2223
Tensor input,
2324
Tensor weight,
24-
pybind11::object bias) {
25+
pybind11::object bias,
26+
float alpha = 1.0f) {
2527

2628
std::optional<Tensor> bias_tensor = std::nullopt;
2729
if (!bias.is_none()) {
2830
bias_tensor = bias.cast<Tensor>();
2931
}
3032

31-
op::linear_(out, input, weight, bias_tensor);
33+
op::linear_(out, input, weight, bias_tensor, alpha);
3234
}
3335

3436
inline void bind_linear(py::module &m) {
@@ -38,15 +40,17 @@ inline void bind_linear(py::module &m) {
3840
py::arg("input"),
3941
py::arg("weight"),
4042
py::arg("bias") = py::none(),
41-
R"doc(Applies a linear transformation to the incoming data: y=xA^T+b.)doc");
43+
py::arg("alpha") = 1.0f,
44+
R"doc(Applies a linear transformation to the incoming data: y=alpha*xA^T+b.)doc");
4245

4346
m.def("linear_",
4447
&ops::py_linear_,
4548
py::arg("out"),
4649
py::arg("input"),
4750
py::arg("weight"),
4851
py::arg("bias") = py::none(),
49-
R"doc(In-place, applies a linear transformation to the incoming data: y=xA^T+b.)doc");
52+
py::arg("alpha") = 1.0f,
53+
R"doc(In-place, applies a linear transformation to the incoming data: y=alpha*xA^T+b.)doc");
5054
}
5155

5256
} // namespace infinicore::ops

test/infinicore/nn/linear.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@
3030
((10, 5, 1024), (3072, 1024), (3072,), False),
3131
]
3232

33+
# Alpha test cases: (x_shape, weight_shape, bias_shape, bias, alpha)
34+
_ALPHA_TEST_CASES_DATA = [
35+
((2, 5, 256), (512, 256), (512,), True, 2.5),
36+
((2, 5, 256), (512, 256), (512,), False, 0.5),
37+
((1, 1024), (3072, 1024), (3072,), True, 1.0),
38+
]
39+
3340
# Tolerance configuration
3441
_TOLERANCE_MAP = {
3542
infinicore.float16: {"atol": 0, "rtol": 1e-2},
@@ -74,6 +81,25 @@ def parse_test_cases():
7481
)
7582
)
7683

84+
# Alpha test cases
85+
for x_shape, weight_shape, bias_shape, has_bias, alpha in _ALPHA_TEST_CASES_DATA:
86+
for dtype in _TENSOR_DTYPES:
87+
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
88+
x_spec = TensorSpec.from_tensor(x_shape, None, dtype, name="x")
89+
weight_spec = TensorSpec.from_tensor(weight_shape, None, dtype, name="weight")
90+
bias_spec = TensorSpec.from_tensor(bias_shape, None, dtype, name="bias")
91+
92+
test_cases.append(
93+
TestCase(
94+
inputs=[x_spec, weight_spec, bias_spec],
95+
kwargs={"has_bias": has_bias, "alpha": alpha},
96+
output_spec=None,
97+
comparison_target=None,
98+
tolerance=tolerance,
99+
description=f"nn.Linear - ALPHA={alpha}",
100+
)
101+
)
102+
77103
return test_cases
78104

79105

@@ -123,7 +149,7 @@ def __init__(self):
123149
def get_test_cases(self):
124150
return parse_test_cases()
125151

126-
def torch_operator(self, x, weight, bias, has_bias):
152+
def torch_operator(self, x, weight, bias, has_bias, alpha=None):
127153
"""PyTorch nn.Linear implementation"""
128154
out_features, in_features = weight.shape
129155
params_dict = {"l.weight": weight}
@@ -141,9 +167,13 @@ def torch_operator(self, x, weight, bias, has_bias):
141167

142168
with torch.no_grad():
143169
y = model(x)
170+
if alpha is not None:
171+
# alpha scales only matmul, not bias: alpha * (x @ W^T) + b
172+
y_matmul = torch.nn.functional.linear(x, weight)
173+
y = alpha * y_matmul + (bias if has_bias else 0)
144174
return y
145175

146-
def infinicore_operator(self, x, weight, bias, has_bias):
176+
def infinicore_operator(self, x, weight, bias, has_bias, alpha=None):
147177
"""InfiniCore nn.Linear implementation"""
148178

149179
out_features, in_features = weight.shape
@@ -158,6 +188,8 @@ def infinicore_operator(self, x, weight, bias, has_bias):
158188
device=weight.device,
159189
dtype=weight.dtype,
160190
)
191+
if alpha is not None:
192+
model.l.alpha = alpha
161193
model.load_state_dict(params_dict)
162194

163195
y = model(x)

test/infinicore/ops/linear.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
(None, 1, 2048, 5632, True, None, None, None),
2727
]
2828

29+
# Alpha test cases: (bs, n, in_features, out_features, bias, input_strides, weight_strides, out_strides, alpha)
30+
_ALPHA_TEST_CASES_DATA = [
31+
(2, 5, 256, 512, True, None, None, None, 2.5),
32+
(2, 5, 256, 512, False, None, None, None, 0.5),
33+
(1, 10, 256, 512, True, None, None, None, 0.0),
34+
]
35+
2936
# Tolerance configuration
3037
_TOLERANCE_MAP = {
3138
infinicore.float16: {"atol": 0, "rtol": 1e-2},
@@ -109,6 +116,40 @@ def parse_test_cases():
109116
)
110117
)
111118

119+
# Alpha test cases
120+
for data in _ALPHA_TEST_CASES_DATA:
121+
bs = data[0]
122+
n, in_features, out_features = data[1], data[2], data[3]
123+
bias = data[4]
124+
input_strides = data[5] if len(data) > 5 else None
125+
weight_strides = data[6] if len(data) > 6 else None
126+
out_strides = data[7] if len(data) > 7 else None
127+
alpha = data[8]
128+
129+
if bs is None:
130+
input_shape = (n, in_features)
131+
else:
132+
input_shape = (bs, n, in_features)
133+
weight_shape = (out_features, in_features)
134+
bias_shape = (out_features,) if bias else None
135+
136+
for dtype in _TENSOR_DTYPES:
137+
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
138+
input_spec = TensorSpec.from_tensor(input_shape, input_strides, dtype)
139+
weight_spec = TensorSpec.from_tensor(weight_shape, weight_strides, dtype)
140+
bias_spec = TensorSpec.from_tensor(bias_shape, None, dtype) if bias_shape else None
141+
142+
test_cases.append(
143+
TestCase(
144+
inputs=[input_spec, weight_spec, bias_spec],
145+
kwargs={"alpha": alpha},
146+
output_spec=None,
147+
comparison_target=None,
148+
tolerance=tolerance,
149+
description=f"Linear - ALPHA={alpha}",
150+
)
151+
)
152+
112153
return test_cases
113154

114155

@@ -123,6 +164,19 @@ def get_test_cases(self):
123164

124165
def torch_operator(self, *args, **kwargs):
125166
"""PyTorch linear implementation"""
167+
alpha = kwargs.pop("alpha", 1.0)
168+
if alpha != 1.0:
169+
input_tensor = args[0]
170+
weight = args[1]
171+
bias = args[2] if len(args) > 2 else None
172+
out = kwargs.get("out")
173+
matmul_result = torch.nn.functional.linear(input_tensor, weight)
174+
bias_value = bias if bias is not None else 0
175+
result = alpha * matmul_result + bias_value
176+
if out is not None:
177+
out.copy_(result)
178+
return out
179+
return result
126180
return torch.nn.functional.linear(*args, **kwargs)
127181

128182
def infinicore_operator(self, *args, **kwargs):

0 commit comments

Comments
 (0)