Skip to content

Commit dc253c9

Browse files
author
zhangyue
committed
feat(ascend): add custom AscendC kernels for RmsNorm and AddRmsNorm
Standalone AscendC kernel project with CMake build system. Includes op_host tiling, op_kernel device code, precision tests, and msprof benchmarks for both operators.
1 parent 33308a4 commit dc253c9

26 files changed

Lines changed: 2581 additions & 0 deletions
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
build/
2+
output/
3+
python/
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
cmake_minimum_required(VERSION 3.20 FATAL_ERROR)
2+
project(ascend-kernel LANGUAGES CXX)
3+
4+
set(CMAKE_CXX_STANDARD 17)
5+
6+
if(NOT CMAKE_BUILD_TYPE)
7+
set(CMAKE_BUILD_TYPE RELEASE)
8+
endif()
9+
10+
add_compile_options(-Wunused-value -Wcast-align -Wcast-qual -Wwrite-strings
11+
-Wsign-compare -Wextra)
12+
13+
if(${CMAKE_BUILD_TYPE} MATCHES "RELEASE")
14+
add_compile_options(-O3 -fvisibility=hidden -fvisibility-inlines-hidden
15+
-fstack-protector-strong -fPIE -fPIC)
16+
message(STATUS "build type set to RELEASE")
17+
else()
18+
add_compile_options(-g -rdynamic)
19+
endif()
20+
21+
set(PROJECT_OP_SRC_BASE ${PROJECT_SOURCE_DIR}/csrc)
22+
set(PROJECT_BUILD_PATH ${PROJECT_SOURCE_DIR}/build)
23+
set(PROJECT_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/output)
24+
25+
include(cmake/config_envs.cmake)
26+
include(cmake/config_ascend.cmake)
27+
28+
find_program(CCACHE_PROGRAM ccache)
29+
if(CCACHE_PROGRAM)
30+
message(STATUS "Found ccache: ${CCACHE_PROGRAM}")
31+
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
32+
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
33+
endif()
34+
35+
add_subdirectory(csrc)

