|
1 | 1 | #ifdef ENABLE_NINETOOTHED |
2 | 2 | #include "../../../../../build/ninetoothed/relu.h" |
| 3 | +#include "../../../ninetoothed/utils.h" |
3 | 4 | #endif |
4 | 5 | #include "../../../devices/nvidia/nvidia_common.cuh" |
5 | 6 | #include "../../../elementwise/nvidia/elementwise_nvidia.cuh" |
@@ -45,22 +46,10 @@ infiniStatus_t Descriptor::calculate( |
45 | 46 | } |
46 | 47 | #ifdef ENABLE_NINETOOTHED |
47 | 48 | const auto &ndim{_info.getNdim()}; |
48 | | - const auto &x_shape_{_info.getInputShape(0)}; |
49 | | - const auto &x_strides_{_info.getInputStrides(0)}; |
50 | | - std::vector<uint64_t> x_shape_vec{x_shape_, x_shape_ + ndim}; |
51 | | - std::vector<int64_t> x_strides_vec{x_strides_, x_strides_ + ndim}; |
52 | | - auto x_data{const_cast<void *>(inputs[0])}; |
53 | | - auto x_shape{x_shape_vec.data()}; |
54 | | - auto x_strides{x_strides_vec.data()}; |
55 | | - const NineToothedTensor x{x_data, x_shape, x_strides}; |
56 | | - const auto &y_shape_{_info.getOutputShape()}; |
57 | | - const auto &y_strides_{_info.getOutputStrides()}; |
58 | | - std::vector<uint64_t> y_shape_vec{y_shape_, y_shape_ + ndim}; |
59 | | - std::vector<int64_t> y_strides_vec{y_strides_, y_strides_ + ndim}; |
60 | | - auto y_data{output}; |
61 | | - auto y_shape{y_shape_vec.data()}; |
62 | | - auto y_strides{y_strides_vec.data()}; |
63 | | - const NineToothedTensor y{y_data, y_shape, y_strides}; |
| 49 | + |
| 50 | + auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}}; |
| 51 | + auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}}; |
| 52 | + |
64 | 53 | constexpr auto block_size{1024}; |
65 | 54 |
|
66 | 55 | switch (_dtype) { |
|
0 commit comments