Skip to content

Commit 8c92b2e

Browse files
bitzyzvoltjia
andauthored
feat: add Cambricon RMSNorm (#19)
* feat: Add RMSNorm op in cambricon backend. * refactor: make `Cast` utility to use `Device::Type` template parameter * refactor: add `Caster` mixin * refactor: rename `cast**` to `caster**` * fix: fix the mlu naming to google c++ naming style * chore: format files with `clang-format` * refactor: update CUDA kernels to use `Caster` * fix: fix rmsnorm dispatch to use one dispatch --------- Co-authored-by: Jiacheng Huang <huangjiacheng0709@outlook.com>
1 parent 56f3330 commit 8c92b2e

File tree

37 files changed

+801
-209
lines changed

37 files changed

+801
-209
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ if(WITH_CAMBRICON)
179179
endif()
180180

181181
# If all other platforms are not enabled, CPU is enabled by default.
182-
if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE)
182+
if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON)
183183
add_compile_definitions(WITH_CPU=1)
184184
endif()
185185

src/CMakeLists.txt

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,48 @@ if(WITH_MOORE)
127127
endif()
128128

129129
if(WITH_CAMBRICON)
130-
target_compile_definitions(infiniops PUBLIC WITH_CAMBRICON=1)
130+
file(GLOB_RECURSE CAMBRICON_MLU_SOURCES CONFIGURE_DEPENDS "cambricon/*/*.mlu")
131+
find_program(CNCC_COMPILER cncc HINTS "${NEUWARE_HOME}/bin" "$ENV{NEUWARE_HOME}/bin" /usr/local/neuware/bin)
132+
if(CNCC_COMPILER)
133+
message(STATUS "Found cncc: ${CNCC_COMPILER}")
134+
set(MLU_COMPILE_OPTS
135+
-c --bang-mlu-arch=mtp_592 -O3 -fPIC -Wall -Werror -std=c++17 -pthread
136+
-I${CMAKE_CURRENT_SOURCE_DIR} -I${NEUWARE_HOME}/include
137+
-idirafter /usr/local/neuware/lib/clang/11.1.0/include
138+
)
139+
function(compile_mlu_file src_file)
140+
get_filename_component(name ${src_file} NAME_WE)
141+
get_filename_component(path ${src_file} DIRECTORY)
142+
set(out_file "${CMAKE_CURRENT_BINARY_DIR}/${path}/${name}.o")
143+
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${path}")
144+
add_custom_command(OUTPUT ${out_file}
145+
COMMAND ${CNCC_COMPILER} ${MLU_COMPILE_OPTS} -c ${src_file} -o ${out_file}
146+
DEPENDS ${src_file}
147+
COMMENT "Building MLU kernel: ${src_file}"
148+
)
149+
set_property(DIRECTORY APPEND PROPERTY CAMBRICON_OBJECTS ${out_file})
150+
endfunction()
151+
foreach(src ${CAMBRICON_MLU_SOURCES})
152+
compile_mlu_file(${src})
153+
endforeach()
154+
get_directory_property(CAMBRICON_OBJECT_FILES CAMBRICON_OBJECTS)
155+
if(CAMBRICON_OBJECT_FILES)
156+
target_sources(infiniops PRIVATE ${CAMBRICON_OBJECT_FILES})
157+
endif()
158+
else()
159+
message(WARNING "cncc compiler not found. MLU kernels will not be compiled.")
160+
endif()
161+
target_compile_definitions(infiniops PRIVATE WITH_CAMBRICON=1)
131162

132163
target_include_directories(infiniops PUBLIC "${NEUWARE_HOME}/include")
133164
target_link_libraries(infiniops PUBLIC ${CAMBRICON_RUNTIME_LIB} ${CAMBRICON_CNNL_LIB} ${CAMBRICON_CNNL_EXTRA_LIB} ${CAMBRICON_PAPI_LIB})
134165

