Skip to content

Commit a7455c6

Browse files
committed
Resolving comments
1 parent 04c5fa2 commit a7455c6

5 files changed

Lines changed: 25 additions & 49 deletions

File tree

csrc/lora/op_host/sgemmc_expand.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,16 @@ HOST_API at::Tensor sgemmc_expand(at::Tensor &x, at::Tensor &weight, at::Tensor
5555

5656
uint32_t block_dim;
5757
uint32_t workspace_size;
58-
int64_t num_tokens_per_core = 0;
59-
int input_hidden_token = 0;
6058

61-
at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, input_hidden_token, max_lora_rank,
59+
at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, max_lora_rank, output_full_dim,
6260
TorchNpuHelper::ConvertDataType(scalar_type));
6361
auto workspace_tensor =
6462
at::empty({workspace_size}, at::TensorOptions().dtype(at::kByte).device(x.options().device()));
6563

6664
/* launch the kernel function via torch */
6765
EXEC_KERNEL_CMD(sgemmc_expand, block_dim, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr,
6866
seq_len_size, lora_ranks_ptr, lora_ranks_size, slice_offsets_ptr, slice_offsets_size, y_ptr,
69-
y_out_ptr, batch_size, num_tokens_per_core, max_lora_rank, output_full_dim, workspace_tensor,
70-
tiling_tensor);
67+
y_out_ptr, batch_size, max_lora_rank, output_full_dim, workspace_tensor, tiling_tensor);
7168

7269
return y_out;
7370
}

csrc/lora/op_host/sgemmc_shrink.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ HOST_API void sgemmc_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
5353

5454
uint32_t block_dim;
5555
uint32_t workspace_size;
56-
int64_t total_extend_tokens = 0;
57-
int64_t num_tokens_per_core = 0;
5856

5957
at::Tensor tiling_tensor = GenerateTiling(block_dim, workspace_size, batch_size, input_hidden_token, max_lora_rank,
6058
TorchNpuHelper::ConvertDataType(scalar_type));
@@ -64,7 +62,7 @@ HOST_API void sgemmc_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
6462
/* launch the kernel function via torch */
6563
EXEC_KERNEL_CMD(sgemmc_shrink, block_dim, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr,
6664
seq_len_size, lora_ranks_ptr, lora_ranks_size, lora_scales_ptr, lora_scales_size, y_ptr, batch_size,
67-
num_tokens_per_core, input_hidden_token, max_lora_rank, workspace_tensor, tiling_tensor);
65+
input_hidden_token, max_lora_rank, workspace_tensor, tiling_tensor);
6866
return;
6967
}
7068

csrc/lora/op_host/tiling/sgemmc_tiling.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,19 @@ matmul_tiling::DataType ConvertToMatMulTypes(host_utils::DataType data_type)
3535
return matmul_tiling::DataType::DT_FLOAT16;
3636
}
3737

38-
at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_t batch_size, uint32_t hidden_size,
39-
uint32_t max_lora_rank, const host_utils::DataType type)
38+
at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_t batch_size, uint32_t inner_size,
39+
uint32_t output_size, const host_utils::DataType type)
4040
{
4141
auto ascendc_platform = *platform_ascendc::PlatformAscendCManager::GetInstance();
42-
uint32_t aiv_num = ascendcPlatform.GetCoreNumAiv();
43-
uint32_t aic_num = ascendcPlatform.GetCoreNumAic();
44-
workspace_size = ascendcPlatform.GetLibApiWorkSpaceSize();
42+
uint32_t aiv_num = ascendc_platform.GetCoreNumAiv();
43+
uint32_t aic_num = ascendc_platform.GetCoreNumAic();
44+
workspace_size = ascendc_platform.GetLibApiWorkSpaceSize();
4545

4646
auto tilingBuffer = at::empty({sizeof(SGEMMCTilingData)}, at::TensorOptions().dtype(at::kByte).device(at::kCPU));
4747
SGEMMCTilingData *tiling_data = reinterpret_cast<SGEMMCTilingData *>(tilingBuffer.data_ptr());
4848

4949
matmul_tiling::MultiCoreMatmulTiling cubeTiling(ascendc_platform);
5050

51-
uint32_t M = batch_size;
52-
uint32_t N = hidden_size;
53-
uint32_t K = max_lora_rank;
54-
5551
const matmul_tiling::DataType data_type = ConvertToMatMulTypes(type);
5652

5753
cubeTiling.EnableBias(false);
@@ -60,11 +56,11 @@ at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_
6056
cubeTiling.SetCType(matmul_tiling::TPosition::VECIN, matmul_tiling::CubeFormat::ND,
6157
matmul_tiling::DataType::DT_FLOAT);
6258
cubeTiling.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, data_type);
63-
59+
cubeTiling.EnableMultiCoreSplitK(false);
6460
cubeTiling.SetDim(aic_num);
6561

