Skip to content

Commit dbca812

Browse files
committed
feat: add cambricon causal softmax op
1 parent 8c92b2e commit dbca812

File tree

6 files changed

+401
-11
lines changed

6 files changed

+401
-11
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ if(WITH_CAMBRICON)
133133
message(STATUS "Found cncc: ${CNCC_COMPILER}")
134134
set(MLU_COMPILE_OPTS
135135
-c --bang-mlu-arch=mtp_592 -O3 -fPIC -Wall -Werror -std=c++17 -pthread
136+
-DWITH_CAMBRICON=1
136137
-I${CMAKE_CURRENT_SOURCE_DIR} -I${NEUWARE_HOME}/include
137138
-idirafter /usr/local/neuware/lib/clang/11.1.0/include
138139
)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#ifndef INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H
2+
#define INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H
3+
4+
#include "base/causal_softmax.h"
5+
#include "cambricon/common.h"
6+
7+
namespace infini::ops {
8+
9+
// Forward declaration - implementation is in kernel.mlu
10+
template <typename T>
11+
void CausalSoftmaxUnion(void *workspace, int core_per_cluster,
12+
int cluster_count, cnrtQueue_t queue, void *y,
13+
const void *x, size_t batch_size_, size_t seq_len_,
14+
size_t total_seq_len_, ptrdiff_t y_stride_b,
15+
ptrdiff_t y_stride_i, ptrdiff_t y_stride_j,
16+
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i,
17+
ptrdiff_t x_stride_j);
18+
19+
template <>
20+
class Operator<CausalSoftmax, Device::Type::kCambricon> : public CausalSoftmax {
21+
public:
22+
Operator(const Tensor input, Tensor out) : CausalSoftmax{input, out} {
23+
cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster,
24+
&cluster_count);
25+
}
26+
void operator()(const Tensor input, Tensor out) const override {
27+
auto queue = static_cast<cnrtQueue_t>(stream_ ? stream_ : 0);
28+
auto workspace{workspace_ ? workspace_ : default_workspace_};
29+
ptrdiff_t y_stride_b = ndim_ == 3 ? out_strides_[0] : 1;
30+
ptrdiff_t y_stride_i = ndim_ == 3 ? out_strides_[1] : out_strides_[0];
31+
ptrdiff_t y_stride_j = ndim_ == 3 ? out_strides_[2] : out_strides_[1];
32+
ptrdiff_t x_stride_b = ndim_ == 3 ? input_strides_[0] : 1;
33+
ptrdiff_t x_stride_i = ndim_ == 3 ? input_strides_[1] : input_strides_[0];
34+
ptrdiff_t x_stride_j = ndim_ == 3 ? input_strides_[2] : input_strides_[1];
35+
36+
DispatchFunc<
37+
List<DataType::kFloat16, DataType::kBFloat16, DataType::kFloat32>>(
38+
{input.dtype()},
39+
[&](auto input_tag) {
40+
using InputT = typename decltype(input_tag)::type;
41+
CausalSoftmaxUnion<InputT>(
42+
workspace, core_per_cluster, cluster_count, queue, out.data(),
43+
input.data(), batch_size_, seq_len_, total_seq_len_, y_stride_b,
44+
y_stride_i, y_stride_j, x_stride_b, x_stride_i, x_stride_j);
45+
},
46+
"CambriconCausalSoftmax::operator() - output dispatch");
47+
}
48+
49+
std::size_t workspace_size_in_bytes() const override { return 0; }
50+
51+
~Operator() {}
52+
53+
void *default_workspace_{nullptr};
54+
int core_per_cluster = 0;
55+
int cluster_count = 0;
56+
};
57+
58+
} // namespace infini::ops
59+
60+
#endif
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#include "causal_softmax.h"
2+
3+
__nram__ char nram_buffer[NRAM_MAX_SIZE];
4+
const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 4;
5+
6+
namespace infini::ops {
7+
8+
template <typename T>
9+
__mlu_func__ void ProcessSoftmaxStep(const T *input, T *output, float scalar,
10+
int num_elements, int stride,
11+
bool is_exp_phase) {
12+
// Calculate buffer sizes (split between float and T buffers)
13+
constexpr bool is_half = std::is_same_v<T, __half>;
14+
constexpr bool is_bfloat16 = std::is_same_v<T, __bang_bfloat16>;
15+
constexpr bool is_float = !is_half && !is_bfloat16;
16+
17+
const int chunk_size =
18+
SRC_MAX_SIZE /
19+
((is_half || is_bfloat16) ? (2 * sizeof(float)) : sizeof(float));
20+
float *float_buffer = (float *)nram_buffer;
21+
T *temp_buffer =
22+
is_float ? nullptr : (T *)(nram_buffer + chunk_size * sizeof(float));
23+
24+
// Common stride configurations
25+
const int src_stride = stride * sizeof(T);
26+
const int dst_stride = stride * sizeof(T);
27+
28+
int processed = 0;
29+
while (processed < num_elements) {
30+
int curr_batch = std::min(chunk_size, num_elements - processed);
31+
32+
// Gather input elements using 2D memcpy
33+
if constexpr (is_float) {
34+
__memcpy(
35+
float_buffer, (is_exp_phase ? input : output) + processed * stride,
36+
sizeof(float), GDRAM2NRAM, sizeof(float), src_stride, curr_batch - 1);
37+
} else {
38+
__memcpy(temp_buffer,
39+
(is_exp_phase ? input : output) + processed * stride, sizeof(T),
40+
GDRAM2NRAM, sizeof(T), src_stride, curr_batch - 1);
41+
42+
// Convert to float
43+
if constexpr (is_half) {
44+
__bang_half2float(float_buffer, reinterpret_cast<half *>(temp_buffer),
45+
curr_batch);
46+
} else if constexpr (is_bfloat16) {
47+
__bang_bfloat162float(float_buffer, temp_buffer, curr_batch);
48+
}
49+
}
50+
51+
// Common processing for all types
52+
if (is_exp_phase) {
53+
__bang_sub_scalar(float_buffer, float_buffer, scalar,
54+
curr_batch); // scalar is max_val
55+
__bang_active_exphp(float_buffer, float_buffer, curr_batch);
56+
} else {
57+
__bang_mul_scalar(float_buffer, float_buffer, scalar,
58+
curr_batch); // scalar is 1.0f/sum_val
59+
}
60+
61+
// Convert back and scatter output using 2D memcpy
62+
if constexpr (is_float) {
63+
__memcpy(output + processed * stride, float_buffer, sizeof(float),
64+
NRAM2GDRAM, dst_stride, sizeof(float), curr_batch - 1);
65+
} else {
66+
// Convert back to original type
67+
if constexpr (is_half) {
68+
__bang_float2half(reinterpret_cast<half *>(temp_buffer), float_buffer,
69+
curr_batch);
70+
} else if constexpr (is_bfloat16) {
71+
__bang_float2bfloat16(temp_buffer, float_buffer, curr_batch);
72+
}
73+
74+
// Scatter output
75+
__memcpy(output + processed * stride, temp_buffer, sizeof(T), NRAM2GDRAM,
76+
dst_stride, sizeof(T), curr_batch - 1);
77+
}
78+
79+
processed += curr_batch;
80+
}
81+
}
82+
83+
template <typename T>
84+
__mlu_global__ void CausalSoftmax(T *y, const T *x, size_t batch_size,
85+
size_t seq_len, size_t total_seq_len,
86+
ptrdiff_t y_stride_b, ptrdiff_t y_stride_i,
87+
ptrdiff_t y_stride_j, ptrdiff_t x_stride_b,
88+
ptrdiff_t x_stride_i, ptrdiff_t x_stride_j) {
89+
size_t task_id = taskId;
90+
size_t task_num = taskDimX * taskDimY;
91+
92+
size_t total_tasks = batch_size * seq_len;
93+
size_t tasks_per_core = (total_tasks + task_num - 1) / task_num;
94+
size_t start = task_id * tasks_per_core;
95+
size_t end = std::min(start + tasks_per_core, total_tasks);
96+
97+
const int max_batch = SRC_MAX_SIZE / sizeof(T);
98+
T *src = (T *)nram_buffer;
99+
float *dst = (float *)(nram_buffer + max_batch * sizeof(T));
100+
101+
for (size_t index = start; index < end; index++) {
102+
size_t batch = index / seq_len;
103+
size_t i = (index % seq_len);
104+
ptrdiff_t y_offset = batch * y_stride_b + i * y_stride_i;
105+
ptrdiff_t x_offset = batch * x_stride_b + i * x_stride_i;
106+
T *y_ = y + y_offset;
107+
const T *x_ = x + x_offset;
108+
109+
// Calculate the valid sequence length for this position.
110+
size_t valid_len = total_seq_len - seq_len + i + 1;
111+
112+
// Zero out future positions
113+
for (size_t j = valid_len; j < total_seq_len; j++) {
114+
y_[j * y_stride_j] = (T)0.0f;
115+
}
116+
117+
// Calculate max value using optimized reduction.
118+
float max_val =
119+
infini::ops::reduce::MaxBatched(x_, src, dst, valid_len, max_batch);
120+
121+
// Compute exp(x - max).
122+
ProcessSoftmaxStep(x_, y_, max_val, valid_len, x_stride_j, true);
123+
124+
// Calculate sum of exponentials.
125+
float sum_val =
126+
infini::ops::reduce::SumBatched(y_, src, dst, valid_len, max_batch);
127+
128+
// Normalize by sum
129+
ProcessSoftmaxStep(y_, y_, 1.0f / sum_val, valid_len, y_stride_j, false);
130+
}
131+
}
132+
133+
template <typename T>
134+
void CausalSoftmaxUnion(void *workspace, int core_per_cluster,
135+
int cluster_count, cnrtQueue_t queue, void *y,
136+
const void *x, size_t batch_size_, size_t seq_len_,
137+
size_t total_seq_len_, ptrdiff_t y_stride_b,
138+
ptrdiff_t y_stride_i, ptrdiff_t y_stride_j,
139+
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i,
140+
ptrdiff_t x_stride_j) {
141+
cnrtDim3_t kernel_dim;
142+
cnrtFunctionType_t kernel_type;
143+
144+
kernel_dim.x = core_per_cluster;
145+
kernel_dim.y = cluster_count;
146+
kernel_dim.z = 1;
147+
kernel_type = cnrtFuncTypeUnion1;
148+
149+
CausalSoftmax<T><<<kernel_dim, kernel_type, queue>>>(
150+
(T *)y, (const T *)x, batch_size_, seq_len_, total_seq_len_, y_stride_b,
151+
y_stride_i, y_stride_j, x_stride_b, x_stride_i, x_stride_j);
152+
153+
cnrtQueueSync(queue);
154+
}
155+
156+
template void CausalSoftmaxUnion<__half>(void *, int, int, cnrtQueue_t, void *,
157+
const void *, size_t, size_t, size_t,
158+
ptrdiff_t, ptrdiff_t, ptrdiff_t,
159+
ptrdiff_t, ptrdiff_t, ptrdiff_t);
160+
161+
template void CausalSoftmaxUnion<__bang_bfloat16>(
162+
void *, int, int, cnrtQueue_t, void *, const void *, size_t, size_t, size_t,
163+
ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t);
164+
165+
template void CausalSoftmaxUnion<float>(void *, int, int, cnrtQueue_t, void *,
166+
const void *, size_t, size_t, size_t,
167+
ptrdiff_t, ptrdiff_t, ptrdiff_t,
168+
ptrdiff_t, ptrdiff_t, ptrdiff_t);
169+
170+
} // namespace infini::ops

0 commit comments

Comments
 (0)