11#include "equal_metax.h"
22
33#include "../../../elementwise/metax/elementwise_metax.h"
4- #include <type_traits>
54
6- namespace op::equal::metax {
7-
8- struct EqualOp {
9- static constexpr size_t num_inputs = 2;
5+ #include "../cuda/kernel.cuh"
106
11- template <typename Tout, typename Tin0, typename Tin1>
12- __device__ __forceinline__ bool operator()(const Tin0 &a, const Tin1 &b) const {
13- if constexpr (std::is_same_v<Tin0, Tin1>) {
14- return static_cast<Tout>(a == b);
15- } else {
16- return false;
17- }
18- }
19- };
7+ namespace op::equal::metax {
208
219Descriptor::~Descriptor() = default;
2210
@@ -25,54 +13,44 @@ infiniStatus_t Descriptor::create(
2513 Descriptor **desc_ptr,
2614 infiniopTensorDescriptor_t out_desc,
2715 std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
28-
2916 auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
30-
3117 const auto &a_desc = input_desc_vec.at(0);
3218 auto compute_dtype = a_desc->dtype();
3319 auto out_dtype = out_desc->dtype();
34-
3520 const auto &b_desc = input_desc_vec.at(1);
3621 const auto &c_shape = out_desc->shape();
3722 const auto &a_shape = a_desc->shape();
3823 const auto &b_shape = b_desc->shape();
39-
4024 CHECK_DTYPE(compute_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16,
4125 INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_F64);
42-
4326 CHECK_DTYPE(out_dtype, INFINI_DTYPE_BOOL);
44-
4527 CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
46-
4728 CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, compute_dtype, out_desc, input_desc_vec)
48-
4929 return INFINI_STATUS_SUCCESS;
5030}
51-
5231infiniStatus_t Descriptor::calculate(
5332 void *workspace,
5433 size_t workspace_size,
5534 void *output,
5635 std::vector<const void *> inputs,
5736 void *stream) const {
58-
5937 if (workspace_size < _workspace_size) {
6038 return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
6139 }
6240
6341 switch (_dtype) {
6442 case INFINI_DTYPE_F16:
65- return _device_info->calculate<256, EqualOp, bool, half, half>(_info, workspace, output, inputs, stream);
43+ return _device_info->calculate<256, cuda:: EqualOp, bool, half, half>(_info, workspace, output, inputs, stream);
6644 case INFINI_DTYPE_BF16:
67- return _device_info->calculate<256, EqualOp, bool, cuda_bfloat16, cuda_bfloat16>(_info, workspace, output, inputs, stream);
45+ return _device_info->calculate<256, cuda:: EqualOp, bool, cuda_bfloat16, cuda_bfloat16>(_info, workspace, output, inputs, stream);
6846 case INFINI_DTYPE_F32:
69- return _device_info->calculate<256, EqualOp, bool, float, float>(_info, workspace, output, inputs, stream);
47+ return _device_info->calculate<256, cuda:: EqualOp, bool, float, float>(_info, workspace, output, inputs, stream);
7048 case INFINI_DTYPE_I32:
71- return _device_info->calculate<256, EqualOp, bool, int32_t, int32_t>(_info, workspace, output, inputs, stream);
49+ return _device_info->calculate<256, cuda:: EqualOp, bool, int32_t, int32_t>(_info, workspace, output, inputs, stream);
7250 case INFINI_DTYPE_I64:
73- return _device_info->calculate<256, EqualOp, bool, int64_t, int64_t>(_info, workspace, output, inputs, stream);
51+ return _device_info->calculate<256, cuda:: EqualOp, bool, int64_t, int64_t>(_info, workspace, output, inputs, stream);
7452 case INFINI_DTYPE_F64:
75- return _device_info->calculate<256, EqualOp, bool, double, double>(_info, workspace, output, inputs, stream);
53+ return _device_info->calculate<256, cuda:: EqualOp, bool, double, double>(_info, workspace, output, inputs, stream);
7654 default:
7755 return INFINI_STATUS_BAD_TENSOR_DTYPE;
7856 }
0 commit comments