Skip to content

Commit f14756f

Browse files
tianyuxbeargongchensu
authored andcommitted
issue/456/feat: add silu operator
1 parent 3959c94 commit f14756f

15 files changed

Lines changed: 725 additions & 1 deletion

File tree

include/infiniop.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "infiniop/ops/relu.h"
1616
#include "infiniop/ops/rms_norm.h"
1717
#include "infiniop/ops/rope.h"
18+
#include "infiniop/ops/silu.h"
1819
#include "infiniop/ops/softplus.h"
1920
#include "infiniop/ops/sub.h"
2021
#include "infiniop/ops/swiglu.h"

include/infiniop/ops/silu.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef __INFINIOP_SILU_API_H__
2+
#define __INFINIOP_SILU_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopSiluDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateSiluDescriptor(infiniopHandle_t handle,
9+
infiniopSiluDescriptor_t *desc_ptr,
10+
infiniopTensorDescriptor_t output,
11+
infiniopTensorDescriptor_t intput);
12+
13+
__C __export infiniStatus_t infiniopGetSiluWorkspaceSize(infiniopSiluDescriptor_t desc, size_t *size);
14+
15+
__C __export infiniStatus_t infiniopSilu(infiniopSiluDescriptor_t desc,
16+
void *workspace,
17+
size_t workspace_size,
18+
void *output,
19+
const void *intput,
20+
void *stream);
21+
22+
__C __export infiniStatus_t infiniopDestroySiluDescriptor(infiniopSiluDescriptor_t desc);
23+
24+
#endif

src/infiniop-test/include/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ DECLARE_INFINIOP_TEST(swiglu)
1515
DECLARE_INFINIOP_TEST(add)
1616
DECLARE_INFINIOP_TEST(causal_softmax)
1717
DECLARE_INFINIOP_TEST(rearrange)
18+
DECLARE_INFINIOP_TEST(silu)
1819
DECLARE_INFINIOP_TEST(sub)
1920

2021
#define REGISTER_INFINIOP_TEST(name) \
@@ -43,6 +44,7 @@ DECLARE_INFINIOP_TEST(sub)
4344
REGISTER_INFINIOP_TEST(causal_softmax) \
4445
REGISTER_INFINIOP_TEST(rearrange) \
4546
REGISTER_INFINIOP_TEST(sub) \
47+
REGISTER_INFINIOP_TEST(silu) \
4648
}
4749

