Skip to content

Commit c79a83f

Browse files
committed
feat: add cambricon add op
1 parent a334495 commit c79a83f

File tree

3 files changed

+285
-0
lines changed

3 files changed

+285
-0
lines changed

src/cambricon/add/add.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#ifndef INFINI_OPS_CAMBRICON_ADD_ADD_H_
2+
#define INFINI_OPS_CAMBRICON_ADD_ADD_H_
3+
4+
#include "cambricon/common.h"
5+
#include "base/add.h"
6+
#include "cambricon/data_type_.h"
7+
8+
namespace infini::ops {
9+
10+
template <typename T>
11+
void AddUnion(void *workspace, int core_per_cluster, int cluster_count,
12+
cnrtQueue_t queue, void *out, const void *input,
13+
const void *other, const size_t *out_shape,
14+
const ptrdiff_t *out_strides, const size_t *input_shape,
15+
const ptrdiff_t *input_strides, const size_t *other_shape,
16+
const ptrdiff_t *other_strides, size_t output_size, int ndim,
17+
bool fast_path, bool out_contiguous);
18+
19+
template <>
20+
class Operator<Add, Device::Type::kCambricon> : public Add {
21+
public:
22+
Operator(const Tensor input, const Tensor other, Tensor out)
23+
: Add{input, other, out} {
24+
cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster,
25+
&cluster_count);
26+
cnrtMalloc(&default_workspace_, workspace_size_in_bytes());
27+
}
28+
29+
void operator()(const Tensor input, const Tensor other,
30+
Tensor out) const override {
31+
auto queue = static_cast<cnrtQueue_t>(stream_ ? stream_ : 0);
32+
auto workspace{workspace_ ? workspace_ : default_workspace_};
33+
34+
bool fast_path = is_input_contiguous_ && is_other_contiguous_ &&
35+
is_out_contiguous_ && input_shape_ == out_shape_ &&
36+
other_shape_ == out_shape_;
37+
38+
DispatchFunc<List<DataType::kFloat16, DataType::kBFloat16,
39+
DataType::kFloat32, DataType::kInt32, DataType::kInt64>>(
40+
{static_cast<int64_t>(out_type_)},
41+
[&](auto tag) {
42+
using T = TypeMapType<Device::Type::kCambricon, ListGet<0>(tag)>;
43+
AddUnion<T>(workspace, core_per_cluster, cluster_count, queue,
44+
out.data(), input.data(), other.data(), out_shape_.data(),
45+
out_strides_.data(), input_shape_.data(),
46+
input_strides_.data(), other_shape_.data(),
47+
other_strides_.data(), output_size_, ndim_, fast_path,
48+
is_out_contiguous_);
49+
},
50+
"CambriconAdd::operator() - output dispatch");
51+
}
52+
53+
~Operator() { cnrtFree(default_workspace_); }
54+
55+
std::size_t workspace_size_in_bytes() const override {
56+
return ndim_ * (3 * sizeof(size_t) + 3 * sizeof(ptrdiff_t));
57+
}
58+
59+
void *default_workspace_{nullptr};
60+
int core_per_cluster = 0;
61+
int cluster_count = 0;
62+
};
63+
64+
} // namespace infini::ops
65+
66+
#endif

