@@ -375,25 +375,25 @@ void TReference<AReal>::Downsample(TMatrixT<AReal> &A, TMatrixT<AReal> &B, const
375375
376376// ______________________________________________________________________________
377377template <typename AReal>
378- void TReference<AReal>::MaxPoolLayerBackward(std::vector<TMatrixT<AReal>> &activationGradientsBackward,
379- const std::vector<TMatrixT<AReal>> &activationGradients,
380- const std::vector<TMatrixT<AReal>> &indexMatrix, size_t batchSize,
381- size_t depth, size_t nLocalViews)
378+ void TReference<AReal>::MaxPoolLayerBackward(TMatrixT<AReal> &activationGradientsBackward,
379+ const TMatrixT<AReal> &activationGradients,
380+ const TMatrixT<AReal> &indexMatrix,
381+ size_t imgHeight, size_t imgWidth, size_t fltHeight,
382+ size_t fltWidth, size_t strideRows, size_t strideCols, size_t nLocalViews)
382383{
383- for (size_t i = 0 ; i < batchSize; i++) {
384- for (size_t j = 0 ; j < depth; j++) {
384+ size_t depth = activationGradientsBackward.GetNrows ();
385385
386- // initialize to zeros
387- for (size_t t = 0 ; t < (size_t )activationGradientsBackward[i].GetNcols (); t++) {
388- activationGradientsBackward[i][j][t] = 0 ;
389- }
386+ for (size_t j = 0 ; j < depth; j++) {
387+ // initialize to zeros
388+ for (size_t t = 0 ; t < (size_t )activationGradientsBackward.GetNcols (); t++) {
389+ activationGradientsBackward[j][t] = 0 ;
390+ }
390391
391- // set values
392- for (size_t k = 0 ; k < nLocalViews; k++) {
393- AReal grad = activationGradients[i][j][k];
394- size_t winningIdx = indexMatrix[i][j][k];
395- activationGradientsBackward[i][j][winningIdx] = grad;
396- }
392+ // set values
393+ for (size_t k = 0 ; k < nLocalViews; k++) {
394+ AReal grad = activationGradients[j][k];
395+ size_t winningIdx = indexMatrix[j][k];
396+ activationGradientsBackward[j][winningIdx] += grad;
397397 }
398398 }
399399}
0 commit comments