diff --git a/.ci/images/ascend/Dockerfile b/.ci/images/ascend/Dockerfile index 3ff79e1c..a542b99e 100644 --- a/.ci/images/ascend/Dockerfile +++ b/.ci/images/ascend/Dockerfile @@ -18,4 +18,12 @@ RUN pip install --no-cache-dir --progress off \ pytest-xdist \ ruff +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling/lib/linux/aarch64:${LD_LIBRARY_PATH} +ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH} +ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH} +ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME} +ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit + WORKDIR /workspace diff --git a/CMakeLists.txt b/CMakeLists.txt index 7f4c5cb4..7c4829d8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,8 @@ option(WITH_CAMBRICON "Enable Cambricon backend" OFF) option(WITH_MOORE "Enable Moore backend" OFF) option(WITH_ASCEND "Enable Ascend backend" OFF) +option(BUILD_CUSTOM_KERNEL "Build custom AscendC kernel PyTorch extension (requires torch_npu)" OFF) + option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 0023c7e9..31eb92d6 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -94,7 +94,7 @@ def __init__(self, name, constructors, calls): def _find_optional_tensor_params(op_name): """Return a set of parameter names declared as `std::optional` in - the base header. `libclang` resolves the type to `int` when the STL + the base header. libclang resolves the type to ``int`` when the STL headers are not fully available, so we fall back to a regex scan of the source text. """ @@ -103,24 +103,39 @@ def _find_optional_tensor_params(op_name): return set(re.findall(r"std::optional\s+(\w+)", source)) +def _find_vector_tensor_params(op_name): + """Return a set of parameter names declared as `std::vector` in + the base header. + """ + source = (_BASE_DIR / f"{op_name}.h").read_text() + + return set(re.findall(r"std::vector\s+(\w+)", source)) + + def _generate_pybind11(operator): optional_tensor_params = _find_optional_tensor_params(operator.name) + vector_tensor_params = _find_vector_tensor_params(operator.name) def _is_optional_tensor(arg): if arg.spelling in optional_tensor_params: return True - return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling + def _is_vector_tensor(arg): + if arg.spelling in vector_tensor_params: + return True + return "std::vector" in arg.type.spelling and "Tensor" in arg.type.spelling + def _generate_params(node): parts = [] for arg in node.get_arguments(): if arg.spelling == "stream": continue - if _is_optional_tensor(arg): parts.append(f"std::optional {arg.spelling}") + elif _is_vector_tensor(arg): + parts.append(f"std::vector {arg.spelling}") else: param = arg.type.spelling.replace("const Tensor", "py::object").replace( "Tensor", "py::object" @@ -135,9 +150,10 @@ def _generate_arguments(node): for arg in node.get_arguments(): if arg.spelling == "stream": continue - if _is_optional_tensor(arg): args.append(f"OptionalTensorFromPybind11Handle({arg.spelling})") + elif _is_vector_tensor(arg): + args.append(f"VectorTensorFromPybind11Handle({arg.spelling})") elif "Tensor" in arg.type.spelling: args.append(f"TensorFromPybind11Handle({arg.spelling})") else: @@ -167,23 +183,23 @@ def _generate_call(op_name, call, method=True): if not method: params = ( - f"{call_params}, std::uintptr_t stream, std::size_t implementation_index" + f"{call_params}, std::size_t implementation_index, std::uintptr_t stream" if call_params - else "std::uintptr_t stream, std::size_t implementation_index" + else "std::size_t implementation_index, std::uintptr_t stream" ) py_args = _generate_py_args(call) py_args_str = f"{py_args}, " if py_args else "" return ( f' m.def("{op_name}", []({params}) {{\n' + f" Config config;\n" + f" config.set_implementation_index(implementation_index);\n" f" Handle handle;\n" f" if (stream) {{\n" f" handle.set_stream(reinterpret_cast(stream));\n" f" }}\n" - f" Config config;\n" - f" config.set_implementation_index(implementation_index);\n" f" return Self::call(handle, config, {call_args});\n" - f' }}, {py_args_str}py::kw_only(), py::arg("stream") = 0, py::arg("implementation_index") = 0);' + f' }}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0, py::arg("stream") = 0);' ) return f""" .def("__call__", [](const Self& self, {call_params}) {{ @@ -442,7 +458,7 @@ def _get_all_ops(devices): nargs="+", default="cpu", type=str, - help="Devices to use. Please pick from `cpu`, `nvidia`, `cambricon`, `ascend`, `metax`, `moore`, `iluvatar`, `kunlun`, `hygon`, and `qy`. (default: `cpu`)", + help="Devices to use. Please pick from cpu, nvidia, cambricon, ascend, metax, moore, iluvatar, kunlun, hygon, and qy. (default: cpu)", ) args = parser.parse_args() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2eb2591d..06313b82 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -178,8 +178,10 @@ if(WITH_ASCEND) "ascend/*.cc" "ascend/*.cpp" ) - # Exclude `kernel_impl.cpp` — AscendC device code, not compiled by the host C++ compiler. + # Exclude kernel_impl.cpp — AscendC device code, not compiled by the host C++ compiler. list(FILTER ASCEND_SOURCES EXCLUDE REGEX ".*kernel_impl\\.cpp$") + # Exclude custom_kernel/ — standalone PyTorch extension, built separately. + list(FILTER ASCEND_SOURCES EXCLUDE REGEX ".*/custom_kernel/.*") target_compile_definitions(infiniops PUBLIC WITH_ASCEND=1) target_sources(infiniops PRIVATE ${ASCEND_SOURCES}) @@ -215,7 +217,38 @@ if(WITH_ASCEND) "${ASCEND_HOME}/lib64/libopapi.so" "${ASCEND_HAL_LIB}") + # ATB (Ascend Transformer Boost) — provides fused operators like + # PagedAttention and ReshapeAndCache that are graph-capture safe. + set(ATB_HOME_DIR "$ENV{ATB_HOME_PATH}") + if(NOT ATB_HOME_DIR) + # Default search path under CANN nnal directory. + file(GLOB ATB_SEARCH_DIRS "/usr/local/Ascend/nnal/atb/*/atb/cxx_abi_1") + if(ATB_SEARCH_DIRS) + list(SORT ATB_SEARCH_DIRS ORDER DESCENDING) + list(GET ATB_SEARCH_DIRS 0 ATB_HOME_DIR) + endif() + endif() + + if(ATB_HOME_DIR AND EXISTS "${ATB_HOME_DIR}/include/atb/operation.h") + message(STATUS "ATB found: ${ATB_HOME_DIR}") + target_compile_definitions(infiniops PUBLIC INFINI_HAS_ATB=1) + target_include_directories(infiniops PUBLIC "${ATB_HOME_DIR}/include") + target_link_libraries(infiniops PUBLIC "${ATB_HOME_DIR}/lib/libatb.so") + else() + message(STATUS "ATB not found — ATB-based operators disabled") + endif() + list(APPEND DEVICE_LIST "ascend") + + # Custom AscendC kernels (PyTorch extension, requires torch_npu). + if(BUILD_CUSTOM_KERNEL) + add_subdirectory(ascend/custom_kernel) + + # Link the compiled AscendC kernel objects into infiniops so that + # custom kernel implementations (e.g. RmsNorm index 1) can call + # them via the generated launch functions. + target_compile_definitions(infiniops PUBLIC INFINI_HAS_CUSTOM_RMS_NORM=1) + endif() endif() target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) @@ -257,6 +290,17 @@ if(GENERATE_PYTHON_BINDINGS) target_include_directories(ops PRIVATE ${PROJECT_SOURCE_DIR}) target_link_libraries(ops PRIVATE infiniops) + # Custom AscendC kernel objects must be linked directly into ops + # because the AscendC toolchain compiles host stubs with hidden + # visibility — `libinfiniops.so` cannot re-export those symbols. + # The `Operator<..., 1>` template instantiations that call + # `aclrtlaunch_*` live in `ops.cc`, so link here with + # `--whole-archive` to ensure all launch functions are available. + if(BUILD_CUSTOM_KERNEL) + target_link_libraries(ops PRIVATE + -Wl,--whole-archive no_workspace_kernel -Wl,--no-whole-archive) + endif() + set_target_properties(infiniops PROPERTIES INSTALL_RPATH "$ORIGIN") set_target_properties(ops PROPERTIES INSTALL_RPATH "$ORIGIN") diff --git a/src/ascend/add/kernel.h b/src/ascend/add/kernel.h new file mode 100644 index 00000000..2c93b5a5 --- /dev/null +++ b/src/ascend/add/kernel.h @@ -0,0 +1,82 @@ +#ifndef INFINI_OPS_ASCEND_ADD_KERNEL_H_ +#define INFINI_OPS_ASCEND_ADD_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_add.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/add.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Add { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Add(input, other, out), + in_cache_(input), + oth_cache_(other), + out_cache_(out) { + // `aclCreateScalar` stores the pointer rather than copying the value, so + // `alpha_storage_*` must remain alive for the lifetime of `alpha_`. + // The alpha scalar type must match the tensor dtype: use int64 for integer + // dtypes and float for floating-point dtypes. + if (ascend::isIntegerDtype(input.dtype())) { + alpha_ = aclCreateScalar(&alpha_int_storage_, ACL_INT64); + } else { + alpha_ = aclCreateScalar(&alpha_float_storage_, ACL_FLOAT); + } + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + aclDestroyScalar(alpha_); + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_oth = oth_cache_.get(const_cast(other.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnAddGetWorkspaceSize(t_in, t_oth, alpha_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_oth, + const_cast(other.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnAdd(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache oth_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; + + float alpha_float_storage_ = + 1.0f; // Stable address for `aclCreateScalar` (float). + int64_t alpha_int_storage_ = + 1; // Stable address for `aclCreateScalar` (int). + aclScalar* alpha_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/add_rms_norm/kernel.h b/src/ascend/add_rms_norm/kernel.h new file mode 100644 index 00000000..7db8a91a --- /dev/null +++ b/src/ascend/add_rms_norm/kernel.h @@ -0,0 +1,134 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_add.h" +#include "aclnn_rms_norm.h" +#include "ascend/add_rms_norm/registry.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +// Decomposed implementation: `aclnnAdd` + `aclnnRmsNorm`. +// +// The fused `aclnnAddRmsNorm` API has ~200 us host-side launch overhead that +// dominates small-tensor dispatch. Decomposing into two fast ACLNN calls +// reduces host dispatch from ~224 us to ~56 us (4x faster) with negligible +// NPU-side impact for inference tensor sizes. +template <> +class Operator : public AddRmsNorm { + public: + Operator(const Tensor x1, const Tensor x2, const Tensor gamma, float eps, + Tensor y_out, Tensor x_out) + : AddRmsNorm(x1, x2, gamma, eps, y_out, x_out), + x1_cache_(x1), + x2_cache_(x2), + gamma_cache_(gamma), + y_out_cache_(y_out), + x_out_cache_(x_out) { + // Alpha scalar for `aclnnAdd` (x_out = x1 + 1.0 * x2). + alpha_ = aclCreateScalar(&alpha_storage_, ACL_FLOAT); + + // `aclnnRmsNorm` writes `rstd` as a required side output. + // Size computed here; buffer obtained from pool in `operator()`. + rstd_shape_ = {static_cast(batch_size_), + static_cast(nhead_)}; + rstd_size_ = batch_size_ * nhead_ * sizeof(float); + } + + ~Operator() { + if (add_exec_) aclDestroyAclOpExecutor(add_exec_); + if (norm_exec_) aclDestroyAclOpExecutor(norm_exec_); + aclDestroyScalar(alpha_); + if (rstd_tensor_) aclDestroyTensor(rstd_tensor_); + } + + void operator()(const Tensor x1, const Tensor x2, const Tensor gamma, + float eps, Tensor y_out, Tensor x_out) const override { + auto t_x1 = x1_cache_.get(const_cast(x1.data())); + auto t_x2 = x2_cache_.get(const_cast(x2.data())); + auto t_gamma = gamma_cache_.get(const_cast(gamma.data())); + auto t_y_out = y_out_cache_.get(y_out.data()); + auto t_x_out = x_out_cache_.get(x_out.data()); + auto stream = static_cast(stream_); + + // Step 1: x_out = x1 + x2. + if (!add_exec_) { + aclnnAddGetWorkspaceSize(t_x1, t_x2, alpha_, t_x_out, &add_ws_, + &add_exec_); + aclSetAclOpExecutorRepeatable(add_exec_); + } else { + aclSetInputTensorAddr(add_exec_, 0, t_x1, const_cast(x1.data())); + aclSetInputTensorAddr(add_exec_, 1, t_x2, const_cast(x2.data())); + aclSetOutputTensorAddr(add_exec_, 0, t_x_out, x_out.data()); + } + auto& add_arena = ascend::workspacePool().ensure(stream, add_ws_); + aclnnAdd(add_arena.buf, add_ws_, add_exec_, stream); + + // Obtain shared rstd buffer from pool. + auto& rstd_arena = + ascend::workspacePool().ensure(stream, rstd_size_, "temp"); + + // Lazily create rstd tensor descriptor on first call. + if (!rstd_tensor_) { + rstd_tensor_ = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, + rstd_shape_.data(), 2, rstd_arena.buf); + } else { + aclSetRawTensorAddr(rstd_tensor_, rstd_arena.buf); + } + + // Step 2: y_out = rms_norm(x_out, gamma, eps). + if (!norm_exec_) { + aclnnRmsNormGetWorkspaceSize(t_x_out, t_gamma, eps, t_y_out, rstd_tensor_, + &norm_ws_, &norm_exec_); + aclSetAclOpExecutorRepeatable(norm_exec_); + } else { + aclSetInputTensorAddr(norm_exec_, 0, t_x_out, x_out.data()); + aclSetInputTensorAddr(norm_exec_, 1, t_gamma, + const_cast(gamma.data())); + aclSetOutputTensorAddr(norm_exec_, 0, t_y_out, y_out.data()); + aclSetOutputTensorAddr(norm_exec_, 1, rstd_tensor_, rstd_arena.buf); + } + auto& norm_arena = ascend::workspacePool().ensure(stream, norm_ws_); + aclnnRmsNorm(norm_arena.buf, norm_ws_, norm_exec_, stream); + } + + private: + mutable ascend::AclTensorCache x1_cache_; + + mutable ascend::AclTensorCache x2_cache_; + + mutable ascend::AclTensorCache gamma_cache_; + + mutable ascend::AclTensorCache y_out_cache_; + + mutable ascend::AclTensorCache x_out_cache_; + + float alpha_storage_ = 1.0f; + + aclScalar* alpha_ = nullptr; + + std::vector rstd_shape_; + + uint64_t rstd_size_ = 0; + + mutable aclTensor* rstd_tensor_ = nullptr; + + mutable aclOpExecutor* add_exec_ = nullptr; + + mutable uint64_t add_ws_ = 0; + + mutable aclOpExecutor* norm_exec_ = nullptr; + + mutable uint64_t norm_ws_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/add_rms_norm/kernel_custom.h b/src/ascend/add_rms_norm/kernel_custom.h new file mode 100644 index 00000000..5e80638a --- /dev/null +++ b/src/ascend/add_rms_norm/kernel_custom.h @@ -0,0 +1,174 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_CUSTOM_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_CUSTOM_H_ + +#ifdef INFINI_HAS_CUSTOM_ADD_RMS_NORM + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cast.h" +#include "ascend/add_rms_norm/registry.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/add_rms_norm.h" +#include "operator.h" + +// Forward-declare the generated AscendC kernel launch function. +// This symbol is provided by the `no_workspace_kernel` static library +// built from +// `ascend/custom_kernel/csrc/ops/add_rms_norm/op_kernel/add_rms_norm.cpp` via +// `ascendc_library()`. +extern "C" uint32_t aclrtlaunch_add_rms_norm( + uint32_t blockDim, void* stream, void* x1, void* x2, void* weight, void* y, + void* x_out, int64_t totalRows, int64_t dimLength, int64_t dimLengthAlign, + int64_t formerNum, int64_t formerLength, int64_t tailLength, float eps, + int64_t dtypeSize); + +namespace infini::ops { + +// Custom AscendC fused AddRmsNorm kernel (implementation index 2). +// +// A single-kernel implementation that computes x_out = x1 + x2 followed by +// y = rms_norm(x_out, gamma, eps) in one launch, avoiding the decomposed +// `aclnnAdd` + `aclnnRmsNorm` calls (index 0) or the fused `aclnnAddRmsNorm` +// call (index 1). Migrated from the custom RmsNorm kernel (index 1 of +// RmsNorm). +// +// Select via `implementation_index=2` in Python: +// infini.ops.add_rms_norm(x1, x2, gamma, eps, y_out, x_out, +// implementation_index=2, stream=s) +// +// Requirements: +// - Input last dimension must be 32-byte aligned (divisible by 16 for fp16 +// or 8 for fp32). All standard LLM hidden dimensions satisfy this. +// - Weight must have the same dtype as input. +// - The custom kernel binary must be linked (`BUILD_CUSTOM_KERNEL=ON`). +template <> +class Operator : public AddRmsNorm { + public: + Operator(const Tensor x1, const Tensor x2, const Tensor gamma, float eps, + Tensor y_out, Tensor x_out) + : AddRmsNorm(x1, x2, gamma, eps, y_out, x_out) { + // Dtype size in bytes. + dtype_size_ = (x1.dtype() == DataType::kFloat16) ? 2 : 4; + + // Alignment check (32-byte boundary). + int64_t align_elems = 32 / dtype_size_; + dim_length_align_ = + ((static_cast(dim_) + align_elems - 1) / align_elems) * + align_elems; + assert( + static_cast(dim_) == dim_length_align_ && + "custom `AddRmsNorm` kernel requires 32-byte aligned last dimension"); + + total_rows_ = + static_cast(batch_size_) * static_cast(nhead_); + + // For fp16 input, weight needs fp32 conversion because the custom + // kernel always reads weight as fp32. + needs_weight_cast_ = (dtype_size_ == 2); + + if (needs_weight_cast_) { + // Allocate persistent fp32 weight buffer on device. + size_t fp32_bytes = static_cast(dim_) * sizeof(float); + aclrtMalloc(&weight_fp32_data_, fp32_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + // AclTensorCache for the cast source (fp16 weight descriptor). + weight_src_cache_ = ascend::AclTensorCache({static_cast(dim_)}, + ACL_FLOAT16, nullptr); + + // AclTensorCache for the cast destination (fp32 weight buffer). + weight_dst_cache_ = ascend::AclTensorCache({static_cast(dim_)}, + ACL_FLOAT, weight_fp32_data_); + } + } + + ~Operator() { + if (!ascend::isAclRuntimeAlive()) return; + if (cast_exec_) aclDestroyAclOpExecutor(cast_exec_); + if (weight_fp32_data_) aclrtFree(weight_fp32_data_); + } + + void operator()(const Tensor x1, const Tensor x2, const Tensor gamma, + float eps, Tensor y_out, Tensor x_out) const override { + auto stream = static_cast(stream_); + + // Determine fp32 weight pointer. + void* weight_fp32; + + if (needs_weight_cast_) { + // Only re-cast when the weight data pointer changes. Model weights + // are fixed after loading, so this typically runs once on the first + // call and is skipped on all subsequent calls. + const void* cur_weight = gamma.data(); + + if (cur_weight != last_weight_ptr_) { + auto t_src = weight_src_cache_.get(const_cast(cur_weight)); + auto t_dst = weight_dst_cache_.get(weight_fp32_data_); + + if (!cast_exec_) { + aclnnCastGetWorkspaceSize(t_src, ACL_FLOAT, t_dst, &cast_ws_, + &cast_exec_); + aclSetAclOpExecutorRepeatable(cast_exec_); + } else { + aclSetInputTensorAddr(cast_exec_, 0, t_src, + const_cast(cur_weight)); + aclSetOutputTensorAddr(cast_exec_, 0, t_dst, weight_fp32_data_); + } + + auto& arena = ascend::workspacePool().ensure(stream, cast_ws_); + aclnnCast(arena.buf, cast_ws_, cast_exec_, stream); + last_weight_ptr_ = cur_weight; + } + + weight_fp32 = weight_fp32_data_; + } else { + // Input is fp32 — weight is already fp32. + weight_fp32 = const_cast(gamma.data()); + } + + // Block-level tiling: distribute rows across cores. + static constexpr int64_t kMaxBlockDim = 40; + int64_t used_cores = std::min(total_rows_, kMaxBlockDim); + int64_t former_length = (total_rows_ + used_cores - 1) / used_cores; + int64_t tail_length = former_length - 1; + int64_t former_num = total_rows_ - tail_length * used_cores; + uint32_t block_dim = static_cast(used_cores); + + // Launch custom AscendC kernel. + aclrtlaunch_add_rms_norm( + block_dim, stream, const_cast(x1.data()), + const_cast(x2.data()), weight_fp32, y_out.data(), x_out.data(), + total_rows_, static_cast(dim_), dim_length_align_, former_num, + former_length, tail_length, eps, dtype_size_); + } + + private: + int64_t dtype_size_; + + int64_t dim_length_align_; + + int64_t total_rows_; + + bool needs_weight_cast_; + + void* weight_fp32_data_ = nullptr; + + mutable ascend::AclTensorCache weight_src_cache_; + + mutable ascend::AclTensorCache weight_dst_cache_; + + mutable const void* last_weight_ptr_ = nullptr; + + mutable aclOpExecutor* cast_exec_ = nullptr; + + mutable uint64_t cast_ws_ = 0; +}; + +} // namespace infini::ops + +#endif // INFINI_HAS_CUSTOM_ADD_RMS_NORM +#endif // INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_CUSTOM_H_ diff --git a/src/ascend/add_rms_norm/kernel_fused.h b/src/ascend/add_rms_norm/kernel_fused.h new file mode 100644 index 00000000..4d67fa0a --- /dev/null +++ b/src/ascend/add_rms_norm/kernel_fused.h @@ -0,0 +1,121 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_FUSED_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_FUSED_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_add_rms_norm.h" +#include "ascend/add_rms_norm/registry.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +// Fused implementation via `aclnnAddRmsNorm` (implementation index 1). +// +// Computes x_out = x1 + x2 and y_out = rms_norm(x_out, gamma, eps) in a +// single CANN launch. The fused API has higher host-side launch overhead +// (~200 us) compared to the decomposed `aclnnAdd` + `aclnnRmsNorm` path (~39 +// us), but may offer better NPU-side efficiency for large tensors where kernel +// fusion reduces memory traffic. +// +// Select via `implementation_index=1` in Python: +// infini.ops.add_rms_norm(..., implementation_index=1, stream=s) +template <> +class Operator : public AddRmsNorm { + public: + Operator(const Tensor x1, const Tensor x2, const Tensor gamma, float eps, + Tensor y_out, Tensor x_out) + : AddRmsNorm(x1, x2, gamma, eps, y_out, x_out), + x1_cache_(x1), + x2_cache_(x2), + gamma_cache_(gamma), + y_out_cache_(y_out), + x_out_cache_(x_out) { + // `aclnnAddRmsNorm` requires `rstdOut` to have the same ndim as x1, with + // the last gamma.ndim() dimensions set to 1. For example: + // x1 shape(2, 32, 128), gamma shape(128) -> rstdOut shape(2, 32, 1) + // x1 shape(64, 128), gamma shape(128) -> rstdOut shape(64, 1) + fused_rstd_shape_.reserve(ndim_); + for (size_t i = 0; i < ndim_ - gamma.ndim(); ++i) { + fused_rstd_shape_.push_back(static_cast(x1.size(i))); + } + for (size_t i = 0; i < gamma.ndim(); ++i) { + fused_rstd_shape_.push_back(1); + } + + size_t rstd_elems = 1; + for (auto d : fused_rstd_shape_) { + rstd_elems *= static_cast(d); + } + size_t rstd_bytes = rstd_elems * sizeof(float); + aclrtMalloc(&rstd_data_, rstd_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + rstd_tensor_ = aclCreateTensor( + fused_rstd_shape_.data(), + static_cast(fused_rstd_shape_.size()), ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, fused_rstd_shape_.data(), + static_cast(fused_rstd_shape_.size()), rstd_data_); + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + if (rstd_tensor_) aclDestroyTensor(rstd_tensor_); + if (rstd_data_) aclrtFree(rstd_data_); + } + + void operator()(const Tensor x1, const Tensor x2, const Tensor gamma, + float eps, Tensor y_out, Tensor x_out) const override { + auto t_x1 = x1_cache_.get(const_cast(x1.data())); + auto t_x2 = x2_cache_.get(const_cast(x2.data())); + auto t_gamma = gamma_cache_.get(const_cast(gamma.data())); + auto t_y_out = y_out_cache_.get(y_out.data()); + auto t_x_out = x_out_cache_.get(x_out.data()); + auto stream = static_cast(stream_); + + if (!executor_) { + aclnnAddRmsNormGetWorkspaceSize( + t_x1, t_x2, t_gamma, static_cast(eps), t_y_out, rstd_tensor_, + t_x_out, &ws_size_, &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_x1, const_cast(x1.data())); + aclSetInputTensorAddr(executor_, 1, t_x2, const_cast(x2.data())); + aclSetInputTensorAddr(executor_, 2, t_gamma, + const_cast(gamma.data())); + aclSetOutputTensorAddr(executor_, 0, t_y_out, y_out.data()); + // rstd at output index 1 has a stable address — no update needed. + aclSetOutputTensorAddr(executor_, 2, t_x_out, x_out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnAddRmsNorm(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache x1_cache_; + + mutable ascend::AclTensorCache x2_cache_; + + mutable ascend::AclTensorCache gamma_cache_; + + mutable ascend::AclTensorCache y_out_cache_; + + mutable ascend::AclTensorCache x_out_cache_; + + std::vector fused_rstd_shape_; + + void* rstd_data_ = nullptr; + + aclTensor* rstd_tensor_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/add_rms_norm/registry.h b/src/ascend/add_rms_norm/registry.h new file mode 100644 index 00000000..d48de306 --- /dev/null +++ b/src/ascend/add_rms_norm/registry.h @@ -0,0 +1,15 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_REGISTRY_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_REGISTRY_H_ + +#include "base/add_rms_norm.h" + +namespace infini::ops { + +template <> +struct ActiveImplementationsImpl { + using type = List<0, 1>; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/atb_common_.h b/src/ascend/atb_common_.h new file mode 100644 index 00000000..fc1439b8 --- /dev/null +++ b/src/ascend/atb_common_.h @@ -0,0 +1,95 @@ +#ifndef INFINI_OPS_ASCEND_ATB_COMMON__H_ +#define INFINI_OPS_ASCEND_ATB_COMMON__H_ + +#ifdef INFINI_HAS_ATB + +#include +#include +#include +#include + +#include "acl/acl.h" +#include "ascend/data_type_.h" +#include "atb/context.h" +#include "atb/operation.h" +#include "atb/types.h" +#include "tensor.h" + +namespace infini::ops::ascend { + +// Thread-local ATB context. +// +// ATB requires a `Context` for Setup/Execute. Creating one per call is +// expensive (internal tiling buffer allocation), so we cache one per thread. +// `SetExecuteStream` is called before every `Execute` to match the caller's +// stream. +inline atb::Context*& threadLocalAtbContext() { + thread_local atb::Context* ctx = nullptr; + + return ctx; +} + +inline atb::Context* getAtbContext(aclrtStream stream) { + auto*& ctx = threadLocalAtbContext(); + + if (!ctx) { + atb::Status s = atb::CreateContext(&ctx); + assert(s == atb::NO_ERROR && "atb::CreateContext failed"); + } + + atb::Status s = ctx->SetExecuteStream(stream); + assert(s == atb::NO_ERROR && "atb::Context::SetExecuteStream failed"); + + return ctx; +} + +// Build an `atb::Tensor` from an InfiniOps Tensor. +// +// Sets dtype, ND format, shape dimensions, and the device data pointer. +// The caller must keep the InfiniOps Tensor alive for the duration of the +// ATB operation. +inline atb::Tensor toAtbTensor(const Tensor& t) { + atb::Tensor out; + out.desc.dtype = toAclDtype(t.dtype()); + out.desc.format = ACL_FORMAT_ND; + out.desc.shape.dimNum = t.ndim(); + assert(t.ndim() <= atb::MAX_DIM); + + for (uint64_t i = 0; i < t.ndim(); ++i) { + out.desc.shape.dims[i] = static_cast(t.size(i)); + } + + out.deviceData = const_cast(t.data()); + out.dataSize = static_cast(t.numel()) * t.element_size(); + + return out; +} + +// Build an `atb::Tensor` from explicit shape, dtype, and data pointer. +// +// Useful for sub-views of a larger buffer (e.g. K-cache and V-cache halves +// of a fused KV cache tensor). +inline atb::Tensor toAtbTensor(const std::vector& shape, + aclDataType dtype, void* data, + uint64_t data_size) { + atb::Tensor out; + out.desc.dtype = dtype; + out.desc.format = ACL_FORMAT_ND; + out.desc.shape.dimNum = shape.size(); + assert(shape.size() <= atb::MAX_DIM); + + for (size_t i = 0; i < shape.size(); ++i) { + out.desc.shape.dims[i] = shape[i]; + } + + out.deviceData = data; + out.dataSize = data_size; + + return out; +} + +} // namespace infini::ops::ascend + +#endif // INFINI_HAS_ATB + +#endif // INFINI_OPS_ASCEND_ATB_COMMON__H_ diff --git a/src/ascend/cast/kernel.h b/src/ascend/cast/kernel.h new file mode 100644 index 00000000..645f05af --- /dev/null +++ b/src/ascend/cast/kernel.h @@ -0,0 +1,60 @@ +#ifndef INFINI_OPS_ASCEND_CAST_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAST_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cast.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/cast.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Cast { + public: + Operator(const Tensor input, Tensor out) + : Cast(input, out), + in_cache_(input), + out_cache_(out), + acl_out_dtype_(ascend::toAclDtype(out.dtype())) {} + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + } + + void operator()(const Tensor input, Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnCastGetWorkspaceSize(t_in, acl_out_dtype_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnCast(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache out_cache_; + + aclDataType acl_out_dtype_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/cat/kernel.h b/src/ascend/cat/kernel.h new file mode 100644 index 00000000..0d3d0976 --- /dev/null +++ b/src/ascend/cat/kernel.h @@ -0,0 +1,94 @@ +#ifndef INFINI_OPS_ASCEND_CAT_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAT_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/acl_meta.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cat.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/cat.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Cat { + public: + Operator(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) + : Cat(first_input, rest_inputs, dim, out), out_cache_(out) { + // Build AclTensorCache for each input tensor. + in_caches_.reserve(input_count_); + in_caches_.emplace_back(first_input); + for (const auto& t : rest_inputs) { + in_caches_.emplace_back(t); + } + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + if (tensor_list_) aclDestroyTensorList(tensor_list_); + } + + void operator()(const Tensor first_input, std::vector rest_inputs, + int64_t /*dim*/, Tensor out) const override { + auto stream = static_cast(stream_); + + // Collect all input tensors in order. + std::vector inputs; + inputs.reserve(input_count_); + inputs.push_back(&first_input); + for (const auto& t : rest_inputs) { + inputs.push_back(&t); + } + + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + // First call: create descriptors, tensor list, and executor. + std::vector acl_tensors(input_count_); + for (size_t i = 0; i < input_count_; ++i) { + acl_tensors[i] = + in_caches_[i].get(const_cast(inputs[i]->data())); + } + + tensor_list_ = + aclCreateTensorList(const_cast(acl_tensors.data()), + static_cast(input_count_)); + + aclnnCatGetWorkspaceSize(tensor_list_, dim_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + // Subsequent calls: update data pointers on cached descriptors via + // `aclSetRawTensorAddr`. The executor holds references to the same + // `aclTensor*` objects inside `tensor_list_`, so updating their data + // pointers is sufficient — no `aclSetInputTensorAddr` needed. + for (size_t i = 0; i < input_count_; ++i) { + in_caches_[i].get(const_cast(inputs[i]->data())); + } + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnCat(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable std::vector in_caches_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclTensorList* tensor_list_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/causal_softmax/kernel.h b/src/ascend/causal_softmax/kernel.h new file mode 100644 index 00000000..1b8c148e --- /dev/null +++ b/src/ascend/causal_softmax/kernel.h @@ -0,0 +1,159 @@ +#ifndef INFINI_OPS_ASCEND_CAUSAL_SOFTMAX_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAUSAL_SOFTMAX_KERNEL_H_ + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_copy.h" +#include "aclnn_masked_fill_scalar.h" +#include "aclnn_softmax.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/causal_softmax.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +// Implements causal softmax via three ACLNN calls: +// 1. `InplaceCopy(temp, input)` — stride-aware copy to contiguous temp +// buffer. +// 2. `InplaceMaskedFillScalar(temp, mask, -inf)` — apply upper-triangle mask. +// 3. `Softmax(temp, dim=-1, out)` — softmax over the last dimension. +// +// The boolean causal mask is pre-computed and uploaded to device once in the +// constructor. Its shape (seq_len, total_seq_len) broadcasts over the batch. +template <> +class Operator : public CausalSoftmax { + public: + Operator(const Tensor input, Tensor out) + : CausalSoftmax(input, out), in_cache_(input), out_cache_(out) { + // Compute temp buffer size — allocated lazily from pool in `operator()`. + size_t n_elems = input.numel(); + size_t elem_bytes = kDataTypeToSize.at(dtype_); + temp_size_ = n_elems * elem_bytes; + + // Build a contiguous Tensor descriptor — data pointer set on first use. + Tensor temp_t{nullptr, input.shape(), input.dtype(), input.device()}; + temp_cache_ = ascend::AclTensorCache(temp_t); + + // Causal mask: mask[i][j] = 1 when position j must be masked for query i. + // Shape (seq_len, total_seq_len) – broadcasts over the batch dimension. + size_t mask_elems = seq_len_ * total_seq_len_; + std::vector mask_host(mask_elems, 0); + + for (size_t i = 0; i < seq_len_; ++i) { + auto vis_end = static_cast(total_seq_len_ - seq_len_ + i); + + for (auto j = vis_end + 1; j < static_cast(total_seq_len_); + ++j) { + mask_host[i * total_seq_len_ + j] = 1; + } + } + + aclrtMalloc(&mask_buf_, mask_elems, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpy(mask_buf_, mask_elems, mask_host.data(), mask_elems, + ACL_MEMCPY_HOST_TO_DEVICE); + + std::vector mshape = {static_cast(seq_len_), + static_cast(total_seq_len_)}; + std::vector mstrides = {static_cast(total_seq_len_), 1}; + mask_tensor_ = aclCreateTensor(mshape.data(), mshape.size(), ACL_BOOL, + mstrides.data(), 0, ACL_FORMAT_ND, + mshape.data(), mshape.size(), mask_buf_); + + // Scalar -inf for the masked-fill step. `aclCreateScalar` stores the + // pointer rather than copying, so `neg_inf_storage_` must stay alive with + // the object. + neg_inf_ = aclCreateScalar(&neg_inf_storage_, ACL_FLOAT); + // Workspaces are allocated lazily on first `operator()` call. + } + + ~Operator() { + if (copy_exec_) aclDestroyAclOpExecutor(copy_exec_); + if (fill_exec_) aclDestroyAclOpExecutor(fill_exec_); + if (softmax_exec_) aclDestroyAclOpExecutor(softmax_exec_); + aclrtFree(mask_buf_); + aclDestroyTensor(mask_tensor_); + aclDestroyScalar(neg_inf_); + } + + void operator()(const Tensor input, Tensor out) const override { + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_out = out_cache_.get(out.data()); + auto stream = static_cast(stream_); + + // Obtain shared temp buffer from pool. + auto& temp = ascend::workspacePool().ensure(stream, temp_size_, "temp"); + auto t_temp = temp_cache_.get(temp.buf); + + // Step 1: copy input (possibly non-contiguous) into contiguous temp. + if (!copy_exec_) { + aclnnInplaceCopyGetWorkspaceSize(t_temp, t_in, ©_ws_, ©_exec_); + aclSetAclOpExecutorRepeatable(copy_exec_); + } else { + aclSetInputTensorAddr(copy_exec_, 0, t_temp, temp.buf); + aclSetInputTensorAddr(copy_exec_, 1, t_in, + const_cast(input.data())); + } + auto& copy_arena = ascend::workspacePool().ensure(stream, copy_ws_); + aclnnInplaceCopy(copy_arena.buf, copy_ws_, copy_exec_, stream); + + // Step 2: mask upper-triangle positions with -inf in-place. + // `mask_tensor_` and `neg_inf_` have stable addresses — first-call only. + if (!fill_exec_) { + aclnnInplaceMaskedFillScalarGetWorkspaceSize( + t_temp, mask_tensor_, neg_inf_, &fill_ws_, &fill_exec_); + aclSetAclOpExecutorRepeatable(fill_exec_); + } + auto& fill_arena = ascend::workspacePool().ensure(stream, fill_ws_); + aclnnInplaceMaskedFillScalar(fill_arena.buf, fill_ws_, fill_exec_, stream); + + // Step 3: softmax over the last dimension -> out. + if (!softmax_exec_) { + constexpr int64_t kLastDim = -1; + aclnnSoftmaxGetWorkspaceSize(t_temp, kLastDim, t_out, &softmax_ws_, + &softmax_exec_); + aclSetAclOpExecutorRepeatable(softmax_exec_); + } else { + aclSetOutputTensorAddr(softmax_exec_, 0, t_out, out.data()); + } + auto& softmax_arena = ascend::workspacePool().ensure(stream, softmax_ws_); + aclnnSoftmax(softmax_arena.buf, softmax_ws_, softmax_exec_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache temp_cache_; + + float neg_inf_storage_ = -std::numeric_limits::infinity(); + + uint64_t temp_size_ = 0; + + void* mask_buf_ = nullptr; + + aclTensor* mask_tensor_ = nullptr; + + aclScalar* neg_inf_ = nullptr; + + mutable aclOpExecutor* copy_exec_ = nullptr; + + mutable uint64_t copy_ws_ = 0; + + mutable aclOpExecutor* fill_exec_ = nullptr; + + mutable uint64_t fill_ws_ = 0; + + mutable aclOpExecutor* softmax_exec_ = nullptr; + + mutable uint64_t softmax_ws_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/common.h b/src/ascend/common.h index fba4766b..b6a927e5 100644 --- a/src/ascend/common.h +++ b/src/ascend/common.h @@ -11,11 +11,23 @@ namespace infini::ops::ascend { -// Build an `aclTensor` descriptor from an InfiniOps `Tensor`. +// Check whether the ACL runtime is still usable. +// +// During process shutdown the CANN runtime may be torn down before C++ +// static destructors run. Calling `aclrtGetDevice` is the cheapest +// probe — it fails once the runtime is gone. Destructors that call +// ACL/ATB APIs must guard with this to avoid use-after-finalize crashes. +inline bool isAclRuntimeAlive() { + int32_t dev_id = -1; + + return aclrtGetDevice(&dev_id) == ACL_SUCCESS; +} + +// Build an aclTensor descriptor from an InfiniOps Tensor. // // When `transpose_last2` is true the last two dimensions are swapped in the -// descriptor (shape and strides) without copying data. This is used by `Gemm` -// and `MatMul` to express a transpose via the view. +// descriptor (shape and strides) without copying data. This is used by GEMM +// and Matmul to express a transpose via the view. inline aclTensor* buildAclTensor(const Tensor& t, bool transpose_last2 = false) { std::vector shape(t.shape().begin(), t.shape().end()); @@ -45,12 +57,131 @@ inline aclTensor* buildAclTensor(const Tensor& t, std::vector storage_shape = {storage_elems}; return aclCreateTensor( - shape.data(), static_cast(shape.size()), ToAclDtype(t.dtype()), + shape.data(), static_cast(shape.size()), toAclDtype(t.dtype()), strides.data(), /*storageOffset=*/0, ACL_FORMAT_ND, storage_shape.data(), static_cast(storage_shape.size()), const_cast(t.data())); } +// Pre-computed tensor metadata for descriptor reuse. +// +// Stores shape, strides, storage_shape, and dtype once (avoiding per-call heap +// allocations). The aclTensor descriptor is created on the first `get()` call +// and its data pointer is updated in-place via `aclSetRawTensorAddr` on +// subsequent calls. +class AclTensorCache { + public: + AclTensorCache() = default; + + // Construct from explicit metadata (for device buffers not wrapped in + // Tensor). Computes contiguous strides from shape. + AclTensorCache(std::vector shape, aclDataType dtype, void* data) + : shape_(std::move(shape)), dtype_(dtype) { + strides_.resize(shape_.size()); + int64_t stride = 1; + for (int i = static_cast(shape_.size()) - 1; i >= 0; --i) { + strides_[i] = stride; + stride *= shape_[i]; + } + storage_shape_ = {stride}; + + if (data) { + tensor_ = aclCreateTensor( + shape_.data(), static_cast(shape_.size()), dtype_, + strides_.data(), + /*storageOffset=*/0, ACL_FORMAT_ND, storage_shape_.data(), + static_cast(storage_shape_.size()), data); + } + } + + explicit AclTensorCache(const Tensor& t, bool transpose_last2 = false) + : dtype_{toAclDtype(t.dtype())} { + shape_.assign(t.shape().begin(), t.shape().end()); + strides_.assign(t.strides().begin(), t.strides().end()); + + if (transpose_last2 && shape_.size() >= 2) { + auto n = shape_.size(); + std::swap(shape_[n - 2], shape_[n - 1]); + std::swap(strides_[n - 2], strides_[n - 1]); + } + + int64_t storage_elems = 1; + for (size_t i = 0; i < shape_.size(); ++i) { + if (shape_[i] == 0) { + storage_elems = 0; + break; + } + if (strides_[i] > 0 && shape_[i] > 1) { + storage_elems += static_cast(shape_[i] - 1) * strides_[i]; + } + } + storage_shape_ = {storage_elems}; + } + + ~AclTensorCache() { + if (tensor_) { + aclDestroyTensor(tensor_); + } + } + + AclTensorCache(const AclTensorCache&) = delete; + + AclTensorCache& operator=(const AclTensorCache&) = delete; + + AclTensorCache(AclTensorCache&& o) noexcept + : shape_(std::move(o.shape_)), + strides_(std::move(o.strides_)), + storage_shape_(std::move(o.storage_shape_)), + dtype_(o.dtype_), + tensor_(o.tensor_) { + o.tensor_ = nullptr; + } + + AclTensorCache& operator=(AclTensorCache&& o) noexcept { + if (this != &o) { + if (tensor_) { + aclDestroyTensor(tensor_); + } + shape_ = std::move(o.shape_); + strides_ = std::move(o.strides_); + storage_shape_ = std::move(o.storage_shape_); + dtype_ = o.dtype_; + tensor_ = o.tensor_; + o.tensor_ = nullptr; + } + + return *this; + } + + // Update the data pointer and return the cached descriptor. + aclTensor* get(void* data) const { + if (tensor_) { + aclSetRawTensorAddr(tensor_, data); + + return tensor_; + } + + tensor_ = aclCreateTensor( + shape_.data(), static_cast(shape_.size()), dtype_, + strides_.data(), + /*storageOffset=*/0, ACL_FORMAT_ND, storage_shape_.data(), + static_cast(storage_shape_.size()), data); + + return tensor_; + } + + private: + std::vector shape_; + + std::vector strides_; + + std::vector storage_shape_; + + aclDataType dtype_{ACL_DT_UNDEFINED}; + + mutable aclTensor* tensor_ = nullptr; +}; + } // namespace infini::ops::ascend #endif diff --git a/src/ascend/custom_kernel/.gitignore b/src/ascend/custom_kernel/.gitignore new file mode 100644 index 00000000..0c983f0a --- /dev/null +++ b/src/ascend/custom_kernel/.gitignore @@ -0,0 +1,3 @@ +build/ +output/ +python/ diff --git a/src/ascend/custom_kernel/CMakeLists.txt b/src/ascend/custom_kernel/CMakeLists.txt new file mode 100644 index 00000000..64ec8967 --- /dev/null +++ b/src/ascend/custom_kernel/CMakeLists.txt @@ -0,0 +1,35 @@ +cmake_minimum_required(VERSION 3.20 FATAL_ERROR) +project(ascend-kernel LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE RELEASE) +endif() + +add_compile_options(-Wunused-value -Wcast-align -Wcast-qual -Wwrite-strings + -Wsign-compare -Wextra) + +if(${CMAKE_BUILD_TYPE} MATCHES "RELEASE") + add_compile_options(-O3 -fvisibility=hidden -fvisibility-inlines-hidden + -fstack-protector-strong -fPIE -fPIC) + message(STATUS "build type set to RELEASE") +else() + add_compile_options(-g -rdynamic) +endif() + +set(PROJECT_OP_SRC_BASE ${PROJECT_SOURCE_DIR}/csrc) +set(PROJECT_BUILD_PATH ${PROJECT_SOURCE_DIR}/build) +set(PROJECT_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/output) + +include(cmake/config_envs.cmake) +include(cmake/config_ascend.cmake) + +find_program(CCACHE_PROGRAM ccache) +if(CCACHE_PROGRAM) + message(STATUS "Found ccache: ${CCACHE_PROGRAM}") + set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") + set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") +endif() + +add_subdirectory(csrc) diff --git a/src/ascend/custom_kernel/build.sh b/src/ascend/custom_kernel/build.sh new file mode 100755 index 00000000..76ec445e --- /dev/null +++ b/src/ascend/custom_kernel/build.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Build custom AscendC kernels into libascend_kernel.so. +set -e + +SOC_VERSION="${1:-Ascend910_9382}" + +# Detect CANN toolkit path. +_CANN_TOOLKIT_INSTALL_PATH=$(grep "Toolkit_InstallPath" /etc/Ascend/ascend_cann_install.info | awk -F'=' '{print $2}') +source "${_CANN_TOOLKIT_INSTALL_PATH}/set_env.sh" +echo "CANN: ${ASCEND_TOOLKIT_HOME}" + +ASCEND_INCLUDE_DIR=${ASCEND_TOOLKIT_HOME}/$(arch)-linux/include +CURRENT_DIR=$(pwd) +OUTPUT_DIR=${CURRENT_DIR}/output +mkdir -p "${OUTPUT_DIR}" + +BUILD_DIR=build +rm -rf "${BUILD_DIR}" +mkdir -p "${BUILD_DIR}" + +cmake \ + -DASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ + -DASCEND_INCLUDE_DIR="${ASCEND_INCLUDE_DIR}" \ + -DSOC_VERSION="${SOC_VERSION}" \ + -B "${BUILD_DIR}" \ + -S . + +cmake --build "${BUILD_DIR}" -j 16 + +echo "Build complete. Output: ${OUTPUT_DIR}" diff --git a/src/ascend/custom_kernel/cmake/config_ascend.cmake b/src/ascend/custom_kernel/cmake/config_ascend.cmake new file mode 100644 index 00000000..1c3785cd --- /dev/null +++ b/src/ascend/custom_kernel/cmake/config_ascend.cmake @@ -0,0 +1,23 @@ + +if(DEFINED ASCEND_HOME_PATH) +elseif(DEFINED ENV{ASCEND_HOME_PATH}) + set(ASCEND_HOME_PATH "$ENV{ASCEND_HOME_PATH}" CACHE PATH "ASCEND CANN package installation directory" FORCE) +endif() + +set(ASCEND_CANN_PACKAGE_PATH ${ASCEND_HOME_PATH}) + +if(EXISTS ${ASCEND_HOME_PATH}/tools/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/tools/tikcpp/ascendc_kernel_cmake) +elseif(EXISTS ${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake) +elseif(EXISTS ${ASCEND_HOME_PATH}/ascendc_devkit/tikcpp/samples/cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/ascendc_devkit/tikcpp/samples/cmake) +else() + message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the cann package is installed.") +endif() + +include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) + + +message(STATUS "ASCEND_CANN_PACKAGE_PATH = ${ASCEND_CANN_PACKAGE_PATH}") +message(STATUS "ASCEND_HOME_PATH = ${ASCEND_HOME_PATH}") diff --git a/src/ascend/custom_kernel/cmake/config_envs.cmake b/src/ascend/custom_kernel/cmake/config_envs.cmake new file mode 100644 index 00000000..d5373981 --- /dev/null +++ b/src/ascend/custom_kernel/cmake/config_envs.cmake @@ -0,0 +1,83 @@ +# find python binary +find_program(PYTHON_EXECUTABLE NAMES python3) + +if (NOT EXISTS ${PYTHON_EXECUTABLE}) + message(FATAL_ERROR "python3 is not found, install python firstly") +endif () + +# get torch path, torch npu path, pybind11 path via python script +execute_process( + COMMAND ${PYTHON_EXECUTABLE} "-c" + "import torch; import torch_npu; import os; import pybind11; import sysconfig; +torch_dir = os.path.realpath(os.path.dirname(torch.__file__)); +torch_npu_dir = os.path.realpath(os.path.dirname(torch_npu.__file__)); +pybind11_dir = os.path.realpath(os.path.dirname(pybind11.__file__)); +abi_enabled=torch.compiled_with_cxx11_abi(); +python_include_dir = sysconfig.get_path('include'); +print(torch_dir, torch_npu_dir, pybind11_dir, abi_enabled, python_include_dir, end=''); +quit(0) + " + RESULT_VARIABLE EXEC_RESULT + OUTPUT_VARIABLE OUTPUT_ENV_DEFINES) + +# if failed to run the python script +if (NOT ${EXEC_RESULT} EQUAL 0) + message(FATAL_ERROR "failed to get run python script to get ENVS like TORCH_DIR etc") +else () + message(STATUS "run python script successfully, output string is [${OUTPUT_ENV_DEFINES}]") +endif () + +# extract TORCH_DIR and set it +execute_process( + COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $1}'" + OUTPUT_VARIABLE TORCH_DIR + RESULT_VARIABLE EXEC_RESULT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# extract TORCH_NPU_DIR and set it +execute_process( + COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $2}'" + OUTPUT_VARIABLE TORCH_NPU_DIR + RESULT_VARIABLE EXEC_RESULT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# extract PYBIND11_DIR and set it +execute_process( + COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $3}'" + OUTPUT_VARIABLE PYBIND11_DIR + RESULT_VARIABLE EXEC_RESULT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# extract PYTROCH_ABI and set it +execute_process( + COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $4}'" + OUTPUT_VARIABLE TORCH_API_ENABLED + RESULT_VARIABLE EXEC_RESULT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# extract PYTHON_INCLUDE_DIR and set it +execute_process( + COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $5}'" + OUTPUT_VARIABLE PYTHON_INCLUDE_DIR + RESULT_VARIABLE EXEC_RESULT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +message(STATUS "SOC_VERSION=${SOC_VERSION}") +message(STATUS "TORCH_DIR=${TORCH_DIR}") +message(STATUS "TORCH_NPU_DIR=${TORCH_NPU_DIR}") +message(STATUS "PYBIND11_DIR=${PYBIND11_DIR}") +message(STATUS "PYTHON_INCLUDE_DIR=${PYTHON_INCLUDE_DIR}") + +# set _GLIBCXX_USE_CXX11_ABI +if (${TORCH_API_ENABLED} STREQUAL "True") + add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=1) + message(STATUS "_GLIBCXX_USE_CXX11_ABI=1") +else () + add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0) + message(STATUS "_GLIBCXX_USE_CXX11_ABI=0") +endif () diff --git a/src/ascend/custom_kernel/csrc/CMakeLists.txt b/src/ascend/custom_kernel/csrc/CMakeLists.txt new file mode 100644 index 00000000..c1b31502 --- /dev/null +++ b/src/ascend/custom_kernel/csrc/CMakeLists.txt @@ -0,0 +1,51 @@ +# Set the library output dir to the project output for linking. +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_OUTPUT_PATH}) + +# Host side files. +file(GLOB OP_SRCS + ${PROJECT_OP_SRC_BASE}/register.cpp + ${PROJECT_OP_SRC_BASE}/ops/rms_norm/op_host/rms_norm.cpp +) + +# Set the shared library name. +set(OP_PLUGIN_NAME ascend_kernel) + +# Kernel side files (device code compiled by AscendC toolchain). +ascendc_library(no_workspace_kernel STATIC + ${PROJECT_OP_SRC_BASE}/ops/rms_norm/op_kernel/rms_norm.cpp +) + +# Create shared library libascend_kernel.so. +add_library(${OP_PLUGIN_NAME} SHARED ${OP_SRCS}) + +target_link_libraries(${OP_PLUGIN_NAME} PRIVATE + no_workspace_kernel + torch_npu + ascendcl + tiling_api + nnopbase + opapi + register + platform + ascendalog + dl +) + +target_link_directories(${OP_PLUGIN_NAME} PRIVATE + ${TORCH_DIR}/lib + ${TORCH_NPU_DIR}/lib +) + +target_include_directories(${OP_PLUGIN_NAME} PRIVATE + ${PROJECT_OP_SRC_BASE}/utils + ${PROJECT_SOURCE_DIR}/include + ${TORCH_DIR}/include + ${TORCH_DIR}/include/torch/csrc/api/include + ${TORCH_NPU_DIR}/include/third_party/acl/inc + ${TORCH_NPU_DIR}/include/third_party/hccl/inc + ${TORCH_NPU_DIR}/include + ${PYTHON_INCLUDE_DIR} + ${ASCEND_INCLUDE_DIR}/external + ${ASCEND_INCLUDE_DIR}/experiment/platform + ${ASCEND_INCLUDE_DIR}/experiment/runtime +) diff --git a/src/ascend/custom_kernel/csrc/ops.h b/src/ascend/custom_kernel/csrc/ops.h new file mode 100644 index 00000000..dcb26c7c --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops.h @@ -0,0 +1,21 @@ +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_H +#define OPS_H + +namespace ascend_kernel { + +at::Tensor rms_norm(const at::Tensor& input, const at::Tensor& weight, + double eps); + +} // namespace ascend_kernel + +#endif // OPS_H diff --git a/src/ascend/custom_kernel/csrc/ops/add_rms_norm/CMakeLists.txt b/src/ascend/custom_kernel/csrc/ops/add_rms_norm/CMakeLists.txt new file mode 100644 index 00000000..1748afc0 --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops/add_rms_norm/CMakeLists.txt @@ -0,0 +1 @@ +ascendc_add_operator(OP_NAME add_rms_norm) diff --git a/src/ascend/custom_kernel/csrc/ops/add_rms_norm/op_host/add_rms_norm.cpp b/src/ascend/custom_kernel/csrc/ops/add_rms_norm/op_host/add_rms_norm.cpp new file mode 100644 index 00000000..122abad1 --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops/add_rms_norm/op_host/add_rms_norm.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2025, InfiniTensor. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include "aclrtlaunch_add_rms_norm.h" +#include "tiling/platform/platform_ascendc.h" +#include "torch_kernel_helper.h" + +namespace ascend_kernel { + +std::vector add_rms_norm(const at::Tensor& x1, const at::Tensor& x2, + const at::Tensor& weight, double eps) { + // Input validation. + TORCH_CHECK(x1.dim() > 0, "add_rms_norm: x1 must have at least 1 dimension"); + TORCH_CHECK(x1.sizes() == x2.sizes(), + "add_rms_norm: x1 and x2 must have the same shape"); + TORCH_CHECK(x1.scalar_type() == x2.scalar_type(), + "add_rms_norm: x1 and x2 must have the same dtype"); + TORCH_CHECK(x1.scalar_type() == at::kHalf || x1.scalar_type() == at::kFloat, + "add_rms_norm: only float16 and float32 are supported, got ", + x1.scalar_type()); + TORCH_CHECK(weight.dim() == 1, "add_rms_norm: weight must be 1-dimensional"); + TORCH_CHECK(weight.size(0) == x1.size(-1), "add_rms_norm: weight size (", + weight.size(0), ") must match input last dim (", x1.size(-1), + ")"); + + int64_t dimLength = x1.size(-1); + int64_t totalRows = x1.numel() / dimLength; + + if (totalRows == 0 || dimLength == 0) { + return {at::empty_like(x1), at::empty_like(x1)}; + } + + at::Tensor inp1 = x1.contiguous(); + at::Tensor inp2 = x2.contiguous(); + int64_t dtypeSize = inp1.element_size(); + + // Hardware parameters. + auto ascendc_platform = + platform_ascendc::PlatformAscendCManager::GetInstance(); + int64_t coreNum = static_cast(ascendc_platform->GetCoreNumAiv()); + uint64_t ubSize; + ascendc_platform->GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + int64_t ubSizeLimit = static_cast(ubSize); + + // Alignment (32-byte boundary). + int64_t alignElements = 32 / dtypeSize; + int64_t dimLengthAlign = + ((dimLength + alignElements - 1) / alignElements) * alignElements; + + // UB capacity check. + // fp16: inQ_x1(×2×2) + inQ_x2(×2×2) + outQ_y(×2×2) + outQ_xout(×2×2) + // + fp32Buf1(×4) + fp32Buf2(×4) + weight(×4) = 16 + 12 = 28 + // fp32: inQ_x1(×2×4) + inQ_x2(×2×4) + outQ_y(×2×4) + outQ_xout(×2×4) + // + weight(×4) = 32 + 4 = 36 + int64_t bufferCoefficient = (dtypeSize == 2) ? 28 : 36; + int64_t maxDimLength = (ubSizeLimit - 1024) / bufferCoefficient; + int64_t fpAlignElements = 32 / 4; + maxDimLength = (maxDimLength / fpAlignElements) * fpAlignElements; + TORCH_CHECK(dimLengthAlign <= maxDimLength, "add_rms_norm: dimLength ", + dimLength, " (aligned ", dimLengthAlign, + ") exceeds UB capacity (max ", maxDimLength, ")"); + + // Padding. + at::Tensor kernelInput1; + at::Tensor kernelInput2; + + if (dimLength != dimLengthAlign) { + kernelInput1 = inp1.reshape({totalRows, dimLength}); + kernelInput1 = + at::constant_pad_nd(kernelInput1, {0, dimLengthAlign - dimLength}, 0.0); + kernelInput1 = kernelInput1.contiguous(); + + kernelInput2 = inp2.reshape({totalRows, dimLength}); + kernelInput2 = + at::constant_pad_nd(kernelInput2, {0, dimLengthAlign - dimLength}, 0.0); + kernelInput2 = kernelInput2.contiguous(); + } else { + kernelInput1 = inp1.reshape({totalRows, dimLengthAlign}).contiguous(); + kernelInput2 = inp2.reshape({totalRows, dimLengthAlign}).contiguous(); + } + + at::Tensor kernelOutputY = at::empty_like(kernelInput1); + at::Tensor kernelOutputXOut = at::empty_like(kernelInput1); + + // Weight: always pass as fp32, padded to `dimLengthAlign`. + at::Tensor weightFloat = weight.contiguous().to(at::kFloat); + + if (dimLength != dimLengthAlign) { + weightFloat = + at::constant_pad_nd(weightFloat, {0, dimLengthAlign - dimLength}, 0.0); + } + + weightFloat = weightFloat.contiguous(); + + // Block-level tiling (distribute rows across cores). + int64_t usedCoreNum = std::min(totalRows, coreNum); + int64_t formerLength = (totalRows + usedCoreNum - 1) / usedCoreNum; + int64_t tailLength = formerLength - 1; + int64_t formerNum = totalRows - tailLength * usedCoreNum; + uint32_t blockDim = static_cast(usedCoreNum); + + // All EXEC_KERNEL_CMD args must be lvalues. + float epsFloat = static_cast(eps); + int64_t dtypeSizeVal = dtypeSize; + + EXEC_KERNEL_CMD(add_rms_norm, blockDim, kernelInput1, kernelInput2, + weightFloat, kernelOutputY, kernelOutputXOut, totalRows, + dimLength, dimLengthAlign, formerNum, formerLength, + tailLength, epsFloat, dtypeSizeVal); + + // Remove padding and reshape back to original shape. + at::Tensor outputY = kernelOutputY; + at::Tensor outputXOut = kernelOutputXOut; + + if (dimLength != dimLengthAlign) { + outputY = outputY.narrow(-1, 0, dimLength).contiguous(); + outputXOut = outputXOut.narrow(-1, 0, dimLength).contiguous(); + } + + outputY = outputY.reshape(x1.sizes()); + outputXOut = outputXOut.reshape(x1.sizes()); + + return {outputY, outputXOut}; +} + +} // namespace ascend_kernel diff --git a/src/ascend/custom_kernel/csrc/ops/add_rms_norm/op_kernel/add_rms_norm.cpp b/src/ascend/custom_kernel/csrc/ops/add_rms_norm/op_kernel/add_rms_norm.cpp new file mode 100644 index 00000000..cd523b52 --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops/add_rms_norm/op_kernel/add_rms_norm.cpp @@ -0,0 +1,256 @@ +/* + * Copyright (c) 2025, InfiniTensor. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include "kernel_operator.h" + +constexpr int32_t BUFFER_NUM = 2; + +template +class KernelAddRmsNorm { + public: + __aicore__ inline KernelAddRmsNorm() {} + + __aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR weight, GM_ADDR y, + GM_ADDR x_out, int64_t totalRows, + int64_t dimLength, int64_t dimLengthAlign, + int64_t formerNum, int64_t formerLength, + int64_t tailLength, float eps) { + this->dimLength = dimLength; + this->dimLengthAlign = dimLengthAlign; + this->eps = eps; + + // Block-level tiling: determine row range for this core. + int64_t blockIdx = AscendC::GetBlockIdx(); + int64_t rowOffset; + + if (blockIdx < formerNum) { + this->blockRows = formerLength; + rowOffset = formerLength * blockIdx; + } else { + this->blockRows = tailLength; + int64_t tailIdx = blockIdx - formerNum; + rowOffset = formerLength * formerNum + tailLength * tailIdx; + } + + // Global memory pointers. + x1Gm.SetGlobalBuffer((__gm__ T*)x1 + rowOffset * dimLengthAlign, + this->blockRows * dimLengthAlign); + x2Gm.SetGlobalBuffer((__gm__ T*)x2 + rowOffset * dimLengthAlign, + this->blockRows * dimLengthAlign); + yGm.SetGlobalBuffer((__gm__ T*)y + rowOffset * dimLengthAlign, + this->blockRows * dimLengthAlign); + xOutGm.SetGlobalBuffer((__gm__ T*)x_out + rowOffset * dimLengthAlign, + this->blockRows * dimLengthAlign); + weightGm.SetGlobalBuffer((__gm__ float*)weight, dimLengthAlign); + + int32_t dimLenAlign = static_cast(this->dimLengthAlign); + + // I/O queues (double-buffered). + pipe.InitBuffer(inQueueX1, BUFFER_NUM, + dimLenAlign * static_cast(sizeof(T))); + pipe.InitBuffer(inQueueX2, BUFFER_NUM, + dimLenAlign * static_cast(sizeof(T))); + pipe.InitBuffer(outQueueY, BUFFER_NUM, + dimLenAlign * static_cast(sizeof(T))); + pipe.InitBuffer(outQueueXOut, BUFFER_NUM, + dimLenAlign * static_cast(sizeof(T))); + + // Weight buffer (fp32, loaded once, reused for all rows). + pipe.InitBuffer(weightBuf, + dimLenAlign * static_cast(sizeof(float))); + + // FP16 path needs extra fp32 compute buffers. + // buf1: holds x_out in fp32 (reused from x1_fp32 after Add). + // buf2: holds x2_fp32 initially, then x_out^2, then final result. + if constexpr (sizeof(T) == 2) { + pipe.InitBuffer(fp32Buf1, + dimLenAlign * static_cast(sizeof(float))); + pipe.InitBuffer(fp32Buf2, + dimLenAlign * static_cast(sizeof(float))); + } + + // ReduceSum temporary buffer (size per API formula). + constexpr int32_t ELEMS_PER_REPEAT = 256 / sizeof(float); + constexpr int32_t ELEMS_PER_BLOCK = 32 / sizeof(float); + int32_t firstMaxRepeat = + (dimLenAlign + ELEMS_PER_REPEAT - 1) / ELEMS_PER_REPEAT; + int32_t reduceTmpSize = + ((firstMaxRepeat + ELEMS_PER_BLOCK - 1) / ELEMS_PER_BLOCK) * + ELEMS_PER_BLOCK; + pipe.InitBuffer(reduceTmpBuf, + reduceTmpSize * static_cast(sizeof(float))); + + // Scalar buffer for reduction result (8 floats = 32 bytes). + pipe.InitBuffer(sumBuf, 32); + + // Load weight (fp32) from GM into `weightBuf`. + AscendC::LocalTensor wLocal = weightBuf.Get(); + AscendC::DataCopyExtParams wParams{ + 1, static_cast(dimLenAlign * sizeof(float)), 0, 0, 0}; + AscendC::DataCopyPadExtParams wPad{false, 0, 0, 0.0f}; + AscendC::DataCopyPad(wLocal, weightGm, wParams, wPad); + + // Ensure weight DMA completes before compute. + AscendC::PipeBarrier(); + } + + __aicore__ inline void Process() { + for (int64_t row = 0; row < this->blockRows; ++row) { + CopyIn(row); + Compute(row); + CopyOut(row); + } + } + + private: + __aicore__ inline void CopyIn(int64_t row) { + AscendC::LocalTensor x1Local = inQueueX1.AllocTensor(); + AscendC::LocalTensor x2Local = inQueueX2.AllocTensor(); + AscendC::DataCopyExtParams params{ + 1, static_cast(this->dimLengthAlign * sizeof(T)), 0, 0, 0}; + AscendC::DataCopyPadExtParams pad{false, 0, 0, static_cast(0)}; + AscendC::DataCopyPad(x1Local, x1Gm[row * this->dimLengthAlign], params, + pad); + AscendC::DataCopyPad(x2Local, x2Gm[row * this->dimLengthAlign], params, + pad); + inQueueX1.EnQue(x1Local); + inQueueX2.EnQue(x2Local); + } + + __aicore__ inline void Compute(int64_t row) { + AscendC::LocalTensor x1Local = inQueueX1.DeQue(); + AscendC::LocalTensor x2Local = inQueueX2.DeQue(); + AscendC::LocalTensor yLocal = outQueueY.AllocTensor(); + AscendC::LocalTensor xOutLocal = outQueueXOut.AllocTensor(); + + AscendC::LocalTensor wLocal = weightBuf.Get(); + AscendC::LocalTensor rTmp = reduceTmpBuf.Get(); + AscendC::LocalTensor sLocal = sumBuf.Get(); + + int32_t dimLen = static_cast(this->dimLength); + int32_t dimLenAlign = static_cast(this->dimLengthAlign); + + if constexpr (sizeof(T) == 4) { + // ---- FP32 path: compute directly. ---- + + // Step 1: x_out = x1 + x2. + AscendC::Add(xOutLocal, x1Local, x2Local, dimLenAlign); + + // Step 2: x_out^2 into yLocal (reuse output buffer temporarily). + AscendC::Mul(yLocal, xOutLocal, xOutLocal, dimLenAlign); + + // Step 3: ReduceSum(x_out^2) -> sLocal[0]. + // ReduceSum may modify yLocal, but we overwrite it below. + AscendC::ReduceSum(sLocal, yLocal, rTmp, dimLenAlign); + + // Step 4-5: scale = 1 / sqrt(mean(x_out^2) + eps). + float sumVal = sLocal.GetValue(0); + float meanVal = sumVal / static_cast(dimLen) + this->eps; + sLocal.SetValue(0, meanVal); + AscendC::Sqrt(sLocal, sLocal, 8); + float scale = 1.0f / sLocal.GetValue(0); + + // Step 6: y = x_out * scale. + AscendC::Muls(yLocal, xOutLocal, scale, dimLenAlign); + + // Step 7: y = y * weight. + AscendC::Mul(yLocal, yLocal, wLocal, dimLenAlign); + + } else { + // ---- FP16 path: cast → fp32 compute → cast back. ---- + AscendC::LocalTensor b1 = fp32Buf1.Get(); + AscendC::LocalTensor b2 = fp32Buf2.Get(); + + // Cast inputs fp16 → fp32. + AscendC::Cast(b1, x1Local, AscendC::RoundMode::CAST_NONE, dimLenAlign); + AscendC::Cast(b2, x2Local, AscendC::RoundMode::CAST_NONE, dimLenAlign); + + // Step 1: x_out = x1 + x2 (fp32), stored in b1. + AscendC::Add(b1, b1, b2, dimLenAlign); + + // Cast x_out fp32 → fp16 for the x_out output. + AscendC::Cast(xOutLocal, b1, AscendC::RoundMode::CAST_ROUND, dimLenAlign); + + // Step 2: x_out^2 in fp32, stored in b2. + AscendC::Mul(b2, b1, b1, dimLenAlign); + + // Step 3: ReduceSum(x_out^2) -> sLocal[0]. + AscendC::ReduceSum(sLocal, b2, rTmp, dimLenAlign); + + // Step 4-5: scale = 1 / sqrt(mean(x_out^2) + eps). + float sumVal = sLocal.GetValue(0); + float meanVal = sumVal / static_cast(dimLen) + this->eps; + sLocal.SetValue(0, meanVal); + AscendC::Sqrt(sLocal, sLocal, 8); + float scale = 1.0f / sLocal.GetValue(0); + + // Step 6: y = x_out * scale (fp32), reuse b2. + AscendC::Muls(b2, b1, scale, dimLenAlign); + + // Step 7: y = y * weight (fp32). + AscendC::Mul(b2, b2, wLocal, dimLenAlign); + + // Cast result fp32 → fp16. + AscendC::Cast(yLocal, b2, AscendC::RoundMode::CAST_ROUND, dimLenAlign); + } + + inQueueX1.FreeTensor(x1Local); + inQueueX2.FreeTensor(x2Local); + outQueueY.EnQue(yLocal); + outQueueXOut.EnQue(xOutLocal); + } + + __aicore__ inline void CopyOut(int64_t row) { + AscendC::LocalTensor yLocal = outQueueY.DeQue(); + AscendC::LocalTensor xOutLocal = outQueueXOut.DeQue(); + AscendC::DataCopyExtParams params{ + 1, static_cast(this->dimLengthAlign * sizeof(T)), 0, 0, 0}; + AscendC::DataCopyPad(yGm[row * this->dimLengthAlign], yLocal, params); + AscendC::DataCopyPad(xOutGm[row * this->dimLengthAlign], xOutLocal, params); + outQueueY.FreeTensor(yLocal); + outQueueXOut.FreeTensor(xOutLocal); + } + + private: + AscendC::TPipe pipe; + AscendC::TQue inQueueX1; + AscendC::TQue inQueueX2; + AscendC::TQue outQueueY; + AscendC::TQue outQueueXOut; + + AscendC::TBuf weightBuf; + AscendC::TBuf fp32Buf1; + AscendC::TBuf fp32Buf2; + AscendC::TBuf reduceTmpBuf; + AscendC::TBuf sumBuf; + + AscendC::GlobalTensor x1Gm, x2Gm, yGm, xOutGm; + AscendC::GlobalTensor weightGm; + + int64_t blockRows; + int64_t dimLength; + int64_t dimLengthAlign; + float eps; +}; + +extern "C" __global__ __aicore__ void add_rms_norm( + GM_ADDR x1, GM_ADDR x2, GM_ADDR weight, GM_ADDR y, GM_ADDR x_out, + int64_t totalRows, int64_t dimLength, int64_t dimLengthAlign, + int64_t formerNum, int64_t formerLength, int64_t tailLength, float eps, + int64_t dtypeSize) { + if (dtypeSize == 2) { + KernelAddRmsNorm op; + op.Init(x1, x2, weight, y, x_out, totalRows, dimLength, dimLengthAlign, + formerNum, formerLength, tailLength, eps); + op.Process(); + } else { + KernelAddRmsNorm op; + op.Init(x1, x2, weight, y, x_out, totalRows, dimLength, dimLengthAlign, + formerNum, formerLength, tailLength, eps); + op.Process(); + } +} diff --git a/src/ascend/custom_kernel/csrc/ops/rms_norm/CMakeLists.txt b/src/ascend/custom_kernel/csrc/ops/rms_norm/CMakeLists.txt new file mode 100644 index 00000000..94ceabaa --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops/rms_norm/CMakeLists.txt @@ -0,0 +1 @@ +ascendc_add_operator(OP_NAME rms_norm) diff --git a/src/ascend/custom_kernel/csrc/ops/rms_norm/README.md b/src/ascend/custom_kernel/csrc/ops/rms_norm/README.md new file mode 100644 index 00000000..39b3cfce --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops/rms_norm/README.md @@ -0,0 +1,59 @@ +# `ascend_kernel.ops.rms_norm` + +```python +torch.ops.npu.rms_norm(input, weight, eps=1e-6) → Tensor +``` + +对输入张量的最后一个维度执行 RMS 归一化(Root Mean Square Layer Normalization)。 + +$$y = x \cdot \frac{1}{\sqrt{\mathrm{mean}(x^2) + \varepsilon}} \cdot \text{weight}$$ + +与 LayerNorm 不同,RMSNorm 不减去均值,仅基于均方根进行归一化,计算开销更低。 + +## 参数说明 + +- **input** (`Tensor`) — 输入张量,维度 ≥ 1。归一化沿最后一个维度进行。 +- **weight** (`Tensor`) — 一维权重张量,形状为 `[hidden_dim]`,其中 `hidden_dim = input.shape[-1]`。 +- **eps** (`float`, 可选) — 加在方差上的小常数,防止除零。默认值 `1e-6`。 + +## 支持的数据类型 + +| 数据类型 | 支持 | +|---------|------| +| `torch.float16` | 是 | +| `torch.float32` | 是 | + +`weight` 的数据类型可与 `input` 不同(内部统一转为 `float32` 计算)。 + +## Shape 约束 + +- `input`: 任意维度 ≥ 1 的张量,形状 `[*, hidden_dim]`。 +- `weight`: 一维张量,形状 `[hidden_dim]`,必须满足 `weight.size(0) == input.size(-1)`。 +- 输出与 `input` 同形状、同数据类型。 + +## 约束条件 + +- `hidden_dim`(对齐后)不能超过单核 UB 容量限制。在 Ascend 910B 上,`hidden_dim` 最大约 9600(`float32`)或 9600(`float16`)。 +- `input` 和 `weight` 必须在 NPU 设备上。 + +## 使用示例 + +```python +import torch +import torch_npu +import ascend_kernel + +# 基本用法。 +x = torch.randn(32, 4096, dtype=torch.float16, device="npu") +w = torch.randn(4096, dtype=torch.float16, device="npu") +y = torch.ops.npu.rms_norm(x, w, 1e-6) + +# 多维输入(batch × seq_len × hidden_dim)。 +x = torch.randn(4, 128, 4096, dtype=torch.float32, device="npu") +w = torch.randn(4096, dtype=torch.float32, device="npu") +y = torch.ops.npu.rms_norm(x, w) # eps 默认 1e-6 +``` + +## 返回值 + +`Tensor` — 与 `input` 同形状、同数据类型的归一化结果。 diff --git a/src/ascend/custom_kernel/csrc/ops/rms_norm/design.md b/src/ascend/custom_kernel/csrc/ops/rms_norm/design.md new file mode 100644 index 00000000..6e3d65fa --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops/rms_norm/design.md @@ -0,0 +1,381 @@ +# RMSNorm 设计文档 + +## 1. 算子接口 + +### 1.1 函数签名 + +```cpp +at::Tensor rms_norm(const at::Tensor &input, const at::Tensor &weight, double eps); +``` + +### 1.2 参数说明 + +| 参数名 | 类型 | 输入/输出 | 支持的数据类型 | 描述 | 约束条件 | +|--------|------|-----------|---------------|------|----------| +| input | at::Tensor | 输入 | float16/float32 | 输入 tensor,shape `[*, hidden_dim]` | 最后一维为归一化维度 | +| weight | at::Tensor | 输入 | float16/float32 | 权重 tensor,shape `[hidden_dim]` | 与 `input` 最后一维等长 | +| eps | double | 输入 | — | 数值稳定性常量 | 默认 1e-6 | +| output | at::Tensor | 输出 | float16/float32 | 输出 tensor,shape 同 `input` | dtype 同 `input` | + +### 1.3 支持的数据类型 + +- [x] float16 +- [x] float32 + +### 1.4 PyTorch 参考 + +```python +torch.nn.functional.rms_norm(input, normalized_shape, weight, eps) +``` + +InfiniOps 基类:`src/base/rms_norm.h`,成员 `dim_`(hidden_dim)、`batch_size_`、`nhead_`、`eps_`。 + +--- + +## 2. 计算逻辑 + +### 2.1 算法描述 + +RMSNorm 对输入 tensor 的每一行(最后一维)做 Root Mean Square 归一化: + +$$y_i = x_i \cdot \text{rsqrt}\left(\frac{1}{N}\sum_{j=0}^{N-1} x_j^2 + \varepsilon\right) \cdot w_i$$ + +其中 $N$ = `hidden_dim`。 + +分步: +1. 对每行 $x$ 计算元素平方 $x^2$。 +2. 沿行方向归约求和 $\text{sum} = \sum x^2$。 +3. 计算均值 $\text{mean} = \text{sum} / N$。 +4. 加 epsilon 并取 rsqrt:$\text{scale} = \text{rsqrt}(\text{mean} + \varepsilon)$。 +5. 逐元素乘以 scale 和 weight:$y = x \cdot \text{scale} \cdot w$。 + +### 2.2 AscendC API 调用伪代码 + +```cpp +// 对每行 hidden_dim 个元素(x 已在 UB 中,float32): + +// Step 1: 计算 x²。 +Mul(sqBuf, xBuf, xBuf, hiddenDim); + +// Step 2: 归约求和。 +// ReduceSum 结果存入 sumBuf(至少 32B)。 +WholeReduceSum(sumBuf, sqBuf, hiddenDim, 1, 1, 8); + +// Step 3-5: 标量运算(在 32B 对齐的 sumBuf 上操作)。 +Muls(sumBuf, sumBuf, 1.0f / hiddenDim, 8); // mean = sum / N +Adds(sumBuf, sumBuf, eps, 8); // mean + eps +Rsqrt(sumBuf, sumBuf, 8); // rsqrt(mean + eps) + +// Step 6: 广播乘以 scale。 +float scale = sumBuf.GetValue(0); +Muls(outBuf, xBuf, scale, hiddenDim); // y = x * scale + +// Step 7: 逐元素乘以 weight。 +Mul(outBuf, outBuf, weightBuf, hiddenDim); // y = y * weight +``` + +**FP16 输入时**,在 Step 1 之前插入升精度,在 Step 7 之后插入降精度: + +```cpp +// 升精度:fp16 → fp32 +Cast(xBufFp32, xBufFp16, RoundMode::CAST_NONE, hiddenDim); + +// ... Steps 1-7 在 fp32 上执行 ... + +// 降精度:fp32 → fp16 +Cast(outBufFp16, outBufFp32, RoundMode::CAST_ROUND, hiddenDim); +``` + +### 2.3 实现路径选择 + +- [x] AscendC Kernel(纯 vector 实现) +- [ ] CATLASS 模板库(矩阵乘法类) +- [ ] ACLNN 封装(CANN 内置算子) + +**选择理由**:RMSNorm 是纯 vector 归约 + 逐元素运算,不涉及矩阵乘法。CANN 的 `aclnnRmsNorm` 内部分解为 5 个子算子(Pows + ReduceMean + Add + Rsqrt + Mul),产生 inter-op 调度开销。自定义 AscendC kernel 可以将整个计算融合在单个 kernel 内,消除子算子之间的调度开销并实现 UB 内数据复用。 + +--- + +## 3. Tiling 策略 + +**算子类型**: Row-reduction(沿最后一维归约,输出与输入同形) + +### 核心设计 + +RMSNorm 以**行**为处理单元。每行 `hidden_dim` 个元素必须整体装入 UB 才能完成归约。因此: + +- **Block 级 Tiling**:将总行数分配到多核并行。 +- **UB 级 Tiling**:每次处理一行(`tileLength = hiddenDim`)。核内循环遍历分配给该核的所有行。 + +``` +GM: [row 0] [row 1] ... [row M-1] (M = totalRows) + │ │ │ + ┌─────┘ │ └─────┐ + ▼ ▼ ▼ +Core 0 Core 1 ... Core 39 ← Block 级(行分配) + rows[0..k] rows[k+1..2k] rows[..] + +Core 内: + for each row: + CopyIn(row) ← GM → UB + Compute(row) ← reduction + scale + weight mul + CopyOut(row) ← UB → GM +``` + +### 3.1 Tiling 参数结构体 + +```cpp +struct RmsNormTilingData { + int64_t totalRows; // 总行数 = product(shape[:-1]) + int64_t hiddenDim; // 最后一维长度 N + int64_t hiddenDimAlign; // 32B 对齐后的 N + + int64_t formerNum; // 整核数量 + int64_t formerLength; // 整核处理的行数 + int64_t tailNum; // 尾核数量 + int64_t tailLength; // 尾核处理的行数 + + float eps; // epsilon + int64_t dtypeSize; // 每个元素字节数(2 或 4) +}; +``` + +### 3.2 Block 级 Tiling(核间切分) + +按行数均匀分配到 `CORE_NUM` 个核,使用整核/尾核策略: + +| 参数 | 计算公式 | +|------|----------| +| totalRows | product(input.shape[:-1]) | +| formerNum | totalRows % CORE_NUM(== 0 时取 CORE_NUM) | +| tailNum | CORE_NUM - formerNum | +| formerLength | totalRows / CORE_NUM + 1 | +| tailLength | totalRows / CORE_NUM | + +**验证**:`formerNum * formerLength + tailNum * tailLength == totalRows` + +### 3.3 UB 级 Tiling(核内切分) + +每次处理一行。`tileLength = hiddenDim`(整行装入 UB)。 + +#### 精度处理 + +| 输入类型 | 计算精度 | UB 额外开销 | +|----------|----------|-------------| +| float32 | float32 | 无 | +| float16 | **升精度到 float32** | 需要 fp32 计算 buffer | + +#### UB 分配表 — float32 + +| Buffer 名称 | 大小(字节) | 数量 | 用途 | 总大小 | +|-------------|-------------|------|------|--------| +| inQueueX | hiddenDim × 4 | 2 (double buf) | 输入行 | hiddenDim × 8 | +| outQueueY | hiddenDim × 4 | 2 (double buf) | 输出行 | hiddenDim × 8 | +| tmpBuf | hiddenDim × 4 | 1 | x² 中间结果 | hiddenDim × 4 | +| weightBuf | hiddenDim × 4 | 1 | weight(load once) | hiddenDim × 4 | +| sumBuf | 32 | 1 | 归约标量 | 32 | +| **总计** | | | | **hiddenDim × 24 + 32** | + +**bufferCoefficient (fp32) = 24** + +maxHiddenDim (fp32) = (UB_SIZE_LIMIT − 32) / 24 + +示例:UB = 192 KB → maxHiddenDim = 8191 + +#### UB 分配表 — float16 + +| Buffer 名称 | 大小(字节) | 数量 | 用途 | 总大小 | +|-------------|-------------|------|------|--------| +| inQueueX | hiddenDim × 2 | 2 (double buf) | 输入行 (fp16) | hiddenDim × 4 | +| outQueueY | hiddenDim × 2 | 2 (double buf) | 输出行 (fp16) | hiddenDim × 4 | +| xFp32Buf | hiddenDim × 4 | 1 | 升精度后的 x | hiddenDim × 4 | +| tmpFp32Buf | hiddenDim × 4 | 1 | x² 中间结果 | hiddenDim × 4 | +| weightFp32Buf | hiddenDim × 4 | 1 | weight (fp32, load once) | hiddenDim × 4 | +| sumBuf | 32 | 1 | 归约标量 | 32 | +| **总计** | | | | **hiddenDim × 20 + 32** | + +**bufferCoefficient (fp16) = 20** + +maxHiddenDim (fp16) = (UB_SIZE_LIMIT − 32) / 20 + +示例:UB = 192 KB → maxHiddenDim = 9828 + +#### 典型模型 hidden_dim 验证 + +| 模型 | hidden_dim | fp32 UB 使用 | fp16 UB 使用 | 是否 fit | +|------|-----------|-------------|-------------|---------| +| Qwen-7B | 4096 | 98,336 B (50%) | 81,952 B (42%) | ✓ | +| Llama-8B | 4096 | 98,336 B | 81,952 B | ✓ | +| Llama-70B | 8192 | 196,640 B (100.02%) | 163,872 B (83%) | fp16 ✓, fp32 需降为 BUFFER_NUM=1 | + +**注意**:fp32 + hidden_dim=8192 超出 192KB 32 字节。此时 Host 端应检测并降低 BUFFER_NUM 为 1(bufferCoefficient 变为 16,maxHiddenDim = 12287)。 + +#### UB 约束验证 + +- **UB 对齐**:32 字节 +- **hiddenDimAlign**:`((hiddenDim + alignElements − 1) / alignElements) * alignElements`,其中 `alignElements = 32 / dtypeSize` +- **UB 总使用** ≤ UB_SIZE_LIMIT(通过 `AscendC::GetSysWorkSpaceSize()` 运行时获取) + +--- + +## 4. Workspace 需求 + +### 4.1 Workspace 大小 + +```cpp +size_t workspaceSize = sizeof(RmsNormTilingData); +``` + +Tiling 参数通过 workspace 传递给 kernel。 + +--- + +## 5. 性能优化 + +### 5.1 关键优化点 + +1. **单 kernel 融合**:将 CANN 的 5 个子算子(Pows + ReduceMean + Add + Rsqrt + Mul)融合为 1 个 kernel,消除 inter-op 调度开销。 +2. **UB 数据复用**:输入行在 UB 中被读取一次,用于平方和归约,又用于 scale 乘法——无需重复从 GM 加载。 +3. **Weight 一次加载**:weight 向量在 Init 阶段加载到 UB,后续所有行复用。 +4. **Double buffer**:输入/输出使用 BUFFER_NUM=2,隐藏 GM 访存延迟。 + +### 5.2 算子特性 + +- **计算模式**: memory-bound(归约 + 逐元素乘法,计算强度低) +- **访存模式**: 顺序行访问(最后一维连续) +- **并行性**: 高(行间完全独立) + +--- + +## 6. Kernel 端实现要点 + +### 6.1 Init(核内初始化) + +```cpp +__aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR y, + GM_ADDR workspace, GM_ADDR tiling) { + // 1. 从 tiling workspace 读取 RmsNormTilingData。 + // 2. 判断当前 block 是整核还是尾核,计算行偏移和行数。 + // 3. 设置 xGm / yGm 的 GlobalBuffer。 + // 4. 加载 weight 到 weightBuf(仅一次)。 + // - fp16 输入时:加载 weight_fp16 → cast 到 weightFp32Buf。 + // - fp32 输入时:直接加载到 weightBuf。 + // 5. 初始化 pipe / queue。 +} +``` + +### 6.2 执行流程(核内循环) + +```cpp +__aicore__ inline void Process() { + // coreRows = 当前核分配的行数 + for (int64_t row = 0; row < coreRows; ++row) { + CopyIn(row); + Compute(row); + CopyOut(row); + } +} +``` + +### 6.3 CopyIn + +```cpp +__aicore__ inline void CopyIn(int64_t row) { + LocalTensor xLocal = inQueueX.AllocTensor(); + DataCopy(xLocal, xGm[row * hiddenDim], hiddenDim); + inQueueX.EnQue(xLocal); +} +``` + +### 6.4 Compute + +```cpp +__aicore__ inline void Compute(int64_t row) { + LocalTensor xLocal = inQueueX.DeQue(); + LocalTensor yLocal = outQueueY.AllocTensor(); + + // [fp16 only] Cast x to fp32. + // Cast(xFp32, xLocal, CAST_NONE, hiddenDim); + + // Step 1: x². + Mul(tmpBuf, xFp32, xFp32, hiddenDim); + + // Step 2: ReduceSum → sumBuf. + // 使用 WholeReduceSum 或手动分块归约。 + + // Step 3-5: mean → +eps → rsqrt(在 sumBuf 上操作)。 + Muls(sumBuf, sumBuf, 1.0f / hiddenDim, 8); + Adds(sumBuf, sumBuf, eps, 8); + Rsqrt(sumBuf, sumBuf, 8); + float scale = sumBuf.GetValue(0); + + // Step 6: y = x * scale. + Muls(yFp32, xFp32, scale, hiddenDim); + + // Step 7: y = y * weight. + Mul(yFp32, yFp32, weightBuf, hiddenDim); + + // [fp16 only] Cast back to fp16. + // Cast(yLocal, yFp32, CAST_ROUND, hiddenDim); + + inQueueX.FreeTensor(xLocal); + outQueueY.EnQue(yLocal); +} +``` + +### 6.5 CopyOut + +```cpp +__aicore__ inline void CopyOut(int64_t row) { + LocalTensor yLocal = outQueueY.DeQue(); + DataCopy(yGm[row * hiddenDim], yLocal, hiddenDim); + outQueueY.FreeTensor(yLocal); +} +``` + +--- + +## 7. 实现检查清单 + +### 7.1 文件结构 + +- [ ] `csrc/ops/rms_norm/CMakeLists.txt` +- [ ] `csrc/ops/rms_norm/op_host/rms_norm.cpp` +- [ ] `csrc/ops/rms_norm/op_kernel/rms_norm.cpp` +- [ ] `csrc/ops.h`(添加声明) +- [ ] `csrc/register.cpp`(添加 `m.def` + `m.impl`) +- [ ] `csrc/CMakeLists.txt`(添加 host + kernel 源文件) + +### 7.2 Host 端实现 + +- [ ] 定义 `RmsNormTilingData` 结构体 +- [ ] 计算 totalRows = product(input.shape[:-1]) +- [ ] Block 级 Tiling 参数(formerNum/tailNum/formerLength/tailLength) +- [ ] 检测 UB 是否能容纳 hiddenDim(超限时降低 BUFFER_NUM) +- [ ] 分配 workspace 并拷贝 tiling data +- [ ] 调用 `EXEC_KERNEL_CMD(rms_norm, ...)` + +### 7.3 Kernel 端实现 + +- [ ] Init:整核/尾核偏移计算,weight 加载 +- [ ] CopyIn:GM → UB 行拷贝 +- [ ] Compute:fp16 升精度 → x² → ReduceSum → rsqrt → scale → weight mul → fp16 降精度 +- [ ] CopyOut:UB → GM 行写回 +- [ ] Process:行循环 + +### 7.4 测试验证 + +- [ ] 小规模:shape `[4, 128]`,fp32/fp16 +- [ ] 中等规模:shape `[32, 4096]`,fp32/fp16 +- [ ] 大规模:shape `[128, 8192]`,fp16 +- [ ] 正确性:与 `torch.nn.functional.rms_norm` 对比 +- [ ] 边界:shape `[1, 128]`(单行)、`[1024, 128]`(多行少列) + +--- + +## 8. 参考实现 + +- **InfiniOps 基类**: `src/base/rms_norm.h` +- **InfiniOps CANN 实现**: `src/ascend/rms_norm/kernel.h`(使用 `aclnnRmsNorm`) +- **PyTorch**: `torch.nn.functional.rms_norm` +- **有效输入范围**: 无限制(任意实数),eps > 0 diff --git a/src/ascend/custom_kernel/csrc/ops/rms_norm/op_host/rms_norm.cpp b/src/ascend/custom_kernel/csrc/ops/rms_norm/op_host/rms_norm.cpp new file mode 100644 index 00000000..27479c31 --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops/rms_norm/op_host/rms_norm.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2025, InfiniTensor. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include "aclrtlaunch_rms_norm.h" +#include "tiling/platform/platform_ascendc.h" +#include "torch_kernel_helper.h" + +namespace ascend_kernel { + +at::Tensor rms_norm(const at::Tensor& input, const at::Tensor& weight, + double eps) { + // Input validation. + TORCH_CHECK(input.dim() > 0, + "rms_norm: input must have at least 1 dimension"); + TORCH_CHECK( + input.scalar_type() == at::kHalf || input.scalar_type() == at::kFloat, + "rms_norm: only float16 and float32 are supported, got ", + input.scalar_type()); + TORCH_CHECK(weight.dim() == 1, "rms_norm: weight must be 1-dimensional"); + TORCH_CHECK(weight.size(0) == input.size(-1), "rms_norm: weight size (", + weight.size(0), ") must match input last dim (", input.size(-1), + ")"); + + int64_t dimLength = input.size(-1); + int64_t totalRows = input.numel() / dimLength; + + if (totalRows == 0 || dimLength == 0) { + return at::empty_like(input); + } + + at::Tensor x = input.contiguous(); + int64_t dtypeSize = x.element_size(); + + // Hardware parameters. + auto ascendc_platform = + platform_ascendc::PlatformAscendCManager::GetInstance(); + int64_t coreNum = static_cast(ascendc_platform->GetCoreNumAiv()); + uint64_t ubSize; + ascendc_platform->GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + int64_t ubSizeLimit = static_cast(ubSize); + + // Alignment (32-byte boundary). + int64_t alignElements = 32 / dtypeSize; + int64_t dimLengthAlign = + ((dimLength + alignElements - 1) / alignElements) * alignElements; + + // UB capacity check. + // fp32: inQ(×2) + outQ(×2) + weight = 5 × dimLenAlign × 4 = coeff 20 + // fp16: inQ(×2) + outQ(×2) + xFp32 + tmpFp32 + weight + // = 2×dimLenAlign×2 ×2 + 3×dimLenAlign×4 = 8 + 12 = coeff 20 + int64_t bufferCoefficient = 20; + int64_t maxDimLength = + (ubSizeLimit - 1024) / bufferCoefficient; // 1024 for reduce bufs. + int64_t fpAlignElements = 32 / 4; // fp32 alignment. + maxDimLength = (maxDimLength / fpAlignElements) * fpAlignElements; + TORCH_CHECK(dimLengthAlign <= maxDimLength, "rms_norm: dimLength ", dimLength, + " (aligned ", dimLengthAlign, ") exceeds UB capacity (max ", + maxDimLength, ")"); + + // Padding. + at::Tensor kernelInput; + + if (dimLength != dimLengthAlign) { + kernelInput = x.reshape({totalRows, dimLength}); + kernelInput = + at::constant_pad_nd(kernelInput, {0, dimLengthAlign - dimLength}, 0.0); + kernelInput = kernelInput.contiguous(); + } else { + kernelInput = x.reshape({totalRows, dimLengthAlign}).contiguous(); + } + + at::Tensor kernelOutput = at::empty_like(kernelInput); + + // Weight: always pass as fp32, padded to `dimLengthAlign`. + at::Tensor weightFloat = weight.contiguous().to(at::kFloat); + + if (dimLength != dimLengthAlign) { + weightFloat = + at::constant_pad_nd(weightFloat, {0, dimLengthAlign - dimLength}, 0.0); + } + + weightFloat = weightFloat.contiguous(); + + // Block-level tiling (distribute rows across cores). + int64_t usedCoreNum = std::min(totalRows, coreNum); + int64_t formerLength = (totalRows + usedCoreNum - 1) / usedCoreNum; + int64_t tailLength = formerLength - 1; + int64_t formerNum = totalRows - tailLength * usedCoreNum; + uint32_t blockDim = static_cast(usedCoreNum); + + // All EXEC_KERNEL_CMD args must be lvalues. + float epsFloat = static_cast(eps); + int64_t dtypeSizeVal = dtypeSize; + + EXEC_KERNEL_CMD(rms_norm, blockDim, kernelInput, weightFloat, kernelOutput, + totalRows, dimLength, dimLengthAlign, formerNum, formerLength, + tailLength, epsFloat, dtypeSizeVal); + + // Remove padding and reshape back to original shape. + at::Tensor output = kernelOutput; + + if (dimLength != dimLengthAlign) { + output = output.narrow(-1, 0, dimLength).contiguous(); + } + + output = output.reshape(input.sizes()); + + return output; +} + +} // namespace ascend_kernel diff --git a/src/ascend/custom_kernel/csrc/ops/rms_norm/op_kernel/rms_norm.cpp b/src/ascend/custom_kernel/csrc/ops/rms_norm/op_kernel/rms_norm.cpp new file mode 100644 index 00000000..8f2f4b4f --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops/rms_norm/op_kernel/rms_norm.cpp @@ -0,0 +1,222 @@ +/* + * Copyright (c) 2025, InfiniTensor. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include "kernel_operator.h" + +constexpr int32_t BUFFER_NUM = 2; + +template +class KernelRmsNorm { + public: + __aicore__ inline KernelRmsNorm() {} + + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR y, + int64_t totalRows, int64_t dimLength, + int64_t dimLengthAlign, int64_t formerNum, + int64_t formerLength, int64_t tailLength, + float eps) { + this->dimLength = dimLength; + this->dimLengthAlign = dimLengthAlign; + this->eps = eps; + + // Block-level tiling: determine row range for this core. + int64_t blockIdx = AscendC::GetBlockIdx(); + int64_t rowOffset; + + if (blockIdx < formerNum) { + this->blockRows = formerLength; + rowOffset = formerLength * blockIdx; + } else { + this->blockRows = tailLength; + int64_t tailIdx = blockIdx - formerNum; + rowOffset = formerLength * formerNum + tailLength * tailIdx; + } + + // Global memory pointers. + xGm.SetGlobalBuffer((__gm__ T*)x + rowOffset * dimLengthAlign, + this->blockRows * dimLengthAlign); + yGm.SetGlobalBuffer((__gm__ T*)y + rowOffset * dimLengthAlign, + this->blockRows * dimLengthAlign); + weightGm.SetGlobalBuffer((__gm__ float*)weight, dimLengthAlign); + + int32_t dimLenAlign = static_cast(this->dimLengthAlign); + + // I/O queues (double-buffered). + pipe.InitBuffer(inQueueX, BUFFER_NUM, + dimLenAlign * static_cast(sizeof(T))); + pipe.InitBuffer(outQueueY, BUFFER_NUM, + dimLenAlign * static_cast(sizeof(T))); + + // Weight buffer (fp32, loaded once, reused for all rows). + pipe.InitBuffer(weightBuf, + dimLenAlign * static_cast(sizeof(float))); + + // FP16 path needs extra fp32 compute buffers. + if constexpr (sizeof(T) == 2) { + pipe.InitBuffer(xFp32Buf, + dimLenAlign * static_cast(sizeof(float))); + pipe.InitBuffer(tmpFp32Buf, + dimLenAlign * static_cast(sizeof(float))); + } + + // ReduceSum temporary buffer (size per API formula). + constexpr int32_t ELEMS_PER_REPEAT = 256 / sizeof(float); + constexpr int32_t ELEMS_PER_BLOCK = 32 / sizeof(float); + int32_t firstMaxRepeat = + (dimLenAlign + ELEMS_PER_REPEAT - 1) / ELEMS_PER_REPEAT; + int32_t reduceTmpSize = + ((firstMaxRepeat + ELEMS_PER_BLOCK - 1) / ELEMS_PER_BLOCK) * + ELEMS_PER_BLOCK; + pipe.InitBuffer(reduceTmpBuf, + reduceTmpSize * static_cast(sizeof(float))); + + // Scalar buffer for reduction result (8 floats = 32 bytes). + pipe.InitBuffer(sumBuf, 32); + + // Load weight (fp32) from GM into `weightBuf`. + AscendC::LocalTensor wLocal = weightBuf.Get(); + AscendC::DataCopyExtParams wParams{ + 1, static_cast(dimLenAlign * sizeof(float)), 0, 0, 0}; + AscendC::DataCopyPadExtParams wPad{false, 0, 0, 0.0f}; + AscendC::DataCopyPad(wLocal, weightGm, wParams, wPad); + + // Ensure weight DMA completes before compute. + AscendC::PipeBarrier(); + } + + __aicore__ inline void Process() { + for (int64_t row = 0; row < this->blockRows; ++row) { + CopyIn(row); + Compute(row); + CopyOut(row); + } + } + + private: + __aicore__ inline void CopyIn(int64_t row) { + AscendC::LocalTensor xLocal = inQueueX.AllocTensor(); + AscendC::DataCopyExtParams params{ + 1, static_cast(this->dimLengthAlign * sizeof(T)), 0, 0, 0}; + AscendC::DataCopyPadExtParams pad{false, 0, 0, static_cast(0)}; + AscendC::DataCopyPad(xLocal, xGm[row * this->dimLengthAlign], params, pad); + inQueueX.EnQue(xLocal); + } + + __aicore__ inline void Compute(int64_t row) { + AscendC::LocalTensor xLocal = inQueueX.DeQue(); + AscendC::LocalTensor yLocal = outQueueY.AllocTensor(); + + AscendC::LocalTensor wLocal = weightBuf.Get(); + AscendC::LocalTensor rTmp = reduceTmpBuf.Get(); + AscendC::LocalTensor sLocal = sumBuf.Get(); + + int32_t dimLen = static_cast(this->dimLength); + int32_t dimLenAlign = static_cast(this->dimLengthAlign); + + if constexpr (sizeof(T) == 4) { + // ---- FP32 path: compute directly. ---- + + // Step 1: x^2 into yLocal (reuse output buffer temporarily). + AscendC::Mul(yLocal, xLocal, xLocal, dimLenAlign); + + // Step 2: ReduceSum(x^2) -> sLocal[0]. + // ReduceSum may modify src (yLocal), but we overwrite it later. + AscendC::ReduceSum(sLocal, yLocal, rTmp, dimLenAlign); + + // Step 3-5: scale = 1 / sqrt(mean(x^2) + eps). + float sumVal = sLocal.GetValue(0); + float meanVal = sumVal / static_cast(dimLen) + this->eps; + sLocal.SetValue(0, meanVal); + AscendC::Sqrt(sLocal, sLocal, 8); + float scale = 1.0f / sLocal.GetValue(0); + + // Step 6: y = x * scale. + AscendC::Muls(yLocal, xLocal, scale, dimLenAlign); + + // Step 7: y = y * weight. + AscendC::Mul(yLocal, yLocal, wLocal, dimLenAlign); + + } else { + // ---- FP16 path: cast → fp32 compute → cast back. ---- + AscendC::LocalTensor xF32 = xFp32Buf.Get(); + AscendC::LocalTensor tmpF32 = tmpFp32Buf.Get(); + + // Cast input fp16 → fp32. + AscendC::Cast(xF32, xLocal, AscendC::RoundMode::CAST_NONE, dimLenAlign); + + // Step 1: x^2 in fp32. + AscendC::Mul(tmpF32, xF32, xF32, dimLenAlign); + + // Step 2: ReduceSum(x^2) -> sLocal[0]. + AscendC::ReduceSum(sLocal, tmpF32, rTmp, dimLenAlign); + + // Step 3-5: scale = 1 / sqrt(mean(x^2) + eps). + float sumVal = sLocal.GetValue(0); + float meanVal = sumVal / static_cast(dimLen) + this->eps; + sLocal.SetValue(0, meanVal); + AscendC::Sqrt(sLocal, sLocal, 8); + float scale = 1.0f / sLocal.GetValue(0); + + // Step 6: y = x * scale (fp32). + AscendC::Muls(tmpF32, xF32, scale, dimLenAlign); + + // Step 7: y = y * weight (fp32). + AscendC::Mul(tmpF32, tmpF32, wLocal, dimLenAlign); + + // Cast result fp32 → fp16. + AscendC::Cast(yLocal, tmpF32, AscendC::RoundMode::CAST_ROUND, + dimLenAlign); + } + + inQueueX.FreeTensor(xLocal); + outQueueY.EnQue(yLocal); + } + + __aicore__ inline void CopyOut(int64_t row) { + AscendC::LocalTensor yLocal = outQueueY.DeQue(); + AscendC::DataCopyExtParams params{ + 1, static_cast(this->dimLengthAlign * sizeof(T)), 0, 0, 0}; + AscendC::DataCopyPad(yGm[row * this->dimLengthAlign], yLocal, params); + outQueueY.FreeTensor(yLocal); + } + + private: + AscendC::TPipe pipe; + AscendC::TQue inQueueX; + AscendC::TQue outQueueY; + + AscendC::TBuf weightBuf; + AscendC::TBuf xFp32Buf; + AscendC::TBuf tmpFp32Buf; + AscendC::TBuf reduceTmpBuf; + AscendC::TBuf sumBuf; + + AscendC::GlobalTensor xGm, yGm; + AscendC::GlobalTensor weightGm; + + int64_t blockRows; + int64_t dimLength; + int64_t dimLengthAlign; + float eps; +}; + +extern "C" __global__ __aicore__ void rms_norm( + GM_ADDR x, GM_ADDR weight, GM_ADDR y, int64_t totalRows, int64_t dimLength, + int64_t dimLengthAlign, int64_t formerNum, int64_t formerLength, + int64_t tailLength, float eps, int64_t dtypeSize) { + if (dtypeSize == 2) { + KernelRmsNorm op; + op.Init(x, weight, y, totalRows, dimLength, dimLengthAlign, formerNum, + formerLength, tailLength, eps); + op.Process(); + } else { + KernelRmsNorm op; + op.Init(x, weight, y, totalRows, dimLength, dimLengthAlign, formerNum, + formerLength, tailLength, eps); + op.Process(); + } +} diff --git a/src/ascend/custom_kernel/csrc/ops/rms_norm/test/benchmark_rms_norm_msprof.py b/src/ascend/custom_kernel/csrc/ops/rms_norm/test/benchmark_rms_norm_msprof.py new file mode 100644 index 00000000..8a744545 --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops/rms_norm/test/benchmark_rms_norm_msprof.py @@ -0,0 +1,216 @@ +"""Performance benchmark orchestrator for RMSNorm using msprof.""" + +import csv +import glob +import json +import os +import subprocess + + +CASES_FILE = os.path.join(os.path.dirname(__file__), "rms_norm_cases.jsonl") +RUNNER_SCRIPT = os.path.join(os.path.dirname(__file__), "run_rms_norm_case.py") +MSPROF_BASE = "/tmp/msprof_rms_norm" + +# OP Type keyword for filtering in op_summary CSV. +OP_TYPE_KEYWORD = "rms_norm" + + +def load_cases(): + cases = [] + with open(CASES_FILE) as f: + for line in f: + line = line.strip() + + if line: + cases.append(json.loads(line)) + + return cases + + +def run_msprof(case, output_dir, iters=20, warmup=10): + """Run a single case under msprof profiling.""" + # Write a self-contained wrapper to avoid shell quoting issues. + os.makedirs(os.path.dirname(output_dir + "_") or ".", exist_ok=True) + wrapper = output_dir + "_run.py" + + with open(wrapper, "w") as f: + f.write( + "import json, torch, torch_npu, ascend_kernel\n" + f"case = {json.dumps(case)}\n" + "shape = tuple(case['shape'])\n" + "dtype = getattr(torch, case['dtype'])\n" + "eps = case['eps']\n" + "hidden_dim = shape[-1]\n" + "x = torch.randn(shape, dtype=dtype, device='npu')\n" + "w = torch.randn(hidden_dim, dtype=dtype, device='npu')\n" + f"for _ in range({warmup}):\n" + " _ = torch.ops.npu.rms_norm(x, w, eps)\n" + "torch.npu.synchronize()\n" + f"for _ in range({iters - warmup}):\n" + " _ = torch.ops.npu.rms_norm(x, w, eps)\n" + "torch.npu.synchronize()\n" + ) + + cmd = ( + f"msprof --output={output_dir} --task-time=l1 --runtime-api=on " + f'--application="python3 {wrapper}"' + ) + result = subprocess.run( + cmd, + shell=True, + capture_output=True, + text=True, + timeout=120, + ) + + try: + os.remove(wrapper) + except OSError: + pass + + if result.returncode != 0: + print(f" msprof FAILED for case {case['id']}: {result.stderr[-300:]}") + + return False + + return True + + +def parse_op_summary(output_dir, op_type_keyword): + """Parse msprof op_summary CSV for the target OP Type.""" + # Find the op_summary CSV. + pattern = os.path.join(output_dir, "**", "op_summary_*.csv") + csv_files = glob.glob(pattern, recursive=True) + + if not csv_files: + return None + + csv_file = csv_files[0] + results = [] + + with open(csv_file, newline="") as f: + reader = csv.DictReader(f) + + for row in reader: + op_type = row.get("OP Type", "") + + if op_type_keyword.lower() in op_type.lower(): + results.append(row) + + return results + + +def main(): + cases = load_cases() + print(f"Loaded {len(cases)} benchmark cases") + print("=" * 80) + + all_results = [] + + for case in cases: + case_id = case["id"] + desc = case["desc"] + output_dir = os.path.join(MSPROF_BASE, f"case_{case_id}") + print(f"[Case {case_id}] {desc} shape={case['shape']} dtype={case['dtype']}") + + ok = run_msprof(case, output_dir, iters=20, warmup=10) + + if not ok: + all_results.append( + { + "id": case_id, + "desc": desc, + "shape": str(case["shape"]), + "dtype": case["dtype"], + "status": "FAILED", + } + ) + continue + + rows = parse_op_summary(output_dir, OP_TYPE_KEYWORD) + + if not rows: + print(f" WARNING: No matching OP Type '{OP_TYPE_KEYWORD}' found") + all_results.append( + { + "id": case_id, + "desc": desc, + "shape": str(case["shape"]), + "dtype": case["dtype"], + "status": "NO_MATCH", + } + ) + continue + + # Aggregate Task Duration across matching rows. + durations = [] + + for row in rows: + dur = row.get("Task Duration(us)", "0") + + try: + durations.append(float(dur)) + except ValueError: + pass + + if durations: + avg_dur = sum(durations) / len(durations) + min_dur = min(durations) + max_dur = max(durations) + else: + avg_dur = min_dur = max_dur = 0.0 + + print( + f" Task Duration: avg={avg_dur:.2f}us min={min_dur:.2f}us max={max_dur:.2f}us ({len(durations)} calls)" + ) + + result = { + "id": case_id, + "desc": desc, + "shape": str(case["shape"]), + "dtype": case["dtype"], + "status": "OK", + "avg_duration_us": avg_dur, + "min_duration_us": min_dur, + "max_duration_us": max_dur, + "num_calls": len(durations), + } + + # Extract additional hardware metrics if available. + if rows: + for key in ["Task Wait Time(us)", "Block Dim"]: + val = rows[0].get(key, "") + + if val: + result[key] = val + + all_results.append(result) + + # Save JSON. + json_path = os.path.join(os.path.dirname(__file__), "rms_norm_perf.json") + + with open(json_path, "w") as f: + json.dump({"results": all_results}, f, indent=2) + + print(f"\n{'=' * 80}") + print(f"JSON results saved to: {json_path}") + + # Print summary table. + print( + f"\n{'ID':>3} {'Shape':>20} {'Dtype':>8} {'Avg(us)':>10} {'Min(us)':>10} {'Max(us)':>10} {'Calls':>6}" + ) + print("-" * 75) + + for r in all_results: + if r["status"] == "OK": + print( + f"{r['id']:>3} {r['shape']:>20} {r['dtype']:>8} " + f"{r['avg_duration_us']:>10.2f} {r['min_duration_us']:>10.2f} " + f"{r['max_duration_us']:>10.2f} {r['num_calls']:>6}" + ) + else: + print(f"{r['id']:>3} {r['shape']:>20} {r['dtype']:>8} {r['status']}") + + +if __name__ == "__main__": + main() diff --git a/src/ascend/custom_kernel/csrc/ops/rms_norm/test/rms_norm-test-cases.md b/src/ascend/custom_kernel/csrc/ops/rms_norm/test/rms_norm-test-cases.md new file mode 100644 index 00000000..ade46795 --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops/rms_norm/test/rms_norm-test-cases.md @@ -0,0 +1,117 @@ +# RMSNorm 用例设计文档 + +## 1. 算子标杆 + +PyTorch 参考实现: +```python +import torch + +def rms_norm_ref(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """CPU 参考实现,使用 float32 精度计算。""" + input_fp32 = input.float() + variance = input_fp32.pow(2).mean(dim=-1, keepdim=True) + hidden_states = input_fp32 * torch.rsqrt(variance + eps) + return (hidden_states * weight.float()).to(input.dtype) +``` + +NPU 调用方式(ascend_kernel 工程算子): +```python +import torch +import ascend_kernel + +# input: [*, hidden_dim], weight: [hidden_dim] +output = ascend_kernel.ops.rms_norm(input.npu(), weight.npu(), eps) +``` + +--- + +## 2. 用例说明 + +### 2.1 测试配置 + +```python +# 支持的数据类型 +SUPPORTED_DTYPES = [torch.float16, torch.float32] + +# 典型用例 — 模型常见 hidden_dim + batch 组合 +TEST_SHAPES = [ + # (category, description, input_shape, hidden_dim_is_last_dim) + ("2D", "small 32x128", (32, 128)), + ("2D", "medium 64x512", (64, 512)), + ("2D", "medium 128x1024", (128, 1024)), + ("2D", "Qwen/Llama 32x4096", (32, 4096)), + ("2D", "Qwen/Llama 128x4096", (128, 4096)), + ("2D", "Llama-70B 32x8192", (32, 8192)), + ("3D", "multi-head 4x32x128", (4, 32, 128)), + ("3D", "multi-head 8x64x512", (8, 64, 512)), + ("3D", "batch 4x128x4096", (4, 128, 4096)), +] + +# 泛化用例 — 边界和大规模场景 +GENERAL_SHAPES = [ + # 小 shape 场景(边界测试) + ("Small", "single row", (1, 128)), + ("Small", "single row 4096", (1, 4096)), + ("Small", "two rows", (2, 256)), + ("Small", "tiny 3D", (1, 1, 128)), + ("Small", "non-aligned rows 3", (3, 512)), + ("Small", "non-aligned rows 7", (7, 1024)), + + # 大 shape 场景(生产环境) + ("Large", "BERT-base 512x768", (512, 768)), + ("Large", "GPT-2 1024x1024", (1024, 1024)), + ("Large", "Llama batch 256x4096", (256, 4096)), + ("Large", "Llama-70B batch 64x8192", (64, 8192)), + ("Large", "3D large 8x512x4096", (8, 512, 4096)), +] + +# 边界值测试 — eps 和特殊输入 +BOUNDARY_VALUES = [ + ("eps_small", "very small eps", (32, 512), {"eps": 1e-12}), + ("eps_large", "large eps", (32, 512), {"eps": 1e-2}), + ("zeros", "all-zero input", (16, 1024), {"input_fill": 0.0}), + ("ones", "all-one input", (16, 1024), {"input_fill": 1.0}), + ("large_vals", "large input values", (16, 1024), {"input_scale": 100.0}), + ("small_vals", "tiny input values", (16, 1024), {"input_scale": 1e-4}), +] +``` + +### 2.2 用例覆盖统计 + +| 类别 | Shape 数量 | 边界值数量 | dtype 数量 | 总用例数 | +|------|-----------|-----------|-----------|---------| +| 常规形状 (TEST_SHAPES) | 9 | — | 2 | 18 | +| 泛化形状 (GENERAL_SHAPES) | 11 | — | 2 | 22 | +| 边界值 (BOUNDARY_VALUES) | — | 6 | 2 | 12 | +| **总计** | **20** | **6** | **2** | **52** | + +--- + +## 3. 使用说明 + +### 生成测试数据示例 + +```python +import torch + +def generate_rms_norm_inputs(shape, dtype, eps=1e-6, input_fill=None, input_scale=1.0): + """生成 rms_norm 测试输入。""" + hidden_dim = shape[-1] + weight = torch.randn(hidden_dim, dtype=dtype) + + if input_fill is not None: + input_tensor = torch.full(shape, input_fill, dtype=dtype) + else: + input_tensor = torch.randn(shape, dtype=dtype) * input_scale + + expected = rms_norm_ref(input_tensor, weight, eps) + + return input_tensor, weight, eps, expected +``` + +### 注意事项 + +1. **weight shape**:始终为 `[hidden_dim]`(1D),`hidden_dim = input.shape[-1]`。 +2. **eps 类型**:Python `float`(double),Host 端转 `float` 传给 kernel。 +3. **fp16 精度**:参考实现中先升精度到 float32 计算,结果再降回 float16。测试对比时应考虑 fp16 的精度损失(rtol=1e-3, atol=1e-3)。 +4. **全零输入**:`rsqrt(0 + eps)` 应正常工作,不应产生 nan/inf。 diff --git a/src/ascend/custom_kernel/csrc/ops/rms_norm/test/rms_norm_cases.jsonl b/src/ascend/custom_kernel/csrc/ops/rms_norm/test/rms_norm_cases.jsonl new file mode 100644 index 00000000..be9bc875 --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops/rms_norm/test/rms_norm_cases.jsonl @@ -0,0 +1,10 @@ +{"id": 1, "shape": [32, 128], "dtype": "float16", "eps": 1e-6, "desc": "small 2D fp16"} +{"id": 2, "shape": [32, 128], "dtype": "float32", "eps": 1e-6, "desc": "small 2D fp32"} +{"id": 3, "shape": [64, 512], "dtype": "float16", "eps": 1e-6, "desc": "medium 2D fp16"} +{"id": 4, "shape": [128, 1024], "dtype": "float16", "eps": 1e-6, "desc": "medium 2D fp16"} +{"id": 5, "shape": [32, 4096], "dtype": "float16", "eps": 1e-6, "desc": "Llama hidden_dim fp16"} +{"id": 6, "shape": [32, 4096], "dtype": "float32", "eps": 1e-6, "desc": "Llama hidden_dim fp32"} +{"id": 7, "shape": [128, 4096], "dtype": "float16", "eps": 1e-6, "desc": "Llama batch fp16"} +{"id": 8, "shape": [32, 8192], "dtype": "float16", "eps": 1e-6, "desc": "Llama-70B fp16"} +{"id": 9, "shape": [256, 4096], "dtype": "float16", "eps": 1e-6, "desc": "large batch fp16"} +{"id": 10, "shape": [512, 768], "dtype": "float16", "eps": 1e-6, "desc": "BERT-base fp16"} diff --git a/src/ascend/custom_kernel/csrc/ops/rms_norm/test/rms_norm_perf_report.md b/src/ascend/custom_kernel/csrc/ops/rms_norm/test/rms_norm_perf_report.md new file mode 100644 index 00000000..876240bf --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops/rms_norm/test/rms_norm_perf_report.md @@ -0,0 +1,35 @@ +# RMSNorm 性能评测报告 + +## 测试环境 + +- 硬件: Ascend 910B (NPU) +- CANN: 8.5.1 +- 采集工具: msprof (`--task-time=l1`) +- 迭代次数: 20 (前 10 次预热) + +## 性能结果 + +| Case | Shape | Dtype | Avg (us) | Min (us) | Max (us) | Calls | +|------|-------|-------|----------|----------|----------|-------| +| 1 | [32, 128] | float16 | 5.40 | 4.62 | 15.02 | 20 | +| 2 | [32, 128] | float32 | 5.65 | 4.96 | 13.22 | 20 | +| 3 | [64, 512] | float16 | 6.79 | 5.84 | 16.20 | 20 | +| 4 | [128, 1024] | float16 | 7.60 | 6.62 | 18.42 | 20 | +| 5 | [32, 4096] | float16 | 6.96 | 6.08 | 14.52 | 20 | +| 6 | [32, 4096] | float32 | 6.96 | 6.12 | 14.12 | 20 | +| 7 | [128, 4096] | float16 | 10.11 | 9.02 | 21.20 | 20 | +| 8 | [32, 8192] | float16 | 7.01 | 6.32 | 13.30 | 20 | +| 9 | [256, 4096] | float16 | 11.41 | 10.26 | 23.28 | 20 | +| 10 | [512, 768] | float16 | 11.40 | 10.36 | 24.06 | 20 | + +## 分析 + +1. **单 kernel 调用延迟极低**: 所有 shape 的平均 Task Duration 在 5-12 us 范围内,fused kernel 相比 CANN `aclnnRmsNorm` 的 5 个子 op (Pows + ReduceMean + Add + Rsqrt + Mul) 消除了 op 间调度开销。 + +2. **fp16 与 fp32 性能相当**: 同 shape 下 fp16 和 fp32 延迟几乎一致 (Case 5 vs 6: 6.96us vs 6.96us),说明瓶颈在内存带宽和调度而非计算。fp16 的 Cast 操作开销可忽略。 + +3. **延迟随 totalRows 线性增长**: 固定 `hidden_dim=4096`,从 32 行 (6.96us) 到 128 行 (10.11us) 到 256 行 (11.41us),增长趋势接近线性。当行数 < AI Core 数 (40) 时,多核并行有效隐藏了单行开销。 + +4. **hidden_dim 对延迟影响较小**: 固定 32 行,从 128 (5.40us) 到 4096 (6.96us) 到 8192 (7.01us),hidden_dim 增大 64 倍仅增加 ~30% 延迟。这是因为单行处理是 memory-bound (GM↔UB 搬运),vector 计算与搬运重叠。 + +5. **首次调用有冷启动开销**: max 值普遍是 min 的 2-3 倍,为首次 kernel 启动开销,后续调用稳定在 min 附近。 diff --git a/src/ascend/custom_kernel/csrc/ops/rms_norm/test/run_rms_norm_case.py b/src/ascend/custom_kernel/csrc/ops/rms_norm/test/run_rms_norm_case.py new file mode 100644 index 00000000..d7f9c9f6 --- /dev/null +++ b/src/ascend/custom_kernel/csrc/ops/rms_norm/test/run_rms_norm_case.py @@ -0,0 +1,40 @@ +"""Single-case msprof executor for RMSNorm performance benchmarking.""" + +import argparse +import json +import torch +import torch_npu # noqa: F401 Registers NPU device. +import ascend_kernel # noqa: F401 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--case", type=str, required=True) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--warmup", type=int, default=10) + args = parser.parse_args() + + case = json.loads(args.case) + shape = tuple(case["shape"]) + dtype = getattr(torch, case["dtype"]) + eps = case["eps"] + hidden_dim = shape[-1] + + x = torch.randn(shape, dtype=dtype, device="npu") + w = torch.randn(hidden_dim, dtype=dtype, device="npu") + + # Warmup. + for _ in range(args.warmup): + _ = torch.ops.npu.rms_norm(x, w, eps) + + torch.npu.synchronize() + + # Timed iterations. + for _ in range(args.iters - args.warmup): + _ = torch.ops.npu.rms_norm(x, w, eps) + + torch.npu.synchronize() + + +if __name__ == "__main__": + main() diff --git a/src/ascend/custom_kernel/csrc/register.cpp b/src/ascend/custom_kernel/csrc/register.cpp new file mode 100644 index 00000000..94fe44b3 --- /dev/null +++ b/src/ascend/custom_kernel/csrc/register.cpp @@ -0,0 +1,24 @@ +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "ops.h" + +namespace { +TORCH_LIBRARY_FRAGMENT(npu, m) { + m.def("rms_norm(Tensor input, Tensor weight, float eps=1e-6) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(npu, PrivateUse1, m) { + m.impl("rms_norm", TORCH_FN(ascend_kernel::rms_norm)); +} +} // namespace diff --git a/src/ascend/custom_kernel/csrc/utils/torch_kernel_helper.h b/src/ascend/custom_kernel/csrc/utils/torch_kernel_helper.h new file mode 100644 index 00000000..f816842f --- /dev/null +++ b/src/ascend/custom_kernel/csrc/utils/torch_kernel_helper.h @@ -0,0 +1,80 @@ +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TORCH_KERNEL_HELPER_H +#define TORCH_KERNEL_HELPER_H + +#include +#include + +#include "torch_npu/csrc/core/npu/NPUStream.h" +#include "torch_npu/csrc/framework/OpCommand.h" + +namespace ascend_kernel { + +#define DEVICE_TYPE c10::DeviceType::PrivateUse1 + +class TorchNpuHelper { + public: + inline static at::Tensor CopyTensorHostToDevice( + const at::Tensor& cpu_tensor) { + at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory(); + int deviceIndex = 0; + c10_npu::GetDevice(&deviceIndex); + return cpuPinMemTensor.to(c10::Device(DEVICE_TYPE, deviceIndex), + cpuPinMemTensor.scalar_type(), true, true); + } + + inline static at::Tensor CopyScalarToDevice(const c10::Scalar& cpu_scalar, + at::ScalarType scalar_data_type) { + return CopyTensorHostToDevice( + scalar_to_tensor(cpu_scalar).to(scalar_data_type)); + } + + inline static void* ConvertType(const at::Tensor& at_tensor) { + return const_cast(at_tensor.data_ptr()); + } + + template + inline static T ConvertType(T value) { + return value; + } + + template + inline static constexpr auto ConvertTypes(Ts&... args) { + return std::make_tuple(ConvertType(args)...); + } +}; +} // namespace ascend_kernel + +/** + * @brief Launch real kernel function on NPU + * + * @param kernel_name [in] name of kernel + * @param blockdim [in] dim size of block + */ +#define EXEC_KERNEL_CMD(kernel_name, blockdim, ...) \ + do { \ + auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ + auto converted_params = \ + ascend_kernel::TorchNpuHelper::ConvertTypes(__VA_ARGS__); \ + auto acl_call = [acl_stream, blockdim, converted_params]() -> int { \ + std::apply( \ + [&](auto&&... params) { \ + ACLRT_LAUNCH_KERNEL(kernel_name) \ + (blockdim, acl_stream, params...); \ + }, \ + converted_params); \ + return 0; \ + }; \ + at_npu::native::OpCommand::RunOpApi(#kernel_name, acl_call); \ + } while (false) + +#endif // TORCH_KERNEL_HELPER_H diff --git a/src/ascend/custom_kernel/tests/test_add_rms_norm.py b/src/ascend/custom_kernel/tests/test_add_rms_norm.py new file mode 100644 index 00000000..23f62bed --- /dev/null +++ b/src/ascend/custom_kernel/tests/test_add_rms_norm.py @@ -0,0 +1,91 @@ +"""Correctness tests for custom AscendC add_rms_norm kernel.""" + +import pytest +import torch +import torch_npu # noqa: F401 Registers NPU device. +import ascend_kernel # noqa: F401 Loads `libascend_kernel.so` into `torch.ops.npu`. + + +def _ref_add_rms_norm(x1, x2, weight, eps): + """Reference implementation on CPU (float64 for precision).""" + x1_f64 = x1.double() + x2_f64 = x2.double() + w_f64 = weight.double() + + x_out = x1_f64 + x2_f64 + variance = x_out.pow(2).mean(dim=-1, keepdim=True) + y = x_out * torch.rsqrt(variance + eps) * w_f64 + + return y.to(x1.dtype), x_out.to(x1.dtype) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize( + "shape", + [ + (1, 128), + (4, 256), + (8, 512), + (32, 896), # Qwen 0.5B hidden_dim. + (16, 2048), # Qwen 3B hidden_dim. + (8, 3584), # Qwen 7B hidden_dim. + (1, 4096), # LLaMA hidden_dim. + (64, 896), # Larger batch. + ], +) +def test_add_rms_norm_correctness(dtype, shape): + """Verify custom kernel output matches CPU reference.""" + eps = 1e-6 + rows, dim = shape + + x1 = torch.randn(rows, dim, dtype=dtype, device="npu") + x2 = torch.randn(rows, dim, dtype=dtype, device="npu") + weight = torch.randn(dim, dtype=dtype, device="npu") + + # Run custom kernel. + result = torch.ops.npu.add_rms_norm(x1, x2, weight, eps) + y_npu = result[0] + x_out_npu = result[1] + + # Run CPU reference. + y_ref, x_out_ref = _ref_add_rms_norm(x1.cpu(), x2.cpu(), weight.cpu(), eps) + + # Check x_out = x1 + x2. + rtol_xout = 1e-3 if dtype == torch.float16 else 1e-5 + atol_xout = 1e-3 if dtype == torch.float16 else 1e-5 + assert torch.allclose(x_out_npu.cpu(), x_out_ref, rtol=rtol_xout, atol=atol_xout), ( + f"x_out mismatch: max_diff={(x_out_npu.cpu() - x_out_ref).abs().max().item()}" + ) + + # Check `y = rms_norm(x_out) * weight`. + rtol = 1e-3 if dtype == torch.float16 else 1e-5 + atol = 1e-3 if dtype == torch.float16 else 1e-5 + assert torch.allclose(y_npu.cpu(), y_ref, rtol=rtol, atol=atol), ( + f"y mismatch: max_diff={(y_npu.cpu() - y_ref).abs().max().item()}" + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_add_rms_norm_3d(dtype): + """Verify 3D input (batch, nhead, dim) works correctly.""" + eps = 1e-6 + batch, nhead, dim = 4, 8, 128 + + x1 = torch.randn(batch, nhead, dim, dtype=dtype, device="npu") + x2 = torch.randn(batch, nhead, dim, dtype=dtype, device="npu") + weight = torch.randn(dim, dtype=dtype, device="npu") + + result = torch.ops.npu.add_rms_norm(x1, x2, weight, eps) + y_npu = result[0] + x_out_npu = result[1] + + y_ref, x_out_ref = _ref_add_rms_norm(x1.cpu(), x2.cpu(), weight.cpu(), eps) + + rtol = 1e-3 if dtype == torch.float16 else 1e-5 + atol = 1e-3 if dtype == torch.float16 else 1e-5 + assert torch.allclose(x_out_npu.cpu(), x_out_ref, rtol=rtol, atol=atol) + assert torch.allclose(y_npu.cpu(), y_ref, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/src/ascend/custom_kernel/tests/test_rms_norm.py b/src/ascend/custom_kernel/tests/test_rms_norm.py new file mode 100644 index 00000000..f09a6d4f --- /dev/null +++ b/src/ascend/custom_kernel/tests/test_rms_norm.py @@ -0,0 +1,127 @@ +"""Functional and precision tests for the RMSNorm AscendC kernel.""" + +import pytest +import torch +import torch_npu # noqa: F401 Registers NPU device. +import ascend_kernel # noqa: F401 Loads `libascend_kernel.so` into `torch.ops.npu`. + + +def rms_norm_ref(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: + """CPU reference implementation in float32.""" + x_fp32 = x.float() + variance = x_fp32.pow(2).mean(dim=-1, keepdim=True) + hidden = x_fp32 * torch.rsqrt(variance + eps) + + return (hidden * weight.float()).to(x.dtype) + + +DTYPES = [torch.float16, torch.float32] + +TEST_SHAPES = [ + (32, 128), + (64, 512), + (128, 1024), + (32, 4096), + (128, 4096), + (32, 8192), + (4, 32, 128), + (8, 64, 512), + (4, 128, 4096), +] + +GENERAL_SHAPES = [ + (1, 128), + (1, 4096), + (2, 256), + (1, 1, 128), + (3, 512), + (7, 1024), + (512, 768), + (1024, 1024), + (256, 4096), + (64, 8192), + (8, 512, 4096), +] + + +def _tolerance(dtype): + if dtype == torch.float16: + return dict(rtol=1e-3, atol=1e-3) + + return dict(rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize("dtype", DTYPES, ids=lambda d: str(d).split(".")[-1]) +@pytest.mark.parametrize( + "shape", TEST_SHAPES + GENERAL_SHAPES, ids=lambda s: "x".join(map(str, s)) +) +def test_rms_norm_shapes(shape, dtype): + eps = 1e-6 + hidden_dim = shape[-1] + x = torch.randn(shape, dtype=dtype) + w = torch.randn(hidden_dim, dtype=dtype) + ref = rms_norm_ref(x, w, eps) + out = torch.ops.npu.rms_norm(x.npu(), w.npu(), eps) + tol = _tolerance(dtype) + assert torch.allclose(out.cpu(), ref, **tol), ( + f"shape={shape} dtype={dtype} " + f"max_abs_err={torch.max(torch.abs(out.cpu() - ref)).item():.6e}" + ) + + +@pytest.mark.parametrize("dtype", DTYPES, ids=lambda d: str(d).split(".")[-1]) +@pytest.mark.parametrize( + "case", + [ + ("eps_small", (32, 512), {"eps": 1e-12}), + ("eps_large", (32, 512), {"eps": 1e-2}), + ("zeros", (16, 1024), {"input_fill": 0.0}), + ("ones", (16, 1024), {"input_fill": 1.0}), + ("large_vals", (16, 1024), {"input_scale": 100.0}), + ("small_vals", (16, 1024), {"input_scale": 1e-4}), + ], + ids=lambda c: c[0], +) +def test_rms_norm_boundary(case, dtype): + name, shape, opts = case + eps = opts.get("eps", 1e-6) + hidden_dim = shape[-1] + fill = opts.get("input_fill", None) + scale = opts.get("input_scale", 1.0) + + if fill is not None: + x = torch.full(shape, fill, dtype=dtype) + else: + x = torch.randn(shape, dtype=dtype) * scale + + w = torch.randn(hidden_dim, dtype=dtype) + ref = rms_norm_ref(x, w, eps) + out = torch.ops.npu.rms_norm(x.npu(), w.npu(), eps) + tol = _tolerance(dtype) + assert torch.allclose(out.cpu(), ref, **tol), ( + f"case={name} dtype={dtype} " + f"max_abs_err={torch.max(torch.abs(out.cpu() - ref)).item():.6e}" + ) + + +if __name__ == "__main__": + print("Running quick functional test...") + x = torch.randn(4, 128, dtype=torch.float16) + w = torch.randn(128, dtype=torch.float16) + ref = rms_norm_ref(x, w, 1e-6) + out = torch.ops.npu.rms_norm(x.npu(), w.npu(), 1e-6) + max_err = torch.max(torch.abs(out.cpu() - ref)).item() + print( + f" fp16 (4,128): max_abs_err = {max_err:.6e} ... {'PASS' if max_err < 1e-3 else 'FAIL'}" + ) + + x = torch.randn(4, 128, dtype=torch.float32) + w = torch.randn(128, dtype=torch.float32) + ref = rms_norm_ref(x, w, 1e-6) + out = torch.ops.npu.rms_norm(x.npu(), w.npu(), 1e-6) + max_err = torch.max(torch.abs(out.cpu() - ref)).item() + print( + f" fp32 (4,128): max_abs_err = {max_err:.6e} ... {'PASS' if max_err < 1e-5 else 'FAIL'}" + ) + + print("Quick test done.") diff --git a/src/ascend/data_type_.h b/src/ascend/data_type_.h index 9026f515..08b1541b 100644 --- a/src/ascend/data_type_.h +++ b/src/ascend/data_type_.h @@ -9,8 +9,14 @@ namespace infini::ops::ascend { -inline aclDataType ToAclDtype(DataType dt) { +inline aclDataType toAclDtype(DataType dt) { switch (dt) { + case DataType::kFloat16: + return ACL_FLOAT16; + case DataType::kBFloat16: + return ACL_BF16; + case DataType::kFloat32: + return ACL_FLOAT; case DataType::kInt8: return ACL_INT8; case DataType::kInt16: @@ -27,20 +33,14 @@ inline aclDataType ToAclDtype(DataType dt) { return ACL_UINT32; case DataType::kUInt64: return ACL_UINT64; - case DataType::kFloat16: - return ACL_FLOAT16; - case DataType::kBFloat16: - return ACL_BF16; - case DataType::kFloat32: - return ACL_FLOAT; default: - assert(false && "Unsupported dtype for Ascend backend."); + assert(false && "unsupported dtype for Ascend backend"); return ACL_DT_UNDEFINED; } } -// Returns true for integer (signed or unsigned) `DataType` values. -inline bool IsIntegerDtype(DataType dt) { +// Returns true for integer (signed or unsigned) DataType values. +inline bool isIntegerDtype(DataType dt) { switch (dt) { case DataType::kInt8: case DataType::kInt16: diff --git a/src/ascend/device_.h b/src/ascend/device_.h index 1b246ad3..b4ec934d 100644 --- a/src/ascend/device_.h +++ b/src/ascend/device_.h @@ -1,7 +1,10 @@ #ifndef INFINI_OPS_ASCEND_DEVICE__H_ #define INFINI_OPS_ASCEND_DEVICE__H_ -#include "device.h" +// NOTE: Cannot use `#include "device.h"` here — GCC resolves quoted includes +// relative to the current file first, and `src/ascend/` used to contain a +// `device.h`. Use `data_type.h` which transitively pulls in `src/device.h`. +#include "data_type.h" namespace infini::ops { diff --git a/src/ascend/flash_attention/kernel.h b/src/ascend/flash_attention/kernel.h new file mode 100644 index 00000000..350f8b4c --- /dev/null +++ b/src/ascend/flash_attention/kernel.h @@ -0,0 +1,360 @@ +#ifndef INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_ +#define INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_ + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_fused_infer_attention_score_v4.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/flash_attention.h" +#include "operator.h" + +namespace infini::ops { + +namespace detail { + +// Extract cu_seqlens differences to a host aclIntArray. +// cu_seqlens = [0, s1, s1+s2, ...] -> per_seq_lens = [s1, s2, ...]. +// Used by paged decode (actualSeqLengthsKv = per-sequence KV lengths). +// +// When cu_seqlens is a CPU tensor (device type kCpu), the data pointer is +// already on the host and can be read directly — no D2H sync needed. +inline aclIntArray* extractSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto n = cu_seqlens.numel(); + + const int64_t* cu_host_ptr = nullptr; + std::vector cu_host_buf; + + if (cu_seqlens.device().type() == Device::Type::kCpu) { + cu_host_ptr = static_cast(cu_seqlens.data()); + } else { + cu_host_buf.resize(n); + aclrtMemcpyAsync(cu_host_buf.data(), n * sizeof(int64_t), cu_seqlens.data(), + n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + cu_host_ptr = cu_host_buf.data(); + } + + std::vector lengths(n - 1); + for (size_t i = 0; i < lengths.size(); ++i) { + lengths[i] = cu_host_ptr[i + 1] - cu_host_ptr[i]; + } + + return aclCreateIntArray(lengths.data(), + static_cast(lengths.size())); +} + +// Extract cumulative end positions from cu_seqlens to a host aclIntArray. +// cu_seqlens = [0, s1, s1+s2, ...] -> cum_lens = [s1, s1+s2, ...]. +// FIA V4 TND varlen uses cumulative end positions, matching the vllm-ascend +// convention for npu_fused_infer_attention_score actual_seq_lengths. +// +// When cu_seqlens is a CPU tensor, reads directly from host memory. +inline aclIntArray* cumSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto n = cu_seqlens.numel(); + + const int64_t* cu_host_ptr = nullptr; + std::vector cu_host_buf; + + if (cu_seqlens.device().type() == Device::Type::kCpu) { + cu_host_ptr = static_cast(cu_seqlens.data()); + } else { + cu_host_buf.resize(n); + aclrtMemcpyAsync(cu_host_buf.data(), n * sizeof(int64_t), cu_seqlens.data(), + n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + cu_host_ptr = cu_host_buf.data(); + } + + // Skip the leading 0; return [s1, s1+s2, ...]. + return aclCreateIntArray(cu_host_ptr + 1, static_cast(n - 1)); +} + +// Allocate a 2048x2048 lower-triangular UINT8 causal mask on device. +// Required for `sparseMode` >= 2. +inline aclTensor* makeCausalMask(void** mask_buf, aclrtStream stream) { + constexpr int64_t kMaskDim = 2048; + const int64_t mask_elems = kMaskDim * kMaskDim; + const size_t mask_bytes = static_cast(mask_elems); // uint8_t + + aclrtMalloc(mask_buf, mask_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + std::vector host_mask(mask_elems); + for (int64_t r = 0; r < kMaskDim; ++r) { + for (int64_t c = 0; c < kMaskDim; ++c) { + // 1 = masked out (upper triangle); 0 = attend (lower triangle). + host_mask[r * kMaskDim + c] = (c > r) ? 1 : 0; + } + } + aclrtMemcpyAsync(*mask_buf, mask_bytes, host_mask.data(), mask_bytes, + ACL_MEMCPY_HOST_TO_DEVICE, stream); + aclrtSynchronizeStream(stream); + + std::vector mask_shape = {kMaskDim, kMaskDim}; + std::vector mask_strides = {kMaskDim, 1}; + std::vector mask_storage = {mask_elems}; + return aclCreateTensor(mask_shape.data(), 2, ACL_UINT8, mask_strides.data(), + 0, ACL_FORMAT_ND, mask_storage.data(), 1, *mask_buf); +} + +} // namespace detail + +template <> +class Operator : public FlashAttention { + public: + Operator(const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, bool causal, + int64_t window_left, int64_t window_right, int64_t block_size, + Tensor output) + : FlashAttention(query, key, value, cu_seqlens_q, cu_seqlens_kv, + block_table, num_heads, num_kv_heads, head_size, scale, + causal, window_left, window_right, block_size, output) { + paged_ = block_table.has_value() && block_size > 0; + aclDataType acl_dt = ascend::toAclDtype(query.dtype()); + + if (!paged_) { + // Prefill: cache Q and output (TND layout). + prefill_q_cache_ = ascend::AclTensorCache(query); + prefill_out_cache_ = ascend::AclTensorCache(output); + + // Pre-compute causal mask once (sparse_mode >= 2). + if (causal) { + int64_t sm = (window_left >= 0) ? 4 : 3; + if (sm >= 2) { + causal_mask_ = detail::makeCausalMask(&causal_mask_buf_, nullptr); + } + } + } else { + // Decode: cache Q/output (BNSD), block_table. + const int64_t N = query.size(1); + const int64_t D = query.size(2); + const int64_t B = query.size(0); + + decode_q_cache_ = ascend::AclTensorCache({B, N, 1, D}, acl_dt, + const_cast(query.data())); + decode_out_cache_ = + ascend::AclTensorCache({B, N, 1, D}, acl_dt, output.data()); + block_table_cache_ = ascend::AclTensorCache(block_table.value()); + + // Pre-compute KV reshape metadata. + const int64_t nb = key.size(0); + const int64_t bsz = key.size(1); + const int64_t NkvD = key.size(2) * key.size(3); + kv_shape_ = {nb, bsz, NkvD}; + kv_strides_ = {bsz * NkvD, NkvD, 1}; + kv_storage_shape_ = {nb * bsz * NkvD}; + kv_acl_dt_ = acl_dt; + } + } + + ~Operator() { + if (causal_mask_) aclDestroyTensor(causal_mask_); + if (causal_mask_buf_) aclrtFree(causal_mask_buf_); + } + + void operator()(const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + bool causal, int64_t window_left, int64_t window_right, + int64_t block_size, Tensor output) const override { + auto stream = static_cast(stream_); + const bool paged = paged_; + + int64_t sparse_mode; + int64_t pre_tokens = 2147483647; + int64_t next_tokens = 2147483647; + if (causal) { + if (window_left >= 0) { + sparse_mode = 4; + pre_tokens = window_left; + next_tokens = 0; + } else { + sparse_mode = 3; + next_tokens = 0; + } + } else { + sparse_mode = 0; + if (window_left >= 0) pre_tokens = window_left; + if (window_right >= 0) next_tokens = window_right; + } + + if (!paged) { + // --- Prefill --- + int64_t T = query.size(0); + + // cumSeqLengths / extractSeqLengths automatically skip D2H when + // cu_seqlens is a CPU tensor (see detail:: helpers above). + aclIntArray* seq_q = + cu_seqlens_q.has_value() + ? detail::cumSeqLengths(cu_seqlens_q.value(), stream) + : aclCreateIntArray(&T, 1); + aclIntArray* seq_kv = + cu_seqlens_kv.has_value() + ? detail::cumSeqLengths(cu_seqlens_kv.value(), stream) + : aclCreateIntArray(&T, 1); + + aclTensor* t_q = prefill_q_cache_.get(const_cast(query.data())); + // K/V descriptors go into TensorList which takes ownership — must be + // per-call (cannot cache). + aclTensor* t_k = ascend::buildAclTensor(key); + aclTensor* t_v = ascend::buildAclTensor(value); + aclTensor* t_out = prefill_out_cache_.get(output.data()); + + const aclTensor* k_arr[] = {t_k}; + const aclTensor* v_arr[] = {t_v}; + aclTensorList* key_list = aclCreateTensorList(k_arr, 1); + aclTensorList* val_list = aclCreateTensorList(v_arr, 1); + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( + t_q, key_list, val_list, + nullptr, // pseShift + causal_mask_, // attenMask (pre-computed, or nullptr) + seq_q, // actualSeqLengths + seq_kv, // actualSeqLengthsKv + nullptr, nullptr, nullptr, nullptr, + nullptr, // deqScale1..quantOffset2 + nullptr, nullptr, // antiquantScale, antiquantOffset + nullptr, // blockTable + nullptr, nullptr, // queryPaddingSize, kvPaddingSize + nullptr, nullptr, nullptr, + nullptr, // key/value antiquant scale/offset + nullptr, nullptr, + nullptr, // keySharedPrefix, valueSharedPrefix, actualSharedPrefixLen + nullptr, nullptr, + nullptr, // queryRope, keyRope, keyRopeAntiquantScale + nullptr, nullptr, // dequantScaleQuery, learnableSink + num_heads, scale, pre_tokens, next_tokens, const_cast("TND"), + num_kv_heads, sparse_mode, + 0, // innerPrecise + 0, // blockSize (unused for prefill) + 0, false, // antiquantMode, softmaxLseFlag + 0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode + t_out, nullptr, &ws_needed, &executor); + assert( + gws == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (prefill)"); + + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclError ret = aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, + executor, stream); + assert(ret == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4 failed (prefill)"); + + // t_q and t_out are owned by caches — do NOT destroy. + // t_k and t_v are owned by TensorLists. + aclDestroyTensorList(key_list); + aclDestroyTensorList(val_list); + aclDestroyIntArray(seq_q); + aclDestroyIntArray(seq_kv); + return; + } + + // --- Paged decode --- + assert(cu_seqlens_kv.has_value() && + "`FlashAttention` paged decode requires `cu_seqlens_kv`"); + + aclTensor* t_query = decode_q_cache_.get(const_cast(query.data())); + aclTensor* t_output = decode_out_cache_.get(output.data()); + + // K/V descriptors go into TensorList which takes ownership — must be + // per-call. Use pre-computed metadata to avoid heap allocs. + aclTensor* t_key = aclCreateTensor( + kv_shape_.data(), static_cast(kv_shape_.size()), kv_acl_dt_, + kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(), + static_cast(kv_storage_shape_.size()), + const_cast(key.data())); + aclTensor* t_value = aclCreateTensor( + kv_shape_.data(), static_cast(kv_shape_.size()), kv_acl_dt_, + kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(), + static_cast(kv_storage_shape_.size()), + const_cast(value.data())); + + // extractSeqLengths skips D2H when cu_seqlens_kv is a CPU tensor. + aclIntArray* seq_kv = + detail::extractSeqLengths(cu_seqlens_kv.value(), stream); + aclTensor* t_block_table = + block_table_cache_.get(const_cast(block_table.value().data())); + + const aclTensor* k_arr[] = {t_key}; + const aclTensor* v_arr[] = {t_value}; + aclTensorList* key_list = aclCreateTensorList(k_arr, 1); + aclTensorList* val_list = aclCreateTensorList(v_arr, 1); + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( + t_query, key_list, val_list, + nullptr, // pseShift + nullptr, // attenMask (sparseMode ignored for Q_S=1) + nullptr, // actualSeqLengths (ignored for Q_S=1) + seq_kv, // actualSeqLengthsKv (mandatory for paged) + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + t_block_table, // blockTable + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, num_heads, scale, + static_cast(2147483647), static_cast(2147483647), + const_cast("BNSD"), num_kv_heads, + 0, // sparseMode=0 (ignored for Q_S=1) + 0, // innerPrecise + block_size, // blockSize + 0, false, // antiquantMode, softmaxLseFlag + 0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode + t_output, nullptr, &ws_needed, &executor); + assert(gws == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (decode)"); + + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclError ret = + aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, executor, stream); + assert(ret == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4 failed (decode)"); + + // t_query, t_output, t_block_table owned by caches — do NOT destroy. + // t_key, t_value owned by TensorLists. + aclDestroyTensorList(key_list); + aclDestroyTensorList(val_list); + aclDestroyIntArray(seq_kv); + } + + private: + bool paged_ = false; + + mutable ascend::AclTensorCache prefill_q_cache_; + + mutable ascend::AclTensorCache prefill_out_cache_; + + mutable ascend::AclTensorCache decode_q_cache_; + + mutable ascend::AclTensorCache decode_out_cache_; + + mutable ascend::AclTensorCache block_table_cache_; + + aclTensor* causal_mask_ = nullptr; + + void* causal_mask_buf_ = nullptr; + + std::vector kv_shape_; + + std::vector kv_strides_; + + std::vector kv_storage_shape_; + + aclDataType kv_acl_dt_ = ACL_DT_UNDEFINED; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/gemm/kernel.h b/src/ascend/gemm/kernel.h index 5f32e272..a59d6249 100644 --- a/src/ascend/gemm/kernel.h +++ b/src/ascend/gemm/kernel.h @@ -21,12 +21,17 @@ class Operator : public Gemm { : Gemm(a, b, alpha, beta, trans_a, trans_b, c), batched_{batch_count_ > 1}, alpha_val_{alpha.value_or(1.0f)}, - beta_val_{beta.value_or(1.0f)} { + beta_val_{beta.value_or(1.0f)}, + self_cache_(c), + a_cache_(a, trans_a_), + b_cache_(b, trans_b_), + out_cache_(c) { alpha_scalar_ = aclCreateScalar(&alpha_val_, ACL_FLOAT); beta_scalar_ = aclCreateScalar(&beta_val_, ACL_FLOAT); } ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); aclDestroyScalar(alpha_scalar_); aclDestroyScalar(beta_scalar_); } @@ -36,35 +41,36 @@ class Operator : public Gemm { std::optional trans_b, Tensor c) const override { auto stream = static_cast(stream_); - auto t_self = ascend::buildAclTensor(c); - auto t_a = ascend::buildAclTensor(a, trans_a_); - auto t_b = ascend::buildAclTensor(b, trans_b_); - auto t_out = ascend::buildAclTensor(c); - - uint64_t ws_needed = 0; - aclOpExecutor* executor = nullptr; - - if (batched_) { - aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, - alpha_scalar_, t_out, 0, &ws_needed, - &executor); + auto t_self = self_cache_.get(c.data()); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(c.data()); + + if (!executor_) { + if (batched_) { + aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } else { + aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } + aclSetAclOpExecutorRepeatable(executor_); } else { - aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, alpha_scalar_, - t_out, 0, &ws_needed, &executor); + aclSetInputTensorAddr(executor_, 0, t_self, c.data()); + aclSetInputTensorAddr(executor_, 1, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 2, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, c.data()); } - auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); if (batched_) { - aclnnBaddbmm(arena.buf, ws_needed, executor, stream); + aclnnBaddbmm(arena.buf, ws_size_, executor_, stream); } else { - aclnnAddmm(arena.buf, ws_needed, executor, stream); + aclnnAddmm(arena.buf, ws_size_, executor_, stream); } - - aclDestroyTensor(t_self); - aclDestroyTensor(t_a); - aclDestroyTensor(t_b); - aclDestroyTensor(t_out); } private: @@ -77,6 +83,18 @@ class Operator : public Gemm { aclScalar* alpha_scalar_ = nullptr; aclScalar* beta_scalar_ = nullptr; + + mutable ascend::AclTensorCache self_cache_; + + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; }; } // namespace infini::ops diff --git a/src/ascend/linear/kernel.h b/src/ascend/linear/kernel.h new file mode 100644 index 00000000..d8233d84 --- /dev/null +++ b/src/ascend/linear/kernel.h @@ -0,0 +1,118 @@ +#ifndef INFINI_OPS_ASCEND_LINEAR_KERNEL_H_ +#define INFINI_OPS_ASCEND_LINEAR_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_addmm.h" +#include "aclnnop/aclnn_baddbmm.h" +#include "aclnnop/aclnn_matmul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/linear.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Linear { + public: + Operator(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : Linear(a, b, bias, trans_a, trans_b, out), + batched_{out.ndim() > 2}, + a_cache_(a, trans_a), + b_cache_(b, trans_b), + out_cache_(out) { + if (has_bias_) { + bias_cache_ = ascend::AclTensorCache(*bias); + alpha_scalar_ = aclCreateScalar(&alpha_storage_, ACL_FLOAT); + beta_scalar_ = aclCreateScalar(&beta_storage_, ACL_FLOAT); + } + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + if (alpha_scalar_) aclDestroyScalar(alpha_scalar_); + if (beta_scalar_) aclDestroyScalar(beta_scalar_); + } + + void operator()(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const override { + auto stream = static_cast(stream_); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(out.data()); + + if (has_bias_) { + auto t_bias = bias_cache_.get(const_cast(bias->data())); + + if (!executor_) { + if (batched_) { + aclnnBaddbmmGetWorkspaceSize(t_bias, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } else { + aclnnAddmmGetWorkspaceSize(t_bias, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_bias, + const_cast(bias->data())); + aclSetInputTensorAddr(executor_, 1, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 2, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + + if (batched_) { + aclnnBaddbmm(arena.buf, ws_size_, executor_, stream); + } else { + aclnnAddmm(arena.buf, ws_size_, executor_, stream); + } + } else { + if (!executor_) { + int8_t cube_math_type = 1; + aclnnMatmulGetWorkspaceSize(t_a, t_b, t_out, cube_math_type, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 1, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnMatmul(arena.buf, ws_size_, executor_, stream); + } + } + + private: + bool batched_; + + mutable ascend::AclTensorCache bias_cache_; + + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + float alpha_storage_ = 1.0f; + + float beta_storage_ = 1.0f; + + aclScalar* alpha_scalar_ = nullptr; + + aclScalar* beta_scalar_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/matmul/kernel.h b/src/ascend/matmul/kernel.h new file mode 100644 index 00000000..2d98c23f --- /dev/null +++ b/src/ascend/matmul/kernel.h @@ -0,0 +1,63 @@ +#ifndef INFINI_OPS_ASCEND_MATMUL_KERNEL_H_ +#define INFINI_OPS_ASCEND_MATMUL_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_matmul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/matmul.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Matmul { + public: + Operator(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) + : Matmul(a, b, c, trans_a, trans_b), + a_cache_(a, trans_a), + b_cache_(b, trans_b), + out_cache_(c) {} + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + } + + void operator()(const Tensor a, const Tensor b, Tensor c, bool trans_a, + bool trans_b) const override { + auto stream = static_cast(stream_); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(c.data()); + + if (!executor_) { + int8_t cube_math_type = 1; + aclnnMatmulGetWorkspaceSize(t_a, t_b, t_out, cube_math_type, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 1, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, c.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnMatmul(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/mul/kernel.h b/src/ascend/mul/kernel.h new file mode 100644 index 00000000..38a09869 --- /dev/null +++ b/src/ascend/mul/kernel.h @@ -0,0 +1,63 @@ +#ifndef INFINI_OPS_ASCEND_MUL_KERNEL_H_ +#define INFINI_OPS_ASCEND_MUL_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_mul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/mul.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Mul { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Mul(input, other, out), + in_cache_(input), + oth_cache_(other), + out_cache_(out) {} + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_oth = oth_cache_.get(const_cast(other.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnMulGetWorkspaceSize(t_in, t_oth, t_out, &ws_size_, &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_oth, + const_cast(other.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnMul(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache oth_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/paged_attention/kernel_atb.h b/src/ascend/paged_attention/kernel_atb.h new file mode 100644 index 00000000..8e08e268 --- /dev/null +++ b/src/ascend/paged_attention/kernel_atb.h @@ -0,0 +1,246 @@ +#ifndef INFINI_OPS_ASCEND_PAGED_ATTENTION_KERNEL_ATB_H_ +#define INFINI_OPS_ASCEND_PAGED_ATTENTION_KERNEL_ATB_H_ + +#ifdef INFINI_HAS_ATB + +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "ascend/atb_common_.h" +#include "ascend/paged_attention/registry.h" +#include "ascend/workspace_pool_.h" +#include "atb/context.h" +#include "atb/infer_op_params.h" +#include "atb/operation.h" +#include "atb/types.h" +#include "base/paged_attention.h" +#include "operator.h" + +namespace infini::ops { + +// ATB-based paged decode attention (implementation index 0). +// +// Wraps ATB `PagedAttentionParam` with the default `inputLayout` +// (`TYPE_BSND`). For decode (single token per request) the S +// dimension is implicitly 1, so query and output use 3D shape +// [batch, num_heads, head_size] matching vLLM's convention. +// +// ATB internally constructs `aclIntArray*` from the `hostData` field +// of `block_table` and `context_lens` tensors. The operator performs +// synchronous D2H copies for these two small tensors in each call. +// All other tensors are device-only. +// +// ATB `VariantPack` layout (BSND with S=1): +// inTensors[0] = query [B, N, D] +// inTensors[1] = key_cache [num_blocks, block_size, Nkv, D] +// inTensors[2] = value_cache [num_blocks, block_size, Nkv, D] +// inTensors[3] = block_table [B, max_num_blocks] (device + host) +// inTensors[4] = context_lens [B] (int32) (device + host) +// outTensors[0] = output [B, N, D] +template <> +class Operator + : public PagedAttention { + public: + Operator(const Tensor query, const Tensor key_cache, const Tensor value_cache, + const Tensor seq_lens, const Tensor block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + int64_t block_size, Tensor output) + : PagedAttention(query, key_cache, value_cache, seq_lens, block_table, + num_heads, num_kv_heads, head_size, scale, block_size, + output) { + int64_t B = static_cast(batch_size_); + int64_t N = num_heads_; + int64_t Nkv = num_kv_heads_; + int64_t D = head_size_; + + // Query/output shapes: 3D [B, N, D] (BSND with S=1 for decode). + query_tnd_shape_ = {B, N, D}; + output_tnd_shape_ = {B, N, D}; + + // KV cache shapes. + int64_t num_blocks = static_cast(key_cache.size(0)); + int64_t bs = static_cast(key_cache.size(1)); + kv_cache_shape_ = {num_blocks, bs, Nkv, D}; + + // Block table and context lens shapes. + int64_t max_blocks = static_cast(block_table.size(1)); + block_table_shape_ = {B, max_blocks}; + context_lens_shape_ = {B}; + + // ACL data types. + acl_dt_ = ascend::toAclDtype(query.dtype()); + bt_dt_ = ascend::toAclDtype(block_table.dtype()); + sl_dt_ = ascend::toAclDtype(seq_lens.dtype()); + + // Element sizes for `dataSize` computation. + elem_size_ = query.element_size(); + bt_elem_size_ = block_table.element_size(); + sl_elem_size_ = seq_lens.element_size(); + + // Pre-allocate pinned host buffers for D2H copies. + // ATB PA reads `hostData` from block_table and context_lens to + // construct internal `aclIntArray*` parameters. + bt_host_bytes_ = static_cast(B * max_blocks) * bt_elem_size_; + sl_host_bytes_ = static_cast(B) * sl_elem_size_; + bt_host_ = std::malloc(bt_host_bytes_); + sl_host_ = std::malloc(sl_host_bytes_); + assert(bt_host_ && sl_host_ && "host buffer allocation failed"); + + // Create the ATB operation (reused across calls). + atb::infer::PagedAttentionParam param; + param.headNum = static_cast(N); + param.kvHeadNum = static_cast(Nkv); + param.qkScale = static_cast(scale_); + + atb::Status s = atb::CreateOperation(param, &op_); + assert(s == atb::NO_ERROR && "atb::CreateOperation(PagedAttention) failed"); + } + + ~Operator() { + if (op_) { + atb::DestroyOperation(op_); + } + + std::free(bt_host_); + std::free(sl_host_); + } + + Operator(const Operator&) = delete; + + Operator& operator=(const Operator&) = delete; + + void operator()(const Tensor query, const Tensor key_cache, + const Tensor value_cache, const Tensor seq_lens, + const Tensor block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + int64_t block_size, Tensor output) const override { + auto stream = static_cast(stream_); + atb::Context* ctx = ascend::getAtbContext(stream); + + // D2H copy for block_table and context_lens. + // ATB reads `hostData` to construct internal `aclIntArray*`. + aclrtMemcpy(bt_host_, bt_host_bytes_, block_table.data(), bt_host_bytes_, + ACL_MEMCPY_DEVICE_TO_HOST); + aclrtMemcpy(sl_host_, sl_host_bytes_, seq_lens.data(), sl_host_bytes_, + ACL_MEMCPY_DEVICE_TO_HOST); + + atb::VariantPack vp = buildVariantPack( + const_cast(query.data()), const_cast(key_cache.data()), + const_cast(value_cache.data()), + const_cast(block_table.data()), + const_cast(seq_lens.data()), output.data()); + + // Setup computes workspace requirements and binds tensor descriptors. + uint64_t ws_size = 0; + atb::Status s = op_->Setup(vp, ws_size, ctx); + assert(s == atb::NO_ERROR && + "atb::Operation::Setup(PagedAttention) failed"); + + // Allocate workspace via the shared pool. + uint8_t* ws_ptr = nullptr; + + if (ws_size > 0) { + auto& arena = ascend::workspacePool().ensure(stream, ws_size); + ws_ptr = static_cast(arena.buf); + } + + s = op_->Execute(vp, ws_ptr, ws_size, ctx); + assert(s == atb::NO_ERROR && + "atb::Operation::Execute(PagedAttention) failed"); + } + + private: + // Build the ATB `VariantPack`. + // + // Query and output are 3D [B, N, D] (BSND with S=1 for decode). + // Block table and context lens carry both `deviceData` and + // `hostData` because ATB reads the host copy to build internal + // `aclIntArray*` parameters. + atb::VariantPack buildVariantPack(void* query_data, void* key_cache_data, + void* value_cache_data, + void* block_table_data, void* seq_lens_data, + void* output_data) const { + int64_t B = query_tnd_shape_[0]; + int64_t N = query_tnd_shape_[1]; + int64_t D = query_tnd_shape_[2]; + + // Query [B, N, D] — 3D (BSND with S=1). + uint64_t q_bytes = static_cast(B * N * D) * elem_size_; + atb::Tensor t_query = + ascend::toAtbTensor(query_tnd_shape_, acl_dt_, query_data, q_bytes); + + // KV caches [num_blocks, block_size, Nkv, D]. + int64_t nb = kv_cache_shape_[0]; + int64_t bs = kv_cache_shape_[1]; + int64_t Nkv = kv_cache_shape_[2]; + uint64_t kv_bytes = static_cast(nb * bs * Nkv * D) * elem_size_; + atb::Tensor t_key_cache = + ascend::toAtbTensor(kv_cache_shape_, acl_dt_, key_cache_data, kv_bytes); + atb::Tensor t_value_cache = ascend::toAtbTensor(kv_cache_shape_, acl_dt_, + value_cache_data, kv_bytes); + + // Block table [B, max_blocks] — with `hostData` for `aclIntArray*`. + atb::Tensor t_block_table = ascend::toAtbTensor( + block_table_shape_, bt_dt_, block_table_data, bt_host_bytes_); + t_block_table.hostData = bt_host_; + + // Context lens [B] — with `hostData` for `aclIntArray*`. + atb::Tensor t_context_lens = ascend::toAtbTensor( + context_lens_shape_, sl_dt_, seq_lens_data, sl_host_bytes_); + t_context_lens.hostData = sl_host_; + + // Output [B, N, D] — 3D (BSND with S=1). + atb::Tensor t_output = + ascend::toAtbTensor(output_tnd_shape_, acl_dt_, output_data, q_bytes); + + atb::VariantPack vp; + vp.inTensors = {t_query, t_key_cache, t_value_cache, t_block_table, + t_context_lens}; + vp.outTensors = {t_output}; + + return vp; + } + + atb::Operation* op_ = nullptr; + + std::vector query_tnd_shape_; + + std::vector output_tnd_shape_; + + std::vector kv_cache_shape_; + + std::vector block_table_shape_; + + std::vector context_lens_shape_; + + aclDataType acl_dt_ = ACL_DT_UNDEFINED; + + aclDataType bt_dt_ = ACL_DT_UNDEFINED; + + aclDataType sl_dt_ = ACL_DT_UNDEFINED; + + uint64_t elem_size_ = 0; + + uint64_t bt_elem_size_ = 0; + + uint64_t sl_elem_size_ = 0; + + // Host-side buffers for ATB's internal `aclIntArray*` construction. + void* bt_host_ = nullptr; + + void* sl_host_ = nullptr; + + uint64_t bt_host_bytes_ = 0; + + uint64_t sl_host_bytes_ = 0; +}; + +} // namespace infini::ops + +#endif // INFINI_HAS_ATB + +#endif // INFINI_OPS_ASCEND_PAGED_ATTENTION_KERNEL_ATB_H_ diff --git a/src/ascend/paged_attention/registry.h b/src/ascend/paged_attention/registry.h new file mode 100644 index 00000000..53c2c836 --- /dev/null +++ b/src/ascend/paged_attention/registry.h @@ -0,0 +1,24 @@ +#ifndef INFINI_OPS_ASCEND_PAGED_ATTENTION_REGISTRY_H_ +#define INFINI_OPS_ASCEND_PAGED_ATTENTION_REGISTRY_H_ + +#include "base/paged_attention.h" + +namespace infini::ops { + +// ATB `PagedAttentionParam` is the only implementation. Unlike +// `FlashAttention`, paged attention exists specifically to provide a +// graph-safe decode path (all parameters are tensor-based, no +// `aclIntArray*`). When ATB is unavailable, fall back to +// `FlashAttention` for decode at the Python layer. +template <> +struct ActiveImplementationsImpl { +#ifdef INFINI_HAS_ATB + using type = List<0>; +#else + using type = List<>; +#endif +}; + +} // namespace infini::ops + +#endif // INFINI_OPS_ASCEND_PAGED_ATTENTION_REGISTRY_H_ diff --git a/src/ascend/reshape_and_cache/kernel.h b/src/ascend/reshape_and_cache/kernel.h new file mode 100644 index 00000000..bc4f1456 --- /dev/null +++ b/src/ascend/reshape_and_cache/kernel.h @@ -0,0 +1,110 @@ +#ifndef INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_H_ +#define INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_H_ + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_index_copy.h" +#include "ascend/common.h" +#include "ascend/reshape_and_cache/registry.h" +#include "ascend/workspace_pool_.h" +#include "base/reshape_and_cache.h" +#include "operator.h" + +namespace infini::ops { + +// Device-side scatter via `aclnnInplaceIndexCopy`. +// +// The previous implementation copied slot_mapping D2H (aclrtSynchronizeStream), +// then issued per-token D2D memcpy in a host loop. For batch=256, this meant +// ~100 us sync + ~500 us host loop overhead. `aclnnInplaceIndexCopy` performs +// the scatter entirely on the NPU with two ACLNN calls (one for K, one for V), +// eliminating all D2H synchronisation and host-side loops. +// +// Requirement: slot_mapping must contain only non-negative values. Padding +// tokens (slot < 0) must be filtered by the caller before invoking this +// operator. +template <> +class Operator + : public ReshapeAndCache { + public: + Operator(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) + : ReshapeAndCache(key, value, kv_cache, slot_mapping, kv_cache_out), + key_cache_(key), + value_cache_(value), + slot_cache_(slot_mapping) { + auto num_blocks = static_cast(kv_cache.size(1)); + auto bs = static_cast(block_size_); + int64_t total_slots = num_blocks * bs; + int64_t nkv = static_cast(num_kv_heads_); + int64_t hs = static_cast(head_size_); + + aclDataType acl_dt = ascend::toAclDtype(key.dtype()); + + // Flattened K cache view: [total_slots, num_kv_heads, head_size]. + // K cache is kv_cache_out[0], starting at offset 0. + kv_k_cache_ = ascend::AclTensorCache({total_slots, nkv, hs}, acl_dt, + kv_cache_out.data()); + + // V cache is kv_cache_out[1], offset by stride(0) elements. + v_offset_bytes_ = static_cast(kv_cache_out.stride(0)) * + kv_cache_out.element_size(); + kv_v_cache_ = ascend::AclTensorCache( + {total_slots, nkv, hs}, acl_dt, + static_cast(kv_cache_out.data()) + v_offset_bytes_); + } + + void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, + Tensor kv_cache_out) const override { + auto stream = static_cast(stream_); + + void* kv_k_data = kv_cache_out.data(); + void* kv_v_data = static_cast(kv_cache_out.data()) + v_offset_bytes_; + + auto t_kv_k = kv_k_cache_.get(kv_k_data); + auto t_kv_v = kv_v_cache_.get(kv_v_data); + auto t_key = key_cache_.get(const_cast(key.data())); + auto t_value = value_cache_.get(const_cast(value.data())); + auto t_slot = slot_cache_.get(const_cast(slot_mapping.data())); + + // K cache scatter: kv_k[slot_mapping[i]] = key[i] along dim 0. + // Executor caching is not used here because `aclnnInplaceIndexCopy` is an + // inplace operation where self is both input and output; the executor + // reuse via aclSetInputTensorAddr does not update the output reference. + uint64_t k_ws = 0; + aclOpExecutor* k_exec = nullptr; + aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_k, 0, t_slot, t_key, &k_ws, + &k_exec); + auto& k_arena = ascend::workspacePool().ensure(stream, k_ws); + aclnnInplaceIndexCopy(k_arena.buf, k_ws, k_exec, stream); + + // V cache scatter: kv_v[slot_mapping[i]] = value[i] along dim 0. + uint64_t v_ws = 0; + aclOpExecutor* v_exec = nullptr; + aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_v, 0, t_slot, t_value, &v_ws, + &v_exec); + auto& v_arena = ascend::workspacePool().ensure(stream, v_ws); + aclnnInplaceIndexCopy(v_arena.buf, v_ws, v_exec, stream); + } + + private: + mutable ascend::AclTensorCache kv_k_cache_; + + mutable ascend::AclTensorCache kv_v_cache_; + + mutable ascend::AclTensorCache key_cache_; + + mutable ascend::AclTensorCache value_cache_; + + mutable ascend::AclTensorCache slot_cache_; + + size_t v_offset_bytes_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/reshape_and_cache/kernel_atb.h b/src/ascend/reshape_and_cache/kernel_atb.h new file mode 100644 index 00000000..13abfc44 --- /dev/null +++ b/src/ascend/reshape_and_cache/kernel_atb.h @@ -0,0 +1,233 @@ +#ifndef INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_ATB_H_ +#define INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_ATB_H_ + +#ifdef INFINI_HAS_ATB + +#include +#include +#include + +#include "acl/acl.h" +#include "ascend/atb_common_.h" +#include "ascend/common.h" +#include "ascend/reshape_and_cache/registry.h" +#include "ascend/workspace_pool_.h" +#include "atb/context.h" +#include "atb/infer_op_params.h" +#include "atb/operation.h" +#include "atb/types.h" +#include "base/reshape_and_cache.h" +#include "operator.h" + +namespace infini::ops { + +// ATB-based KV cache scatter via `atb::infer::ReshapeAndCacheParam` +// (implementation index 2). +// +// Handles both K and V in a single fused operation. Profiled at ~9.5 us/call +// on Ascend 910B (256 tokens, fp16) — 3.7x faster than the +// `aclnnInplaceIndexCopy` path (index 0, ~35 us). +// +// The ATB operation is created once in the constructor. Setup is called +// before each `Execute` to bind the `VariantPack`. +// +// NOTE: `ReshapeAndCacheParam` requires int32 `slot_mapping`. When the +// caller passes int64 (the default in PyTorch / vLLM), this operator casts +// to int32 via a pre-allocated device buffer — matching the pattern used in +// the ATB rotary_embedding operator. +// +// Input layout: +// key, value : [num_tokens, num_kv_heads, head_size] +// slot_mapping: [num_tokens] (int32 or int64; int64 is cast internally) +// +// KV cache layout: +// kv_cache: [2, num_blocks, block_size, num_kv_heads, head_size] +// Output key_cache = kv_cache[0], value_cache = kv_cache[1], each with +// shape [num_blocks, block_size, num_kv_heads, head_size]. +template <> +class Operator + : public ReshapeAndCache { + public: + Operator(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) + : ReshapeAndCache(key, value, kv_cache, slot_mapping, kv_cache_out) { + auto num_blocks = static_cast(kv_cache.size(1)); + auto bs = static_cast(block_size_); + int64_t nkv = static_cast(num_kv_heads_); + int64_t hs = static_cast(head_size_); + int64_t T = static_cast(num_tokens_); + + // Cache shapes for rebuilding `VariantPack` on each call. + kv_shape_ = {num_blocks, bs, nkv, hs}; + key_shape_ = {T, nkv, hs}; + slot_shape_ = {T}; + acl_dt_ = ascend::toAclDtype(key.dtype()); + + // Compute V-cache byte offset (kv_cache_out[1]). + v_offset_bytes_ = static_cast(kv_cache_out.stride(0)) * + kv_cache_out.element_size(); + + // Element sizes for dataSize computation. + elem_size_ = key.element_size(); + + // Pre-allocate int32 device buffer for `slot_mapping`. + // `ReshapeAndCacheParam` requires int32; int64 is silently ignored + // (writes nothing). + slot32_bytes_ = static_cast(T) * sizeof(int32_t); + aclrtMalloc(&slot32_buf_, slot32_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); + assert(slot32_buf_ && "aclrtMalloc for slot32_buf_ failed"); + + slot_is_int32_ = (slot_mapping.element_size() == sizeof(int32_t)); + + // Create the ATB operation (reused across calls). + atb::infer::ReshapeAndCacheParam param; + atb::Status s = atb::CreateOperation(param, &op_); + assert(s == atb::NO_ERROR && + "atb::CreateOperation(ReshapeAndCache) failed"); + } + + ~Operator() { + if (!ascend::isAclRuntimeAlive()) return; + if (op_) atb::DestroyOperation(op_); + if (slot32_buf_) aclrtFree(slot32_buf_); + } + + Operator(const Operator&) = delete; + + Operator& operator=(const Operator&) = delete; + + void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, + Tensor kv_cache_out) const override { + auto stream = static_cast(stream_); + + // `ReshapeAndCacheParam` requires int32 `slot_mapping`. When the + // caller provides int64 (the PyTorch/vLLM default), cast to int32 via + // a pre-allocated device buffer. + void* slot32_ptr; + + if (slot_is_int32_) { + // Already int32 — pass through directly. + slot32_ptr = const_cast(slot_mapping.data()); + } else { + // int64 → int32: D2H, CPU cast, H2D. + auto T = static_cast(num_tokens_); + std::vector i64(T); + aclrtMemcpyAsync(i64.data(), T * sizeof(int64_t), slot_mapping.data(), + T * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + + std::vector i32(T); + + for (size_t i = 0; i < T; ++i) { + i32[i] = static_cast(i64[i]); + } + + aclrtMemcpyAsync(slot32_buf_, slot32_bytes_, i32.data(), slot32_bytes_, + ACL_MEMCPY_HOST_TO_DEVICE, stream); + slot32_ptr = slot32_buf_; + } + + atb::Context* ctx = ascend::getAtbContext(stream); + + atb::VariantPack vp = buildVariantPack(const_cast(key.data()), + const_cast(value.data()), + kv_cache_out.data(), slot32_ptr); + + // `Setup` binds the `VariantPack` and computes workspace requirements. + uint64_t ws_size = 0; + atb::Status s = op_->Setup(vp, ws_size, ctx); + assert(s == atb::NO_ERROR && + "atb::Operation::Setup(ReshapeAndCache) failed"); + + // Allocate workspace via the shared pool. + uint8_t* ws_ptr = nullptr; + + if (ws_size > 0) { + auto& arena = ascend::workspacePool().ensure(stream, ws_size); + ws_ptr = static_cast(arena.buf); + } + + s = op_->Execute(vp, ws_ptr, ws_size, ctx); + assert(s == atb::NO_ERROR && + "atb::Operation::Execute(ReshapeAndCache) failed"); + } + + private: + // Build the ATB `VariantPack` for this operation. + // + // ATB `ReshapeAndCache` expects 5 inputs and 2 outputs: + // inTensors[0] = key [num_tokens, num_kv_heads, head_size] + // inTensors[1] = value [num_tokens, num_kv_heads, head_size] + // inTensors[2] = key_cache [num_blocks, block_size, num_kv_heads, + // head_size] inTensors[3] = value_cache [num_blocks, block_size, + // num_kv_heads, head_size] inTensors[4] = slot_mapping [num_tokens] (int32) + // outTensors[0] = key_cache (same buffer, in-place) + // outTensors[1] = value_cache (same buffer, in-place) + atb::VariantPack buildVariantPack(void* key_data, void* value_data, + void* kv_out_data, + void* slot32_data) const { + int64_t num_tokens = key_shape_[0]; + int64_t nkv = key_shape_[1]; + int64_t hs = key_shape_[2]; + uint64_t kv_bytes = + static_cast(num_tokens * nkv * hs) * elem_size_; + + int64_t nb = kv_shape_[0]; + int64_t bs = kv_shape_[1]; + uint64_t cache_bytes = + static_cast(nb * bs * nkv * hs) * elem_size_; + + void* v_out_data = static_cast(kv_out_data) + v_offset_bytes_; + + atb::Tensor t_key = + ascend::toAtbTensor(key_shape_, acl_dt_, key_data, kv_bytes); + + atb::Tensor t_value = + ascend::toAtbTensor(key_shape_, acl_dt_, value_data, kv_bytes); + + atb::Tensor t_kv_k = + ascend::toAtbTensor(kv_shape_, acl_dt_, kv_out_data, cache_bytes); + + atb::Tensor t_kv_v = + ascend::toAtbTensor(kv_shape_, acl_dt_, v_out_data, cache_bytes); + + // Always int32 — the caller's `operator()` has already cast to int32. + atb::Tensor t_slot = + ascend::toAtbTensor(slot_shape_, ACL_INT32, slot32_data, slot32_bytes_); + + atb::VariantPack vp; + vp.inTensors = {t_key, t_value, t_kv_k, t_kv_v, t_slot}; + vp.outTensors = {t_kv_k, t_kv_v}; + + return vp; + } + + atb::Operation* op_ = nullptr; + + std::vector kv_shape_; + + std::vector key_shape_; + + std::vector slot_shape_; + + aclDataType acl_dt_ = ACL_DT_UNDEFINED; + + size_t v_offset_bytes_ = 0; + + uint64_t elem_size_ = 0; + + // Pre-allocated int32 device buffer for `slot_mapping`. + void* slot32_buf_ = nullptr; + + size_t slot32_bytes_ = 0; + + // True if the caller already provides int32 `slot_mapping`. + bool slot_is_int32_ = false; +}; + +} // namespace infini::ops + +#endif // INFINI_HAS_ATB + +#endif // INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_ATB_H_ diff --git a/src/ascend/reshape_and_cache/kernel_v2.h b/src/ascend/reshape_and_cache/kernel_v2.h new file mode 100644 index 00000000..b4e59d7a --- /dev/null +++ b/src/ascend/reshape_and_cache/kernel_v2.h @@ -0,0 +1,124 @@ +#ifndef INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_V2_H_ +#define INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_V2_H_ + +// WARNING: This implementation is experimental and has strict hardware limits. +// +// Limitations: +// 1. Requires CANN 8.5.1+ (`aclnnScatterPaKvCache` API). +// 2. Only supported on Atlas A5 hardware (SoC 260). NOT supported on +// A2 (Ascend 910B, SoC 220-225) or A3 (SoC 250-255). +// 3. Not yet validated in production workloads. +// +// On unsupported hardware this file compiles to nothing (guarded by +// `__has_include`). Use `implementation_index=0` (the default +// `aclnnInplaceIndexCopy` path) for general-purpose deployment. + +#if __has_include("aclnnop/aclnn_scatter_pa_kv_cache.h") + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_scatter_pa_kv_cache.h" +#include "ascend/common.h" +#include "ascend/reshape_and_cache/registry.h" +#include "ascend/workspace_pool_.h" +#include "base/reshape_and_cache.h" +#include "operator.h" + +namespace infini::ops { + +// Fused KV cache scatter via `aclnnScatterPaKvCache` (implementation index 1). +// +// Handles both K and V scatter in a single CANN kernel launch, replacing two +// separate `aclnnInplaceIndexCopy` calls (index 0). The fused API is +// purpose-built for paged KV cache and avoids the internal decomposition to +// `ScatterElementsV2`. +// +// Requirements: +// - CANN 8.5.1+ (`aclnnop/aclnn_scatter_pa_kv_cache.h`). +// - Atlas A5 hardware (SoC 260). The API is NOT supported on A2 (910B, +// SoC 220-225) or A3 (SoC 250-255). +// +// Select via `implementation_index=1` in Python: +// infini.ops.reshape_and_cache(..., implementation_index=1, stream=s) +template <> +class Operator + : public ReshapeAndCache { + public: + Operator(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) + : ReshapeAndCache(key, value, kv_cache, slot_mapping, kv_cache_out), + key_cache_(key), + value_cache_(value), + slot_cache_(slot_mapping) { + auto num_blocks = static_cast(kv_cache.size(1)); + auto bs = static_cast(block_size_); + int64_t nkv = static_cast(num_kv_heads_); + int64_t hs = static_cast(head_size_); + + aclDataType acl_dt = ascend::toAclDtype(key.dtype()); + + // 4D K cache view: [num_blocks, block_size, num_kv_heads, head_size]. + // K cache is kv_cache_out[0], starting at offset 0. + kv_k_cache_ = ascend::AclTensorCache({num_blocks, bs, nkv, hs}, acl_dt, + kv_cache_out.data()); + + // V cache is kv_cache_out[1], offset by stride(0) elements. + v_offset_bytes_ = static_cast(kv_cache_out.stride(0)) * + kv_cache_out.element_size(); + kv_v_cache_ = ascend::AclTensorCache( + {num_blocks, bs, nkv, hs}, acl_dt, + static_cast(kv_cache_out.data()) + v_offset_bytes_); + } + + void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, + Tensor kv_cache_out) const override { + auto stream = static_cast(stream_); + + void* kv_k_data = kv_cache_out.data(); + void* kv_v_data = static_cast(kv_cache_out.data()) + v_offset_bytes_; + + auto t_key = key_cache_.get(const_cast(key.data())); + auto t_value = value_cache_.get(const_cast(value.data())); + auto t_slot = slot_cache_.get(const_cast(slot_mapping.data())); + auto t_kv_k = kv_k_cache_.get(kv_k_data); + auto t_kv_v = kv_v_cache_.get(kv_v_data); + + // Single fused scatter for both K and V caches. + uint64_t ws = 0; + aclOpExecutor* exec = nullptr; + aclnnScatterPaKvCacheGetWorkspaceSize( + t_key, t_kv_k, t_slot, t_value, t_kv_v, + /*compressLensOptional=*/nullptr, + /*compressSeqOffsetOptional=*/nullptr, + /*seqLensOptional=*/nullptr, + /*cacheModeOptional=*/nullptr, + /*scatterModeOptional=*/nullptr, + /*stridesOptional=*/nullptr, + /*offsetsOptional=*/nullptr, &ws, &exec); + auto& arena = ascend::workspacePool().ensure(stream, ws); + aclnnScatterPaKvCache(arena.buf, ws, exec, stream); + } + + private: + mutable ascend::AclTensorCache kv_k_cache_; + + mutable ascend::AclTensorCache kv_v_cache_; + + mutable ascend::AclTensorCache key_cache_; + + mutable ascend::AclTensorCache value_cache_; + + mutable ascend::AclTensorCache slot_cache_; + + size_t v_offset_bytes_ = 0; +}; + +} // namespace infini::ops + +#endif // __has_include("aclnnop/aclnn_scatter_pa_kv_cache.h") + +#endif // INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_V2_H_ diff --git a/src/ascend/reshape_and_cache/registry.h b/src/ascend/reshape_and_cache/registry.h new file mode 100644 index 00000000..c8c0fe48 --- /dev/null +++ b/src/ascend/reshape_and_cache/registry.h @@ -0,0 +1,27 @@ +#ifndef INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_REGISTRY_H_ +#define INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_REGISTRY_H_ + +#include "base/reshape_and_cache.h" + +namespace infini::ops { + +// Implementation 0: `aclnnInplaceIndexCopy` (CANN 8.0+, two calls for K+V). +// Implementation 1: `aclnnScatterPaKvCache` (CANN 8.5.1+, single fused call). +// Implementation 2: ATB `ReshapeAndCacheNdKernel` (fused K+V, graph-safe). +template <> +struct ActiveImplementationsImpl { +#if defined(INFINI_HAS_ATB) && \ + __has_include("aclnnop/aclnn_scatter_pa_kv_cache.h") + using type = List<0, 1, 2>; +#elif defined(INFINI_HAS_ATB) + using type = List<0, 2>; +#elif __has_include("aclnnop/aclnn_scatter_pa_kv_cache.h") + using type = List<0, 1>; +#else + using type = List<0>; +#endif +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/rms_norm/kernel.h b/src/ascend/rms_norm/kernel.h new file mode 100644 index 00000000..d80441f2 --- /dev/null +++ b/src/ascend/rms_norm/kernel.h @@ -0,0 +1,96 @@ +#ifndef INFINI_OPS_ASCEND_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_ASCEND_RMS_NORM_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_rms_norm.h" +#include "ascend/common.h" +#include "ascend/rms_norm/registry.h" +#include "ascend/workspace_pool_.h" +#include "base/rms_norm.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public RmsNorm { + public: + Operator(const Tensor input, const Tensor weight, float eps, Tensor out) + : RmsNorm(input, weight, eps, out), + in_cache_(input), + weight_cache_(weight), + out_cache_(out) { + // `aclnnRmsNorm` writes `rstd` as a required side output. + // Size computed here; buffer obtained from pool in `operator()`. + rstd_shape_ = {static_cast(batch_size_), + static_cast(nhead_)}; + rstd_size_ = batch_size_ * nhead_ * sizeof(float); + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + if (rstd_tensor_) aclDestroyTensor(rstd_tensor_); + } + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_weight = weight_cache_.get(const_cast(weight.data())); + auto t_out = out_cache_.get(out.data()); + auto stream = static_cast(stream_); + + // Obtain shared rstd buffer from pool. + auto& rstd_arena = + ascend::workspacePool().ensure(stream, rstd_size_, "temp"); + + // Lazily create rstd tensor descriptor on first call. + if (!rstd_tensor_) { + rstd_tensor_ = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, + rstd_shape_.data(), 2, rstd_arena.buf); + } else { + aclSetRawTensorAddr(rstd_tensor_, rstd_arena.buf); + } + + if (!executor_) { + aclnnRmsNormGetWorkspaceSize(t_in, t_weight, eps, t_out, rstd_tensor_, + &ws_size_, &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_weight, + const_cast(weight.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + aclSetOutputTensorAddr(executor_, 1, rstd_tensor_, rstd_arena.buf); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnRmsNorm(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache weight_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; + + std::vector rstd_shape_; + + uint64_t rstd_size_ = 0; + + mutable aclTensor* rstd_tensor_ = nullptr; +}; + +} // namespace infini::ops + +#include "ascend/rms_norm/kernel_custom.h" + +#endif diff --git a/src/ascend/rms_norm/kernel_custom.h b/src/ascend/rms_norm/kernel_custom.h new file mode 100644 index 00000000..7c725ecd --- /dev/null +++ b/src/ascend/rms_norm/kernel_custom.h @@ -0,0 +1,161 @@ +#ifndef INFINI_OPS_ASCEND_RMS_NORM_KERNEL_CUSTOM_H_ +#define INFINI_OPS_ASCEND_RMS_NORM_KERNEL_CUSTOM_H_ + +#ifdef INFINI_HAS_CUSTOM_RMS_NORM + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cast.h" +#include "ascend/common.h" +#include "ascend/rms_norm/registry.h" +#include "ascend/workspace_pool_.h" +#include "base/rms_norm.h" +#include "operator.h" + +// Forward-declare the generated AscendC kernel launch function. +// This symbol is provided by the `no_workspace_kernel` static library +// built from `ascend/custom_kernel/csrc/ops/rms_norm/op_kernel/rms_norm.cpp` +// via `ascendc_library()`. +extern "C" uint32_t aclrtlaunch_rms_norm( + uint32_t blockDim, void* stream, void* x, void* weight, void* y, + int64_t totalRows, int64_t dimLength, int64_t dimLengthAlign, + int64_t formerNum, int64_t formerLength, int64_t tailLength, float eps, + int64_t dtypeSize); + +namespace infini::ops { + +// Custom AscendC fused RmsNorm kernel (implementation index 1). +// +// A single-kernel implementation that computes RMSNorm in one launch, avoiding +// the 5-sub-op decomposition of `aclnnRmsNorm` (index 0). Uses `Sqrt` + +// scalar division instead of `Rsqrt` for higher precision (~1e-7 fp32 error +// vs ~0.2% with `Rsqrt`). +// +// Select via `implementation_index=1` in Python: +// infini.ops.rms_norm(input, weight, eps, out, implementation_index=1, +// stream=s) +// +// Requirements: +// - Input last dimension must be 32-byte aligned (divisible by 16 for fp16 +// or 8 for fp32). All standard LLM hidden dimensions satisfy this. +// - Weight must have the same dtype as input. +// - The custom kernel binary must be linked (`BUILD_CUSTOM_KERNEL=ON`). +template <> +class Operator : public RmsNorm { + public: + Operator(const Tensor input, const Tensor weight, float eps, Tensor out) + : RmsNorm(input, weight, eps, out) { + // Dtype size in bytes. + dtype_size_ = (input.dtype() == DataType::kFloat16) ? 2 : 4; + + // Alignment check (32-byte boundary). + int64_t align_elems = 32 / dtype_size_; + dim_length_align_ = + ((static_cast(dim_) + align_elems - 1) / align_elems) * + align_elems; + assert(static_cast(dim_) == dim_length_align_ && + "custom `RmsNorm` kernel requires 32-byte aligned last dimension"); + + total_rows_ = + static_cast(batch_size_) * static_cast(nhead_); + + // For fp16 input, weight needs fp32 conversion because the custom + // kernel always reads weight as fp32. + needs_weight_cast_ = (dtype_size_ == 2); + + if (needs_weight_cast_) { + // Allocate persistent fp32 weight buffer on device. + size_t fp32_bytes = static_cast(dim_) * sizeof(float); + aclrtMalloc(&weight_fp32_data_, fp32_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + // AclTensorCache for the cast source (fp16 weight descriptor). + weight_src_cache_ = ascend::AclTensorCache({static_cast(dim_)}, + ACL_FLOAT16, nullptr); + + // AclTensorCache for the cast destination (fp32 weight buffer). + weight_dst_cache_ = ascend::AclTensorCache({static_cast(dim_)}, + ACL_FLOAT, weight_fp32_data_); + } + } + + ~Operator() { + if (cast_exec_) aclDestroyAclOpExecutor(cast_exec_); + if (weight_fp32_data_) aclrtFree(weight_fp32_data_); + } + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + auto stream = static_cast(stream_); + + // Determine fp32 weight pointer. + void* weight_fp32; + + if (needs_weight_cast_) { + // Cast weight fp16 -> fp32 using cached ACLNN executor. + auto t_src = weight_src_cache_.get(const_cast(weight.data())); + auto t_dst = weight_dst_cache_.get(weight_fp32_data_); + + if (!cast_exec_) { + aclnnCastGetWorkspaceSize(t_src, ACL_FLOAT, t_dst, &cast_ws_, + &cast_exec_); + aclSetAclOpExecutorRepeatable(cast_exec_); + } else { + aclSetInputTensorAddr(cast_exec_, 0, t_src, + const_cast(weight.data())); + aclSetOutputTensorAddr(cast_exec_, 0, t_dst, weight_fp32_data_); + } + + auto& arena = ascend::workspacePool().ensure(stream, cast_ws_); + aclnnCast(arena.buf, cast_ws_, cast_exec_, stream); + weight_fp32 = weight_fp32_data_; + } else { + // Input is fp32 — weight is already fp32. + weight_fp32 = const_cast(weight.data()); + } + + // Block-level tiling: distribute rows across cores. + // Maximum block dimension covers Ascend 910B (20-40 AIV cores). + // Over-subscribing is safe (runtime multiplexes blocks across cores), + // though slightly sub-optimal due to per-block weight loading. + static constexpr int64_t kMaxBlockDim = 40; + int64_t used_cores = std::min(total_rows_, kMaxBlockDim); + int64_t former_length = (total_rows_ + used_cores - 1) / used_cores; + int64_t tail_length = former_length - 1; + int64_t former_num = total_rows_ - tail_length * used_cores; + uint32_t block_dim = static_cast(used_cores); + + // Launch custom AscendC kernel. + aclrtlaunch_rms_norm( + block_dim, stream, const_cast(input.data()), weight_fp32, + out.data(), total_rows_, static_cast(dim_), dim_length_align_, + former_num, former_length, tail_length, eps, dtype_size_); + } + + private: + int64_t dtype_size_; + + int64_t dim_length_align_; + + int64_t total_rows_; + + bool needs_weight_cast_; + + void* weight_fp32_data_ = nullptr; + + mutable ascend::AclTensorCache weight_src_cache_; + + mutable ascend::AclTensorCache weight_dst_cache_; + + mutable aclOpExecutor* cast_exec_ = nullptr; + + mutable uint64_t cast_ws_ = 0; +}; + +} // namespace infini::ops + +#endif // INFINI_HAS_CUSTOM_RMS_NORM +#endif // INFINI_OPS_ASCEND_RMS_NORM_KERNEL_CUSTOM_H_ diff --git a/src/ascend/rms_norm/registry.h b/src/ascend/rms_norm/registry.h new file mode 100644 index 00000000..5d279fd4 --- /dev/null +++ b/src/ascend/rms_norm/registry.h @@ -0,0 +1,19 @@ +#ifndef INFINI_OPS_ASCEND_RMS_NORM_REGISTRY_H_ +#define INFINI_OPS_ASCEND_RMS_NORM_REGISTRY_H_ + +#include "base/rms_norm.h" + +namespace infini::ops { + +template <> +struct ActiveImplementationsImpl { +#ifdef INFINI_HAS_CUSTOM_RMS_NORM + using type = List<0, 1>; +#else + using type = List<0>; +#endif +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/rotary_embedding/kernel.h b/src/ascend/rotary_embedding/kernel.h new file mode 100644 index 00000000..4b05be31 --- /dev/null +++ b/src/ascend/rotary_embedding/kernel.h @@ -0,0 +1,268 @@ +#ifndef INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_H_ +#define INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_H_ + +#include +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_apply_rotary_pos_emb_v2.h" +#include "aclnnop/aclnn_index_select.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/rotary_embedding.h" +#include "operator.h" + +namespace infini::ops { + +// Rotary position embedding via `aclnnApplyRotaryPosEmbV2`. +// +// V2 handles Q and K simultaneously in a single inplace call (layout=4, TND). +// The `rotaryMode` parameter accepts "half", "interleave", or "quarter", but +// CANN currently only supports "half" (neox style). Passing "interleave" or +// "quarter" returns ACLNN_ERR_PARAM_INVALID. +// +// fp16 note: V2 accumulates with ~4 ULP error for float16 (max diff ~0.008), +// which exceeds strict atol=0.001 tests but is acceptable for inference. +// bfloat16 passes with atol=0.005. +// +// Restrictions: +// - rotary_dim must equal head_size (partial rotation not supported). +// - is_neox_style must be true (rotaryMode="half" only). +// All mainstream models (LLaMA, Qwen, Mistral, DeepSeek) satisfy both. +template <> +class Operator + : public RotaryEmbedding { + public: + Operator(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, int64_t rotary_dim, + bool is_neox_style, Tensor query_out, Tensor key_out) + : RotaryEmbedding(positions, query, key, cos_sin_cache, head_size, + rotary_dim, is_neox_style, query_out, key_out) { + assert(rotary_dim == head_size && + "ascend `RotaryEmbedding` requires `rotary_dim` == `head_size` " + "(partial rotation not supported)"); + assert(is_neox_style && + "ascend `RotaryEmbedding` requires neox style — " + "`aclnnApplyRotaryPosEmbV2` `rotaryMode` only supports " + "\"half\"; \"interleave\" and \"quarter\" return " + "`ACLNN_ERR_PARAM_INVALID`"); + + const int64_t max_seq_len = cos_sin_cache.size(0); + const int64_t D = head_size_; + const int64_t half_D = D / 2; + const size_t elem_sz = cos_sin_cache.element_size(); + + // One-time: D2H copy cos_sin_cache, split cos/sin, expand, upload. + // cos_sin_cache layout per row: [c0..c_{D/2-1}, s0..s_{D/2-1}]. + size_t table_bytes = static_cast(max_seq_len * D) * elem_sz; + std::vector cache_host(table_bytes); + aclrtMemcpy(cache_host.data(), table_bytes, cos_sin_cache.data(), + table_bytes, ACL_MEMCPY_DEVICE_TO_HOST); + + // Pre-expand into separate cos/sin tables [max_seq_len, D]. + // neox: cos = [c0..c_{hD-1}, c0..c_{hD-1}] (halves duplicated) + // interleave: cos = [c0,c0, c1,c1, ..., c_{hD-1},c_{hD-1}] + std::vector cos_host(table_bytes); + std::vector sin_host(table_bytes); + + for (int64_t p = 0; p < max_seq_len; ++p) { + for (int64_t j = 0; j < half_D; ++j) { + const auto* c_src = + cache_host.data() + static_cast(p * D + j) * elem_sz; + const auto* s_src = cache_host.data() + + static_cast(p * D + half_D + j) * elem_sz; + + // Neox expansion: [c0..c_{hD-1}, c0..c_{hD-1}] (halves duplicated). + std::memcpy(cos_host.data() + static_cast(p * D + j) * elem_sz, + c_src, elem_sz); + std::memcpy( + cos_host.data() + static_cast(p * D + half_D + j) * elem_sz, + c_src, elem_sz); + std::memcpy(sin_host.data() + static_cast(p * D + j) * elem_sz, + s_src, elem_sz); + std::memcpy( + sin_host.data() + static_cast(p * D + half_D + j) * elem_sz, + s_src, elem_sz); + } + } + + // Upload expanded tables to device (one-time). + aclrtMalloc(&cos_table_dev_, table_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&sin_table_dev_, table_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpy(cos_table_dev_, table_bytes, cos_host.data(), table_bytes, + ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(sin_table_dev_, table_bytes, sin_host.data(), table_bytes, + ACL_MEMCPY_HOST_TO_DEVICE); + + const int64_t T = num_tokens_; + const int64_t Nq = num_heads_; + const int64_t Nkv = num_kv_heads_; + aclDataType acl_dt = ascend::toAclDtype(query.dtype()); + + // Gathered cos/sin buffers [T, D] — filled by `aclnnIndexSelect` each call. + size_t gathered_bytes = static_cast(T * D) * elem_sz; + aclrtMalloc(&cos_dev_, gathered_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&sin_dev_, gathered_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + // IndexSelect descriptors: table ptrs stable, positions ptr varies. + cos_table_cache_ = + ascend::AclTensorCache({max_seq_len, D}, acl_dt, cos_table_dev_); + sin_table_cache_ = + ascend::AclTensorCache({max_seq_len, D}, acl_dt, sin_table_dev_); + idx_cache_ = ascend::AclTensorCache({T}, ACL_INT64, + const_cast(positions.data())); + cos_out_cache_ = ascend::AclTensorCache({T, D}, acl_dt, cos_dev_); + sin_out_cache_ = ascend::AclTensorCache({T, D}, acl_dt, sin_dev_); + + // V2 descriptors: cos/sin [T, 1, D], Q [T, Nq, D], K [T, Nkv, D]. + cos_v2_cache_ = ascend::AclTensorCache({T, 1, D}, acl_dt, cos_dev_); + sin_v2_cache_ = ascend::AclTensorCache({T, 1, D}, acl_dt, sin_dev_); + q_cache_ = ascend::AclTensorCache({T, Nq, D}, acl_dt, + const_cast(query_out.data())); + k_cache_ = ascend::AclTensorCache({T, Nkv, D}, acl_dt, + const_cast(key_out.data())); + } + + ~Operator() { + if (idx_cos_exec_) aclDestroyAclOpExecutor(idx_cos_exec_); + if (idx_sin_exec_) aclDestroyAclOpExecutor(idx_sin_exec_); + if (v2_exec_) aclDestroyAclOpExecutor(v2_exec_); + + if (cos_table_dev_) aclrtFree(cos_table_dev_); + if (sin_table_dev_) aclrtFree(sin_table_dev_); + if (cos_dev_) aclrtFree(cos_dev_); + if (sin_dev_) aclrtFree(sin_dev_); + } + + void operator()(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, + int64_t rotary_dim, bool is_neox_style, Tensor query_out, + Tensor key_out) const override { + auto stream = static_cast(stream_); + + const int64_t T = query.size(0); + const int64_t Nq = query.size(1); + const int64_t Nkv = key.size(1); + const int64_t D = head_size; + + // Step 1: Gather cos/sin by positions via `aclnnIndexSelect` (async). + { + auto t_cos_table = cos_table_cache_.get(cos_table_dev_); + auto t_sin_table = sin_table_cache_.get(sin_table_dev_); + auto t_idx = idx_cache_.get(const_cast(positions.data())); + auto t_cos_out = cos_out_cache_.get(cos_dev_); + auto t_sin_out = sin_out_cache_.get(sin_dev_); + + if (!idx_cos_exec_) { + aclnnIndexSelectGetWorkspaceSize(t_cos_table, 0, t_idx, t_cos_out, + &idx_cos_ws_, &idx_cos_exec_); + aclSetAclOpExecutorRepeatable(idx_cos_exec_); + } else { + aclSetInputTensorAddr(idx_cos_exec_, 1, t_idx, + const_cast(positions.data())); + } + + if (!idx_sin_exec_) { + aclnnIndexSelectGetWorkspaceSize(t_sin_table, 0, t_idx, t_sin_out, + &idx_sin_ws_, &idx_sin_exec_); + aclSetAclOpExecutorRepeatable(idx_sin_exec_); + } else { + aclSetInputTensorAddr(idx_sin_exec_, 1, t_idx, + const_cast(positions.data())); + } + + uint64_t ws_max = idx_cos_ws_ > idx_sin_ws_ ? idx_cos_ws_ : idx_sin_ws_; + auto& arena = ascend::workspacePool().ensure(stream, ws_max); + + aclnnIndexSelect(arena.buf, idx_cos_ws_, idx_cos_exec_, stream); + aclnnIndexSelect(arena.buf, idx_sin_ws_, idx_sin_exec_, stream); + } + + // Step 2: Copy q→q_out, k→k_out if not inplace (V2 operates inplace). + size_t elem_sz = query.element_size(); + + if (query.data() != query_out.data()) { + aclrtMemcpyAsync(query_out.data(), + static_cast(T * Nq * D) * elem_sz, query.data(), + static_cast(T * Nq * D) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + + if (key.data() != key_out.data()) { + aclrtMemcpyAsync(key_out.data(), + static_cast(T * Nkv * D) * elem_sz, key.data(), + static_cast(T * Nkv * D) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + + // Step 3: Apply V2 RoPE inplace on q_out and k_out. + auto t_cos = cos_v2_cache_.get(cos_dev_); + auto t_sin = sin_v2_cache_.get(sin_dev_); + auto t_q = q_cache_.get(query_out.data()); + auto t_k = k_cache_.get(key_out.data()); + + if (!v2_exec_) { + aclnnApplyRotaryPosEmbV2GetWorkspaceSize( + t_q, t_k, t_cos, t_sin, /*layout=*/4, const_cast("half"), + &v2_ws_, &v2_exec_); + aclSetAclOpExecutorRepeatable(v2_exec_); + } else { + aclSetInputTensorAddr(v2_exec_, 0, t_q, query_out.data()); + aclSetInputTensorAddr(v2_exec_, 1, t_k, key_out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, v2_ws_); + aclnnApplyRotaryPosEmbV2(arena.buf, v2_ws_, v2_exec_, stream); + } + + private: + // Pre-expanded cos/sin tables on device: [max_seq_len, D]. + void* cos_table_dev_ = nullptr; + + void* sin_table_dev_ = nullptr; + + // Device buffers for gathered [T, D] cos/sin. + void* cos_dev_ = nullptr; + + void* sin_dev_ = nullptr; + + // IndexSelect descriptors. + mutable ascend::AclTensorCache cos_table_cache_; + + mutable ascend::AclTensorCache sin_table_cache_; + + mutable ascend::AclTensorCache idx_cache_; + + mutable ascend::AclTensorCache cos_out_cache_; + + mutable ascend::AclTensorCache sin_out_cache_; + + // V2 descriptors. + mutable ascend::AclTensorCache cos_v2_cache_; + + mutable ascend::AclTensorCache sin_v2_cache_; + + mutable ascend::AclTensorCache q_cache_; + + mutable ascend::AclTensorCache k_cache_; + + // Cached executors. + mutable aclOpExecutor* idx_cos_exec_ = nullptr; + + mutable uint64_t idx_cos_ws_ = 0; + + mutable aclOpExecutor* idx_sin_exec_ = nullptr; + + mutable uint64_t idx_sin_ws_ = 0; + + mutable aclOpExecutor* v2_exec_ = nullptr; + + mutable uint64_t v2_ws_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/rotary_embedding/kernel_atb.h b/src/ascend/rotary_embedding/kernel_atb.h new file mode 100644 index 00000000..82b2ced1 --- /dev/null +++ b/src/ascend/rotary_embedding/kernel_atb.h @@ -0,0 +1,283 @@ +#ifndef INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_ATB_H_ +#define INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_ATB_H_ + +#ifdef INFINI_HAS_ATB + +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "ascend/atb_common_.h" +#include "ascend/common.h" +#include "ascend/rotary_embedding/registry.h" +#include "ascend/workspace_pool_.h" +#include "atb/context.h" +#include "atb/infer_op_params.h" +#include "atb/operation.h" +#include "atb/types.h" +#include "base/rotary_embedding.h" +#include "operator.h" + +namespace infini::ops { + +// ATB-based rotary position embedding (implementation index 1). +// +// Wraps ATB `RopeParam` which applies rotary embedding in a single fused +// kernel. ATB Rope handles position gathering internally, eliminating +// the 2x `aclnnIndexSelect` calls that produce ~62k GatherV3+Slice +// kernels per inference step in the CANN path (index=0). +// +// ATB Rope expects 5 inputs and 2 outputs: +// inTensors[0] = query [num_tokens, hiddenSizeQ] +// inTensors[1] = key [num_tokens, hiddenSizeK] +// inTensors[2] = cos_table [max_seq_len, headDim] +// inTensors[3] = sin_table [max_seq_len, headDim] +// inTensors[4] = seq_len [num_tokens] (int32, position indices) +// outTensors[0] = query_out [num_tokens, hiddenSizeQ] +// outTensors[1] = key_out [num_tokens, hiddenSizeK] +// +// The constructor splits the cos_sin_cache into separate cos/sin +// device tables [max_seq_len, headDim] with neox expansion. +// +// Restrictions: +// - rotary_dim must equal head_size (full rotation only). +// - is_neox_style must be true (`rotaryCoeff`=2). +// - fp16 only (ATB inference constraint). +template <> +class Operator + : public RotaryEmbedding { + public: + Operator(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, int64_t rotary_dim, + bool is_neox_style, Tensor query_out, Tensor key_out) + : RotaryEmbedding(positions, query, key, cos_sin_cache, head_size, + rotary_dim, is_neox_style, query_out, key_out) { + assert(rotary_dim == head_size && + "ATB `RotaryEmbedding` requires `rotary_dim` == `head_size`"); + assert(is_neox_style && + "ATB `RotaryEmbedding` requires neox style (`rotaryCoeff`=2)"); + + const int64_t max_seq_len = cos_sin_cache.size(0); + const int64_t D = head_size_; + const int64_t half_D = D / 2; + const size_t elem_sz = cos_sin_cache.element_size(); + + // One-time: D2H copy cos_sin_cache, split into cos/sin, upload. + // cos_sin_cache layout per row: [c0..c_{hD-1}, s0..s_{hD-1}]. + size_t row_bytes = static_cast(D) * elem_sz; + size_t table_bytes = static_cast(max_seq_len) * row_bytes; + + std::vector cache_host(table_bytes); + aclrtMemcpy(cache_host.data(), table_bytes, cos_sin_cache.data(), + table_bytes, ACL_MEMCPY_DEVICE_TO_HOST); + + // ATB Rope with `rotaryCoeff`=2 expects cos/sin of shape [S, D]. + // Neox-style expansion: [c0..c_{hD-1}, c0..c_{hD-1}]. + std::vector cos_host(table_bytes); + std::vector sin_host(table_bytes); + + for (int64_t p = 0; p < max_seq_len; ++p) { + for (int64_t j = 0; j < half_D; ++j) { + const auto* c_src = + cache_host.data() + static_cast(p * D + j) * elem_sz; + const auto* s_src = cache_host.data() + + static_cast(p * D + half_D + j) * elem_sz; + + std::memcpy(cos_host.data() + static_cast(p * D + j) * elem_sz, + c_src, elem_sz); + std::memcpy( + cos_host.data() + static_cast(p * D + half_D + j) * elem_sz, + c_src, elem_sz); + std::memcpy(sin_host.data() + static_cast(p * D + j) * elem_sz, + s_src, elem_sz); + std::memcpy( + sin_host.data() + static_cast(p * D + half_D + j) * elem_sz, + s_src, elem_sz); + } + } + + // Upload expanded tables to device (persistent, reused across calls). + aclrtMalloc(&cos_table_dev_, table_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&sin_table_dev_, table_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpy(cos_table_dev_, table_bytes, cos_host.data(), table_bytes, + ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(sin_table_dev_, table_bytes, sin_host.data(), table_bytes, + ACL_MEMCPY_HOST_TO_DEVICE); + + // Cache shapes and metadata. + // Query/key may be 2D [T, N*D] or 3D [T, N, D]. Derive the total hidden + // size directly from the tensor to handle both layouts. + const int64_t T = num_tokens_; + int64_t hiddenQ = static_cast(query.numel()) / T; + int64_t hiddenK = static_cast(key.numel()) / T; + q_2d_shape_ = {T, hiddenQ}; + k_2d_shape_ = {T, hiddenK}; + cos_sin_table_shape_ = {max_seq_len, D}; + pos_shape_ = {T}; + acl_dt_ = ascend::toAclDtype(query.dtype()); + elem_size_ = static_cast(elem_sz); + max_seq_len_ = max_seq_len; + + // Create the ATB Rope operation. + atb::infer::RopeParam param; + param.rotaryCoeff = 2; // Neox half-rotation. + param.cosFormat = 0; // Inference mode. + atb::Status s = atb::CreateOperation(param, &op_); + + assert(s == atb::NO_ERROR && "atb::CreateOperation(Rope) failed"); + } + + ~Operator() { + if (!ascend::isAclRuntimeAlive()) return; + if (op_) atb::DestroyOperation(op_); + if (cos_table_dev_) aclrtFree(cos_table_dev_); + if (sin_table_dev_) aclrtFree(sin_table_dev_); + if (pos_buf_dev_) aclrtFree(pos_buf_dev_); + } + + Operator(const Operator&) = delete; + + Operator& operator=(const Operator&) = delete; + + void operator()(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, + int64_t rotary_dim, bool is_neox_style, Tensor query_out, + Tensor key_out) const override { + auto stream = static_cast(stream_); + + int64_t T = query.size(0); + int64_t D = head_size; + + // Query/key may be 2D [T, N*D] or 3D [T, N, D]. Compute total hidden + // sizes from the tensor element count to handle both layouts. + int64_t hiddenQ = static_cast(query.numel()) / T; + int64_t hiddenK = static_cast(key.numel()) / T; + + // Copy q→q_out, k→k_out if not in-place. + size_t elem_sz = query.element_size(); + + if (query.data() != query_out.data()) { + aclrtMemcpyAsync(query_out.data(), + static_cast(T * hiddenQ) * elem_sz, query.data(), + static_cast(T * hiddenQ) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + + if (key.data() != key_out.data()) { + aclrtMemcpyAsync(key_out.data(), + static_cast(T * hiddenK) * elem_sz, key.data(), + static_cast(T * hiddenK) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + + // Provide int32 positions to ATB. When the caller pre-casts to int32 + // (required for NPU graph capture), a device-to-device copy suffices. + // The D2H+sync fallback remains for standalone tests with int64 positions. + size_t pos32_bytes = static_cast(T) * sizeof(int32_t); + + if (pos32_bytes > pos_buf_size_) { + if (pos_buf_dev_) aclrtFree(pos_buf_dev_); + aclrtMalloc(&pos_buf_dev_, pos32_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + pos_buf_size_ = pos32_bytes; + } + + if (positions.element_size() == sizeof(int32_t)) { + // Already int32 — async D2D copy, graph-compatible. + aclrtMemcpyAsync(pos_buf_dev_, pos32_bytes, positions.data(), pos32_bytes, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } else { + // int64 fallback — D2H, CPU cast, H2D (not graph-compatible). + std::vector pos_i64(static_cast(T)); + aclrtMemcpyAsync(pos_i64.data(), static_cast(T) * sizeof(int64_t), + positions.data(), + static_cast(T) * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + + std::vector pos_i32(static_cast(T)); + + for (int64_t i = 0; i < T; ++i) { + pos_i32[static_cast(i)] = + static_cast(pos_i64[static_cast(i)]); + } + + aclrtMemcpyAsync(pos_buf_dev_, pos32_bytes, pos_i32.data(), pos32_bytes, + ACL_MEMCPY_HOST_TO_DEVICE, stream); + } + + // Build ATB `VariantPack` with 5 inputs + 2 outputs. + atb::Context* ctx = ascend::getAtbContext(stream); + + uint64_t q_bytes = static_cast(T * hiddenQ) * elem_size_; + uint64_t k_bytes = static_cast(T * hiddenK) * elem_size_; + uint64_t table_bytes = static_cast(max_seq_len_ * D) * elem_size_; + + atb::Tensor t_q = + ascend::toAtbTensor(q_2d_shape_, acl_dt_, query_out.data(), q_bytes); + atb::Tensor t_k = + ascend::toAtbTensor(k_2d_shape_, acl_dt_, key_out.data(), k_bytes); + atb::Tensor t_cos = ascend::toAtbTensor(cos_sin_table_shape_, acl_dt_, + cos_table_dev_, table_bytes); + atb::Tensor t_sin = ascend::toAtbTensor(cos_sin_table_shape_, acl_dt_, + sin_table_dev_, table_bytes); + atb::Tensor t_pos = + ascend::toAtbTensor(pos_shape_, ACL_INT32, pos_buf_dev_, pos32_bytes); + + atb::VariantPack vp; + vp.inTensors = {t_q, t_k, t_cos, t_sin, t_pos}; + vp.outTensors = {t_q, t_k}; + + uint64_t ws_size = 0; + atb::Status s = op_->Setup(vp, ws_size, ctx); + + assert(s == atb::NO_ERROR && "ATB rope setup failed"); + + uint8_t* ws_ptr = nullptr; + + if (ws_size > 0) { + auto& arena = ascend::workspacePool().ensure(stream, ws_size); + ws_ptr = static_cast(arena.buf); + } + + s = op_->Execute(vp, ws_ptr, ws_size, ctx); + + assert(s == atb::NO_ERROR && "ATB rope execute failed"); + } + + private: + atb::Operation* op_ = nullptr; + + // Neox-expanded cos/sin tables on device: [max_seq_len, D]. + void* cos_table_dev_ = nullptr; + + void* sin_table_dev_ = nullptr; + + // Reusable int32 positions buffer on device. + mutable void* pos_buf_dev_ = nullptr; + + mutable size_t pos_buf_size_ = 0; + + // Cached shapes for ATB `VariantPack`. + std::vector q_2d_shape_; + + std::vector k_2d_shape_; + + std::vector cos_sin_table_shape_; + + std::vector pos_shape_; + + aclDataType acl_dt_ = ACL_DT_UNDEFINED; + + uint64_t elem_size_ = 0; + + int64_t max_seq_len_ = 0; +}; + +} // namespace infini::ops + +#endif // INFINI_HAS_ATB + +#endif // INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_ATB_H_ diff --git a/src/ascend/rotary_embedding/registry.h b/src/ascend/rotary_embedding/registry.h new file mode 100644 index 00000000..6055aa79 --- /dev/null +++ b/src/ascend/rotary_embedding/registry.h @@ -0,0 +1,21 @@ +#ifndef INFINI_OPS_ASCEND_ROTARY_EMBEDDING_REGISTRY_H_ +#define INFINI_OPS_ASCEND_ROTARY_EMBEDDING_REGISTRY_H_ + +#include "base/rotary_embedding.h" + +namespace infini::ops { + +// Implementation 0: `aclnnApplyRotaryPosEmbV2` (CANN, 2× IndexSelect + V2). +// Implementation 1: ATB `Rope` (fused kernel, eliminates GatherV3+Slice). +template <> +struct ActiveImplementationsImpl { +#if defined(INFINI_HAS_ATB) + using type = List<0, 1>; +#else + using type = List<0>; +#endif +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/silu_and_mul/kernel.h b/src/ascend/silu_and_mul/kernel.h new file mode 100644 index 00000000..816cb544 --- /dev/null +++ b/src/ascend/silu_and_mul/kernel.h @@ -0,0 +1,117 @@ +#ifndef INFINI_OPS_ASCEND_SILU_AND_MUL_KERNEL_H_ +#define INFINI_OPS_ASCEND_SILU_AND_MUL_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_copy.h" +#include "aclnnop/aclnn_swi_glu.h" +#include "ascend/common.h" +#include "ascend/silu_and_mul/registry.h" +#include "ascend/workspace_pool_.h" +#include "base/silu_and_mul.h" +#include "operator.h" + +namespace infini::ops { + +// Calls `aclnnSwiGlu` directly on the concatenated `x = [gate, up]` tensor. +// +// `aclnnSwiGlu` splits `x` along `dim` into `[first_half, second_half]` and +// computes `second_half * silu(first_half)`, i.e. `up * silu(gate)`. +// +// `aclnnSwiGlu` ignores output strides and writes contiguously. When the +// output is non-contiguous, a contiguous staging buffer is used and the +// result is copied back via `aclnnInplaceCopy`. +template <> +class Operator : public SiluAndMul { + public: + Operator(const Tensor x, int64_t dim, Tensor out) + : SiluAndMul(x, dim, out), x_cache_(x), out_cache_(out) { + needs_copy_ = !is_out_contiguous_; + + if (needs_copy_) { + out_staging_size_ = out.numel() * kDataTypeToSize.at(out.dtype()); + } + } + + ~Operator() { + if (swiglu_exec_) aclDestroyAclOpExecutor(swiglu_exec_); + if (copy_exec_) aclDestroyAclOpExecutor(copy_exec_); + } + + void operator()(const Tensor x, int64_t dim, Tensor out) const override { + auto t_x = x_cache_.get(const_cast(x.data())); + auto t_out = out_cache_.get(out.data()); + auto stream = static_cast(stream_); + + // Determine effective output target. + aclTensor* t_swiglu_out = t_out; + void* swiglu_out_data = out.data(); + + if (needs_copy_) { + auto& staging = + ascend::workspacePool().ensure(stream, out_staging_size_, "staging"); + + if (!out_staging_cache_) { + std::vector out_shape(out_shape_.begin(), out_shape_.end()); + out_staging_cache_.emplace(out_shape, ascend::toAclDtype(out_dtype_), + staging.buf); + } + + t_swiglu_out = out_staging_cache_->get(staging.buf); + swiglu_out_data = staging.buf; + } + + // Call `aclnnSwiGlu`. + if (!swiglu_exec_) { + aclnnSwiGluGetWorkspaceSize(t_x, dim_, t_swiglu_out, &swiglu_ws_, + &swiglu_exec_); + aclSetAclOpExecutorRepeatable(swiglu_exec_); + } else { + aclSetInputTensorAddr(swiglu_exec_, 0, t_x, const_cast(x.data())); + aclSetOutputTensorAddr(swiglu_exec_, 0, t_swiglu_out, swiglu_out_data); + } + + auto& arena = ascend::workspacePool().ensure(stream, swiglu_ws_); + aclnnSwiGlu(arena.buf, swiglu_ws_, swiglu_exec_, stream); + + // Copy staging buffer back to non-contiguous output if needed. + if (needs_copy_) { + if (!copy_exec_) { + aclnnInplaceCopyGetWorkspaceSize(t_out, t_swiglu_out, ©_ws_, + ©_exec_); + aclSetAclOpExecutorRepeatable(copy_exec_); + } else { + aclSetInputTensorAddr(copy_exec_, 0, t_out, out.data()); + aclSetInputTensorAddr(copy_exec_, 1, t_swiglu_out, swiglu_out_data); + } + + auto& copy_arena = ascend::workspacePool().ensure(stream, copy_ws_); + aclnnInplaceCopy(copy_arena.buf, copy_ws_, copy_exec_, stream); + } + } + + private: + mutable ascend::AclTensorCache x_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable std::optional out_staging_cache_; + + bool needs_copy_ = false; + + uint64_t out_staging_size_ = 0; + + mutable aclOpExecutor* swiglu_exec_ = nullptr; + + mutable uint64_t swiglu_ws_ = 0; + + mutable aclOpExecutor* copy_exec_ = nullptr; + + mutable uint64_t copy_ws_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/silu_and_mul/registry.h b/src/ascend/silu_and_mul/registry.h new file mode 100644 index 00000000..5718b882 --- /dev/null +++ b/src/ascend/silu_and_mul/registry.h @@ -0,0 +1,15 @@ +#ifndef INFINI_OPS_ASCEND_SILU_AND_MUL_REGISTRY_H_ +#define INFINI_OPS_ASCEND_SILU_AND_MUL_REGISTRY_H_ + +#include "base/silu_and_mul.h" + +namespace infini::ops { + +template <> +struct ActiveImplementationsImpl { + using type = List<0>; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/swiglu/kernel.h b/src/ascend/swiglu/kernel.h new file mode 100644 index 00000000..74d7044f --- /dev/null +++ b/src/ascend/swiglu/kernel.h @@ -0,0 +1,104 @@ +#ifndef INFINI_OPS_ASCEND_SWIGLU_KERNEL_H_ +#define INFINI_OPS_ASCEND_SWIGLU_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_mul.h" +#include "aclnn_silu.h" +#include "ascend/common.h" +#include "ascend/swiglu/registry.h" +#include "ascend/workspace_pool_.h" +#include "base/swiglu.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +// Implements SwiGLU as two ACLNN calls: silu(gate) into a temp buffer, +// then elementwise mul(input, temp) into out. +// `aclnnSiluMul` was not used because it fuses silu_AND_mul on the same +// tensor (x * silu(x)), whereas SwiGLU requires input * silu(gate) — +// two distinct inputs. +template <> +class Operator : public Swiglu { + public: + Operator(const Tensor input, const Tensor gate, Tensor out) + : Swiglu(input, gate, out), + in_cache_(input), + gate_cache_(gate), + out_cache_(out) { + temp_size_ = input.numel() * kDataTypeToSize.at(input.dtype()); + + // Build temp cache from gate geometry (contiguous, same shape/dtype). + // No data pointer yet — will be set on first `get()` call. + Tensor temp_t{nullptr, gate.shape(), gate.dtype(), gate.device()}; + temp_cache_ = ascend::AclTensorCache(temp_t); + } + + ~Operator() { + if (silu_exec_) aclDestroyAclOpExecutor(silu_exec_); + if (mul_exec_) aclDestroyAclOpExecutor(mul_exec_); + } + + void operator()(const Tensor input, const Tensor gate, + Tensor out) const override { + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_gate = gate_cache_.get(const_cast(gate.data())); + auto t_out = out_cache_.get(out.data()); + auto stream = static_cast(stream_); + + // Obtain shared temp buffer from pool. + auto& temp = ascend::workspacePool().ensure(stream, temp_size_, "temp"); + auto t_temp = temp_cache_.get(temp.buf); + + // Step 1: silu(gate) -> temp. + if (!silu_exec_) { + aclnnSiluGetWorkspaceSize(t_gate, t_temp, &silu_ws_, &silu_exec_); + aclSetAclOpExecutorRepeatable(silu_exec_); + } else { + aclSetInputTensorAddr(silu_exec_, 0, t_gate, + const_cast(gate.data())); + aclSetOutputTensorAddr(silu_exec_, 0, t_temp, temp.buf); + } + auto& silu_arena = ascend::workspacePool().ensure(stream, silu_ws_); + aclnnSilu(silu_arena.buf, silu_ws_, silu_exec_, stream); + + // Step 2: mul(input, temp) -> out. + if (!mul_exec_) { + aclnnMulGetWorkspaceSize(t_in, t_temp, t_out, &mul_ws_, &mul_exec_); + aclSetAclOpExecutorRepeatable(mul_exec_); + } else { + aclSetInputTensorAddr(mul_exec_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(mul_exec_, 1, t_temp, temp.buf); + aclSetOutputTensorAddr(mul_exec_, 0, t_out, out.data()); + } + auto& mul_arena = ascend::workspacePool().ensure(stream, mul_ws_); + aclnnMul(mul_arena.buf, mul_ws_, mul_exec_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache gate_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache temp_cache_; + + uint64_t temp_size_ = 0; + + mutable aclOpExecutor* silu_exec_ = nullptr; + + mutable uint64_t silu_ws_ = 0; + + mutable aclOpExecutor* mul_exec_ = nullptr; + + mutable uint64_t mul_ws_ = 0; +}; + +} // namespace infini::ops + +#include "ascend/swiglu/kernel_fused.h" + +#endif diff --git a/src/ascend/swiglu/kernel_fused.h b/src/ascend/swiglu/kernel_fused.h new file mode 100644 index 00000000..e7653e20 --- /dev/null +++ b/src/ascend/swiglu/kernel_fused.h @@ -0,0 +1,189 @@ +#ifndef INFINI_OPS_ASCEND_SWIGLU_KERNEL_FUSED_H_ +#define INFINI_OPS_ASCEND_SWIGLU_KERNEL_FUSED_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_copy.h" +#include "aclnnop/aclnn_cat.h" +#include "aclnnop/aclnn_swi_glu.h" +#include "ascend/common.h" +#include "ascend/swiglu/registry.h" +#include "ascend/workspace_pool_.h" +#include "base/swiglu.h" +#include "operator.h" + +namespace infini::ops { + +// Fused implementation via `aclnnSwiGlu` (implementation index 1). +// +// Concatenates `[gate, input]` into a temp buffer via `aclnnCat`, then calls +// `aclnnSwiGlu` which computes `second_half * silu(first_half)` in a single +// fused kernel, i.e. `input * silu(gate)`. +// +// This trades an extra `aclnnCat` launch for a single fused SwiGLU kernel +// instead of separate `aclnnSilu` + `aclnnMul`. The net benefit is one fewer +// intermediate buffer materialised on-device (the silu temp is eliminated). +// +// `aclnnSwiGlu` requires a contiguous output tensor. When the caller's output +// is non-contiguous, a contiguous temp buffer is used and the result is copied +// back via `aclnnInplaceCopy`. +// +// Select via `implementation_index=1` in Python: +// infini.ops.swiglu(..., implementation_index=1, stream=s) +template <> +class Operator : public Swiglu { + public: + Operator(const Tensor input, const Tensor gate, Tensor out) + : Swiglu(input, gate, out), + gate_cache_(gate), + in_cache_(input), + out_cache_(out) { + // Compute the concatenated shape: same as input but with last dim doubled. + cat_shape_.assign(input.shape().begin(), input.shape().end()); + cat_shape_.back() *= 2; + + uint64_t cat_elems = 1; + + for (auto d : cat_shape_) { + cat_elems *= static_cast(d); + } + + cat_size_ = cat_elems * kDataTypeToSize.at(input.dtype()); + + // `aclnnSwiGlu` ignores output strides and writes contiguously. + // When the output is non-contiguous we need a contiguous staging buffer. + needs_copy_ = !is_out_contiguous_; + + if (needs_copy_) { + out_staging_size_ = output_size_ * kDataTypeToSize.at(out.dtype()); + } + } + + ~Operator() { + if (cat_exec_) aclDestroyAclOpExecutor(cat_exec_); + if (swiglu_exec_) aclDestroyAclOpExecutor(swiglu_exec_); + if (copy_exec_) aclDestroyAclOpExecutor(copy_exec_); + if (cat_tensor_list_) aclDestroyTensorList(cat_tensor_list_); + } + + void operator()(const Tensor input, const Tensor gate, + Tensor out) const override { + auto t_gate = gate_cache_.get(const_cast(gate.data())); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_out = out_cache_.get(out.data()); + auto stream = static_cast(stream_); + + // Obtain shared temp buffer for the concatenated tensor. + auto& cat_arena = ascend::workspacePool().ensure(stream, cat_size_, "temp"); + + // Lazily build the cat output tensor cache on first call. + if (!cat_out_cache_) { + cat_out_cache_.emplace(cat_shape_, ascend::toAclDtype(input_type_), + cat_arena.buf); + } + + auto t_cat = cat_out_cache_->get(cat_arena.buf); + + // Step 1: cat([gate, input], dim=-1) -> cat_buf. + if (!cat_exec_) { + aclTensor* tensors[2] = {t_gate, t_in}; + cat_tensor_list_ = + aclCreateTensorList(const_cast(tensors), 2); + aclnnCatGetWorkspaceSize(cat_tensor_list_, + static_cast(ndim_ - 1), t_cat, &cat_ws_, + &cat_exec_); + aclSetAclOpExecutorRepeatable(cat_exec_); + } else { + // The tensor list references the same `aclTensor*` objects whose data + // pointers were already updated by `get()` above. + aclSetOutputTensorAddr(cat_exec_, 0, t_cat, cat_arena.buf); + } + + auto& cat_ws_arena = ascend::workspacePool().ensure(stream, cat_ws_); + aclnnCat(cat_ws_arena.buf, cat_ws_, cat_exec_, stream); + + // Step 2: swiglu(cat_buf, dim=-1) -> out (or staging buffer). + aclTensor* t_swiglu_out = t_out; + void* swiglu_out_data = out.data(); + + if (needs_copy_) { + auto& staging = + ascend::workspacePool().ensure(stream, out_staging_size_, "staging"); + + if (!out_staging_cache_) { + std::vector out_shape(out_shape_.begin(), out_shape_.end()); + out_staging_cache_.emplace(out_shape, ascend::toAclDtype(out_type_), + staging.buf); + } + + t_swiglu_out = out_staging_cache_->get(staging.buf); + swiglu_out_data = staging.buf; + } + + if (!swiglu_exec_) { + aclnnSwiGluGetWorkspaceSize(t_cat, static_cast(ndim_ - 1), + t_swiglu_out, &swiglu_ws_, &swiglu_exec_); + aclSetAclOpExecutorRepeatable(swiglu_exec_); + } else { + aclSetInputTensorAddr(swiglu_exec_, 0, t_cat, cat_arena.buf); + aclSetOutputTensorAddr(swiglu_exec_, 0, t_swiglu_out, swiglu_out_data); + } + + auto& swiglu_arena = ascend::workspacePool().ensure(stream, swiglu_ws_); + aclnnSwiGlu(swiglu_arena.buf, swiglu_ws_, swiglu_exec_, stream); + + // Step 3 (non-contiguous output only): copy staging -> out. + if (needs_copy_) { + if (!copy_exec_) { + aclnnInplaceCopyGetWorkspaceSize(t_out, t_swiglu_out, ©_ws_, + ©_exec_); + aclSetAclOpExecutorRepeatable(copy_exec_); + } else { + aclSetInputTensorAddr(copy_exec_, 0, t_out, out.data()); + aclSetInputTensorAddr(copy_exec_, 1, t_swiglu_out, swiglu_out_data); + } + + auto& copy_arena = ascend::workspacePool().ensure(stream, copy_ws_); + aclnnInplaceCopy(copy_arena.buf, copy_ws_, copy_exec_, stream); + } + } + + private: + mutable ascend::AclTensorCache gate_cache_; + + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable std::optional cat_out_cache_; + + mutable std::optional out_staging_cache_; + + std::vector cat_shape_; + + uint64_t cat_size_ = 0; + + bool needs_copy_ = false; + + uint64_t out_staging_size_ = 0; + + mutable aclTensorList* cat_tensor_list_ = nullptr; + + mutable aclOpExecutor* cat_exec_ = nullptr; + + mutable uint64_t cat_ws_ = 0; + + mutable aclOpExecutor* swiglu_exec_ = nullptr; + + mutable uint64_t swiglu_ws_ = 0; + + mutable aclOpExecutor* copy_exec_ = nullptr; + + mutable uint64_t copy_ws_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/swiglu/registry.h b/src/ascend/swiglu/registry.h new file mode 100644 index 00000000..8c7d6545 --- /dev/null +++ b/src/ascend/swiglu/registry.h @@ -0,0 +1,15 @@ +#ifndef INFINI_OPS_ASCEND_SWIGLU_REGISTRY_H_ +#define INFINI_OPS_ASCEND_SWIGLU_REGISTRY_H_ + +#include "base/swiglu.h" + +namespace infini::ops { + +template <> +struct ActiveImplementationsImpl { + using type = List<0, 1>; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/workspace_pool_.h b/src/ascend/workspace_pool_.h index ebb670da..bf7452fd 100644 --- a/src/ascend/workspace_pool_.h +++ b/src/ascend/workspace_pool_.h @@ -2,7 +2,11 @@ #define INFINI_OPS_ASCEND_WORKSPACE_POOL__H_ #include +#include #include +#include +#include +#include #include #include @@ -18,36 +22,134 @@ struct WorkspaceArena { class WorkspacePool { public: - WorkspaceArena& ensure(aclrtStream stream, uint64_t needed) { + // Ensure the arena for `(stream, slot)` has at least `needed` bytes. + // + // The `slot` parameter defaults to `"workspace"` for backward + // compatibility. Operators needing a separate temp arena pass + // `"temp"`. See the design spec for details: + // docs/superpowers/specs/2026-04-12-workspace-pool-multi-slot-design.md + WorkspaceArena& ensure(aclrtStream stream, uint64_t needed, + const char* slot = "workspace") { + // Thread-local fast path: a small flat array of recently used + // `(stream, slot, arena*)` triples. In practice operators use at + // most 2-3 slots, so linear scan is sufficient — no heap + // allocation on the hot path. + struct CacheEntry { + aclrtStream stream = nullptr; + const char* slot = nullptr; + WorkspaceArena* arena = nullptr; + }; + static constexpr int kCacheSize = 4; + thread_local CacheEntry cache[kCacheSize] = {}; + + for (int i = 0; i < kCacheSize; ++i) { + auto& e = cache[i]; + + if (e.stream == stream && e.slot != nullptr && + std::strcmp(e.slot, slot) == 0 && e.arena != nullptr && + needed <= e.arena->capacity) { + return *e.arena; + } + } + + // Slow path: look up arena in the map under lock. + assert(!capturing_ && + "`WorkspacePool`: `aclrtMalloc` on slow path during graph " + "capture; ensure all operators run at least once during " + "eager warmup"); + std::lock_guard lock(mutex_); - auto& arena = arenas_[stream]; - if (needed <= arena.capacity) return arena; - if (arena.capacity > 0) { - aclrtSynchronizeStream(stream); - aclrtFree(arena.buf); + + SlotKey key{stream, slot}; + auto& owned = arenas_[key]; + + if (!owned) { + owned = std::make_unique(); } - if (needed > 0) { - auto ret = aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); - assert(ret == ACL_SUCCESS && "`WorkspacePool`: `aclrtMalloc` failed"); + + auto* arena = owned.get(); + + if (needed > arena->capacity) { + if (arena->capacity > 0) { + aclrtSynchronizeStream(stream); + aclrtFree(arena->buf); + } + + if (needed > 0) { + auto ret = aclrtMalloc(&arena->buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); + assert(ret == ACL_SUCCESS && "`WorkspacePool`: `aclrtMalloc` failed"); + } + + arena->capacity = needed; + } + + // Insert into the thread-local cache (evict oldest). + for (int i = kCacheSize - 1; i > 0; --i) { + cache[i] = cache[i - 1]; } - arena.capacity = needed; - return arena; + cache[0] = {stream, slot, arena}; + + return *arena; } + // Set to true before NPUGraph capture, false after. When true, + // the slow path (which calls `aclrtMalloc`) triggers an assert + // failure — a safety net against accidental device allocations + // being recorded into the graph. + void set_capture_mode(bool capturing) { capturing_ = capturing; } + ~WorkspacePool() { - for (auto& [stream, arena] : arenas_) { - if (arena.capacity > 0) aclrtFree(arena.buf); + for (auto& [key, arena] : arenas_) { + if (arena && arena->capacity > 0) { + // The CANN runtime may already be torn down when this static + // destructor runs. `aclrtGetDevice` fails in that case — + // skip the free to avoid glibc "double free" abort. + int32_t dev_id = -1; + + if (aclrtGetDevice(&dev_id) == ACL_SUCCESS) { + aclrtFree(arena->buf); + } else { + fprintf(stderr, + "[InfiniOps] `WorkspacePool`: CANN runtime already " + "finalized, skipping `aclrtFree` (%" PRIu64 + " bytes leaked).\n", + arena->capacity); + } + } } } private: - std::unordered_map arenas_; + struct SlotKey { + aclrtStream stream; + + std::string slot; + + bool operator==(const SlotKey& o) const { + return stream == o.stream && slot == o.slot; + } + }; + + struct SlotKeyHash { + size_t operator()(const SlotKey& k) const { + auto h1 = std::hash{}(static_cast(k.stream)); + auto h2 = std::hash{}(k.slot); + + return h1 ^ (h2 << 1); + } + }; + + std::unordered_map, SlotKeyHash> + arenas_; std::mutex mutex_; + + bool capturing_ = false; }; inline WorkspacePool& workspacePool() { static WorkspacePool pool; + return pool; } diff --git a/src/base/add_rms_norm.h b/src/base/add_rms_norm.h index 3c888917..8243a53c 100644 --- a/src/base/add_rms_norm.h +++ b/src/base/add_rms_norm.h @@ -11,26 +11,23 @@ namespace infini::ops { class AddRmsNorm : public Operator { public: - // TODO: Make `eps` an `std::optional` with a PyTorch-aligned default. - // Also consider the same change for `RmsNorm`. - AddRmsNorm(const Tensor input, const Tensor other, const Tensor weight, - float eps, Tensor out, Tensor rstd_out) - : input_shape_{input.shape()}, + AddRmsNorm(const Tensor x1, const Tensor x2, const Tensor gamma, float eps, + Tensor y_out, Tensor x_out) + : input_shape_{x1.shape()}, eps_{eps}, - dim_{input.size(-1)}, - ndim_{input.ndim()}, - batch_size_{ndim_ == 2 ? input.size(-2) : input.size(-3)}, - nhead_{ndim_ == 2 ? 1 : input.size(-2)}, + dim_{x1.size(-1)}, + ndim_{x1.ndim()}, + batch_size_{ndim_ == 2 ? x1.size(-2) : x1.size(-3)}, + nhead_{ndim_ == 2 ? 1 : x1.size(-2)}, rstd_shape_{static_cast(batch_size_), static_cast(nhead_)} { - assert(input.dtype() == other.dtype()); - assert(input.dtype() == out.dtype()); - assert(input.dtype() == rstd_out.dtype()); + assert(x1.dtype() == x2.dtype()); + assert(x1.dtype() == y_out.dtype()); + assert(x1.dtype() == x_out.dtype()); } - virtual void operator()(const Tensor input, const Tensor other, - const Tensor weight, float eps, Tensor out, - Tensor rstd_out) const = 0; + virtual void operator()(const Tensor x1, const Tensor x2, const Tensor gamma, + float eps, Tensor y_out, Tensor x_out) const = 0; protected: Tensor::Shape input_shape_; diff --git a/src/base/cast.h b/src/base/cast.h new file mode 100644 index 00000000..29f1f40c --- /dev/null +++ b/src/base/cast.h @@ -0,0 +1,52 @@ +#ifndef INFINI_OPS_BASE_CAST_H_ +#define INFINI_OPS_BASE_CAST_H_ + +#include "operator.h" + +namespace infini::ops { + +class Cast : public Operator { + public: + Cast(const Tensor input, Tensor out) + : ndim_{out.ndim()}, + output_size_{out.numel()}, + input_dtype_{input.dtype()}, + out_dtype_{out.dtype()}, + input_shape_{input.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + out_strides_{out.strides()}, + is_input_contiguous_{input.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(input.numel() == out.numel() && + "the input and output of `Cast` must have the same number of " + "elements"); + } + + virtual void operator()(const Tensor input, Tensor out) const = 0; + + protected: + Tensor::Size ndim_{0}; + + Tensor::Size output_size_{0}; + + const DataType input_dtype_; + + const DataType out_dtype_; + + Tensor::Shape input_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides out_strides_; + + bool is_input_contiguous_{false}; + + bool is_out_contiguous_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/cat.h b/src/base/cat.h new file mode 100644 index 00000000..dcb0ba58 --- /dev/null +++ b/src/base/cat.h @@ -0,0 +1,35 @@ +#ifndef INFINI_OPS_BASE_CAT_H_ +#define INFINI_OPS_BASE_CAT_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +class Cat : public Operator { + public: + Cat(const Tensor first_input, std::vector rest_inputs, int64_t dim, + Tensor out) + : input_count_{1 + rest_inputs.size()} { + assert(input_count_ >= 2 && "`Cat` requires at least 2 input tensors"); + + auto ndim = static_cast(out.ndim()); + // Normalize negative dim (e.g. -1 means last dimension). + dim_ = dim < 0 ? dim + ndim : dim; + assert(dim_ >= 0 && dim_ < ndim && "`Cat` dim out of range"); + } + + virtual void operator()(const Tensor first_input, + std::vector rest_inputs, int64_t dim, + Tensor out) const = 0; + + protected: + int64_t dim_; + + size_t input_count_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/flash_attention.h b/src/base/flash_attention.h index e5952b51..1e8baad4 100644 --- a/src/base/flash_attention.h +++ b/src/base/flash_attention.h @@ -9,14 +9,6 @@ namespace infini::ops { -// Fused multi-head / grouped-query attention. -// -// Interface follows vLLM v1 `AttentionImpl.forward()`: -// `vllm.v1.attention.backends.abstract.AttentionImpl` -// -// Layout: `query` / `key` / `value` are `[T, N, D]` (TND). -// Prefill uses `cu_seqlens_q` / `cu_seqlens_kv` for variable-length packing. -// Decode uses `block_table` for paged KV cache lookup. class FlashAttention : public Operator { public: FlashAttention(const Tensor query, const Tensor key, const Tensor value, @@ -48,7 +40,7 @@ class FlashAttention : public Operator { has_cu_seqlens_kv_{cu_seqlens_kv.has_value()}, has_block_table_{block_table.has_value()} { assert(num_heads % num_kv_heads == 0 && - "`FlashAttention` requires num_heads divisible by num_kv_heads"); + "`FlashAttention` requires `num_heads` divisible by `num_kv_heads`"); assert(query.ndim() == 3 && "`FlashAttention` requires query to be 3D [T, N, D]"); } diff --git a/src/base/linear.h b/src/base/linear.h new file mode 100644 index 00000000..a5276e61 --- /dev/null +++ b/src/base/linear.h @@ -0,0 +1,65 @@ +#ifndef INFINI_OPS_BASE_LINEAR_H_ +#define INFINI_OPS_BASE_LINEAR_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +// Fused linear projection: out = a @ b (+ bias). +// +// When bias is present, computes out = a @ b + bias in a single dispatch. +// When bias is absent, computes out = a @ b (equivalent to Matmul). +// `trans_a` / `trans_b`: If true, transpose the last two dims before +// multiplying. +class Linear : public Operator { + public: + Linear(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : a_shape_{a.shape()}, + b_shape_{b.shape()}, + out_shape_{out.shape()}, + a_strides_{a.strides()}, + b_strides_{b.strides()}, + out_strides_{out.strides()}, + trans_a_{trans_a}, + trans_b_{trans_b}, + has_bias_{bias.has_value()} { + assert(a.dtype() == b.dtype() && + "operator `Linear` requires a and b to have the same dtype"); + assert(a.dtype() == out.dtype() && + "operator `Linear` requires a and out to have the same dtype"); + if (has_bias_) { + assert(bias->dtype() == out.dtype() && + "operator `Linear` requires bias and out to have the same dtype"); + } + } + + virtual void operator()(const Tensor a, const Tensor b, + std::optional bias, bool trans_a, + bool trans_b, Tensor out) const = 0; + + protected: + Tensor::Shape a_shape_; + + Tensor::Shape b_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides a_strides_; + + Tensor::Strides b_strides_; + + Tensor::Strides out_strides_; + + bool trans_a_{false}; + + bool trans_b_{false}; + + bool has_bias_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/mat_mul.h b/src/base/mat_mul.h deleted file mode 100644 index 6180c8bf..00000000 --- a/src/base/mat_mul.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef INFINI_OPS_BASE_MAT_MUL_H_ -#define INFINI_OPS_BASE_MAT_MUL_H_ - -#include "operator.h" -#include "tensor.h" - -namespace infini::ops { - -class MatMul : public Operator { - public: - MatMul(const Tensor input, const Tensor other, Tensor out) - : input_shape_{input.shape()}, - other_shape_{other.shape()}, - out_shape_{out.shape()} { - assert(input.dtype() == other.dtype()); - } - - virtual void operator()(const Tensor input, const Tensor other, - Tensor out) const = 0; - - protected: - Tensor::Shape input_shape_; - - Tensor::Shape other_shape_; - - Tensor::Shape out_shape_; -}; - -} // namespace infini::ops - -#endif diff --git a/src/base/matmul.h b/src/base/matmul.h new file mode 100644 index 00000000..071feaea --- /dev/null +++ b/src/base/matmul.h @@ -0,0 +1,41 @@ +#ifndef INFINI_OPS_BASE_MATMUL_H_ +#define INFINI_OPS_BASE_MATMUL_H_ + +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class Matmul : public Operator { + public: + // `trans_a` / `trans_b`: If true, transpose the last two dims of `a` / `b` + // before multiplying. These are constructor parameters so the `CacheKey` + // encodes the transposition and distinct descriptors are cached for each + // combination. + Matmul(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) + : a_shape_{a.shape()}, + b_shape_{b.shape()}, + c_shape_{c.shape()}, + trans_a_{trans_a}, + trans_b_{trans_b} { + assert(a.dtype() == b.dtype()); + } + + virtual void operator()(const Tensor a, const Tensor b, Tensor c, + bool trans_a, bool trans_b) const = 0; + + protected: + Tensor::Shape a_shape_; + + Tensor::Shape b_shape_; + + Tensor::Shape c_shape_; + + bool trans_a_{false}; + + bool trans_b_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/mul.h b/src/base/mul.h new file mode 100644 index 00000000..9e7be223 --- /dev/null +++ b/src/base/mul.h @@ -0,0 +1,67 @@ +#ifndef INFINI_OPS_BASE_MUL_H_ +#define INFINI_OPS_BASE_MUL_H_ + +#include "operator.h" + +namespace infini::ops { + +class Mul : public Operator { + public: + Mul(const Tensor input, const Tensor other, Tensor out) + : ndim_{out.ndim()}, + output_size_{out.numel()}, + input_type_{input.dtype()}, + other_type_{other.dtype()}, + out_type_{out.dtype()}, + input_shape_{input.shape()}, + other_shape_{other.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + other_strides_{other.strides()}, + out_strides_{out.strides()}, + is_input_contiguous_{input.IsContiguous()}, + is_other_contiguous_{other.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(!out.HasBroadcastDim() && + "the output of `Mul` should NOT have broadcasted dim!"); + assert(input_type_ == other_type_ && other_type_ == out_type_ && + "operator `Mul` requires all input and output tensors to have the " + "same dtype"); + } + + virtual void operator()(const Tensor input, const Tensor other, + Tensor out) const = 0; + + protected: + Tensor::Size ndim_{0}; + + Tensor::Size output_size_{0}; + + const DataType input_type_; + + const DataType other_type_; + + const DataType out_type_; + + Tensor::Shape input_shape_; + + Tensor::Shape other_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides other_strides_; + + Tensor::Strides out_strides_; + + bool is_input_contiguous_{false}; + + bool is_other_contiguous_{false}; + + bool is_out_contiguous_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/paged_attention.h b/src/base/paged_attention.h new file mode 100644 index 00000000..ede40f4d --- /dev/null +++ b/src/base/paged_attention.h @@ -0,0 +1,105 @@ +#ifndef INFINI_OPS_BASE_PAGED_ATTENTION_H_ +#define INFINI_OPS_BASE_PAGED_ATTENTION_H_ + +#include +#include + +#include "operator.h" + +namespace infini::ops { + +// Paged decode attention operator. +// +// Performs multi-head attention over paged KV caches for decode (single-token +// queries per sequence). +// +// Interface follows vLLM's paged attention convention: +// - vLLM CUDA: `torch.ops.vllm.paged_attention_v1` uses the same query +// shape [batch, num_heads, head_size] and seq_lens [batch] int32. +// KV cache differs (5D on CUDA for vectorization, 4D here). +// - vLLM-Ascend: `torch_npu._npu_paged_attention` wraps ATB +// `PagedAttentionParam` with default `inputLayout` (`TYPE_BSND`). +// - ATB `PagedAttentionParam`: `headNum`, `kvHeadNum`, `qkScale`, +// `maskType` (default NORM), `inputLayout` (default `TYPE_BSND`). +// +// Input layout (BSND with S=1 for decode): +// query : [batch, num_heads, head_size] +// key_cache : [num_blocks, block_size, num_kv_heads, head_size] +// value_cache : [num_blocks, block_size, num_kv_heads, head_size] +// seq_lens : [batch] int32 — total context length per sequence +// block_table : [batch, max_num_blocks_per_seq] int32 +// +// Output layout: +// output : [batch, num_heads, head_size] +class PagedAttention : public Operator { + public: + PagedAttention(const Tensor query, const Tensor key_cache, + const Tensor value_cache, const Tensor seq_lens, + const Tensor block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + int64_t block_size, Tensor output) + : batch_size_{query.size(0)}, + num_heads_{num_heads}, + num_kv_heads_{num_kv_heads}, + head_size_{head_size}, + scale_{scale}, + block_size_{block_size}, + dtype_{query.dtype()}, + query_shape_{query.shape()}, + key_cache_shape_{key_cache.shape()}, + value_cache_shape_{value_cache.shape()}, + seq_lens_shape_{seq_lens.shape()}, + block_table_shape_{block_table.shape()}, + output_shape_{output.shape()} { + assert(num_heads % num_kv_heads == 0 && + "`PagedAttention` requires `num_heads` divisible by `num_kv_heads`"); + assert(query.ndim() == 3 && + "`PagedAttention` requires query to be 3D [batch, num_heads, " + "head_size]"); + assert(key_cache.ndim() == 4 && + "`PagedAttention` requires key_cache to be 4D [num_blocks, " + "block_size, num_kv_heads, head_size]"); + assert(seq_lens.ndim() == 1 && + "`PagedAttention` requires seq_lens to be 1D [batch]"); + assert(block_table.ndim() == 2 && + "`PagedAttention` requires block_table to be 2D [batch, " + "max_num_blocks]"); + } + + virtual void operator()(const Tensor query, const Tensor key_cache, + const Tensor value_cache, const Tensor seq_lens, + const Tensor block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + int64_t block_size, Tensor output) const = 0; + + protected: + Tensor::Size batch_size_{0}; + + int64_t num_heads_{0}; + + int64_t num_kv_heads_{0}; + + int64_t head_size_{0}; + + double scale_{0.0}; + + int64_t block_size_{0}; + + const DataType dtype_; + + Tensor::Shape query_shape_; + + Tensor::Shape key_cache_shape_; + + Tensor::Shape value_cache_shape_; + + Tensor::Shape seq_lens_shape_; + + Tensor::Shape block_table_shape_; + + Tensor::Shape output_shape_; +}; + +} // namespace infini::ops + +#endif // INFINI_OPS_BASE_PAGED_ATTENTION_H_ diff --git a/src/base/reshape_and_cache.h b/src/base/reshape_and_cache.h index 4bbd5db8..5d0adfad 100644 --- a/src/base/reshape_and_cache.h +++ b/src/base/reshape_and_cache.h @@ -8,15 +8,6 @@ namespace infini::ops { -// Scatter `key` / `value` tokens into a paged KV cache. -// -// Interface follows vLLM's `reshape_and_cache` kernel: -// `vllm._custom_ops.reshape_and_cache_flash` -// -// `kv_cache` layout: `[2, num_blocks, block_size, num_kv_heads, head_size]`. -// `slot_mapping`: 1D `[num_tokens]`, each entry is the linear slot index -// into the cache. Padding tokens must be filtered by the caller (no -// negative indices). class ReshapeAndCache : public Operator { public: ReshapeAndCache(const Tensor key, const Tensor value, const Tensor kv_cache, diff --git a/src/base/rotary_embedding.h b/src/base/rotary_embedding.h index 10426ee8..3fc081c6 100644 --- a/src/base/rotary_embedding.h +++ b/src/base/rotary_embedding.h @@ -8,15 +8,6 @@ namespace infini::ops { -// Rotary position embedding (RoPE) applied in-place to Q and K. -// -// Interface follows vLLM's `RotaryEmbedding.forward_oot()`: -// `vllm.model_executor.layers.rotary_embedding.RotaryEmbedding` -// -// `positions`: `[T]` token position indices. -// `cos_sin_cache`: precomputed `[max_seq_len, rotary_dim]` table. -// `query` / `key`: `[T, N, D]` (TND layout), mutated in-place into -// `query_out` / `key_out`. class RotaryEmbedding : public Operator { public: RotaryEmbedding(const Tensor positions, const Tensor query, const Tensor key, @@ -43,7 +34,7 @@ class RotaryEmbedding : public Operator { assert(key.ndim() == 3 && "`RotaryEmbedding` requires key to be 3D [T, N_kv, D]"); assert(rotary_dim <= head_size && - "`RotaryEmbedding` requires rotary_dim <= head_size"); + "`RotaryEmbedding` requires `rotary_dim` <= `head_size`"); } virtual void operator()(const Tensor positions, const Tensor query, diff --git a/src/base/silu_and_mul.h b/src/base/silu_and_mul.h new file mode 100644 index 00000000..9258ace1 --- /dev/null +++ b/src/base/silu_and_mul.h @@ -0,0 +1,51 @@ +#ifndef INFINI_OPS_BASE_SILU_AND_MUL_H_ +#define INFINI_OPS_BASE_SILU_AND_MUL_H_ + +#include "operator.h" + +namespace infini::ops { + +class SiluAndMul : public Operator { + public: + SiluAndMul(const Tensor x, int64_t dim, Tensor out) + : x_shape_{x.shape()}, + x_strides_{x.strides()}, + out_shape_{out.shape()}, + out_strides_{out.strides()}, + x_dtype_{x.dtype()}, + out_dtype_{out.dtype()}, + dim_{dim}, + ndim_{x.ndim()}, + is_x_contiguous_{x.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(x_dtype_ == out_dtype_ && + "operator `SiluAndMul` requires x and out to have the same dtype"); + } + + virtual void operator()(const Tensor x, int64_t dim, Tensor out) const = 0; + + protected: + Tensor::Shape x_shape_; + + Tensor::Strides x_strides_; + + Tensor::Shape out_shape_; + + Tensor::Strides out_strides_; + + const DataType x_dtype_; + + const DataType out_dtype_; + + int64_t dim_; + + Tensor::Size ndim_; + + bool is_x_contiguous_; + + bool is_out_contiguous_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/cast/cast.h b/src/cpu/cast/cast.h new file mode 100644 index 00000000..67c8367c --- /dev/null +++ b/src/cpu/cast/cast.h @@ -0,0 +1,57 @@ +#ifndef INFINI_OPS_CPU_CAST_CAST_H_ +#define INFINI_OPS_CPU_CAST_CAST_H_ + +#include "base/cast.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Cast { + public: + Operator(const Tensor input, Tensor out) : Cast{input, out} {} + + void operator()(const Tensor input, Tensor out) const override { + DispatchFunc( + input_dtype_, + [&](auto in_tag) { + using InT = typename decltype(in_tag)::type; + DispatchFunc( + out_dtype_, + [&](auto out_tag) { + using OutT = typename decltype(out_tag)::type; + Compute(input, out); + }, + "`Operator::operator()` (out)"); + }, + "`Operator::operator()` (in)"); + } + + private: + template + void Compute(const Tensor input, Tensor out) const { + const auto* in_ptr = static_cast(input.data()); + auto* out_ptr = static_cast(out.data()); + + auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, + const auto* strides) { + return is_contig ? i : utils::IndexToOffset(i, ndim_, shape, strides); + }; + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size_; ++i) { + auto in_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), + input_strides_.data()); + auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), + out_strides_.data()); + + out_ptr[out_idx] = + Caster::template Cast(in_ptr[in_idx]); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/cat/cat.h b/src/cpu/cat/cat.h new file mode 100644 index 00000000..18b45247 --- /dev/null +++ b/src/cpu/cat/cat.h @@ -0,0 +1,71 @@ +#ifndef INFINI_OPS_CPU_CAT_CAT_H_ +#define INFINI_OPS_CPU_CAT_CAT_H_ + +#include +#include + +#include "base/cat.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Cat { + public: + Operator(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) + : Cat{first_input, rest_inputs, dim, out} {} + + void operator()(const Tensor first_input, std::vector rest_inputs, + int64_t /*dim*/, Tensor out) const override { + // Collect all input tensors. + std::vector inputs; + inputs.reserve(input_count_); + inputs.push_back(&first_input); + for (const auto& t : rest_inputs) { + inputs.push_back(&t); + } + + // Use normalized `dim_` from base class (handles negative dim). + auto dim = dim_; + auto elem_size = kDataTypeToSize.at(out.dtype()); + auto ndim = out.ndim(); + auto out_shape = out.shape(); + + // Compute outer and inner sizes relative to the cat dimension. + Tensor::Size outer = 1; + for (int64_t i = 0; i < dim; ++i) { + outer *= out_shape[i]; + } + + Tensor::Size inner = 1; + for (size_t i = static_cast(dim) + 1; i < ndim; ++i) { + inner *= out_shape[i]; + } + + auto* out_ptr = static_cast(out.data()); + Tensor::Size out_dim_size = out_shape[dim]; + + // For each outer index, copy slices from each input along the cat dim. + for (Tensor::Size o = 0; o < outer; ++o) { + Tensor::Size offset_in_dim = 0; + + for (size_t t = 0; t < input_count_; ++t) { + auto in_dim = inputs[t]->shape()[dim]; + auto in_ptr = static_cast(inputs[t]->data()); + + auto src_offset = (o * in_dim) * inner * elem_size; + auto dst_offset = + (o * out_dim_size + offset_in_dim) * inner * elem_size; + auto copy_size = in_dim * inner * elem_size; + + std::memcpy(out_ptr + dst_offset, in_ptr + src_offset, copy_size); + offset_in_dim += in_dim; + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/linear/linear.h b/src/cpu/linear/linear.h new file mode 100644 index 00000000..f5323c2f --- /dev/null +++ b/src/cpu/linear/linear.h @@ -0,0 +1,108 @@ +#ifndef INFINI_OPS_CPU_LINEAR_LINEAR_H_ +#define INFINI_OPS_CPU_LINEAR_LINEAR_H_ + +#include + +#include "base/linear.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Linear, + Caster { + public: + Operator(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : Linear{a, b, bias, trans_a, trans_b, out} {} + + void operator()(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const override { + DispatchFunc( + out.dtype(), + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(a, b, bias, trans_a, trans_b, out); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const { + const auto* A = static_cast(a.data()); + const auto* B = static_cast(b.data()); + auto* Out = static_cast(out.data()); + const T* Bias = bias ? static_cast(bias->data()) : nullptr; + + // Determine M, K, N from shapes and transpose flags. + auto ndim_a = a_shape_.size(); + auto ndim_b = b_shape_.size(); + auto ndim_out = out_shape_.size(); + + Tensor::Size M = out_shape_[ndim_out - 2]; + Tensor::Size N = out_shape_[ndim_out - 1]; + Tensor::Size K = trans_a ? a_shape_[ndim_a - 2] : a_shape_[ndim_a - 1]; + + // Compute strides for the inner matrix dimensions after transpose. + Tensor::Stride stride_a_m = + trans_a ? a_strides_[ndim_a - 1] : a_strides_[ndim_a - 2]; + Tensor::Stride stride_a_k = + trans_a ? a_strides_[ndim_a - 2] : a_strides_[ndim_a - 1]; + Tensor::Stride stride_b_k = + trans_b ? b_strides_[ndim_b - 1] : b_strides_[ndim_b - 2]; + Tensor::Stride stride_b_n = + trans_b ? b_strides_[ndim_b - 2] : b_strides_[ndim_b - 1]; + Tensor::Stride stride_out_m = out_strides_[ndim_out - 2]; + Tensor::Stride stride_out_n = out_strides_[ndim_out - 1]; + + // Batch dimensions. + Tensor::Size batch_count = 1; + for (size_t i = 0; i + 2 < ndim_out; ++i) { + batch_count *= out_shape_[i]; + } + + Tensor::Stride batch_stride_a = ndim_a > 2 ? a_strides_[ndim_a - 3] : 0; + Tensor::Stride batch_stride_b = ndim_b > 2 ? b_strides_[ndim_b - 3] : 0; + Tensor::Stride batch_stride_out = + ndim_out > 2 ? out_strides_[ndim_out - 3] : 0; + + // Bias stride: for 1D bias [N], stride is 1. For batched bias, use last + // stride. + Tensor::Stride bias_stride = 0; + if (Bias && bias) { + auto ndim_bias = bias->shape().size(); + bias_stride = bias->strides()[ndim_bias - 1]; + } + + for (Tensor::Size batch = 0; batch < batch_count; ++batch) { + const auto* A_batch = A + batch * batch_stride_a; + const auto* B_batch = B + batch * batch_stride_b; + auto* Out_batch = Out + batch * batch_stride_out; + + for (Tensor::Size i = 0; i < M; ++i) { + for (Tensor::Size j = 0; j < N; ++j) { + float sum = 0.0f; + + for (Tensor::Size l = 0; l < K; ++l) { + float a_val = Cast(A_batch[i * stride_a_m + l * stride_a_k]); + float b_val = Cast(B_batch[l * stride_b_k + j * stride_b_n]); + sum += a_val * b_val; + } + + if (Bias) { + sum += Cast(Bias[j * bias_stride]); + } + + Out_batch[i * stride_out_m + j * stride_out_n] = Cast(sum); + } + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/mul/mul.h b/src/cpu/mul/mul.h new file mode 100644 index 00000000..0bdefb96 --- /dev/null +++ b/src/cpu/mul/mul.h @@ -0,0 +1,63 @@ +#ifndef INFINI_OPS_CPU_MUL_MUL_H_ +#define INFINI_OPS_CPU_MUL_MUL_H_ + +#include + +#include "base/mul.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Mul, + Caster { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Mul{input, other, out} {} + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + DispatchFunc( + out_type_, + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(input, other, out); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(const Tensor input, const Tensor other, Tensor out) const { + using ComputeType = std::conditional_t || + IsFP16, + float, T>; + + const auto* input_ptr = static_cast(input.data()); + const auto* other_ptr = static_cast(other.data()); + auto* out_ptr = static_cast(out.data()); + + auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, + const auto* strides) { + return is_contig ? i : utils::IndexToOffset(i, ndim_, shape, strides); + }; + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size_; ++i) { + auto input_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), + input_strides_.data()); + auto other_idx = get_idx(i, is_other_contiguous_, other_shape_.data(), + other_strides_.data()); + auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), + out_strides_.data()); + + out_ptr[out_idx] = Cast(Cast(input_ptr[input_idx]) * + Cast(other_ptr[other_idx])); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/hash.h b/src/hash.h index efb34f75..4721f33f 100644 --- a/src/hash.h +++ b/src/hash.h @@ -2,6 +2,7 @@ #define INFINI_OPS_HASH_H_ #include +#include template inline void HashCombine(std::size_t& seed, const T& v) { @@ -9,4 +10,12 @@ inline void HashCombine(std::size_t& seed, const T& v) { seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +template +inline void HashCombine(std::size_t& seed, const std::vector& v) { + HashCombine(seed, v.size()); + for (const auto& elem : v) { + HashCombine(seed, elem); + } +} + #endif diff --git a/src/operator.h b/src/operator.h index dbe92d7d..93d55822 100644 --- a/src/operator.h +++ b/src/operator.h @@ -37,6 +37,14 @@ struct CacheKey { tensors.push_back(t); } + void Absorb(const std::vector& ts) { + HashCombine(hash, ts.size()); + for (const auto& t : ts) { + HashCombine(hash, t); + tensors.push_back(t); + } + } + template void Absorb(const T& v) { HashCombine(hash, v); @@ -179,7 +187,7 @@ class Operator : public OperatorBase { if (it == cache.end()) { // Pass args as lvalue refs so they remain valid for the `operator()` call // below. Forwarding rvalue temporaries into `make()` would leave the args - // in a moved-from (empty) state before operator() can use them. + // in a moved-from (empty) state before `operator()` can use them. it = cache.emplace(std::move(key), make(config, args...)).first; } diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index 766b6eab..acbb52b4 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -118,10 +118,20 @@ inline Tensor TensorFromPybind11Handle(py::handle obj) { inline std::optional OptionalTensorFromPybind11Handle( const std::optional& obj) { - if (!obj.has_value()) return std::nullopt; + if (!obj.has_value() || obj->is_none()) return std::nullopt; return TensorFromPybind11Handle(*obj); } +inline std::vector VectorTensorFromPybind11Handle( + const std::vector& objs) { + std::vector result; + result.reserve(objs.size()); + for (const auto& obj : objs) { + result.push_back(TensorFromPybind11Handle(obj)); + } + return result; +} + } // namespace infini::ops #endif diff --git a/tests/conftest.py b/tests/conftest.py index 8a72355e..905e011a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ def pytest_addoption(parser): "--devices", nargs="+", default=None, - help="Device(s) to test on (e.g., `--devices ascend cpu`). Accepts platform names (`nvidia`, `metax`, `iluvatar`, `moore`, `cambricon`, `ascend`) or PyTorch device types (`cuda`, `mlu`, `musa`, `npu`). Defaults to all available devices.", + help="Device(s) to test on (e.g., --devices ascend cpu). Accepts platform names (ascend, nvidia, cambricon, metax, moore, iluvatar) or PyTorch device types (npu, cuda, mlu, musa). Defaults to all available devices.", ) @@ -46,8 +46,7 @@ def set_seed_per_test(request): _NPU_UNSUPPORTED_DTYPES = {torch.float64} -# `torch_npu` does not implement random number generation for -# `uint16`/`uint32`/`uint64`. +# `torch_npu` does not implement random number generation for `uint16`/`uint32`/`uint64`. for _bits in (16, 32, 64): _t = getattr(torch, f"uint{_bits}", None) if _t is not None: @@ -55,7 +54,7 @@ def set_seed_per_test(request): @pytest.fixture(autouse=True) -def skip_unsupported_dtypes(request): +def skip_unsupported_dtype(request): if not hasattr(request.node, "callspec"): return @@ -72,16 +71,16 @@ def _set_random_seed(seed): _PLATFORM_TO_TORCH_DEVICE = { "nvidia": "cuda", - "metax": "cuda", "iluvatar": "cuda", - "moore": "musa", + "metax": "cuda", "cambricon": "mlu", + "moore": "musa", "ascend": "npu", } def _resolve_device(name): - """Map a platform name (e.g., `ascend`) to a PyTorch device type (e.g., `npu`).""" + """Map a platform name (e.g., ``ascend``) to a PyTorch device type (e.g., ``npu``).""" return _PLATFORM_TO_TORCH_DEVICE.get(name, name) diff --git a/tests/test_add.py b/tests/test_add.py index 8b8166c3..f5604355 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -2,7 +2,13 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randint_strided, randn_strided +from tests.utils import ( + Payload, + empty_strided, + get_npu_stream, + randint_strided, + randn_strided, +) _INT_DTYPES = (torch.int16, torch.int32, torch.int64) @@ -63,7 +69,10 @@ def test_add( def _add(input, other, out): - infini.ops.add(input, other, out) + if input.device.type == "npu": + infini.ops.add(input, other, out, stream=get_npu_stream(input)) + else: + infini.ops.add(input, other, out) return out diff --git a/tests/test_add_rms_norm.py b/tests/test_add_rms_norm.py new file mode 100644 index 00000000..d047641f --- /dev/null +++ b/tests/test_add_rms_norm.py @@ -0,0 +1,107 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, strides", + ( + ((1, 64), None), + ((2, 128), None), + ((4, 48, 64), None), + ((2, 4, 2048), None), + ((1, 64), (64, 1)), + ((4, 48, 64), (3072, 64, 1)), + ), +) +@pytest.mark.parametrize("eps", (1e-6, 1e-5)) +@pytest.mark.parametrize("implementation_index", (0, 1)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-4, 1e-4), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 2e-2, 1e-2), + ), +) +def test_add_rms_norm( + shape, + strides, + eps, + implementation_index, + dtype, + device, + rtol, + atol, +): + active_indices = infini.ops.AddRmsNorm.active_implementation_indices(device) + + if implementation_index not in active_indices: + pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + + weight_shape = (shape[-1],) + x1 = randn_strided(shape, strides, dtype=dtype, device=device) + x2 = randn_strided(shape, strides, dtype=dtype, device=device) + gamma = randn_strided(weight_shape, None, dtype=dtype, device=device) + y_out = empty_strided(shape, strides, dtype=dtype, device=device) + x_out = empty_strided(shape, strides, dtype=dtype, device=device) + + return Payload( + lambda *args, **kwargs: _add_rms_norm( + *args, **kwargs, implementation_index=implementation_index + ), + _torch_add_rms_norm, + (x1, x2, gamma), + {"eps": eps, "y_out": y_out, "x_out": x_out}, + rtol=rtol, + atol=atol, + ) + + +def _add_rms_norm( + x1, x2, gamma, *, eps=1e-6, y_out=None, x_out=None, implementation_index=0 +): + if x1.device.type == "npu": + infini.ops.add_rms_norm( + x1, + x2, + gamma, + eps, + y_out, + x_out, + implementation_index=implementation_index, + stream=get_npu_stream(x1), + ) + else: + infini.ops.add_rms_norm( + x1, + x2, + gamma, + eps, + y_out, + x_out, + implementation_index=implementation_index, + ) + + # Concatenate both outputs into a single flat tensor for `allclose` comparison. + return torch.cat([y_out.contiguous().flatten(), x_out.contiguous().flatten()]) + + +def _torch_add_rms_norm(x1, x2, gamma, *, eps=1e-6, y_out=None, x_out=None): + x_sum = x1 + x2 + + if x_out is not None: + x_out.copy_(x_sum) + + rms = torch.sqrt( + torch.mean(x_sum.float() * x_sum.float(), dim=-1, keepdim=True) + eps + ) + y = (x_sum.float() / rms * gamma.float()).to(x1.dtype) + + if y_out is not None: + y_out.copy_(y) + + return torch.cat([y_out.contiguous().flatten(), x_out.contiguous().flatten()]) diff --git a/tests/test_cast.py b/tests/test_cast.py new file mode 100644 index 00000000..24b50ee9 --- /dev/null +++ b/tests/test_cast.py @@ -0,0 +1,65 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, out_strides", + ( + ((13, 4), None, None), + ((13, 4), (10, 1), (10, 1)), + ((13, 4, 4), None, None), + ((16, 5632), None, None), + ((4, 4, 5632), None, None), + ), +) +@pytest.mark.parametrize( + ("input_dtype", "out_dtype", "rtol", "atol"), + ( + (torch.float16, torch.float32, 1e-3, 1e-3), + (torch.float32, torch.float16, 1e-3, 1e-3), + (torch.bfloat16, torch.float32, 1e-2, 5e-3), + (torch.float32, torch.bfloat16, 1e-2, 5e-3), + (torch.float16, torch.bfloat16, 1e-2, 5e-3), + (torch.bfloat16, torch.float16, 1e-2, 5e-3), + ), +) +def test_cast( + shape, + input_strides, + out_strides, + input_dtype, + out_dtype, + device, + rtol, + atol, +): + input = randn_strided(shape, input_strides, dtype=input_dtype, device=device) + out = empty_strided(shape, out_strides, dtype=out_dtype, device=device) + + return Payload( + _cast, + _torch_cast, + (input, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _cast(input, out): + if input.device.type == "npu": + infini.ops.cast(input, out, stream=get_npu_stream(input)) + else: + infini.ops.cast(input, out) + + return out + + +def _torch_cast(input, out): + out.copy_(input.to(out.dtype)) + + return out diff --git a/tests/test_cat.py b/tests/test_cat.py new file mode 100644 index 00000000..9bbb398c --- /dev/null +++ b/tests/test_cat.py @@ -0,0 +1,72 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shapes, dim, out_shape", + ( + # 2 inputs, dim=0 + (((4, 64), (4, 64)), 0, (8, 64)), + # 2 inputs, dim=1 + (((4, 32), (4, 64)), 1, (4, 96)), + # 2 inputs, dim=-1 (negative dim) + (((4, 32), (4, 64)), -1, (4, 96)), + # 3 inputs, dim=1 + (((4, 16), (4, 32), (4, 16)), 1, (4, 64)), + # 2 inputs, dim=0, 3D + (((2, 4, 64), (2, 4, 64)), 0, (4, 4, 64)), + # 2 inputs, dim=2, 3D + (((2, 4, 32), (2, 4, 64)), 2, (2, 4, 96)), + # 4 inputs, dim=1 + (((1, 1024), (1, 1024), (1, 1024), (1, 1024)), 1, (1, 4096)), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_cat(shapes, dim, out_shape, dtype, device, rtol, atol): + inputs = [randn_strided(s, None, dtype=dtype, device=device) for s in shapes] + out = empty_strided(out_shape, None, dtype=dtype, device=device) + + return Payload( + lambda *args: _cat(*args, dim=dim), + lambda *args: _torch_cat(*args, dim=dim), + (*inputs, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _cat(*args, dim): + inputs = list(args[:-1]) + out = args[-1] + + first = inputs[0] + rest = inputs[1:] + + if first.device.type == "npu": + infini.ops.cat(first, rest, dim, out, stream=get_npu_stream(first)) + else: + infini.ops.cat(first, rest, dim, out) + + return out + + +def _torch_cat(*args, dim): + inputs = list(args[:-1]) + out = args[-1] + + result = torch.cat(inputs, dim=dim) + out.copy_(result) + + return out diff --git a/tests/test_causal_softmax.py b/tests/test_causal_softmax.py index 8b35457a..df4894c3 100644 --- a/tests/test_causal_softmax.py +++ b/tests/test_causal_softmax.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randn_strided +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -40,7 +40,10 @@ def test_causal_softmax(shape, input_strides, out_strides, dtype, device, rtol, def _causal_softmax(input, out): - infini.ops.causal_softmax(input, out) + if input.device.type == "npu": + infini.ops.causal_softmax(input, out, stream=get_npu_stream(input)) + else: + infini.ops.causal_softmax(input, out) return out @@ -48,7 +51,7 @@ def _causal_softmax(input, out): def _torch_causal_softmax(input, out): mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1]) masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32)) - result = torch.nn.functional.softmax(masked, dim=-1, dtype=input.dtype) + result = torch.nn.functional.softmax(masked, dim=-1) out.copy_(result) return out diff --git a/tests/test_e2e_layer.py b/tests/test_e2e_layer.py new file mode 100644 index 00000000..92df9a2c --- /dev/null +++ b/tests/test_e2e_layer.py @@ -0,0 +1,418 @@ +import infini.ops +import pytest +import torch + +from tests.utils import get_npu_stream, randn_strided, randint_strided + + +def _stream_kw(tensor): + if tensor.device.type == "npu": + return {"stream": get_npu_stream(tensor)} + + return {} + + +def _ref_rms_norm(x, weight, eps): + rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + eps) + + return (x / rms) * weight + + +def _ref_rope( + positions, query, key, cos_sin_cache, head_size, rotary_dim, is_neox_style +): + T = query.size(0) + R = rotary_dim + half_R = R // 2 + cos_half = cos_sin_cache[:, :half_R] + sin_half = cos_sin_cache[:, half_R:] + + def apply_rope(x): + out = x.clone() + + for t in range(T): + p = positions[t].item() + c = cos_half[p] + s = sin_half[p] + + if is_neox_style: + x1 = x[t, :, :half_R] + x2 = x[t, :, half_R:R] + out[t, :, :half_R] = c * x1 - s * x2 + out[t, :, half_R:R] = c * x2 + s * x1 + else: + x1 = x[t, :, 0::2] + x2 = x[t, :, 1::2] + out[t, :, 0::2] = c * x1 - s * x2 + out[t, :, 1::2] = c * x2 + s * x1 + + return out + + return apply_rope(query), apply_rope(key) + + +def _ref_sdpa(query, key, value, num_heads, num_kv_heads, head_size, scale, causal): + q = query.transpose(0, 1).float() + k = key.transpose(0, 1).float() + v = value.transpose(0, 1).float() + + if num_kv_heads < num_heads: + ratio = num_heads // num_kv_heads + k = k.repeat_interleave(ratio, dim=0) + v = v.repeat_interleave(ratio, dim=0) + + out = torch.nn.functional.scaled_dot_product_attention( + q.unsqueeze(0), + k.unsqueeze(0), + v.unsqueeze(0), + scale=scale, + is_causal=causal, + ) + + return out.squeeze(0).transpose(0, 1) + + +def _infiniops_layer( + hidden, + positions, + cos_sin_cache, + input_norm_w, + qkv_proj_w, + o_proj_w, + gate_proj_w, + up_proj_w, + down_proj_w, + post_norm_w, + num_heads, + num_kv_heads, + head_size, + rotary_dim, + intermediate_size, + is_neox_style, + eps, + scale, + num_tokens, +): + """Run one LLaMA decoder layer using InfiniOps kernels.""" + kw = _stream_kw(hidden) + dtype = hidden.dtype + device = hidden.device + hidden_size = hidden.size(-1) + + # Save residual. + residual = hidden.clone() + + # 1. Input RMSNorm. + normed = torch.empty_like(hidden) + infini.ops.rms_norm(hidden, input_norm_w, eps, normed, **kw) + + # 2. QKV projection: [T, D] @ [D, (N+2*Nkv)*H] -> [T, (N+2*Nkv)*H]. + qkv_dim = (num_heads + 2 * num_kv_heads) * head_size + qkv = torch.empty(num_tokens, qkv_dim, dtype=dtype, device=device) + infini.ops.gemm(normed, qkv_proj_w, 1.0, 0.0, False, False, qkv, **kw) + + # Split Q, K, V. + q = ( + qkv[:, : num_heads * head_size] + .reshape( + num_tokens, + num_heads, + head_size, + ) + .contiguous() + ) + k = ( + qkv[:, num_heads * head_size : (num_heads + num_kv_heads) * head_size] + .reshape( + num_tokens, + num_kv_heads, + head_size, + ) + .contiguous() + ) + v = ( + qkv[:, (num_heads + num_kv_heads) * head_size :] + .reshape( + num_tokens, + num_kv_heads, + head_size, + ) + .contiguous() + ) + + # 3. RoPE. + q_rot = torch.empty_like(q) + k_rot = torch.empty_like(k) + infini.ops.rotary_embedding( + positions, + q, + k, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + q_rot, + k_rot, + **kw, + ) + + # 4. Flash attention (single-sequence prefill, causal). + attn_out = torch.empty( + num_tokens, + num_heads, + head_size, + dtype=dtype, + device=device, + ) + infini.ops.flash_attention( + q_rot, + k_rot, + v, + None, + None, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + attn_out, + **kw, + ) + + # 5. O projection: [T, N*H] @ [N*H, D] -> [T, D]. + attn_2d = attn_out.reshape(num_tokens, num_heads * head_size) + o_out = torch.empty(num_tokens, hidden_size, dtype=dtype, device=device) + infini.ops.gemm(attn_2d, o_proj_w, 1.0, 0.0, False, False, o_out, **kw) + + # 6. Residual add. + after_attn = torch.empty_like(residual) + infini.ops.add(residual, o_out, after_attn, **kw) + + # 7. Post-attention RMSNorm. + residual2 = after_attn.clone() + normed2 = torch.empty_like(after_attn) + infini.ops.rms_norm(after_attn, post_norm_w, eps, normed2, **kw) + + # 8. Gate + up projections. + gate = torch.empty(num_tokens, intermediate_size, dtype=dtype, device=device) + up = torch.empty(num_tokens, intermediate_size, dtype=dtype, device=device) + infini.ops.gemm(normed2, gate_proj_w, 1.0, 0.0, False, False, gate, **kw) + infini.ops.gemm(normed2, up_proj_w, 1.0, 0.0, False, False, up, **kw) + + # 9. SwiGLU: ``up * silu(gate)``. + ffn = torch.empty(num_tokens, intermediate_size, dtype=dtype, device=device) + infini.ops.swiglu(up, gate, ffn, **kw) + + # 10. Down projection: [T, FFN] @ [FFN, D] -> [T, D]. + down = torch.empty(num_tokens, hidden_size, dtype=dtype, device=device) + infini.ops.gemm(ffn, down_proj_w, 1.0, 0.0, False, False, down, **kw) + + # 11. Second residual add. + output = torch.empty_like(residual2) + infini.ops.add(residual2, down, output, **kw) + + return output + + +def _reference_layer( + hidden, + positions, + cos_sin_cache, + input_norm_w, + qkv_proj_w, + o_proj_w, + gate_proj_w, + up_proj_w, + down_proj_w, + post_norm_w, + num_heads, + num_kv_heads, + head_size, + rotary_dim, + intermediate_size, + is_neox_style, + eps, + scale, + num_tokens, +): + """PyTorch float32 reference for one LLaMA decoder layer.""" + # Compute in float32 on CPU for accuracy. + h = hidden.float().cpu() + pos = positions.cpu() + csc = cos_sin_cache.float().cpu() + inw = input_norm_w.float().cpu() + qkvw = qkv_proj_w.float().cpu() + ow = o_proj_w.float().cpu() + gw = gate_proj_w.float().cpu() + uw = up_proj_w.float().cpu() + dw = down_proj_w.float().cpu() + pnw = post_norm_w.float().cpu() + + # 1. Input RMSNorm. + residual = h.clone() + normed = _ref_rms_norm(h, inw, eps) + + # 2. QKV projection. + qkv = normed @ qkvw + + q = qkv[:, : num_heads * head_size].reshape(num_tokens, num_heads, head_size) + k = qkv[:, num_heads * head_size : (num_heads + num_kv_heads) * head_size].reshape( + num_tokens, + num_kv_heads, + head_size, + ) + v = qkv[:, (num_heads + num_kv_heads) * head_size :].reshape( + num_tokens, + num_kv_heads, + head_size, + ) + + # 3. RoPE. + q_rot, k_rot = _ref_rope( + pos, + q, + k, + csc, + head_size, + rotary_dim, + is_neox_style, + ) + + # 4. SDPA. + attn_out = _ref_sdpa( + q_rot, k_rot, v, num_heads, num_kv_heads, head_size, scale, causal=True + ) + + # 5. O projection. + attn_2d = attn_out.reshape(num_tokens, num_heads * head_size) + o_out = attn_2d @ ow + + # 6. Residual add. + after_attn = residual + o_out + + # 7. Post-attention RMSNorm. + residual2 = after_attn.clone() + normed2 = _ref_rms_norm(after_attn, pnw, eps) + + # 8. Gate + up projections. + gate = normed2 @ gw + up = normed2 @ uw + + # 9. SwiGLU: ``up * silu(gate)``. + ffn = up * (gate * torch.sigmoid(gate)) + + # 10. Down projection. + down = ffn @ dw + + # 11. Second residual add. + output = residual2 + down + + return output.to(hidden.dtype).to(hidden.device) + + +def _make_rope_cache(max_seq_len, rotary_dim, dtype, device): + """Build a proper RoPE cos/sin cache (bounded to [-1, 1]).""" + freq = 1.0 / (10000.0 ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) + t = torch.arange(max_seq_len, dtype=torch.float32) + angles = torch.outer(t, freq) # [max_seq_len, half_dim] + cos_half = torch.cos(angles).to(dtype=dtype, device=device) + sin_half = torch.sin(angles).to(dtype=dtype, device=device) + + return torch.cat([cos_half, sin_half], dim=-1) + + +@pytest.mark.parametrize("device", ("npu",)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 5e-3, 5e-3), + (torch.bfloat16, 1e-2, 2e-2), + ), +) +def test_llama_layer(device, dtype, rtol, atol): + """End-to-end test of a LLaMA decoder layer using InfiniOps kernels.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + # Small LLaMA-like model config. + hidden_size = 512 + num_heads = 8 + num_kv_heads = 2 + head_size = hidden_size // num_heads + intermediate_size = 1024 + num_tokens = 1 + max_seq_len = 16 + rotary_dim = head_size + is_neox_style = True + eps = 1e-6 + scale = 1.0 / head_size**0.5 + + def _scaled_weight(*shape): + return randn_strided(shape, None, dtype=dtype, device=device) / shape[0] ** 0.5 + + # Random weights (stored as [in_features, out_features], Xavier-scaled). + qkv_proj_w = _scaled_weight( + hidden_size, + (num_heads + 2 * num_kv_heads) * head_size, + ) + o_proj_w = _scaled_weight(num_heads * head_size, hidden_size) + gate_proj_w = _scaled_weight(hidden_size, intermediate_size) + up_proj_w = _scaled_weight(hidden_size, intermediate_size) + down_proj_w = _scaled_weight(intermediate_size, hidden_size) + input_norm_w = torch.ones(hidden_size, dtype=dtype, device=device) + post_norm_w = torch.ones(hidden_size, dtype=dtype, device=device) + + # Proper cos/sin cache from frequency decomposition (bounded [-1, 1]). + cos_sin_cache = _make_rope_cache(max_seq_len, rotary_dim, dtype, device) + positions = randint_strided( + 0, + max_seq_len, + (num_tokens,), + None, + dtype=torch.int64, + device=device, + ) + + # Input hidden states scaled to prevent value explosion through layers. + hidden = ( + randn_strided( + (num_tokens, hidden_size), + None, + dtype=dtype, + device=device, + ) + / hidden_size**0.5 + ) + + common = dict( + positions=positions, + cos_sin_cache=cos_sin_cache, + input_norm_w=input_norm_w, + qkv_proj_w=qkv_proj_w, + o_proj_w=o_proj_w, + gate_proj_w=gate_proj_w, + up_proj_w=up_proj_w, + down_proj_w=down_proj_w, + post_norm_w=post_norm_w, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + rotary_dim=rotary_dim, + intermediate_size=intermediate_size, + is_neox_style=is_neox_style, + eps=eps, + scale=scale, + num_tokens=num_tokens, + ) + + infini_out = _infiniops_layer(hidden, **common) + ref_out = _reference_layer(hidden, **common) + + max_diff = (infini_out.float() - ref_out.float()).abs().max().item() + assert torch.allclose(infini_out, ref_out, rtol=rtol, atol=atol), ( + f"Max diff: {max_diff}" + ) diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py new file mode 100644 index 00000000..b016020b --- /dev/null +++ b/tests/test_flash_attention.py @@ -0,0 +1,444 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size", + ( + (32, 32, 128), # MHA + (32, 8, 128), # GQA (4x) + (16, 4, 64), # GQA (4x), smaller + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_prefill_single( + num_heads, + num_kv_heads, + head_size, + dtype, + rtol, + atol, + device, +): + """Single sequence prefill (no block table).""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_tokens = 16 + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + output = torch.empty((num_tokens, num_heads, head_size), dtype=dtype, device=device) + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + None, + None, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + o, + ), + lambda q, k, v, o: _ref_flash_attention( + q, + k, + v, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, + ), + (query, key, value, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size", + ((32, 8, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_prefill_multi( + num_heads, + num_kv_heads, + head_size, + dtype, + rtol, + atol, + device, +): + """Multi-sequence prefill with cu_seqlens.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + seq_lens = [8, 12, 4] + num_tokens = sum(seq_lens) + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + output = torch.empty((num_tokens, num_heads, head_size), dtype=dtype, device=device) + + cu_seqlens_q = torch.tensor( + [0] + [sum(seq_lens[: i + 1]) for i in range(len(seq_lens))], + dtype=torch.int64, + device=device, + ) + cu_seqlens_kv = cu_seqlens_q.clone() + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + o, + ), + lambda q, k, v, o: _ref_flash_attention_multi( + q, + k, + v, + seq_lens, + seq_lens, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, + ), + (query, key, value, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ( + (32, 8, 128, 128), + (16, 4, 64, 128), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_decode( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Decode phase: single token per request with paged KV cache.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 3 + kv_len = 16 # Total KV length per request. + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_reqs * num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + # Paged KV cache: vLLM standard layout [num_blocks, block_size, KV_N, D]. + kv_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + # Block table: request i uses blocks [i*num_blocks_per_req, ...]. + block_table = torch.zeros( + (num_reqs, num_blocks_per_req), dtype=torch.int32, device=device + ) + for i in range(num_reqs): + for j in range(num_blocks_per_req): + block_table[i, j] = i * num_blocks_per_req + j + + cu_seqlens_q = torch.arange(0, num_reqs + 1, dtype=torch.int64, device=device) + cu_seqlens_kv = torch.tensor( + [i * kv_len for i in range(num_reqs + 1)], dtype=torch.int64, device=device + ) + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + block_size, + o, + ), + lambda q, k, v, o: _ref_flash_attention_paged( + q, + k, + block_table, + cu_seqlens_q, + cu_seqlens_kv, + num_heads, + num_kv_heads, + head_size, + block_size, + scale, + causal=True, + ), + (query, kv_cache, kv_cache, output), + {}, + rtol=rtol, + atol=atol, + ) + + +def _flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, +): + if query.device.type == "npu": + infini.ops.flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, + stream=get_npu_stream(query), + ) + else: + infini.ops.flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, + ) + + return output + + +def _ref_flash_attention( + query, key, value, num_heads, num_kv_heads, head_size, scale, causal=True +): + """PyTorch SDPA reference for single-sequence prefill.""" + # [T, N, D] -> [N, T, D] + q = query.transpose(0, 1).float() + k = key.transpose(0, 1).float() + v = value.transpose(0, 1).float() + + # GQA: expand K/V to match num_heads. + if num_kv_heads < num_heads: + ratio = num_heads // num_kv_heads + k = k.repeat_interleave(ratio, dim=0) + v = v.repeat_interleave(ratio, dim=0) + + # [N, T, D] -> [1, N, T, D] for scaled_dot_product_attention. + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, scale=scale, is_causal=causal + ) + + # [1, N, T, D] -> [T, N, D] -> original dtype. + return out.squeeze(0).transpose(0, 1).to(query.dtype) + + +def _ref_flash_attention_multi( + query, + key, + value, + seq_lens_q, + seq_lens_kv, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, +): + """PyTorch SDPA reference for multi-sequence prefill.""" + outputs = [] + offset = 0 + for sq, sk in zip(seq_lens_q, seq_lens_kv): + q = query[offset : offset + sq] + k = key[offset : offset + sq] + v = value[offset : offset + sq] + out = _ref_flash_attention( + q, k, v, num_heads, num_kv_heads, head_size, scale, causal + ) + outputs.append(out) + offset += sq + + return torch.cat(outputs, dim=0) + + +def _ref_flash_attention_paged( + query, + kv_cache_arg, + block_table, + cu_seqlens_q, + cu_seqlens_kv, + num_heads, + num_kv_heads, + head_size, + block_size, + scale, + causal=True, +): + """PyTorch SDPA reference for decode with paged KV cache.""" + cu_kv = cu_seqlens_kv.cpu() + bt = block_table.cpu() + cache = kv_cache_arg.cpu() + q_cpu = query.cpu() + num_reqs = bt.size(0) + outputs = [] + + for i in range(num_reqs): + q = q_cpu[i : i + 1] # [1, N, D] + kv_len = int(cu_kv[i + 1] - cu_kv[i]) + + # Gather KV from paged cache. + # cache: [num_blocks, KV_N, block_size, D] + blocks = bt[i] + k_pages = [] + v_pages = [] + remaining = kv_len + + for b in blocks: + if remaining <= 0: + break + + take = min(remaining, block_size) + # cache layout: [num_blocks, block_size, KV_N, D] + # Slice [take, KV_N, D], transpose to [KV_N, take, D] for cat. + k_pages.append(cache[int(b.item()), :take, :, :].transpose(0, 1)) + v_pages.append(cache[int(b.item()), :take, :, :].transpose(0, 1)) + remaining -= take + k = torch.cat(k_pages, dim=1) # [KV_N, kv_len, D] + v = torch.cat(v_pages, dim=1) + + # Decode: Q_S=1 attends to all past KV positions; causal masking is + # not applicable here (it would mask everything beyond position 0). + out = _ref_flash_attention( + q, # [1, N, D] - already TND format + k.transpose(0, 1), # [KV_N, kv_len, D] -> [kv_len, KV_N, D] + v.transpose(0, 1), + num_heads, + num_kv_heads, + head_size, + scale, + causal=False, + ) + outputs.append(out) + + return torch.cat(outputs, dim=0).to(query.device) diff --git a/tests/test_gemm.py b/tests/test_gemm.py index d3b26884..3f48562f 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, get_stream, randn_strided +from tests.utils import Payload, get_npu_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -84,17 +84,28 @@ def test_gemm( def _gemm(a, b, alpha, beta, trans_a, trans_b, c, implementation_index=0): - infini.ops.gemm( - a, - b, - alpha, - beta, - trans_a, - trans_b, - c, - stream=get_stream(a.device), - implementation_index=implementation_index, - ) + if a.device.type == "npu": + infini.ops.gemm( + a, + b, + alpha, + beta, + trans_a, + trans_b, + c, + stream=get_npu_stream(a), + ) + else: + infini.ops.gemm( + a, + b, + alpha, + beta, + trans_a, + trans_b, + c, + implementation_index=implementation_index, + ) return c diff --git a/tests/test_linear.py b/tests/test_linear.py new file mode 100644 index 00000000..d08bf20e --- /dev/null +++ b/tests/test_linear.py @@ -0,0 +1,93 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "a_shape, b_shape, out_shape", + ( + ((4, 64), (64, 32), (4, 32)), + ((2, 128), (128, 256), (2, 256)), + ((1, 4096), (4096, 4096), (1, 4096)), + ((2, 4, 64), (2, 64, 32), (2, 4, 32)), + ((4, 8, 128), (4, 128, 64), (4, 8, 64)), + ), +) +@pytest.mark.parametrize("trans_a", (False, True)) +@pytest.mark.parametrize("trans_b", (False, True)) +@pytest.mark.parametrize("has_bias", (False, True)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-2, 5e-2), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), + ), +) +def test_linear( + a_shape, + b_shape, + out_shape, + trans_a, + trans_b, + has_bias, + dtype, + device, + rtol, + atol, +): + a = randn_strided(a_shape, None, dtype=dtype, device=device) + b = randn_strided(b_shape, None, dtype=dtype, device=device) + + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + # Bias shape is [N], the last dim of the output. + bias = None + + if has_bias: + N = out_shape[-1] + bias = randn_strided((N,), None, dtype=dtype, device=device) + + out = empty_strided(out_shape, None, dtype=dtype, device=device) + + return Payload( + lambda *args: _linear(*args, trans_a=trans_a, trans_b=trans_b), + lambda *args: _torch_linear(*args, trans_a=trans_a, trans_b=trans_b), + (a, b, bias, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _linear(a, b, bias, out, trans_a=False, trans_b=False): + if a.device.type == "npu": + infini.ops.linear(a, b, bias, trans_a, trans_b, out, stream=get_npu_stream(a)) + else: + infini.ops.linear(a, b, bias, trans_a, trans_b, out) + + return out + + +def _torch_linear(a, b, bias, out, trans_a=False, trans_b=False): + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + result = torch.matmul(a.float(), b.float()) + + if bias is not None: + result = result + bias.float() + + out.copy_(result.to(out.dtype)) + + return out diff --git a/tests/test_matmul.py b/tests/test_matmul.py new file mode 100644 index 00000000..dae3961b --- /dev/null +++ b/tests/test_matmul.py @@ -0,0 +1,79 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "a_shape, b_shape, c_shape", + ( + ((4, 64), (64, 32), (4, 32)), + ((2, 128), (128, 256), (2, 256)), + ((2, 4, 64), (2, 64, 32), (2, 4, 32)), + ((4, 8, 128), (4, 128, 64), (4, 8, 64)), + ), +) +@pytest.mark.parametrize("trans_a", (False, True)) +@pytest.mark.parametrize("trans_b", (False, True)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-2, 1e-2), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), + ), +) +def test_matmul( + a_shape, + b_shape, + c_shape, + trans_a, + trans_b, + dtype, + device, + rtol, + atol, +): + a = randn_strided(a_shape, None, dtype=dtype, device=device) + b = randn_strided(b_shape, None, dtype=dtype, device=device) + + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + c = empty_strided(c_shape, None, dtype=dtype, device=device) + + return Payload( + lambda *args: _matmul(*args, trans_a=trans_a, trans_b=trans_b), + lambda *args: _torch_matmul(*args, trans_a=trans_a, trans_b=trans_b), + (a, b, c), + {}, + rtol=rtol, + atol=atol, + ) + + +def _matmul(a, b, c, trans_a=False, trans_b=False): + if a.device.type == "npu": + infini.ops.matmul(a, b, c, trans_a, trans_b, stream=get_npu_stream(a)) + else: + infini.ops.matmul(a, b, c, trans_a, trans_b) + + return c + + +def _torch_matmul(a, b, c, trans_a=False, trans_b=False): + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + result = torch.matmul(a.float(), b.float()).to(c.dtype) + c.copy_(result) + + return c diff --git a/tests/test_mul.py b/tests/test_mul.py new file mode 100644 index 00000000..ea7f9180 --- /dev/null +++ b/tests/test_mul.py @@ -0,0 +1,90 @@ +import infini.ops +import pytest +import torch + +from tests.utils import ( + Payload, + empty_strided, + get_npu_stream, + randint_strided, + randn_strided, +) + +_INT_DTYPES = (torch.int16, torch.int32, torch.int64) + +_UINT_DTYPES = tuple( + filter(None, (getattr(torch, f"uint{bits}", None) for bits in (16, 32, 64))) +) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, other_strides, out_strides", + ( + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), + ((13, 4, 4), None, None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), + ((16, 5632), None, None, None), + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), + ((13, 16, 2), (128, 4, 1), (0, 2, 1), (64, 4, 1)), + ((13, 16, 2), (128, 4, 1), (2, 0, 1), (64, 4, 1)), + ((4, 4, 5632), None, None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ) + + tuple((dtype, 0, 0) for dtype in _INT_DTYPES + _UINT_DTYPES), +) +def test_mul( + shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol +): + if device == "musa" and dtype in _UINT_DTYPES: + pytest.skip( + "The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`." + ) + + if dtype in _INT_DTYPES or dtype in _UINT_DTYPES: + input = randint_strided( + 0, 100, shape, input_strides, dtype=dtype, device=device + ) + other = randint_strided( + 0, 100, shape, other_strides, dtype=dtype, device=device + ) + else: + input = randn_strided(shape, input_strides, dtype=dtype, device=device) + other = randn_strided(shape, other_strides, dtype=dtype, device=device) + + out = empty_strided(shape, out_strides, dtype=dtype, device=device) + + return Payload(_mul, _torch_mul, (input, other, out), {}, rtol=rtol, atol=atol) + + +def _mul(input, other, out): + if input.device.type == "npu": + infini.ops.mul(input, other, out, stream=get_npu_stream(input)) + else: + infini.ops.mul(input, other, out) + + return out + + +def _torch_mul(input, other, out): + if input.dtype in _UINT_DTYPES: + input = input.to(torch.int64) + + if other.dtype in _UINT_DTYPES: + other = other.to(torch.int64) + + res = torch.mul(input, other) + out.copy_(res.to(out.dtype)) + + return out diff --git a/tests/test_paged_attention.py b/tests/test_paged_attention.py new file mode 100644 index 00000000..17ab0bf0 --- /dev/null +++ b/tests/test_paged_attention.py @@ -0,0 +1,461 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_npu_stream, randn_strided + + +def _atb_pa_available(): + """Check whether ATB PagedAttention works on the current hardware. + + ATB PA is known to crash during `Setup` on Ascend 910B (CANN 8.5.x). + Returns True only when a minimal smoke call succeeds. + """ + if not (hasattr(torch, "npu") and torch.npu.is_available()): + return False + + if not infini.ops.PagedAttention.active_implementation_indices("ascend"): + return False + + try: + B, N, Nkv, D, bs = 1, 4, 4, 64, 16 + q = torch.randn(B, N, D, dtype=torch.float16, device="npu") + kc = torch.randn(1, bs, Nkv, D, dtype=torch.float16, device="npu") + vc = torch.randn(1, bs, Nkv, D, dtype=torch.float16, device="npu") + bt = torch.zeros(B, 1, dtype=torch.int32, device="npu") + sl = torch.tensor([bs], dtype=torch.int32, device="npu") + o = torch.zeros(B, N, D, dtype=torch.float16, device="npu") + infini.ops.paged_attention( + q, + kc, + vc, + sl, + bt, + N, + Nkv, + D, + 1.0 / D**0.5, + bs, + o, + stream=get_npu_stream(q), + ) + torch.npu.synchronize() + + return True + except Exception: + return False + + +_skip_no_atb_pa = pytest.mark.skipif( + not _atb_pa_available(), + reason="ATB PagedAttention not supported on this hardware", +) + + +@_skip_no_atb_pa +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ( + (32, 8, 128, 128), + (16, 4, 64, 128), + (32, 32, 128, 128), # MHA + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_paged_attention_basic( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Basic paged decode attention with contiguous block assignments.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 4 + kv_len = 16 + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_reqs * num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + key_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + value_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + # Block table: request i uses blocks [i*num_blocks_per_req, ...]. + block_table = torch.zeros( + (num_reqs, num_blocks_per_req), dtype=torch.int32, device=device + ) + + for i in range(num_reqs): + for j in range(num_blocks_per_req): + block_table[i, j] = i * num_blocks_per_req + j + + # Context lengths (total KV length per request). + seq_lens = torch.full((num_reqs,), kv_len, dtype=torch.int32, device=device) + + return Payload( + lambda q, kc, vc, sl, bt, o: _paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + o, + ), + lambda q, kc, vc, sl, bt, o: _ref_paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + ), + (query, key_cache, value_cache, seq_lens, block_table, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@_skip_no_atb_pa +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ((32, 8, 128, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_paged_attention_variable_seq_lens( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Paged decode attention where each request has a different KV length.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + kv_lens = [8, 32, 16, 128] + num_reqs = len(kv_lens) + max_blocks_per_req = max((kv + block_size - 1) // block_size for kv in kv_lens) + num_blocks = sum((kv + block_size - 1) // block_size for kv in kv_lens) + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + key_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + value_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + # Block table: assign blocks sequentially. + block_table = torch.zeros( + (num_reqs, max_blocks_per_req), dtype=torch.int32, device=device + ) + block_idx = 0 + + for i in range(num_reqs): + n_blocks = (kv_lens[i] + block_size - 1) // block_size + + for j in range(n_blocks): + block_table[i, j] = block_idx + block_idx += 1 + + seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device) + + return Payload( + lambda q, kc, vc, sl, bt, o: _paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + o, + ), + lambda q, kc, vc, sl, bt, o: _ref_paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + ), + (query, key_cache, value_cache, seq_lens, block_table, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@_skip_no_atb_pa +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ((32, 8, 128, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ((torch.float16, 1e-3, 1e-3),), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_paged_attention_single_request( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Single request decode (batch_size=1).""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 1 + kv_len = 64 + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + key_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + value_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + block_table = torch.arange( + num_blocks_per_req, dtype=torch.int32, device=device + ).unsqueeze(0) + + seq_lens = torch.tensor([kv_len], dtype=torch.int32, device=device) + + return Payload( + lambda q, kc, vc, sl, bt, o: _paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + o, + ), + lambda q, kc, vc, sl, bt, o: _ref_paged_attention( + q, + kc, + vc, + sl, + bt, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + ), + (query, key_cache, value_cache, seq_lens, block_table, output), + {}, + rtol=rtol, + atol=atol, + ) + + +def _paged_attention( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + output, +): + if query.device.type == "npu": + infini.ops.paged_attention( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + output, + stream=get_npu_stream(query), + ) + else: + infini.ops.paged_attention( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, + output, + ) + + return output + + +def _ref_paged_attention( + query, + key_cache, + value_cache, + seq_lens, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + block_size, +): + """PyTorch SDPA reference for paged decode attention.""" + sl = seq_lens.cpu() + bt = block_table.cpu() + kc = key_cache.cpu().float() + vc = value_cache.cpu().float() + q_cpu = query.cpu().float() + num_reqs = bt.size(0) + outputs = [] + + for i in range(num_reqs): + q = q_cpu[i : i + 1] # [1, N, D] + kv_len = int(sl[i].item()) + + # Gather K and V from paged cache. + # Cache layout: [num_blocks, block_size, Nkv, D]. + blocks = bt[i] + k_pages = [] + v_pages = [] + remaining = kv_len + + for b in blocks: + if remaining <= 0: + break + + take = min(remaining, block_size) + k_pages.append(kc[int(b.item()), :take, :, :]) + v_pages.append(vc[int(b.item()), :take, :, :]) + remaining -= take + + # [kv_len, Nkv, D] + k = torch.cat(k_pages, dim=0) + v = torch.cat(v_pages, dim=0) + + # SDPA reference with GQA expansion. + # q: [1, N, D] -> [N, 1, D] + q_t = q.transpose(0, 1) + # k, v: [kv_len, Nkv, D] -> [Nkv, kv_len, D] + k_t = k.transpose(0, 1) + v_t = v.transpose(0, 1) + + if num_kv_heads < num_heads: + ratio = num_heads // num_kv_heads + k_t = k_t.repeat_interleave(ratio, dim=0) + v_t = v_t.repeat_interleave(ratio, dim=0) + + # [N, 1, D] and [N, kv_len, D] -> [1, N, 1, D] and [1, N, kv_len, D] + q_4d = q_t.unsqueeze(0) + k_4d = k_t.unsqueeze(0) + v_4d = v_t.unsqueeze(0) + + # Decode: query attends to all past KV (no causal mask). + out = torch.nn.functional.scaled_dot_product_attention( + q_4d, + k_4d, + v_4d, + scale=scale, + is_causal=False, + ) + + # [1, N, 1, D] -> [1, N, D] + outputs.append(out.squeeze(0).transpose(0, 1).squeeze(0).unsqueeze(0)) + + return torch.cat(outputs, dim=0).to(query.dtype).to(query.device) diff --git a/tests/test_reshape_and_cache.py b/tests/test_reshape_and_cache.py new file mode 100644 index 00000000..de234e2a --- /dev/null +++ b/tests/test_reshape_and_cache.py @@ -0,0 +1,152 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_npu_stream, randn_strided + +# `ReshapeAndCache` only works on NPU (`aclrtMemcpy`-based), so tests only +# parametrize on `float16`/`bfloat16` and use explicit device parametrization. + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, num_blocks, block_size", + ( + (1, 8, 128, 4, 16), + (4, 8, 128, 4, 16), + (8, 4, 64, 8, 32), + (16, 2, 128, 8, 64), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_reshape_and_cache_contiguous( + num_tokens, + num_kv_heads, + head_size, + num_blocks, + block_size, + dtype, + rtol, + atol, + device, +): + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + # Layout: [2, num_blocks, block_size, num_kv_heads, head_size] + # Index 0 = key cache, index 1 = value cache. + kv_cache = torch.zeros( + (2, num_blocks, block_size, num_kv_heads, head_size), + dtype=dtype, + device=device, + ) + # Contiguous slot mapping: token i -> slot i. + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + return Payload( + _reshape_and_cache, + _ref_reshape_and_cache, + (key, value, kv_cache, slot_mapping, kv_cache), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, num_blocks, block_size", + ( + (4, 8, 128, 4, 16), + (8, 4, 64, 8, 32), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_reshape_and_cache_noncontiguous_slots( + num_tokens, + num_kv_heads, + head_size, + num_blocks, + block_size, + dtype, + rtol, + atol, + device, +): + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + kv_cache = torch.zeros( + (2, num_blocks, block_size, num_kv_heads, head_size), + dtype=dtype, + device=device, + ) + # Non-contiguous slots: skip every other slot. + slot_mapping = torch.tensor( + [i * 2 for i in range(num_tokens)], dtype=torch.int64, device=device + ) + + return Payload( + _reshape_and_cache, + _ref_reshape_and_cache, + (key, value, kv_cache, slot_mapping, kv_cache), + {}, + rtol=rtol, + atol=atol, + ) + + +def _reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out): + if key.device.type == "npu": + infini.ops.reshape_and_cache( + key, value, kv_cache, slot_mapping, kv_cache_out, stream=get_npu_stream(key) + ) + else: + infini.ops.reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out) + + return kv_cache_out + + +def _ref_reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out): + kv_cache_out = kv_cache_out.clone() + slots = slot_mapping.cpu() + block_size = kv_cache_out.size(2) + + for i in range(key.size(0)): + slot = int(slots[i].item()) + + if slot < 0: + continue + + block_idx = slot // block_size + offset = slot % block_size + kv_cache_out[0, block_idx, offset, :, :] = key[i, :, :] + kv_cache_out[1, block_idx, offset, :, :] = value[i, :, :] + + return kv_cache_out diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index d6d4dff1..ba540a95 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randn_strided +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -53,7 +53,10 @@ def test_rms_norm( def _rms_norm(input, weight, *, eps=1e-6, out=None): - infini.ops.rms_norm(input, weight, eps, out) + if input.device.type == "npu": + infini.ops.rms_norm(input, weight, eps, out, stream=get_npu_stream(input)) + else: + infini.ops.rms_norm(input, weight, eps, out) return out diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py new file mode 100644 index 00000000..823532a1 --- /dev/null +++ b/tests/test_rotary_embedding.py @@ -0,0 +1,279 @@ +import infini.ops +import pytest +import torch + +from tests.utils import get_npu_stream, randn_strided, randint_strided + + +def _rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + device, +): + if device == "npu": + infini.ops.rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + stream=get_npu_stream(query), + ) + else: + infini.ops.rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + ) + + return query_out, key_out + + +def _ref_rotary_embedding( + positions, query, key, cos_sin_cache, head_size, rotary_dim, is_neox_style +): + """PyTorch reference for RoPE. + + ``cos_sin_cache`` layout: ``[max_seq_len, rotary_dim]`` where the first + ``rotary_dim // 2`` columns are cos and the rest are sin. + """ + T = query.size(0) + R = rotary_dim + half_R = R // 2 + + cos_sin = cos_sin_cache.float() + cos_half = cos_sin[:, :half_R] + sin_half = cos_sin[:, half_R:] + + def apply_rope(x): + out = x.float().clone() + + for t in range(T): + p = positions[t].item() + c = cos_half[p] + s = sin_half[p] + + if is_neox_style: + x1 = x[t, :, :half_R].float() + x2 = x[t, :, half_R:R].float() + out[t, :, :half_R] = c * x1 - s * x2 + out[t, :, half_R:R] = c * x2 + s * x1 + else: + x1 = x[t, :, 0::2].float() + x2 = x[t, :, 1::2].float() + out[t, :, 0::2] = c * x1 - s * x2 + out[t, :, 1::2] = c * x2 + s * x1 + + return out.to(x.dtype) + + return apply_rope(query), apply_rope(key) + + +def _assert_close(actual, expected, rtol, atol): + assert torch.allclose(actual, expected, rtol=rtol, atol=atol), ( + f"Max diff: {(actual.float() - expected.float()).abs().max().item()}" + ) + + +@pytest.mark.parametrize( + "num_heads, head_size", + ( + (32, 128), + (8, 64), + ), +) +@pytest.mark.parametrize("is_neox_style", (True, False)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_rotary_embedding_full( + num_heads, head_size, is_neox_style, dtype, rtol, atol, device +): + """Full rotary: ``rotary_dim == head_size``.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + if device == "npu" and not is_neox_style: + pytest.skip( + "Ascend aclnnApplyRotaryPosEmbV2 only supports neox style " + "(rotaryMode='half')" + ) + + # `aclnnApplyRotaryPosEmbV2` accumulates with ~4 ULP error for `float16`. + if device == "npu" and dtype == torch.float16: + atol = 0.01 + + num_kv_heads = num_heads + rotary_dim = head_size + num_tokens = 16 + max_seq_len = 64 + + positions = randint_strided( + 0, + max_seq_len, + (num_tokens,), + None, + dtype=torch.int64, + device=device, + ) + query = randn_strided( + (num_tokens, num_heads, head_size), + None, + dtype=dtype, + device=device, + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + cos_sin_cache = randn_strided( + (max_seq_len, rotary_dim), + None, + dtype=dtype, + device=device, + ) + query_out = torch.empty_like(query) + key_out = torch.empty_like(key) + + q_out, k_out = _rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + device, + ) + + ref_q, ref_k = _ref_rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + ) + + _assert_close(q_out, ref_q, rtol, atol) + _assert_close(k_out, ref_k, rtol, atol) + + +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, rotary_dim", + ( + (32, 8, 128, 64), + (16, 4, 64, 32), + ), +) +@pytest.mark.parametrize("is_neox_style", (True,)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_rotary_embedding_partial( + num_heads, + num_kv_heads, + head_size, + rotary_dim, + is_neox_style, + dtype, + rtol, + atol, + device, +): + """Partial rotary: ``rotary_dim < head_size``.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + if device == "npu": + pytest.skip("Ascend aclnnApplyRotaryPosEmbV2 requires rotary_dim == head_size") + + num_tokens = 16 + max_seq_len = 64 + + positions = randint_strided( + 0, + max_seq_len, + (num_tokens,), + None, + dtype=torch.int64, + device=device, + ) + query = randn_strided( + (num_tokens, num_heads, head_size), + None, + dtype=dtype, + device=device, + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + cos_sin_cache = randn_strided( + (max_seq_len, rotary_dim), + None, + dtype=dtype, + device=device, + ) + query_out = torch.empty_like(query) + key_out = torch.empty_like(key) + + q_out, k_out = _rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + device, + ) + + ref_q, ref_k = _ref_rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + ) + + _assert_close(q_out, ref_q, rtol, atol) + _assert_close(k_out, ref_k, rtol, atol) diff --git a/tests/test_silu_and_mul.py b/tests/test_silu_and_mul.py new file mode 100644 index 00000000..76d99464 --- /dev/null +++ b/tests/test_silu_and_mul.py @@ -0,0 +1,63 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_npu_stream, rand_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, x_strides, out_strides", + ( + ((13, 8), None, None), + ((16, 11264), None, None), + ((4, 4, 11264), None, None), + ((1, 8), None, None), + ((32, 5632), None, None), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_silu_and_mul(shape, x_strides, out_strides, dtype, device, rtol, atol): + x = rand_strided(shape, x_strides, dtype=dtype, device=device) + d = shape[-1] // 2 + out_shape = (*shape[:-1], d) + out = empty_strided(out_shape, out_strides, dtype=dtype, device=device) + + return Payload( + _silu_and_mul, + _torch_silu_and_mul, + (x, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _silu_and_mul(x, out): + if x.device.type == "npu": + infini.ops.silu_and_mul( + x, + -1, + out, + stream=get_npu_stream(x), + ) + else: + infini.ops.silu_and_mul(x, -1, out) + + return out + + +def _torch_silu_and_mul(x, out): + d = x.shape[-1] // 2 + gate = x[..., :d] + up = x[..., d:] + result = up * torch.sigmoid(gate) * gate + + return result.to(out.dtype) diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index 89c95f77..2419b10a 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, rand_strided +from tests.utils import Payload, empty_strided, get_npu_stream, rand_strided @pytest.mark.auto_act_and_assert @@ -19,6 +19,7 @@ ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ), ) +@pytest.mark.parametrize("implementation_index", (0, 1)) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( @@ -28,17 +29,53 @@ ), ) def test_swiglu( - shape, input_strides, gate_strides, out_strides, dtype, device, rtol, atol + shape, + input_strides, + gate_strides, + out_strides, + implementation_index, + dtype, + device, + rtol, + atol, ): + active_indices = infini.ops.Swiglu.active_implementation_indices(device) + + if implementation_index not in active_indices: + pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + input = rand_strided(shape, input_strides, dtype=dtype, device=device) gate = rand_strided(shape, gate_strides, dtype=dtype, device=device) out = empty_strided(shape, out_strides, dtype=dtype, device=device) - return Payload(_swiglu, _torch_swiglu, (input, gate, out), {}, rtol=rtol, atol=atol) + return Payload( + lambda *args, **kwargs: _swiglu( + *args, **kwargs, implementation_index=implementation_index + ), + _torch_swiglu, + (input, gate, out), + {}, + rtol=rtol, + atol=atol, + ) -def _swiglu(input, gate, out): - infini.ops.swiglu(input, gate, out) +def _swiglu(input, gate, out, implementation_index=0): + if input.device.type == "npu": + infini.ops.swiglu( + input, + gate, + out, + implementation_index=implementation_index, + stream=get_npu_stream(input), + ) + else: + infini.ops.swiglu( + input, + gate, + out, + implementation_index=implementation_index, + ) return out diff --git a/tests/utils.py b/tests/utils.py index 8f9532aa..8412cd61 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -82,47 +82,12 @@ def randint_strided(low, high, shape, strides, *, dtype=None, device=None): return output -def get_stream(device): - """Return the raw stream handle for `device`, or 0 for CPU. - - Uses `torch.accelerator.current_stream` when available, falling back to - device-specific APIs for older PyTorch versions. - """ - if isinstance(device, torch.device): - device = device.type - - if isinstance(device, str) and ":" in device: - device = device.split(":")[0] - - if device == "cpu": - return 0 - - if hasattr(torch, "accelerator") and hasattr(torch.accelerator, "current_stream"): - stream = torch.accelerator.current_stream() - - # Each backend exposes the raw handle under a different attribute name. - for attr in ("npu_stream", "cuda_stream", "mlu_stream", "musa_stream"): - if hasattr(stream, attr): - return getattr(stream, attr) - +def get_npu_stream(tensor): + """Return the current NPU stream handle for `tensor`, or 0 on other devices.""" + if tensor.device.type != "npu": return 0 - # Fallback for older PyTorch builds without `torch.accelerator`. - _STREAM_ACCESSORS = { - "npu": ("npu", "npu_stream"), - "cuda": ("cuda", "cuda_stream"), - "mlu": ("mlu", "mlu_stream"), - "musa": ("musa", "musa_stream"), - } - - if device in _STREAM_ACCESSORS: - mod_name, attr = _STREAM_ACCESSORS[device] - mod = getattr(torch, mod_name, None) - - if mod is not None and hasattr(mod, "current_stream"): - return getattr(mod.current_stream(), attr) - - return 0 + return torch.npu.current_stream().npu_stream def clone_strided(input):