Skip to content

Commit 607361a

Browse files
committed
feat:add cambricon swiglu op
1 parent a334495 commit 607361a

File tree

2 files changed

+278
-0
lines changed

2 files changed

+278
-0
lines changed

src/cambricon/swiglu/kernel.mlu

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
#include "swiglu.h"
2+
3+
__nram__ char nram_buffer[NRAM_MAX_SIZE];
4+
5+
namespace infini::ops {
6+
7+
template <typename T>
8+
__mlu_device__ void ComputeSwiglu(const T *input, const T *gate, T *output,
9+
size_t n) {
10+
if constexpr (std::is_same_v<T, float>) {
11+
for (size_t i = 0; i < n; ++i) {
12+
float g = gate[i];
13+
output[i] = input[i] * g / (1.0f + expf(-g));
14+
}
15+
} else if constexpr (std::is_same_v<T, __half>) {
16+
auto *out_h = reinterpret_cast<half *>(output);
17+
auto *in_h = reinterpret_cast<const half *>(input);
18+
auto *gate_h = reinterpret_cast<const half *>(gate);
19+
__bang_active_sigmoid(out_h, gate_h, n);
20+
__bang_mul(out_h, out_h, gate_h, n);
21+
__bang_mul(out_h, out_h, in_h, n);
22+
} else {
23+
__bang_active_sigmoid(output, gate, n);
24+
__bang_mul(output, output, gate, n);
25+
__bang_mul(output, output, input, n);
26+
}
27+
}
28+
29+
template <typename T>
30+
__mlu_global__ void SwigluKernel(const T *input, const T *gate, T *output,
31+
const size_t *out_shape,
32+
const ptrdiff_t *out_strides,
33+
const size_t *input_shape,
34+
const ptrdiff_t *input_strides,
35+
const size_t *gate_shape,
36+
const ptrdiff_t *gate_strides,
37+
size_t output_size, int ndim, bool fast_path,
38+
bool out_contiguous) {
39+
size_t elements_per_task = (output_size + taskDim - 1) / taskDim;
40+
size_t start = taskId * elements_per_task;
41+
size_t end = start + elements_per_task;
42+
if (end > output_size) end = output_size;
43+
size_t num_elements = end > start ? end - start : 0;
44+
if (num_elements == 0) return;
45+
46+
size_t nram_usable = NRAM_MAX_SIZE - 256;
47+
size_t block_size = nram_usable / (3 * sizeof(T));
48+
block_size = (block_size / 64) * 64;
49+
if (block_size == 0) block_size = 64;
50+
51+
T *input_buf = reinterpret_cast<T *>(nram_buffer);
52+
T *gate_buf = input_buf + block_size;
53+
T *output_buf = gate_buf + block_size;
54+
55+
size_t processed = 0;
56+
57+
if (fast_path) {
58+
while (processed < num_elements) {
59+
size_t curr = block_size;
60+
if (curr > num_elements - processed) curr = num_elements - processed;
61+
62+
__memcpy(input_buf, input + start + processed,
63+
curr * sizeof(T), GDRAM2NRAM);
64+
__memcpy(gate_buf, gate + start + processed,
65+
curr * sizeof(T), GDRAM2NRAM);
66+
ComputeSwiglu<T>(input_buf, gate_buf, output_buf, curr);
67+
__memcpy(output + start + processed, output_buf,
68+
curr * sizeof(T), NRAM2GDRAM);
69+
70+
processed += curr;
71+
}
72+
return;
73+
}
74+
75+
// General path: handle non-contiguous tensors and broadcasting.
76+
while (processed < num_elements) {
77+
size_t curr = block_size;
78+
if (curr > num_elements - processed) curr = num_elements - processed;
79+
80+
for (size_t i = 0; i < curr; ++i) {
81+
size_t flat_idx = start + processed + i;
82+
83+
// Compute `input` offset.
84+
{
85+
size_t tmp = flat_idx;
86+
ptrdiff_t offset = 0;
87+
for (int d = ndim - 1; d >= 0; --d) {
88+
size_t coord = tmp % out_shape[d];
89+
tmp /= out_shape[d];
90+
size_t c = coord < input_shape[d] ? coord : 0;
91+
offset += static_cast<ptrdiff_t>(c) * input_strides[d];
92+
}
93+
input_buf[i] = input[offset];
94+
}
95+
96+
// Compute `gate` offset.
97+
{
98+
size_t tmp = flat_idx;
99+
ptrdiff_t offset = 0;
100+
for (int d = ndim - 1; d >= 0; --d) {
101+
size_t coord = tmp % out_shape[d];
102+
tmp /= out_shape[d];
103+
size_t c = coord < gate_shape[d] ? coord : 0;
104+
offset += static_cast<ptrdiff_t>(c) * gate_strides[d];
105+
}
106+
gate_buf[i] = gate[offset];
107+
}
108+
}
109+
110+
ComputeSwiglu<T>(input_buf, gate_buf, output_buf, curr);
111+
112+
if (out_contiguous) {
113+
__memcpy(output + start + processed, output_buf,
114+
curr * sizeof(T), NRAM2GDRAM);
115+
} else {
116+
for (size_t i = 0; i < curr; ++i) {
117+
size_t flat_idx = start + processed + i;
118+
size_t tmp = flat_idx;
119+
ptrdiff_t offset = 0;
120+
for (int d = ndim - 1; d >= 0; --d) {
121+
size_t coord = tmp % out_shape[d];
122+
offset += static_cast<ptrdiff_t>(coord) * out_strides[d];
123+
tmp /= out_shape[d];
124+
}
125+
output[offset] = output_buf[i];
126+
}
127+
}
128+
129+
processed += curr;
130+
}
131+
}
132+
133+
template <typename T>
134+
void SwigluUnion(void *workspace, int core_per_cluster, int cluster_count,
135+
cnrtQueue_t queue, void *out, const void *input,
136+
const void *gate, const size_t *out_shape,
137+
const ptrdiff_t *out_strides, const size_t *input_shape,
138+
const ptrdiff_t *input_strides, const size_t *gate_shape,
139+
const ptrdiff_t *gate_strides, size_t output_size, int ndim,
140+
bool fast_path, bool out_contiguous) {
141+
cnrtDim3_t kernel_dim;
142+
cnrtFunctionType_t kernel_type;
143+
144+
kernel_dim.x = core_per_cluster;
145+
kernel_dim.y = cluster_count;
146+
kernel_dim.z = 1;
147+
kernel_type = cnrtFuncTypeUnion1;
148+
149+
auto out_ = reinterpret_cast<T *>(out);
150+
auto input_ = reinterpret_cast<const T *>(input);
151+
auto gate_ = reinterpret_cast<const T *>(gate);
152+
153+
char *tmp = reinterpret_cast<char *>(workspace);
154+
size_t *mlu_out_shape = reinterpret_cast<size_t *>(tmp);
155+
size_t *mlu_input_shape = mlu_out_shape + ndim;
156+
size_t *mlu_gate_shape = mlu_input_shape + ndim;
157+
ptrdiff_t *mlu_out_strides =
158+
reinterpret_cast<ptrdiff_t *>(mlu_gate_shape + ndim);
159+
ptrdiff_t *mlu_input_strides = mlu_out_strides + ndim;
160+
ptrdiff_t *mlu_gate_strides = mlu_input_strides + ndim;
161+
162+
CNRT_CHECK(cnrtMemcpyAsync(mlu_out_shape, const_cast<size_t *>(out_shape),
163+
ndim * sizeof(size_t), queue,
164+
cnrtMemcpyHostToDev));
165+
CNRT_CHECK(cnrtMemcpyAsync(mlu_input_shape, const_cast<size_t *>(input_shape),
166+
ndim * sizeof(size_t), queue,
167+
cnrtMemcpyHostToDev));
168+
CNRT_CHECK(cnrtMemcpyAsync(mlu_gate_shape, const_cast<size_t *>(gate_shape),
169+
ndim * sizeof(size_t), queue,
170+
cnrtMemcpyHostToDev));
171+
CNRT_CHECK(cnrtMemcpyAsync(mlu_out_strides,
172+
const_cast<ptrdiff_t *>(out_strides),
173+
ndim * sizeof(ptrdiff_t), queue,
174+
cnrtMemcpyHostToDev));
175+
CNRT_CHECK(cnrtMemcpyAsync(mlu_input_strides,
176+
const_cast<ptrdiff_t *>(input_strides),
177+
ndim * sizeof(ptrdiff_t), queue,
178+
cnrtMemcpyHostToDev));
179+
CNRT_CHECK(cnrtMemcpyAsync(mlu_gate_strides,
180+
const_cast<ptrdiff_t *>(gate_strides),
181+
ndim * sizeof(ptrdiff_t), queue,
182+
cnrtMemcpyHostToDev));
183+
184+
SwigluKernel<T><<<kernel_dim, kernel_type, queue>>>(
185+
input_, gate_, out_, mlu_out_shape, mlu_out_strides, mlu_input_shape,
186+
mlu_input_strides, mlu_gate_shape, mlu_gate_strides, output_size, ndim,
187+
fast_path, out_contiguous);
188+
189+
cnrtQueueSync(queue);
190+
}
191+
192+
template void SwigluUnion<__half>(void *, int, int, cnrtQueue_t, void *,
193+
const void *, const void *, const size_t *,
194+
const ptrdiff_t *, const size_t *,
195+
const ptrdiff_t *, const size_t *,
196+
const ptrdiff_t *, size_t, int, bool, bool);
197+
198+
template void SwigluUnion<__bang_bfloat16>(void *, int, int, cnrtQueue_t, void *,
199+
const void *, const void *,
200+
const size_t *, const ptrdiff_t *,
201+
const size_t *, const ptrdiff_t *,
202+
const size_t *, const ptrdiff_t *,
203+
size_t, int, bool, bool);
204+
205+
template void SwigluUnion<float>(void *, int, int, cnrtQueue_t, void *,
206+
const void *, const void *, const size_t *,
207+
const ptrdiff_t *, const size_t *,
208+
const ptrdiff_t *, const size_t *,
209+
const ptrdiff_t *, size_t, int, bool, bool);
210+
211+
} // namespace infini::ops

