Skip to content

Commit 5e2bf7c

Browse files
committed
make weight and bias fields of class && remove duplicate code
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 63917ca commit 5e2bf7c

2 files changed

Lines changed: 31 additions & 77 deletions

File tree

include/infinicore/nn/linear.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ class Linear : public Module {
1717
Tensor forward(const Tensor &input, const Tensor &residual) const;
1818

1919
// Accessors for parameters
20-
Tensor weight() const;
21-
Tensor bias() const;
20+
Tensor weight() const { return weight_; }
21+
Tensor bias() const { return bias_; }
2222

2323
// Module information
2424
size_t in_features() const { return in_features_; }
@@ -28,7 +28,14 @@ class Linear : public Module {
2828
// String representation
2929
std::string extra_repr() const;
3030

31+
// Direct access to parameters as fields
32+
Parameter weight_;
33+
Parameter bias_;
34+
3135
private:
36+
// Helper method for common forward computation
37+
Tensor compute_linear(const Tensor &input) const;
38+
3239
size_t in_features_;
3340
size_t out_features_;
3441
bool has_bias_;

src/infinicore/nn/linear.cc

Lines changed: 22 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -5,45 +5,41 @@
55
namespace infinicore::nn {
66

77
Linear::Linear(size_t in_features, size_t out_features, bool bias, const Device &device)
8-
: in_features_(in_features), out_features_(out_features), has_bias_(bias) {
8+
: weight_(Parameter({out_features, in_features}, DataType::F32, device)),
9+
bias_(bias ? Parameter({out_features}, DataType::F32, device) : Parameter()),
10+
in_features_(in_features),
11+
out_features_(out_features),
12+
has_bias_(bias) {
913

1014
device_ = device;
1115

12-
// Register weight parameter: [out_features, in_features]
13-
register_parameter("weight", Parameter({out_features, in_features}, DataType::F32, device));
16+
// Register weight parameter in state dict
17+
register_parameter("weight", weight_);
1418

15-
// Register bias parameter if requested: [out_features]
19+
// Register bias parameter if requested
1620
if (bias) {
17-
register_parameter("bias", Parameter({out_features}, DataType::F32, device));
21+
register_parameter("bias", bias_);
1822
}
1923

2024
spdlog::debug("Created Linear module: in_features={}, out_features={}, bias={}",
2125
in_features, out_features, bias);
2226
}
2327

24-
Tensor Linear::forward(const Tensor &input) const {
25-
auto sd = state_dict();
26-
auto weight = sd.at("weight");
27-
auto bias_it = sd.find("bias");
28-
28+
Tensor Linear::compute_linear(const Tensor &input) const {
2929
// Create output tensor with shape [batch_size, out_features]
3030
auto output_shape = input->shape();
3131
output_shape[output_shape.size() - 1] = out_features_;
3232
auto output = Tensor::empty(output_shape, input->dtype(), input->device());
3333

3434
// Transpose weight: [out_features, in_features] -> [in_features, out_features]
35-
auto weight_t = weight->permute({1, 0});
36-
37-
// InfiniLM-style linear computation: output = input @ weight_t + bias
38-
// Handle bias broadcasting similar to InferenceContext::linear
39-
if (bias_it != sd.end()) {
40-
auto bias = bias_it->second;
35+
auto weight_t = weight_->permute({1, 0});
4136

42-
// Broadcast bias to output shape (similar to InfiniLM's bias handling)
37+
if (has_bias_) {
38+
// Broadcast bias to output shape
4339
size_t ndim_diff = output->ndim() - 1;
4440
std::vector<Stride> strides(ndim_diff, 0);
45-
strides.push_back(bias->stride(0));
46-
auto bias_view = bias->as_strided(output->shape(), strides);
41+
strides.push_back(bias_->stride(0));
42+
auto bias_view = bias_->as_strided(output->shape(), strides);
4743

4844
// First set output to bias (broadcasted)
4945
infinicore::op::rearrange_(output, bias_view);
@@ -59,68 +55,19 @@ Tensor Linear::forward(const Tensor &input) const {
5955
return output;
6056
}
6157

62-
Tensor Linear::forward(const Tensor &input, const Tensor &residual) const {
63-
auto sd = state_dict();
64-
auto weight = sd.at("weight");
65-
auto bias_it = sd.find("bias");
66-
67-
// Create output tensor with shape [batch_size, out_features]
68-
auto output_shape = input->shape();
69-
output_shape[output_shape.size() - 1] = out_features_;
70-
auto output = Tensor::empty(output_shape, input->dtype(), input->device());
71-
72-
// Transpose weight: [out_features, in_features] -> [in_features, out_features]
73-
auto weight_t = weight->permute({1, 0});
74-
75-
// InfiniLM-style computation with residual: output = input @ weight_t + bias + residual
76-
if (bias_it != sd.end()) {
77-
auto bias = bias_it->second;
78-
79-
// Broadcast bias to output shape
80-
size_t ndim_diff = output->ndim() - 1;
81-
std::vector<Stride> strides(ndim_diff, 0);
82-
strides.push_back(bias->stride(0));
83-
auto bias_view = bias->as_strided(output->shape(), strides);
84-
85-
// First set output to bias (broadcasted)
86-
infinicore::op::rearrange_(output, bias_view);
58+
Tensor Linear::forward(const Tensor &input) const {
59+
return compute_linear(input);
60+
}
8761

88-
// Compute matmul result separately, then add to output
89-
auto matmul_result = infinicore::op::matmul(input, weight_t);
90-
infinicore::op::add_(output, output, matmul_result);
62+
Tensor Linear::forward(const Tensor &input, const Tensor &residual) const {
63+
auto output = compute_linear(input);
9164

92-
// Add residual: output = output + residual
93-
infinicore::op::add_(output, output, residual);
94-
} else {
95-
// No bias: compute output = input @ weight_t + residual
96-
infinicore::op::matmul_(output, input, weight_t);
97-
infinicore::op::add_(output, output, residual);
98-
}
65+
// Add residual: output = output + residual
66+
infinicore::op::add_(output, output, residual);
9967

10068
return output;
10169
}
10270

103-
Tensor Linear::weight() const {
104-
auto sd = state_dict();
105-
auto it = sd.find("weight");
106-
if (it != sd.end()) {
107-
return it->second;
108-
}
109-
throw std::runtime_error("Weight parameter not found");
110-
}
111-
112-
Tensor Linear::bias() const {
113-
if (!has_bias_) {
114-
throw std::runtime_error("Linear module does not have bias");
115-
}
116-
auto sd = state_dict();
117-
auto it = sd.find("bias");
118-
if (it != sd.end()) {
119-
return it->second;
120-
}
121-
throw std::runtime_error("Bias parameter not found");
122-
}
123-
12471
std::string Linear::extra_repr() const {
12572
return "in_features=" + std::to_string(in_features_) + ", out_features=" + std::to_string(out_features_) + ", bias=" + (has_bias_ ? "true" : "false");
12673
}

0 commit comments

Comments
 (0)