Skip to content

Commit 1a27b48

Browse files
committed
issue/931 - ninetoothed swiglu
1 parent 6f8a443 commit 1a27b48

4 files changed

Lines changed: 172 additions & 0 deletions

File tree

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import ninetoothed
2+
import swiglu
3+
4+
import infiniop.ninetoothed.build
5+
6+
7+
def build():
8+
MAX_NDIM = 5
9+
10+
ndim_values = range(1, MAX_NDIM + 1)
11+
dtype_values = (
12+
ninetoothed.float16,
13+
ninetoothed.bfloat16,
14+
ninetoothed.float32,
15+
)
16+
17+
constexpr_param_grid = {
18+
"ndim": ndim_values,
19+
"dtype": dtype_values,
20+
"block_size": (1024,),
21+
}
22+
23+
infiniop.ninetoothed.build.build(
24+
swiglu.premake,
25+
constexpr_param_grid,
26+
caller="cuda",
27+
op_name="swiglu",
28+
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
29+
)
30+
31+
32+
build()
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#ifndef SWIGLU_H
2+
#define SWIGLU_H
3+
4+
#include "../../../handle.h"
5+
#include "../../../operator.h"
6+
#include "../../../tensor.h"
7+
8+
#include "../../../../../build/ninetoothed/swiglu.h"
9+
#include "../../../ninetoothed/utils.h"
10+
11+
namespace op::swiglu::ninetoothed {
12+
class Descriptor final : public InfiniopDescriptor {
13+
14+
public:
15+
Descriptor(
16+
infiniopHandle_t handle,
17+
infiniopTensorDescriptor_t out_desc,
18+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) : InfiniopDescriptor{handle->device, handle->device_id},
19+
out_shape_{out_desc->shape()},
20+
out_strides_{out_desc->strides()},
21+
up_shape_{input_desc_vec[0]->shape()},
22+
up_strides_{input_desc_vec[0]->strides()},
23+
gate_shape_{input_desc_vec[1]->shape()},
24+
gate_strides_{input_desc_vec[1]->strides()},
25+
dtype_{out_desc->dtype()} {}
26+
27+
~Descriptor() = default;
28+
29+
size_t workspaceSize() const {
30+
return 0;
31+
}
32+
33+
static infiniStatus_t create(
34+
infiniopHandle_t handle,
35+
Descriptor **desc_ptr,
36+
infiniopTensorDescriptor_t out_desc,
37+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
38+
*desc_ptr = new Descriptor(handle, out_desc, input_desc_vec);
39+
return INFINI_STATUS_SUCCESS;
40+
}
41+
42+
infiniStatus_t calculate(
43+
void *workspace,
44+
size_t workspace_size,
45+
void *output,
46+
std::vector<const void *> inputs,
47+
void *stream) const {
48+
auto out_nt{::ninetoothed::Tensor(output, out_shape_, out_strides_)};
49+
auto up_nt{::ninetoothed::Tensor(inputs[0], up_shape_, up_strides_)};
50+
auto gate_nt{::ninetoothed::Tensor(inputs[1], gate_shape_, gate_strides_)};
51+
52+
if (launch_swiglu(stream,
53+
out_nt,
54+
up_nt,
55+
gate_nt,
56+
out_shape_.size(),
57+
dtype_,
58+
1024)) {
59+
return INFINI_STATUS_NOT_IMPLEMENTED;
60+
}
61+
62+
return INFINI_STATUS_SUCCESS;
63+
}
64+
65+
private:
66+
using Size = ::ninetoothed::Tensor<>::Size;
67+
using Stride = ::ninetoothed::Tensor<>::Stride;
68+
69+
std::vector<Size> out_shape_;
70+
std::vector<Stride> out_strides_;
71+
72+
std::vector<Size> up_shape_;
73+
std::vector<Stride> up_strides_;
74+
75+
std::vector<Size> gate_shape_;
76+
std::vector<Stride> gate_strides_;
77+
78+
infiniDtype_t dtype_;
79+
};
80+
} // namespace op::swiglu::ninetoothed
81+
82+
#endif // SWIGLU_H
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import functools
2+
3+
import ninetoothed.language as ntl
4+
from ninetoothed import Tensor
5+
6+
from ntops.kernels.element_wise import arrangement
7+
8+
9+
def application(output, up, gate):
10+
output = ntl.sigmoid(ntl.cast(gate, ntl.float32)) * gate * up # noqa: F841
11+
12+
13+
def premake(ndim, dtype=None, block_size=None):
14+
arrangement_ = functools.partial(arrangement, block_size=block_size)
15+
16+
tensors = (
17+
Tensor(ndim, dtype=dtype),
18+
Tensor(ndim, dtype=dtype),
19+
Tensor(ndim, dtype=dtype),
20+
)
21+
22+
return arrangement_, application, tensors