src/cambricon/swiglu/swiglu.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#ifndef INFINI_OPS_CAMBRICON_SWIGLU_SWIGLU_H_
2+
#define INFINI_OPS_CAMBRICON_SWIGLU_SWIGLU_H_
3+
4+
#include "cambricon/common.h"
5+
#include "base/swiglu.h"
6+
#include "cambricon/data_type_.h"
7+
8+
namespace infini::ops {
9+
10+
template <typename T>
11+
void SwigluUnion(void *workspace, int core_per_cluster, int cluster_count,
12+
cnrtQueue_t queue, void *out, const void *input,
13+
const void *gate, const size_t *out_shape,
14+
const ptrdiff_t *out_strides, const size_t *input_shape,
15+
const ptrdiff_t *input_strides, const size_t *gate_shape,
16+
const ptrdiff_t *gate_strides, size_t output_size, int ndim,
17+
bool fast_path, bool out_contiguous);
18+
19+
template <>
20+
class Operator<Swiglu, Device::Type::kCambricon> : public Swiglu {
21+
public:
22+
Operator(const Tensor input, const Tensor gate, Tensor out)
23+
: Swiglu{input, gate, out} {
24+
cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster,
25+
&cluster_count);
26+
cnrtMalloc(&default_workspace_, workspace_size_in_bytes());
27+
}
28+
29+
void operator()(const Tensor input, const Tensor gate,
30+
Tensor out) const override {
31+
auto queue = static_cast<cnrtQueue_t>(stream_ ? stream_ : 0);
32+
auto workspace{workspace_ ? workspace_ : default_workspace_};
33+
34+
bool fast_path = is_input_contiguous_ && is_gate_contiguous_ &&
35+
is_out_contiguous_ && input_shape_ == out_shape_ &&
36+
gate_shape_ == out_shape_;
37+
38+
DispatchFunc<List<DataType::kFloat16, DataType::kBFloat16,
39+
DataType::kFloat32>>(
40+
{static_cast<int64_t>(out_type_)},
41+
[&](auto tag) {
42+
using T = TypeMapType<Device::Type::kCambricon, ListGet<0>(tag)>;
43+
SwigluUnion<T>(workspace, core_per_cluster, cluster_count, queue,
44+
out.data(), input.data(), gate.data(),
45+
out_shape_.data(), out_strides_.data(),
46+
input_shape_.data(), input_strides_.data(),
47+
gate_shape_.data(), gate_strides_.data(),
48+
output_size_, ndim_, fast_path,
49+
is_out_contiguous_);
50+
},
51+
"CambriconSwiglu::operator() - output dispatch");
52+
}
53+
54+
~Operator() { cnrtFree(default_workspace_); }
55+
56+
std::size_t workspace_size_in_bytes() const override {
57+
return ndim_ * (3 * sizeof(size_t) + 3 * sizeof(ptrdiff_t));
58+
}
59+
60+
void *default_workspace_{nullptr};
61+
int core_per_cluster = 0;
62+
int cluster_count = 0;
63+
};
64+
65+
} // namespace infini::ops
66+
67+
#endif

0 commit comments

Comments
 (0)