4850
namespace infiniop_test {

src/infiniop-test/src/ops/silu.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#include "ops.hpp"
2+
#include "utils.hpp"
3+
#include <infinirt.h>
4+
#include <iomanip>
5+
#include <iostream>
6+
7+
namespace infiniop_test::silu {
8+
struct Test::Attributes {
9+
std::shared_ptr<Tensor> input;
10+
std::shared_ptr<Tensor> output;
11+
std::shared_ptr<Tensor> ans;
12+
};
13+
14+
std::shared_ptr<Test> Test::build(
15+
std::unordered_map<std::string, std::vector<uint8_t>> attributes,
16+
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors,
17+
double rtol, double atol) {
18+
auto test = std::shared_ptr<Test>(new Test(rtol, atol));
19+
test->_attributes = new Attributes();
20+
if (tensors.find("input") == tensors.end()
21+
|| tensors.find("output") == tensors.end()
22+
|| tensors.find("ans") == tensors.end()) {
23+
throw std::runtime_error("Invalid Test");
24+
}
25+
26+
test->_attributes->input = tensors["input"];
27+
test->_attributes->output = tensors["output"];
28+
test->_attributes->ans = tensors["ans"];
29+
30+
return test;
31+
}
32+
33+
std::shared_ptr<infiniop_test::Result> Test::run(
34+
infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) {
35+
infiniopSiluDescriptor_t op_desc;
36+
auto input = _attributes->input->to(device, device_id);
37+
auto output = _attributes->output->to(device, device_id);
38+
CHECK_OR(infiniopCreateSiluDescriptor(handle, &op_desc,
39+
output->desc(),
40+
input->desc()),
41+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor."));
42+
size_t workspace_size;
43+
CHECK_OR(infiniopGetSiluWorkspaceSize(op_desc, &workspace_size),
44+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size."));
45+
void *workspace;
46+
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
47+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace."));
48+
CHECK_OR(infiniopSilu(op_desc, workspace, workspace_size,
49+
output->data(),
50+
input->data(),
51+
nullptr),
52+
return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution."));
53+
54+
try {
55+
allClose(output, _attributes->ans, _rtol, _atol);
56+
} catch (const std::exception &e) {
57+
return TEST_FAILED(RESULT_INCORRECT, e.what());
58+
}
59+
60+
double elapsed_time = 0.;
61+
62+
elapsed_time = benchmark(
63+
[=]() {
64+
infiniopSilu(
65+
op_desc, workspace, workspace_size,
66+
output->data(),
67+
input->data(),
68+
nullptr);
69+
},
70+
warm_ups, iterations);
71+
72+
return TEST_PASSED(elapsed_time);
73+
}
74+
75+
std::vector<std::string> Test::attribute_names() {
76+
return {};
77+
}
78+
79+
std::vector<std::string> Test::tensor_names() {
80+
return {"input", "output", "ans"};
81+
}
82+
83+
std::vector<std::string> Test::output_names() {
84+
return {"output"};
85+
}
86+
87+
std::string Test::toString() const {
88+
std::ostringstream oss;
89+
oss << op_name() << std::endl;
90+
oss << "- input: " << _attributes->input->info() << std::endl;
91+
oss << "- output: " << _attributes->output->info() << std::endl;
92+
oss << std::scientific << std::setprecision(2);
93+
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl;
94+
return oss.str();
95+
}
96+
97+
Test::~Test() {
98+
delete _attributes;
99+
}
100+
101+
} // namespace infiniop_test::silu
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "silu_cpu.h"
2+
3+
namespace op::silu::cpu {
4+
5+
Descriptor::~Descriptor() = default;
6+
7+
infiniStatus_t Descriptor::create(
8+
infiniopHandle_t handle_,
9+
Descriptor **desc_ptr,
10+
infiniopTensorDescriptor_t out_desc,
11+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
12+
13+
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
14+
auto dtype = out_desc->dtype();
15+
16+
const auto &input_desc = input_desc_vec.at(0);
17+
const auto &output_shape = out_desc->shape();
18+
const auto &input_shape = input_desc->shape();
19+
20+
CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
21+
22+
CHECK_SAME_SHAPE(output_shape, input_shape);
23+
24+
// create CPU elementwise descriptor
25+
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
26+
27+
return INFINI_STATUS_SUCCESS;
28+
}
29+
30+
infiniStatus_t Descriptor::calculate(
31+
void *workspace,
32+
size_t workspace_size,
33+
void *output,
34+
std::vector<const void *> inputs,
35+
void *stream) const {
36+
37+
switch (_dtype) {
38+
case INFINI_DTYPE_BF16:
39+
return _device_info->calculate<SiluOp, bf16_t>(_info, output, inputs, stream);
40+
case INFINI_DTYPE_F16:
41+
return _device_info->calculate<SiluOp, fp16_t>(_info, output, inputs, stream);
42+
case INFINI_DTYPE_F32:
43+
return _device_info->calculate<SiluOp, float>(_info, output, inputs, stream);
44+
case INFINI_DTYPE_F64:
45+
return _device_info->calculate<SiluOp, double>(_info, output, inputs, stream);
46+
default:
47+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
48+
}
49+
50+
return INFINI_STATUS_SUCCESS;
51+
}
52+
} // namespace op::silu::cpu
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef __SILU_CPU_H__
2+
#define __SILU_CPU_H__
3+
4+
#include "../../../elementwise/cpu/elementwise_cpu.h"
5+
6+
ELEMENTWISE_DESCRIPTOR(silu, cpu)
7+
8+
#include <cmath>
9+
10+
namespace op::silu::cpu {
11+
typedef struct SiluOp {
12+
public:
13+
static constexpr size_t num_inputs = 1;
14+
15+
template <typename T>
16+
T operator()(const T &x) const {
17+
return x / (static_cast<T>(1) + std::exp(-x));
18+
}
19+
} SiluOp;
20+
21+
} // namespace op::silu::cpu
22+
23+
#endif // __SILU_CPU_H__
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#ifndef __SILU_CUDA_H__
2+
#define __SILU_CUDA_H__
3+
4+
#include <cmath>
5+
6+
namespace op::silu::cuda {
7+
8+
typedef struct SiluOp {
9+
public:
10+
static constexpr size_t num_inputs = 1;
11+
template <typename T>
12+
__device__ __forceinline__ T operator()(const T &x) const {
13+
if constexpr (std::is_same_v<T, half2>) {
14+
// half2向量化优化
15+
return __hmul2(x, __h2div(__float2half2_rn(1.0f),
16+
__hadd2(__float2half2_rn(1.0f), h2exp(__hneg2(x)))));
17+
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
18+
// BF16
19+
const float x_f = __bfloat162float(x);
20+
return __float2bfloat16(x_f / (1.0f + __expf(-x_f)));
21+
} else if constexpr (std::is_same_v<T, half>) {
22+
// FP16
23+
const float x_f = __half2float(x);
24+
return __float2half(x_f / (1.0f + __expf(-x_f)));
25+
} else if constexpr (std::is_same_v<T, float>) {
26+
// FP32
27+
return x * (1.0f / (1.0f + __expf(-x)));
28+
} else if constexpr (std::is_same_v<T, double>) {
29+
// FP64
30+
return x / (1.0 + exp(-x));
31+
}
32+
}
33+
} SiluOp;
34+
35+
} // namespace op::silu::cuda
36+
37+
#endif // __SILU_CUDA_H__
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __SILU_METAX_API_H__
2+
#define __SILU_METAX_API_H__
3+
4+
#include "../../../elementwise/metax/elementwise_metax_api.h"
5+
6+
ELEMENTWISE_DESCRIPTOR(silu, metax)
7+
8+
#endif // __SILU_METAX_API_H__
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include "silu_metax.h"
2+
3+
#include "../../../elementwise/metax/elementwise_metax.h"
4+
5+
#include "../cuda/kernel.cuh"
6+
7+
namespace op::silu::metax {
8+
9+
Descriptor::~Descriptor() = default;
10+
11+
infiniStatus_t Descriptor::create(
12+
infiniopHandle_t handle_,
13+
Descriptor **desc_ptr,
14+
infiniopTensorDescriptor_t out_desc,
15+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
16+
17+
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
18+
auto dtype = out_desc->dtype();
19+
20+
const auto &input_desc = input_desc_vec.at(0);
21+
const auto &output_shape = out_desc->shape();
22+
const auto &input_shape = input_desc->shape();
23+
24+
CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
25+
26+
CHECK_SAME_SHAPE(output_shape, input_shape);
27+
28+
// create METAX elementwise descriptor
29+
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
30+
31+
return INFINI_STATUS_SUCCESS;
32+
}
33+
34+
infiniStatus_t Descriptor::calculate(
35+
void *workspace,
36+
size_t workspace_size,
37+
void *output,
38+
std::vector<const void *> inputs,
39+
void *stream) const {
40+
41+
if (workspace_size < _workspace_size) {
42+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
43+
}
44+
45+
switch (_dtype) {
46+
case INFINI_DTYPE_BF16:
47+
return _device_info->calculate<256, cuda::SiluOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
48+
case INFINI_DTYPE_F16:
49+
return _device_info->calculate<256, cuda::SiluOp, half>(_info, workspace, output, inputs, stream);
50+
case INFINI_DTYPE_F32:
51+
return _device_info->calculate<256, cuda::SiluOp, float>(_info, workspace, output, inputs, stream);
52+
case INFINI_DTYPE_F64:
53+
return _device_info->calculate<256, cuda::SiluOp, double>(_info, workspace, output, inputs, stream);
54+
default:
55+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
56+
}
57+
58+
return INFINI_STATUS_SUCCESS;
59+
}
60+
} // namespace op::silu::metax
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include "../../../elementwise/nvidia/elementwise_nvidia.cuh"
2+
3+
#include "../cuda/kernel.cuh"
4+
#include "silu_nvidia.cuh"
5+
6+
namespace op::silu::nvidia {
7+
8+
Descriptor::~Descriptor() = default;
9+
10+
infiniStatus_t Descriptor::create(
11+
infiniopHandle_t handle_,
12+
Descriptor **desc_ptr,
13+
infiniopTensorDescriptor_t out_desc,
14+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
15+
16+
auto handle = reinterpret_cast<device::nvidia::Handle *>(handle_);
17+
auto dtype = out_desc->dtype();
18+
19+
const auto &input_desc = input_desc_vec.at(0);
20+
const auto &output_shape = out_desc->shape();
21+
const auto &input_shape = input_desc->shape();
22+
23+
CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
24+
25+
CHECK_SAME_SHAPE(output_shape, input_shape);
26+
27+
// create CUDA elementwise descriptor
28+
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
29+
30+
return INFINI_STATUS_SUCCESS;
31+
}
32+
33+
infiniStatus_t Descriptor::calculate(
34+
void *workspace,
35+
size_t workspace_size,
36+
void *output,
37+
std::vector<const void *> inputs,
38+
void *stream) const {
39+
40+
if (workspace_size < _workspace_size) {
41+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
42+
}
43+
44+
switch (_dtype) {
45+
case INFINI_DTYPE_BF16:
46+
return _device_info->calculate<256, cuda::SiluOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
47+
case INFINI_DTYPE_F16:
48+
return _device_info->calculate<256, cuda::SiluOp, half>(_info, workspace, output, inputs, stream);
49+
case INFINI_DTYPE_F32:
50+
return _device_info->calculate<256, cuda::SiluOp, float>(_info, workspace, output, inputs, stream);
51+
case INFINI_DTYPE_F64:
52+
return _device_info->calculate<256, cuda::SiluOp, double>(_info, workspace, output, inputs, stream);
53+
default:
54+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
55+
}
56+
57+
return INFINI_STATUS_SUCCESS;
58+
}
59+
} // namespace op::silu::nvidia

0 commit comments

Comments
 (0)