Skip to content

Commit b36b844

Browse files
committed
feat(cambricon): add AddOp in Cambricon
1 parent c4141be commit b36b844

3 files changed

Lines changed: 276 additions & 0 deletions

File tree

src/native/cambricon/ops/add/add.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#ifndef INFINI_OPS_CAMBRICON_ADD_H_
2+
#define INFINI_OPS_CAMBRICON_ADD_H_
3+
4+
#include "base/add.h"
5+
#include "native/cambricon/common.h"
6+
#include "native/cambricon/data_type_.h"
7+
8+
namespace infini::ops {
9+
10+
template <typename T>
11+
void AddUnion(void* workspace, int core_per_cluster, int cluster_count,
12+
cnrtQueue_t queue, void* out, const void* input,
13+
const void* other, const size_t* out_shape,
14+
const ptrdiff_t* out_strides, const size_t* input_shape,
15+
const ptrdiff_t* input_strides, const size_t* other_shape,
16+
const ptrdiff_t* other_strides, size_t output_size, int ndim,
17+
bool fast_path, bool out_contiguous);
18+
19+
template <>
20+
class Operator<Add, Device::Type::kCambricon> : public Add {
21+
public:
22+
Operator(const Tensor input, const Tensor other, Tensor out)
23+
: Add{input, other, out} {
24+
cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster,
25+
&cluster_count);
26+
cnrtMalloc(&default_workspace_, workspace_size_in_bytes());
27+
}
28+
29+
void operator()(const Tensor input, const Tensor other,
30+
Tensor out) const override {
31+
auto queue = static_cast<cnrtQueue_t>(stream_ ? stream_ : 0);
32+
auto workspace{workspace_ ? workspace_ : default_workspace_};
33+
34+
bool fast_path = is_input_contiguous_ && is_other_contiguous_ &&
35+
is_out_contiguous_ && input_shape_ == out_shape_ &&
36+
other_shape_ == out_shape_;
37+
38+
DispatchFunc<List<DataType::kFloat16, DataType::kBFloat16,
39+
DataType::kFloat32, DataType::kInt32, DataType::kInt64>>(
40+
{static_cast<int64_t>(out_type_)},
41+
[&](auto tag) {
42+
using T = TypeMapType<Device::Type::kCambricon, ListGet<0>(tag)>;
43+
AddUnion<T>(workspace, core_per_cluster, cluster_count, queue,
44+
out.data(), input.data(), other.data(), out_shape_.data(),
45+
out_strides_.data(), input_shape_.data(),
46+
input_strides_.data(), other_shape_.data(),
47+
other_strides_.data(), output_size_, ndim_, fast_path,
48+
is_out_contiguous_);
49+
},
50+
"CambriconAdd::operator() - output dispatch");
51+
}
52+
53+
~Operator() { cnrtFree(default_workspace_); }
54+
55+
std::size_t workspace_size_in_bytes() const override {
56+
return ndim_ * (3 * sizeof(size_t) + 3 * sizeof(ptrdiff_t));
57+
}
58+
59+
void* default_workspace_{nullptr};
60+
int core_per_cluster = 0;
61+
int cluster_count = 0;
62+
};
63+
64+
} // namespace infini::ops
65+
66+
#endif
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
#include "add.h"
2+
3+
__nram__ char nram_buffer[NRAM_MAX_SIZE];
4+
5+
namespace infini::ops {
6+
7+
template <typename T>
8+
__mlu_device__ void BangAdd(const T* src1, const T* src2, T* dst, size_t n) {
9+
if constexpr (std::is_same_v<T, __half>) {
10+
__bang_add(reinterpret_cast<half*>(dst),
11+
reinterpret_cast<const half*>(src1),
12+
reinterpret_cast<const half*>(src2), n);
13+
} else {
14+
__bang_add(dst, src1, src2, n);
15+
}
16+
}
17+
18+
template <typename T>
19+
__mlu_global__ void AddKernel(
20+
const T* input, const T* other, T* output, const size_t* out_shape,
21+
const ptrdiff_t* out_strides, const size_t* input_shape,
22+
const ptrdiff_t* input_strides, const size_t* other_shape,
23+
const ptrdiff_t* other_strides, size_t output_size, int ndim,
24+
bool fast_path, bool out_contiguous) {
25+
size_t elements_per_task = (output_size + taskDim - 1) / taskDim;
26+
size_t start = taskId * elements_per_task;
27+
size_t end = start + elements_per_task;
28+
if (end > output_size) end = output_size;
29+
size_t num_elements = end > start ? end - start : 0;
30+
if (num_elements == 0) return;
31+
32+
size_t nram_usable = NRAM_MAX_SIZE - 256;
33+
size_t block_size = nram_usable / (3 * sizeof(T));
34+
block_size = (block_size / 64) * 64; // Align to 64 elements
35+
if (block_size == 0) block_size = 64;
36+
37+
T* input_buf = reinterpret_cast<T*>(nram_buffer);
38+
T* other_buf = input_buf + block_size;
39+
T* output_buf = other_buf + block_size;
40+
41+
size_t processed = 0;
42+
43+
if (fast_path) {
44+
// Fast path: all tensors contiguous with matching shapes (no broadcast).
45+
while (processed < num_elements) {
46+
size_t curr = block_size;
47+
if (curr > num_elements - processed) curr = num_elements - processed;
48+
49+
__memcpy(input_buf, input + start + processed, curr * sizeof(T),
50+
GDRAM2NRAM);
51+
__memcpy(other_buf, other + start + processed, curr * sizeof(T),
52+
GDRAM2NRAM);
53+
BangAdd(input_buf, other_buf, output_buf, curr);
54+
__memcpy(output + start + processed, output_buf, curr * sizeof(T),
55+
NRAM2GDRAM);
56+
57+
processed += curr;
58+
}
59+
return;
60+
}
61+
62+
// General path: handle non-contiguous tensors and broadcasting.
63+
while (processed < num_elements) {
64+
size_t curr = block_size;
65+
if (curr > num_elements - processed) curr = num_elements - processed;
66+
67+
for (size_t i = 0; i < curr; ++i) {
68+
size_t flat_idx = start + processed + i;
69+
70+
// Compute `input` offset.
71+
{
72+
size_t tmp = flat_idx;
73+
ptrdiff_t offset = 0;
74+
for (int d = ndim - 1; d >= 0; --d) {
75+
size_t coord = tmp % out_shape[d];
76+
tmp /= out_shape[d];
77+
size_t c = coord < input_shape[d] ? coord : 0;
78+
offset += static_cast<ptrdiff_t>(c) * input_strides[d];
79+
}
80+
input_buf[i] = input[offset];
81+
}
82+
83+
// Compute `other` offset.
84+
{
85+
size_t tmp = flat_idx;
86+
ptrdiff_t offset = 0;
87+
for (int d = ndim - 1; d >= 0; --d) {
88+
size_t coord = tmp % out_shape[d];
89+
tmp /= out_shape[d];
90+
size_t c = coord < other_shape[d] ? coord : 0;
91+
offset += static_cast<ptrdiff_t>(c) * other_strides[d];
92+
}
93+
other_buf[i] = other[offset];
94+
}
95+
}
96+
97+
BangAdd(input_buf, other_buf, output_buf, curr);
98+
99+
if (out_contiguous) {
100+
__memcpy(output + start + processed, output_buf, curr * sizeof(T),
101+
NRAM2GDRAM);
102+
} else {
103+
for (size_t i = 0; i < curr; ++i) {
104+
size_t flat_idx = start + processed + i;
105+
size_t tmp = flat_idx;
106+
ptrdiff_t offset = 0;
107+
for (int d = ndim - 1; d >= 0; --d) {
108+
size_t coord = tmp % out_shape[d];
109+
offset += static_cast<ptrdiff_t>(coord) * out_strides[d];
110+
tmp /= out_shape[d];
111+
}
112+
output[offset] = output_buf[i];
113+
}
114+
}
115+
116+
processed += curr;
117+
}
118+
}
119+
120+
template <typename T>
121+
void AddUnion(void* workspace, int core_per_cluster, int cluster_count,
122+
cnrtQueue_t queue, void* out, const void* input,
123+
const void* other, const size_t* out_shape,
124+
const ptrdiff_t* out_strides, const size_t* input_shape,
125+
const ptrdiff_t* input_strides, const size_t* other_shape,
126+
const ptrdiff_t* other_strides, size_t output_size, int ndim,
127+
bool fast_path, bool out_contiguous) {
128+
cnrtDim3_t kernel_dim;
129+
cnrtFunctionType_t kernel_type;
130+
131+
kernel_dim.x = core_per_cluster;
132+
kernel_dim.y = cluster_count;
133+
kernel_dim.z = 1;
134+
kernel_type = cnrtFuncTypeUnion1;
135+
136+
auto out_ = reinterpret_cast<T*>(out);
137+
auto input_ = reinterpret_cast<const T*>(input);
138+
auto other_ = reinterpret_cast<const T*>(other);
139+
140+
char* tmp = reinterpret_cast<char*>(workspace);
141+
size_t* mlu_out_shape = reinterpret_cast<size_t*>(tmp);
142+
size_t* mlu_input_shape = mlu_out_shape + ndim;
143+
size_t* mlu_other_shape = mlu_input_shape + ndim;
144+
ptrdiff_t* mlu_out_strides =
145+
reinterpret_cast<ptrdiff_t*>(mlu_other_shape + ndim);
146+
ptrdiff_t* mlu_input_strides = mlu_out_strides + ndim;
147+
ptrdiff_t* mlu_other_strides = mlu_input_strides + ndim;
148+
149+
CNRT_CHECK(cnrtMemcpyAsync(mlu_out_shape, const_cast<size_t*>(out_shape),
150+
ndim * sizeof(size_t), queue,
151+
cnrtMemcpyHostToDev));
152+
CNRT_CHECK(cnrtMemcpyAsync(mlu_input_shape, const_cast<size_t*>(input_shape),
153+
ndim * sizeof(size_t), queue,
154+
cnrtMemcpyHostToDev));
155+
CNRT_CHECK(cnrtMemcpyAsync(mlu_other_shape, const_cast<size_t*>(other_shape),
156+
ndim * sizeof(size_t), queue,
157+
cnrtMemcpyHostToDev));
158+
CNRT_CHECK(
159+
cnrtMemcpyAsync(mlu_out_strides, const_cast<ptrdiff_t*>(out_strides),
160+
ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
161+
CNRT_CHECK(
162+
cnrtMemcpyAsync(mlu_input_strides, const_cast<ptrdiff_t*>(input_strides),
163+
ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
164+
CNRT_CHECK(
165+
cnrtMemcpyAsync(mlu_other_strides, const_cast<ptrdiff_t*>(other_strides),
166+
ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev));
167+
168+
AddKernel<T><<<kernel_dim, kernel_type, queue>>>(
169+
input_, other_, out_, mlu_out_shape, mlu_out_strides, mlu_input_shape,
170+
mlu_input_strides, mlu_other_shape, mlu_other_strides, output_size, ndim,
171+
fast_path, out_contiguous);
172+
173+
cnrtQueueSync(queue);
174+
}
175+
176+
template void AddUnion<__half>(void*, int, int, cnrtQueue_t, void*, const void*,
177+
const void*, const size_t*, const ptrdiff_t*,
178+
const size_t*, const ptrdiff_t*, const size_t*,
179+
const ptrdiff_t*, size_t, int, bool, bool);
180+
181+
template void AddUnion<__bang_bfloat16>(void*, int, int, cnrtQueue_t, void*,
182+
const void*, const void*, const size_t*,
183+
const ptrdiff_t*, const size_t*,
184+
const ptrdiff_t*, const size_t*,
185+
const ptrdiff_t*, size_t, int, bool,
186+
bool);
187+
188+
template void AddUnion<float>(void*, int, int, cnrtQueue_t, void*, const void*,
189+
const void*, const size_t*, const ptrdiff_t*,
190+
const size_t*, const ptrdiff_t*, const size_t*,
191+
const ptrdiff_t*, size_t, int, bool, bool);
192+
193+
template void AddUnion<int32_t>(void*, int, int, cnrtQueue_t, void*,
194+
const void*, const void*, const size_t*,
195+
const ptrdiff_t*, const size_t*,
196+
const ptrdiff_t*, const size_t*,
197+
const ptrdiff_t*, size_t, int, bool, bool);
198+
199+
template void AddUnion<int64_t>(void*, int, int, cnrtQueue_t, void*,
200+
const void*, const void*, const size_t*,
201+
const ptrdiff_t*, const size_t*,
202+
const ptrdiff_t*, const size_t*,
203+
const ptrdiff_t*, size_t, int, bool, bool);
204+
205+
} // namespace infini::ops

tests/test_add.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def test_add(
6060
"The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`."
6161
)
6262

63+
if device == "mlu" and (dtype in _UINT_DTYPES or dtype == torch.int16):
64+
pytest.skip(
65+
"The `torch.mlu` test cloning path does not support `int16`, `uint16`, `uint32`, or `uint64`."
66+
)
67+
6368
if implementation_index == 1 and dtype in _UINT_DTYPES:
6469
pytest.skip("ATen `add` does not support unsigned integer types")
6570

0 commit comments

Comments
 (0)