66-
cubeTiling.SetOrgShape(1, hidden_size, max_lora_rank);
67-
cubeTiling.SetShape(1, hidden_size, max_lora_rank);
62+
cubeTiling.SetOrgShape(1, inner_size, output_size);
63+
cubeTiling.SetShape(1, inner_size, output_size);
6864
cubeTiling.SetBufferSpace(-1, -1, -1);
6965

7066
if (cubeTiling.GetTiling(tiling_data->cubeTiling) == -1) {
@@ -73,8 +69,9 @@ at::Tensor GenerateTiling(uint32_t &block_dim, uint32_t &workspace_size, uint32_
7369
}
7470

7571
tiling_data->batch = batch_size;
72+
tiling_data->dataType = (type == host_utils::DataType::DT_BFLOAT16);
7673

77-
block_dim = batch * tiling_data->cubeTiling.usedCoreNum;
74+
block_dim = batch_size * tiling_data->cubeTiling.usedCoreNum;
7875

7976
return tilingBuffer;
8077
}

csrc/lora/op_kernel/sgemmc_expand_kernel.cpp

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,12 @@ class SGEMMCExpand
4747
__aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, uint32_t loraIndicesSize,
4848
GM_ADDR seqLen, uint32_t seqLenSize, GM_ADDR loraRanks, uint32_t loraRanksSize,
4949
GM_ADDR sliceOffsets, uint32_t sliceOffsetsSize, GM_ADDR yIn, GM_ADDR yOut,
50-
uint32_t batchSize, uint32_t numBlocksPerCore, uint32_t maxLoRARank,
51-
uint32_t outputFullDim, GM_ADDR workspace, TCubeTiling &tiling)
50+
uint32_t batchSize, uint32_t maxLoRARank, uint32_t outputFullDim, GM_ADDR workspace,
51+
TCubeTiling &tiling)
5252
{
5353
this->tiling = tiling;
5454

5555
batchSize_ = batchSize;
56-
numBlocksPerCore_ = numBlocksPerCore;
5756
maxLoRARank_ = maxLoRARank;
5857
sliceCount_ = sliceOffsetsSize - 1;
5958
outputFullDim_ = outputFullDim;
@@ -78,15 +77,11 @@ class SGEMMCExpand
7877
int64_t blocks = AscendC::GetBlockNum();
7978
int64_t blockIdx = AscendC::GetBlockIdx();
8079

81-
int64_t startIdx = blockIdx * numBlocksPerCore_;
82-
int64_t endIdx = startIdx + numBlocksPerCore_;
83-
8480
AscendC::WaitPreTaskEnd();
8581

86-
int64_t batchIdx = 0;
8782
int64_t requestBlock = 0;
8883
lora_common::BlockIterator blockIterator(seqLenGm_);
89-
requestBlock = blockIterator.GetBlockIdx(batchIdx);
84+
requestBlock = blockIterator.GetBlockIdx(blockIdx);
9085
if (requestBlock < 0) {
9186
return;
9287
}
@@ -178,7 +173,6 @@ class SGEMMCExpand
178173

179174
uint32_t batchSize_;
180175
uint32_t sliceCount_;
181-
uint32_t numBlocksPerCore_;
182176
uint32_t maxLoRARank_;
183177
uint32_t outputHiddenDim_;
184178
uint32_t sliceOffset_;
@@ -197,8 +191,8 @@ extern "C" __global__ __aicore__ void sgemmc_expand(GM_ADDR x, GM_ADDR weight, G
197191
uint32_t loraIndicesSize, GM_ADDR seqLen, uint32_t seqLenSize,
198192
GM_ADDR loraRanks, uint32_t loraRanksSize, GM_ADDR sliceOffsets,
199193
uint32_t sliceOffsetsSize, GM_ADDR yIn, GM_ADDR yOut,
200-
uint32_t batchSize, uint32_t numBlocksPerCore, uint32_t maxLoRARank,
201-
uint32_t outputFullDim, GM_ADDR workspace, GM_ADDR tiling)
194+
uint32_t batchSize, uint32_t maxLoRARank, uint32_t outputFullDim,
195+
GM_ADDR workspace, GM_ADDR tiling)
202196
{
203197
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_1);
204198

