diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt
index 6b3236f8b..68826e8d2 100644
--- a/csrc/CMakeLists.txt
+++ b/csrc/CMakeLists.txt
@@ -22,6 +22,8 @@ FILE(GLOB OP_SRCS
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/lightning_indexer.cpp
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/tiling/lightning_indexer_tiling.cpp
${PROJECT_OP_SRC_BASE}/tri_inv/op_host/tri_inv.cpp
+ ${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp
+ ${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp
)
if(BUILD_CATLASS_MODULE)
list(APPEND OP_SRCS
@@ -53,6 +55,7 @@ set(WORKSPACE_KERNEL_SRCS
${PROJECT_OP_SRC_BASE}/alloc_extend/op_kernel/alloc_extend_kernel.cpp
${PROJECT_OP_SRC_BASE}/build_tree/op_kernel/build_tree_kernel.cpp
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_kernel/lightning_indexer_kernel.cpp
+ ${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp
)
if(BUILD_CATLASS_MODULE)
list(APPEND WORKSPACE_KERNEL_SRCS
diff --git a/csrc/apply_top_k_top_p_min_p/README.md b/csrc/apply_top_k_top_p_min_p/README.md
new file mode 100644
index 000000000..f452ea4a2
--- /dev/null
+++ b/csrc/apply_top_k_top_p_min_p/README.md
@@ -0,0 +1,65 @@
+## Introduction
+A top-k, top-p and min-p sampling implementation for ascend.
+
+## Sheet 1: Parameters
+| Parameter | Dimension | Data Type | Format | Description |
+|--------------|--------------------------|----------------------|--------|--------------------------------------------------|
+| probs | [batch_size, vocab_size] | float32/float16/bf16 | ND | Probabilities for sampling.
The probabilities should be sorted in descending order. |
+| k | [batch_size] | int32 | ND | Representing the threshold for top-k sampling. |
+| p | [batch_size] | float32/float16/bf16 | ND | Representing the threshold for top-p sampling. |
+| min_p | [batch_size] | float32/float16/bf16 | ND | Representing the threshold for min-p sampling.
When min_p is nullptr, the min-p sampling will be skipped. |
+| sampled_res | [batch_size, vocab_size] | float32/float16/bf16 | ND | The result after sampling.
The DataType of sampled_res should be same as probs. |
+
+## Calculation Formula
+$$
+sampled\_res[b][v] =
+\begin{cases}
+0 & \text{v >= k[b]} \\
+probs[b][v] & \text{v < k[b]}
+\end{cases}
+$$
+$$probs\_sum = cumsum(sampled\_res, dim=-1)$$
+$$top\_p\_mask[b][v] = probs\_sum[b][v] - sampled\_res[b][v] > p[b]$$
+$$
+sampled\_res[b][v] =
+\begin{cases}
+0 & \text{top\_p\_mask = True} \\
+sampled\_res[b][v] & \text{top\_p\_mask = False}
+\end{cases}
+$$
+$$min\_p\_mask[b][v] = sampled\_res[b][v] < sampled\_res[b][0] * min\_p[b]$$
+$$
+sampled\_res[b][v] =
+\begin{cases}
+0 & \text{min\_p\_mask = True} \\
+sampled\_res[b][v] & \text{min\_p\_mask = False}
+\end{cases}
+$$
+Where $0 \le b \lt batch\_size$, and $0 \le v \lt vocab\_size$.
+
+## Restrictions
+1. Only support Ascend A2/A3.
+2. $0 \lt k[b] \le vocab\_size$, where $0 \le b \lt batch\_size$, if $k[b] \lt 0$ or $k[b] \gt vocab\_size$, the $k[b]$ will regarded as vocab\_size.
+2. $0 \le p[b] \le 1$, where $0 \le b \lt batch\_size$.
+
+## Sample Code
+```python
+import numpy as np
+import torch
+import torch_npu
+import sgl_kernel_npu
+
+dtype = torch.float16
+batch_size = 4
+vocab_size = 128
+
+logits = torch.tensor(np.random.uniform(-10, 10, (batch_size, vocab_size))).to(dtype).npu()
+k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to(torch.int32).npu()
+p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype).npu()
+min_p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype).npu()
+
+probs = torch.softmax(logits, dim=-1)
+probs_sort, probs_idx = probs.sort(dim=-1, descending=True, stable=True)
+
+torch.ops.npu.apply_top_k_top_p_min_p(probs_sort, k, p, min_p=min_p)
+```
diff --git a/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp
new file mode 100644
index 000000000..e543795f9
--- /dev/null
+++ b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp
@@ -0,0 +1,79 @@
+#include
+#include
+#include "acl/acl.h"
+#include "kernel_tiling/kernel_tiling.h"
+#include "tiling/platform/platform_ascendc.h"
+#include "tiling/apply_top_k_top_p_min_p_tiling.h"
+#include "defines.h"
+#include "torch_helper.h"
+#include "ge_helper.h"
+#include "common_tiling.h"
+#include "apply_top_k_top_p_min_p_def.h"
+#include "common.h"
+#include "aclrtlaunch_apply_top_k_top_p_min_p.h"
+
+namespace sglang::ATKTPMPHost {
+
+using namespace ge_helper;
+constexpr uint32_t PADDING_BYTE = 32U;
+
+inline at::Tensor ConstructApplyTopKTopPMinPOutputTensor(const at::Tensor &probs)
+{
+ for (size_t i = 0; i < probs.sizes().size(); i++) {
+ TORCH_CHECK(probs.size(i) > 0,
+ "All values within probs's shape should be greater "
+ "than 0, but shape[",
+ i, "] is ", probs.size(i));
+ }
+ at::Tensor output = at::empty_like(probs);
+ return output;
+}
+} // namespace sglang::ATKTPMPHost
+
+namespace sglang {
+namespace npu_kernel {
+HOST_API at::Tensor apply_top_k_top_p_min_p(const at::Tensor &probs, const at::Tensor &k, const at::Tensor &p,
+ const c10::optional &min_p)
+{
+ using namespace ATKTPMPHost;
+ at::Tensor sampledRes = ConstructApplyTopKTopPMinPOutputTensor(probs);
+
+ auto probsType = probs.scalar_type();
+
+ at::Tensor minP = min_p.has_value()
+ ? min_p.value()
+ : at::empty({1}, at::TensorOptions().dtype(probsType).device(probs.options().device()));
+
+ ApplyTopKTopPMinPTilingInfo applyTopKTopPMinPInfo;
+ applyTopKTopPMinPInfo.opParamInfo.probs.dtype = SCALAR_TYPE_TO_GE_DATATYPE(probsType);
+ applyTopKTopPMinPInfo.opParamInfo.probs.shape = probs.sizes();
+ applyTopKTopPMinPInfo.opParamInfo.k.dtype = SCALAR_TYPE_TO_GE_DATATYPE(k.scalar_type());
+ applyTopKTopPMinPInfo.opParamInfo.k.shape = k.sizes();
+ applyTopKTopPMinPInfo.opParamInfo.p.dtype = SCALAR_TYPE_TO_GE_DATATYPE(p.scalar_type());
+ applyTopKTopPMinPInfo.opParamInfo.p.shape = p.sizes();
+ if (min_p.has_value()) {
+ applyTopKTopPMinPInfo.opParamInfo.minP.dtype = SCALAR_TYPE_TO_GE_DATATYPE(minP.scalar_type());
+ applyTopKTopPMinPInfo.opParamInfo.minP.shape = minP.sizes();
+ }
+ applyTopKTopPMinPInfo.opParamInfo.sampledRes.dtype = SCALAR_TYPE_TO_GE_DATATYPE(sampledRes.scalar_type());
+ applyTopKTopPMinPInfo.opParamInfo.sampledRes.shape = sampledRes.sizes();
+
+ ApplyTopKTopPMinPTiling applyTopKTopPMinPTiling(&applyTopKTopPMinPInfo);
+ TORCH_CHECK(applyTopKTopPMinPTiling.DoTiling() == ge::GRAPH_SUCCESS, "apply_top_k_top_p_min_p DoTiling failed");
+
+ const auto &tilingData = applyTopKTopPMinPTiling.GetTilingData();
+
+ uint32_t tilingSize = (sizeof(ApplyTopKTopPMinPTiling) + PADDING_BYTE - 1) / PADDING_BYTE * PADDING_BYTE;
+ auto blockDim = tilingData.coreNum;
+ static auto tilingBuffer =
+ at::empty({tilingSize}, at::TensorOptions().dtype(at::kByte).device(probs.options().device()));
+ aclrtMemcpy(tilingBuffer.data_ptr(), tilingSize, &tilingData, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
+ at::Tensor tilingTensor = at::from_blob(tilingBuffer.data_ptr(), tilingSize, at::kByte);
+
+ auto workspace = at::empty({applyTopKTopPMinPInfo.workspaceSize},
+ at::TensorOptions().dtype(at::kByte).device(probs.options().device()));
+ EXEC_KERNEL_CMD(apply_top_k_top_p_min_p, blockDim, probs, k, p, minP, sampledRes, workspace, tilingTensor);
+ return sampledRes;
+}
+} // namespace npu_kernel
+} // namespace sglang
diff --git a/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h
new file mode 100644
index 000000000..7fb248d48
--- /dev/null
+++ b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h
@@ -0,0 +1,50 @@
+/**
+ * This program is free software, you can redistribute it and/or modify it.
+ * Copyright (c) 2026 Huawei Technologies Co., Ltd.
+ * This file is a part of the CANN Open Software.
+ * Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
+ * Please refer to the License for details. You may not use this file except in compliance with the License.
+ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
+ * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
+ * the software repository for the full text of the License.
+ */
+
+/*!
+ * \file apply_top_k_top_p_min_p_def.cpp
+ * \brief
+ */
+#include
+#include "ge_helper.h"
+
+namespace sglang {
+namespace ATKTPMPHost {
+using namespace ge_helper;
+class ApplyTopKTopPMinP : public OpDef
+{
+public:
+ explicit ApplyTopKTopPMinP(const char *name) : OpDef(name)
+ {
+ this->Input("probs")
+ .ParamType(REQUIRED)
+ .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
+ .FormatList({ge::FORMAT_ND})
+ .AutoContiguous();
+ this->Input("k").ParamType(REQUIRED).DataTypeList({ge::DT_INT32}).FormatList({ge::FORMAT_ND}).AutoContiguous();
+ this->Input("p")
+ .ParamType(REQUIRED)
+ .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
+ .FormatList({ge::FORMAT_ND})
+ .AutoContiguous();
+ this->Input("min_p")
+ .ParamType(OPTIONAL)
+ .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
+ .FormatList({ge::FORMAT_ND})
+ .AutoContiguous();
+ this->Output("sampled_res")
+ .ParamType(REQUIRED)
+ .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
+ .FormatList({ge::FORMAT_ND});
+ }
+};
+} // namespace ATKTPMPHost
+} // namespace sglang
diff --git a/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp
new file mode 100644
index 000000000..8f38d1c0a
--- /dev/null
+++ b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp
@@ -0,0 +1,130 @@
+/**
+ * This program is free software, you can redistribute it and/or modify it.
+ * Copyright (c) 2026 Huawei Technologies Co., Ltd.
+ * This file is a part of the CANN Open Software.
+ * Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
+ * Please refer to the License for details. You may not use this file except in compliance with the License.
+ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
+ * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
+ * the software repository for the full text of the License.
+ */
+
+/*!
+ * \file apply_top_k_top_p_min_p_tiling.cpp
+ * \brief
+ */
+
+#include "apply_top_k_top_p_min_p_tiling.h"
+
+using namespace ge;
+using namespace AscendC;
+using std::map;
+using std::string;
+namespace sglang::ATKTPMPHost {
+
+// --------------------------ApplyTopKTopPMinPTiling类成员函数定义-----------------------
+ge::graphStatus ApplyTopKTopPMinPTiling::CheckDtype()
+{
+ TORCH_CHECK((tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT16) ||
+ (tilingInfo_->opParamInfo.probs.dtype == ge::DT_BF16) ||
+ (tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT),
+ "The data types of probs, p and sampled_res must be float16, bfloat16 or float.");
+
+ TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.p.dtype,
+ "The data types of probs and p must be the same.");
+ TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.sampledRes.dtype,
+ "The data types of probs and sampled_res must be the same.");
+
+ TORCH_CHECK(tilingInfo_->opParamInfo.k.dtype == ge::DT_INT32, "The data types of the input k must be int32.");
+
+ return ge::GRAPH_SUCCESS;
+}
+
+ge::graphStatus ApplyTopKTopPMinPTiling::CheckShape()
+{
+ TORCH_CHECK(tilingInfo_->opParamInfo.probs.shape.size() == DIM_NUM_TWO,
+ "ApplyTopKTopPMinP: the dimNum of probs should be ", DIM_NUM_TWO, ", but now is ",
+ tilingInfo_->opParamInfo.probs.shape.size(), ".");
+ tilingData_.batchSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ZERO];
+ tilingData_.vocabSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ONE];
+
+ TORCH_CHECK(tilingInfo_->opParamInfo.k.shape.size() == DIM_NUM_ONE, "ApplyTopKTopPMinP: the dimNum of k should be ",
+ DIM_NUM_ONE, ", but now is ", tilingInfo_->opParamInfo.k.shape.size(), ".");
+ int64_t kSize = tilingInfo_->opParamInfo.k.shape[DIM_IDX_ZERO];
+ TORCH_CHECK(kSize == tilingData_.batchSize, "ApplyTopKTopPMinP: the shape of k should be [", tilingData_.batchSize,
+ "], but now is [", kSize, "].");
+
+ TORCH_CHECK(tilingInfo_->opParamInfo.p.shape.size() == DIM_NUM_ONE, "ApplyTopKTopPMinP: the dimNum of p should be ",
+ DIM_NUM_ONE, ", but now is ", tilingInfo_->opParamInfo.p.shape.size(), ".");
+ int64_t pSize = tilingInfo_->opParamInfo.p.shape[DIM_IDX_ZERO];
+ TORCH_CHECK(pSize == tilingData_.batchSize, "ApplyTopKTopPMinP: the shape of p should be [", tilingData_.batchSize,
+ "], but now is [", pSize, "].");
+
+ if (tilingInfo_->opParamInfo.minP.shape.size() != DIM_NUM_ZERO) {
+ int64_t minPSize = tilingInfo_->opParamInfo.minP.shape[DIM_IDX_ZERO];
+ TORCH_CHECK(minPSize == tilingData_.batchSize, ": the shape of p should be [", tilingData_.batchSize,
+ "], but now is [", minPSize, "].");
+ tilingInfo_->needMinPSample = 1;
+ }
+
+ TORCH_CHECK(tilingInfo_->opParamInfo.sampledRes.shape.size() == DIM_NUM_TWO,
+ "ApplyTopKTopPMinP: the dimNum of sampled_res should be ", DIM_NUM_TWO, ", but now is ",
+ tilingInfo_->opParamInfo.sampledRes.shape.size(), ".");
+ int64_t sampledResSize0 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ZERO];
+ int64_t sampledResSize1 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ONE];
+ TORCH_CHECK(sampledResSize0 == tilingData_.batchSize && sampledResSize1 == tilingData_.vocabSize,
+ "ApplyTopKTopPMinP: the size of sampledRes should be [", tilingData_.batchSize, ", ",
+ tilingData_.vocabSize, "], but now is [", sampledResSize0, ", ", sampledResSize1, "].");
+ return ge::GRAPH_SUCCESS;
+}
+
+void ApplyTopKTopPMinPTiling::SplitTask()
+{
+ tilingData_.loopDataNum = tilingData_.ubSize / BYTES_B32 / LOCAL_TENSOR_NUM / BYTES_PER_REPEAT * BYTES_PER_REPEAT;
+ tilingData_.coreNum = tilingData_.batchSize > tilingData_.coreNum ? tilingData_.coreNum : tilingData_.batchSize;
+ tilingData_.batchPerCore = tilingData_.batchSize / std::max(tilingData_.coreNum, static_cast(1));
+ tilingData_.batchTailCore = tilingData_.batchSize - tilingData_.batchPerCore * tilingData_.coreNum;
+}
+
+ge::graphStatus ApplyTopKTopPMinPTiling::DoTiling()
+{
+ if (CheckDtype() != ge::GRAPH_SUCCESS) {
+ return ge::GRAPH_FAILED;
+ }
+ if (CheckShape() != ge::GRAPH_SUCCESS) {
+ return ge::GRAPH_FAILED;
+ }
+
+ auto ascendcPlatform = *platform_ascendc::PlatformAscendCManager::GetInstance();
+ uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
+ uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
+ TORCH_CHECK(aivNum != 0 && aivNum != 0, "num of core obtained is 0");
+ tilingData_.coreNum = static_cast(aivNum);
+
+ uint64_t ubSize = 0;
+ ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
+ tilingData_.ubSize = static_cast(ubSize) - SELECT_MODE_BYTES;
+
+ auto socVersion = ascendcPlatform.GetSocVersion();
+ TORCH_CHECK(socVersion == platform_ascendc::SocVersion::ASCEND910B ||
+ socVersion == platform_ascendc::SocVersion::ASCEND910_93,
+ "soc version does not support ", (int32_t)socVersion);
+
+ SplitTask();
+
+ // -------------set workspacesize-----------------
+ tilingInfo_->workspaceSize = static_cast(ascendcPlatform.GetLibApiWorkSpaceSize()) +
+ tilingData_.batchSize * tilingData_.vocabSize * BYTES_B32;
+
+ // -------------set tilingkey-----------------
+ tilingData_.tilingKey =
+ G_DTYPE_MAP.at(tilingInfo_->opParamInfo.probs.dtype) * COEF_TEN + tilingInfo_->needMinPSample;
+
+ return ge::GRAPH_SUCCESS;
+}
+
+const ApplyTopKTopPMinPTilingData &ApplyTopKTopPMinPTiling::GetTilingData() const
+{
+ return tilingData_;
+}
+} // namespace sglang::ATKTPMPHost
diff --git a/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h
new file mode 100644
index 000000000..1b7cda228
--- /dev/null
+++ b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h
@@ -0,0 +1,84 @@
+/**
+ * This program is free software, you can redistribute it and/or modify it.
+ * Copyright (c) 2026 Huawei Technologies Co., Ltd.
+ * This file is a part of the CANN Open Software.
+ * Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
+ * Please refer to the License for details. You may not use this file except in compliance with the License.
+ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
+ * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
+ * the software repository for the full text of the License.
+ */
+
+/*!
+ * \file apply_top_k_top_p_min_p_tiling.h
+ * \brief
+ */
+
+#ifndef APPLY_TOP_K_TOP_P_MIN_P_TILING_H_
+#define APPLY_TOP_K_TOP_P_MIN_P_TILING_H_
+
+#include "register/op_def_registry.h"
+#include "register/tilingdata_base.h"
+#include "tiling/platform/platform_ascendc.h"
+#include "tiling/tiling_api.h"
+#include "apply_top_k_top_p_min_p_tiling_data.h"
+#include "ge_helper.h"
+
+namespace sglang::ATKTPMPHost {
+struct TensorParaInfo {
+ ge::DataType dtype;
+ c10::ArrayRef shape;
+};
+
+const std::map G_DTYPE_MAP = {{ge::DT_FLOAT, 1}, {ge::DT_FLOAT16, 2}, {ge::DT_BF16, 3}};
+// ------------------算子原型索引常量定义----------------
+// Dim Index
+constexpr uint32_t DIM_IDX_ZERO = 0;
+constexpr uint32_t DIM_IDX_ONE = 1;
+// Dim Num
+constexpr uint32_t DIM_NUM_ZERO = 0;
+constexpr uint32_t DIM_NUM_ONE = 1;
+constexpr uint32_t DIM_NUM_TWO = 2;
+
+constexpr int64_t COEF_TEN = 10;
+constexpr int64_t BYTES_B32 = 4;
+constexpr int64_t LOCAL_TENSOR_NUM = 4;
+constexpr int64_t SELECT_MODE_BYTES = 8192;
+constexpr int64_t BYTES_PER_REPEAT = 256;
+
+// -----------算子Tiling入参结构体定义---------------
+struct ApplyTopKTopPMinPParaInfo {
+ TensorParaInfo probs = {ge::DT_FLOAT, c10::ArrayRef{}};
+ TensorParaInfo k = {ge::DT_INT32, c10::ArrayRef{}};
+ TensorParaInfo p = {ge::DT_FLOAT, c10::ArrayRef{}};
+ TensorParaInfo minP = {ge::DT_FLOAT, c10::ArrayRef{}};
+ TensorParaInfo sampledRes = {ge::DT_FLOAT, c10::ArrayRef{}};
+};
+
+// -----------算子Tiling入参信息类---------------
+class ApplyTopKTopPMinPTilingInfo
+{
+public:
+ ApplyTopKTopPMinPParaInfo opParamInfo;
+ int64_t workspaceSize = 0;
+ int64_t needMinPSample = 0;
+};
+
+// ---------------算子Tiling类---------------
+class ApplyTopKTopPMinPTiling
+{
+public:
+ explicit ApplyTopKTopPMinPTiling(ApplyTopKTopPMinPTilingInfo *tilingInfo) : tilingInfo_(tilingInfo) {};
+ ge::graphStatus CheckDtype();
+ ge::graphStatus CheckShape();
+ void SplitTask();
+ ge::graphStatus DoTiling();
+ const ApplyTopKTopPMinPTilingData &GetTilingData() const;
+
+private:
+ ApplyTopKTopPMinPTilingInfo *tilingInfo_ = nullptr;
+ ApplyTopKTopPMinPTilingData tilingData_;
+};
+
+} // namespace sglang::ATKTPMPHost
+#endif // APPLY_TOP_K_TOP_P_MIN_P_TILING_H_
diff --git a/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling_data.h b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling_data.h
new file mode 100644
index 000000000..4d0e15817
--- /dev/null
+++ b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling_data.h
@@ -0,0 +1,23 @@
+#ifndef APPLY_TOP_K_TOP_P_MIN_P_TILING_DATA_H
+#define APPLY_TOP_K_TOP_P_MIN_P_TILING_DATA_H
+#include
+
+namespace sglang {
+namespace ATKTPMPHost {
+
+// -----------算子TilingData定义---------------
+#pragma pack(push, 1)
+struct ApplyTopKTopPMinPTilingData {
+ int64_t batchSize = 0;
+ int64_t vocabSize = 0;
+ int64_t batchPerCore = 0;
+ int64_t batchTailCore = 0;
+ int64_t ubSize = 0;
+ int64_t coreNum = 0;
+ int64_t loopDataNum = 0;
+ int64_t tilingKey = 0;
+};
+#pragma pack(pop)
+} // namespace ATKTPMPHost
+} // namespace sglang
+#endif
diff --git a/csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp b/csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp
new file mode 100644
index 000000000..e44b615d2
--- /dev/null
+++ b/csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp
@@ -0,0 +1,440 @@
+/**
+ * This program is free software, you can redistribute it and/or modify it.
+ * Copyright (c) 2026 Huawei Technologies Co., Ltd.
+ * This file is a part of the CANN Open Software.
+ * Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
+ * Please refer to the License for details. You may not use this file except in compliance with the License.
+ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
+ * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
+ * the software repository for the full text of the License.
+ */
+
+/*!
+ * \file apply_top_k_top_p_min_p_kernel.cpp
+ * \brief
+ */
+
+#include "kernel_operator.h"
+#include "kernel_tiling/kernel_tiling.h"
+#include "../op_host/tiling/apply_top_k_top_p_min_p_tiling_data.h"
+
+using namespace AscendC;
+using namespace sglang::ATKTPMPHost;
+#define TTILING_FP32_WITHOUT_MIN_P 10
+#define TTILING_FP16_WITHOUT_MIN_P 20
+#define TTILING_BF16_WITHOUT_MIN_P 30
+#define TTILING_FP32_MIN_P 11
+#define TTILING_FP16_MIN_P 21
+#define TTILING_BF16_MIN_P 31
+
+namespace sglang::npu_kernel::ApplyTopKTopPMinPKernel {
+template
+__aicore__ inline void SetWaitFlag(HardEvent evt)
+{
+ event_t eventId = static_cast(GetTPipePtr()->FetchEventID(evt));
+ SetFlag(eventId);
+ WaitFlag(eventId);
+}
+
+template
+class ApplyTopKTopPMinP
+{
+public:
+ __aicore__ inline ApplyTopKTopPMinP(){};
+ __aicore__ inline void Init(const __gm__ ApplyTopKTopPMinPTilingData *tilingData, GM_ADDR probs, GM_ADDR k,
+ GM_ADDR p, GM_ADDR min_p, GM_ADDR sampled_res, GM_ADDR workspace, TPipe *tPipe);
+ __aicore__ inline void Process();
+
+private:
+ __aicore__ inline void CopyInWithCast(LocalTensor &localTensor, GlobalTensor &globalTensor,
+ int64_t gmOffset, int64_t copyNum);
+ __aicore__ inline void InitWorkspace(LocalTensor &workLocal);
+ __aicore__ inline void GetFloatValue(GlobalTensor &globalTensor, int64_t offset, float &value);
+ __aicore__ inline void CumSumImpl(LocalTensor &cumSumInput1Local, LocalTensor &cumSumInput2Local);
+ __aicore__ inline void CopyOutWithCast(GlobalTensor &globalTensor, LocalTensor &localTensor,
+ int64_t gmOffset, int64_t copyNum);
+ __aicore__ inline void TopPProcess(LocalTensor &cumSumInput1Local, LocalTensor &cumSumInput2Local,
+ LocalTensor &zeroLocal, LocalTensor &maskLocal);
+ __aicore__ inline void MinPProcess(LocalTensor &workLocal, LocalTensor &zeroLocal,
+ LocalTensor &maskLocal);
+
+private:
+ TBuf calBuf_;
+
+ // tilingData
+ int64_t batchSize_ = 0;
+ int64_t vocabSize_ = 0;
+ int64_t batchPerCore_ = 0;
+ int64_t batchTailCore_ = 0;
+ uint32_t blockIdx_ = 0;
+ int64_t coreBatch_ = 0;
+ int64_t batchOffset_ = 0;
+ int64_t coreNum_ = 0;
+ int64_t iterateTimes_ = 0;
+ int64_t baseGmOffset_ = 0;
+ int64_t probsGmOffset_ = 0;
+
+ int32_t kValue_ = 0;
+ float pValue_ = 0;
+ float minPValue_ = 0;
+ float maxValue_ = 0;
+ float minPThresholds_ = 0;
+ int64_t lastIndex_ = 0;
+ int64_t kLoopNum_ = 0;
+ int64_t kTailNum_ = 0;
+ int64_t loopDataNum_ = 0;
+
+ GlobalTensor gmProbs_;
+ GlobalTensor gmK_;
+ GlobalTensor gmP_;
+ GlobalTensor gmMinP_;
+ GlobalTensor gmSampledRes_;
+
+ GlobalTensor gmWk_;
+};
+
+template
+__aicore__ inline void ApplyTopKTopPMinP::Init(const __gm__ ApplyTopKTopPMinPTilingData *tilingData,
+ GM_ADDR probs, GM_ADDR k, GM_ADDR p, GM_ADDR min_p,
+ GM_ADDR sampled_res, GM_ADDR workspace, TPipe *tPipe)
+{
+ batchSize_ = tilingData->batchSize;
+ vocabSize_ = tilingData->vocabSize;
+ batchPerCore_ = tilingData->batchPerCore;
+ batchTailCore_ = tilingData->batchTailCore;
+ coreNum_ = tilingData->coreNum;
+ uint32_t ubSize = static_cast(tilingData->ubSize);
+ loopDataNum_ = tilingData->loopDataNum;
+
+ blockIdx_ = GetBlockIdx();
+ if (blockIdx_ >= coreNum_) {
+ return;
+ }
+
+ if (blockIdx_ < batchTailCore_) {
+ coreBatch_ = batchPerCore_ + 1;
+ batchOffset_ = blockIdx_ * coreBatch_;
+ } else {
+ coreBatch_ = batchPerCore_;
+ batchOffset_ = blockIdx_ * batchPerCore_ + batchTailCore_;
+ }
+ baseGmOffset_ = batchOffset_ * vocabSize_;
+
+ if (coreBatch_ == 0) {
+ return;
+ }
+
+ gmProbs_.SetGlobalBuffer((__gm__ T *)probs + baseGmOffset_);
+ gmK_.SetGlobalBuffer((__gm__ int32_t *)k + batchOffset_);
+ gmP_.SetGlobalBuffer((__gm__ T *)p + batchOffset_);
+ if constexpr (IsMinPSampling == 1) {
+ gmMinP_.SetGlobalBuffer((__gm__ T *)min_p + batchOffset_);
+ }
+ gmSampledRes_.SetGlobalBuffer((__gm__ T *)sampled_res + baseGmOffset_);
+ gmWk_.SetGlobalBuffer((__gm__ float *)workspace + baseGmOffset_);
+ InitGlobalMemory(gmSampledRes_, coreBatch_ * vocabSize_, T(0));
+ InitGlobalMemory(gmWk_, coreBatch_ * vocabSize_, float(0));
+ tPipe->InitBuffer(calBuf_, ubSize);
+}
+
+template
+__aicore__ inline void ApplyTopKTopPMinP::GetFloatValue(GlobalTensor &globalTensor,
+ int64_t offset, float &value)
+{
+ if constexpr (IsSameType::value) {
+ value = globalTensor.GetValue(offset);
+ } else if constexpr (IsSameType::value) {
+ value = static_cast(globalTensor.GetValue(offset));
+ } else {
+ value = ToFloat(globalTensor.GetValue(offset));
+ }
+}
+
+template
+__aicore__ inline void ApplyTopKTopPMinP::Process()
+{
+ if (blockIdx_ >= coreNum_) {
+ return;
+ }
+ uint32_t bufOffset = 0;
+ LocalTensor maskLocal = calBuf_.GetWithOffset(loopDataNum_ / 8, bufOffset);
+ bufOffset += loopDataNum_ / 8 * sizeof(uint8_t);
+ LocalTensor zeroLocal = calBuf_.GetWithOffset(loopDataNum_, bufOffset);
+ bufOffset += loopDataNum_ * sizeof(float);
+ LocalTensor cumSumInput1Local = calBuf_.GetWithOffset(loopDataNum_, bufOffset);
+ bufOffset += loopDataNum_ * sizeof(float);
+ LocalTensor cumSumInput2Local = calBuf_.GetWithOffset(loopDataNum_, bufOffset);
+ bufOffset += loopDataNum_ * sizeof(float);
+
+ Duplicate(zeroLocal, float(0), loopDataNum_);
+ SetWaitFlag(HardEvent::MTE3_MTE2);
+ for (int64_t batchLoop = 0; batchLoop < coreBatch_; batchLoop++) {
+ probsGmOffset_ = batchLoop * vocabSize_;
+
+ kValue_ = gmK_.GetValue(batchLoop);
+ GetFloatValue(gmP_, batchLoop, pValue_);
+ if constexpr (IsMinPSampling == 1) {
+ GetFloatValue(gmMinP_, batchLoop, minPValue_);
+ GetFloatValue(gmProbs_, probsGmOffset_, maxValue_);
+ minPThresholds_ = maxValue_ * minPValue_;
+ }
+ if (kValue_ > vocabSize_ || kValue_ < 0) {
+ lastIndex_ = vocabSize_;
+ } else {
+ lastIndex_ = kValue_;
+ }
+ kLoopNum_ = lastIndex_ / loopDataNum_;
+ kTailNum_ = lastIndex_ - kLoopNum_ * loopDataNum_;
+
+ InitWorkspace(cumSumInput1Local);
+
+ TopPProcess(cumSumInput1Local, cumSumInput2Local, zeroLocal, maskLocal);
+
+ if constexpr (IsMinPSampling == 1) {
+ MinPProcess(cumSumInput1Local, zeroLocal, maskLocal);
+ }
+ }
+}
+
+template
+__aicore__ inline void ApplyTopKTopPMinP::CopyInWithCast(LocalTensor &localTensor,
+ GlobalTensor &globalTensor,
+ int64_t gmOffset, int64_t copyNum)
+{
+ if constexpr (IsSameType::value) {
+ DataCopyPad(localTensor, globalTensor[gmOffset], {1, static_cast(copyNum * sizeof(T)), 0, 0, 0},
+ {false, 0, 0, 0});
+ } else {
+ DataCopyPad(localTensor.ReinterpretCast()[loopDataNum_], globalTensor[gmOffset],
+ {1, static_cast(copyNum * sizeof(T)), 0, 0, 0}, {false, 0, 0, 0});
+ SetWaitFlag(HardEvent::MTE2_V);
+ Cast(localTensor, localTensor.ReinterpretCast()[loopDataNum_], RoundMode::CAST_NONE, copyNum);
+ }
+}
+
+template
+__aicore__ inline void ApplyTopKTopPMinP::InitWorkspace(LocalTensor &workLocal)
+{
+ for (int64_t vocabLoop = 0; vocabLoop < kLoopNum_; vocabLoop++) {
+ CopyInWithCast(workLocal, gmProbs_, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_);
+ if constexpr (IsSameType::value) {
+ SetWaitFlag(HardEvent::MTE2_MTE3);
+ } else {
+ SetWaitFlag(HardEvent::V_MTE3);
+ }
+ DataCopyPad(gmWk_[probsGmOffset_ + vocabLoop * loopDataNum_], workLocal,
+ {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0});
+ SetWaitFlag(HardEvent::MTE3_MTE2);
+ }
+ if (kTailNum_ > 0) {
+ CopyInWithCast(workLocal, gmProbs_, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_);
+ if constexpr (IsSameType::value) {
+ SetWaitFlag(HardEvent::MTE2_MTE3);
+ } else {
+ SetWaitFlag(HardEvent::V_MTE3);
+ }
+ DataCopyPad(gmWk_[probsGmOffset_ + kLoopNum_ * loopDataNum_], workLocal,
+ {1, static_cast(kTailNum_ * sizeof(float)), 0, 0, 0});
+ SetWaitFlag(HardEvent::MTE3_MTE2);
+ }
+}
+
+template
+__aicore__ inline void ApplyTopKTopPMinP::CumSumImpl(LocalTensor &cumSumInput1Local,
+ LocalTensor &cumSumInput2Local)
+{
+ int64_t tmpValue = 1;
+ iterateTimes_ = 0;
+ while (tmpValue < lastIndex_) {
+ tmpValue <<= 1;
+ iterateTimes_++;
+ }
+ for (int64_t iterateTime = 0; iterateTime < iterateTimes_; iterateTime++) {
+ int64_t iteratOffset = 1;
+ for (int64_t powerIdx = 0; powerIdx < iterateTime; powerIdx++) {
+ iteratOffset = iteratOffset * 2;
+ }
+ int64_t addLength = lastIndex_ - iteratOffset;
+ int64_t innerLoopNum = addLength / loopDataNum_;
+ int64_t dataTail = addLength - innerLoopNum * loopDataNum_;
+ for (int64_t innerLoop = 0; innerLoop < innerLoopNum; innerLoop++) {
+ int64_t innerLoopOffset = dataTail + (innerLoopNum - 1 - innerLoop) * loopDataNum_;
+ DataCopyPad(cumSumInput1Local, gmWk_[probsGmOffset_ + innerLoopOffset],
+ {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
+ DataCopyPad(cumSumInput2Local, gmWk_[probsGmOffset_ + innerLoopOffset + iteratOffset],
+ {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
+ SetWaitFlag(HardEvent::MTE2_V);
+ Add(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum_);
+ SetWaitFlag(HardEvent::V_MTE3);
+ DataCopyPad(gmWk_[probsGmOffset_ + innerLoopOffset + iteratOffset], cumSumInput1Local,
+ {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0});
+ SetWaitFlag(HardEvent::MTE3_MTE2);
+ }
+ if (dataTail > 0) {
+ DataCopyPad(cumSumInput1Local, gmWk_[probsGmOffset_],
+ {1, static_cast(dataTail * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
+ DataCopyPad(cumSumInput2Local, gmWk_[probsGmOffset_ + iteratOffset],
+ {1, static_cast(dataTail * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
+ SetWaitFlag(HardEvent::MTE2_V);
+ Add(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, dataTail);
+ SetWaitFlag(HardEvent::V_MTE3);
+ DataCopyPad(gmWk_[probsGmOffset_ + iteratOffset], cumSumInput1Local,
+ {1, static_cast(dataTail * sizeof(float)), 0, 0, 0});
+ SetWaitFlag(HardEvent::MTE3_MTE2);
+ }
+ }
+}
+
+template
+__aicore__ inline void ApplyTopKTopPMinP::CopyOutWithCast(GlobalTensor &globalTensor,
+ LocalTensor &localTensor,
+ int64_t gmOffset, int64_t copyNum)
+{
+ if constexpr (IsSameType::value) {
+ DataCopyPad(globalTensor[gmOffset], localTensor, {1, static_cast(copyNum * sizeof(T)), 0, 0, 0});
+ } else {
+ Cast(localTensor.ReinterpretCast(), localTensor, RoundMode::CAST_ROUND, copyNum);
+ SetWaitFlag(HardEvent::V_MTE3);
+ DataCopyPad(globalTensor[gmOffset], localTensor.ReinterpretCast(),
+ {1, static_cast(copyNum * sizeof(T)), 0, 0, 0});
+ }
+}
+
+template
+__aicore__ inline void ApplyTopKTopPMinP::TopPProcess(LocalTensor &cumSumInput1Local,
+ LocalTensor &cumSumInput2Local,
+ LocalTensor &zeroLocal,
+ LocalTensor &maskLocal)
+{
+ CumSumImpl(cumSumInput1Local, cumSumInput2Local);
+ for (int64_t vocabLoop = 0; vocabLoop < kLoopNum_; vocabLoop++) {
+ DataCopyPad(cumSumInput1Local, gmWk_[probsGmOffset_ + vocabLoop * loopDataNum_],
+ {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
+ CopyInWithCast(cumSumInput2Local, gmProbs_, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_);
+ if constexpr (IsSameType::value) {
+ SetWaitFlag(HardEvent::MTE2_V);
+ } else {
+ PipeBarrier();
+ }
+ Sub(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum_);
+ PipeBarrier();
+ CompareScalar(maskLocal, cumSumInput1Local, pValue_, CMPMODE::GT, loopDataNum_);
+ PipeBarrier();
+ Select(cumSumInput2Local, maskLocal, zeroLocal, cumSumInput2Local, SELMODE::VSEL_TENSOR_TENSOR_MODE,
+ loopDataNum_);
+ if constexpr (IsSameType::value) {
+ SetWaitFlag(HardEvent::V_MTE3);
+ } else {
+ PipeBarrier();
+ }
+ CopyOutWithCast(gmSampledRes_, cumSumInput2Local, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_);
+ SetWaitFlag(HardEvent::MTE3_MTE2);
+ }
+ if (kTailNum_ > 0) {
+ DataCopyPad(cumSumInput1Local, gmWk_[probsGmOffset_ + kLoopNum_ * loopDataNum_],
+ {1, static_cast(kTailNum_ * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0});
+ CopyInWithCast(cumSumInput2Local, gmProbs_, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_);
+ if constexpr (IsSameType::value) {
+ SetWaitFlag(HardEvent::MTE2_V);
+ } else {
+ PipeBarrier();
+ }
+ Sub(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum_);
+ PipeBarrier();
+ CompareScalar(maskLocal, cumSumInput1Local, pValue_, CMPMODE::GT, loopDataNum_);
+ PipeBarrier();
+ Select(cumSumInput2Local, maskLocal, zeroLocal, cumSumInput2Local, SELMODE::VSEL_TENSOR_TENSOR_MODE,
+ loopDataNum_);
+ if constexpr (IsSameType::value) {
+ SetWaitFlag(HardEvent::V_MTE3);
+ } else {
+ PipeBarrier();
+ }
+ CopyOutWithCast(gmSampledRes_, cumSumInput2Local, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_);
+ SetWaitFlag(HardEvent::MTE3_MTE2);
+ }
+}
+
+template
+__aicore__ inline void ApplyTopKTopPMinP::MinPProcess(LocalTensor &workLocal,
+ LocalTensor &zeroLocal,
+ LocalTensor &maskLocal)
+{
+ for (int64_t vocabLoop = 0; vocabLoop < kLoopNum_; vocabLoop++) {
+ CopyInWithCast(workLocal, gmSampledRes_, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_);
+ if constexpr (IsSameType::value) {
+ SetWaitFlag(HardEvent::MTE2_V);
+ } else {
+ PipeBarrier();
+ }
+ CompareScalar(maskLocal, workLocal, minPThresholds_, CMPMODE::LT, loopDataNum_);
+ PipeBarrier();
+ Select(workLocal, maskLocal, zeroLocal, workLocal, SELMODE::VSEL_TENSOR_TENSOR_MODE, loopDataNum_);
+ if constexpr (IsSameType::value) {
+ SetWaitFlag(HardEvent::V_MTE3);
+ } else {
+ PipeBarrier();
+ }
+ CopyOutWithCast(gmSampledRes_, workLocal, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_);
+ SetWaitFlag(HardEvent::MTE3_MTE2);
+ }
+ if (kTailNum_ > 0) {
+ CopyInWithCast(workLocal, gmSampledRes_, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_);
+ if constexpr (IsSameType::value) {
+ SetWaitFlag(HardEvent::MTE2_V);
+ } else {
+ PipeBarrier();
+ }
+ CompareScalar(maskLocal, workLocal, minPThresholds_, CMPMODE::LT, loopDataNum_);
+ PipeBarrier();
+ Select(workLocal, maskLocal, zeroLocal, workLocal, SELMODE::VSEL_TENSOR_TENSOR_MODE, loopDataNum_);
+ if constexpr (IsSameType::value) {
+ SetWaitFlag(HardEvent::V_MTE3);
+ } else {
+ PipeBarrier();
+ }
+ CopyOutWithCast(gmSampledRes_, workLocal, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_);
+ SetWaitFlag(HardEvent::MTE3_MTE2);
+ }
+}
+} // namespace sglang::npu_kernel::ApplyTopKTopPMinPKernel
+
+__global__ __aicore__ void apply_top_k_top_p_min_p(GM_ADDR probs, GM_ADDR k, GM_ADDR p, GM_ADDR min_p,
+ GM_ADDR sampled_res, GM_ADDR workspace, GM_ADDR tiling)
+{
+#define INIT_AND_PROCESS \
+ op.Init(tilingData, probs, k, p, min_p, sampled_res, userWS, &tPipe); \
+ op.Process();
+
+ AscendC::TPipe tPipe;
+ using namespace sglang::npu_kernel::ApplyTopKTopPMinPKernel;
+
+ KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
+
+ auto tilingData = reinterpret_cast<__gm__ sglang::ATKTPMPHost::ApplyTopKTopPMinPTilingData *>(tiling);
+ auto tilingKey = tilingData->tilingKey;
+ GM_ADDR userWS = GetUserWorkspace(workspace);
+ if (userWS == nullptr) {
+ return;
+ }
+
+ if (tilingKey == TTILING_FP32_WITHOUT_MIN_P) {
+ ApplyTopKTopPMinP op;
+ INIT_AND_PROCESS;
+ } else if (tilingKey == TTILING_FP16_WITHOUT_MIN_P) {
+ ApplyTopKTopPMinP op;
+ INIT_AND_PROCESS;
+ } else if (tilingKey == TTILING_BF16_WITHOUT_MIN_P) {
+ ApplyTopKTopPMinP op;
+ INIT_AND_PROCESS;
+ } else if (tilingKey == TTILING_FP32_MIN_P) {
+ ApplyTopKTopPMinP op;
+ INIT_AND_PROCESS;
+ } else if (tilingKey == TTILING_FP16_MIN_P) {
+ ApplyTopKTopPMinP op;
+ INIT_AND_PROCESS;
+ } else if (tilingKey == TTILING_BF16_MIN_P) {
+ ApplyTopKTopPMinP op;
+ INIT_AND_PROCESS;
+ }
+}
diff --git a/csrc/pytorch_extensions.cpp b/csrc/pytorch_extensions.cpp
index a7d093730..5d88dc1df 100644
--- a/csrc/pytorch_extensions.cpp
+++ b/csrc/pytorch_extensions.cpp
@@ -98,6 +98,8 @@ TORCH_LIBRARY_FRAGMENT(npu, m)
"int? sparse_count=None, int? sparse_mode=None) -> Tensor");
m.def("triangular_inverse(Tensor x) -> Tensor");
+
+ m.def("apply_top_k_top_p_min_p(Tensor logits, Tensor k, Tensor p, Tensor? min_p=None) -> Tensor");
}
} // namespace
@@ -141,5 +143,7 @@ TORCH_LIBRARY_IMPL(npu, PrivateUse1, m)
m.impl("lightning_indexer", TORCH_FN(sglang::npu_kernel::lightning_indexer));
m.impl("triangular_inverse", TORCH_FN(sglang::npu_kernel::tri_inv_col_sweep));
+
+ m.impl("apply_top_k_top_p_min_p", TORCH_FN(sglang::npu_kernel::apply_top_k_top_p_min_p));
}
} // namespace
diff --git a/include/sgl_kenel_npu_ops.h b/include/sgl_kenel_npu_ops.h
index e2f5a6804..e3aeda613 100644
--- a/include/sgl_kenel_npu_ops.h
+++ b/include/sgl_kenel_npu_ops.h
@@ -123,6 +123,10 @@ at::Tensor lightning_indexer(
* is inversed.
*/
at::Tensor tri_inv_col_sweep(const at::Tensor &tensor_in);
+
+at::Tensor apply_top_k_top_p_min_p(const at::Tensor &logits,
+ const at::Tensor &k, const at::Tensor &p,
+ const c10::optional &min_p);
} // namespace npu_kernel
} // namespace sglang
diff --git a/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py b/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py
new file mode 100644
index 000000000..ef1929ea8
--- /dev/null
+++ b/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py
@@ -0,0 +1,125 @@
+# This program is free software, you can redistribute it and/or modify it.
+# Copyright (c) 2025 Huawei Technologies Co., Ltd.
+# This file is a part of the CANN Open Software.
+# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
+# Please refer to the License for details. You may not use this file except in compliance with the License.
+# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
+# See LICENSE in the root of the software repository for the full text of the License.
+
+import unittest
+
+import numpy as np
+import sgl_kernel_npu
+import torch
+import torch.nn as nn
+import torch_npu
+import torchair
+
+DEVICE_ID = 0
+torch_npu.npu.set_device(int(DEVICE_ID))
+DTYPE_ATOL = {torch.float32: 0.0001, torch.float16: 0.001, torch.bfloat16: 0.004}
+
+
+def _apply_top_k_top_p_min_p(
+ probs_sort,
+ k,
+ p,
+ min_p=None,
+):
+ probs_sort_out = probs_sort.clone().to(torch.float32)
+ top_k_mask = torch.arange(0, probs_sort.shape[-1], device=probs_sort.device).view(
+ 1, -1
+ ) >= k.view(-1, 1)
+ probs_sort_out.masked_fill_(top_k_mask, 0.0)
+
+ probs_sum = torch.cumsum(probs_sort_out, dim=-1)
+ top_p_mask = probs_sum - probs_sort_out > p.view(-1, 1)
+ probs_sort_out.masked_fill_(top_p_mask, 0.0)
+
+ if min_p is not None:
+ min_p_thresholds = probs_sort_out[:, 0] * min_p
+ min_p_mask = probs_sort_out < min_p_thresholds.view(-1, 1)
+ probs_sort_out.masked_fill_(min_p_mask, 0.0)
+ return probs_sort_out.to(probs_sort.dtype)
+
+
+class TestCustomApplyTopKTopPMinP(unittest.TestCase):
+ def test_apply_top_k_top_p_min_p_eager(self):
+ batch_size = 4
+ vocab_size = 131072
+
+ for dtype in [torch.bfloat16, torch.float16, torch.float32]:
+ np.random.seed(3)
+ logits = torch.tensor(
+ np.random.uniform(-10, 10, (batch_size, vocab_size))
+ ).to(dtype)
+ k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to(
+ torch.int32
+ )
+ p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype)
+ min_p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype)
+ probs = torch.softmax(logits, dim=-1)
+ probs_sort, probs_idx = probs.sort(dim=-1, descending=True, stable=True)
+ cpu_out = _apply_top_k_top_p_min_p(probs_sort, k, p, min_p)
+
+ torch_npu.npu.set_device(int(DEVICE_ID))
+ probs_sort = probs_sort.to("npu:%s" % DEVICE_ID)
+ k = k.to("npu:%s" % DEVICE_ID)
+ p = p.to("npu:%s" % DEVICE_ID)
+ min_p = min_p.to("npu:%s" % DEVICE_ID)
+
+ npu_out = torch.ops.npu.apply_top_k_top_p_min_p(
+ probs_sort, k, p, min_p=min_p
+ )
+
+ # compare result
+ npu_out = npu_out.cpu()
+ cpu_out = cpu_out.cpu()
+ tol = DTYPE_ATOL[dtype]
+ assert torch.allclose(
+ cpu_out.to(torch.float32),
+ npu_out.to(torch.float32),
+ atol=tol,
+ rtol=tol,
+ )
+
+ def test_apply_top_k_top_p_min_p_eager_without_min_p(self):
+ batch_size = 4
+ vocab_size = 131072
+
+ for dtype in [torch.bfloat16, torch.float16, torch.float32]:
+ np.random.seed(4)
+ logits = torch.tensor(
+ np.random.uniform(-10, 10, (batch_size, vocab_size))
+ ).to(dtype)
+ k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to(
+ torch.int32
+ )
+ p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype)
+ probs = torch.softmax(logits, dim=-1)
+ probs_sort, probs_idx = probs.sort(dim=-1, descending=True, stable=True)
+ cpu_out = _apply_top_k_top_p_min_p(probs_sort, k, p, min_p=None)
+
+ torch_npu.npu.set_device(int(DEVICE_ID))
+ probs_sort = probs_sort.to("npu:%s" % DEVICE_ID)
+ k = k.to("npu:%s" % DEVICE_ID)
+ p = p.to("npu:%s" % DEVICE_ID)
+
+ npu_out = torch.ops.npu.apply_top_k_top_p_min_p(
+ probs_sort, k, p, min_p=None
+ )
+
+ # compare result
+ npu_out = npu_out.cpu()
+ cpu_out = cpu_out.cpu()
+ tol = DTYPE_ATOL[dtype]
+ assert torch.allclose(
+ cpu_out.to(torch.float32),
+ npu_out.to(torch.float32),
+ atol=tol,
+ rtol=tol,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)