@@ -50,38 +50,71 @@ std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, cons
5050 return {output};
5151}
5252
53- std::tuple<std:: shared_ptr<Tensor>, std::shared_ptr<Tensor>>
54- MatmulBackward ( const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other ,
55- const std::shared_ptr<Tensor > &grad_output ) {
53+ std::shared_ptr<Tensor> MatmulBackwardInput1 ( const std::shared_ptr<Tensor> &other,
54+ const std::shared_ptr<Tensor> &grad_output ,
55+ const std::vector< int64_t > &input_dims ) {
5656 /*
5757 grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T
58- grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n]
5958 */
60- const auto &input_dims = input->Dims ();
6159 const auto &other_dims = other->Dims ();
6260 const auto &grad_output_dims = grad_output->Dims ();
6361
62+ CHECK_GE (other_dims.size (), 2 );
63+ CHECK_EQ (other_dims.size (), grad_output_dims.size ());
64+
65+ const int64_t m = grad_output_dims[grad_output_dims.size () - 2 ];
66+ const int64_t k = other_dims[other_dims.size () - 2 ];
67+ const int64_t n = grad_output_dims[grad_output_dims.size () - 1 ];
68+
69+ const int64_t bs
70+ = std::accumulate (grad_output_dims.rbegin () + 2 , grad_output_dims.rend (), 1 , std::multiplies<int64_t >{});
71+ for (int64_t i = 0 ; i < grad_output_dims.size () - 2 ; ++i) {
72+ CHECK_EQ (grad_output_dims[i], other_dims[i]) << " Batch dims must match" ;
73+ }
74+
75+ auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32 );
76+ grad_input->Fill <float >(0 .0f );
77+
78+ for (int64_t b = 0 ; b < bs; ++b) {
79+ for (int64_t i = 0 ; i < m; ++i) {
80+ for (int64_t j = 0 ; j < n; ++j) {
81+ const float grad = static_cast <float *>(grad_output->DataPtr ())[b * m * n + i * n + j];
82+ for (int64_t p = 0 ; p < k; ++p) {
83+ const auto other_idx = b * k * n + p * n + j;
84+ static_cast <float *>(grad_input->DataPtr ())[b * m * k + i * k + p]
85+ += grad * static_cast <const float *>(other->DataPtr ())[other_idx];
86+ }
87+ }
88+ }
89+ }
90+ return grad_input;
91+ }
92+
93+ std::shared_ptr<Tensor> MatmulBackwardInput2 (const std::shared_ptr<Tensor> &input1,
94+ const std::shared_ptr<Tensor> &grad_output,
95+ const std::vector<int64_t > &other_dims) {
96+ /*
97+ grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n]
98+ */
99+ const auto &input_dims = input1->Dims ();
100+ const auto &grad_output_dims = grad_output->Dims ();
101+
64102 CHECK_GE (input_dims.size (), 2 );
65- CHECK_EQ (input_dims.size (), other_dims.size ());
66103 CHECK_EQ (input_dims.size (), grad_output_dims.size ());
67104
68105 const int64_t m = input_dims[input_dims.size () - 2 ];
69106 const int64_t k = input_dims[input_dims.size () - 1 ];
70- CHECK_EQ (k, other_dims[other_dims.size () - 2 ]);
71- const int64_t n = other_dims[other_dims.size () - 1 ];
72-
107+ const int64_t n = grad_output_dims[grad_output_dims.size () - 1 ];
73108 CHECK_EQ (m, grad_output_dims[grad_output_dims.size () - 2 ]);
74- CHECK_EQ (n, grad_output_dims[grad_output_dims .size () - 1 ]);
109+ CHECK_EQ (k, other_dims[other_dims .size () - 2 ]);
75110
76111 const int64_t bs = std::accumulate (input_dims.rbegin () + 2 , input_dims.rend (), 1 , std::multiplies<int64_t >{});
77112 for (int64_t i = 0 ; i < input_dims.size () - 2 ; ++i) {
78- CHECK_EQ (input_dims[i], other_dims[i]) << " Batch dims must match" ;
79113 CHECK_EQ (input_dims[i], grad_output_dims[i]) << " Batch dims must match" ;
114+ CHECK_EQ (input_dims[i], other_dims[i]) << " Batch dims must match" ;
80115 }
81116
82- auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32 );
83117 auto grad_other = std::make_shared<Tensor>(other_dims, DataType::kFLOAT32 );
84- grad_input->Fill <float >(0 .0f );
85118 grad_other->Fill <float >(0 .0f );
86119
87120 for (int64_t b = 0 ; b < bs; ++b) {
@@ -90,16 +123,13 @@ MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
90123 const float grad = static_cast <float *>(grad_output->DataPtr ())[b * m * n + i * n + j];
91124 for (int64_t p = 0 ; p < k; ++p) {
92125 const auto input_idx = b * m * k + i * k + p;
93- const auto other_idx = b * k * n + p * n + j;
94- static_cast <float *>(grad_input->DataPtr ())[input_idx]
95- += grad * static_cast <const float *>(other->DataPtr ())[other_idx];
96- static_cast <float *>(grad_other->DataPtr ())[other_idx]
97- += grad * static_cast <const float *>(input->DataPtr ())[input_idx];
126+ static_cast <float *>(grad_other->DataPtr ())[b * k * n + p * n + j]
127+ += grad * static_cast <const float *>(input1->DataPtr ())[input_idx];
98128 }
99129 }
100130 }
101131 }
102- return {grad_input, grad_other} ;
132+ return grad_other;
103133}
104134
105135std::shared_ptr<Tensor> LinearForward (const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight,
@@ -201,7 +231,8 @@ std::shared_ptr<Tensor> LinearBackwardBias(const std::shared_ptr<Tensor> &grad_o
201231 REGISTER_KERNEL (infini_train::Device::DeviceType::kCPU , kernel_name, infini_train::kernels::cpu::kernel_name)
202232
203233REGISTER_CPU_LINEAR_KERNEL(MatmulForward)
204- REGISTER_CPU_LINEAR_KERNEL(MatmulBackward)
234+ REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput1)
235+ REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput2)
205236REGISTER_CPU_LINEAR_KERNEL(LinearForward)
206237REGISTER_CPU_LINEAR_KERNEL(LinearBackwardInput)
207238REGISTER_CPU_LINEAR_KERNEL(LinearBackwardWeight)
0 commit comments