@@ -209,14 +203,12 @@ extern "C" __global__ __aicore__ void sgemmc_expand(GM_ADDR x, GM_ADDR weight, G
209203
if (tilingData.dataType == 1) {
210204
SGEMMCExpand<bfloat16_t, float> op(&pipe);
211205
op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, sliceOffsets,
212-
sliceOffsetsSize, yIn, yOut, batchSize, numBlocksPerCore, maxLoRARank, outputFullDim, workspace,
213-
tilingData.cubeTiling);
206+
sliceOffsetsSize, yIn, yOut, batchSize, maxLoRARank, outputFullDim, workspace, tilingData.cubeTiling);
214207
op.Process();
215208
} else {
216209
SGEMMCExpand<half, float> op(&pipe);
217210
op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, sliceOffsets,
218-
sliceOffsetsSize, yIn, yOut, batchSize, numBlocksPerCore, maxLoRARank, outputFullDim, workspace,
219-
tilingData.cubeTiling);
211+
sliceOffsetsSize, yIn, yOut, batchSize, maxLoRARank, outputFullDim, workspace, tilingData.cubeTiling);
220212
op.Process();
221213
}
222214
}

csrc/lora/op_kernel/sgemmc_shrink_kernel.cpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,11 @@ class SGEMMCShrink
4747
__aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR loraIndices, uint32_t loraIndicesSize,
4848
GM_ADDR seqLen, uint32_t seqLenSize, GM_ADDR loraRanks, uint32_t loraRanksSize,
4949
GM_ADDR loraScales, uint32_t loraScalesSize, GM_ADDR y, uint32_t batchSize,
50-
uint32_t numBlocksPerCore, uint32_t inputHiddenDim, uint32_t maxLoRARank,
51-
GM_ADDR workspace, TCubeTiling &tiling)
50+
uint32_t inputHiddenDim, uint32_t maxLoRARank, GM_ADDR workspace, TCubeTiling &tiling)
5251
{
5352
this->tiling = tiling;
5453

5554
batchSize_ = batchSize;
56-
numBlocksPerCore_ = numBlocksPerCore;
5755
inputHiddenDim_ = inputHiddenDim;
5856
maxLoRARank_ = maxLoRARank;
5957
singleLoRAWeightLen_ = inputHiddenDim_ * maxLoRARank_;
@@ -76,9 +74,6 @@ class SGEMMCShrink
7674
int64_t blocks = AscendC::GetBlockNum();
7775
int64_t blockIdx = AscendC::GetBlockIdx();
7876

79-
int64_t startIdx = blockIdx * numBlocksPerCore_;
80-
int64_t endIdx = startIdx + numBlocksPerCore_;
81-
8277
AscendC::WaitPreTaskEnd();
8378

8479
int64_t batchIdx = 0;
@@ -165,7 +160,6 @@ class SGEMMCShrink
165160
AscendC::TBuf<AscendC::QuePosition::VECCALC> vectorCalcBuf;
166161

167162
uint32_t batchSize_;
168-
uint32_t numBlocksPerCore_;
169163
uint32_t inputHiddenDim_;
170164
uint32_t maxLoRARank_;
171165
uint32_t singleLoRAWeightLen_;
@@ -179,8 +173,8 @@ extern "C" __global__ __aicore__ void sgemmc_shrink(GM_ADDR x, GM_ADDR weight, G
179173
uint32_t loraIndicesSize, GM_ADDR seqLen, uint32_t seqLenSize,
180174
GM_ADDR loraRanks, uint32_t loraRanksSize, GM_ADDR loraScales,
181175
uint32_t loraScalesSize, GM_ADDR y, uint32_t batchSize,
182-
uint32_t numBlocksPerCore, uint32_t inputHiddenDim,
183-
uint32_t maxLoRARank, GM_ADDR workspace, GM_ADDR tiling)
176+
uint32_t inputHiddenDim, uint32_t maxLoRARank, GM_ADDR workspace,
177+
GM_ADDR tiling)
184178
{
185179
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_1);
186180

@@ -191,14 +185,12 @@ extern "C" __global__ __aicore__ void sgemmc_shrink(GM_ADDR x, GM_ADDR weight, G
191185
if (tilingData.dataType == 1) {
192186
SGEMMCShrink<bfloat16_t, float> op(&pipe);
193187
op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, loraScales,
194-
loraScalesSize, y, batchSize, numBlocksPerCore, inputHiddenDim, maxLoRARank, workspace,
195-
tilingData.cubeTiling);
188+
loraScalesSize, y, batchSize, inputHiddenDim, maxLoRARank, workspace, tilingData.cubeTiling);
196189
op.Process();
197190
} else {
198191
SGEMMCShrink<half, float> op(&pipe);
199192
op.Init(x, weight, loraIndices, loraIndicesSize, seqLen, seqLenSize, loraRanks, loraRanksSize, loraScales,
200-
loraScalesSize, y, batchSize, numBlocksPerCore, inputHiddenDim, maxLoRARank, workspace,
201-
tilingData.cubeTiling);
193+
loraScalesSize, y, batchSize, inputHiddenDim, maxLoRARank, workspace, tilingData.cubeTiling);
202194
op.Process();
203195
}
204196
}

0 commit comments

Comments
 (0)