diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index c5e33117b..fdb3938be 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -19,6 +19,9 @@ FILE(GLOB OP_SRCS ${PROJECT_OP_SRC_BASE}/lora/op_host/sgmv_shrink.cpp ${PROJECT_OP_SRC_BASE}/lora/op_host/sgemmv_expand.cpp ${PROJECT_OP_SRC_BASE}/lora/op_host/sgemmv_shrink.cpp + ${PROJECT_OP_SRC_BASE}/lora/op_host/sgemmc_expand.cpp + ${PROJECT_OP_SRC_BASE}/lora/op_host/sgemmc_shrink.cpp + ${PROJECT_OP_SRC_BASE}/lora/op_host/tiling/sgemmc_tiling.cpp ${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/lightning_indexer.cpp ${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/tiling/lightning_indexer_tiling.cpp ${PROJECT_OP_SRC_BASE}/tri_inv/op_host/tri_inv.cpp @@ -51,6 +54,10 @@ ascendc_library(no_workspace_kernel STATIC ${PROJECT_OP_SRC_BASE}/recurrent_gated_delta_rule/op_kernel/recurrent_gated_delta_rule_kernel.cpp ) +ascendc_include_directories(no_workspace_kernel PRIVATE + ${PROJECT_OP_SRC_BASE}/utils/kernel +) + # kernel side files with workspace set(WORKSPACE_KERNEL_SRCS ${PROJECT_OP_SRC_BASE}/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp @@ -59,6 +66,8 @@ set(WORKSPACE_KERNEL_SRCS ${PROJECT_OP_SRC_BASE}/lightning_indexer/op_kernel/lightning_indexer_kernel.cpp ${PROJECT_OP_SRC_BASE}/causal_conv1d_update/op_kernel/causal_conv1d_update.cpp ${PROJECT_OP_SRC_BASE}/causal_conv1d/op_kernel/causal_conv1d.cpp + ${PROJECT_OP_SRC_BASE}/lora/op_kernel/sgemmc_expand_kernel.cpp + ${PROJECT_OP_SRC_BASE}/lora/op_kernel/sgemmc_shrink_kernel.cpp ) if(BUILD_CATLASS_MODULE) list(APPEND WORKSPACE_KERNEL_SRCS @@ -76,6 +85,10 @@ if(BUILD_CATLASS_MODULE) ) endif() +ascendc_include_directories(workspace_kernel PRIVATE + ${PROJECT_OP_SRC_BASE}/utils/kernel +) + ascendc_compile_definitions(workspace_kernel PRIVATE -DHAVE_WORKSPACE -DHAVE_TILING @@ -115,6 +128,7 @@ target_include_directories(${OP_PLUGIN_NAME} PRIVATE ${TORCH_DIR}/include ${TORCH_DIR}/include/torch/csrc/api/include ${TORCH_NPU_DIR}/include + ${ASCEND_INCLUDE_DIR} ${ASCEND_INCLUDE_DIR}/external ${ASCEND_INCLUDE_DIR}/experiment/platform ${ASCEND_INCLUDE_DIR}/experiment/runtime diff --git a/csrc/lora/op_host/sgemmc_expand.cpp b/csrc/lora/op_host/sgemmc_expand.cpp new file mode 100644 index 000000000..aefb902d5 --- /dev/null +++ b/csrc/lora/op_host/sgemmc_expand.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "defines.h" +#include "tiling/sgemmc_tiling.h" +#include "torch_helper.h" + +#include "aclrtlaunch_sgemmc_expand.h" + +namespace sglang { +namespace npu_kernel { + +HOST_API at::Tensor sgemmc_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len, + at::Tensor &lora_ranks, at::Tensor &slice_offsets, at::Tensor &y) +{ + at::ScalarType scalar_type = y.scalar_type(); + TORCH_CHECK(scalar_type == at::kHalf || scalar_type == at::kBFloat16, "only support half and bf16"); + TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]"); + TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4, + "weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]"); + TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]"); + TORCH_CHECK(slice_offsets.dim() == 1 && slice_offsets.size(0) > 1, + "slice_offsets should be a vector of size 2 and more."); + TORCH_CHECK(lora_ranks.dim() == 1, "lora_ranks should be a vector."); + + at::Tensor y_out = y; + void *x_ptr = x.data_ptr(); + void *weight_ptr = weight.data_ptr(); + void *y_ptr = y.data_ptr(); + void *y_out_ptr = y_out.data_ptr(); + + void *lora_indices_ptr = lora_indices.data_ptr(); + int lora_indices_size = lora_indices.size(0); + void *seq_len_ptr = seq_len.data_ptr(); + int seq_len_size = seq_len.size(0); + void *lora_ranks_ptr = lora_ranks.data_ptr(); + int lora_ranks_size = lora_ranks.size(0); + void *slice_offsets_ptr = slice_offsets.data_ptr(); + int slice_offsets_size = slice_offsets.size(0); + int slice_count = slice_offsets_size - 1; + int batch_size = x.size(0); + int max_lora_rank = x.size(1) / slice_count; + int output_full_dim = y.size(1); + + uint32_t block_dim; + uint32_t workspace_size; + + at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, max_lora_rank, output_full_dim, + slice_count, TorchNpuHelper::ConvertDataType(scalar_type)); + auto workspace_tensor = + at::empty({workspace_size}, at::TensorOptions().dtype(at::kByte).device(x.options().device())); + + /* launch the kernel function via torch */ + EXEC_KERNEL_CMD(sgemmc_expand, block_dim, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, + seq_len_size, lora_ranks_ptr, lora_ranks_size, slice_offsets_ptr, slice_offsets_size, y_ptr, + y_out_ptr, batch_size, max_lora_rank, output_full_dim, workspace_tensor, tiling_tensor); + + return y_out; +} + +} // namespace npu_kernel +} // namespace sglang diff --git a/csrc/lora/op_host/sgemmc_shrink.cpp b/csrc/lora/op_host/sgemmc_shrink.cpp new file mode 100644 index 000000000..d9bd21270 --- /dev/null +++ b/csrc/lora/op_host/sgemmc_shrink.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "defines.h" +#include "tiling/sgemmc_tiling.h" +#include "torch_helper.h" + +#include "aclrtlaunch_sgemmc_shrink.h" + +namespace sglang { +namespace npu_kernel { + +HOST_API void sgemmc_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len, + at::Tensor &lora_ranks, at::Tensor &lora_scales, at::Tensor &y, int64_t slices) +{ + at::ScalarType scalar_type = x.scalar_type(); + TORCH_CHECK(scalar_type == at::kHalf || scalar_type == at::kBFloat16, "only support half and bf16"); + TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]"); + TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4, + "weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]"); + TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]"); + TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out"); + + void *x_ptr = x.data_ptr(); + void *weight_ptr = weight.data_ptr(); + void *lora_indices_ptr = lora_indices.data_ptr(); + int lora_indices_size = lora_indices.size(0); + void *seq_len_ptr = seq_len.data_ptr(); + int seq_len_size = seq_len.size(0); + void *lora_ranks_ptr = lora_ranks.data_ptr(); + int lora_ranks_size = lora_ranks.size(0); + void *lora_scales_ptr = lora_scales.data_ptr(); + int lora_scales_size = lora_scales.size(0); + + void *y_ptr = y.data_ptr(); + int batch_size = x.size(0); + int input_hidden_token = x.size(1); + uint32_t max_lora_rank = y.size(1) / slices; + uint32_t slice_count = slices; + + uint32_t block_dim; + uint32_t workspace_size; + + at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, input_hidden_token, max_lora_rank, + slice_count, TorchNpuHelper::ConvertDataType(scalar_type)); + + auto workspace_tensor = + at::empty({workspace_size}, at::TensorOptions().dtype(at::kByte).device(x.options().device())); + /* launch the kernel function via torch */ + EXEC_KERNEL_CMD(sgemmc_shrink, block_dim, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, + seq_len_size, lora_ranks_ptr, lora_ranks_size, lora_scales_ptr, lora_scales_size, y_ptr, batch_size, + input_hidden_token, max_lora_rank, slice_count, workspace_tensor, tiling_tensor); + return; +} + +} // namespace npu_kernel +} // namespace sglang diff --git a/csrc/lora/op_host/tiling/sgemmc_tiling.cpp b/csrc/lora/op_host/tiling/sgemmc_tiling.cpp new file mode 100644 index 000000000..d403084bd --- /dev/null +++ b/csrc/lora/op_host/tiling/sgemmc_tiling.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "common.h" +#include "sgemmc_tiling.h" + +namespace sglang { +namespace npu_kernel { + +matmul_tiling::DataType ConvertToMatMulTypes(host_utils::DataType data_type) +{ + switch (data_type) { + case host_utils::DataType::DT_BFLOAT16: + return matmul_tiling::DataType::DT_BFLOAT16; + case host_utils::DataType::DT_FLOAT: + return matmul_tiling::DataType::DT_FLOAT; + case host_utils::DataType::DT_FLOAT16: + return matmul_tiling::DataType::DT_FLOAT16; + } + + return matmul_tiling::DataType::DT_FLOAT16; +} + +at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_t batch_size, uint32_t inner_size, + uint32_t output_size, uint32_t slice_count, const host_utils::DataType type) +{ + auto ascendc_platform = *platform_ascendc::PlatformAscendCManager::GetInstance(); + uint32_t aiv_num = ascendc_platform.GetCoreNumAiv(); + uint32_t aic_num = ascendc_platform.GetCoreNumAic(); + workspace_size = ascendc_platform.GetLibApiWorkSpaceSize(); + + auto tilingBuffer = at::empty({sizeof(SGEMMCTilingData)}, at::TensorOptions().dtype(at::kByte).device(at::kCPU)); + SGEMMCTilingData *tiling_data = reinterpret_cast(tilingBuffer.data_ptr()); + + matmul_tiling::MatmulApiTiling cubeTiling(ascendc_platform); + + const matmul_tiling::DataType data_type = ConvertToMatMulTypes(type); + + cubeTiling.EnableBias(false); + cubeTiling.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::VECTOR, data_type, false); + cubeTiling.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, data_type, true); + cubeTiling.SetCType(matmul_tiling::TPosition::VECIN, matmul_tiling::CubeFormat::ND, + matmul_tiling::DataType::DT_FLOAT); + cubeTiling.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, data_type); + + cubeTiling.SetOrgShape(1, output_size, inner_size); + cubeTiling.SetShape(1, output_size, inner_size); + cubeTiling.SetBufferSpace(-1, -1, -1); + + if (cubeTiling.GetTiling(tiling_data->cubeTiling) == -1) { + TORCH_CHECK(false, "Generate tiling failed."); + return {}; + } + + tiling_data->tilingKey = (type == host_utils::DataType::DT_BFLOAT16); + + block_dim = batch_size * slice_count; + workspace_size = ascendc_platform.GetLibApiWorkSpaceSize() + + static_cast(batch_size * tiling_data->cubeTiling.baseM * tiling_data->cubeTiling.baseN * + sizeof(float)); + + return TorchNpuHelper::CopyTensorHostToDevice(tilingBuffer); +} + +} // namespace npu_kernel +} // namespace sglang diff --git a/csrc/lora/op_host/tiling/sgemmc_tiling.h b/csrc/lora/op_host/tiling/sgemmc_tiling.h new file mode 100644 index 000000000..e0accfd67 --- /dev/null +++ b/csrc/lora/op_host/tiling/sgemmc_tiling.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef SGEMMC_TILING_H +#define SGEMMC_TILING_H + +#include +#include + +#include "torch_helper.h" +#include "common_tiling.h" +#include "sgemmc_tiling_data.h" + +namespace sglang { +namespace npu_kernel { + +at::Tensor GenerateTiling(uint32_t &blockDim, uint32_t &workspace, uint32_t batch, uint32_t hidden_size, uint32_t k, + uint32_t slice_count, const host_utils::DataType type); + +} // namespace npu_kernel +} // namespace sglang + +#endif // SGEMMC_TILING_H diff --git a/csrc/lora/op_host/tiling/sgemmc_tiling_data.h b/csrc/lora/op_host/tiling/sgemmc_tiling_data.h new file mode 100644 index 000000000..559608625 --- /dev/null +++ b/csrc/lora/op_host/tiling/sgemmc_tiling_data.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef SGEMMC_TILING_DATA_H +#define SGEMMC_TILING_DATA_H + +#include + +namespace AscendC { +namespace tiling { + +struct TCubeTiling; + +} // namespace tiling +} // namespace AscendC + +namespace sglang { +namespace npu_kernel { + +#pragma pack(push, 1) +struct SGEMMCTilingData { + uint32_t tilingKey; + AscendC::tiling::TCubeTiling cubeTiling; +}; +#pragma pack(pop) + +} // namespace npu_kernel +} // namespace sglang + +#endif // SGEMMC_TILING_DATA_H diff --git a/csrc/lora/op_kernel/lora_common_kernel.h b/csrc/lora/op_kernel/lora_common_kernel.h new file mode 100644 index 000000000..34ce9bb59 --- /dev/null +++ b/csrc/lora/op_kernel/lora_common_kernel.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef SGL_KERNEL_NPU_KERNEL_LORA_COMMON_H +#define SGL_KERNEL_NPU_KERNEL_LORA_COMMON_H + +#include "kernel_operator.h" + +namespace lora_common { + +template +class BlockIterator +{ + AscendC::GlobalTensor blocks; + int64_t previous_block; + int64_t previous_offset; + +public: + __aicore__ explicit BlockIterator(AscendC::GlobalTensor &blocks_) + : blocks(blocks_), previous_block(0), previous_offset(0) + {} + __aicore__ inline int64_t GetBlockIdx(int64_t index) + { + int64_t current_offset = previous_offset; + uint64_t blockIdx = previous_block; + + for (; blockIdx < blocks.GetSize(); ++blockIdx) { + int64_t blockOffset = blocks.GetValue(blockIdx); + if (index >= current_offset + blockOffset) { + current_offset += blockOffset; + } else { + previous_offset = current_offset; + previous_block = blockIdx; + return blockIdx; + } + } + + return -1; + } +}; + +} // namespace lora_common + +#endif // SGL_KERNEL_NPU_KERNEL_LORA_COMMON_H diff --git a/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp b/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp new file mode 100644 index 000000000..fd8b571ee --- /dev/null +++ b/csrc/lora/op_kernel/sgemmc_expand_kernel.cpp @@ -0,0 +1,237 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef SGL_KERNEL_NPU_KERNEL_SGEMMC_EXPAND_H +#define SGL_KERNEL_NPU_KERNEL_SGEMMC_EXPAND_H + +#include +#include "kernel_operator.h" +#include "lib/matmul_intf.h" +#include "lora_common_kernel.h" +#include "common_tiling_kernel.h" + +#include "../op_host/tiling/sgemmc_tiling_data.h" + +template +class SGEMMCExpand +{ +public: + using X_T = scalar_t; + using W_T = scalar_t; + using INNER_T = inner_t; + using Y_T = scalar_t; + + using X_MAT_TYPE = AscendC::MatmulType; + using W_MAT_TYPE = AscendC::MatmulType; + using Y_MAT_TYPE = AscendC::MatmulType; + using BIAS_MAT_TYPE = AscendC::MatmulType; + + using MAT_TYPE = AscendC::Matmul; + +public: + __aicore__ explicit SGEMMCExpand(AscendC::TPipe *pipe) : pipe_(pipe) {} + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, uint32_t loraIndicesSize, + GM_ADDR seqLen, uint32_t seqLenSize, GM_ADDR loraRanks, uint32_t loraRanksSize, + GM_ADDR sliceOffsets, uint32_t sliceOffsetsSize, GM_ADDR yIn, GM_ADDR yOut, + uint32_t batchSize, uint32_t maxLoRARank, uint32_t outputFullDim, GM_ADDR workspace, + TCubeTiling &tiling) + { + this->tiling = tiling; + + batchSize_ = batchSize; + maxLoRARank_ = maxLoRARank; + sliceCount_ = sliceOffsetsSize - 1; + outputFullDim_ = outputFullDim; + singleLoRAWeightLen_ = maxLoRARank_ * outputFullDim_; + + xInGm_.SetGlobalBuffer(reinterpret_cast<__gm__ X_T *>(x)); + wInGm_.SetGlobalBuffer(reinterpret_cast<__gm__ W_T *>(weight)); + yInGm_.SetGlobalBuffer(reinterpret_cast<__gm__ Y_T *>(yIn)); + yOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ Y_T *>(yOut)); + loraIndicesGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(loraIndices), loraIndicesSize); + seqLenGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(seqLen), seqLenSize); + loraRanksGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(loraRanks), loraRanksSize); + sliceOffsetsGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sliceOffsets), sliceOffsetsSize); + + workspaceGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ INNER_T *>(workspace)); + } + + __aicore__ inline void Process() + { + int64_t blocks = AscendC::GetBlockNum(); + int64_t blockIdx = AscendC::GetBlockIdx(); + + if ASCEND_IS_AIV { + if (AscendC::GetSubBlockIdx() == 1) { + return; + } + blockIdx /= AscendC::GetSubBlockNum(); + } + + int64_t tokenIdx = blockIdx / sliceCount_; + int64_t sliceIdx = blockIdx % sliceCount_; + + lora_common::BlockIterator blockIterator(seqLenGm_); + int64_t requestBlock = blockIterator.GetBlockIdx(tokenIdx); + if (requestBlock < 0) { + return; + } + + reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); + if (reqLoRAIndex_ < 0) { + return; + } + + reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_; + reqLoRARank_ = loraRanksGm_.GetValue(reqLoRAIndex_); + + if (reqLoRARank_ == 0) { + return; + } + + int32_t beginSlice = sliceOffsetsGm_.GetValue(sliceIdx); + int32_t endSlice = sliceOffsetsGm_.GetValue(sliceIdx + 1); + int32_t slice = endSlice - beginSlice; + uint32_t baseM = min(tiling.baseM, tiling.singleCoreM); + uint32_t baseN = min(tiling.baseN, min(tiling.singleCoreN, slice)); + uint32_t elements = baseM * baseN; + uint32_t maxElements = tiling.baseM * tiling.baseN; + + workspaceGlobal = workspaceGlobal[blockIdx * maxElements]; + + REGIST_MATMUL_OBJ(pipe_, GetSysWorkSpacePtr(), matmulObj, &tiling); + + matmulObj.DisableBias(); + matmulObj.SetWorkspace(workspaceGlobal); + matmulObj.SetOrgShape(tiling.M, tiling.N, tiling.Ka, tiling.Kb); + matmulObj.SetSingleShape(tiling.singleCoreM, slice, reqLoRARank_); + matmulObj.SetTensorA(xInGm_[tokenIdx * sliceCount_ * maxLoRARank_ + sliceIdx * reqLoRARank_], false); + matmulObj.SetTensorB(wInGm_[reqLoRAWeightOffset_ + maxLoRARank_ * beginSlice], true); + matmulObj.template Iterate(); + + pipe_->InitBuffer(calcBuf, maxElements * sizeof(INNER_T)); + pipe_->InitBuffer(matmulQueue, 1, maxElements * sizeof(INNER_T)); + pipe_->InitBuffer(vectorYInQueue, 1, maxElements * sizeof(Y_T)); + pipe_->InitBuffer(vectorOutQueue, 1, maxElements * sizeof(Y_T)); + + AscendC::DataCopyParams copyParams = {(uint16_t)baseM, + (uint16_t)(baseN * sizeof(Y_T) / AscendC::DEFAULT_C0_SIZE), (uint16_t)0, + (uint16_t)((tiling.N - baseN) * sizeof(Y_T) / AscendC::DEFAULT_C0_SIZE)}; + uint32_t iterateTimes = AscendC::Ceil(tiling.singleCoreM, baseM) * AscendC::Ceil(slice, baseN); + uint32_t outputOffset = tokenIdx * tiling.N + beginSlice; + for (uint32_t i = 0; i < iterateTimes; ++i) { + uint32_t offset = outputOffset + i * baseN; + auto cInLocal = matmulQueue.AllocTensor(); + matmulObj.template GetTensorC(cInLocal); + matmulObj.WaitGetTensorC(); + matmulQueue.EnQue(cInLocal); + + AscendC::LocalTensor yInLocalCube = vectorYInQueue.AllocTensor(); + DataCopy(yInLocalCube, yInGm_[offset], elements); + vectorYInQueue.EnQue(yInLocalCube); + + AscendC::LocalTensor tmpTensor = calcBuf.Get(); + AscendC::LocalTensor yInLocal = vectorYInQueue.DeQue(); + AscendC::Cast(tmpTensor, yInLocal, AscendC::RoundMode::CAST_NONE, elements); + AscendC::PipeBarrier(); + vectorYInQueue.FreeTensor(yInLocal); + + AscendC::LocalTensor yLocal = matmulQueue.DeQue(); + AscendC::Add(tmpTensor, tmpTensor, yLocal, elements); + AscendC::PipeBarrier(); + + AscendC::LocalTensor yOutLocal = vectorOutQueue.AllocTensor(); + AscendC::Cast(yOutLocal, tmpTensor, AscendC::RoundMode::CAST_RINT, elements); + AscendC::PipeBarrier(); + + vectorOutQueue.EnQue(yOutLocal); + calcBuf.FreeTensor(tmpTensor); + matmulQueue.FreeTensor(yLocal); + + AscendC::LocalTensor outputCopy = vectorOutQueue.DeQue(); + DataCopy(yOutGm_[offset], outputCopy, copyParams); + vectorOutQueue.FreeTensor(outputCopy); + } + matmulObj.End(); + } + +private: + AscendC::TPipe *pipe_; + + MAT_TYPE matmulObj; + TCubeTiling tiling; + + AscendC::GlobalTensor workspaceGlobal; + + AscendC::TQue matmulQueue; + AscendC::TQue vectorYInQueue; + AscendC::TQue vectorOutQueue; + AscendC::TBuf calcBuf; + + AscendC::GlobalTensor xInGm_; + AscendC::GlobalTensor wInGm_; + AscendC::GlobalTensor yInGm_; + AscendC::GlobalTensor yOutGm_; + + AscendC::GlobalTensor seqLenGm_; + AscendC::GlobalTensor loraIndicesGm_; + AscendC::GlobalTensor loraRanksGm_; + AscendC::GlobalTensor sliceOffsetsGm_; + + uint32_t batchSize_; + uint32_t sliceCount_; + uint32_t maxLoRARank_; + uint32_t outputHiddenDim_; + uint32_t sliceOffset_; + uint32_t outputFullDim_; + uint32_t singleLoRAWeightLen_; + int64_t reqLoRAIndex_; + int32_t reqLoRARank_; + uint64_t reqLoRAWeightOffset_; + int32_t reqSlice_; + uint32_t numOutputElementsPerInputTile_; + uint32_t numStreamInPerOutputTile_; + uint64_t yOffset_; +}; + +extern "C" __global__ __aicore__ void sgemmc_expand(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, + uint32_t loraIndicesSize, GM_ADDR seqLen, uint32_t seqLenSize, + GM_ADDR loraRanks, uint32_t loraRanksSize, GM_ADDR sliceOffsets, + uint32_t sliceOffsetsSize, GM_ADDR yIn, GM_ADDR yOut, + uint32_t batchSize, uint32_t maxLoRARank, uint32_t outputFullDim, + GM_ADDR workspace, GM_ADDR tiling) +{ + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_1); + + AscendC::TPipe pipe; + sglang::npu_kernel::SGEMMCTilingData tilingData; + kernel_utils::CopyTiling(&tilingData, tiling); + + if (tilingData.tilingKey == 1) { + SGEMMCExpand op(&pipe); + op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, sliceOffsets, + sliceOffsetsSize, yIn, yOut, batchSize, maxLoRARank, outputFullDim, workspace, tilingData.cubeTiling); + op.Process(); + } else { + SGEMMCExpand op(&pipe); + op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, sliceOffsets, + sliceOffsetsSize, yIn, yOut, batchSize, maxLoRARank, outputFullDim, workspace, tilingData.cubeTiling); + op.Process(); + } +} + +#endif // SGL_KERNEL_NPU_KERNEL_SGEMMC_EXPAND_H diff --git a/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp b/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp new file mode 100644 index 000000000..942f0c2dd --- /dev/null +++ b/csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp @@ -0,0 +1,222 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef SGL_KERNEL_NPU_KERNEL_SGEMMC_SHRINK_H +#define SGL_KERNEL_NPU_KERNEL_SGEMMC_SHRINK_H + +#include +#include "kernel_operator.h" +#include "lib/matmul_intf.h" +#include "lora_common_kernel.h" +#include "common_tiling_kernel.h" + +#include "../op_host/tiling/sgemmc_tiling_data.h" + +constexpr uint32_t BLOCK_SIZE = 16U; + +template +class SGEMMCShrink +{ +public: + using X_T = scalar_t; + using W_T = scalar_t; + using INNER_T = inner_t; + using Y_T = scalar_t; + + using X_MAT_TYPE = AscendC::MatmulType; + using W_MAT_TYPE = AscendC::MatmulType; + using Y_MAT_TYPE = AscendC::MatmulType; + using BIAS_MAT_TYPE = AscendC::MatmulType; + + using MAT_TYPE = AscendC::Matmul; + +public: + __aicore__ explicit SGEMMCShrink(AscendC::TPipe *pipe) : pipe_(pipe) {} + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, uint32_t loraIndicesSize, + GM_ADDR seqLen, uint32_t seqLenSize, GM_ADDR loraRanks, uint32_t loraRanksSize, + GM_ADDR loraScales, uint32_t loraScalesSize, GM_ADDR y, uint32_t batchSize, + uint32_t inputHiddenDim, uint32_t maxLoRARank, uint32_t slices, GM_ADDR workspace, + TCubeTiling &tiling) + { + this->tiling = tiling; + + slices_ = slices; + batchSize_ = batchSize; + inputHiddenDim_ = inputHiddenDim; + maxLoRARank_ = maxLoRARank; + singleLoRAWeightLen_ = inputHiddenDim_ * maxLoRARank_ * slices_; + + xInGm_.SetGlobalBuffer(reinterpret_cast<__gm__ X_T *>(x)); + yOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ Y_T *>(y)); + wInGm_.SetGlobalBuffer(reinterpret_cast<__gm__ W_T *>(weight)); + loraIndicesGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(loraIndices), loraIndicesSize); + seqLenGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(seqLen), seqLenSize); + loraRanksGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(loraRanks), loraRanksSize); + loraScalesGm_.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(loraScales), loraScalesSize); + + workspaceGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ INNER_T *>(workspace)); + } + + __aicore__ inline void Process() + { + if (GetSysWorkSpacePtr() == nullptr) { + return; + } + + int64_t blocks = AscendC::GetBlockNum(); + int64_t blockIdx = AscendC::GetBlockIdx(); + + if ASCEND_IS_AIV { + if (AscendC::GetSubBlockIdx() == 1) { + return; + } + blockIdx /= AscendC::GetSubBlockNum(); + } + + int64_t tokenIdx = blockIdx / slices_; + int64_t sliceIdx = blockIdx % slices_; + + lora_common::BlockIterator blockIterator(seqLenGm_); + int64_t requestBlock = blockIterator.GetBlockIdx(tokenIdx); + if (requestBlock < 0) { + return; + } + + reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); + if (reqLoRAIndex_ < 0) { + return; + } + + reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_; + reqLoRARank_ = loraRanksGm_.GetValue(reqLoRAIndex_); + reqLoRAScale_ = loraScalesGm_.GetValue(reqLoRAIndex_); + + if (reqLoRARank_ == 0) { + return; + } + + uint32_t baseM = min(tiling.baseM, tiling.singleCoreM); + uint32_t baseN = min(tiling.baseN, min(tiling.singleCoreN, reqLoRARank_)); + uint32_t elements = baseM * baseN; + uint32_t maxElements = tiling.baseM * tiling.baseN; + + workspaceGlobal = workspaceGlobal[blockIdx * maxElements]; + + REGIST_MATMUL_OBJ(pipe_, GetSysWorkSpacePtr(), matmulObj, &tiling); + + matmulObj.DisableBias(); + matmulObj.SetWorkspace(workspaceGlobal); + matmulObj.SetOrgShape(tiling.M, tiling.N, tiling.Ka, tiling.Kb); + matmulObj.SetSingleShape(tiling.singleCoreM, reqLoRARank_, tiling.singleCoreK); + matmulObj.SetTensorA(xInGm_[tokenIdx * inputHiddenDim_], false); + matmulObj.SetTensorB(wInGm_[reqLoRAWeightOffset_ + sliceIdx * inputHiddenDim_ * reqLoRARank_], true); + matmulObj.template Iterate(); + + pipe_->InitBuffer(calcBuf, maxElements * sizeof(INNER_T)); + pipe_->InitBuffer(matmulQueue, 1, maxElements * sizeof(INNER_T)); + pipe_->InitBuffer(outQueue, 1, maxElements * sizeof(Y_T)); + + AscendC::DataCopyParams copyParams = { + (uint16_t)baseM, (uint16_t)(baseN * sizeof(Y_T) / AscendC::DEFAULT_C0_SIZE), (uint16_t)0, + (uint16_t)((slices_ * tiling.N - baseN) * sizeof(Y_T) / AscendC::DEFAULT_C0_SIZE)}; + uint32_t iteratations = AscendC::Ceil(tiling.singleCoreM, baseM) * AscendC::Ceil(reqLoRARank_, baseN); + uint32_t outputOffset = tokenIdx * slices_ * maxLoRARank_ + sliceIdx * reqLoRARank_; + for (uint32_t i = 0; i < iteratations; ++i) { + AscendC::LocalTensor cInLocal = matmulQueue.AllocTensor(); + matmulObj.template GetTensorC(cInLocal); + matmulObj.WaitGetTensorC(); + matmulQueue.EnQue(cInLocal); + + AscendC::LocalTensor tmpTensor = calcBuf.Get(); + AscendC::LocalTensor mmResTensor = matmulQueue.DeQue(); + AscendC::LocalTensor output = outQueue.AllocTensor(); + + AscendC::Muls(tmpTensor, mmResTensor, reqLoRAScale_, elements); + AscendC::PipeBarrier(); + AscendC::Cast(output, tmpTensor, AscendC::RoundMode::CAST_RINT, elements); + AscendC::PipeBarrier(); + + outQueue.EnQue(output); + matmulQueue.FreeTensor(mmResTensor); + calcBuf.FreeTensor(tmpTensor); + + AscendC::LocalTensor outputCopy = outQueue.DeQue(); + DataCopy(yOutGm_[outputOffset + i * baseN], outputCopy, copyParams); + outQueue.FreeTensor(outputCopy); + } + matmulObj.End(); + } + +private: + AscendC::TPipe *pipe_; + + TCubeTiling tiling; + MAT_TYPE matmulObj; + + AscendC::GlobalTensor xInGm_; + AscendC::GlobalTensor wInGm_; + AscendC::GlobalTensor yOutGm_; + AscendC::GlobalTensor loraIndicesGm_; + AscendC::GlobalTensor seqLenGm_; + AscendC::GlobalTensor loraRanksGm_; + AscendC::GlobalTensor loraScalesGm_; + + AscendC::GlobalTensor workspaceGlobal; + + AscendC::TBuf calcBuf; + AscendC::TQue matmulQueue; + AscendC::TQue outQueue; + + uint32_t slices_; + uint32_t batchSize_; + uint32_t inputHiddenDim_; + uint32_t maxLoRARank_; + uint32_t singleLoRAWeightLen_; + + uint64_t reqLoRAWeightOffset_; + int32_t reqLoRAIndex_; + int32_t reqLoRARank_; + INNER_T reqLoRAScale_; +}; + +extern "C" __global__ __aicore__ void sgemmc_shrink(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, + uint32_t loraIndicesSize, GM_ADDR seqLen, uint32_t seqLenSize, + GM_ADDR loraRanks, uint32_t loraRanksSize, GM_ADDR loraScales, + uint32_t loraScalesSize, GM_ADDR y, uint32_t batchSize, + uint32_t inputHiddenDim, uint32_t maxLoRARank, uint32_t slices, + GM_ADDR workspace, GM_ADDR tiling) +{ + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_1); + + AscendC::TPipe pipe; + sglang::npu_kernel::SGEMMCTilingData tilingData; + kernel_utils::CopyTiling(&tilingData, tiling); + + if (tilingData.tilingKey == 1) { + SGEMMCShrink op(&pipe); + op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, loraScales, + loraScalesSize, y, batchSize, inputHiddenDim, maxLoRARank, slices, workspace, tilingData.cubeTiling); + op.Process(); + } else { + SGEMMCShrink op(&pipe); + op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, loraScales, + loraScalesSize, y, batchSize, inputHiddenDim, maxLoRARank, slices, workspace, tilingData.cubeTiling); + op.Process(); + } +} + +#endif // SGL_KERNEL_NPU_KERNEL_SGEMMC_SHRINK_H diff --git a/csrc/lora/op_kernel/sgemmv_expand_kernel.cpp b/csrc/lora/op_kernel/sgemmv_expand_kernel.cpp index 284b75ccf..8111ed3a8 100644 --- a/csrc/lora/op_kernel/sgemmv_expand_kernel.cpp +++ b/csrc/lora/op_kernel/sgemmv_expand_kernel.cpp @@ -20,6 +20,7 @@ #define SGL_KERNEL_NPU_KERNEL_SGEMMV_EXPAND_H #include "kernel_operator.h" +#include "lora_common_kernel.h" template class SGEMMVExpand @@ -106,11 +107,19 @@ class SGEMMVExpand if (endIdx > batchSize_) { endIdx = batchSize_; } + + int64_t requestBlock = 0; + lora_common::BlockIterator blockIterator(seqLenGm_); for (int64_t idx = startIdx; idx < endIdx; idx++) { yOffset_ = outputFullDim_ * idx + sliceOffset_; + requestBlock = blockIterator.GetBlockIdx(idx); + if (requestBlock < 0) { + continue; + } + // Set up LoRA index - CopyInIndex(idx); + reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); if (reqLoRAIndex_ < 0) { continue; } @@ -144,22 +153,6 @@ class SGEMMVExpand } private: - __aicore__ inline void CopyInIndex(const int64_t idx) - { - // Look up the LoRA index - int64_t weightIdx = idx; - uint64_t i = 0; - for (; i < seqLenGm_.GetSize(); i++) { - int64_t repeatValue = seqLenGm_.GetValue(i); - if (weightIdx >= repeatValue) { - weightIdx -= repeatValue; - continue; - } - break; - } - reqLoRAIndex_ = (i < seqLenGm_.GetSize()) ? loraIndicesGm_.GetValue(i) : -1; - } - __aicore__ inline void ComputeLastIteration() { int32_t remainingY = outputHiddenDim_ % Y_OUT_TILE_NUM_ELEMENTS; diff --git a/csrc/lora/op_kernel/sgemmv_shrink_kernel.cpp b/csrc/lora/op_kernel/sgemmv_shrink_kernel.cpp index a01fa4df7..54dbf27e6 100644 --- a/csrc/lora/op_kernel/sgemmv_shrink_kernel.cpp +++ b/csrc/lora/op_kernel/sgemmv_shrink_kernel.cpp @@ -20,6 +20,7 @@ #define SGL_KERNEL_NPU_KERNEL_SGEMMV_SHRINK_H #include "kernel_operator.h" +#include "lora_common_kernel.h" template class SGEMMVShrink @@ -71,9 +72,16 @@ class SGEMMVShrink if (endIdx > batchSize_) { endIdx = batchSize_; } + int64_t requestBlock = 0; + lora_common::BlockIterator blockIterator(seqLenGm_); for (int64_t idx = startIdx; idx < endIdx; idx++) { // set up LoRA index - CopyInIndex(idx); + requestBlock = blockIterator.GetBlockIdx(idx); + if (requestBlock < 0) { + continue; + } + + reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); if (reqLoRAIndex_ < 0) { continue; } @@ -126,22 +134,6 @@ class SGEMMVShrink } } - __aicore__ inline void CopyInIndex(const int64_t idx) - { - // look up the LoRA index - int64_t weightIdx = idx; - uint64_t i = 0; - for (; i < seqLenGm_.GetSize(); i++) { - int64_t repeatValue = seqLenGm_.GetValue(i); - if (weightIdx >= repeatValue) { - weightIdx -= repeatValue; - continue; - } - break; - } - reqLoRAIndex_ = (i < seqLenGm_.GetSize()) ? loraIndicesGm_.GetValue(i) : -1; - } - __aicore__ inline void CopyInX(const int64_t idx, int32_t colIdx, int32_t numElements = TILE_LENGTH) { AscendC::LocalTensor xLocal = inQueueX_.AllocTensor(); diff --git a/csrc/lora/op_kernel/sgmv_expand_kernel.cpp b/csrc/lora/op_kernel/sgmv_expand_kernel.cpp index 9be8b01ce..5bb573990 100644 --- a/csrc/lora/op_kernel/sgmv_expand_kernel.cpp +++ b/csrc/lora/op_kernel/sgmv_expand_kernel.cpp @@ -20,6 +20,7 @@ #define SGL_KERNEL_NPU_KERNEL_SGMV_EXPAND_H #include "kernel_operator.h" +#include "lora_common_kernel.h" template class SGMVExpand @@ -102,11 +103,18 @@ class SGMVExpand if (endIdx > batchSize_) { endIdx = batchSize_; } + + int64_t requestBlock = 0; + lora_common::BlockIterator blockIterator(seqLenGm_); for (int64_t idx = startIdx; idx < endIdx; idx++) { yOffset_ = outputFullDim_ * idx + sliceOffset_; - // Set up LoRA index - CopyInIndex(idx); + requestBlock = blockIterator.GetBlockIdx(idx); + if (requestBlock < 0) { + continue; + } + + reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); if (reqLoRAIndex_ < 0) { continue; } @@ -128,22 +136,6 @@ class SGMVExpand } private: - __aicore__ inline void CopyInIndex(const int64_t idx) - { - // Look up the LoRA index - int64_t weightIdx = idx; - uint64_t i = 0; - for (; i < seqLenGm_.GetSize(); i++) { - int64_t repeatValue = seqLenGm_.GetValue(i); - if (weightIdx >= repeatValue) { - weightIdx -= repeatValue; - continue; - } - break; - } - reqLoRAIndex_ = (i < seqLenGm_.GetSize()) ? loraIndicesGm_.GetValue(i) : -1; - } - __aicore__ inline void ComputeLastIteration() { int32_t remainingY = outputHiddenDim_ % Y_OUT_TILE_NUM_ELEMENTS; diff --git a/csrc/lora/op_kernel/sgmv_shrink_kernel.cpp b/csrc/lora/op_kernel/sgmv_shrink_kernel.cpp index 05b7212a8..e483e5ab1 100644 --- a/csrc/lora/op_kernel/sgmv_shrink_kernel.cpp +++ b/csrc/lora/op_kernel/sgmv_shrink_kernel.cpp @@ -20,6 +20,7 @@ #define SGL_KERNEL_NPU_KERNEL_SGMV_SHRINK_H #include "kernel_operator.h" +#include "lora_common_kernel.h" template class SGMVShrink @@ -69,9 +70,17 @@ class SGMVShrink if (endIdx > batchSize_) { endIdx = batchSize_; } + + int64_t requestBlock = 0; + lora_common::BlockIterator blockIterator(seqLenGm_); for (int64_t idx = startIdx; idx < endIdx; idx++) { // set up LoRA index - CopyInIndex(idx); + requestBlock = blockIterator.GetBlockIdx(idx); + if (requestBlock < 0) { + continue; + } + + reqLoRAIndex_ = loraIndicesGm_.GetValue(requestBlock); if (reqLoRAIndex_ < 0) { continue; } @@ -116,22 +125,6 @@ class SGMVShrink } } - __aicore__ inline void CopyInIndex(const int64_t idx) - { - // look up the LoRA index - int64_t weightIdx = idx; - uint64_t i = 0; - for (; i < seqLenGm_.GetSize(); i++) { - int64_t repeatValue = seqLenGm_.GetValue(i); - if (weightIdx >= repeatValue) { - weightIdx -= repeatValue; - continue; - } - break; - } - reqLoRAIndex_ = (i < seqLenGm_.GetSize()) ? loraIndicesGm_.GetValue(i) : -1; - } - __aicore__ inline void CopyInX(const int64_t idx, int32_t colIdx, int32_t numElements = TILE_LENGTH) { AscendC::LocalTensor xLocal = inQueueX_.AllocTensor(); diff --git a/csrc/pytorch_extensions.cpp b/csrc/pytorch_extensions.cpp index 5c761af07..2ad55bdd8 100644 --- a/csrc/pytorch_extensions.cpp +++ b/csrc/pytorch_extensions.cpp @@ -72,14 +72,17 @@ TORCH_LIBRARY_FRAGMENT(npu, m) "bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y," " int slice_offset, int slice_size) -> Tensor"); - m.def("bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()"); + m.def( + "bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y," + " float scale) -> ()"); m.def( "sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y," " int slice_offset, int slice_size) -> Tensor"); m.def( - "sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()"); + "sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y," + " float scale) -> ()"); m.def( "sgemmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! lora_ranks," @@ -96,6 +99,14 @@ TORCH_LIBRARY_FRAGMENT(npu, m) "Tensor(b!)? intermediate_state=None, Tensor? cache_indices=None, " "Tensor? num_accepted_tokens=None, Tensor? g=None, Tensor? gk=None) -> Tensor"); + m.def( + "sgemmc_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! lora_ranks," + " Tensor! sliceOffsets, Tensor! y) -> Tensor"); + + m.def( + "sgemmc_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! lora_ranks," + " Tensor! lora_scales, Tensor! y, int slice_count) -> ()"); + #ifdef BUILD_CATLASS_MODULE m.def("catlass_matmul_basic(Tensor tensor_a, Tensor tensor_b, Tensor(a!) tensor_c, str? format_mode=None) -> ()"); #endif @@ -155,6 +166,10 @@ TORCH_LIBRARY_IMPL(npu, PrivateUse1, m) m.impl("recurrent_gated_delta_rule", TORCH_FN(sglang::npu_kernel::recurrent_gated_delta_rule)); + m.impl("sgemmc_expand", TORCH_FN(sglang::npu_kernel::sgemmc_expand)); + + m.impl("sgemmc_shrink", TORCH_FN(sglang::npu_kernel::sgemmc_shrink)); + #ifdef BUILD_CATLASS_MODULE m.impl("catlass_matmul_basic", TORCH_FN(sglang::npu_kernel::catlass_matmul_basic)); #endif diff --git a/csrc/utils/common_tiling.h b/csrc/utils/common_tiling.h index a7543142c..dd8b22371 100644 --- a/csrc/utils/common_tiling.h +++ b/csrc/utils/common_tiling.h @@ -19,6 +19,27 @@ namespace host_utils { +enum class DataType : int32_t { + DT_FLOAT16 = 0, // fp16 type + DT_BFLOAT16 = 1, // bfloat16 type + DT_FLOAT = 2, // float type + DT_DOUBLE = 3, // double type + DT_BOOL = 4, // bool type + DT_INT8 = 5, // int8 type + DT_INT16, // int16 type + DT_INT32, // int32 type + DT_INT64, // int64 type + DT_UINT8, // unsigned int8 type + DT_UINT16, // unsigned int16 type + DT_UINT32, // unsigned int32 type + DT_UINT64, // unsigned int64 type + DT_COMPLEX64, // complex64 type + DT_COMPLEX128, // complex128 type + + DT_UNDEFINED, + DT_MAX // max type +}; + constexpr uint32_t FP16_SIZE = 2; constexpr uint32_t FP32_SIZE = 4; constexpr uint32_t BLOCK_SIZE = 16; @@ -236,5 +257,7 @@ inline __attribute__((always_inline)) void PpMatmulTilingCheck(const PpTilingDat TORCH_CHECK(tilingData.nLoop > 0, "nLoop is invalid"); TORCH_CHECK(tilingData.blockDim > 0, "nLoop is invalid"); } + } // namespace host_utils + #endif diff --git a/csrc/utils/kernel/common_tiling_kernel.h b/csrc/utils/kernel/common_tiling_kernel.h new file mode 100644 index 000000000..efa8135f4 --- /dev/null +++ b/csrc/utils/kernel/common_tiling_kernel.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COMMMON_TILING_KERNEL_H +#define COMMMON_TILING_KERNEL_H + +#include "kernel_operator.h" + +namespace kernel_utils { + +template +__aicore__ inline void CopyTiling(T *tiling, GM_ADDR tilingGM) +{ + uint32_t *ptr = reinterpret_cast(tiling); + auto tiling32 = reinterpret_cast<__gm__ uint32_t *>(tilingGM); + + for (int i = 0; i < sizeof(T) / sizeof(uint32_t); ++i, ++ptr) { + *ptr = *(tiling32 + i); + } +} + +} // namespace kernel_utils + +#endif // COMMMON_TILING_KERNEL_H diff --git a/csrc/utils/torch_helper.h b/csrc/utils/torch_helper.h index ab92de038..239f0824c 100644 --- a/csrc/utils/torch_helper.h +++ b/csrc/utils/torch_helper.h @@ -25,6 +25,7 @@ #include "torch_npu/csrc/core/npu/NPUStream.h" #include "torch_npu/csrc/framework/OpCommand.h" +#include "common_tiling.h" namespace sglang { namespace npu_kernel { @@ -52,6 +53,49 @@ class TorchNpuHelper return const_cast(at_tensor.data_ptr()); } + inline static host_utils::DataType ConvertDataType(const at::ScalarType type) + { + switch (type) { + case at::ScalarType::Float: + return host_utils::DataType::DT_FLOAT; + + case at::ScalarType::Half: + return host_utils::DataType::DT_FLOAT16; + + case at::ScalarType::BFloat16: + return host_utils::DataType::DT_BFLOAT16; + + case at::ScalarType::Double: + return host_utils::DataType::DT_DOUBLE; + + case at::ScalarType::Bool: + return host_utils::DataType::DT_BOOL; + + case at::ScalarType::Char: + return host_utils::DataType::DT_INT8; + + case at::ScalarType::Short: + return host_utils::DataType::DT_INT16; + + case at::ScalarType::Int: + return host_utils::DataType::DT_INT32; + + case at::ScalarType::Long: + return host_utils::DataType::DT_INT64; + + case at::ScalarType::Byte: + return host_utils::DataType::DT_UINT8; + + case at::ScalarType::ComplexFloat: + return host_utils::DataType::DT_COMPLEX64; + + case at::ScalarType::ComplexDouble: + return host_utils::DataType::DT_COMPLEX128; + } + + return host_utils::DataType::DT_MAX; + } + template inline static T ConvertType(T value) { diff --git a/include/sgl_kenel_npu_ops.h b/include/sgl_kenel_npu_ops.h index 95460b68a..578c4ef40 100644 --- a/include/sgl_kenel_npu_ops.h +++ b/include/sgl_kenel_npu_ops.h @@ -106,6 +106,15 @@ at::Tensor recurrent_gated_delta_rule( c10::optional num_accepted_tokens_opt, c10::optional g_opt, c10::optional gk_opt); +at::Tensor sgemmc_expand(at::Tensor &x, at::Tensor &weight, + at::Tensor &lora_indices, at::Tensor &seq_len, + at::Tensor &lora_ranks, at::Tensor &slice_offsets, + at::Tensor &y); + +void sgemmc_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, + at::Tensor &seq_len, at::Tensor &lora_ranks, + at::Tensor &lora_scales, at::Tensor &y, int64_t slice_count); + #ifdef BUILD_CATLASS_MODULE void catlass_matmul_basic(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,