From 27a056c41c98d3536569f958c713b01daaf05a2f Mon Sep 17 00:00:00 2001 From: Vladimir Serov Date: Mon, 26 Jan 2026 14:49:52 +0300 Subject: [PATCH 1/5] LoRA: Implementing kernels using CUBE computation unit --- csrc/CMakeLists.txt | 15 ++ csrc/lora/op_host/sgemmc_expand.cpp | 76 ++++++ csrc/lora/op_host/sgemmc_shrink.cpp | 72 ++++++ csrc/lora/op_host/tiling/sgemmc_tiling.cpp | 83 +++++++ csrc/lora/op_host/tiling/sgemmc_tiling.h | 37 +++ csrc/lora/op_host/tiling/sgemmc_tiling_data.h | 48 ++++ csrc/lora/op_kernel/lora_common_kernel.h | 58 +++++ csrc/lora/op_kernel/sgemmc_expand_kernel.cpp | 224 ++++++++++++++++++ csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp | 206 ++++++++++++++++ csrc/lora/op_kernel/sgemmv_expand_kernel.cpp | 27 +-- csrc/lora/op_kernel/sgemmv_shrink_kernel.cpp | 26 +- csrc/lora/op_kernel/sgmv_expand_kernel.cpp | 28 +-- csrc/lora/op_kernel/sgmv_shrink_kernel.cpp | 27 +-- csrc/pytorch_extensions.cpp | 19 +- csrc/utils/common_tiling.h | 23 ++ csrc/utils/kernel/common_tiling_kernel.h | 31 +++ csrc/utils/torch_helper.h | 44 ++++ include/sgl_kenel_npu_ops.h | 9 + 18 files changed, 982 insertions(+), 71 deletions(-) create mode 100644 csrc/lora/op_host/sgemmc_expand.cpp create mode 100644 csrc/lora/op_host/sgemmc_shrink.cpp create mode 100644 csrc/lora/op_host/tiling/sgemmc_tiling.cpp create mode 100644 csrc/lora/op_host/tiling/sgemmc_tiling.h create mode 100644 csrc/lora/op_host/tiling/sgemmc_tiling_data.h create mode 100644 csrc/lora/op_kernel/lora_common_kernel.h create mode 100644 csrc/lora/op_kernel/sgemmc_expand_kernel.cpp create mode 100644 csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp create mode 100644 csrc/utils/kernel/common_tiling_kernel.h diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 6b3236f8b..7ad6112b3 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -19,6 +19,9 @@ FILE(GLOB OP_SRCS ${PROJECT_OP_SRC_BASE}/lora/op_host/sgmv_shrink.cpp ${PROJECT_OP_SRC_BASE}/lora/op_host/sgemmv_expand.cpp ${PROJECT_OP_SRC_BASE}/lora/op_host/sgemmv_shrink.cpp + ${PROJECT_OP_SRC_BASE}/lora/op_host/sgemmc_expand.cpp + ${PROJECT_OP_SRC_BASE}/lora/op_host/sgemmc_shrink.cpp + ${PROJECT_OP_SRC_BASE}/lora/op_host/tiling/sgemmc_tiling.cpp ${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/lightning_indexer.cpp ${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/tiling/lightning_indexer_tiling.cpp ${PROJECT_OP_SRC_BASE}/tri_inv/op_host/tri_inv.cpp @@ -47,12 +50,18 @@ ascendc_library(no_workspace_kernel STATIC ${PROJECT_OP_SRC_BASE}/tri_inv/op_kernel/tri_inv_kernel.cpp ) +ascendc_include_directories(no_workspace_kernel PRIVATE + ${PROJECT_OP_SRC_BASE}/utils/kernel +) + # kernel side files with workspace set(WORKSPACE_KERNEL_SRCS ${PROJECT_OP_SRC_BASE}/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp ${PROJECT_OP_SRC_BASE}/alloc_extend/op_kernel/alloc_extend_kernel.cpp ${PROJECT_OP_SRC_BASE}/build_tree/op_kernel/build_tree_kernel.cpp ${PROJECT_OP_SRC_BASE}/lightning_indexer/op_kernel/lightning_indexer_kernel.cpp + ${PROJECT_OP_SRC_BASE}/lora/op_kernel/sgemmc_expand_kernel.cpp + ${PROJECT_OP_SRC_BASE}/lora/op_kernel/sgemmc_shrink_kernel.cpp ) if(BUILD_CATLASS_MODULE) list(APPEND WORKSPACE_KERNEL_SRCS @@ -66,6 +75,10 @@ if(BUILD_CATLASS_MODULE) ) endif() +ascendc_include_directories(workspace_kernel PRIVATE + ${PROJECT_OP_SRC_BASE}/utils/kernel +) + ascendc_compile_definitions(workspace_kernel PRIVATE -DHAVE_WORKSPACE -DHAVE_TILING @@ -77,6 +90,7 @@ add_library(${OP_PLUGIN_NAME} SHARED ${OP_SRCS}) target_link_libraries(${OP_PLUGIN_NAME} PRIVATE workspace_kernel no_workspace_kernel + host_intf_pub torch_npu ascendcl tiling_api @@ -97,6 +111,7 @@ target_include_directories(${OP_PLUGIN_NAME} PRIVATE ${TORCH_DIR}/include ${TORCH_DIR}/include/torch/csrc/api/include ${TORCH_NPU_DIR}/include + ${ASCEND_INCLUDE_DIR} ${ASCEND_INCLUDE_DIR}/external ${ASCEND_INCLUDE_DIR}/experiment/platform ${ASCEND_INCLUDE_DIR}/experiment/runtime diff --git a/csrc/lora/op_host/sgemmc_expand.cpp b/csrc/lora/op_host/sgemmc_expand.cpp new file mode 100644 index 000000000..24a0224d4 --- /dev/null +++ b/csrc/lora/op_host/sgemmc_expand.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "defines.h" +#include "tiling/sgemmc_tiling.h" +#include "torch_helper.h" + +#include "aclrtlaunch_sgemmc_expand.h" + +namespace sglang { +namespace npu_kernel { + +HOST_API at::Tensor sgemmc_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len, + at::Tensor &lora_ranks, at::Tensor &slice_offsets, at::Tensor &y) +{ + at::ScalarType scalar_type = y.scalar_type(); + TORCH_CHECK(scalar_type == at::kHalf || scalar_type == at::kBFloat16, "only support half and bf16"); + TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]"); + TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4, + "weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]"); + TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]"); + + at::Tensor y_out = y; + void *x_ptr = x.data_ptr(); + void *weight_ptr = weight.data_ptr(); + void *y_ptr = y.data_ptr(); + void *y_out_ptr = y_out.data_ptr(); + + void *lora_indices_ptr = lora_indices.data_ptr(); + int lora_indices_size = lora_indices.size(0); + void *seq_len_ptr = seq_len.data_ptr(); + int seq_len_size = seq_len.size(0); + void *lora_ranks_ptr = lora_ranks.data_ptr(); + int lora_ranks_size = lora_ranks.size(0); + void *slice_offsets_ptr = slice_offsets.data_ptr(); + int slice_offsets_size = slice_offsets.size(0); + int slice_count = slice_offsets_size - 1; + int batch_size = x.size(0); + int max_lora_rank = x.size(1) / slice_count; + int output_full_dim = y.size(1); + + uint32_t block_dim; + uint32_t workspace_size; + int64_t num_tokens_per_core = 0; + int input_hidden_token = 0; + + at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, input_hidden_token, max_lora_rank, + TorchNpuHelper::ConvertDataType(scalar_type)); + auto workspace_tensor = + at::empty({workspace_size}, at::TensorOptions().dtype(at::kByte).device(x.options().device())); + + /* launch the kernel function via torch */ + EXEC_KERNEL_CMD(sgemmc_expand, block_dim, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, + seq_len_size, lora_ranks_ptr, lora_ranks_size, slice_offsets_ptr, slice_offsets_size, y_ptr, + y_out_ptr, batch_size, num_tokens_per_core, max_lora_rank, output_full_dim, workspace_tensor, + tiling_tensor); + + return y_out; +} + +} // namespace npu_kernel +} // namespace sglang diff --git a/csrc/lora/op_host/sgemmc_shrink.cpp b/csrc/lora/op_host/sgemmc_shrink.cpp new file mode 100644 index 000000000..7d50f1aa6 --- /dev/null +++ b/csrc/lora/op_host/sgemmc_shrink.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "defines.h" +#include "tiling/sgemmc_tiling.h" +#include "torch_helper.h" + +#include "aclrtlaunch_sgemmc_shrink.h" + +namespace sglang { +namespace npu_kernel { + +HOST_API void sgemmc_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len, + at::Tensor &lora_ranks, at::Tensor &lora_scales, at::Tensor &y) +{ + at::ScalarType scalar_type = x.scalar_type(); + TORCH_CHECK(scalar_type == at::kHalf || scalar_type == at::kBFloat16, "only support half and bf16"); + TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]"); + TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4, + "weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]"); + TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]"); + TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out"); + void *x_ptr = x.data_ptr(); + void *weight_ptr = weight.data_ptr(); + + void *lora_indices_ptr = lora_indices.data_ptr(); + int lora_indices_size = lora_indices.size(0); + void *seq_len_ptr = seq_len.data_ptr(); + int seq_len_size = seq_len.size(0); + void *lora_ranks_ptr = lora_ranks.data_ptr(); + int lora_ranks_size = lora_ranks.size(0); + void *lora_scales_ptr = lora_scales.data_ptr(); + int lora_scales_size = lora_scales.size(0); + + void *y_ptr = y.data_ptr(); + int batch_size = x.size(0); + int input_hidden_token = x.size(1); + uint32_t max_lora_rank = y.size(1); + + uint32_t block_dim; + uint32_t workspace_size; + int64_t total_extend_tokens = 0; + int64_t num_tokens_per_core = 0; + + at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, input_hidden_token, max_lora_rank, + TorchNpuHelper::ConvertDataType(scalar_type)); + + auto workspace_tensor = + at::empty({workspace_size}, at::TensorOptions().dtype(at::kByte).device(x.options().device())); + /* launch the kernel function via torch */ + EXEC_KERNEL_CMD(sgemmc_shrink, block_dim, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, + seq_len_size, lora_ranks_ptr, lora_ranks_size, lora_scales_ptr, lora_scales_size, y_ptr, batch_size, + num_tokens_per_core, input_hidden_token, max_lora_rank, workspace_tensor, tiling_tensor); + return; +} + +} // namespace npu_kernel +} // namespace sglang diff --git a/csrc/lora/op_host/tiling/sgemmc_tiling.cpp b/csrc/lora/op_host/tiling/sgemmc_tiling.cpp new file mode 100644 index 000000000..7ad493507 --- /dev/null +++ b/csrc/lora/op_host/tiling/sgemmc_tiling.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "common.h" +#include "sgemmc_tiling.h" + +namespace sglang { +namespace npu_kernel { + +matmul_tiling::DataType ConvertToMatMulTypes(host_utils::DataType data_type) +{ + switch (data_type) { + case host_utils::DataType::DT_BFLOAT16: + return matmul_tiling::DataType::DT_BFLOAT16; + case host_utils::DataType::DT_FLOAT: + return matmul_tiling::DataType::DT_FLOAT; + case host_utils::DataType::DT_FLOAT16: + return matmul_tiling::DataType::DT_FLOAT16; + } + + return matmul_tiling::DataType::DT_FLOAT16; +} + +at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_t batch_size, uint32_t hidden_size, + uint32_t max_lora_rank, const host_utils::DataType type) +{ + auto ascendc_platform = *platform_ascendc::PlatformAscendCManager::GetInstance(); + uint32_t aiv_num = ascendcPlatform.GetCoreNumAiv(); + uint32_t aic_num = ascendcPlatform.GetCoreNumAic(); + workspace_size = ascendcPlatform.GetLibApiWorkSpaceSize(); + + auto tilingBuffer = at::empty({sizeof(SGEMMCTilingData)}, at::TensorOptions().dtype(at::kByte).device(at::kCPU)); + SGEMMCTilingData *tiling_data = reinterpret_cast(tilingBuffer.data_ptr()); + + matmul_tiling::MultiCoreMatmulTiling cubeTiling(ascendc_platform); + + uint32_t M = batch_size; + uint32_t N = hidden_size; + uint32_t K = max_lora_rank; + + const matmul_tiling::DataType data_type = ConvertToMatMulTypes(type); + + cubeTiling.EnableBias(false); + cubeTiling.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::VECTOR, data_type, false); + cubeTiling.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, data_type, true); + cubeTiling.SetCType(matmul_tiling::TPosition::VECIN, matmul_tiling::CubeFormat::ND, + matmul_tiling::DataType::DT_FLOAT); + cubeTiling.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, data_type); + + cubeTiling.SetDim(aic_num); + + cubeTiling.SetOrgShape(1, hidden_size, max_lora_rank); + cubeTiling.SetShape(1, hidden_size, max_lora_rank); + cubeTiling.SetBufferSpace(-1, -1, -1); + + if (cubeTiling.GetTiling(tiling_data->cubeTiling) == -1) { + TORCH_CHECK(false, "Generate tiling failed."); + return {}; + } + + tiling_data->batch = batch_size; + + block_dim = batch * tiling_data->cubeTiling.usedCoreNum; + + return tilingBuffer; +} + +} // namespace npu_kernel +} // namespace sglang diff --git a/csrc/lora/op_host/tiling/sgemmc_tiling.h b/csrc/lora/op_host/tiling/sgemmc_tiling.h new file mode 100644 index 000000000..075bf6594 --- /dev/null +++ b/csrc/lora/op_host/tiling/sgemmc_tiling.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef SGEMMC_TILING_H +#define SGEMMC_TILING_H + +#include +#include + +#include "torch_helper.h" +#include "common_tiling.h" +#include "sgemmc_tiling_data.h" + +namespace sglang { +namespace npu_kernel { + +at::Tensor GenerateTiling(uint32_t &blockDim, uint32_t &workspace, uint32_t batch, uint32_t hidden_size, uint32_t k, + const host_utils::DataType type); + +} // namespace npu_kernel +} // namespace sglang + +#endif // SGEMMC_TILING_H diff --git a/csrc/lora/op_host/tiling/sgemmc_tiling_data.h b/csrc/lora/op_host/tiling/sgemmc_tiling_data.h new file mode 100644 index 000000000..88c99cd8a --- /dev/null +++ b/csrc/lora/op_host/tiling/sgemmc_tiling_data.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef SGEMMC_TILING_DATA_H +#define SGEMMC_TILING_DATA_H + +#include + +namespace AscendC { +namespace tiling { + +struct TCubeTiling; + +} // namespace tiling +} // namespace AscendC + +namespace sglang { +namespace npu_kernel { + +#pragma pack(push, 1) +struct SGEMMCTilingData { + uint32_t dataType; + uint32_t batch; + uint32_t hidden; + uint32_t k; + uint32_t slices; + AscendC::tiling::TCubeTiling cubeTiling; +}; +#pragma pack(pop) + +} // namespace npu_kernel +} // namespace sglang + +#endif // SGEMMC_TILING_DATA_H diff --git a/csrc/lora/op_kernel/lora_common_kernel.h b/csrc/lora/op_kernel/lora_common_kernel.h new file mode 100644 index 000000000..34ce9bb59 --- /dev/null +++ b/csrc/lora/op_kernel/lora_common_kernel.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef SGL_KERNEL_NPU_KERNEL_LORA_COMMON_H +#define SGL_KERNEL_NPU_KERNEL_LORA_COMMON_H + +#include "kernel_operator.h" + +namespace lora_common { + +template +class BlockIterator +{ + AscendC::GlobalTensor blocks; + int64_t previous_block; + int64_t previous_offset; + +public: + __aicore__ explicit BlockIterator(AscendC::GlobalTensor &blocks_) + : blocks(blocks_), previous_block(0), previous_offset(0) + {} + __aicore__ inline int64_t GetBlockIdx(int64_t index) + { + int64_t current_offset = previous_offset; + uint64_t blockIdx = previous_block; + + for (; blockIdx < blocks.GetSize(); ++blockIdx) { + int64_t blockOffset = blocks.GetValue(blockIdx); + if (index >= current_offset + blockOffset) { + current_offset += blockOffset; + } else { + previous_offset = current_offset; + previous_block = blockIdx; + return blockIdx; + } + } + + return -1; + } +}; + +} // namespace lora_common + +#endif // SGL_KERNEL_NPU_KERNEL_LORA_COMMON_H diff --git a/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp b/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp new file mode 100644 index 000000000..4455e110e --- /dev/null +++ b/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp @@ -0,0 +1,224 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef SGL_KERNEL_NPU_KERNEL_SGEMMC_EXPAND_H +#define SGL_KERNEL_NPU_KERNEL_SGEMMC_EXPAND_H + +#include +#include "kernel_operator.h" +#include "lib/matmul_intf.h" +#include "lora_common_kernel.h" +#include "common_tiling_kernel.h" + +#include "../op_host/tiling/sgemmc_tiling_data.h" + +template +class SGEMMCExpand +{ +public: + using X_T = scalar_t; + using W_T = scalar_t; + using INNER_T = inner_t; + using Y_T = scalar_t; + + using X_MAT_TYPE = AscendC::MatmulType; + using W_MAT_TYPE = AscendC::MatmulType; + using Y_MAT_TYPE = AscendC::MatmulType; + using BIAS_MAT_TYPE = AscendC::MatmulType; + + using MAT_TYPE = AscendC::Matmul; + +public: + __aicore__ explicit SGEMMCExpand(AscendC::TPipe *pipe) : pipe_(pipe) {} + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, uint32_t loraIndicesSize, + GM_ADDR seqLen, uint32_t seqLenSize, GM_ADDR loraRanks, uint32_t loraRanksSize, + GM_ADDR sliceOffsets, uint32_t sliceOffsetsSize, GM_ADDR yIn, GM_ADDR yOut, + uint32_t batchSize, uint32_t numBlocksPerCore, uint32_t maxLoRARank, + uint32_t outputFullDim, GM_ADDR workspace, TCubeTiling &tiling) + { + this->tiling = tiling; + + batchSize_ = batchSize; + numBlocksPerCore_ = numBlocksPerCore; + maxLoRARank_ = maxLoRARank; + sliceCount_ = sliceOffsetsSize - 1; + outputFullDim_ = outputFullDim; + singleLoRAWeightLen_ = maxLoRARank_ * outputFullDim_; + + xInGm_.SetGlobalBuffer(reinterpret_cast<__gm__ X_T *>(x)); + wInGm_.SetGlobalBuffer(reinterpret_cast<__gm__ W_T *>(weight)); + yInGm_.SetGlobalBuffer(reinterpret_cast<__gm__ Y_T *>(yIn)); + yOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ Y_T *>(yOut)); + loraIndicesGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(loraIndices), loraIndicesSize); + seqLenGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(seqLen), seqLenSize); + loraRanksGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(loraRanks), loraRanksSize); + sliceOffsetsGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sliceOffsets), sliceOffsetsSize); + + workspaceGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ Y_T *>(workspace)); + + REGIST_MATMUL_OBJ(pipe_, GetSysWorkSpacePtr(), matmulObj, &tiling); + } + + __aicore__ inline void Process() + { + int64_t blocks = AscendC::GetBlockNum(); + int64_t blockIdx = AscendC::GetBlockIdx(); + + int64_t startIdx = blockIdx * numBlocksPerCore_; + int64_t endIdx = startIdx + numBlocksPerCore_; + + AscendC::WaitPreTaskEnd(); + + int64_t batchIdx = 0; + int64_t requestBlock = 0; + lora_common::BlockIterator blockIterator(seqLenGm_); + requestBlock = blockIterator.GetBlockIdx(batchIdx); + if (requestBlock < 0) { + return; + } + + int32_t reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); + if (reqLoRAIndex_ < 0) { + return; + } + + int64_t reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_; + int32_t reqLoRARank_ = loraRanksGm_.GetValue(reqLoRAIndex_); + + if (reqLoRARank_ == 0) { + return; + } + + matmulObj.SetWorkspace(workspaceGlobal); + matmulObj.SetTensorA(xInGm_); + matmulObj.SetTensorB(wInGm_); + matmulObj.template Iterate(); + + uint32_t baseM = tiling.baseM; + uint32_t baseN = tiling.baseN; + pipe_->InitBuffer(vectorCalcBuf, baseM * baseN * sizeof(INNER_T)); + pipe_->InitBuffer(vectorInQueue, 1, baseM * baseN * sizeof(INNER_T)); + pipe_->InitBuffer(vectorYInQueue, 1, baseM * baseN * sizeof(INNER_T)); + pipe_->InitBuffer(vectorOutQueue, 1, baseM * baseN * sizeof(Y_T)); + + AscendC::DataCopyParams copyParams = {(uint16_t)baseM, + (uint16_t)(baseN * sizeof(Y_T) / AscendC::DEFAULT_C0_SIZE), (uint16_t)0, + (uint16_t)((tiling.N - baseN) * sizeof(Y_T) / AscendC::DEFAULT_C0_SIZE)}; + uint32_t iterateTimes = AscendC::Ceil(tiling.singleCoreM, baseM) * AscendC::Ceil(tiling.singleCoreN, baseN); + for (uint32_t i = 0; i < iterateTimes; ++i) { + auto cInLocal = vectorInQueue.AllocTensor(); + matmulObj.template GetTensorC(cInLocal); + vectorInQueue.EnQue(cInLocal); + + AscendC::LocalTensor yInLocalCube = vectorYInQueue.AllocTensor(); + DataCopy(yInLocalCube, yInGm_[i], baseM * baseN); + vectorYInQueue.EnQue(yInLocalCube); + + AscendC::LocalTensor tmpTensor = vectorCalcBuf.Get(); + AscendC::LocalTensor yInLocal = vectorYInQueue.DeQue(); + AscendC::LocalTensor yLocal = vectorInQueue.DeQue(); + Cast(tmpTensor, yInLocal, AscendC::RoundMode::CAST_NONE, baseM * baseN); + pipe_barrier(PIPE_V); + vectorYInQueue.FreeTensor(yInLocal); + + Add(yLocal, yLocal, tmpTensor, baseM * baseN); + pipe_barrier(PIPE_V); + + AscendC::LocalTensor yOutLocal = vectorOutQueue.AllocTensor(); + Cast(yOutLocal, yLocal, AscendC::RoundMode::CAST_RINT, baseM * baseN); + pipe_barrier(PIPE_V); + + vectorOutQueue.EnQue(yOutLocal); + + // copy out + auto cOutLocal = vectorOutQueue.DeQue(); + DataCopy(yOutGm_[i], cOutLocal, copyParams); + vectorOutQueue.FreeTensor(cOutLocal); + } + matmulObj.End(); + AscendC::SetNextTaskStart(); + } + +private: + AscendC::TPipe *pipe_; + MAT_TYPE matmulObj; + + AscendC::GlobalTensor workspaceGlobal; + + TCubeTiling tiling; + AscendC::TQue vectorInQueue; + AscendC::TQue vectorYInQueue; + AscendC::TQue vectorOutQueue; + AscendC::TBuf vectorCalcBuf; + + AscendC::GlobalTensor xInGm_; + AscendC::GlobalTensor wInGm_; + AscendC::GlobalTensor yInGm_; + AscendC::GlobalTensor yOutGm_; + + AscendC::GlobalTensor seqLenGm_; + AscendC::GlobalTensor loraIndicesGm_; + + AscendC::GlobalTensor loraRanksGm_; + AscendC::GlobalTensor sliceOffsetsGm_; + + uint32_t batchSize_; + uint32_t sliceCount_; + uint32_t numBlocksPerCore_; + uint32_t maxLoRARank_; + uint32_t outputHiddenDim_; + uint32_t sliceOffset_; + uint32_t outputFullDim_; + uint32_t singleLoRAWeightLen_; + int64_t reqLoRAIndex_; + int32_t reqLoRARank_; + uint64_t reqLoRAWeightOffset_; + int32_t reqSlice_; + uint32_t numOutputElementsPerInputTile_; + uint32_t numStreamInPerOutputTile_; + uint64_t yOffset_; +}; + +extern "C" __global__ __aicore__ void sgemmc_expand(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, + uint32_t loraIndicesSize, GM_ADDR seqLen, uint32_t seqLenSize, + GM_ADDR loraRanks, uint32_t loraRanksSize, GM_ADDR sliceOffsets, + uint32_t sliceOffsetsSize, GM_ADDR yIn, GM_ADDR yOut, + uint32_t batchSize, uint32_t numBlocksPerCore, uint32_t maxLoRARank, + uint32_t outputFullDim, GM_ADDR workspace, GM_ADDR tiling) +{ + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_1); + + AscendC::TPipe pipe; + sglang::npu_kernel::SGEMMCTilingData tilingData; + kernel_utils::CopyTiling(&tilingData, tiling); + + if (tilingData.dataType == 1) { + SGEMMCExpand op(&pipe); + op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, sliceOffsets, + sliceOffsetsSize, yIn, yOut, batchSize, numBlocksPerCore, maxLoRARank, outputFullDim, workspace, + tilingData.cubeTiling); + op.Process(); + } else { + SGEMMCExpand op(&pipe); + op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, sliceOffsets, + sliceOffsetsSize, yIn, yOut, batchSize, numBlocksPerCore, maxLoRARank, outputFullDim, workspace, + tilingData.cubeTiling); + op.Process(); + } +} + +#endif // SGL_KERNEL_NPU_KERNEL_SGEMMC_EXPAND_H diff --git a/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp b/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp new file mode 100644 index 000000000..0045883f8 --- /dev/null +++ b/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp @@ -0,0 +1,206 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef SGL_KERNEL_NPU_KERNEL_SGEMMC_SHRINK_H +#define SGL_KERNEL_NPU_KERNEL_SGEMMC_SHRINK_H + +#include +#include "kernel_operator.h" +#include "lib/matmul_intf.h" +#include "lora_common_kernel.h" +#include "common_tiling_kernel.h" + +#include "../op_host/tiling/sgemmc_tiling_data.h" + +template +class SGEMMCShrink +{ +public: + using X_T = scalar_t; + using W_T = scalar_t; + using INNER_T = inner_t; + using Y_T = scalar_t; + + using X_MAT_TYPE = AscendC::MatmulType; + using W_MAT_TYPE = AscendC::MatmulType; + using Y_MAT_TYPE = AscendC::MatmulType; + using BIAS_MAT_TYPE = AscendC::MatmulType; + + using MAT_TYPE = AscendC::Matmul; + +public: + __aicore__ explicit SGEMMCShrink(AscendC::TPipe *pipe) : pipe_(pipe) {} + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, uint32_t loraIndicesSize, + GM_ADDR seqLen, uint32_t seqLenSize, GM_ADDR loraRanks, uint32_t loraRanksSize, + GM_ADDR loraScales, uint32_t loraScalesSize, GM_ADDR y, uint32_t batchSize, + uint32_t numBlocksPerCore, uint32_t inputHiddenDim, uint32_t maxLoRARank, + GM_ADDR workspace, TCubeTiling &tiling) + { + this->tiling = tiling; + + batchSize_ = batchSize; + numBlocksPerCore_ = numBlocksPerCore; + inputHiddenDim_ = inputHiddenDim; + maxLoRARank_ = maxLoRARank; + singleLoRAWeightLen_ = inputHiddenDim_ * maxLoRARank_; + + xInGm_.SetGlobalBuffer(reinterpret_cast<__gm__ X_T *>(x)); + yOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ Y_T *>(y)); + wInGm_.SetGlobalBuffer(reinterpret_cast<__gm__ W_T *>(weight)); + loraIndicesGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(loraIndices), loraIndicesSize); + seqLenGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(seqLen), seqLenSize); + loraRanksGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(loraRanks), loraRanksSize); + loraScalesGm_.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(loraScales), loraScalesSize); + + workspaceGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ Y_T *>(workspace)); + + REGIST_MATMUL_OBJ(pipe_, GetSysWorkSpacePtr(), matmulObj, &tiling); + } + + __aicore__ inline void Process() + { + int64_t blocks = AscendC::GetBlockNum(); + int64_t blockIdx = AscendC::GetBlockIdx(); + + int64_t startIdx = blockIdx * numBlocksPerCore_; + int64_t endIdx = startIdx + numBlocksPerCore_; + + AscendC::WaitPreTaskEnd(); + + int64_t batchIdx = 0; + int64_t requestBlock = 0; + lora_common::BlockIterator blockIterator(seqLenGm_); + requestBlock = blockIterator.GetBlockIdx(batchIdx); + if (requestBlock < 0) { + return; + } + + int32_t reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); + if (reqLoRAIndex_ < 0) { + return; + } + + int64_t reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_; + int32_t reqLoRARank_ = loraRanksGm_.GetValue(reqLoRAIndex_); + + if (reqLoRARank_ == 0) { + return; + } + + matmulObj.SetWorkspace(workspaceGlobal); + matmulObj.SetTensorA(xInGm_); + matmulObj.SetTensorB(wInGm_); + matmulObj.template Iterate(); + + half loraScale = loraScalesGm_.GetValue(reqLoRAIndex_); + INNER_T scalar = AscendC::ScalarCast(loraScale); + + uint32_t baseM = this->tiling.baseM; + uint32_t baseN = this->tiling.baseN; + pipe_->InitBuffer(vectorCalcBuf, baseM * baseN * sizeof(INNER_T)); + pipe_->InitBuffer(vectorInQueue, 1, baseM * baseN * sizeof(INNER_T)); + pipe_->InitBuffer(vectorOutQueue, 1, baseM * baseN * sizeof(Y_T)); + + AscendC::DataCopyParams copyParams = { + (uint16_t)baseM, (uint16_t)(baseN * sizeof(Y_T) / AscendC::DEFAULT_C0_SIZE), (uint16_t)0, + (uint16_t)((this->tiling.N - baseN) * sizeof(Y_T) / AscendC::DEFAULT_C0_SIZE)}; + uint32_t iterateTimes = + AscendC::Ceil(this->tiling.singleCoreM, baseM) * AscendC::Ceil(this->tiling.singleCoreN, baseN); + for (uint32_t i = 0; i < iterateTimes; ++i) { + // compute + auto cInLocal = vectorInQueue.AllocTensor(); + matmulObj.template GetTensorC(cInLocal); + vectorInQueue.EnQue(cInLocal); + // any vector operator + auto src = vectorInQueue.DeQue(); + auto dst = vectorOutQueue.AllocTensor(); + + AscendC::LocalTensor tmpTensor = vectorCalcBuf.Get(); + AscendC::Muls(tmpTensor, src, scalar, baseM * baseN); + AscendC::PipeBarrier(); + AscendC::Cast(dst, tmpTensor, AscendC::RoundMode::CAST_NONE, baseM * baseN); + AscendC::PipeBarrier(); + vectorOutQueue.EnQue(dst); + vectorInQueue.FreeTensor(src); + // copy out + auto cOutLocal = vectorOutQueue.DeQue(); + DataCopy(yOutGm_[i], cOutLocal, copyParams); + vectorOutQueue.FreeTensor(cOutLocal); + } + matmulObj.End(); + AscendC::SetNextTaskStart(); + } + +private: + AscendC::TPipe *pipe_; + MAT_TYPE matmulObj; + + AscendC::GlobalTensor xInGm_; + AscendC::GlobalTensor wInGm_; + AscendC::GlobalTensor yOutGm_; + AscendC::GlobalTensor loraIndicesGm_; + AscendC::GlobalTensor seqLenGm_; + AscendC::GlobalTensor loraRanksGm_; + AscendC::GlobalTensor loraScalesGm_; + + AscendC::GlobalTensor workspaceGlobal; + + TCubeTiling tiling; + AscendC::TQue vectorInQueue; + AscendC::TQue vectorOutQueue; + AscendC::TBuf vectorCalcBuf; + + uint32_t batchSize_; + uint32_t numBlocksPerCore_; + uint32_t inputHiddenDim_; + uint32_t maxLoRARank_; + uint32_t singleLoRAWeightLen_; + + uint64_t reqLoRAWeightOffset_; + int32_t reqLoRAIndex_; + int32_t reqLoRARank_; +}; + +extern "C" __global__ __aicore__ void sgemmc_shrink(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, + uint32_t loraIndicesSize, GM_ADDR seqLen, uint32_t seqLenSize, + GM_ADDR loraRanks, uint32_t loraRanksSize, GM_ADDR loraScales, + uint32_t loraScalesSize, GM_ADDR y, uint32_t batchSize, + uint32_t numBlocksPerCore, uint32_t inputHiddenDim, + uint32_t maxLoRARank, GM_ADDR workspace, GM_ADDR tiling) +{ + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_1); + + AscendC::TPipe pipe; + sglang::npu_kernel::SGEMMCTilingData tilingData; + kernel_utils::CopyTiling(&tilingData, tiling); + + if (tilingData.dataType == 1) { + SGEMMCShrink op(&pipe); + op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, loraScales, + loraScalesSize, y, batchSize, numBlocksPerCore, inputHiddenDim, maxLoRARank, workspace, + tilingData.cubeTiling); + op.Process(); + } else { + SGEMMCShrink op(&pipe); + op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, loraScales, + loraScalesSize, y, batchSize, numBlocksPerCore, inputHiddenDim, maxLoRARank, workspace, + tilingData.cubeTiling); + op.Process(); + } +} + +#endif // SGL_KERNEL_NPU_KERNEL_SGEMMC_SHRINK_H diff --git a/csrc/lora/op_kernel/sgemmv_expand_kernel.cpp b/csrc/lora/op_kernel/sgemmv_expand_kernel.cpp index 284b75ccf..8111ed3a8 100644 --- a/csrc/lora/op_kernel/sgemmv_expand_kernel.cpp +++ b/csrc/lora/op_kernel/sgemmv_expand_kernel.cpp @@ -20,6 +20,7 @@ #define SGL_KERNEL_NPU_KERNEL_SGEMMV_EXPAND_H #include "kernel_operator.h" +#include "lora_common_kernel.h" template class SGEMMVExpand @@ -106,11 +107,19 @@ class SGEMMVExpand if (endIdx > batchSize_) { endIdx = batchSize_; } + + int64_t requestBlock = 0; + lora_common::BlockIterator blockIterator(seqLenGm_); for (int64_t idx = startIdx; idx < endIdx; idx++) { yOffset_ = outputFullDim_ * idx + sliceOffset_; + requestBlock = blockIterator.GetBlockIdx(idx); + if (requestBlock < 0) { + continue; + } + // Set up LoRA index - CopyInIndex(idx); + reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); if (reqLoRAIndex_ < 0) { continue; } @@ -144,22 +153,6 @@ class SGEMMVExpand } private: - __aicore__ inline void CopyInIndex(const int64_t idx) - { - // Look up the LoRA index - int64_t weightIdx = idx; - uint64_t i = 0; - for (; i < seqLenGm_.GetSize(); i++) { - int64_t repeatValue = seqLenGm_.GetValue(i); - if (weightIdx >= repeatValue) { - weightIdx -= repeatValue; - continue; - } - break; - } - reqLoRAIndex_ = (i < seqLenGm_.GetSize()) ? loraIndicesGm_.GetValue(i) : -1; - } - __aicore__ inline void ComputeLastIteration() { int32_t remainingY = outputHiddenDim_ % Y_OUT_TILE_NUM_ELEMENTS; diff --git a/csrc/lora/op_kernel/sgemmv_shrink_kernel.cpp b/csrc/lora/op_kernel/sgemmv_shrink_kernel.cpp index a01fa4df7..54dbf27e6 100644 --- a/csrc/lora/op_kernel/sgemmv_shrink_kernel.cpp +++ b/csrc/lora/op_kernel/sgemmv_shrink_kernel.cpp @@ -20,6 +20,7 @@ #define SGL_KERNEL_NPU_KERNEL_SGEMMV_SHRINK_H #include "kernel_operator.h" +#include "lora_common_kernel.h" template class SGEMMVShrink @@ -71,9 +72,16 @@ class SGEMMVShrink if (endIdx > batchSize_) { endIdx = batchSize_; } + int64_t requestBlock = 0; + lora_common::BlockIterator blockIterator(seqLenGm_); for (int64_t idx = startIdx; idx < endIdx; idx++) { // set up LoRA index - CopyInIndex(idx); + requestBlock = blockIterator.GetBlockIdx(idx); + if (requestBlock < 0) { + continue; + } + + reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); if (reqLoRAIndex_ < 0) { continue; } @@ -126,22 +134,6 @@ class SGEMMVShrink } } - __aicore__ inline void CopyInIndex(const int64_t idx) - { - // look up the LoRA index - int64_t weightIdx = idx; - uint64_t i = 0; - for (; i < seqLenGm_.GetSize(); i++) { - int64_t repeatValue = seqLenGm_.GetValue(i); - if (weightIdx >= repeatValue) { - weightIdx -= repeatValue; - continue; - } - break; - } - reqLoRAIndex_ = (i < seqLenGm_.GetSize()) ? loraIndicesGm_.GetValue(i) : -1; - } - __aicore__ inline void CopyInX(const int64_t idx, int32_t colIdx, int32_t numElements = TILE_LENGTH) { AscendC::LocalTensor xLocal = inQueueX_.AllocTensor(); diff --git a/csrc/lora/op_kernel/sgmv_expand_kernel.cpp b/csrc/lora/op_kernel/sgmv_expand_kernel.cpp index 9be8b01ce..5bb573990 100644 --- a/csrc/lora/op_kernel/sgmv_expand_kernel.cpp +++ b/csrc/lora/op_kernel/sgmv_expand_kernel.cpp @@ -20,6 +20,7 @@ #define SGL_KERNEL_NPU_KERNEL_SGMV_EXPAND_H #include "kernel_operator.h" +#include "lora_common_kernel.h" template class SGMVExpand @@ -102,11 +103,18 @@ class SGMVExpand if (endIdx > batchSize_) { endIdx = batchSize_; } + + int64_t requestBlock = 0; + lora_common::BlockIterator blockIterator(seqLenGm_); for (int64_t idx = startIdx; idx < endIdx; idx++) { yOffset_ = outputFullDim_ * idx + sliceOffset_; - // Set up LoRA index - CopyInIndex(idx); + requestBlock = blockIterator.GetBlockIdx(idx); + if (requestBlock < 0) { + continue; + } + + reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); if (reqLoRAIndex_ < 0) { continue; } @@ -128,22 +136,6 @@ class SGMVExpand } private: - __aicore__ inline void CopyInIndex(const int64_t idx) - { - // Look up the LoRA index - int64_t weightIdx = idx; - uint64_t i = 0; - for (; i < seqLenGm_.GetSize(); i++) { - int64_t repeatValue = seqLenGm_.GetValue(i); - if (weightIdx >= repeatValue) { - weightIdx -= repeatValue; - continue; - } - break; - } - reqLoRAIndex_ = (i < seqLenGm_.GetSize()) ? loraIndicesGm_.GetValue(i) : -1; - } - __aicore__ inline void ComputeLastIteration() { int32_t remainingY = outputHiddenDim_ % Y_OUT_TILE_NUM_ELEMENTS; diff --git a/csrc/lora/op_kernel/sgmv_shrink_kernel.cpp b/csrc/lora/op_kernel/sgmv_shrink_kernel.cpp index 05b7212a8..e483e5ab1 100644 --- a/csrc/lora/op_kernel/sgmv_shrink_kernel.cpp +++ b/csrc/lora/op_kernel/sgmv_shrink_kernel.cpp @@ -20,6 +20,7 @@ #define SGL_KERNEL_NPU_KERNEL_SGMV_SHRINK_H #include "kernel_operator.h" +#include "lora_common_kernel.h" template class SGMVShrink @@ -69,9 +70,17 @@ class SGMVShrink if (endIdx > batchSize_) { endIdx = batchSize_; } + + int64_t requestBlock = 0; + lora_common::BlockIterator blockIterator(seqLenGm_); for (int64_t idx = startIdx; idx < endIdx; idx++) { // set up LoRA index - CopyInIndex(idx); + requestBlock = blockIterator.GetBlockIdx(idx); + if (requestBlock < 0) { + continue; + } + + reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); if (reqLoRAIndex_ < 0) { continue; } @@ -116,22 +125,6 @@ class SGMVShrink } } - __aicore__ inline void CopyInIndex(const int64_t idx) - { - // look up the LoRA index - int64_t weightIdx = idx; - uint64_t i = 0; - for (; i < seqLenGm_.GetSize(); i++) { - int64_t repeatValue = seqLenGm_.GetValue(i); - if (weightIdx >= repeatValue) { - weightIdx -= repeatValue; - continue; - } - break; - } - reqLoRAIndex_ = (i < seqLenGm_.GetSize()) ? loraIndicesGm_.GetValue(i) : -1; - } - __aicore__ inline void CopyInX(const int64_t idx, int32_t colIdx, int32_t numElements = TILE_LENGTH) { AscendC::LocalTensor xLocal = inQueueX_.AllocTensor(); diff --git a/csrc/pytorch_extensions.cpp b/csrc/pytorch_extensions.cpp index a7d093730..42f8e1b1d 100644 --- a/csrc/pytorch_extensions.cpp +++ b/csrc/pytorch_extensions.cpp @@ -70,14 +70,17 @@ TORCH_LIBRARY_FRAGMENT(npu, m) "bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y," " int slice_offset, int slice_size) -> Tensor"); - m.def("bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()"); + m.def( + "bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y," + " float scale) -> ()"); m.def( "sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y," " int slice_offset, int slice_size) -> Tensor"); m.def( - "sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()"); + "sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y," + " float scale) -> ()"); m.def( "sgemmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! lora_ranks," @@ -87,6 +90,14 @@ TORCH_LIBRARY_FRAGMENT(npu, m) "sgemmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! lora_ranks," " Tensor! lora_scales, Tensor! y) -> ()"); + m.def( + "sgemmc_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! lora_ranks," + " Tensor! sliceOffsets, Tensor! y) -> Tensor"); + + m.def( + "sgemmc_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! lora_ranks," + " Tensor! lora_scales, Tensor! y) -> ()"); + #ifdef BUILD_CATLASS_MODULE m.def("catlass_matmul_basic(Tensor tensor_a, Tensor tensor_b, Tensor(a!) tensor_c, str? format_mode=None) -> ()"); #endif @@ -134,6 +145,10 @@ TORCH_LIBRARY_IMPL(npu, PrivateUse1, m) m.impl("sgemmv_shrink", TORCH_FN(sglang::npu_kernel::sgemmv_shrink)); + m.impl("sgemmc_expand", TORCH_FN(sglang::npu_kernel::sgemmc_expand)); + + m.impl("sgemmc_shrink", TORCH_FN(sglang::npu_kernel::sgemmc_shrink)); + #ifdef BUILD_CATLASS_MODULE m.impl("catlass_matmul_basic", TORCH_FN(sglang::npu_kernel::catlass_matmul_basic)); #endif diff --git a/csrc/utils/common_tiling.h b/csrc/utils/common_tiling.h index a7543142c..dd8b22371 100644 --- a/csrc/utils/common_tiling.h +++ b/csrc/utils/common_tiling.h @@ -19,6 +19,27 @@ namespace host_utils { +enum class DataType : int32_t { + DT_FLOAT16 = 0, // fp16 type + DT_BFLOAT16 = 1, // bfloat16 type + DT_FLOAT = 2, // float type + DT_DOUBLE = 3, // double type + DT_BOOL = 4, // bool type + DT_INT8 = 5, // int8 type + DT_INT16, // int16 type + DT_INT32, // int32 type + DT_INT64, // int64 type + DT_UINT8, // unsigned int8 type + DT_UINT16, // unsigned int16 type + DT_UINT32, // unsigned int32 type + DT_UINT64, // unsigned int64 type + DT_COMPLEX64, // complex64 type + DT_COMPLEX128, // complex128 type + + DT_UNDEFINED, + DT_MAX // max type +}; + constexpr uint32_t FP16_SIZE = 2; constexpr uint32_t FP32_SIZE = 4; constexpr uint32_t BLOCK_SIZE = 16; @@ -236,5 +257,7 @@ inline __attribute__((always_inline)) void PpMatmulTilingCheck(const PpTilingDat TORCH_CHECK(tilingData.nLoop > 0, "nLoop is invalid"); TORCH_CHECK(tilingData.blockDim > 0, "nLoop is invalid"); } + } // namespace host_utils + #endif diff --git a/csrc/utils/kernel/common_tiling_kernel.h b/csrc/utils/kernel/common_tiling_kernel.h new file mode 100644 index 000000000..efa8135f4 --- /dev/null +++ b/csrc/utils/kernel/common_tiling_kernel.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COMMMON_TILING_KERNEL_H +#define COMMMON_TILING_KERNEL_H + +#include "kernel_operator.h" + +namespace kernel_utils { + +template +__aicore__ inline void CopyTiling(T *tiling, GM_ADDR tilingGM) +{ + uint32_t *ptr = reinterpret_cast(tiling); + auto tiling32 = reinterpret_cast<__gm__ uint32_t *>(tilingGM); + + for (int i = 0; i < sizeof(T) / sizeof(uint32_t); ++i, ++ptr) { + *ptr = *(tiling32 + i); + } +} + +} // namespace kernel_utils + +#endif // COMMMON_TILING_KERNEL_H diff --git a/csrc/utils/torch_helper.h b/csrc/utils/torch_helper.h index da0b77cb6..b12e38ee0 100644 --- a/csrc/utils/torch_helper.h +++ b/csrc/utils/torch_helper.h @@ -16,6 +16,7 @@ #include "torch_npu/csrc/core/npu/NPUStream.h" #include "torch_npu/csrc/framework/OpCommand.h" +#include "common_tiling.h" namespace sglang { namespace npu_kernel { @@ -43,6 +44,49 @@ class TorchNpuHelper return const_cast(at_tensor.data_ptr()); } + inline static host_utils::DataType ConvertDataType(const at::ScalarType type) + { + switch (type) { + case at::ScalarType::Float: + return host_utils::DataType::DT_FLOAT; + + case at::ScalarType::Half: + return host_utils::DataType::DT_FLOAT16; + + case at::ScalarType::BFloat16: + return host_utils::DataType::DT_BFLOAT16; + + case at::ScalarType::Double: + return host_utils::DataType::DT_DOUBLE; + + case at::ScalarType::Bool: + return host_utils::DataType::DT_BOOL; + + case at::ScalarType::Char: + return host_utils::DataType::DT_INT8; + + case at::ScalarType::Short: + return host_utils::DataType::DT_INT16; + + case at::ScalarType::Int: + return host_utils::DataType::DT_INT32; + + case at::ScalarType::Long: + return host_utils::DataType::DT_INT64; + + case at::ScalarType::Byte: + return host_utils::DataType::DT_UINT8; + + case at::ScalarType::ComplexFloat: + return host_utils::DataType::DT_COMPLEX64; + + case at::ScalarType::ComplexDouble: + return host_utils::DataType::DT_COMPLEX128; + } + + return host_utils::DataType::DT_MAX; + } + template inline static T ConvertType(T value) { diff --git a/include/sgl_kenel_npu_ops.h b/include/sgl_kenel_npu_ops.h index e2f5a6804..a30603b0c 100644 --- a/include/sgl_kenel_npu_ops.h +++ b/include/sgl_kenel_npu_ops.h @@ -98,6 +98,15 @@ void sgemmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len, at::Tensor &lora_ranks, at::Tensor &lora_scales, at::Tensor &y); +at::Tensor sgemmc_expand(at::Tensor &x, at::Tensor &weight, + at::Tensor &lora_indices, at::Tensor &seq_len, + at::Tensor &lora_ranks, at::Tensor &slice_offsets, + at::Tensor &y); + +void sgemmc_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, + at::Tensor &seq_len, at::Tensor &lora_ranks, + at::Tensor &lora_scales, at::Tensor &y); + #ifdef BUILD_CATLASS_MODULE void catlass_matmul_basic(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, From a7455c64cc40063d15d890ceeea6bf091f56a194 Mon Sep 17 00:00:00 2001 From: Vladimir Serov Date: Thu, 2 Apr 2026 07:22:33 +0300 Subject: [PATCH 2/5] Resolving comments --- csrc/lora/op_host/sgemmc_expand.cpp | 7 ++---- csrc/lora/op_host/sgemmc_shrink.cpp | 4 +--- csrc/lora/op_host/tiling/sgemmc_tiling.cpp | 23 +++++++++----------- csrc/lora/op_kernel/sgemmc_expand_kernel.cpp | 22 ++++++------------- csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp | 18 +++++---------- 5 files changed, 25 insertions(+), 49 deletions(-) diff --git a/csrc/lora/op_host/sgemmc_expand.cpp b/csrc/lora/op_host/sgemmc_expand.cpp index 24a0224d4..c9524e311 100644 --- a/csrc/lora/op_host/sgemmc_expand.cpp +++ b/csrc/lora/op_host/sgemmc_expand.cpp @@ -55,10 +55,8 @@ HOST_API at::Tensor sgemmc_expand(at::Tensor &x, at::Tensor &weight, at::Tensor uint32_t block_dim; uint32_t workspace_size; - int64_t num_tokens_per_core = 0; - int input_hidden_token = 0; - at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, input_hidden_token, max_lora_rank, + at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, max_lora_rank, output_full_dim, TorchNpuHelper::ConvertDataType(scalar_type)); auto workspace_tensor = at::empty({workspace_size}, at::TensorOptions().dtype(at::kByte).device(x.options().device())); @@ -66,8 +64,7 @@ HOST_API at::Tensor sgemmc_expand(at::Tensor &x, at::Tensor &weight, at::Tensor /* launch the kernel function via torch */ EXEC_KERNEL_CMD(sgemmc_expand, block_dim, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, lora_ranks_ptr, lora_ranks_size, slice_offsets_ptr, slice_offsets_size, y_ptr, - y_out_ptr, batch_size, num_tokens_per_core, max_lora_rank, output_full_dim, workspace_tensor, - tiling_tensor); + y_out_ptr, batch_size, max_lora_rank, output_full_dim, workspace_tensor, tiling_tensor); return y_out; } diff --git a/csrc/lora/op_host/sgemmc_shrink.cpp b/csrc/lora/op_host/sgemmc_shrink.cpp index 7d50f1aa6..1ca57e955 100644 --- a/csrc/lora/op_host/sgemmc_shrink.cpp +++ b/csrc/lora/op_host/sgemmc_shrink.cpp @@ -53,8 +53,6 @@ HOST_API void sgemmc_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_ uint32_t block_dim; uint32_t workspace_size; - int64_t total_extend_tokens = 0; - int64_t num_tokens_per_core = 0; at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, input_hidden_token, max_lora_rank, TorchNpuHelper::ConvertDataType(scalar_type)); @@ -64,7 +62,7 @@ HOST_API void sgemmc_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_ /* launch the kernel function via torch */ EXEC_KERNEL_CMD(sgemmc_shrink, block_dim, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, lora_ranks_ptr, lora_ranks_size, lora_scales_ptr, lora_scales_size, y_ptr, batch_size, - num_tokens_per_core, input_hidden_token, max_lora_rank, workspace_tensor, tiling_tensor); + input_hidden_token, max_lora_rank, workspace_tensor, tiling_tensor); return; } diff --git a/csrc/lora/op_host/tiling/sgemmc_tiling.cpp b/csrc/lora/op_host/tiling/sgemmc_tiling.cpp index 7ad493507..69c5d4d3b 100644 --- a/csrc/lora/op_host/tiling/sgemmc_tiling.cpp +++ b/csrc/lora/op_host/tiling/sgemmc_tiling.cpp @@ -35,23 +35,19 @@ matmul_tiling::DataType ConvertToMatMulTypes(host_utils::DataType data_type) return matmul_tiling::DataType::DT_FLOAT16; } -at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_t batch_size, uint32_t hidden_size, - uint32_t max_lora_rank, const host_utils::DataType type) +at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_t batch_size, uint32_t inner_size, + uint32_t output_size, const host_utils::DataType type) { auto ascendc_platform = *platform_ascendc::PlatformAscendCManager::GetInstance(); - uint32_t aiv_num = ascendcPlatform.GetCoreNumAiv(); - uint32_t aic_num = ascendcPlatform.GetCoreNumAic(); - workspace_size = ascendcPlatform.GetLibApiWorkSpaceSize(); + uint32_t aiv_num = ascendc_platform.GetCoreNumAiv(); + uint32_t aic_num = ascendc_platform.GetCoreNumAic(); + workspace_size = ascendc_platform.GetLibApiWorkSpaceSize(); auto tilingBuffer = at::empty({sizeof(SGEMMCTilingData)}, at::TensorOptions().dtype(at::kByte).device(at::kCPU)); SGEMMCTilingData *tiling_data = reinterpret_cast(tilingBuffer.data_ptr()); matmul_tiling::MultiCoreMatmulTiling cubeTiling(ascendc_platform); - uint32_t M = batch_size; - uint32_t N = hidden_size; - uint32_t K = max_lora_rank; - const matmul_tiling::DataType data_type = ConvertToMatMulTypes(type); cubeTiling.EnableBias(false); @@ -60,11 +56,11 @@ at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_ cubeTiling.SetCType(matmul_tiling::TPosition::VECIN, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); cubeTiling.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, data_type); - + cubeTiling.EnableMultiCoreSplitK(false); cubeTiling.SetDim(aic_num); - cubeTiling.SetOrgShape(1, hidden_size, max_lora_rank); - cubeTiling.SetShape(1, hidden_size, max_lora_rank); + cubeTiling.SetOrgShape(1, inner_size, output_size); + cubeTiling.SetShape(1, inner_size, output_size); cubeTiling.SetBufferSpace(-1, -1, -1); if (cubeTiling.GetTiling(tiling_data->cubeTiling) == -1) { @@ -73,8 +69,9 @@ at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_ } tiling_data->batch = batch_size; + tiling_data->dataType = (type == host_utils::DataType::DT_BFLOAT16); - block_dim = batch * tiling_data->cubeTiling.usedCoreNum; + block_dim = batch_size * tiling_data->cubeTiling.usedCoreNum; return tilingBuffer; } diff --git a/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp b/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp index 4455e110e..32bb42ad9 100644 --- a/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp +++ b/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp @@ -47,13 +47,12 @@ class SGEMMCExpand __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, uint32_t loraIndicesSize, GM_ADDR seqLen, uint32_t seqLenSize, GM_ADDR loraRanks, uint32_t loraRanksSize, GM_ADDR sliceOffsets, uint32_t sliceOffsetsSize, GM_ADDR yIn, GM_ADDR yOut, - uint32_t batchSize, uint32_t numBlocksPerCore, uint32_t maxLoRARank, - uint32_t outputFullDim, GM_ADDR workspace, TCubeTiling &tiling) + uint32_t batchSize, uint32_t maxLoRARank, uint32_t outputFullDim, GM_ADDR workspace, + TCubeTiling &tiling) { this->tiling = tiling; batchSize_ = batchSize; - numBlocksPerCore_ = numBlocksPerCore; maxLoRARank_ = maxLoRARank; sliceCount_ = sliceOffsetsSize - 1; outputFullDim_ = outputFullDim; @@ -78,15 +77,11 @@ class SGEMMCExpand int64_t blocks = AscendC::GetBlockNum(); int64_t blockIdx = AscendC::GetBlockIdx(); - int64_t startIdx = blockIdx * numBlocksPerCore_; - int64_t endIdx = startIdx + numBlocksPerCore_; - AscendC::WaitPreTaskEnd(); - int64_t batchIdx = 0; int64_t requestBlock = 0; lora_common::BlockIterator blockIterator(seqLenGm_); - requestBlock = blockIterator.GetBlockIdx(batchIdx); + requestBlock = blockIterator.GetBlockIdx(blockIdx); if (requestBlock < 0) { return; } @@ -178,7 +173,6 @@ class SGEMMCExpand uint32_t batchSize_; uint32_t sliceCount_; - uint32_t numBlocksPerCore_; uint32_t maxLoRARank_; uint32_t outputHiddenDim_; uint32_t sliceOffset_; @@ -197,8 +191,8 @@ extern "C" __global__ __aicore__ void sgemmc_expand(GM_ADDR x, GM_ADDR weight, G uint32_t loraIndicesSize, GM_ADDR seqLen, uint32_t seqLenSize, GM_ADDR loraRanks, uint32_t loraRanksSize, GM_ADDR sliceOffsets, uint32_t sliceOffsetsSize, GM_ADDR yIn, GM_ADDR yOut, - uint32_t batchSize, uint32_t numBlocksPerCore, uint32_t maxLoRARank, - uint32_t outputFullDim, GM_ADDR workspace, GM_ADDR tiling) + uint32_t batchSize, uint32_t maxLoRARank, uint32_t outputFullDim, + GM_ADDR workspace, GM_ADDR tiling) { KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_1); @@ -209,14 +203,12 @@ extern "C" __global__ __aicore__ void sgemmc_expand(GM_ADDR x, GM_ADDR weight, G if (tilingData.dataType == 1) { SGEMMCExpand op(&pipe); op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, sliceOffsets, - sliceOffsetsSize, yIn, yOut, batchSize, numBlocksPerCore, maxLoRARank, outputFullDim, workspace, - tilingData.cubeTiling); + sliceOffsetsSize, yIn, yOut, batchSize, maxLoRARank, outputFullDim, workspace, tilingData.cubeTiling); op.Process(); } else { SGEMMCExpand op(&pipe); op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, sliceOffsets, - sliceOffsetsSize, yIn, yOut, batchSize, numBlocksPerCore, maxLoRARank, outputFullDim, workspace, - tilingData.cubeTiling); + sliceOffsetsSize, yIn, yOut, batchSize, maxLoRARank, outputFullDim, workspace, tilingData.cubeTiling); op.Process(); } } diff --git a/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp b/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp index 0045883f8..0d7f69bb4 100644 --- a/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp +++ b/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp @@ -47,13 +47,11 @@ class SGEMMCShrink __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, uint32_t loraIndicesSize, GM_ADDR seqLen, uint32_t seqLenSize, GM_ADDR loraRanks, uint32_t loraRanksSize, GM_ADDR loraScales, uint32_t loraScalesSize, GM_ADDR y, uint32_t batchSize, - uint32_t numBlocksPerCore, uint32_t inputHiddenDim, uint32_t maxLoRARank, - GM_ADDR workspace, TCubeTiling &tiling) + uint32_t inputHiddenDim, uint32_t maxLoRARank, GM_ADDR workspace, TCubeTiling &tiling) { this->tiling = tiling; batchSize_ = batchSize; - numBlocksPerCore_ = numBlocksPerCore; inputHiddenDim_ = inputHiddenDim; maxLoRARank_ = maxLoRARank; singleLoRAWeightLen_ = inputHiddenDim_ * maxLoRARank_; @@ -76,9 +74,6 @@ class SGEMMCShrink int64_t blocks = AscendC::GetBlockNum(); int64_t blockIdx = AscendC::GetBlockIdx(); - int64_t startIdx = blockIdx * numBlocksPerCore_; - int64_t endIdx = startIdx + numBlocksPerCore_; - AscendC::WaitPreTaskEnd(); int64_t batchIdx = 0; @@ -165,7 +160,6 @@ class SGEMMCShrink AscendC::TBuf vectorCalcBuf; uint32_t batchSize_; - uint32_t numBlocksPerCore_; uint32_t inputHiddenDim_; uint32_t maxLoRARank_; uint32_t singleLoRAWeightLen_; @@ -179,8 +173,8 @@ extern "C" __global__ __aicore__ void sgemmc_shrink(GM_ADDR x, GM_ADDR weight, G uint32_t loraIndicesSize, GM_ADDR seqLen, uint32_t seqLenSize, GM_ADDR loraRanks, uint32_t loraRanksSize, GM_ADDR loraScales, uint32_t loraScalesSize, GM_ADDR y, uint32_t batchSize, - uint32_t numBlocksPerCore, uint32_t inputHiddenDim, - uint32_t maxLoRARank, GM_ADDR workspace, GM_ADDR tiling) + uint32_t inputHiddenDim, uint32_t maxLoRARank, GM_ADDR workspace, + GM_ADDR tiling) { KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_1); @@ -191,14 +185,12 @@ extern "C" __global__ __aicore__ void sgemmc_shrink(GM_ADDR x, GM_ADDR weight, G if (tilingData.dataType == 1) { SGEMMCShrink op(&pipe); op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, loraScales, - loraScalesSize, y, batchSize, numBlocksPerCore, inputHiddenDim, maxLoRARank, workspace, - tilingData.cubeTiling); + loraScalesSize, y, batchSize, inputHiddenDim, maxLoRARank, workspace, tilingData.cubeTiling); op.Process(); } else { SGEMMCShrink op(&pipe); op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, loraScales, - loraScalesSize, y, batchSize, numBlocksPerCore, inputHiddenDim, maxLoRARank, workspace, - tilingData.cubeTiling); + loraScalesSize, y, batchSize, inputHiddenDim, maxLoRARank, workspace, tilingData.cubeTiling); op.Process(); } } From a103bd6546945c78794a972473d370ff4f1c34b9 Mon Sep 17 00:00:00 2001 From: Vladimir Serov Date: Mon, 6 Apr 2026 15:04:19 +0300 Subject: [PATCH 3/5] Update pre-commit after merge --- .../mamba/mamba_state_update_triton.py | 8 ++++++-- .../python/sgl_kernel_npu/test_mamba_state_update.py | 12 ++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/sgl_kernel_npu/sgl_kernel_npu/mamba/mamba_state_update_triton.py b/python/sgl_kernel_npu/sgl_kernel_npu/mamba/mamba_state_update_triton.py index 08542cd44..854910021 100644 --- a/python/sgl_kernel_npu/sgl_kernel_npu/mamba/mamba_state_update_triton.py +++ b/python/sgl_kernel_npu/sgl_kernel_npu/mamba/mamba_state_update_triton.py @@ -113,8 +113,12 @@ def move_intermediate_cache( dst_layer_stride, dst_size_stride = int(ssm_states.stride()[0]), int( ssm_states.stride()[1] ) - assert len(dst_indices_tensor) == len(last_steps_tensor), "Destination indices lengths must match" - assert len(src_indices_tensor) == len(last_steps_tensor), "Source indices lengths must match" + assert len(dst_indices_tensor) == len( + last_steps_tensor + ), "Destination indices lengths must match" + assert len(src_indices_tensor) == len( + last_steps_tensor + ), "Source indices lengths must match" # Grid: one thread per valid index grid = (len(dst_indices_tensor),) diff --git a/tests/python/sgl_kernel_npu/test_mamba_state_update.py b/tests/python/sgl_kernel_npu/test_mamba_state_update.py index e755f9767..ed78f1797 100644 --- a/tests/python/sgl_kernel_npu/test_mamba_state_update.py +++ b/tests/python/sgl_kernel_npu/test_mamba_state_update.py @@ -139,7 +139,9 @@ def test_move_intermediate_cache( valid_indices = random.sample(population, num_valid) last_step_pos = [random.randint(0, D - 1) for _ in range(num_valid)] dst_indices_tensor = torch.tensor(valid_indices, device=device, dtype=torch.int32) - src_indices_tensor = torch.arange(dst_indices_tensor.shape[0], device=device, dtype=torch.int32) + src_indices_tensor = torch.arange( + dst_indices_tensor.shape[0], device=device, dtype=torch.int32 + ) last_steps_tensor = torch.tensor(last_step_pos, device=device, dtype=torch.int32) valid_mask = last_steps_tensor >= 0 @@ -151,6 +153,12 @@ def test_move_intermediate_cache( :, src_state_indices, valid_last_steps ] - move_intermediate_cache(dst_cache_clone, src_cache, dst_indices_tensor, src_indices_tensor, last_steps_tensor) + move_intermediate_cache( + dst_cache_clone, + src_cache, + dst_indices_tensor, + src_indices_tensor, + last_steps_tensor, + ) assert_close("move_cache", dst_cache, dst_cache_clone, 1e-3) From e26f9b894ca532efc11182448febc1baccd99246 Mon Sep 17 00:00:00 2001 From: Vladimir Serov Date: Wed, 8 Apr 2026 18:10:22 +0300 Subject: [PATCH 4/5] Fix for CANN 8.3 --- csrc/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index dbedd6ac1..fdb3938be 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -100,7 +100,6 @@ add_library(${OP_PLUGIN_NAME} SHARED ${OP_SRCS}) target_link_libraries(${OP_PLUGIN_NAME} PRIVATE workspace_kernel no_workspace_kernel - host_intf_pub torch_npu ascendcl tiling_api From a25f41cd31863d5e3c93bfeacd230b65c879c11f Mon Sep 17 00:00:00 2001 From: Vladimir Serov Date: Wed, 22 Apr 2026 10:02:02 +0300 Subject: [PATCH 5/5] Resolve gemini comments --- csrc/lora/op_host/sgemmc_expand.cpp | 5 +- csrc/lora/op_host/sgemmc_shrink.cpp | 11 +- csrc/lora/op_host/tiling/sgemmc_tiling.cpp | 20 +-- csrc/lora/op_host/tiling/sgemmc_tiling.h | 2 +- csrc/lora/op_host/tiling/sgemmc_tiling_data.h | 6 +- csrc/lora/op_kernel/sgemmc_expand_kernel.cpp | 103 ++++++++------ csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp | 130 +++++++++++------- csrc/pytorch_extensions.cpp | 2 +- include/sgl_kenel_npu_ops.h | 2 +- 9 files changed, 163 insertions(+), 118 deletions(-) diff --git a/csrc/lora/op_host/sgemmc_expand.cpp b/csrc/lora/op_host/sgemmc_expand.cpp index c9524e311..aefb902d5 100644 --- a/csrc/lora/op_host/sgemmc_expand.cpp +++ b/csrc/lora/op_host/sgemmc_expand.cpp @@ -33,6 +33,9 @@ HOST_API at::Tensor sgemmc_expand(at::Tensor &x, at::Tensor &weight, at::Tensor TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4, "weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]"); TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]"); + TORCH_CHECK(slice_offsets.dim() == 1 && slice_offsets.size(0) > 1, + "slice_offsets should be a vector of size 2 and more."); + TORCH_CHECK(lora_ranks.dim() == 1, "lora_ranks should be a vector."); at::Tensor y_out = y; void *x_ptr = x.data_ptr(); @@ -57,7 +60,7 @@ HOST_API at::Tensor sgemmc_expand(at::Tensor &x, at::Tensor &weight, at::Tensor uint32_t workspace_size; at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, max_lora_rank, output_full_dim, - TorchNpuHelper::ConvertDataType(scalar_type)); + slice_count, TorchNpuHelper::ConvertDataType(scalar_type)); auto workspace_tensor = at::empty({workspace_size}, at::TensorOptions().dtype(at::kByte).device(x.options().device())); diff --git a/csrc/lora/op_host/sgemmc_shrink.cpp b/csrc/lora/op_host/sgemmc_shrink.cpp index 1ca57e955..d9bd21270 100644 --- a/csrc/lora/op_host/sgemmc_shrink.cpp +++ b/csrc/lora/op_host/sgemmc_shrink.cpp @@ -25,7 +25,7 @@ namespace sglang { namespace npu_kernel { HOST_API void sgemmc_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len, - at::Tensor &lora_ranks, at::Tensor &lora_scales, at::Tensor &y) + at::Tensor &lora_ranks, at::Tensor &lora_scales, at::Tensor &y, int64_t slices) { at::ScalarType scalar_type = x.scalar_type(); TORCH_CHECK(scalar_type == at::kHalf || scalar_type == at::kBFloat16, "only support half and bf16"); @@ -34,9 +34,9 @@ HOST_API void sgemmc_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_ "weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]"); TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]"); TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out"); + void *x_ptr = x.data_ptr(); void *weight_ptr = weight.data_ptr(); - void *lora_indices_ptr = lora_indices.data_ptr(); int lora_indices_size = lora_indices.size(0); void *seq_len_ptr = seq_len.data_ptr(); @@ -49,20 +49,21 @@ HOST_API void sgemmc_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_ void *y_ptr = y.data_ptr(); int batch_size = x.size(0); int input_hidden_token = x.size(1); - uint32_t max_lora_rank = y.size(1); + uint32_t max_lora_rank = y.size(1) / slices; + uint32_t slice_count = slices; uint32_t block_dim; uint32_t workspace_size; at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, input_hidden_token, max_lora_rank, - TorchNpuHelper::ConvertDataType(scalar_type)); + slice_count, TorchNpuHelper::ConvertDataType(scalar_type)); auto workspace_tensor = at::empty({workspace_size}, at::TensorOptions().dtype(at::kByte).device(x.options().device())); /* launch the kernel function via torch */ EXEC_KERNEL_CMD(sgemmc_shrink, block_dim, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, lora_ranks_ptr, lora_ranks_size, lora_scales_ptr, lora_scales_size, y_ptr, batch_size, - input_hidden_token, max_lora_rank, workspace_tensor, tiling_tensor); + input_hidden_token, max_lora_rank, slice_count, workspace_tensor, tiling_tensor); return; } diff --git a/csrc/lora/op_host/tiling/sgemmc_tiling.cpp b/csrc/lora/op_host/tiling/sgemmc_tiling.cpp index 69c5d4d3b..d403084bd 100644 --- a/csrc/lora/op_host/tiling/sgemmc_tiling.cpp +++ b/csrc/lora/op_host/tiling/sgemmc_tiling.cpp @@ -36,7 +36,7 @@ matmul_tiling::DataType ConvertToMatMulTypes(host_utils::DataType data_type) } at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_t batch_size, uint32_t inner_size, - uint32_t output_size, const host_utils::DataType type) + uint32_t output_size, uint32_t slice_count, const host_utils::DataType type) { auto ascendc_platform = *platform_ascendc::PlatformAscendCManager::GetInstance(); uint32_t aiv_num = ascendc_platform.GetCoreNumAiv(); @@ -46,7 +46,7 @@ at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_ auto tilingBuffer = at::empty({sizeof(SGEMMCTilingData)}, at::TensorOptions().dtype(at::kByte).device(at::kCPU)); SGEMMCTilingData *tiling_data = reinterpret_cast(tilingBuffer.data_ptr()); - matmul_tiling::MultiCoreMatmulTiling cubeTiling(ascendc_platform); + matmul_tiling::MatmulApiTiling cubeTiling(ascendc_platform); const matmul_tiling::DataType data_type = ConvertToMatMulTypes(type); @@ -56,11 +56,9 @@ at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_ cubeTiling.SetCType(matmul_tiling::TPosition::VECIN, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); cubeTiling.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, data_type); - cubeTiling.EnableMultiCoreSplitK(false); - cubeTiling.SetDim(aic_num); - cubeTiling.SetOrgShape(1, inner_size, output_size); - cubeTiling.SetShape(1, inner_size, output_size); + cubeTiling.SetOrgShape(1, output_size, inner_size); + cubeTiling.SetShape(1, output_size, inner_size); cubeTiling.SetBufferSpace(-1, -1, -1); if (cubeTiling.GetTiling(tiling_data->cubeTiling) == -1) { @@ -68,12 +66,14 @@ at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_ return {}; } - tiling_data->batch = batch_size; - tiling_data->dataType = (type == host_utils::DataType::DT_BFLOAT16); + tiling_data->tilingKey = (type == host_utils::DataType::DT_BFLOAT16); - block_dim = batch_size * tiling_data->cubeTiling.usedCoreNum; + block_dim = batch_size * slice_count; + workspace_size = ascendc_platform.GetLibApiWorkSpaceSize() + + static_cast(batch_size * tiling_data->cubeTiling.baseM * tiling_data->cubeTiling.baseN * + sizeof(float)); - return tilingBuffer; + return TorchNpuHelper::CopyTensorHostToDevice(tilingBuffer); } } // namespace npu_kernel diff --git a/csrc/lora/op_host/tiling/sgemmc_tiling.h b/csrc/lora/op_host/tiling/sgemmc_tiling.h index 075bf6594..e0accfd67 100644 --- a/csrc/lora/op_host/tiling/sgemmc_tiling.h +++ b/csrc/lora/op_host/tiling/sgemmc_tiling.h @@ -29,7 +29,7 @@ namespace sglang { namespace npu_kernel { at::Tensor GenerateTiling(uint32_t &blockDim, uint32_t &workspace, uint32_t batch, uint32_t hidden_size, uint32_t k, - const host_utils::DataType type); + uint32_t slice_count, const host_utils::DataType type); } // namespace npu_kernel } // namespace sglang diff --git a/csrc/lora/op_host/tiling/sgemmc_tiling_data.h b/csrc/lora/op_host/tiling/sgemmc_tiling_data.h index 88c99cd8a..559608625 100644 --- a/csrc/lora/op_host/tiling/sgemmc_tiling_data.h +++ b/csrc/lora/op_host/tiling/sgemmc_tiling_data.h @@ -33,11 +33,7 @@ namespace npu_kernel { #pragma pack(push, 1) struct SGEMMCTilingData { - uint32_t dataType; - uint32_t batch; - uint32_t hidden; - uint32_t k; - uint32_t slices; + uint32_t tilingKey; AscendC::tiling::TCubeTiling cubeTiling; }; #pragma pack(pop) diff --git a/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp b/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp index 32bb42ad9..fd8b571ee 100644 --- a/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp +++ b/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp @@ -38,7 +38,7 @@ class SGEMMCExpand using X_MAT_TYPE = AscendC::MatmulType; using W_MAT_TYPE = AscendC::MatmulType; using Y_MAT_TYPE = AscendC::MatmulType; - using BIAS_MAT_TYPE = AscendC::MatmulType; + using BIAS_MAT_TYPE = AscendC::MatmulType; using MAT_TYPE = AscendC::Matmul; @@ -67,9 +67,7 @@ class SGEMMCExpand loraRanksGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(loraRanks), loraRanksSize); sliceOffsetsGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sliceOffsets), sliceOffsetsSize); - workspaceGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ Y_T *>(workspace)); - - REGIST_MATMUL_OBJ(pipe_, GetSysWorkSpacePtr(), matmulObj, &tiling); + workspaceGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ INNER_T *>(workspace)); } __aicore__ inline void Process() @@ -77,88 +75,112 @@ class SGEMMCExpand int64_t blocks = AscendC::GetBlockNum(); int64_t blockIdx = AscendC::GetBlockIdx(); - AscendC::WaitPreTaskEnd(); + if ASCEND_IS_AIV { + if (AscendC::GetSubBlockIdx() == 1) { + return; + } + blockIdx /= AscendC::GetSubBlockNum(); + } + + int64_t tokenIdx = blockIdx / sliceCount_; + int64_t sliceIdx = blockIdx % sliceCount_; - int64_t requestBlock = 0; lora_common::BlockIterator blockIterator(seqLenGm_); - requestBlock = blockIterator.GetBlockIdx(blockIdx); + int64_t requestBlock = blockIterator.GetBlockIdx(tokenIdx); if (requestBlock < 0) { return; } - int32_t reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); + reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); if (reqLoRAIndex_ < 0) { return; } - int64_t reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_; - int32_t reqLoRARank_ = loraRanksGm_.GetValue(reqLoRAIndex_); + reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_; + reqLoRARank_ = loraRanksGm_.GetValue(reqLoRAIndex_); if (reqLoRARank_ == 0) { return; } + int32_t beginSlice = sliceOffsetsGm_.GetValue(sliceIdx); + int32_t endSlice = sliceOffsetsGm_.GetValue(sliceIdx + 1); + int32_t slice = endSlice - beginSlice; + uint32_t baseM = min(tiling.baseM, tiling.singleCoreM); + uint32_t baseN = min(tiling.baseN, min(tiling.singleCoreN, slice)); + uint32_t elements = baseM * baseN; + uint32_t maxElements = tiling.baseM * tiling.baseN; + + workspaceGlobal = workspaceGlobal[blockIdx * maxElements]; + + REGIST_MATMUL_OBJ(pipe_, GetSysWorkSpacePtr(), matmulObj, &tiling); + + matmulObj.DisableBias(); matmulObj.SetWorkspace(workspaceGlobal); - matmulObj.SetTensorA(xInGm_); - matmulObj.SetTensorB(wInGm_); + matmulObj.SetOrgShape(tiling.M, tiling.N, tiling.Ka, tiling.Kb); + matmulObj.SetSingleShape(tiling.singleCoreM, slice, reqLoRARank_); + matmulObj.SetTensorA(xInGm_[tokenIdx * sliceCount_ * maxLoRARank_ + sliceIdx * reqLoRARank_], false); + matmulObj.SetTensorB(wInGm_[reqLoRAWeightOffset_ + maxLoRARank_ * beginSlice], true); matmulObj.template Iterate(); - uint32_t baseM = tiling.baseM; - uint32_t baseN = tiling.baseN; - pipe_->InitBuffer(vectorCalcBuf, baseM * baseN * sizeof(INNER_T)); - pipe_->InitBuffer(vectorInQueue, 1, baseM * baseN * sizeof(INNER_T)); - pipe_->InitBuffer(vectorYInQueue, 1, baseM * baseN * sizeof(INNER_T)); - pipe_->InitBuffer(vectorOutQueue, 1, baseM * baseN * sizeof(Y_T)); + pipe_->InitBuffer(calcBuf, maxElements * sizeof(INNER_T)); + pipe_->InitBuffer(matmulQueue, 1, maxElements * sizeof(INNER_T)); + pipe_->InitBuffer(vectorYInQueue, 1, maxElements * sizeof(Y_T)); + pipe_->InitBuffer(vectorOutQueue, 1, maxElements * sizeof(Y_T)); AscendC::DataCopyParams copyParams = {(uint16_t)baseM, (uint16_t)(baseN * sizeof(Y_T) / AscendC::DEFAULT_C0_SIZE), (uint16_t)0, (uint16_t)((tiling.N - baseN) * sizeof(Y_T) / AscendC::DEFAULT_C0_SIZE)}; - uint32_t iterateTimes = AscendC::Ceil(tiling.singleCoreM, baseM) * AscendC::Ceil(tiling.singleCoreN, baseN); + uint32_t iterateTimes = AscendC::Ceil(tiling.singleCoreM, baseM) * AscendC::Ceil(slice, baseN); + uint32_t outputOffset = tokenIdx * tiling.N + beginSlice; for (uint32_t i = 0; i < iterateTimes; ++i) { - auto cInLocal = vectorInQueue.AllocTensor(); + uint32_t offset = outputOffset + i * baseN; + auto cInLocal = matmulQueue.AllocTensor(); matmulObj.template GetTensorC(cInLocal); - vectorInQueue.EnQue(cInLocal); + matmulObj.WaitGetTensorC(); + matmulQueue.EnQue(cInLocal); AscendC::LocalTensor yInLocalCube = vectorYInQueue.AllocTensor(); - DataCopy(yInLocalCube, yInGm_[i], baseM * baseN); + DataCopy(yInLocalCube, yInGm_[offset], elements); vectorYInQueue.EnQue(yInLocalCube); - AscendC::LocalTensor tmpTensor = vectorCalcBuf.Get(); + AscendC::LocalTensor tmpTensor = calcBuf.Get(); AscendC::LocalTensor yInLocal = vectorYInQueue.DeQue(); - AscendC::LocalTensor yLocal = vectorInQueue.DeQue(); - Cast(tmpTensor, yInLocal, AscendC::RoundMode::CAST_NONE, baseM * baseN); - pipe_barrier(PIPE_V); + AscendC::Cast(tmpTensor, yInLocal, AscendC::RoundMode::CAST_NONE, elements); + AscendC::PipeBarrier(); vectorYInQueue.FreeTensor(yInLocal); - Add(yLocal, yLocal, tmpTensor, baseM * baseN); - pipe_barrier(PIPE_V); + AscendC::LocalTensor yLocal = matmulQueue.DeQue(); + AscendC::Add(tmpTensor, tmpTensor, yLocal, elements); + AscendC::PipeBarrier(); AscendC::LocalTensor yOutLocal = vectorOutQueue.AllocTensor(); - Cast(yOutLocal, yLocal, AscendC::RoundMode::CAST_RINT, baseM * baseN); - pipe_barrier(PIPE_V); + AscendC::Cast(yOutLocal, tmpTensor, AscendC::RoundMode::CAST_RINT, elements); + AscendC::PipeBarrier(); vectorOutQueue.EnQue(yOutLocal); + calcBuf.FreeTensor(tmpTensor); + matmulQueue.FreeTensor(yLocal); - // copy out - auto cOutLocal = vectorOutQueue.DeQue(); - DataCopy(yOutGm_[i], cOutLocal, copyParams); - vectorOutQueue.FreeTensor(cOutLocal); + AscendC::LocalTensor outputCopy = vectorOutQueue.DeQue(); + DataCopy(yOutGm_[offset], outputCopy, copyParams); + vectorOutQueue.FreeTensor(outputCopy); } matmulObj.End(); - AscendC::SetNextTaskStart(); } private: AscendC::TPipe *pipe_; + MAT_TYPE matmulObj; + TCubeTiling tiling; - AscendC::GlobalTensor workspaceGlobal; + AscendC::GlobalTensor workspaceGlobal; - TCubeTiling tiling; - AscendC::TQue vectorInQueue; + AscendC::TQue matmulQueue; AscendC::TQue vectorYInQueue; AscendC::TQue vectorOutQueue; - AscendC::TBuf vectorCalcBuf; + AscendC::TBuf calcBuf; AscendC::GlobalTensor xInGm_; AscendC::GlobalTensor wInGm_; @@ -167,7 +189,6 @@ class SGEMMCExpand AscendC::GlobalTensor seqLenGm_; AscendC::GlobalTensor loraIndicesGm_; - AscendC::GlobalTensor loraRanksGm_; AscendC::GlobalTensor sliceOffsetsGm_; @@ -200,7 +221,7 @@ extern "C" __global__ __aicore__ void sgemmc_expand(GM_ADDR x, GM_ADDR weight, G sglang::npu_kernel::SGEMMCTilingData tilingData; kernel_utils::CopyTiling(&tilingData, tiling); - if (tilingData.dataType == 1) { + if (tilingData.tilingKey == 1) { SGEMMCExpand op(&pipe); op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, sliceOffsets, sliceOffsetsSize, yIn, yOut, batchSize, maxLoRARank, outputFullDim, workspace, tilingData.cubeTiling); diff --git a/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp b/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp index 0d7f69bb4..942f0c2dd 100644 --- a/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp +++ b/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp @@ -26,6 +26,8 @@ #include "../op_host/tiling/sgemmc_tiling_data.h" +constexpr uint32_t BLOCK_SIZE = 16U; + template class SGEMMCShrink { @@ -38,7 +40,7 @@ class SGEMMCShrink using X_MAT_TYPE = AscendC::MatmulType; using W_MAT_TYPE = AscendC::MatmulType; using Y_MAT_TYPE = AscendC::MatmulType; - using BIAS_MAT_TYPE = AscendC::MatmulType; + using BIAS_MAT_TYPE = AscendC::MatmulType; using MAT_TYPE = AscendC::Matmul; @@ -47,14 +49,16 @@ class SGEMMCShrink __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, uint32_t loraIndicesSize, GM_ADDR seqLen, uint32_t seqLenSize, GM_ADDR loraRanks, uint32_t loraRanksSize, GM_ADDR loraScales, uint32_t loraScalesSize, GM_ADDR y, uint32_t batchSize, - uint32_t inputHiddenDim, uint32_t maxLoRARank, GM_ADDR workspace, TCubeTiling &tiling) + uint32_t inputHiddenDim, uint32_t maxLoRARank, uint32_t slices, GM_ADDR workspace, + TCubeTiling &tiling) { this->tiling = tiling; + slices_ = slices; batchSize_ = batchSize; inputHiddenDim_ = inputHiddenDim; maxLoRARank_ = maxLoRARank; - singleLoRAWeightLen_ = inputHiddenDim_ * maxLoRARank_; + singleLoRAWeightLen_ = inputHiddenDim_ * maxLoRARank_ * slices_; xInGm_.SetGlobalBuffer(reinterpret_cast<__gm__ X_T *>(x)); yOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ Y_T *>(y)); @@ -64,84 +68,103 @@ class SGEMMCShrink loraRanksGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(loraRanks), loraRanksSize); loraScalesGm_.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(loraScales), loraScalesSize); - workspaceGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ Y_T *>(workspace)); - - REGIST_MATMUL_OBJ(pipe_, GetSysWorkSpacePtr(), matmulObj, &tiling); + workspaceGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ INNER_T *>(workspace)); } __aicore__ inline void Process() { + if (GetSysWorkSpacePtr() == nullptr) { + return; + } + int64_t blocks = AscendC::GetBlockNum(); int64_t blockIdx = AscendC::GetBlockIdx(); - AscendC::WaitPreTaskEnd(); + if ASCEND_IS_AIV { + if (AscendC::GetSubBlockIdx() == 1) { + return; + } + blockIdx /= AscendC::GetSubBlockNum(); + } + + int64_t tokenIdx = blockIdx / slices_; + int64_t sliceIdx = blockIdx % slices_; - int64_t batchIdx = 0; - int64_t requestBlock = 0; lora_common::BlockIterator blockIterator(seqLenGm_); - requestBlock = blockIterator.GetBlockIdx(batchIdx); + int64_t requestBlock = blockIterator.GetBlockIdx(tokenIdx); if (requestBlock < 0) { return; } - int32_t reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); + reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); if (reqLoRAIndex_ < 0) { return; } - int64_t reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_; - int32_t reqLoRARank_ = loraRanksGm_.GetValue(reqLoRAIndex_); + reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_; + reqLoRARank_ = loraRanksGm_.GetValue(reqLoRAIndex_); + reqLoRAScale_ = loraScalesGm_.GetValue(reqLoRAIndex_); if (reqLoRARank_ == 0) { return; } + uint32_t baseM = min(tiling.baseM, tiling.singleCoreM); + uint32_t baseN = min(tiling.baseN, min(tiling.singleCoreN, reqLoRARank_)); + uint32_t elements = baseM * baseN; + uint32_t maxElements = tiling.baseM * tiling.baseN; + + workspaceGlobal = workspaceGlobal[blockIdx * maxElements]; + + REGIST_MATMUL_OBJ(pipe_, GetSysWorkSpacePtr(), matmulObj, &tiling); + + matmulObj.DisableBias(); matmulObj.SetWorkspace(workspaceGlobal); - matmulObj.SetTensorA(xInGm_); - matmulObj.SetTensorB(wInGm_); + matmulObj.SetOrgShape(tiling.M, tiling.N, tiling.Ka, tiling.Kb); + matmulObj.SetSingleShape(tiling.singleCoreM, reqLoRARank_, tiling.singleCoreK); + matmulObj.SetTensorA(xInGm_[tokenIdx * inputHiddenDim_], false); + matmulObj.SetTensorB(wInGm_[reqLoRAWeightOffset_ + sliceIdx * inputHiddenDim_ * reqLoRARank_], true); matmulObj.template Iterate(); - half loraScale = loraScalesGm_.GetValue(reqLoRAIndex_); - INNER_T scalar = AscendC::ScalarCast(loraScale); - - uint32_t baseM = this->tiling.baseM; - uint32_t baseN = this->tiling.baseN; - pipe_->InitBuffer(vectorCalcBuf, baseM * baseN * sizeof(INNER_T)); - pipe_->InitBuffer(vectorInQueue, 1, baseM * baseN * sizeof(INNER_T)); - pipe_->InitBuffer(vectorOutQueue, 1, baseM * baseN * sizeof(Y_T)); + pipe_->InitBuffer(calcBuf, maxElements * sizeof(INNER_T)); + pipe_->InitBuffer(matmulQueue, 1, maxElements * sizeof(INNER_T)); + pipe_->InitBuffer(outQueue, 1, maxElements * sizeof(Y_T)); AscendC::DataCopyParams copyParams = { (uint16_t)baseM, (uint16_t)(baseN * sizeof(Y_T) / AscendC::DEFAULT_C0_SIZE), (uint16_t)0, - (uint16_t)((this->tiling.N - baseN) * sizeof(Y_T) / AscendC::DEFAULT_C0_SIZE)}; - uint32_t iterateTimes = - AscendC::Ceil(this->tiling.singleCoreM, baseM) * AscendC::Ceil(this->tiling.singleCoreN, baseN); - for (uint32_t i = 0; i < iterateTimes; ++i) { - // compute - auto cInLocal = vectorInQueue.AllocTensor(); + (uint16_t)((slices_ * tiling.N - baseN) * sizeof(Y_T) / AscendC::DEFAULT_C0_SIZE)}; + uint32_t iteratations = AscendC::Ceil(tiling.singleCoreM, baseM) * AscendC::Ceil(reqLoRARank_, baseN); + uint32_t outputOffset = tokenIdx * slices_ * maxLoRARank_ + sliceIdx * reqLoRARank_; + for (uint32_t i = 0; i < iteratations; ++i) { + AscendC::LocalTensor cInLocal = matmulQueue.AllocTensor(); matmulObj.template GetTensorC(cInLocal); - vectorInQueue.EnQue(cInLocal); - // any vector operator - auto src = vectorInQueue.DeQue(); - auto dst = vectorOutQueue.AllocTensor(); + matmulObj.WaitGetTensorC(); + matmulQueue.EnQue(cInLocal); - AscendC::LocalTensor tmpTensor = vectorCalcBuf.Get(); - AscendC::Muls(tmpTensor, src, scalar, baseM * baseN); + AscendC::LocalTensor tmpTensor = calcBuf.Get(); + AscendC::LocalTensor mmResTensor = matmulQueue.DeQue(); + AscendC::LocalTensor output = outQueue.AllocTensor(); + + AscendC::Muls(tmpTensor, mmResTensor, reqLoRAScale_, elements); AscendC::PipeBarrier(); - AscendC::Cast(dst, tmpTensor, AscendC::RoundMode::CAST_NONE, baseM * baseN); + AscendC::Cast(output, tmpTensor, AscendC::RoundMode::CAST_RINT, elements); AscendC::PipeBarrier(); - vectorOutQueue.EnQue(dst); - vectorInQueue.FreeTensor(src); - // copy out - auto cOutLocal = vectorOutQueue.DeQue(); - DataCopy(yOutGm_[i], cOutLocal, copyParams); - vectorOutQueue.FreeTensor(cOutLocal); + + outQueue.EnQue(output); + matmulQueue.FreeTensor(mmResTensor); + calcBuf.FreeTensor(tmpTensor); + + AscendC::LocalTensor outputCopy = outQueue.DeQue(); + DataCopy(yOutGm_[outputOffset + i * baseN], outputCopy, copyParams); + outQueue.FreeTensor(outputCopy); } matmulObj.End(); - AscendC::SetNextTaskStart(); } private: AscendC::TPipe *pipe_; + + TCubeTiling tiling; MAT_TYPE matmulObj; AscendC::GlobalTensor xInGm_; @@ -152,13 +175,13 @@ class SGEMMCShrink AscendC::GlobalTensor loraRanksGm_; AscendC::GlobalTensor loraScalesGm_; - AscendC::GlobalTensor workspaceGlobal; + AscendC::GlobalTensor workspaceGlobal; - TCubeTiling tiling; - AscendC::TQue vectorInQueue; - AscendC::TQue vectorOutQueue; - AscendC::TBuf vectorCalcBuf; + AscendC::TBuf calcBuf; + AscendC::TQue matmulQueue; + AscendC::TQue outQueue; + uint32_t slices_; uint32_t batchSize_; uint32_t inputHiddenDim_; uint32_t maxLoRARank_; @@ -167,14 +190,15 @@ class SGEMMCShrink uint64_t reqLoRAWeightOffset_; int32_t reqLoRAIndex_; int32_t reqLoRARank_; + INNER_T reqLoRAScale_; }; extern "C" __global__ __aicore__ void sgemmc_shrink(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, uint32_t loraIndicesSize, GM_ADDR seqLen, uint32_t seqLenSize, GM_ADDR loraRanks, uint32_t loraRanksSize, GM_ADDR loraScales, uint32_t loraScalesSize, GM_ADDR y, uint32_t batchSize, - uint32_t inputHiddenDim, uint32_t maxLoRARank, GM_ADDR workspace, - GM_ADDR tiling) + uint32_t inputHiddenDim, uint32_t maxLoRARank, uint32_t slices, + GM_ADDR workspace, GM_ADDR tiling) { KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_1); @@ -182,15 +206,15 @@ extern "C" __global__ __aicore__ void sgemmc_shrink(GM_ADDR x, GM_ADDR weight, G sglang::npu_kernel::SGEMMCTilingData tilingData; kernel_utils::CopyTiling(&tilingData, tiling); - if (tilingData.dataType == 1) { + if (tilingData.tilingKey == 1) { SGEMMCShrink op(&pipe); op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, loraScales, - loraScalesSize, y, batchSize, inputHiddenDim, maxLoRARank, workspace, tilingData.cubeTiling); + loraScalesSize, y, batchSize, inputHiddenDim, maxLoRARank, slices, workspace, tilingData.cubeTiling); op.Process(); } else { SGEMMCShrink op(&pipe); op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, loraScales, - loraScalesSize, y, batchSize, inputHiddenDim, maxLoRARank, workspace, tilingData.cubeTiling); + loraScalesSize, y, batchSize, inputHiddenDim, maxLoRARank, slices, workspace, tilingData.cubeTiling); op.Process(); } } diff --git a/csrc/pytorch_extensions.cpp b/csrc/pytorch_extensions.cpp index 4a6810cf6..2ad55bdd8 100644 --- a/csrc/pytorch_extensions.cpp +++ b/csrc/pytorch_extensions.cpp @@ -105,7 +105,7 @@ TORCH_LIBRARY_FRAGMENT(npu, m) m.def( "sgemmc_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! lora_ranks," - " Tensor! lora_scales, Tensor! y) -> ()"); + " Tensor! lora_scales, Tensor! y, int slice_count) -> ()"); #ifdef BUILD_CATLASS_MODULE m.def("catlass_matmul_basic(Tensor tensor_a, Tensor tensor_b, Tensor(a!) tensor_c, str? format_mode=None) -> ()"); diff --git a/include/sgl_kenel_npu_ops.h b/include/sgl_kenel_npu_ops.h index a156d93b1..578c4ef40 100644 --- a/include/sgl_kenel_npu_ops.h +++ b/include/sgl_kenel_npu_ops.h @@ -113,7 +113,7 @@ at::Tensor sgemmc_expand(at::Tensor &x, at::Tensor &weight, void sgemmc_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len, at::Tensor &lora_ranks, - at::Tensor &lora_scales, at::Tensor &y); + at::Tensor &lora_scales, at::Tensor &y, int64_t slice_count); #ifdef BUILD_CATLASS_MODULE void catlass_matmul_basic(const at::Tensor &tensor_a,