src/ascend/custom_kernel/build.sh

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/bin/bash
2+
# Build custom AscendC kernels into libascend_kernel.so.
3+
set -e
4+
5+
SOC_VERSION="${1:-Ascend910_9382}"
6+
7+
# Detect CANN toolkit path.
8+
_CANN_TOOLKIT_INSTALL_PATH=$(grep "Toolkit_InstallPath" /etc/Ascend/ascend_cann_install.info | awk -F'=' '{print $2}')
9+
source "${_CANN_TOOLKIT_INSTALL_PATH}/set_env.sh"
10+
echo "CANN: ${ASCEND_TOOLKIT_HOME}"
11+
12+
ASCEND_INCLUDE_DIR=${ASCEND_TOOLKIT_HOME}/$(arch)-linux/include
13+
CURRENT_DIR=$(pwd)
14+
OUTPUT_DIR=${CURRENT_DIR}/output
15+
mkdir -p "${OUTPUT_DIR}"
16+
17+
BUILD_DIR=build
18+
rm -rf "${BUILD_DIR}"
19+
mkdir -p "${BUILD_DIR}"
20+
21+
cmake \
22+
-DASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \
23+
-DASCEND_INCLUDE_DIR="${ASCEND_INCLUDE_DIR}" \
24+
-DSOC_VERSION="${SOC_VERSION}" \
25+
-B "${BUILD_DIR}" \
26+
-S .
27+
28+
cmake --build "${BUILD_DIR}" -j 16
29+
30+
echo "Build complete. Output: ${OUTPUT_DIR}"
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
2+
if(DEFINED ASCEND_HOME_PATH)
3+
elseif(DEFINED ENV{ASCEND_HOME_PATH})
4+
set(ASCEND_HOME_PATH "$ENV{ASCEND_HOME_PATH}" CACHE PATH "ASCEND CANN package installation directory" FORCE)
5+
endif()
6+
7+
set(ASCEND_CANN_PACKAGE_PATH ${ASCEND_HOME_PATH})
8+
9+
if(EXISTS ${ASCEND_HOME_PATH}/tools/tikcpp/ascendc_kernel_cmake)
10+
set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/tools/tikcpp/ascendc_kernel_cmake)
11+
elseif(EXISTS ${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
12+
set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
13+
elseif(EXISTS ${ASCEND_HOME_PATH}/ascendc_devkit/tikcpp/samples/cmake)
14+
set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/ascendc_devkit/tikcpp/samples/cmake)
15+
else()
16+
message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the cann package is installed.")
17+
endif()
18+
19+
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
20+
21+
22+
message(STATUS "ASCEND_CANN_PACKAGE_PATH = ${ASCEND_CANN_PACKAGE_PATH}")
23+
message(STATUS "ASCEND_HOME_PATH = ${ASCEND_HOME_PATH}")
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# find python binary
2+
find_program(PYTHON_EXECUTABLE NAMES python3)
3+
4+
if (NOT EXISTS ${PYTHON_EXECUTABLE})
5+
message(FATAL_ERROR "python3 is not found, install python firstly")
6+
endif ()
7+
8+
# get torch path, torch npu path, pybind11 path via python script
9+
execute_process(
10+
COMMAND ${PYTHON_EXECUTABLE} "-c"
11+
"import torch; import torch_npu; import os; import pybind11; import sysconfig;
12+
torch_dir = os.path.realpath(os.path.dirname(torch.__file__));
13+
torch_npu_dir = os.path.realpath(os.path.dirname(torch_npu.__file__));
14+
pybind11_dir = os.path.realpath(os.path.dirname(pybind11.__file__));
15+
abi_enabled=torch.compiled_with_cxx11_abi();
16+
python_include_dir = sysconfig.get_path('include');
17+
print(torch_dir, torch_npu_dir, pybind11_dir, abi_enabled, python_include_dir, end='');
18+
quit(0)
19+
"
20+
RESULT_VARIABLE EXEC_RESULT
21+
OUTPUT_VARIABLE OUTPUT_ENV_DEFINES)
22+
23+
# if failed to run the python script
24+
if (NOT ${EXEC_RESULT} EQUAL 0)
25+
message(FATAL_ERROR "failed to get run python script to get ENVS like TORCH_DIR etc")
26+
else ()
27+
message(STATUS "run python script successfully, output string is [${OUTPUT_ENV_DEFINES}]")
28+
endif ()
29+
30+
# extract TORCH_DIR and set it
31+
execute_process(
32+
COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $1}'"
33+
OUTPUT_VARIABLE TORCH_DIR
34+
RESULT_VARIABLE EXEC_RESULT
35+
OUTPUT_STRIP_TRAILING_WHITESPACE
36+
)
37+
38+
# extract TORCH_NPU_DIR and set it
39+
execute_process(
40+
COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $2}'"
41+
OUTPUT_VARIABLE TORCH_NPU_DIR
42+
RESULT_VARIABLE EXEC_RESULT
43+
OUTPUT_STRIP_TRAILING_WHITESPACE
44+
)
45+
46+
# extract PYBIND11_DIR and set it
47+
execute_process(
48+
COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $3}'"
49+
OUTPUT_VARIABLE PYBIND11_DIR
50+
RESULT_VARIABLE EXEC_RESULT
51+
OUTPUT_STRIP_TRAILING_WHITESPACE
52+
)
53+
54+
# extract PYTROCH_ABI and set it
55+
execute_process(
56+
COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $4}'"
57+
OUTPUT_VARIABLE TORCH_API_ENABLED
58+
RESULT_VARIABLE EXEC_RESULT
59+
OUTPUT_STRIP_TRAILING_WHITESPACE
60+
)
61+
62+
# extract PYTHON_INCLUDE_DIR and set it
63+
execute_process(
64+
COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $5}'"
65+
OUTPUT_VARIABLE PYTHON_INCLUDE_DIR
66+
RESULT_VARIABLE EXEC_RESULT
67+
OUTPUT_STRIP_TRAILING_WHITESPACE
68+
)
69+
70+
message(STATUS "SOC_VERSION=${SOC_VERSION}")
71+
message(STATUS "TORCH_DIR=${TORCH_DIR}")
72+
message(STATUS "TORCH_NPU_DIR=${TORCH_NPU_DIR}")
73+
message(STATUS "PYBIND11_DIR=${PYBIND11_DIR}")
74+
message(STATUS "PYTHON_INCLUDE_DIR=${PYTHON_INCLUDE_DIR}")
75+
76+
# set _GLIBCXX_USE_CXX11_ABI
77+
if (${TORCH_API_ENABLED} STREQUAL "True")
78+
add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=1)
79+
message(STATUS "_GLIBCXX_USE_CXX11_ABI=1")
80+
else ()
81+
add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0)
82+
message(STATUS "_GLIBCXX_USE_CXX11_ABI=0")
83+
endif ()
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Set the library output dir to the project output for linking.
2+
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_OUTPUT_PATH})
3+
4+
# Host side files.
5+
file(GLOB OP_SRCS
6+
${PROJECT_OP_SRC_BASE}/register.cpp
7+
${PROJECT_OP_SRC_BASE}/ops/rms_norm/op_host/rms_norm.cpp
8+
)
9+
10+
# Set the shared library name.
11+
set(OP_PLUGIN_NAME ascend_kernel)
12+
13+
# Kernel side files (device code compiled by AscendC toolchain).
14+
ascendc_library(no_workspace_kernel STATIC
15+
${PROJECT_OP_SRC_BASE}/ops/rms_norm/op_kernel/rms_norm.cpp
16+
)
17+
18+
# Create shared library libascend_kernel.so.
19+
add_library(${OP_PLUGIN_NAME} SHARED ${OP_SRCS})
20+
21+
target_link_libraries(${OP_PLUGIN_NAME} PRIVATE
22+
no_workspace_kernel
23+
torch_npu
24+
ascendcl
25+
tiling_api
26+
nnopbase
27+
opapi
28+
register
29+
platform
30+
ascendalog
31+
dl
32+
)
33+
34+
target_link_directories(${OP_PLUGIN_NAME} PRIVATE
35+
${TORCH_DIR}/lib
36+
${TORCH_NPU_DIR}/lib
37+
)
38+
39+
target_include_directories(${OP_PLUGIN_NAME} PRIVATE
40+
${PROJECT_OP_SRC_BASE}/utils
41+
${PROJECT_SOURCE_DIR}/include
42+
${TORCH_DIR}/include
43+
${TORCH_DIR}/include/torch/csrc/api/include
44+
${TORCH_NPU_DIR}/include/third_party/acl/inc
45+
${TORCH_NPU_DIR}/include/third_party/hccl/inc
46+
${TORCH_NPU_DIR}/include
47+
${PYTHON_INCLUDE_DIR}
48+
${ASCEND_INCLUDE_DIR}/external
49+
${ASCEND_INCLUDE_DIR}/experiment/platform
50+
${ASCEND_INCLUDE_DIR}/experiment/runtime
51+
)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Licensed under the BSD 3-Clause License (the "License");
2+
// you may not use this file except in compliance with the License.
3+
// You may obtain a copy of the License at
4+
//
5+
// Unless required by applicable law or agreed to in writing, software
6+
// distributed under the License is distributed on an "AS IS" BASIS,
7+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8+
// See the License for the specific language governing permissions and
9+
// limitations under the License.
10+
11+
#ifndef OPS_H
12+
#define OPS_H
13+
14+
namespace ascend_kernel {
15+
16+
at::Tensor rms_norm(const at::Tensor &input, const at::Tensor &weight,
17+
double eps);
18+
19+
} // namespace ascend_kernel
20+
21+
#endif // OPS_H
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ascendc_add_operator(OP_NAME add_rms_norm)
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Copyright (c) 2025, InfiniTensor.
3+
* All rights reserved.
4+
*
5+
* SPDX-License-Identifier: BSD-3-Clause
6+
*/
7+
8+
#include "torch_kernel_helper.h"
9+
#include "tiling/platform/platform_ascendc.h"
10+
#include "aclrtlaunch_add_rms_norm.h"
11+
12+
namespace ascend_kernel {
13+
14+
std::vector<at::Tensor> add_rms_norm(const at::Tensor &x1,
15+
const at::Tensor &x2,
16+
const at::Tensor &weight, double eps) {
17+
// Input validation.
18+
TORCH_CHECK(x1.dim() > 0,
19+
"add_rms_norm: x1 must have at least 1 dimension");
20+
TORCH_CHECK(x1.sizes() == x2.sizes(),
21+
"add_rms_norm: x1 and x2 must have the same shape");
22+
TORCH_CHECK(x1.scalar_type() == x2.scalar_type(),
23+
"add_rms_norm: x1 and x2 must have the same dtype");
24+
TORCH_CHECK(x1.scalar_type() == at::kHalf ||
25+
x1.scalar_type() == at::kFloat,
26+
"add_rms_norm: only float16 and float32 are supported, got ",
27+
x1.scalar_type());
28+
TORCH_CHECK(weight.dim() == 1,
29+
"add_rms_norm: weight must be 1-dimensional");
30+
TORCH_CHECK(weight.size(0) == x1.size(-1),
31+
"add_rms_norm: weight size (", weight.size(0),
32+
") must match input last dim (", x1.size(-1), ")");
33+
34+
int64_t dimLength = x1.size(-1);
35+
int64_t totalRows = x1.numel() / dimLength;
36+
37+
if (totalRows == 0 || dimLength == 0) {
38+
return {at::empty_like(x1), at::empty_like(x1)};
39+
}
40+
41+
at::Tensor inp1 = x1.contiguous();
42+
at::Tensor inp2 = x2.contiguous();
43+
int64_t dtypeSize = inp1.element_size();
44+
45+
// Hardware parameters.
46+
auto ascendc_platform =
47+
platform_ascendc::PlatformAscendCManager::GetInstance();
48+
int64_t coreNum =
49+
static_cast<int64_t>(ascendc_platform->GetCoreNumAiv());
50+
uint64_t ubSize;
51+
ascendc_platform->GetCoreMemSize(platform_ascendc::CoreMemType::UB,
52+
ubSize);
53+
int64_t ubSizeLimit = static_cast<int64_t>(ubSize);
54+
55+
// Alignment (32-byte boundary).
56+
int64_t alignElements = 32 / dtypeSize;
57+
int64_t dimLengthAlign =
58+
((dimLength + alignElements - 1) / alignElements) * alignElements;
59+
60+
// UB capacity check.
61+
// fp16: inQ_x1(×2×2) + inQ_x2(×2×2) + outQ_y(×2×2) + outQ_xout(×2×2)
62+
// + fp32Buf1(×4) + fp32Buf2(×4) + weight(×4) = 16 + 12 = 28
63+
// fp32: inQ_x1(×2×4) + inQ_x2(×2×4) + outQ_y(×2×4) + outQ_xout(×2×4)
64+
// + weight(×4) = 32 + 4 = 36
65+
int64_t bufferCoefficient = (dtypeSize == 2) ? 28 : 36;
66+
int64_t maxDimLength =
67+
(ubSizeLimit - 1024) / bufferCoefficient;
68+
int64_t fpAlignElements = 32 / 4;
69+
maxDimLength =
70+
(maxDimLength / fpAlignElements) * fpAlignElements;
71+
TORCH_CHECK(dimLengthAlign <= maxDimLength,
72+
"add_rms_norm: dimLength ", dimLength,
73+
" (aligned ", dimLengthAlign,
74+
") exceeds UB capacity (max ", maxDimLength, ")");
75+
76+
// Padding.
77+
at::Tensor kernelInput1;
78+
at::Tensor kernelInput2;
79+
80+
if (dimLength != dimLengthAlign) {
81+
kernelInput1 = inp1.reshape({totalRows, dimLength});
82+
kernelInput1 = at::constant_pad_nd(
83+
kernelInput1, {0, dimLengthAlign - dimLength}, 0.0);
84+
kernelInput1 = kernelInput1.contiguous();
85+
86+
kernelInput2 = inp2.reshape({totalRows, dimLength});
87+
kernelInput2 = at::constant_pad_nd(
88+
kernelInput2, {0, dimLengthAlign - dimLength}, 0.0);
89+
kernelInput2 = kernelInput2.contiguous();
90+
} else {
91+
kernelInput1 =
92+
inp1.reshape({totalRows, dimLengthAlign}).contiguous();
93+
kernelInput2 =
94+
inp2.reshape({totalRows, dimLengthAlign}).contiguous();
95+
}
96+
97+
at::Tensor kernelOutputY = at::empty_like(kernelInput1);
98+
at::Tensor kernelOutputXOut = at::empty_like(kernelInput1);
99+
100+
// Weight: always pass as fp32, padded to `dimLengthAlign`.
101+
at::Tensor weightFloat = weight.contiguous().to(at::kFloat);
102+
103+
if (dimLength != dimLengthAlign) {
104+
weightFloat = at::constant_pad_nd(
105+
weightFloat, {0, dimLengthAlign - dimLength}, 0.0);
106+
}
107+
108+
weightFloat = weightFloat.contiguous();
109+
110+
// Block-level tiling (distribute rows across cores).
111+
int64_t usedCoreNum = std::min(totalRows, coreNum);
112+
int64_t formerLength =
113+
(totalRows + usedCoreNum - 1) / usedCoreNum;
114+
int64_t tailLength = formerLength - 1;
115+
int64_t formerNum = totalRows - tailLength * usedCoreNum;
116+
uint32_t blockDim = static_cast<uint32_t>(usedCoreNum);
117+
118+
// All EXEC_KERNEL_CMD args must be lvalues.
119+
float epsFloat = static_cast<float>(eps);
120+
int64_t dtypeSizeVal = dtypeSize;
121+
122+
EXEC_KERNEL_CMD(add_rms_norm, blockDim,
123+
kernelInput1, kernelInput2, weightFloat,
124+
kernelOutputY, kernelOutputXOut,
125+
totalRows, dimLength, dimLengthAlign,
126+
formerNum, formerLength, tailLength,
127+
epsFloat, dtypeSizeVal);
128+
129+
// Remove padding and reshape back to original shape.
130+
at::Tensor outputY = kernelOutputY;
131+
at::Tensor outputXOut = kernelOutputXOut;
132+
133+
if (dimLength != dimLengthAlign) {
134+
outputY = outputY.narrow(-1, 0, dimLength).contiguous();
135+
outputXOut = outputXOut.narrow(-1, 0, dimLength).contiguous();
136+
}
137+
138+
outputY = outputY.reshape(x1.sizes());
139+
outputXOut = outputXOut.reshape(x1.sizes());
140+
141+
return {outputY, outputXOut};
142+
}
143+
144+
} // namespace ascend_kernel

0 commit comments

Comments
 (0)