Skip to content

Commit 7de4dcf

Browse files
committed
feat: add cambricon causal softmax op
1 parent 8c92b2e commit 7de4dcf

File tree

6 files changed

+392
-9
lines changed

6 files changed

+392
-9
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ if(WITH_CAMBRICON)
133133
message(STATUS "Found cncc: ${CNCC_COMPILER}")
134134
set(MLU_COMPILE_OPTS
135135
-c --bang-mlu-arch=mtp_592 -O3 -fPIC -Wall -Werror -std=c++17 -pthread
136+
-DWITH_CAMBRICON=1
136137
-I${CMAKE_CURRENT_SOURCE_DIR} -I${NEUWARE_HOME}/include
137138
-idirafter /usr/local/neuware/lib/clang/11.1.0/include
138139
)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#ifndef INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H
2+
#define INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H
3+
4+
#include "base/causal_softmax.h"
5+
#include "cambricon/common.h"
6+
7+
namespace infini::ops {
8+
9+
// TODO: Remove forward declaration.
10+
template <typename T>
11+
void CausalSoftmaxUnion(void *workspace, int core_per_cluster,
12+
int cluster_count, cnrtQueue_t queue, void *y,
13+
const void *x, size_t batch_size_, size_t seq_len_,
14+
size_t total_seq_len_, ptrdiff_t y_stride_b,
15+
ptrdiff_t y_stride_i, ptrdiff_t y_stride_j,
16+
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i,
17+
ptrdiff_t x_stride_j);
18+
19+
template <>
20+
class Operator<CausalSoftmax, Device::Type::kCambricon> : public CausalSoftmax {
21+
public:
22+
Operator(const Tensor input, Tensor out) : CausalSoftmax{input, out} {
23+
cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster,
24+
&cluster_count);
25+
}
26+
void operator()(const Tensor input, Tensor out) const override {
27+
auto queue = static_cast<cnrtQueue_t>(stream_ ? stream_ : 0);
28+
auto workspace{workspace_ ? workspace_ : default_workspace_};
29+
ptrdiff_t y_stride_b = ndim_ == 3 ? out_strides_[0] : 1;
30+
ptrdiff_t y_stride_i = ndim_ == 3 ? out_strides_[1] : out_strides_[0];
31+
ptrdiff_t y_stride_j = ndim_ == 3 ? out_strides_[2] : out_strides_[1];
32+
ptrdiff_t x_stride_b = ndim_ == 3 ? input_strides_[0] : 1;
33+
ptrdiff_t x_stride_i = ndim_ == 3 ? input_strides_[1] : input_strides_[0];
34+
ptrdiff_t x_stride_j = ndim_ == 3 ? input_strides_[2] : input_strides_[1];
35+
36+
DispatchFunc<
37+
List<DataType::kFloat16, DataType::kBFloat16, DataType::kFloat32>>(
38+
{input.dtype()},
39+
[&](auto input_tag) {
40+
using InputT = typename decltype(input_tag)::type;
41+
CausalSoftmaxUnion<InputT>(
42+
workspace, core_per_cluster, cluster_count, queue, out.data(),
43+
input.data(), batch_size_, seq_len_, total_seq_len_, y_stride_b,
44+
y_stride_i, y_stride_j, x_stride_b, x_stride_i, x_stride_j);
45+
},
46+
"CambriconCausalSoftmax::operator() - output dispatch");
47+
}
48+
49+
std::size_t workspace_size_in_bytes() const override { return 0; }
50+
51+
~Operator() {}
52+
53+
void *default_workspace_{nullptr};
54+
int core_per_cluster = 0;
55+
int cluster_count = 0;
56+
};
57+
58+
} // namespace infini::ops
59+
60+
#endif
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#include "causal_softmax.h"
2+
3+
__nram__ char nram_buffer[NRAM_MAX_SIZE];
4+
const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 4;
5+
6+
namespace infini::ops {
7+
8+
template <typename T>
9+
__mlu_func__ void ProcessSoftmaxStep(const T *input, T *output, float scalar,
10+
int num_elements, int stride,
11+
bool is_exp_phase) {
12+
constexpr bool is_half = std::is_same_v<T, __half>;
13+
constexpr bool is_bfloat16 = std::is_same_v<T, __bang_bfloat16>;
14+
constexpr bool is_float = !is_half && !is_bfloat16;
15+
16+
const int chunk_size =
17+
SRC_MAX_SIZE /
18+
((is_half || is_bfloat16) ? (2 * sizeof(float)) : sizeof(float));
19+
float *float_buffer = (float *)nram_buffer;
20+
T *temp_buffer =
21+
is_float ? nullptr : (T *)(nram_buffer + chunk_size * sizeof(float));
22+
23+
// Common stride configurations.
24+
const int src_stride = stride * sizeof(T);
25+
const int dst_stride = stride * sizeof(T);
26+
27+
int processed = 0;
28+
while (processed < num_elements) {
29+
int curr_batch = std::min(chunk_size, num_elements - processed);
30+
31+
if constexpr (is_float) {
32+
__memcpy(
33+
float_buffer, (is_exp_phase ? input : output) + processed * stride,
34+
sizeof(float), GDRAM2NRAM, sizeof(float), src_stride, curr_batch - 1);
35+
} else {
36+
__memcpy(temp_buffer,
37+
(is_exp_phase ? input : output) + processed * stride, sizeof(T),
38+
GDRAM2NRAM, sizeof(T), src_stride, curr_batch - 1);
39+
40+
if constexpr (is_half) {
41+
__bang_half2float(float_buffer, reinterpret_cast<half *>(temp_buffer),
42+
curr_batch);
43+
} else if constexpr (is_bfloat16) {
44+
__bang_bfloat162float(float_buffer, temp_buffer, curr_batch);
45+
}
46+
}
47+
48+
// Common processing for all types.
49+
if (is_exp_phase) {
50+
__bang_sub_scalar(float_buffer, float_buffer, scalar,
51+
curr_batch); // scalar is max_val
52+
__bang_active_exphp(float_buffer, float_buffer, curr_batch);
53+
} else {
54+
__bang_mul_scalar(float_buffer, float_buffer, scalar,
55+
curr_batch); // scalar is 1.0f/sum_val
56+
}
57+
58+
if constexpr (is_float) {
59+
__memcpy(output + processed * stride, float_buffer, sizeof(float),
60+
NRAM2GDRAM, dst_stride, sizeof(float), curr_batch - 1);
61+
} else {
62+
if constexpr (is_half) {
63+
__bang_float2half(reinterpret_cast<half *>(temp_buffer), float_buffer,
64+
curr_batch);
65+
} else if constexpr (is_bfloat16) {
66+
__bang_float2bfloat16(temp_buffer, float_buffer, curr_batch);
67+
}
68+
69+
__memcpy(output + processed * stride, temp_buffer, sizeof(T), NRAM2GDRAM,
70+
dst_stride, sizeof(T), curr_batch - 1);
71+
}
72+
73+
processed += curr_batch;
74+
}
75+
}
76+
77+
template <typename T>
78+
__mlu_global__ void CausalSoftmax(T *y, const T *x, size_t batch_size,
79+
size_t seq_len, size_t total_seq_len,
80+
ptrdiff_t y_stride_b, ptrdiff_t y_stride_i,
81+
ptrdiff_t y_stride_j, ptrdiff_t x_stride_b,
82+
ptrdiff_t x_stride_i, ptrdiff_t x_stride_j) {
83+
size_t task_id = taskId;
84+
size_t task_num = taskDimX * taskDimY;
85+
86+
size_t total_tasks = batch_size * seq_len;
87+
size_t tasks_per_core = (total_tasks + task_num - 1) / task_num;
88+
size_t start = task_id * tasks_per_core;
89+
size_t end = std::min(start + tasks_per_core, total_tasks);
90+
91+
const int max_batch = SRC_MAX_SIZE / sizeof(T);
92+
T *src = (T *)nram_buffer;
93+
float *dst = (float *)(nram_buffer + max_batch * sizeof(T));
94+
95+
for (size_t index = start; index < end; index++) {
96+
size_t batch = index / seq_len;
97+
size_t i = (index % seq_len);
98+
ptrdiff_t y_offset = batch * y_stride_b + i * y_stride_i;
99+
ptrdiff_t x_offset = batch * x_stride_b + i * x_stride_i;
100+
T *y_ = y + y_offset;
101+
const T *x_ = x + x_offset;
102+
103+
// Calculate the valid sequence length for this position.
104+
size_t valid_len = total_seq_len - seq_len + i + 1;
105+
106+
// Zero out future positions.
107+
for (size_t j = valid_len; j < total_seq_len; j++) {
108+
y_[j * y_stride_j] = (T)0.0f;
109+
}
110+
111+
// Calculate max value using optimized reduction.
112+
float max_val =
113+
infini::ops::reduce::MaxBatched(x_, src, dst, valid_len, max_batch);
114+
115+
// Compute `exp(x - max)`.
116+
ProcessSoftmaxStep(x_, y_, max_val, valid_len, x_stride_j, true);
117+
118+
// Calculate sum of exponentials.
119+
float sum_val =
120+
infini::ops::reduce::SumBatched(y_, src, dst, valid_len, max_batch);
121+
122+
// Normalize by sum.
123+
ProcessSoftmaxStep(y_, y_, 1.0f / sum_val, valid_len, y_stride_j, false);
124+
}
125+
}
126+
127+
template <typename T>
128+
void CausalSoftmaxUnion(void *workspace, int core_per_cluster,
129+
int cluster_count, cnrtQueue_t queue, void *y,
130+
const void *x, size_t batch_size_, size_t seq_len_,
131+
size_t total_seq_len_, ptrdiff_t y_stride_b,
132+
ptrdiff_t y_stride_i, ptrdiff_t y_stride_j,
133+
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i,
134+
ptrdiff_t x_stride_j) {
135+
cnrtDim3_t kernel_dim;
136+
cnrtFunctionType_t kernel_type;
137+
138+
kernel_dim.x = core_per_cluster;
139+
kernel_dim.y = cluster_count;
140+
kernel_dim.z = 1;
141+
kernel_type = cnrtFuncTypeUnion1;
142+
143+
CausalSoftmax<T><<<kernel_dim, kernel_type, queue>>>(
144+
(T *)y, (const T *)x, batch_size_, seq_len_, total_seq_len_, y_stride_b,
145+
y_stride_i, y_stride_j, x_stride_b, x_stride_i, x_stride_j);
146+
147+
cnrtQueueSync(queue);
148+
}
149+
150+
template void CausalSoftmaxUnion<__half>(void *, int, int, cnrtQueue_t, void *,
151+
const void *, size_t, size_t, size_t,
152+
ptrdiff_t, ptrdiff_t, ptrdiff_t,
153+
ptrdiff_t, ptrdiff_t, ptrdiff_t);
154+
155+
template void CausalSoftmaxUnion<__bang_bfloat16>(
156+
void *, int, int, cnrtQueue_t, void *, const void *, size_t, size_t, size_t,
157+
ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t);
158+
159+
template void CausalSoftmaxUnion<float>(void *, int, int, cnrtQueue_t, void *,
160+
const void *, size_t, size_t, size_t,
161+
ptrdiff_t, ptrdiff_t, ptrdiff_t,
162+
ptrdiff_t, ptrdiff_t, ptrdiff_t);
163+
164+
} // namespace infini::ops

0 commit comments

Comments
 (0)