Skip to content

Commit d8ea2a7

Browse files
committed
feat(cambricon): add Swiglu op in Cambricon
1 parent c4141be commit d8ea2a7

2 files changed

Lines changed: 271 additions & 0 deletions

File tree

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
#include "swiglu.h"
2+
3+
__nram__ char nram_buffer[NRAM_MAX_SIZE];
4+
5+
namespace infini::ops {
6+
7+
template <typename T>
8+
__mlu_device__ void ComputeSwiglu(const T* input, const T* gate, T* output,
9+
size_t n) {
10+
if constexpr (std::is_same_v<T, float>) {
11+
for (size_t i = 0; i < n; ++i) {
12+
float g = gate[i];
13+
output[i] = input[i] * g / (1.0f + expf(-g));
14+
}
15+
} else if constexpr (std::is_same_v<T, __half>) {
16+
auto* out_h = reinterpret_cast<half*>(output);
17+
auto* in_h = reinterpret_cast<const half*>(input);
18+
auto* gate_h = reinterpret_cast<const half*>(gate);
19+
__bang_active_sigmoid(out_h, gate_h, n);
20+
__bang_mul(out_h, out_h, gate_h, n);
21+
__bang_mul(out_h, out_h, in_h, n);
22+
} else {
23+
__bang_active_sigmoid(output, gate, n);
24+
__bang_mul(output, output, gate, n);
25+
__bang_mul(output, output, input, n);
26+
}
27+
}
28+
29+
template <typename T>
30+
__mlu_global__ void SwigluKernel(
31+
const T* input, const T* gate, T* output, const size_t* out_shape,
32+
const ptrdiff_t* out_strides, const size_t* input_shape,
33+
const ptrdiff_t* input_strides, const size_t* gate_shape,
34+
const ptrdiff_t* gate_strides, size_t output_size, int ndim, bool fast_path,
35+
bool out_contiguous) {
36+
size_t elements_per_task = (output_size + taskDim - 1) / taskDim;
37+
size_t start = taskId * elements_per_task;
38+
size_t end = start + elements_per_task;
39+
if (end > output_size) end = output_size;
40+
size_t num_elements = end > start ? end - start : 0;
41+
if (num_elements == 0) return;
42+
43+
size_t nram_usable = NRAM_MAX_SIZE - 256;
44+
size_t block_size = nram_usable / (3 * sizeof(T));
45+
block_size = (block_size / 64) * 64;
46+
if (block_size == 0) block_size = 64;
47+
48+
T* input_buf = reinterpret_cast<T*>(nram_buffer);
49+
T* gate_buf = input_buf + block_size;
50+
T* output_buf = gate_buf + block_size;
51+
52+
size_t processed = 0;
53+
54+
if (fast_path) {
55+
while (processed < num_elements) {
56+
size_t curr = block_size;
57+
if (curr > num_elements - processed) curr = num_elements - processed;
58+
59+
__memcpy(input_buf, input + start + processed, curr * sizeof(T),
60+
GDRAM2NRAM);
61+
__memcpy(gate_buf, gate + start + processed, curr * sizeof(T),
62+
GDRAM2NRAM);
63+
ComputeSwiglu<T>(input_buf, gate_buf, output_buf, curr);
64+
__memcpy(output + start + processed, output_buf, curr * sizeof(T),
65+
NRAM2GDRAM);
66+
67+
processed += curr;
68+
}
69+
return;
70+
}
71+
72+
// General path: handle non-contiguous tensors and broadcasting.
73+
while (processed < num_elements) {
74+
size_t curr = block_size;
75+
if (curr > num_elements - processed) curr = num_elements - processed;
76+
77+
for (size_t i = 0; i < curr; ++i) {
78+
size_t flat_idx = start + processed + i;
79+
80+
// Compute `input` offset.
81+
{
82+
size_t tmp = flat_idx;
83+
ptrdiff_t offset = 0;
84+
for (int d = ndim - 1; d >= 0; --d) {
85+
size_t coord = tmp % out_shape[d];
86+
tmp /= out_shape[d];
87+
size_t c = coord < input_shape[d] ? coord : 0;
88+
offset += static_cast<ptrdiff_t>(c) * input_strides[d];
89+
}
90+
input_buf[i] = input[offset];
91+
}
92+
93+
// Compute `gate` offset.
94+
{
95+
size_t tmp = flat_idx;
96+
ptrdiff_t offset = 0;
97+
for (int d = ndim - 1; d >= 0; --d) {
98+
size_t coord = tmp % out_shape[d];
99+
tmp /= out_shape[d];
100+
size_t c = coord < gate_shape[d] ? coord : 0;
101+
offset += static_cast<ptrdiff_t>(c) * gate_strides[d];
102+
}
103+
gate_buf[i] = gate[offset];
104+
}
105+
}
106+
107+
ComputeSwiglu<T>(input_buf, gate_buf, output_buf, curr);
108+
109+
if (out_contiguous) {
110+
__memcpy(output + start + processed, output_buf, curr * sizeof(T),
111+
NRAM2GDRAM);
112+
} else {
113+
for (size_t i = 0; i < curr; ++i) {
114+
size_t flat_idx = start + processed + i;
115+
size_t tmp = flat_idx;
116+
ptrdiff_t offset = 0;
117+
for (int d = ndim - 1; d >= 0; --d) {
118+
size_t coord = tmp % out_shape[d];
119+
offset += static_cast<ptrdiff_t>(coord) * out_strides[d];
120+
tmp /= out_shape[d];
121+
}
122+
output[offset] = output_buf[i];
123+
}
124+
}
125+
126+
processed += curr;
127+
}
128+
}
129+
130+
template <typename T>
131+
void SwigluUnion(void* workspace, int core_per_cluster, int cluster_count,
132+
cnrtQueue_t queue, void* out, const void* input,
133+
const void* gate, const size_t* out_shape,
134+
const ptrdiff_t* out_strides, const size_t* input_shape,
135+
const ptrdiff_t* input_strides, const size_t* gate_shape,
136+
const ptrdiff_t* gate_strides, size_t output_size, int ndim,
137+
bool fast_path, bool out_contiguous) {
138+
cnrtDim3_t kernel_dim;
139+
cnrtFunctionType_t kernel_type;
140+
141+
kernel_dim.x = core_per_cluster;
142+
kernel_dim.y = cluster_count;
143+
kernel_dim.z = 1;
144+
kernel_type = cnrtFuncTypeUnion1;
145+
146+
auto out_ = reinterpret_cast<T*>(out);
147+
auto input_ = reinterpret_cast<const T*>(input);
148+
auto gate_ = reinterpret_cast<const T*>(gate);
149+
150+
char* tmp = reinterpret_cast<char*>(workspace);
151+
size_t* mlu_out_shape = reinterpret_cast<size_t*>(tmp);
152+
size_t* mlu_input_shape = mlu_out_shape + ndim;
153+
size_t* mlu_gate_shape = mlu_input_shape + ndim;
154+
ptrdiff_t* mlu_out_strides =
155+
reinterpret_cast<ptrdiff_t*>(mlu_gate_shape + ndim);
156+
ptrdiff_t* mlu_input_strides = mlu_out_strides + ndim;
157+
ptrdiff_t* mlu_gate_strides = mlu_input_strides + ndim;
158+
159+
CNRT_CHECK(cnrtMemcpyAsync(mlu_out_shape, const_cast<size_t*>(out_shape),
160+
ndim * sizeof(size_t), queue,
161+
cnrtMemcpyHostToDev));
162+
CNRT_CHECK(cnrtMemcpyAsync(mlu_input_shape, const_cast<size_t*>(input_shape),
163+
ndim * sizeof(size_t), queue,
164+
cnrtMemcpyHostToDev));
165+
CNRT_CHECK(cnrtMemcpyAsync(mlu_gate_shape, const_cast<size_t*>(gate_shape),
166+
ndim * sizeof(size_t), queue,
167+
cnrtMemcpyHostToDev));
168+
CNRT_CHECK(
169+
cnrtMemcpyAsync(mlu_out_strides, const_cast<ptrdiff_t*>(out_strides),
170+
ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
171+
CNRT_CHECK(
172+
cnrtMemcpyAsync(mlu_input_strides, const_cast<ptrdiff_t*>(input_strides),
173+
ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
174+
CNRT_CHECK(
175+
cnrtMemcpyAsync(mlu_gate_strides, const_cast<ptrdiff_t*>(gate_strides),
176+
ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
177+
178+
SwigluKernel<T><<<kernel_dim, kernel_type, queue>>>(
179+
input_, gate_, out_, mlu_out_shape, mlu_out_strides, mlu_input_shape,
180+
mlu_input_strides, mlu_gate_shape, mlu_gate_strides, output_size, ndim,
181+
fast_path, out_contiguous);
182+
183+
cnrtQueueSync(queue);
184+
}
185+
186+
template void SwigluUnion<__half>(void*, int, int, cnrtQueue_t, void*,
187+
const void*, const void*, const size_t*,
188+
const ptrdiff_t*, const size_t*,
189+
const ptrdiff_t*, const size_t*,
190+
const ptrdiff_t*, size_t, int, bool, bool);
191+
192+
template void SwigluUnion<__bang_bfloat16>(void*, int, int, cnrtQueue_t, void*,
193+
const void*, const void*,
194+
const size_t*, const ptrdiff_t*,
195+
const size_t*, const ptrdiff_t*,
196+
const size_t*, const ptrdiff_t*,
197+
size_t, int, bool, bool);
198+
199+
template void SwigluUnion<float>(void*, int, int, cnrtQueue_t, void*,
200+
const void*, const void*, const size_t*,
201+
const ptrdiff_t*, const size_t*,
202+
const ptrdiff_t*, const size_t*,
203+
const ptrdiff_t*, size_t, int, bool, bool);
204+
205+
} // namespace infini::ops
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#ifndef INFINI_OPS_CAMBRICON_SWIGLU_H_
2+
#define INFINI_OPS_CAMBRICON_SWIGLU_H_
3+
4+
#include "base/swiglu.h"
5+
#include "native/cambricon/common.h"
6+
#include "native/cambricon/data_type_.h"
7+
8+
namespace infini::ops {
9+
10+
template <typename T>
11+
void SwigluUnion(void* workspace, int core_per_cluster, int cluster_count,
12+
cnrtQueue_t queue, void* out, const void* input,
13+
const void* gate, const size_t* out_shape,
14+
const ptrdiff_t* out_strides, const size_t* input_shape,
15+
const ptrdiff_t* input_strides, const size_t* gate_shape,
16+
const ptrdiff_t* gate_strides, size_t output_size, int ndim,
17+
bool fast_path, bool out_contiguous);
18+
19+
template <>
20+
class Operator<Swiglu, Device::Type::kCambricon> : public Swiglu {
21+
public:
22+
Operator(const Tensor input, const Tensor gate, Tensor out)
23+
: Swiglu{input, gate, out} {
24+
cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster,
25+
&cluster_count);
26+
cnrtMalloc(&default_workspace_, workspace_size_in_bytes());
27+
}
28+
29+
void operator()(const Tensor input, const Tensor gate,
30+
Tensor out) const override {
31+
auto queue = static_cast<cnrtQueue_t>(stream_ ? stream_ : 0);
32+
auto workspace{workspace_ ? workspace_ : default_workspace_};
33+
34+
bool fast_path = is_input_contiguous_ && is_gate_contiguous_ &&
35+
is_out_contiguous_ && input_shape_ == out_shape_ &&
36+
gate_shape_ == out_shape_;
37+
38+
DispatchFunc<
39+
List<DataType::kFloat16, DataType::kBFloat16, DataType::kFloat32>>(
40+
{static_cast<int64_t>(out_type_)},
41+
[&](auto tag) {
42+
using T = TypeMapType<Device::Type::kCambricon, ListGet<0>(tag)>;
43+
SwigluUnion<T>(workspace, core_per_cluster, cluster_count, queue,
44+
out.data(), input.data(), gate.data(),
45+
out_shape_.data(), out_strides_.data(),
46+
input_shape_.data(), input_strides_.data(),
47+
gate_shape_.data(), gate_strides_.data(), output_size_,
48+
ndim_, fast_path, is_out_contiguous_);
49+
},
50+
"CambriconSwiglu::operator() - output dispatch");
51+
}
52+
53+
~Operator() { cnrtFree(default_workspace_); }
54+
55+
std::size_t workspace_size_in_bytes() const override {
56+
return ndim_ * (3 * sizeof(size_t) + 3 * sizeof(ptrdiff_t));
57+
}
58+
59+
void* default_workspace_{nullptr};
60+
int core_per_cluster = 0;
61+
int cluster_count = 0;
62+
};
63+
64+
} // namespace infini::ops
65+
66+
#endif

0 commit comments

Comments
 (0)