55namespace infinicore ::nn {
66
77Linear::Linear (size_t in_features, size_t out_features, bool bias, const Device &device)
8- : in_features_(in_features), out_features_(out_features), has_bias_(bias) {
8+ : weight_(Parameter({out_features, in_features}, DataType::F32, device)),
9+ bias_ (bias ? Parameter({out_features}, DataType::F32, device) : Parameter()),
10+ in_features_ (in_features),
11+ out_features_(out_features),
12+ has_bias_(bias) {
913
1014 device_ = device;
1115
12- // Register weight parameter: [out_features, in_features]
13- register_parameter (" weight" , Parameter ({out_features, in_features}, DataType::F32, device) );
16+ // Register weight parameter in state dict
17+ register_parameter (" weight" , weight_ );
1418
15- // Register bias parameter if requested: [out_features]
19+ // Register bias parameter if requested
1620 if (bias) {
17- register_parameter (" bias" , Parameter ({out_features}, DataType::F32, device) );
21+ register_parameter (" bias" , bias_ );
1822 }
1923
2024 spdlog::debug (" Created Linear module: in_features={}, out_features={}, bias={}" ,
2125 in_features, out_features, bias);
2226}
2327
24- Tensor Linear::forward (const Tensor &input) const {
25- auto sd = state_dict ();
26- auto weight = sd.at (" weight" );
27- auto bias_it = sd.find (" bias" );
28-
28+ Tensor Linear::compute_linear (const Tensor &input) const {
2929 // Create output tensor with shape [batch_size, out_features]
3030 auto output_shape = input->shape ();
3131 output_shape[output_shape.size () - 1 ] = out_features_;
3232 auto output = Tensor::empty (output_shape, input->dtype (), input->device ());
3333
3434 // Transpose weight: [out_features, in_features] -> [in_features, out_features]
35- auto weight_t = weight->permute ({1 , 0 });
36-
37- // InfiniLM-style linear computation: output = input @ weight_t + bias
38- // Handle bias broadcasting similar to InferenceContext::linear
39- if (bias_it != sd.end ()) {
40- auto bias = bias_it->second ;
35+ auto weight_t = weight_->permute ({1 , 0 });
4136
42- // Broadcast bias to output shape (similar to InfiniLM's bias handling)
37+ if (has_bias_) {
38+ // Broadcast bias to output shape
4339 size_t ndim_diff = output->ndim () - 1 ;
4440 std::vector<Stride> strides (ndim_diff, 0 );
45- strides.push_back (bias ->stride (0 ));
46- auto bias_view = bias ->as_strided (output->shape (), strides);
41+ strides.push_back (bias_ ->stride (0 ));
42+ auto bias_view = bias_ ->as_strided (output->shape (), strides);
4743
4844 // First set output to bias (broadcasted)
4945 infinicore::op::rearrange_ (output, bias_view);
@@ -59,68 +55,19 @@ Tensor Linear::forward(const Tensor &input) const {
5955 return output;
6056}
6157
62- Tensor Linear::forward (const Tensor &input, const Tensor &residual) const {
63- auto sd = state_dict ();
64- auto weight = sd.at (" weight" );
65- auto bias_it = sd.find (" bias" );
66-
67- // Create output tensor with shape [batch_size, out_features]
68- auto output_shape = input->shape ();
69- output_shape[output_shape.size () - 1 ] = out_features_;
70- auto output = Tensor::empty (output_shape, input->dtype (), input->device ());
71-
72- // Transpose weight: [out_features, in_features] -> [in_features, out_features]
73- auto weight_t = weight->permute ({1 , 0 });
74-
75- // InfiniLM-style computation with residual: output = input @ weight_t + bias + residual
76- if (bias_it != sd.end ()) {
77- auto bias = bias_it->second ;
78-
79- // Broadcast bias to output shape
80- size_t ndim_diff = output->ndim () - 1 ;
81- std::vector<Stride> strides (ndim_diff, 0 );
82- strides.push_back (bias->stride (0 ));
83- auto bias_view = bias->as_strided (output->shape (), strides);
84-
85- // First set output to bias (broadcasted)
86- infinicore::op::rearrange_ (output, bias_view);
58+ Tensor Linear::forward (const Tensor &input) const {
59+ return compute_linear (input);
60+ }
8761
88- // Compute matmul result separately, then add to output
89- auto matmul_result = infinicore::op::matmul (input, weight_t );
90- infinicore::op::add_ (output, output, matmul_result);
62+ Tensor Linear::forward (const Tensor &input, const Tensor &residual) const {
63+ auto output = compute_linear (input);
9164
92- // Add residual: output = output + residual
93- infinicore::op::add_ (output, output, residual);
94- } else {
95- // No bias: compute output = input @ weight_t + residual
96- infinicore::op::matmul_ (output, input, weight_t );
97- infinicore::op::add_ (output, output, residual);
98- }
65+ // Add residual: output = output + residual
66+ infinicore::op::add_ (output, output, residual);
9967
10068 return output;
10169}
10270
103- Tensor Linear::weight () const {
104- auto sd = state_dict ();
105- auto it = sd.find (" weight" );
106- if (it != sd.end ()) {
107- return it->second ;
108- }
109- throw std::runtime_error (" Weight parameter not found" );
110- }
111-
112- Tensor Linear::bias () const {
113- if (!has_bias_) {
114- throw std::runtime_error (" Linear module does not have bias" );
115- }
116- auto sd = state_dict ();
117- auto it = sd.find (" bias" );
118- if (it != sd.end ()) {
119- return it->second ;
120- }
121- throw std::runtime_error (" Bias parameter not found" );
122- }
123-
12471std::string Linear::extra_repr () const {
12572 return " in_features=" + std::to_string (in_features_) + " , out_features=" + std::to_string (out_features_) + " , bias=" + (has_bias_ ? " true" : " false" );
12673}
0 commit comments