166+
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
167+
target_compile_options(infiniops PUBLIC
168+
"$<$<COMPILE_LANGUAGE:CXX>:SHELL:-idirafter /usr/local/neuware/lib/clang/11.1.0/include>"
169+
)
170+
endif()
171+
135172
list(APPEND DEVICE_LIST "cambricon")
136173
endif()
137174

src/base/rms_norm.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,17 @@ namespace infini::ops {
1212
class RmsNorm : public Operator<RmsNorm> {
1313
public:
1414
RmsNorm(const Tensor input, const Tensor weight, float eps, Tensor out)
15-
: eps_{eps},
15+
: input_shape_{input.shape()},
1616
out_shape_{out.shape()},
17-
input_shape_{input.shape()},
18-
out_strides_{out.strides()},
1917
input_strides_{input.strides()},
18+
out_strides_{out.strides()},
19+
eps_{eps},
2020
dim_{out.size(-1)},
2121
ndim_{out.ndim()},
2222
batch_size_{ndim_ == 2 ? out.size(-2) : out.size(-3)},
23-
nhead_{ndim_ == 2 ? 1 : out.size(-2)} {}
23+
nhead_{ndim_ == 2 ? 1 : out.size(-2)} {
24+
assert(input.dtype() == out.dtype());
25+
}
2426

2527
RmsNorm(const Tensor input, const Tensor weight, Tensor out)
2628
: RmsNorm{input, weight, 1e-6f, out} {}

src/cambricon/common.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,74 @@
22
#define INFINI_OPS_CAMBRICON_COMMON_H_
33

44
#include <cnnl.h>
5+
#include <cnrt.h>
56

67
#include "data_type.h"
8+
#include "device.h"
9+
10+
#define NRAM_MAX_SIZE (1024 * 240)
11+
12+
#ifdef __BANG__
13+
14+
namespace infini::ops::reduce {
15+
16+
constexpr int batch_size = 128 / sizeof(float);
17+
18+
__mlu_func__ void SumInternal(float* dst, float* src, int max_batch) {
19+
const int width = max_batch / batch_size;
20+
21+
if (width >= 4) {
22+
__bang_sumpool(dst, src, batch_size, 1, width, 1, width, 1, 1);
23+
__bang_reduce_sum(dst, dst, batch_size);
24+
} else {
25+
float sum = 0.0f;
26+
for (int i = 0; i < max_batch; ++i) {
27+
sum += src[i];
28+
}
29+
dst[0] = sum;
30+
}
31+
}
32+
33+
} // namespace infini::ops::reduce
34+
35+
#endif // __BANG__
736

837
namespace infini::ops::cnnl_utils {
938

1039
inline cnnlDataType_t GetDataType(DataType dtype) {
1140
switch (dtype) {
41+
case DataType::kInt8:
42+
return CNNL_DTYPE_INT8;
43+
case DataType::kUInt8:
44+
return CNNL_DTYPE_UINT8;
1245
case DataType::kInt32:
1346
return CNNL_DTYPE_INT32;
47+
case DataType::kInt64:
48+
return CNNL_DTYPE_INT64;
1449
case DataType::kFloat16:
1550
return CNNL_DTYPE_HALF;
1651
case DataType::kFloat32:
1752
return CNNL_DTYPE_FLOAT;
53+
case DataType::kBFloat16:
54+
return CNNL_DTYPE_BFLOAT16;
55+
case DataType::kFloat64:
56+
return CNNL_DTYPE_DOUBLE;
1857
default:
1958
return CNNL_DTYPE_INVALID;
2059
}
2160
}
2261

2362
} // namespace infini::ops::cnnl_utils
2463

64+
namespace infini::ops::cnrt_utils {
65+
66+
inline void GetLaunchConfig(const Device& device, int* core_per_cluster,
67+
int* cluster_count) {
68+
int device_id = device.index();
69+
cnrtDeviceGetAttribute(cluster_count, cnrtAttrClusterCount, device_id);
70+
cnrtDeviceGetAttribute(core_per_cluster, cnrtAttrMcorePerCluster, device_id);
71+
}
72+
73+
} // namespace infini::ops::cnrt_utils
74+
2575
#endif

0 commit comments

Comments
 (0)