src/infiniop/ops/swiglu/operator.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
#include "cpu/swiglu_cpu.h"
77
#endif
88
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
9+
#if defined(ENABLE_NINETOOTHED)
10+
#include "ninetoothed/swiglu.h"
11+
#else
912
#include "nvidia/swiglu_nvidia.cuh"
1013
#endif
14+
#endif
1115
#ifdef ENABLE_KUNLUN_API
1216
#include "kunlun/swiglu_kunlun.h"
1317
#endif
@@ -46,11 +50,19 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
4650
CREATE(INFINI_DEVICE_CPU, cpu);
4751
#endif
4852
#ifdef ENABLE_NVIDIA_API
53+
#ifdef ENABLE_NINETOOTHED
54+
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
55+
#else
4956
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
5057
#endif
58+
#endif
5159
#ifdef ENABLE_ILUVATAR_API
60+
#ifdef ENABLE_NINETOOTHED
61+
CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
62+
#else
5263
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
5364
#endif
65+
#endif
5466
#ifdef ENABLE_QY_API
5567
CREATE(INFINI_DEVICE_QY, nvidia);
5668
#endif
@@ -92,11 +104,19 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
92104
GET(INFINI_DEVICE_CPU, cpu);
93105
#endif
94106
#ifdef ENABLE_NVIDIA_API
107+
#ifdef ENABLE_NINETOOTHED
108+
GET(INFINI_DEVICE_NVIDIA, ninetoothed);
109+
#else
95110
GET(INFINI_DEVICE_NVIDIA, nvidia);
96111
#endif
112+
#endif
97113
#ifdef ENABLE_ILUVATAR_API
114+
#ifdef ENABLE_NINETOOTHED
115+
GET(INFINI_DEVICE_ILUVATAR, ninetoothed);
116+
#else
98117
GET(INFINI_DEVICE_ILUVATAR, nvidia);
99118
#endif
119+
#endif
100120
#ifdef ENABLE_QY_API
101121
GET(INFINI_DEVICE_QY, nvidia);
102122
#endif
@@ -145,11 +165,19 @@ __C infiniStatus_t infiniopSwiGLU(
145165
CALCULATE(INFINI_DEVICE_CPU, cpu);
146166
#endif
147167
#ifdef ENABLE_NVIDIA_API
168+
#ifdef ENABLE_NINETOOTHED
169+
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
170+
#else
148171
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
149172
#endif
173+
#endif
150174
#ifdef ENABLE_ILUVATAR_API
175+
#ifdef ENABLE_NINETOOTHED
176+
CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
177+
#else
151178
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
152179
#endif
180+
#endif
153181
#ifdef ENABLE_QY_API
154182
CALCULATE(INFINI_DEVICE_QY, nvidia);
155183
#endif
@@ -193,11 +221,19 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
193221
DELETE(INFINI_DEVICE_CPU, cpu);
194222
#endif
195223
#ifdef ENABLE_NVIDIA_API
224+
#ifdef ENABLE_NINETOOTHED
225+
DELETE(INFINI_DEVICE_NVIDIA, ninetoothed);
226+
#else
196227
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
197228
#endif
229+
#endif
198230
#ifdef ENABLE_ILUVATAR_API
231+
#ifdef ENABLE_NINETOOTHED
232+
DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed);
233+
#else
199234
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
200235
#endif
236+
#endif
201237
#ifdef ENABLE_QY_API
202238
DELETE(INFINI_DEVICE_QY, nvidia);
203239
#endif

0 commit comments

Comments
 (0)