src/cambricon/add/kernel.mlu

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
#include "add.h"
2+
3+
__nram__ char nram_buffer[NRAM_MAX_SIZE];
4+
5+
namespace infini::ops {
6+
7+
template <typename T>
8+
__mlu_device__ void BangAdd(const T *src1, const T *src2, T *dst,
9+
size_t n) {
10+
if constexpr (std::is_same_v<T, __half>) {
11+
__bang_add(reinterpret_cast<half *>(dst),
12+
reinterpret_cast<const half *>(src1),
13+
reinterpret_cast<const half *>(src2), n);
14+
} else {
15+
__bang_add(dst, src1, src2, n);
16+
}
17+
}
18+
19+
template <typename T>
20+
__mlu_global__ void AddKernel(const T *input, const T *other, T *output,
21+
const size_t *out_shape,
22+
const ptrdiff_t *out_strides,
23+
const size_t *input_shape,
24+
const ptrdiff_t *input_strides,
25+
const size_t *other_shape,
26+
const ptrdiff_t *other_strides,
27+
size_t output_size, int ndim, bool fast_path,
28+
bool out_contiguous) {
29+
size_t elements_per_task = (output_size + taskDim - 1) / taskDim;
30+
size_t start = taskId * elements_per_task;
31+
size_t end = start + elements_per_task;
32+
if (end > output_size) end = output_size;
33+
size_t num_elements = end > start ? end - start : 0;
34+
if (num_elements == 0) return;
35+
36+
size_t nram_usable = NRAM_MAX_SIZE - 256;
37+
size_t block_size = nram_usable / (3 * sizeof(T));
38+
block_size = (block_size / 64) * 64; // Align to 64 elements
39+
if (block_size == 0) block_size = 64;
40+
41+
T *input_buf = reinterpret_cast<T *>(nram_buffer);
42+
T *other_buf = input_buf + block_size;
43+
T *output_buf = other_buf + block_size;
44+
45+
size_t processed = 0;
46+
47+
if (fast_path) {
48+
// Fast path: all tensors contiguous with matching shapes (no broadcast).
49+
while (processed < num_elements) {
50+
size_t curr = block_size;
51+
if (curr > num_elements - processed) curr = num_elements - processed;
52+
53+
__memcpy(input_buf, input + start + processed,
54+
curr * sizeof(T), GDRAM2NRAM);
55+
__memcpy(other_buf, other + start + processed,
56+
curr * sizeof(T), GDRAM2NRAM);
57+
BangAdd(input_buf, other_buf, output_buf, curr);
58+
__memcpy(output + start + processed, output_buf,
59+
curr * sizeof(T), NRAM2GDRAM);
60+
61+
processed += curr;
62+
}
63+
return;
64+
}
65+
66+
// General path: handle non-contiguous tensors and broadcasting.
67+
while (processed < num_elements) {
68+
size_t curr = block_size;
69+
if (curr > num_elements - processed) curr = num_elements - processed;
70+
71+
for (size_t i = 0; i < curr; ++i) {
72+
size_t flat_idx = start + processed + i;
73+
74+
// Compute `input` offset.
75+
{
76+
size_t tmp = flat_idx;
77+
ptrdiff_t offset = 0;
78+
for (int d = ndim - 1; d >= 0; --d) {
79+
size_t coord = tmp % out_shape[d];
80+
tmp /= out_shape[d];
81+
size_t c = coord < input_shape[d] ? coord : 0;
82+
offset += static_cast<ptrdiff_t>(c) * input_strides[d];
83+
}
84+
input_buf[i] = input[offset];
85+
}
86+
87+
// Compute `other` offset.
88+
{
89+
size_t tmp = flat_idx;
90+
ptrdiff_t offset = 0;
91+
for (int d = ndim - 1; d >= 0; --d) {
92+
size_t coord = tmp % out_shape[d];
93+
tmp /= out_shape[d];
94+
size_t c = coord < other_shape[d] ? coord : 0;
95+
offset += static_cast<ptrdiff_t>(c) * other_strides[d];
96+
}
97+
other_buf[i] = other[offset];
98+
}
99+
}
100+
101+
BangAdd(input_buf, other_buf, output_buf, curr);
102+
103+
if (out_contiguous) {
104+
__memcpy(output + start + processed, output_buf,
105+
curr * sizeof(T), NRAM2GDRAM);
106+
} else {
107+
for (size_t i = 0; i < curr; ++i) {
108+
size_t flat_idx = start + processed + i;
109+
size_t tmp = flat_idx;
110+
ptrdiff_t offset = 0;
111+
for (int d = ndim - 1; d >= 0; --d) {
112+
size_t coord = tmp % out_shape[d];
113+
offset += static_cast<ptrdiff_t>(coord) * out_strides[d];
114+
tmp /= out_shape[d];
115+
}
116+
output[offset] = output_buf[i];
117+
}
118+
}
119+
120+
processed += curr;
121+
}
122+
}
123+
124+
template <typename T>
125+
void AddUnion(void *workspace, int core_per_cluster, int cluster_count,
126+
cnrtQueue_t queue, void *out, const void *input, const void *other,
127+
const size_t *out_shape, const ptrdiff_t *out_strides,
128+
const size_t *input_shape, const ptrdiff_t *input_strides,
129+
const size_t *other_shape, const ptrdiff_t *other_strides,
130+
size_t output_size, int ndim, bool fast_path,
131+
bool out_contiguous) {
132+
cnrtDim3_t kernel_dim;
133+
cnrtFunctionType_t kernel_type;
134+
135+
kernel_dim.x = core_per_cluster;
136+
kernel_dim.y = cluster_count;
137+
kernel_dim.z = 1;
138+
kernel_type = cnrtFuncTypeUnion1;
139+
140+
auto out_ = reinterpret_cast<T *>(out);
141+
auto input_ = reinterpret_cast<const T *>(input);
142+
auto other_ = reinterpret_cast<const T *>(other);
143+
144+
char *tmp = reinterpret_cast<char *>(workspace);
145+
size_t *mlu_out_shape = reinterpret_cast<size_t *>(tmp);
146+
size_t *mlu_input_shape = mlu_out_shape + ndim;
147+
size_t *mlu_other_shape = mlu_input_shape + ndim;
148+
ptrdiff_t *mlu_out_strides =
149+
reinterpret_cast<ptrdiff_t *>(mlu_other_shape + ndim);
150+
ptrdiff_t *mlu_input_strides = mlu_out_strides + ndim;
151+
ptrdiff_t *mlu_other_strides = mlu_input_strides + ndim;
152+
153+
CNRT_CHECK(cnrtMemcpyAsync(mlu_out_shape, const_cast<size_t *>(out_shape),
154+
ndim * sizeof(size_t), queue,
155+
cnrtMemcpyHostToDev));
156+
CNRT_CHECK(cnrtMemcpyAsync(mlu_input_shape, const_cast<size_t *>(input_shape),
157+
ndim * sizeof(size_t), queue,
158+
cnrtMemcpyHostToDev));
159+
CNRT_CHECK(cnrtMemcpyAsync(mlu_other_shape, const_cast<size_t *>(other_shape),
160+
ndim * sizeof(size_t), queue,
161+
cnrtMemcpyHostToDev));
162+
CNRT_CHECK(cnrtMemcpyAsync(mlu_out_strides,
163+
const_cast<ptrdiff_t *>(out_strides),
164+
ndim * sizeof(ptrdiff_t), queue,
165+
cnrtMemcpyHostToDev));
166+
CNRT_CHECK(cnrtMemcpyAsync(mlu_input_strides,
167+
const_cast<ptrdiff_t *>(input_strides),
168+
ndim * sizeof(ptrdiff_t), queue,
169+
cnrtMemcpyHostToDev));
170+
CNRT_CHECK(cnrtMemcpyAsync(mlu_other_strides,
171+
const_cast<ptrdiff_t *>(other_strides),
172+
ndim * sizeof(ptrdiff_t), queue,
173+
cnrtMemcpyHostToDev));
174+
175+
AddKernel<T><<<kernel_dim, kernel_type, queue>>>(
176+
input_, other_, out_, mlu_out_shape, mlu_out_strides, mlu_input_shape,
177+
mlu_input_strides, mlu_other_shape, mlu_other_strides, output_size, ndim,
178+
fast_path, out_contiguous);
179+
180+
cnrtQueueSync(queue);
181+
}
182+
183+
template void AddUnion<__half>(void *, int, int, cnrtQueue_t, void *,
184+
const void *, const void *, const size_t *,
185+
const ptrdiff_t *, const size_t *,
186+
const ptrdiff_t *, const size_t *,
187+
const ptrdiff_t *, size_t, int, bool, bool);
188+
189+
template void AddUnion<__bang_bfloat16>(void *, int, int, cnrtQueue_t, void *,
190+
const void *, const void *,
191+
const size_t *, const ptrdiff_t *,
192+
const size_t *, const ptrdiff_t *,
193+
const size_t *, const ptrdiff_t *,
194+
size_t, int, bool, bool);
195+
196+
template void AddUnion<float>(void *, int, int, cnrtQueue_t, void *,
197+
const void *, const void *, const size_t *,
198+
const ptrdiff_t *, const size_t *,
199+
const ptrdiff_t *, const size_t *,
200+
const ptrdiff_t *, size_t, int, bool, bool);
201+
202+
template void AddUnion<int32_t>(void *, int, int, cnrtQueue_t, void *,
203+
const void *, const void *, const size_t *,
204+
const ptrdiff_t *, const size_t *,
205+
const ptrdiff_t *, const size_t *,
206+
const ptrdiff_t *, size_t, int, bool, bool);
207+
208+
template void AddUnion<int64_t>(void *, int, int, cnrtQueue_t, void *,
209+
const void *, const void *, const size_t *,
210+
const ptrdiff_t *, const size_t *,
211+
const ptrdiff_t *, const size_t *,
212+
const ptrdiff_t *, size_t, int, bool, bool);
213+
214+
} // namespace infini::ops

tests/test_add.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def test_add(
4545
pytest.skip(
4646
"The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`."
4747
)
48+
49+
if device == "mlu" and ( dtype in _UINT_DTYPES or dtype == torch.int16):
50+
pytest.skip(
51+
"The `torch.mlu` test cloning path does not support `int16`, `uint16`, `uint32`, or `uint64`."
52+
)
4853

4954
if dtype in _INT_DTYPES or dtype in _UINT_DTYPES:
5055
input = randint_strided(

0 commit comments

Comments
 (0)