Skip to content

Commit 7b93364

Browse files
authored
LoRA: Implementing kernels using CUBE computation unit (#384)
* LoRA: Implementing kernels using CUBE computation unit * Resolving comments * Update pre-commit after merge
1 parent 135e62b commit 7b93364

20 files changed

Lines changed: 974 additions & 75 deletions

csrc/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ FILE(GLOB OP_SRCS
1919
${PROJECT_OP_SRC_BASE}/lora/op_host/sgmv_shrink.cpp
2020
${PROJECT_OP_SRC_BASE}/lora/op_host/sgemmv_expand.cpp
2121
${PROJECT_OP_SRC_BASE}/lora/op_host/sgemmv_shrink.cpp
22+
${PROJECT_OP_SRC_BASE}/lora/op_host/sgemmc_expand.cpp
23+
${PROJECT_OP_SRC_BASE}/lora/op_host/sgemmc_shrink.cpp
24+
${PROJECT_OP_SRC_BASE}/lora/op_host/tiling/sgemmc_tiling.cpp
2225
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/lightning_indexer.cpp
2326
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/tiling/lightning_indexer_tiling.cpp
2427
${PROJECT_OP_SRC_BASE}/tri_inv/op_host/tri_inv.cpp
@@ -51,6 +54,10 @@ ascendc_library(no_workspace_kernel STATIC
5154
${PROJECT_OP_SRC_BASE}/recurrent_gated_delta_rule/op_kernel/recurrent_gated_delta_rule_kernel.cpp
5255
)
5356

57+
ascendc_include_directories(no_workspace_kernel PRIVATE
58+
${PROJECT_OP_SRC_BASE}/utils/kernel
59+
)
60+
5461
# kernel side files with workspace
5562
set(WORKSPACE_KERNEL_SRCS
5663
${PROJECT_OP_SRC_BASE}/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
@@ -59,6 +66,8 @@ set(WORKSPACE_KERNEL_SRCS
5966
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_kernel/lightning_indexer_kernel.cpp
6067
${PROJECT_OP_SRC_BASE}/causal_conv1d_update/op_kernel/causal_conv1d_update.cpp
6168
${PROJECT_OP_SRC_BASE}/causal_conv1d/op_kernel/causal_conv1d.cpp
69+
${PROJECT_OP_SRC_BASE}/lora/op_kernel/sgemmc_expand_kernel.cpp
70+
${PROJECT_OP_SRC_BASE}/lora/op_kernel/sgemmc_shrink_kernel.cpp
6271
)
6372
if(BUILD_CATLASS_MODULE)
6473
list(APPEND WORKSPACE_KERNEL_SRCS
@@ -76,6 +85,10 @@ if(BUILD_CATLASS_MODULE)
7685
)
7786
endif()
7887

88+
ascendc_include_directories(workspace_kernel PRIVATE
89+
${PROJECT_OP_SRC_BASE}/utils/kernel
90+
)
91+
7992
ascendc_compile_definitions(workspace_kernel PRIVATE
8093
-DHAVE_WORKSPACE
8194
-DHAVE_TILING
@@ -87,6 +100,7 @@ add_library(${OP_PLUGIN_NAME} SHARED ${OP_SRCS})
87100
target_link_libraries(${OP_PLUGIN_NAME} PRIVATE
88101
workspace_kernel
89102
no_workspace_kernel
103+
host_intf_pub
90104
torch_npu
91105
ascendcl
92106
tiling_api
@@ -115,6 +129,7 @@ target_include_directories(${OP_PLUGIN_NAME} PRIVATE
115129
${TORCH_DIR}/include
116130
${TORCH_DIR}/include/torch/csrc/api/include
117131
${TORCH_NPU_DIR}/include
132+
${ASCEND_INCLUDE_DIR}
118133
${ASCEND_INCLUDE_DIR}/external
119134
${ASCEND_INCLUDE_DIR}/experiment/platform
120135
${ASCEND_INCLUDE_DIR}/experiment/runtime
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
*/
17+
18+
#include "defines.h"
19+
#include "tiling/sgemmc_tiling.h"
20+
#include "torch_helper.h"
21+
22+
#include "aclrtlaunch_sgemmc_expand.h"
23+
24+
namespace sglang {
25+
namespace npu_kernel {
26+
27+
HOST_API at::Tensor sgemmc_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len,
28+
at::Tensor &lora_ranks, at::Tensor &slice_offsets, at::Tensor &y)
29+
{
30+
at::ScalarType scalar_type = y.scalar_type();
31+
TORCH_CHECK(scalar_type == at::kHalf || scalar_type == at::kBFloat16, "only support half and bf16");
32+
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
33+
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
34+
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
35+
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
36+
37+
at::Tensor y_out = y;
38+
void *x_ptr = x.data_ptr();
39+
void *weight_ptr = weight.data_ptr();
40+
void *y_ptr = y.data_ptr();
41+
void *y_out_ptr = y_out.data_ptr();
42+
43+
void *lora_indices_ptr = lora_indices.data_ptr();
44+
int lora_indices_size = lora_indices.size(0);
45+
void *seq_len_ptr = seq_len.data_ptr();
46+
int seq_len_size = seq_len.size(0);
47+
void *lora_ranks_ptr = lora_ranks.data_ptr();
48+
int lora_ranks_size = lora_ranks.size(0);
49+
void *slice_offsets_ptr = slice_offsets.data_ptr();
50+
int slice_offsets_size = slice_offsets.size(0);
51+
int slice_count = slice_offsets_size - 1;
52+
int batch_size = x.size(0);
53+
int max_lora_rank = x.size(1) / slice_count;
54+
int output_full_dim = y.size(1);
55+
56+
uint32_t block_dim;
57+
uint32_t workspace_size;
58+
59+
at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, max_lora_rank, output_full_dim,
60+
TorchNpuHelper::ConvertDataType(scalar_type));
61+
auto workspace_tensor =
62+
at::empty({workspace_size}, at::TensorOptions().dtype(at::kByte).device(x.options().device()));
63+
64+
/* launch the kernel function via torch */
65+
EXEC_KERNEL_CMD(sgemmc_expand, block_dim, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr,
66+
seq_len_size, lora_ranks_ptr, lora_ranks_size, slice_offsets_ptr, slice_offsets_size, y_ptr,
67+
y_out_ptr, batch_size, max_lora_rank, output_full_dim, workspace_tensor, tiling_tensor);
68+
69+
return y_out;
70+
}
71+
72+
} // namespace npu_kernel
73+
} // namespace sglang
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
*/
17+
18+
#include "defines.h"
19+
#include "tiling/sgemmc_tiling.h"
20+
#include "torch_helper.h"
21+
22+
#include "aclrtlaunch_sgemmc_shrink.h"
23+
24+
namespace sglang {
25+
namespace npu_kernel {
26+
27+
HOST_API void sgemmc_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len,
28+
at::Tensor &lora_ranks, at::Tensor &lora_scales, at::Tensor &y)
29+
{
30+
at::ScalarType scalar_type = x.scalar_type();
31+
TORCH_CHECK(scalar_type == at::kHalf || scalar_type == at::kBFloat16, "only support half and bf16");
32+
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
33+
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
34+
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
35+
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
36+
TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out");
37+
void *x_ptr = x.data_ptr();
38+
void *weight_ptr = weight.data_ptr();
39+
40+
void *lora_indices_ptr = lora_indices.data_ptr();
41+
int lora_indices_size = lora_indices.size(0);
42+
void *seq_len_ptr = seq_len.data_ptr();
43+
int seq_len_size = seq_len.size(0);
44+
void *lora_ranks_ptr = lora_ranks.data_ptr();
45+
int lora_ranks_size = lora_ranks.size(0);
46+
void *lora_scales_ptr = lora_scales.data_ptr();
47+
int lora_scales_size = lora_scales.size(0);
48+
49+
void *y_ptr = y.data_ptr();
50+
int batch_size = x.size(0);
51+
int input_hidden_token = x.size(1);
52+
uint32_t max_lora_rank = y.size(1);
53+
54+
uint32_t block_dim;
55+
uint32_t workspace_size;
56+
57+
at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, input_hidden_token, max_lora_rank,
58+
TorchNpuHelper::ConvertDataType(scalar_type));
59+
60+
auto workspace_tensor =
61+
at::empty({workspace_size}, at::TensorOptions().dtype(at::kByte).device(x.options().device()));
62+
/* launch the kernel function via torch */
63+
EXEC_KERNEL_CMD(sgemmc_shrink, block_dim, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr,
64+
seq_len_size, lora_ranks_ptr, lora_ranks_size, lora_scales_ptr, lora_scales_size, y_ptr, batch_size,
65+
input_hidden_token, max_lora_rank, workspace_tensor, tiling_tensor);
66+
return;
67+
}
68+
69+
} // namespace npu_kernel
70+
} // namespace sglang
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
*/
17+
18+
#include "common.h"
19+
#include "sgemmc_tiling.h"
20+
21+
namespace sglang {
22+
namespace npu_kernel {
23+
24+
matmul_tiling::DataType ConvertToMatMulTypes(host_utils::DataType data_type)
25+
{
26+
switch (data_type) {
27+
case host_utils::DataType::DT_BFLOAT16:
28+
return matmul_tiling::DataType::DT_BFLOAT16;
29+
case host_utils::DataType::DT_FLOAT:
30+
return matmul_tiling::DataType::DT_FLOAT;
31+
case host_utils::DataType::DT_FLOAT16:
32+
return matmul_tiling::DataType::DT_FLOAT16;
33+
}
34+
35+
return matmul_tiling::DataType::DT_FLOAT16;
36+
}
37+
38+
at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_t batch_size, uint32_t inner_size,
39+
uint32_t output_size, const host_utils::DataType type)
40+
{
41+
auto ascendc_platform = *platform_ascendc::PlatformAscendCManager::GetInstance();
42+
uint32_t aiv_num = ascendc_platform.GetCoreNumAiv();
43+
uint32_t aic_num = ascendc_platform.GetCoreNumAic();
44+
workspace_size = ascendc_platform.GetLibApiWorkSpaceSize();
45+
46+
auto tilingBuffer = at::empty({sizeof(SGEMMCTilingData)}, at::TensorOptions().dtype(at::kByte).device(at::kCPU));
47+
SGEMMCTilingData *tiling_data = reinterpret_cast<SGEMMCTilingData *>(tilingBuffer.data_ptr());
48+
49+
matmul_tiling::MultiCoreMatmulTiling cubeTiling(ascendc_platform);
50+
51+
const matmul_tiling::DataType data_type = ConvertToMatMulTypes(type);
52+
53+
cubeTiling.EnableBias(false);
54+
cubeTiling.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::VECTOR, data_type, false);
55+
cubeTiling.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, data_type, true);
56+
cubeTiling.SetCType(matmul_tiling::TPosition::VECIN, matmul_tiling::CubeFormat::ND,
57+
matmul_tiling::DataType::DT_FLOAT);
58+
cubeTiling.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, data_type);
59+
cubeTiling.EnableMultiCoreSplitK(false);
60+
cubeTiling.SetDim(aic_num);
61+
62+
cubeTiling.SetOrgShape(1, inner_size, output_size);
63+
cubeTiling.SetShape(1, inner_size, output_size);
64+
cubeTiling.SetBufferSpace(-1, -1, -1);
65+
66+
if (cubeTiling.GetTiling(tiling_data->cubeTiling) == -1) {
67+
TORCH_CHECK(false, "Generate tiling failed.");
68+
return {};
69+
}
70+
71+
tiling_data->batch = batch_size;
72+
tiling_data->dataType = (type == host_utils::DataType::DT_BFLOAT16);
73+
74+
block_dim = batch_size * tiling_data->cubeTiling.usedCoreNum;
75+
76+
return tilingBuffer;
77+
}
78+
79+
} // namespace npu_kernel
80+
} // namespace sglang
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
*/
17+
18+
#ifndef SGEMMC_TILING_H
19+
#define SGEMMC_TILING_H
20+
21+
#include <register/tilingdata_base.h>
22+
#include <tiling/tiling_api.h>
23+
24+
#include "torch_helper.h"
25+
#include "common_tiling.h"
26+
#include "sgemmc_tiling_data.h"
27+
28+
namespace sglang {
29+
namespace npu_kernel {
30+
31+
at::Tensor GenerateTiling(uint32_t &blockDim, uint32_t &workspace, uint32_t batch, uint32_t hidden_size, uint32_t k,
32+
const host_utils::DataType type);
33+
34+
} // namespace npu_kernel
35+
} // namespace sglang
36+
37+
#endif // SGEMMC_TILING_H
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
*/
17+
18+
#ifndef SGEMMC_TILING_DATA_H
19+
#define SGEMMC_TILING_DATA_H
20+
21+
#include <cstdint>
22+
23+
namespace AscendC {
24+
namespace tiling {
25+
26+
struct TCubeTiling;
27+
28+
} // namespace tiling
29+
} // namespace AscendC
30+
31+
namespace sglang {
32+
namespace npu_kernel {
33+
34+
#pragma pack(push, 1)
35+
struct SGEMMCTilingData {
36+
uint32_t dataType;
37+
uint32_t batch;
38+
uint32_t hidden;
39+
uint32_t k;
40+
uint32_t slices;
41+
AscendC::tiling::TCubeTiling cubeTiling;
42+
};
43+
#pragma pack(pop)
44+
45+
} // namespace npu_kernel
46+
} // namespace sglang
47+
48+
#endif // SGEMMC_TILING_DATA_H

0 commit comments

Comments
 (0)