Skip to content

Commit cac0e82

Browse files
committed
feat: add nll_loss2d_forward base
1 parent 78424f7 commit cac0e82

1 file changed

Lines changed: 64 additions & 0 deletions

File tree

src/base/nll_loss2d_forward.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// AUTO-GENERATED by `scripts/generate_torch_ops.py` — DO NOT EDIT.
2+
#ifndef INFINI_OPS_BASE_NLL_LOSS2D_FORWARD_H_
3+
#define INFINI_OPS_BASE_NLL_LOSS2D_FORWARD_H_
4+
5+
#include "operator.h"
6+
7+
namespace infini::ops {
8+
9+
class NllLoss2dForward : public Operator<NllLoss2dForward> {
10+
public:
11+
NllLoss2dForward(const Tensor input, const Tensor target, const int64_t reduction, const int64_t ignore_index, Tensor output, Tensor total_weight)
12+
: input_shape_{input.shape()},
13+
input_strides_{input.strides()},
14+
input_type_{input.dtype()},
15+
target_shape_{target.shape()},
16+
target_strides_{target.strides()},
17+
target_type_{target.dtype()},
18+
output_shape_{output.shape()},
19+
output_strides_{output.strides()},
20+
output_type_{output.dtype()},
21+
total_weight_shape_{total_weight.shape()},
22+
total_weight_strides_{total_weight.strides()},
23+
total_weight_type_{total_weight.dtype()},
24+
reduction_{reduction},
25+
ignore_index_{ignore_index},
26+
device_index_{output.device().index()} {}
27+
28+
virtual void operator()(const Tensor input, const Tensor target, const int64_t reduction, const int64_t ignore_index, Tensor output, Tensor total_weight) const = 0;
29+
30+
protected:
31+
Tensor::Shape input_shape_;
32+
33+
Tensor::Strides input_strides_;
34+
35+
DataType input_type_;
36+
37+
Tensor::Shape target_shape_;
38+
39+
Tensor::Strides target_strides_;
40+
41+
DataType target_type_;
42+
43+
Tensor::Shape output_shape_;
44+
45+
Tensor::Strides output_strides_;
46+
47+
DataType output_type_;
48+
49+
Tensor::Shape total_weight_shape_;
50+
51+
Tensor::Strides total_weight_strides_;
52+
53+
DataType total_weight_type_;
54+
55+
int64_t reduction_{};
56+
57+
int64_t ignore_index_{};
58+
59+
int device_index_{0};
60+
};
61+
62+
} // namespace infini::ops
63+
64+
#endif

0 commit comments

Comments
 (0)