Skip to content

Commit a995e64

Browse files
committed
implement linear module
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 80c9530 commit a995e64

9 files changed

Lines changed: 513 additions & 52 deletions

File tree

include/infinicore.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
22

3+
#include "infinicore/nn.hpp"
34
#include "infinicore/ops.hpp"
45
#include "infinicore/tensor.hpp"

include/infinicore/nn.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#pragma once
2+
3+
#include "nn/linear.hpp"

include/infinicore/nn/linear.hpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma once
2+
3+
#include "module.hpp"
4+
#include "../ops.hpp"
5+
6+
namespace infinicore::nn {
7+
8+
class Linear : public Module {
9+
public:
10+
Linear(size_t in_features, size_t out_features, bool bias = true, const Device &device = Device());
11+
12+
// Forward pass: output = input @ weight.T + bias
13+
Tensor forward(const Tensor &input) const;
14+
15+
// Forward pass with residual connection (InfiniLM-style)
16+
// output = input @ weight.T + bias + residual
17+
Tensor forward(const Tensor &input, const Tensor &residual) const;
18+
19+
// Accessors for parameters
20+
Tensor weight() const;
21+
Tensor bias() const;
22+
23+
// Module information
24+
size_t in_features() const { return in_features_; }
25+
size_t out_features() const { return out_features_; }
26+
bool has_bias() const { return has_bias_; }
27+
28+
// String representation
29+
std::string extra_repr() const;
30+
31+
private:
32+
size_t in_features_;
33+
size_t out_features_;
34+
bool has_bias_;
35+
};
36+
37+
} // namespace infinicore::nn

