Skip to content

Commit 476c2a1

Browse files
committed
add apply_top_k_top_p_min_p op
1 parent 640d5a7 commit 476c2a1

11 files changed

Lines changed: 1021 additions & 0 deletions

csrc/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ FILE(GLOB OP_SRCS
2222
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/lightning_indexer.cpp
2323
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/tiling/lightning_indexer_tiling.cpp
2424
${PROJECT_OP_SRC_BASE}/tri_inv/op_host/tri_inv.cpp
25+
${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp
26+
${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp
2527
)
2628
if(BUILD_CATLASS_MODULE)
2729
list(APPEND OP_SRCS
@@ -53,6 +55,7 @@ set(WORKSPACE_KERNEL_SRCS
5355
${PROJECT_OP_SRC_BASE}/alloc_extend/op_kernel/alloc_extend_kernel.cpp
5456
${PROJECT_OP_SRC_BASE}/build_tree/op_kernel/build_tree_kernel.cpp
5557
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_kernel/lightning_indexer_kernel.cpp
58+
${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp
5659
)
5760
if(BUILD_CATLASS_MODULE)
5861
list(APPEND WORKSPACE_KERNEL_SRCS
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
## Introduction
2+
A top-k, top-p and min-p sampling implementation for ascend.
3+
4+
## Sheet 1: Parameters
5+
| Parameter | Dimension | Data Type | Format | Description |
6+
|--------------|--------------------------|----------------------|--------|--------------------------------------------------|
7+
| probs | [batch_size, vocab_size] | float32/float16/bf16 | ND | Probabilities for sampling.<br>The probabilities should be sorted in descending order. |
8+
| k | [batch_size] | int32 | ND | Representing the threshold for top-k sampling. |
9+
| p | [batch_size] | float32/float16/bf16 | ND | Representing the threshold for top-p sampling. |
10+
| min_p | [batch_size] | float32/float16/bf16 | ND | Representing the threshold for min-p sampling.<br>When min_p is nullptr, the min-p sampling will be skipped. |
11+
| sampled_res | [batch_size, vocab_size] | float32/float16/bf16 | ND | The result after sampling.<br>The DataType of sampled_res should be same as probs. |
12+
13+
## Calculation Formula
14+
$$
15+
sampled\_res[b][v] =
16+
\begin{cases}
17+
0 & \text{v >= k[b]} \\
18+
probs[b][v] & \text{v < k[b]}
19+
\end{cases}
20+
$$
21+
$$probs\_sum = cumsum(sampled\_res, dim=-1)$$
22+
$$top\_p\_mask[b][v] = probs\_sum[b][v] - sampled\_res[b][v] > p[b]$$
23+
$$
24+
sampled\_res[b][v] =
25+
\begin{cases}
26+
0 & \text{top\_p\_mask = True} \\
27+
sampled\_res[b][v] & \text{top\_p\_mask = False}
28+
\end{cases}
29+
$$
30+
$$min\_p\_mask[b][v] = sampled\_res[b][v] < sampled\_res[b][0] * min\_p[b]$$
31+
$$
32+
sampled\_res[b][v] =
33+
\begin{cases}
34+
0 & \text{min\_p\_mask = True} \\
35+
sampled\_res[b][v] & \text{min\_p\_mask = False}
36+
\end{cases}
37+
$$
38+
Where $0 \le b \lt batch\_size$, and $0 \le v \lt vocab\_size$.
39+
40+
## Restrictions
41+
1. Only support Ascend A2/A3.
42+
2. $0 \lt k[b] \le vocab\_size$, where $0 \le b \lt batch\_size$.
43+
2. $0 \le p[b] \le 1$, where $0 \le b \lt batch\_size$.
44+
45+
## Sample Code
46+
```python
47+
import numpy as np
48+
import torch
49+
import torch_npu
50+
import sgl_kernel_npu
51+
52+
dtype = torch.float16
53+
batch_size = 4
54+
vocab_size = 128
55+
56+
logits = torch.tensor(np.random.uniform(-10, 10, (batch_size, vocab_size))).to(dtype).npu()
57+
k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to(torch.int32).npu()
58+
p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype).npu()
59+
min_p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype).npu()
60+
61+
probs = torch.softmax(logits, dim=-1)
62+
probs_sort, probs_idx = probs.sort(dim=-1, descending=True, stable=True)
63+
64+
torch.ops.npu.apply_top_k_top_p_min_p(probs_sort, k, p, min_p=min_p)
65+
```
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#include <cstdio>
2+
#include <string>
3+
#include "acl/acl.h"
4+
#include "kernel_tiling/kernel_tiling.h"
5+
#include "tiling/platform/platform_ascendc.h"
6+
#include "tiling/apply_top_k_top_p_min_p_tiling.h"
7+
#include "defines.h"
8+
#include "torch_helper.h"
9+
#include "ge_helper.h"
10+
#include "common_tiling.h"
11+
#include "apply_top_k_top_p_min_p_def.h"
12+
#include "common.h"
13+
#include "aclrtlaunch_apply_top_k_top_p_min_p.h"
14+
15+
namespace sglang::ATKTPMPHost {
16+
17+
using namespace ge_helper;
18+
constexpr uint32_t PADDING_BYTE = 32U;
19+
20+
inline at::Tensor ConstructApplyTopKTopPMinPOutputTensor(const at::Tensor &probs)
21+
{
22+
for (size_t i = 0; i < probs.sizes().size(); i++) {
23+
TORCH_CHECK(probs.size(i) > 0,
24+
"All values within probs's shape should be greater "
25+
"than 0, but shape[",
26+
i, "] is ", probs.size(i));
27+
}
28+
at::Tensor output = at::empty_like(probs);
29+
return output;
30+
}
31+
} // namespace sglang::ATKTPMPHost
32+
33+
namespace sglang {
34+
namespace npu_kernel {
35+
HOST_API at::Tensor apply_top_k_top_p_min_p(const at::Tensor &probs, const at::Tensor &k, const at::Tensor &p,
36+
const c10::optional<at::Tensor> &min_p)
37+
{
38+
using namespace ATKTPMPHost;
39+
at::Tensor sampledRes = ConstructApplyTopKTopPMinPOutputTensor(probs);
40+
41+
auto probsType = probs.scalar_type();
42+
43+
at::Tensor minP =
44+
min_p.has_value()
45+
? min_p.value()
46+
: at::empty({1}, at::TensorOptions().dtype(probsType).device(probs.options().device()));
47+
48+
ApplyTopKTopPMinPTilingInfo applyTopKTopPMinPInfo;
49+
applyTopKTopPMinPInfo.opParamInfo.probs.dtype = SCALAR_TYPE_TO_GE_DATATYPE(probsType);
50+
applyTopKTopPMinPInfo.opParamInfo.probs.shape = probs.sizes();
51+
applyTopKTopPMinPInfo.opParamInfo.k.dtype = SCALAR_TYPE_TO_GE_DATATYPE(k.scalar_type());
52+
applyTopKTopPMinPInfo.opParamInfo.k.shape = k.sizes();
53+
applyTopKTopPMinPInfo.opParamInfo.p.dtype = SCALAR_TYPE_TO_GE_DATATYPE(p.scalar_type());
54+
applyTopKTopPMinPInfo.opParamInfo.p.shape = p.sizes();
55+
if (min_p.has_value()) {
56+
applyTopKTopPMinPInfo.opParamInfo.minP.dtype = SCALAR_TYPE_TO_GE_DATATYPE(minP.scalar_type());
57+
applyTopKTopPMinPInfo.opParamInfo.minP.shape = minP.sizes();
58+
}
59+
applyTopKTopPMinPInfo.opParamInfo.sampledRes.dtype = SCALAR_TYPE_TO_GE_DATATYPE(sampledRes.scalar_type());
60+
applyTopKTopPMinPInfo.opParamInfo.sampledRes.shape = sampledRes.sizes();
61+
62+
ApplyTopKTopPMinPTiling applyTopKTopPMinPTiling(&applyTopKTopPMinPInfo);
63+
TORCH_CHECK(applyTopKTopPMinPTiling.DoTiling() == ge::GRAPH_SUCCESS,
64+
"apply_top_k_top_p_min_p DoTiling failed")
65+
66+
const auto &tilingData = applyTopKTopPMinPTiling.GetTilingData();
67+
68+
uint32_t tilingSize = (sizeof(ApplyTopKTopPMinPTiling) + PADDING_BYTE - 1) / PADDING_BYTE * PADDING_BYTE;
69+
auto blockDim = tilingData.coreNum;
70+
static auto tilingBuffer =
71+
at::empty({tilingSize}, at::TensorOptions().dtype(at::kByte).device(probs.options().device()));
72+
aclrtMemcpy(tilingBuffer.data_ptr<uint8_t>(), tilingSize, &tilingData, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
73+
at::Tensor tilingTensor = at::from_blob(tilingBuffer.data_ptr<uint8_t>(), tilingSize, at::kByte);
74+
75+
auto workspace = at::empty({applyTopKTopPMinPInfo.workspaceSize},
76+
at::TensorOptions().dtype(at::kByte).device(probs.options().device()));
77+
EXEC_KERNEL_CMD(apply_top_k_top_p_min_p, blockDim, probs, k, p, minP, sampledRes, workspace, tilingTensor);
78+
return sampledRes;
79+
}
80+
} // namespace npu_kernel
81+
} // namespace sglang
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/**
2+
* This program is free software, you can redistribute it and/or modify it.
3+
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
4+
* This file is a part of the CANN Open Software.
5+
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
6+
* Please refer to the License for details. You may not use this file except in compliance with the License.
7+
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
8+
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
9+
* the software repository for the full text of the License.
10+
*/
11+
12+
/*!
13+
* \file apply_top_k_top_p_min_p_def.cpp
14+
* \brief
15+
*/
16+
#include <cstdint>
17+
#include "ge_helper.h"
18+
19+
namespace sglang {
20+
namespace ATKTPMPHost {
21+
using namespace ge_helper;
22+
class ApplyTopKTopPMinP : public OpDef
23+
{
24+
public:
25+
explicit ApplyTopKTopPMinP(const char *name) : OpDef(name)
26+
{
27+
this->Input("probs")
28+
.ParamType(REQUIRED)
29+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
30+
.FormatList({ge::FORMAT_ND})
31+
.AutoContiguous();
32+
this->Input("k")
33+
.ParamType(REQUIRED)
34+
.DataTypeList({ge::DT_INT32})
35+
.FormatList({ge::FORMAT_ND})
36+
.AutoContiguous();
37+
this->Input("p")
38+
.ParamType(REQUIRED)
39+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
40+
.FormatList({ge::FORMAT_ND})
41+
.AutoContiguous();
42+
this->Input("min_p")
43+
.ParamType(OPTIONAL)
44+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
45+
.FormatList({ge::FORMAT_ND})
46+
.AutoContiguous();
47+
this->Output("sampled_res")
48+
.ParamType(REQUIRED)
49+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
50+
.FormatList({ge::FORMAT_ND});
51+
}
52+
};
53+
} // namespace ATKTPMPHost
54+
} // namespace sglang
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/**
2+
* This program is free software, you can redistribute it and/or modify it.
3+
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
4+
* This file is a part of the CANN Open Software.
5+
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
6+
* Please refer to the License for details. You may not use this file except in compliance with the License.
7+
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
8+
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
9+
* the software repository for the full text of the License.
10+
*/
11+
12+
/*!
13+
* \file apply_top_k_top_p_min_p_tiling.cpp
14+
* \brief
15+
*/
16+
17+
#include "apply_top_k_top_p_min_p_tiling.h"
18+
19+
using namespace ge;
20+
using namespace AscendC;
21+
using std::map;
22+
using std::string;
23+
namespace sglang::ATKTPMPHost {
24+
25+
// --------------------------ApplyTopKTopPMinPTiling类成员函数定义-----------------------
26+
ge::graphStatus ApplyTopKTopPMinPTiling::CheckDtype()
27+
{
28+
TORCH_CHECK((tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT16) ||
29+
(tilingInfo_->opParamInfo.probs.dtype == ge::DT_BF16) ||
30+
(tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT),
31+
"The data types of probs, p and sampled_res must be float16, bfloat16 or float.");
32+
33+
TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.p.dtype,
34+
"The data types of probs and p must be the same.");
35+
TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.sampledRes.dtype,
36+
"The data types of probs and sampled_res must be the same.");
37+
38+
TORCH_CHECK(tilingInfo_->opParamInfo.k.dtype == ge::DT_INT32,
39+
"The data types of the input k must be int32.");
40+
41+
return ge::GRAPH_SUCCESS;
42+
}
43+
44+
ge::graphStatus ApplyTopKTopPMinPTiling::CheckShape()
45+
{
46+
TORCH_CHECK(tilingInfo_->opParamInfo.probs.shape.size() == DIM_NUM_TWO,
47+
"ApplyTopKTopPMinP: the dimNum of probs should be ", DIM_NUM_TWO, ", but now is ",
48+
tilingInfo_->opParamInfo.probs.shape.size(), ".");
49+
tilingData_.batchSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ZERO];
50+
tilingData_.vocabSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ONE];
51+
52+
TORCH_CHECK(tilingInfo_->opParamInfo.k.shape.size() == DIM_NUM_ONE,
53+
"ApplyTopKTopPMinP: the dimNum of k should be ", DIM_NUM_ONE, ", but now is ",
54+
tilingInfo_->opParamInfo.k.shape.size(), ".");
55+
int64_t kSize = tilingInfo_->opParamInfo.k.shape[DIM_IDX_ZERO];
56+
TORCH_CHECK(kSize == tilingData_.batchSize,
57+
"ApplyTopKTopPMinP: the shape of k should be [", tilingData_.batchSize, "], but now is [", kSize, "].");
58+
59+
TORCH_CHECK(tilingInfo_->opParamInfo.p.shape.size() == DIM_NUM_ONE,
60+
"ApplyTopKTopPMinP: the dimNum of p should be ", DIM_NUM_ONE, ", but now is ",
61+
tilingInfo_->opParamInfo.p.shape.size(), ".");
62+
int64_t pSize = tilingInfo_->opParamInfo.p.shape[DIM_IDX_ZERO];
63+
TORCH_CHECK(pSize == tilingData_.batchSize,
64+
"ApplyTopKTopPMinP: the shape of p should be [", tilingData_.batchSize, "], but now is [", pSize, "].");
65+
66+
if (tilingInfo_->opParamInfo.minP.shape.size() != DIM_NUM_ZERO) {
67+
int64_t minPSize = tilingInfo_->opParamInfo.minP.shape[DIM_IDX_ZERO];
68+
TORCH_CHECK(minPSize == tilingData_.batchSize, ": the shape of p should be [", tilingData_.batchSize,
69+
"], but now is [", minPSize, "].");
70+
tilingInfo_->needMinPSample = 1;
71+
}
72+
73+
TORCH_CHECK(tilingInfo_->opParamInfo.sampledRes.shape.size() == DIM_NUM_TWO,
74+
"ApplyTopKTopPMinP: the dimNum of sampled_res should be ", DIM_NUM_TWO, ", but now is ",
75+
tilingInfo_->opParamInfo.sampledRes.shape.size(), ".");
76+
int64_t sampledResSize0 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ZERO];
77+
int64_t sampledResSize1 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ONE];
78+
TORCH_CHECK(sampledResSize0 == tilingData_.batchSize && sampledResSize1 == tilingData_.vocabSize,
79+
"ApplyTopKTopPMinP: the size of sampledRes should be [",
80+
tilingData_.batchSize, ", ", tilingData_.vocabSize,
81+
"], but now is [", sampledResSize0, ", ", sampledResSize1, "].");
82+
return ge::GRAPH_SUCCESS;
83+
}
84+
85+
void ApplyTopKTopPMinPTiling::SplitTask()
86+
{
87+
tilingData_.loopDataNum = tilingData_.ubSize / BYTES_B32 / LOCAL_TENSOR_NUM / BYTES_PER_REPEAT * BYTES_PER_REPEAT;
88+
tilingData_.coreNum = tilingData_.batchSize > tilingData_.coreNum ? tilingData_.coreNum : tilingData_.batchSize;
89+
tilingData_.batchPerCore = tilingData_.batchSize / std::max(tilingData_.coreNum, static_cast<int64_t>(1));
90+
tilingData_.batchTailCore = tilingData_.batchSize - tilingData_.batchPerCore * tilingData_.coreNum;
91+
}
92+
93+
ge::graphStatus ApplyTopKTopPMinPTiling::DoTiling()
94+
{
95+
if (CheckDtype() != ge::GRAPH_SUCCESS) {
96+
return ge::GRAPH_FAILED;
97+
}
98+
if (CheckShape() != ge::GRAPH_SUCCESS) {
99+
return ge::GRAPH_FAILED;
100+
}
101+
102+
auto ascendcPlatform = *platform_ascendc::PlatformAscendCManager::GetInstance();
103+
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
104+
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
105+
TORCH_CHECK(aivNum != 0 && aivNum != 0, "num of core obtained is 0");
106+
tilingData_.coreNum = static_cast<int64_t>(aivNum);
107+
108+
uint64_t ubSize = 0;
109+
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
110+
tilingData_.ubSize = static_cast<int64_t>(ubSize) - SELECT_MODE_BYTES;
111+
112+
auto socVersion = ascendcPlatform.GetSocVersion();
113+
TORCH_CHECK(socVersion == platform_ascendc::SocVersion::ASCEND910B ||
114+
socVersion == platform_ascendc::SocVersion::ASCEND910_93,
115+
"soc version does not support ", (int32_t)socVersion);
116+
117+
SplitTask();
118+
119+
// -------------set workspacesize-----------------
120+
tilingInfo_->workspaceSize = static_cast<int64_t>(ascendcPlatform.GetLibApiWorkSpaceSize()) +
121+
tilingData_.batchSize * tilingData_.vocabSize * BYTES_B32;
122+
123+
// -------------set tilingkey-----------------
124+
tilingData_.tilingKey = G_DTYPE_MAP.at(tilingInfo_->opParamInfo.probs.dtype) * COEF_TEN +
125+
tilingInfo_->needMinPSample;
126+
127+
return ge::GRAPH_SUCCESS;
128+
}
129+
130+
const ApplyTopKTopPMinPTilingData &ApplyTopKTopPMinPTiling::GetTilingData() const
131+
{
132+
return tilingData_;
133+
}
134+
} // namespace sglang::ATKTPMPHost

0 commit comments

Comments
 (0)