Skip to content

Commit 3efc214

Browse files
voltjiawooway777
authored andcommitted
使用 ninetoothed::Tensor 接入九齿的 ReLU 算子
1 parent d1821a5 commit 3efc214

2 files changed

Lines changed: 10 additions & 32 deletions

File tree

src/infiniop/ops/relu/metax/relu_metax.maca

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "../../../../../build/ninetoothed/relu.h"
44
#include "../../../devices/metax/metax_common.h"
5+
#include "../../../ninetoothed/utils.h"
56
#include "relu_metax.h"
67

78
namespace op::relu::metax {
@@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate(
4243
}
4344

4445
const auto &ndim{_info.getNdim()};
45-
const auto &x_shape_{_info.getInputShape(0)};
46-
const auto &x_strides_{_info.getInputStrides(0)};
47-
std::vector<uint64_t> x_shape_vec{x_shape_, x_shape_ + ndim};
48-
std::vector<int64_t> x_strides_vec{x_strides_, x_strides_ + ndim};
49-
auto x_data{const_cast<void *>(inputs[0])};
50-
auto x_shape{x_shape_vec.data()};
51-
auto x_strides{x_strides_vec.data()};
52-
const NineToothedTensor x{x_data, x_shape, x_strides};
53-
const auto &y_shape_{_info.getOutputShape()};
54-
const auto &y_strides_{_info.getOutputStrides()};
55-
std::vector<uint64_t> y_shape_vec{y_shape_, y_shape_ + ndim};
56-
std::vector<int64_t> y_strides_vec{y_strides_, y_strides_ + ndim};
57-
auto y_data{output};
58-
auto y_shape{y_shape_vec.data()};
59-
auto y_strides{y_strides_vec.data()};
60-
const NineToothedTensor y{y_data, y_shape, y_strides};
46+
47+
auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}};
48+
auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}};
49+
6150
constexpr auto block_size{1024};
6251

6352
switch (_dtype) {

src/infiniop/ops/relu/nvidia/relu_nvidia.cu

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#ifdef ENABLE_NINETOOTHED
22
#include "../../../../../build/ninetoothed/relu.h"
3+
#include "../../../ninetoothed/utils.h"
34
#endif
45
#include "../../../devices/nvidia/nvidia_common.cuh"
56
#include "../../../elementwise/nvidia/elementwise_nvidia.cuh"
@@ -45,22 +46,10 @@ infiniStatus_t Descriptor::calculate(
4546
}
4647
#ifdef ENABLE_NINETOOTHED
4748
const auto &ndim{_info.getNdim()};
48-
const auto &x_shape_{_info.getInputShape(0)};
49-
const auto &x_strides_{_info.getInputStrides(0)};
50-
std::vector<uint64_t> x_shape_vec{x_shape_, x_shape_ + ndim};
51-
std::vector<int64_t> x_strides_vec{x_strides_, x_strides_ + ndim};
52-
auto x_data{const_cast<void *>(inputs[0])};
53-
auto x_shape{x_shape_vec.data()};
54-
auto x_strides{x_strides_vec.data()};
55-
const NineToothedTensor x{x_data, x_shape, x_strides};
56-
const auto &y_shape_{_info.getOutputShape()};
57-
const auto &y_strides_{_info.getOutputStrides()};
58-
std::vector<uint64_t> y_shape_vec{y_shape_, y_shape_ + ndim};
59-
std::vector<int64_t> y_strides_vec{y_strides_, y_strides_ + ndim};
60-
auto y_data{output};
61-
auto y_shape{y_shape_vec.data()};
62-
auto y_strides{y_strides_vec.data()};
63-
const NineToothedTensor y{y_data, y_shape, y_strides};
49+
50+
auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}};
51+
auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}};
52+
6453
constexpr auto block_size{1024};
6554

6655
switch (_dtype) {

0 commit comments

Comments
 (0)