include/infinicore/nn/module.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ class Module {
1717

1818
Tensor register_parameter(const std::string &name, Parameter param);
1919

20+
// Create a Linear submodule-like parameter set (weight and optional bias)
21+
// Mirrors torch.nn.Linear shapes: weight [out_features, in_features], bias [out_features]
22+
void linear(const std::string &name, size_t in_features, size_t out_features, bool bias = true);
23+
2024
template <typename M>
2125
std::shared_ptr<M> add_module(const std::string &name, std::shared_ptr<M> submodule) {
2226
submodules_[name] = submodule;

include/infinicore/tensor.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class TensorImpl : public std::enable_shared_from_this<TensorImpl> {
146146

147147
/**
148148
* Copy Data from another tensor to this tensor.
149+
* Currently, only contigous tensors of the same dtype and shape are supported.
149150
*
150151
* @param src The source tensor to copy from
151152
*

src/infinicore-test/test_nn_module.cc

Lines changed: 322 additions & 50 deletions
Large diffs are not rendered by default.

src/infinicore-test/test_nn_module.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
namespace infinicore::test {
1313

1414
// Simple test module that mimics torch.nn.Linear
15-
class TestLinearModule : public infinicore::nn::Module {
15+
class MockLinearModule : public infinicore::nn::Module {
1616
public:
17-
TestLinearModule(int input_size, int output_size, const infinicore::Device &device)
17+
MockLinearModule(int input_size, int output_size, const infinicore::Device &device)
1818
: input_size_(input_size), output_size_(output_size), device_(device) {
1919

2020
// Initialize weight parameter (similar to torch.nn.Linear.weight)
@@ -71,6 +71,7 @@ class NNModuleTest : public MemoryTestFramework {
7171
TestResult testModuleHierarchy();
7272
TestResult testParameterLoading();
7373
TestResult testModuleComparison();
74+
TestResult testModuleLinear();
7475
};
7576

7677
} // namespace infinicore::test

src/infinicore/nn/linear.cc

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#include "infinicore/nn/linear.hpp"
2+
#include "infinicore/ops.hpp"
3+
#include <spdlog/spdlog.h>
4+
5+
namespace infinicore::nn {
6+
7+
Linear::Linear(size_t in_features, size_t out_features, bool bias, const Device &device)
8+
: in_features_(in_features), out_features_(out_features), has_bias_(bias) {
9+
10+
device_ = device;
11+
12+
// Register weight parameter: [out_features, in_features]
13+
register_parameter("weight", Parameter({out_features, in_features}, DataType::F32, device));
14+
15+
// Register bias parameter if requested: [out_features]
16+
if (bias) {
17+
register_parameter("bias", Parameter({out_features}, DataType::F32, device));
18+
}
19+
20+
spdlog::debug("Created Linear module: in_features={}, out_features={}, bias={}",
21+
in_features, out_features, bias);
22+
}
23+
24+
Tensor Linear::forward(const Tensor &input) const {
25+
auto sd = state_dict();
26+
auto weight = sd.at("weight");
27+
auto bias_it = sd.find("bias");
28+
29+
// Create output tensor with shape [batch_size, out_features]
30+
auto output_shape = input->shape();
31+
output_shape[output_shape.size() - 1] = out_features_;
32+
auto output = Tensor::empty(output_shape, input->dtype(), input->device());
33+
34+
// Transpose weight: [out_features, in_features] -> [in_features, out_features]
35+
auto weight_t = weight->permute({1, 0});
36+
37+
// InfiniLM-style linear computation: output = input @ weight_t + bias
38+
// Handle bias broadcasting similar to InferenceContext::linear
39+
if (bias_it != sd.end()) {
40+
auto bias = bias_it->second;
41+
42+
// Broadcast bias to output shape (similar to InfiniLM's bias handling)
43+
size_t ndim_diff = output->ndim() - 1;
44+
std::vector<Stride> strides(ndim_diff, 0);
45+
strides.push_back(bias->stride(0));
46+
auto bias_view = bias->as_strided(output->shape(), strides);
47+
48+
// First set output to bias (broadcasted)
49+
infinicore::op::rearrange_(output, bias_view);
50+
51+
// Compute matmul result separately, then add to output
52+
auto matmul_result = infinicore::op::matmul(input, weight_t);
53+
infinicore::op::add_(output, output, matmul_result);
54+
} else {
55+
// No bias: just compute output = input @ weight_t
56+
infinicore::op::matmul_(output, input, weight_t);
57+
}
58+
59+
return output;
60+
}
61+
62+
Tensor Linear::forward(const Tensor &input, const Tensor &residual) const {
63+
auto sd = state_dict();
64+
auto weight = sd.at("weight");
65+
auto bias_it = sd.find("bias");
66+
67+
// Create output tensor with shape [batch_size, out_features]
68+
auto output_shape = input->shape();
69+
output_shape[output_shape.size() - 1] = out_features_;
70+
auto output = Tensor::empty(output_shape, input->dtype(), input->device());
71+
72+
// Transpose weight: [out_features, in_features] -> [in_features, out_features]
73+
auto weight_t = weight->permute({1, 0});
74+
75+
// InfiniLM-style computation with residual: output = input @ weight_t + bias + residual
76+
if (bias_it != sd.end()) {
77+
auto bias = bias_it->second;
78+
79+
// Broadcast bias to output shape
80+
size_t ndim_diff = output->ndim() - 1;
81+
std::vector<Stride> strides(ndim_diff, 0);
82+
strides.push_back(bias->stride(0));
83+
auto bias_view = bias->as_strided(output->shape(), strides);
84+
85+
// First set output to bias (broadcasted)
86+
infinicore::op::rearrange_(output, bias_view);
87+
88+
// Compute matmul result separately, then add to output
89+
auto matmul_result = infinicore::op::matmul(input, weight_t);
90+
infinicore::op::add_(output, output, matmul_result);
91+
92+
// Add residual: output = output + residual
93+
infinicore::op::add_(output, output, residual);
94+
} else {
95+
// No bias: compute output = input @ weight_t + residual
96+
infinicore::op::matmul_(output, input, weight_t);
97+
infinicore::op::add_(output, output, residual);
98+
}
99+
100+
return output;
101+
}
102+
103+
Tensor Linear::weight() const {
104+
auto sd = state_dict();
105+
auto it = sd.find("weight");
106+
if (it != sd.end()) {
107+
return it->second;
108+
}
109+
throw std::runtime_error("Weight parameter not found");
110+
}
111+
112+
Tensor Linear::bias() const {
113+
if (!has_bias_) {
114+
throw std::runtime_error("Linear module does not have bias");
115+
}
116+
auto sd = state_dict();
117+
auto it = sd.find("bias");
118+
if (it != sd.end()) {
119+
return it->second;
120+
}
121+
throw std::runtime_error("Bias parameter not found");
122+
}
123+
124+
std::string Linear::extra_repr() const {
125+
return "in_features=" + std::to_string(in_features_) + ", out_features=" + std::to_string(out_features_) + ", bias=" + (has_bias_ ? "true" : "false");
126+
}
127+
128+
} // namespace infinicore::nn

src/infinicore/nn/module.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,17 @@ void Module::collect_all_parameters(const std::string &prefix, std::unordered_ma
5151
}
5252

5353
} // namespace infinicore::nn
54+
55+
namespace infinicore::nn {
56+
void Module::linear(const std::string &name, size_t in_features, size_t out_features, bool bias) {
57+
// Register weight parameter: [out_features, in_features]
58+
register_parameter(name + ".weight",
59+
Parameter({out_features, in_features}, DataType::F32, device_));
60+
61+
// Register optional bias parameter: [out_features]
62+
if (bias) {
63+
register_parameter(name + ".bias",
64+
Parameter({out_features}, DataType::F32, device_));
65+
}
66+
}
67+
} // namespace infinicore::nn

0 commit comments

Comments
 (0)