Skip to content

Commit 0668540

Browse files
authored
feat: add mse_loss base (#344)
1 parent 46b4f81 commit 0668540

1 file changed

Lines changed: 53 additions & 0 deletions

File tree

src/base/mse_loss.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#ifndef INFINI_OPS_BASE_MSE_LOSS_H_
2+
#define INFINI_OPS_BASE_MSE_LOSS_H_
3+
4+
#include "operator.h"
5+
6+
namespace infini::ops {
7+
8+
class MseLoss : public Operator<MseLoss> {
9+
public:
10+
MseLoss(const Tensor input, const Tensor target, const int64_t reduction,
11+
Tensor out)
12+
: input_shape_{input.shape()},
13+
input_strides_{input.strides()},
14+
input_type_{input.dtype()},
15+
target_shape_{target.shape()},
16+
target_strides_{target.strides()},
17+
target_type_{target.dtype()},
18+
out_shape_{out.shape()},
19+
out_strides_{out.strides()},
20+
out_type_{out.dtype()},
21+
reduction_{reduction},
22+
device_index_{out.device().index()} {}
23+
24+
virtual void operator()(const Tensor input, const Tensor target,
25+
const int64_t reduction, Tensor out) const = 0;
26+
27+
protected:
28+
Tensor::Shape input_shape_;
29+
30+
Tensor::Strides input_strides_;
31+
32+
DataType input_type_;
33+
34+
Tensor::Shape target_shape_;
35+
36+
Tensor::Strides target_strides_;
37+
38+
DataType target_type_;
39+
40+
Tensor::Shape out_shape_;
41+
42+
Tensor::Strides out_strides_;
43+
44+
DataType out_type_;
45+
46+
int64_t reduction_{};
47+
48+
int device_index_{0};
49+
};
50+
51+
} // namespace infini::ops
52+
53+
#endif

0 commit comments

Comments
 (0)