From 8ff3939da28e975be94e266378ece1f96c7f7552 Mon Sep 17 00:00:00 2001 From: wangxinwei21 Date: Mon, 19 Jan 2026 10:59:07 +0800 Subject: [PATCH 1/3] add apply_top_k_top_p_min_p op --- csrc/CMakeLists.txt | 3 + csrc/apply_top_k_top_p_min_p/README.md | 65 +++ .../op_host/apply_top_k_top_p_min_p.cpp | 81 ++++ .../op_host/apply_top_k_top_p_min_p_def.h | 54 +++ .../tiling/apply_top_k_top_p_min_p_tiling.cpp | 134 ++++++ .../tiling/apply_top_k_top_p_min_p_tiling.h | 86 ++++ .../apply_top_k_top_p_min_p_tiling_data.h | 23 + .../apply_top_k_top_p_min_p_kernel.cpp | 435 ++++++++++++++++++ csrc/pytorch_extensions.cpp | 4 + include/sgl_kenel_npu_ops.h | 4 + .../test_apply_top_k_top_p_min_p.py | 132 ++++++ 11 files changed, 1021 insertions(+) create mode 100644 csrc/apply_top_k_top_p_min_p/README.md create mode 100644 csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp create mode 100644 csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h create mode 100644 csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp create mode 100644 csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h create mode 100644 csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling_data.h create mode 100644 csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp create mode 100644 tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 6b3236f8b..68826e8d2 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -22,6 +22,8 @@ FILE(GLOB OP_SRCS ${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/lightning_indexer.cpp ${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/tiling/lightning_indexer_tiling.cpp ${PROJECT_OP_SRC_BASE}/tri_inv/op_host/tri_inv.cpp + ${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp + ${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp ) if(BUILD_CATLASS_MODULE) list(APPEND OP_SRCS @@ -53,6 +55,7 @@ set(WORKSPACE_KERNEL_SRCS ${PROJECT_OP_SRC_BASE}/alloc_extend/op_kernel/alloc_extend_kernel.cpp ${PROJECT_OP_SRC_BASE}/build_tree/op_kernel/build_tree_kernel.cpp ${PROJECT_OP_SRC_BASE}/lightning_indexer/op_kernel/lightning_indexer_kernel.cpp + ${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp ) if(BUILD_CATLASS_MODULE) list(APPEND WORKSPACE_KERNEL_SRCS diff --git a/csrc/apply_top_k_top_p_min_p/README.md b/csrc/apply_top_k_top_p_min_p/README.md new file mode 100644 index 000000000..03a18b808 --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/README.md @@ -0,0 +1,65 @@ +## Introduction +A top-k, top-p and min-p sampling implementation for ascend. + +## Sheet 1: Parameters +| Parameter | Dimension | Data Type | Format | Description | +|--------------|--------------------------|----------------------|--------|--------------------------------------------------| +| probs | [batch_size, vocab_size] | float32/float16/bf16 | ND | Probabilities for sampling.
The probabilities should be sorted in descending order. | +| k | [batch_size] | int32 | ND | Representing the threshold for top-k sampling. | +| p | [batch_size] | float32/float16/bf16 | ND | Representing the threshold for top-p sampling. | +| min_p | [batch_size] | float32/float16/bf16 | ND | Representing the threshold for min-p sampling.
When min_p is nullptr, the min-p sampling will be skipped. | +| sampled_res | [batch_size, vocab_size] | float32/float16/bf16 | ND | The result after sampling.
The DataType of sampled_res should be same as probs. | + +## Calculation Formula +$$ +sampled\_res[b][v] = +\begin{cases} +0 & \text{v >= k[b]} \\ +probs[b][v] & \text{v < k[b]} +\end{cases} +$$ +$$probs\_sum = cumsum(sampled\_res, dim=-1)$$ +$$top\_p\_mask[b][v] = probs\_sum[b][v] - sampled\_res[b][v] > p[b]$$ +$$ +sampled\_res[b][v] = +\begin{cases} +0 & \text{top\_p\_mask = True} \\ +sampled\_res[b][v] & \text{top\_p\_mask = False} +\end{cases} +$$ +$$min\_p\_mask[b][v] = sampled\_res[b][v] < sampled\_res[b][0] * min\_p[b]$$ +$$ +sampled\_res[b][v] = +\begin{cases} +0 & \text{min\_p\_mask = True} \\ +sampled\_res[b][v] & \text{min\_p\_mask = False} +\end{cases} +$$ +Where $0 \le b \lt batch\_size$, and $0 \le v \lt vocab\_size$. + +## Restrictions +1. Only support Ascend A2/A3. +2. $0 \lt k[b] \le vocab\_size$, where $0 \le b \lt batch\_size$, if $k[b] \lt 0$ or $k[b] \gt vocab\_size$, the $k[b]$ will regarded as vocab\_size. +2. $0 \le p[b] \le 1$, where $0 \le b \lt batch\_size$. + +## Sample Code +```python +import numpy as np +import torch +import torch_npu +import sgl_kernel_npu + +dtype = torch.float16 +batch_size = 4 +vocab_size = 128 + +logits = torch.tensor(np.random.uniform(-10, 10, (batch_size, vocab_size))).to(dtype).npu() +k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to(torch.int32).npu() +p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype).npu() +min_p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype).npu() + +probs = torch.softmax(logits, dim=-1) +probs_sort, probs_idx = probs.sort(dim=-1, descending=True, stable=True) + +torch.ops.npu.apply_top_k_top_p_min_p(probs_sort, k, p, min_p=min_p) +``` diff --git a/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp new file mode 100644 index 000000000..18da6acb4 --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp @@ -0,0 +1,81 @@ +#include +#include +#include "acl/acl.h" +#include "kernel_tiling/kernel_tiling.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/apply_top_k_top_p_min_p_tiling.h" +#include "defines.h" +#include "torch_helper.h" +#include "ge_helper.h" +#include "common_tiling.h" +#include "apply_top_k_top_p_min_p_def.h" +#include "common.h" +#include "aclrtlaunch_apply_top_k_top_p_min_p.h" + +namespace sglang::ATKTPMPHost { + +using namespace ge_helper; +constexpr uint32_t PADDING_BYTE = 32U; + +inline at::Tensor ConstructApplyTopKTopPMinPOutputTensor(const at::Tensor &probs) +{ + for (size_t i = 0; i < probs.sizes().size(); i++) { + TORCH_CHECK(probs.size(i) > 0, + "All values within probs's shape should be greater " + "than 0, but shape[", + i, "] is ", probs.size(i)); + } + at::Tensor output = at::empty_like(probs); + return output; +} +} // namespace sglang::ATKTPMPHost + +namespace sglang { +namespace npu_kernel { +HOST_API at::Tensor apply_top_k_top_p_min_p(const at::Tensor &probs, const at::Tensor &k, const at::Tensor &p, + const c10::optional &min_p) +{ + using namespace ATKTPMPHost; + at::Tensor sampledRes = ConstructApplyTopKTopPMinPOutputTensor(probs); + + auto probsType = probs.scalar_type(); + + at::Tensor minP = + min_p.has_value() + ? min_p.value() + : at::empty({1}, at::TensorOptions().dtype(probsType).device(probs.options().device())); + + ApplyTopKTopPMinPTilingInfo applyTopKTopPMinPInfo; + applyTopKTopPMinPInfo.opParamInfo.probs.dtype = SCALAR_TYPE_TO_GE_DATATYPE(probsType); + applyTopKTopPMinPInfo.opParamInfo.probs.shape = probs.sizes(); + applyTopKTopPMinPInfo.opParamInfo.k.dtype = SCALAR_TYPE_TO_GE_DATATYPE(k.scalar_type()); + applyTopKTopPMinPInfo.opParamInfo.k.shape = k.sizes(); + applyTopKTopPMinPInfo.opParamInfo.p.dtype = SCALAR_TYPE_TO_GE_DATATYPE(p.scalar_type()); + applyTopKTopPMinPInfo.opParamInfo.p.shape = p.sizes(); + if (min_p.has_value()) { + applyTopKTopPMinPInfo.opParamInfo.minP.dtype = SCALAR_TYPE_TO_GE_DATATYPE(minP.scalar_type()); + applyTopKTopPMinPInfo.opParamInfo.minP.shape = minP.sizes(); + } + applyTopKTopPMinPInfo.opParamInfo.sampledRes.dtype = SCALAR_TYPE_TO_GE_DATATYPE(sampledRes.scalar_type()); + applyTopKTopPMinPInfo.opParamInfo.sampledRes.shape = sampledRes.sizes(); + + ApplyTopKTopPMinPTiling applyTopKTopPMinPTiling(&applyTopKTopPMinPInfo); + TORCH_CHECK(applyTopKTopPMinPTiling.DoTiling() == ge::GRAPH_SUCCESS, + "apply_top_k_top_p_min_p DoTiling failed") + + const auto &tilingData = applyTopKTopPMinPTiling.GetTilingData(); + + uint32_t tilingSize = (sizeof(ApplyTopKTopPMinPTiling) + PADDING_BYTE - 1) / PADDING_BYTE * PADDING_BYTE; + auto blockDim = tilingData.coreNum; + static auto tilingBuffer = + at::empty({tilingSize}, at::TensorOptions().dtype(at::kByte).device(probs.options().device())); + aclrtMemcpy(tilingBuffer.data_ptr(), tilingSize, &tilingData, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE); + at::Tensor tilingTensor = at::from_blob(tilingBuffer.data_ptr(), tilingSize, at::kByte); + + auto workspace = at::empty({applyTopKTopPMinPInfo.workspaceSize}, + at::TensorOptions().dtype(at::kByte).device(probs.options().device())); + EXEC_KERNEL_CMD(apply_top_k_top_p_min_p, blockDim, probs, k, p, minP, sampledRes, workspace, tilingTensor); + return sampledRes; +} +} // namespace npu_kernel +} // namespace sglang diff --git a/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h new file mode 100644 index 000000000..8b82c9de3 --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h @@ -0,0 +1,54 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of + * the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_min_p_def.cpp + * \brief + */ +#include +#include "ge_helper.h" + +namespace sglang { +namespace ATKTPMPHost { +using namespace ge_helper; +class ApplyTopKTopPMinP : public OpDef +{ +public: + explicit ApplyTopKTopPMinP(const char *name) : OpDef(name) + { + this->Input("probs") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("k") + .ParamType(REQUIRED) + .DataTypeList({ge::DT_INT32}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("p") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("min_p") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Output("sampled_res") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .FormatList({ge::FORMAT_ND}); + } +}; +} // namespace ATKTPMPHost +} // namespace sglang diff --git a/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp new file mode 100644 index 000000000..769f54847 --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp @@ -0,0 +1,134 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of + * the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_min_p_tiling.cpp + * \brief + */ + +#include "apply_top_k_top_p_min_p_tiling.h" + +using namespace ge; +using namespace AscendC; +using std::map; +using std::string; +namespace sglang::ATKTPMPHost { + +// --------------------------ApplyTopKTopPMinPTiling类成员函数定义----------------------- +ge::graphStatus ApplyTopKTopPMinPTiling::CheckDtype() +{ + TORCH_CHECK((tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT16) || + (tilingInfo_->opParamInfo.probs.dtype == ge::DT_BF16) || + (tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT), + "The data types of probs, p and sampled_res must be float16, bfloat16 or float."); + + TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.p.dtype, + "The data types of probs and p must be the same."); + TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.sampledRes.dtype, + "The data types of probs and sampled_res must be the same."); + + TORCH_CHECK(tilingInfo_->opParamInfo.k.dtype == ge::DT_INT32, + "The data types of the input k must be int32."); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus ApplyTopKTopPMinPTiling::CheckShape() +{ + TORCH_CHECK(tilingInfo_->opParamInfo.probs.shape.size() == DIM_NUM_TWO, + "ApplyTopKTopPMinP: the dimNum of probs should be ", DIM_NUM_TWO, ", but now is ", + tilingInfo_->opParamInfo.probs.shape.size(), "."); + tilingData_.batchSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ZERO]; + tilingData_.vocabSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ONE]; + + TORCH_CHECK(tilingInfo_->opParamInfo.k.shape.size() == DIM_NUM_ONE, + "ApplyTopKTopPMinP: the dimNum of k should be ", DIM_NUM_ONE, ", but now is ", + tilingInfo_->opParamInfo.k.shape.size(), "."); + int64_t kSize = tilingInfo_->opParamInfo.k.shape[DIM_IDX_ZERO]; + TORCH_CHECK(kSize == tilingData_.batchSize, + "ApplyTopKTopPMinP: the shape of k should be [", tilingData_.batchSize, "], but now is [", kSize, "]."); + + TORCH_CHECK(tilingInfo_->opParamInfo.p.shape.size() == DIM_NUM_ONE, + "ApplyTopKTopPMinP: the dimNum of p should be ", DIM_NUM_ONE, ", but now is ", + tilingInfo_->opParamInfo.p.shape.size(), "."); + int64_t pSize = tilingInfo_->opParamInfo.p.shape[DIM_IDX_ZERO]; + TORCH_CHECK(pSize == tilingData_.batchSize, + "ApplyTopKTopPMinP: the shape of p should be [", tilingData_.batchSize, "], but now is [", pSize, "]."); + + if (tilingInfo_->opParamInfo.minP.shape.size() != DIM_NUM_ZERO) { + int64_t minPSize = tilingInfo_->opParamInfo.minP.shape[DIM_IDX_ZERO]; + TORCH_CHECK(minPSize == tilingData_.batchSize, ": the shape of p should be [", tilingData_.batchSize, + "], but now is [", minPSize, "]."); + tilingInfo_->needMinPSample = 1; + } + + TORCH_CHECK(tilingInfo_->opParamInfo.sampledRes.shape.size() == DIM_NUM_TWO, + "ApplyTopKTopPMinP: the dimNum of sampled_res should be ", DIM_NUM_TWO, ", but now is ", + tilingInfo_->opParamInfo.sampledRes.shape.size(), "."); + int64_t sampledResSize0 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ZERO]; + int64_t sampledResSize1 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ONE]; + TORCH_CHECK(sampledResSize0 == tilingData_.batchSize && sampledResSize1 == tilingData_.vocabSize, + "ApplyTopKTopPMinP: the size of sampledRes should be [", + tilingData_.batchSize, ", ", tilingData_.vocabSize, + "], but now is [", sampledResSize0, ", ", sampledResSize1, "]."); + return ge::GRAPH_SUCCESS; +} + +void ApplyTopKTopPMinPTiling::SplitTask() +{ + tilingData_.loopDataNum = tilingData_.ubSize / BYTES_B32 / LOCAL_TENSOR_NUM / BYTES_PER_REPEAT * BYTES_PER_REPEAT; + tilingData_.coreNum = tilingData_.batchSize > tilingData_.coreNum ? tilingData_.coreNum : tilingData_.batchSize; + tilingData_.batchPerCore = tilingData_.batchSize / std::max(tilingData_.coreNum, static_cast(1)); + tilingData_.batchTailCore = tilingData_.batchSize - tilingData_.batchPerCore * tilingData_.coreNum; +} + +ge::graphStatus ApplyTopKTopPMinPTiling::DoTiling() +{ + if (CheckDtype() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + if (CheckShape() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + auto ascendcPlatform = *platform_ascendc::PlatformAscendCManager::GetInstance(); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); + TORCH_CHECK(aivNum != 0 && aivNum != 0, "num of core obtained is 0"); + tilingData_.coreNum = static_cast(aivNum); + + uint64_t ubSize = 0; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + tilingData_.ubSize = static_cast(ubSize) - SELECT_MODE_BYTES; + + auto socVersion = ascendcPlatform.GetSocVersion(); + TORCH_CHECK(socVersion == platform_ascendc::SocVersion::ASCEND910B || + socVersion == platform_ascendc::SocVersion::ASCEND910_93, + "soc version does not support ", (int32_t)socVersion); + + SplitTask(); + + // -------------set workspacesize----------------- + tilingInfo_->workspaceSize = static_cast(ascendcPlatform.GetLibApiWorkSpaceSize()) + + tilingData_.batchSize * tilingData_.vocabSize * BYTES_B32; + + // -------------set tilingkey----------------- + tilingData_.tilingKey = G_DTYPE_MAP.at(tilingInfo_->opParamInfo.probs.dtype) * COEF_TEN + + tilingInfo_->needMinPSample; + + return ge::GRAPH_SUCCESS; +} + +const ApplyTopKTopPMinPTilingData &ApplyTopKTopPMinPTiling::GetTilingData() const +{ + return tilingData_; +} +} // namespace sglang::ATKTPMPHost diff --git a/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h new file mode 100644 index 000000000..5856a7470 --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h @@ -0,0 +1,86 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of + * the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_min_p_tiling.h + * \brief + */ + +#ifndef APPLY_TOP_K_TOP_P_MIN_P_TILING_H_ +#define APPLY_TOP_K_TOP_P_MIN_P_TILING_H_ + +#include "register/op_def_registry.h" +#include "register/tilingdata_base.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/tiling_api.h" +#include "apply_top_k_top_p_min_p_tiling_data.h" +#include "ge_helper.h" + +namespace sglang::ATKTPMPHost { +struct TensorParaInfo { + ge::DataType dtype; + c10::ArrayRef shape; +}; + +const std::map G_DTYPE_MAP = {{ge::DT_FLOAT, 1}, + {ge::DT_FLOAT16, 2}, + {ge::DT_BF16, 3}}; +// ------------------算子原型索引常量定义---------------- +// Dim Index +constexpr uint32_t DIM_IDX_ZERO = 0; +constexpr uint32_t DIM_IDX_ONE = 1; +// Dim Num +constexpr uint32_t DIM_NUM_ZERO = 0; +constexpr uint32_t DIM_NUM_ONE = 1; +constexpr uint32_t DIM_NUM_TWO = 2; + +constexpr int64_t COEF_TEN = 10; +constexpr int64_t BYTES_B32 = 4; +constexpr int64_t LOCAL_TENSOR_NUM = 4; +constexpr int64_t SELECT_MODE_BYTES = 8192; +constexpr int64_t BYTES_PER_REPEAT = 256; + +// -----------算子Tiling入参结构体定义--------------- +struct ApplyTopKTopPMinPParaInfo { + TensorParaInfo probs = {ge::DT_FLOAT, c10::ArrayRef{}}; + TensorParaInfo k = {ge::DT_INT32, c10::ArrayRef{}}; + TensorParaInfo p = {ge::DT_FLOAT, c10::ArrayRef{}}; + TensorParaInfo minP = {ge::DT_FLOAT, c10::ArrayRef{}}; + TensorParaInfo sampledRes = {ge::DT_FLOAT, c10::ArrayRef{}}; +}; + +// -----------算子Tiling入参信息类--------------- +class ApplyTopKTopPMinPTilingInfo +{ +public: + ApplyTopKTopPMinPParaInfo opParamInfo; + int64_t workspaceSize = 0; + int64_t needMinPSample = 0; +}; + +// ---------------算子Tiling类--------------- +class ApplyTopKTopPMinPTiling +{ +public: + explicit ApplyTopKTopPMinPTiling(ApplyTopKTopPMinPTilingInfo *tilingInfo) : tilingInfo_(tilingInfo) {}; + ge::graphStatus CheckDtype(); + ge::graphStatus CheckShape(); + void SplitTask(); + ge::graphStatus DoTiling(); + const ApplyTopKTopPMinPTilingData &GetTilingData() const; + +private: + ApplyTopKTopPMinPTilingInfo *tilingInfo_ = nullptr; + ApplyTopKTopPMinPTilingData tilingData_; +}; + +} // namespace sglang::ATKTPMPHost +#endif // APPLY_TOP_K_TOP_P_MIN_P_TILING_H_ diff --git a/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling_data.h b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling_data.h new file mode 100644 index 000000000..4d0e15817 --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling_data.h @@ -0,0 +1,23 @@ +#ifndef APPLY_TOP_K_TOP_P_MIN_P_TILING_DATA_H +#define APPLY_TOP_K_TOP_P_MIN_P_TILING_DATA_H +#include + +namespace sglang { +namespace ATKTPMPHost { + +// -----------算子TilingData定义--------------- +#pragma pack(push, 1) +struct ApplyTopKTopPMinPTilingData { + int64_t batchSize = 0; + int64_t vocabSize = 0; + int64_t batchPerCore = 0; + int64_t batchTailCore = 0; + int64_t ubSize = 0; + int64_t coreNum = 0; + int64_t loopDataNum = 0; + int64_t tilingKey = 0; +}; +#pragma pack(pop) +} // namespace ATKTPMPHost +} // namespace sglang +#endif diff --git a/csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp b/csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp new file mode 100644 index 000000000..a5df21296 --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp @@ -0,0 +1,435 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of + * the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_min_p_kernel.cpp + * \brief + */ + +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "../op_host/tiling/apply_top_k_top_p_min_p_tiling_data.h" + +using namespace AscendC; +using namespace sglang::ATKTPMPHost; +#define TTILING_FP32_WITHOUT_MIN_P 10 +#define TTILING_FP16_WITHOUT_MIN_P 20 +#define TTILING_BF16_WITHOUT_MIN_P 30 +#define TTILING_FP32_MIN_P 11 +#define TTILING_FP16_MIN_P 21 +#define TTILING_BF16_MIN_P 31 + +namespace sglang::npu_kernel::ApplyTopKTopPMinPKernel { +template +__aicore__ inline void SetWaitFlag(HardEvent evt) +{ + event_t eventId = static_cast(GetTPipePtr()->FetchEventID(evt)); + SetFlag(eventId); + WaitFlag(eventId); +} + +template +class ApplyTopKTopPMinP { +public: + __aicore__ inline ApplyTopKTopPMinP(){}; + __aicore__ inline void Init( + const __gm__ ApplyTopKTopPMinPTilingData *tilingData, GM_ADDR probs, GM_ADDR k, + GM_ADDR p, GM_ADDR min_p, GM_ADDR sampled_res, GM_ADDR workspace, TPipe *tPipe); + __aicore__ inline void Process(); +private: + __aicore__ inline void CopyInWithCast( + LocalTensor& localTensor, GlobalTensor& globalTensor, int64_t gmOffset, int64_t copyNum); + __aicore__ inline void InitWorkspace(LocalTensor& workLocal); + __aicore__ inline void GetFloatValue(GlobalTensor& globalTensor, int64_t offset, float& value); + __aicore__ inline void CumSumImpl(LocalTensor& cumSumInput1Local, LocalTensor& cumSumInput2Local); + __aicore__ inline void CopyOutWithCast( + GlobalTensor& globalTensor, LocalTensor& localTensor, int64_t gmOffset, int64_t copyNum); + __aicore__ inline void TopPProcess( + LocalTensor& cumSumInput1Local, LocalTensor& cumSumInput2Local, LocalTensor& zeroLocal, + LocalTensor& maskLocal); + __aicore__ inline void MinPProcess( + LocalTensor& workLocal, LocalTensor& zeroLocal, LocalTensor& maskLocal); +private: + TBuf calBuf_; + + // tilingData + int64_t batchSize_ = 0; + int64_t vocabSize_ = 0; + int64_t batchPerCore_ = 0; + int64_t batchTailCore_ = 0; + uint32_t blockIdx_ = 0; + int64_t coreBatch_ = 0; + int64_t batchOffset_ = 0; + int64_t coreNum_ = 0; + int64_t iterateTimes_ = 0; + int64_t baseGmOffset_ = 0; + int64_t probsGmOffset_ = 0; + + int32_t kValue_ = 0; + float pValue_ = 0; + float minPValue_ = 0; + float maxValue_ = 0; + float minPThresholds_ = 0; + int64_t lastIndex_ = 0; + int64_t kLoopNum_ = 0; + int64_t kTailNum_ = 0; + int64_t loopDataNum_ = 0; + + GlobalTensor gmProbs_; + GlobalTensor gmK_; + GlobalTensor gmP_; + GlobalTensor gmMinP_; + GlobalTensor gmSampledRes_; + + GlobalTensor gmWk_; +}; + +template +__aicore__ inline void ApplyTopKTopPMinP::Init( + const __gm__ ApplyTopKTopPMinPTilingData *tilingData, GM_ADDR probs, GM_ADDR k, + GM_ADDR p, GM_ADDR min_p, GM_ADDR sampled_res, GM_ADDR workspace, TPipe *tPipe) +{ + batchSize_ = tilingData->batchSize; + vocabSize_ = tilingData->vocabSize; + batchPerCore_ = tilingData->batchPerCore; + batchTailCore_ = tilingData->batchTailCore; + coreNum_ = tilingData->coreNum; + uint32_t ubSize = static_cast(tilingData->ubSize); + loopDataNum_ = tilingData->loopDataNum; + + blockIdx_ = GetBlockIdx(); + if (blockIdx_ >= coreNum_) { + return; + } + + if (blockIdx_ < batchTailCore_) { + coreBatch_ = batchPerCore_ + 1; + batchOffset_ = blockIdx_ * coreBatch_; + } else { + coreBatch_ = batchPerCore_; + batchOffset_ = blockIdx_ * batchPerCore_ + batchTailCore_; + } + baseGmOffset_ = batchOffset_ * vocabSize_; + + if (coreBatch_ == 0) { + return; + } + + gmProbs_.SetGlobalBuffer((__gm__ T *)probs + baseGmOffset_); + gmK_.SetGlobalBuffer((__gm__ int32_t *)k + batchOffset_); + gmP_.SetGlobalBuffer((__gm__ T *)p + batchOffset_); + if constexpr (IsMinPSampling == 1) { + gmMinP_.SetGlobalBuffer((__gm__ T *)min_p + batchOffset_); + } + gmSampledRes_.SetGlobalBuffer((__gm__ T *)sampled_res + baseGmOffset_); + gmWk_.SetGlobalBuffer((__gm__ float *)workspace + baseGmOffset_); + InitGlobalMemory(gmSampledRes_, coreBatch_ * vocabSize_, T(0)); + InitGlobalMemory(gmWk_, coreBatch_ * vocabSize_, float(0)); + tPipe->InitBuffer(calBuf_, ubSize); +} + +template +__aicore__ inline void ApplyTopKTopPMinP::GetFloatValue( + GlobalTensor& globalTensor, int64_t offset, float& value) +{ + if constexpr (IsSameType::value) { + value = globalTensor.GetValue(offset); + } else if constexpr (IsSameType::value) { + value = static_cast(globalTensor.GetValue(offset)); + } else { + value = ToFloat(globalTensor.GetValue(offset)); + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::Process() +{ + if (blockIdx_ >= coreNum_) { + return; + } + uint32_t bufOffset = 0; + LocalTensor maskLocal = calBuf_.GetWithOffset(loopDataNum_ / 8, bufOffset); + bufOffset += loopDataNum_ / 8 * sizeof(uint8_t); + LocalTensor zeroLocal = calBuf_.GetWithOffset(loopDataNum_, bufOffset); + bufOffset += loopDataNum_ * sizeof(float); + LocalTensor cumSumInput1Local = calBuf_.GetWithOffset(loopDataNum_, bufOffset); + bufOffset += loopDataNum_ * sizeof(float); + LocalTensor cumSumInput2Local = calBuf_.GetWithOffset(loopDataNum_, bufOffset); + bufOffset += loopDataNum_ * sizeof(float); + + Duplicate(zeroLocal, float(0), loopDataNum_); + SetWaitFlag(HardEvent::MTE3_MTE2); + for (int64_t batchLoop = 0; batchLoop < coreBatch_; batchLoop++) { + probsGmOffset_ = batchLoop * vocabSize_; + + kValue_ = gmK_.GetValue(batchLoop); + GetFloatValue(gmP_, batchLoop, pValue_); + if constexpr (IsMinPSampling == 1) { + GetFloatValue(gmMinP_, batchLoop, minPValue_); + GetFloatValue(gmProbs_, probsGmOffset_, maxValue_); + minPThresholds_ = maxValue_ * minPValue_; + } + if (kValue_ > vocabSize_ || kValue_ < 0) { + lastIndex_ = vocabSize_; + } else { + lastIndex_ = kValue_; + } + kLoopNum_ = lastIndex_ / loopDataNum_; + kTailNum_ = lastIndex_ - kLoopNum_ * loopDataNum_; + + InitWorkspace(cumSumInput1Local); + + TopPProcess(cumSumInput1Local, cumSumInput2Local, zeroLocal, maskLocal); + + if constexpr (IsMinPSampling == 1) { + MinPProcess(cumSumInput1Local, zeroLocal, maskLocal); + } + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::CopyInWithCast( + LocalTensor& localTensor, GlobalTensor& globalTensor, int64_t gmOffset, int64_t copyNum) +{ + if constexpr (IsSameType::value) { + DataCopyPad(localTensor, globalTensor[gmOffset], + {1, static_cast(copyNum * sizeof(T)), 0, 0, 0}, {false, 0, 0, 0}); + } else { + DataCopyPad(localTensor.ReinterpretCast()[loopDataNum_], globalTensor[gmOffset], + {1, static_cast(copyNum * sizeof(T)), 0, 0, 0}, {false, 0, 0, 0}); + SetWaitFlag(HardEvent::MTE2_V); + Cast(localTensor, localTensor.ReinterpretCast()[loopDataNum_], RoundMode::CAST_NONE, copyNum); + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::InitWorkspace(LocalTensor& workLocal) +{ + for (int64_t vocabLoop = 0; vocabLoop < kLoopNum_; vocabLoop++) { + CopyInWithCast(workLocal, gmProbs_, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::MTE2_MTE3); + } else { + SetWaitFlag(HardEvent::V_MTE3); + } + DataCopyPad(gmWk_[probsGmOffset_ + vocabLoop * loopDataNum_], workLocal, + {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0}); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + if (kTailNum_ > 0) { + CopyInWithCast(workLocal, gmProbs_, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::MTE2_MTE3); + } else { + SetWaitFlag(HardEvent::V_MTE3); + } + DataCopyPad(gmWk_[probsGmOffset_ + kLoopNum_ * loopDataNum_], workLocal, + {1, static_cast(kTailNum_ * sizeof(float)), 0, 0, 0}); + SetWaitFlag(HardEvent::MTE3_MTE2); + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::CumSumImpl( + LocalTensor& cumSumInput1Local, LocalTensor& cumSumInput2Local) +{ + int64_t tmpValue = 1; + iterateTimes_ = 0; + while (tmpValue < lastIndex_) { + tmpValue <<= 1; + iterateTimes_++; + } + for (int64_t iterateTime = 0; iterateTime < iterateTimes_; iterateTime++) { + int64_t iteratOffset = 1; + for (int64_t powerIdx = 0; powerIdx < iterateTime; powerIdx++) { + iteratOffset = iteratOffset * 2; + } + int64_t addLength = lastIndex_ - iteratOffset; + int64_t innerLoopNum = addLength / loopDataNum_; + int64_t dataTail = addLength - innerLoopNum * loopDataNum_; + for (int64_t innerLoop = 0; innerLoop < innerLoopNum; innerLoop++) { + int64_t innerLoopOffset = dataTail + (innerLoopNum - 1 - innerLoop) * loopDataNum_; + DataCopyPad(cumSumInput1Local, gmWk_[probsGmOffset_ + innerLoopOffset], + {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + DataCopyPad(cumSumInput2Local, gmWk_[probsGmOffset_ + innerLoopOffset + iteratOffset], + {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + SetWaitFlag(HardEvent::MTE2_V); + Add(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum_); + SetWaitFlag(HardEvent::V_MTE3); + DataCopyPad(gmWk_[probsGmOffset_ + innerLoopOffset + iteratOffset], cumSumInput1Local, + {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0}); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + if (dataTail > 0) { + DataCopyPad(cumSumInput1Local, gmWk_[probsGmOffset_], + {1, static_cast(dataTail * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + DataCopyPad(cumSumInput2Local, gmWk_[probsGmOffset_ + iteratOffset], + {1, static_cast(dataTail * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + SetWaitFlag(HardEvent::MTE2_V); + Add(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, dataTail); + SetWaitFlag(HardEvent::V_MTE3); + DataCopyPad(gmWk_[probsGmOffset_ + iteratOffset], cumSumInput1Local, + {1, static_cast(dataTail * sizeof(float)), 0, 0, 0}); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::CopyOutWithCast( + GlobalTensor& globalTensor, LocalTensor& localTensor, int64_t gmOffset, int64_t copyNum) +{ + if constexpr (IsSameType::value) { + DataCopyPad(globalTensor[gmOffset], localTensor, {1, static_cast(copyNum * sizeof(T)), 0, 0, 0}); + } else { + Cast(localTensor.ReinterpretCast(), localTensor, RoundMode::CAST_ROUND, copyNum); + SetWaitFlag(HardEvent::V_MTE3); + DataCopyPad(globalTensor[gmOffset], localTensor.ReinterpretCast(), + {1, static_cast(copyNum * sizeof(T)), 0, 0, 0}); + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::TopPProcess( + LocalTensor& cumSumInput1Local, LocalTensor& cumSumInput2Local, LocalTensor& zeroLocal, + LocalTensor& maskLocal) +{ + CumSumImpl(cumSumInput1Local, cumSumInput2Local); + for (int64_t vocabLoop = 0; vocabLoop < kLoopNum_; vocabLoop++) { + DataCopyPad(cumSumInput1Local, gmWk_[probsGmOffset_ + vocabLoop * loopDataNum_], + {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + CopyInWithCast(cumSumInput2Local, gmProbs_, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::MTE2_V); + } else { + PipeBarrier(); + } + Sub(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum_); + PipeBarrier(); + CompareScalar(maskLocal, cumSumInput1Local, pValue_, CMPMODE::GT, loopDataNum_); + PipeBarrier(); + Select(cumSumInput2Local, maskLocal, zeroLocal, cumSumInput2Local, SELMODE::VSEL_TENSOR_TENSOR_MODE, loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::V_MTE3); + } else { + PipeBarrier(); + } + CopyOutWithCast(gmSampledRes_, cumSumInput2Local, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + if (kTailNum_ > 0) { + DataCopyPad(cumSumInput1Local, gmWk_[probsGmOffset_ + kLoopNum_ * loopDataNum_], + {1, static_cast(kTailNum_ * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + CopyInWithCast(cumSumInput2Local, gmProbs_, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::MTE2_V); + } else { + PipeBarrier(); + } + Sub(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum_); + PipeBarrier(); + CompareScalar(maskLocal, cumSumInput1Local, pValue_, CMPMODE::GT, loopDataNum_); + PipeBarrier(); + Select(cumSumInput2Local, maskLocal, zeroLocal, cumSumInput2Local, SELMODE::VSEL_TENSOR_TENSOR_MODE, loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::V_MTE3); + } else { + PipeBarrier(); + } + CopyOutWithCast(gmSampledRes_, cumSumInput2Local, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_); + SetWaitFlag(HardEvent::MTE3_MTE2); + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::MinPProcess( + LocalTensor& workLocal, LocalTensor& zeroLocal, LocalTensor& maskLocal) +{ + for (int64_t vocabLoop = 0; vocabLoop < kLoopNum_; vocabLoop++) { + CopyInWithCast(workLocal, gmSampledRes_, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::MTE2_V); + } else { + PipeBarrier(); + } + CompareScalar(maskLocal, workLocal, minPThresholds_, CMPMODE::LT, loopDataNum_); + PipeBarrier(); + Select(workLocal, maskLocal, zeroLocal, workLocal, SELMODE::VSEL_TENSOR_TENSOR_MODE, loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::V_MTE3); + } else { + PipeBarrier(); + } + CopyOutWithCast(gmSampledRes_, workLocal, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + if (kTailNum_ > 0) { + CopyInWithCast(workLocal, gmSampledRes_, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::MTE2_V); + } else { + PipeBarrier(); + } + CompareScalar(maskLocal, workLocal, minPThresholds_, CMPMODE::LT, loopDataNum_); + PipeBarrier(); + Select(workLocal, maskLocal, zeroLocal, workLocal, SELMODE::VSEL_TENSOR_TENSOR_MODE, loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::V_MTE3); + } else { + PipeBarrier(); + } + CopyOutWithCast(gmSampledRes_, workLocal, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_); + SetWaitFlag(HardEvent::MTE3_MTE2); + } +} +} // namespace sglang::npu_kernel::ApplyTopKTopPMinPKernel + +__global__ __aicore__ void apply_top_k_top_p_min_p( + GM_ADDR probs, GM_ADDR k, GM_ADDR p, GM_ADDR min_p, GM_ADDR sampled_res, + GM_ADDR workspace, GM_ADDR tiling) +{ +#define INIT_AND_PROCESS \ + op.Init(tilingData, probs, k, p, min_p, sampled_res, userWS, &tPipe); \ + op.Process(); + + AscendC::TPipe tPipe; + using namespace sglang::npu_kernel::ApplyTopKTopPMinPKernel; + + // __gm__ uint8_t *userWorkspace = workspace; + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY); + + auto tilingData = reinterpret_cast<__gm__ sglang::ATKTPMPHost::ApplyTopKTopPMinPTilingData *>(tiling); + auto tilingKey = tilingData->tilingKey; + GM_ADDR userWS = GetUserWorkspace(workspace); + if (userWS == nullptr) { + return; + } + + if (tilingKey == TTILING_FP32_WITHOUT_MIN_P) { + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; + } else if (tilingKey == TTILING_FP16_WITHOUT_MIN_P) { + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; + } else if (tilingKey == TTILING_BF16_WITHOUT_MIN_P) { + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; + } else if (tilingKey == TTILING_FP32_MIN_P) { + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; + } else if (tilingKey == TTILING_FP16_MIN_P) { + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; + } else if (tilingKey == TTILING_BF16_MIN_P) { + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; + } +} diff --git a/csrc/pytorch_extensions.cpp b/csrc/pytorch_extensions.cpp index a7d093730..6916c284c 100644 --- a/csrc/pytorch_extensions.cpp +++ b/csrc/pytorch_extensions.cpp @@ -98,6 +98,8 @@ TORCH_LIBRARY_FRAGMENT(npu, m) "int? sparse_count=None, int? sparse_mode=None) -> Tensor"); m.def("triangular_inverse(Tensor x) -> Tensor"); + + m.def("apply_top_k_top_p_min_p(Tensor logits, Tensor k, Tensor p, Tensor? min_p=None) -> Tensor"); } } // namespace @@ -141,5 +143,7 @@ TORCH_LIBRARY_IMPL(npu, PrivateUse1, m) m.impl("lightning_indexer", TORCH_FN(sglang::npu_kernel::lightning_indexer)); m.impl("triangular_inverse", TORCH_FN(sglang::npu_kernel::tri_inv_col_sweep)); + + m.impl("apply_top_k_top_p_min_p", TORCH_FN(sglang::npu_kernel::apply_top_k_top_p_min_p)); } } // namespace diff --git a/include/sgl_kenel_npu_ops.h b/include/sgl_kenel_npu_ops.h index e2f5a6804..22517df12 100644 --- a/include/sgl_kenel_npu_ops.h +++ b/include/sgl_kenel_npu_ops.h @@ -123,6 +123,10 @@ at::Tensor lightning_indexer( * is inversed. */ at::Tensor tri_inv_col_sweep(const at::Tensor &tensor_in); + +at::Tensor apply_top_k_top_p_min_p( + const at::Tensor &logits, const at::Tensor &k, const at::Tensor &p, + const c10::optional &min_p); } // namespace npu_kernel } // namespace sglang diff --git a/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py b/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py new file mode 100644 index 000000000..9b1dc1f62 --- /dev/null +++ b/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py @@ -0,0 +1,132 @@ +# This program is free software, you can redistribute it and/or modify it. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import unittest + +import numpy as np +import sgl_kernel_npu +import torch +import torch.nn as nn +import torch_npu +import torchair + +DEVICE_ID = 0 +torch_npu.npu.set_device(int(DEVICE_ID)) +DTYPE_ATOL = {torch.float32: 0.0001, torch.float16: 0.001, torch.bfloat16: 0.004} + + +def _apply_top_k_top_p_min_p( + probs_sort, + k, + p, + min_p=None, +): + probs_sort_out = probs_sort.clone().to(torch.float32) + top_k_mask = torch.arange(0, probs_sort.shape[-1], device=probs_sort.device).view(1, -1) >= k.view(-1, 1) + probs_sort_out.masked_fill_(top_k_mask, 0.0) + + probs_sum = torch.cumsum(probs_sort_out, dim=-1) + top_p_mask = probs_sum - probs_sort_out > p.view(-1, 1) + probs_sort_out.masked_fill_(top_p_mask, 0.0) + + if min_p is not None: + min_p_thresholds = probs_sort_out[:, 0] * min_p + min_p_mask = probs_sort_out < min_p_thresholds.view(-1, 1) + probs_sort_out.masked_fill_(min_p_mask, 0.0) + return probs_sort_out.to(probs_sort.dtype) + + +class TestCustomApplyTopKTopPMinP(unittest.TestCase): + def test_apply_top_k_top_p_min_p_eager(self): + batch_size = 4 + vocab_size = 131072 + + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + np.random.seed(3) + logits = torch.tensor(np.random.uniform(-10, 10, (batch_size, vocab_size))).to(dtype) + k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to(torch.int32) + p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype) + min_p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype) + probs = torch.softmax(logits, dim=-1) + probs_sort, probs_idx = probs.sort(dim=-1, descending=True, stable=True) + cpu_out = _apply_top_k_top_p_min_p( + probs_sort, + k, + p, + min_p + ) + + torch_npu.npu.set_device(int(DEVICE_ID)) + probs_sort = probs_sort.to("npu:%s" % DEVICE_ID) + k = k.to("npu:%s" % DEVICE_ID) + p = p.to("npu:%s" % DEVICE_ID) + min_p = min_p.to("npu:%s" % DEVICE_ID) + + npu_out = torch.ops.npu.apply_top_k_top_p_min_p( + probs_sort, + k, + p, + min_p=min_p + ) + + # compare result + npu_out = npu_out.cpu() + cpu_out = cpu_out.cpu() + tol = DTYPE_ATOL[dtype] + assert torch.allclose( + cpu_out.to(torch.float32), + npu_out.to(torch.float32), + atol=tol, + rtol=tol, + ) + + + def test_apply_top_k_top_p_min_p_eager_without_min_p(self): + batch_size = 4 + vocab_size = 131072 + + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + np.random.seed(4) + logits = torch.tensor(np.random.uniform(-10, 10, (batch_size, vocab_size))).to(dtype) + k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to(torch.int32) + p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype) + probs = torch.softmax(logits, dim=-1) + probs_sort, probs_idx = probs.sort(dim=-1, descending=True, stable=True) + cpu_out = _apply_top_k_top_p_min_p( + probs_sort, + k, + p, + min_p=None + ) + + torch_npu.npu.set_device(int(DEVICE_ID)) + probs_sort = probs_sort.to("npu:%s" % DEVICE_ID) + k = k.to("npu:%s" % DEVICE_ID) + p = p.to("npu:%s" % DEVICE_ID) + + npu_out = torch.ops.npu.apply_top_k_top_p_min_p( + probs_sort, + k, + p, + min_p=None + ) + + # compare result + npu_out = npu_out.cpu() + cpu_out = cpu_out.cpu() + tol = DTYPE_ATOL[dtype] + assert torch.allclose( + cpu_out.to(torch.float32), + npu_out.to(torch.float32), + atol=tol, + rtol=tol, + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 9211cf8e367842c19fd43c22a7ec77099aae71f0 Mon Sep 17 00:00:00 2001 From: wangxinwei328 Date: Mon, 26 Jan 2026 15:09:30 +0800 Subject: [PATCH 2/3] fix lint --- csrc/apply_top_k_top_p_min_p/README.md | 6 +- .../op_host/apply_top_k_top_p_min_p.cpp | 16 ++- .../op_host/apply_top_k_top_p_min_p_def.h | 6 +- .../tiling/apply_top_k_top_p_min_p_tiling.cpp | 44 ++++--- .../tiling/apply_top_k_top_p_min_p_tiling.h | 4 +- .../apply_top_k_top_p_min_p_kernel.cpp | 113 +++++++++--------- csrc/pytorch_extensions.cpp | 2 +- include/sgl_kenel_npu_ops.h | 6 +- .../test_apply_top_k_top_p_min_p.py | 64 +++++----- 9 files changed, 124 insertions(+), 137 deletions(-) diff --git a/csrc/apply_top_k_top_p_min_p/README.md b/csrc/apply_top_k_top_p_min_p/README.md index 03a18b808..f452ea4a2 100644 --- a/csrc/apply_top_k_top_p_min_p/README.md +++ b/csrc/apply_top_k_top_p_min_p/README.md @@ -12,7 +12,7 @@ A top-k, top-p and min-p sampling implementation for ascend. ## Calculation Formula $$ -sampled\_res[b][v] = +sampled\_res[b][v] = \begin{cases} 0 & \text{v >= k[b]} \\ probs[b][v] & \text{v < k[b]} @@ -21,7 +21,7 @@ $$ $$probs\_sum = cumsum(sampled\_res, dim=-1)$$ $$top\_p\_mask[b][v] = probs\_sum[b][v] - sampled\_res[b][v] > p[b]$$ $$ -sampled\_res[b][v] = +sampled\_res[b][v] = \begin{cases} 0 & \text{top\_p\_mask = True} \\ sampled\_res[b][v] & \text{top\_p\_mask = False} @@ -29,7 +29,7 @@ sampled\_res[b][v] & \text{top\_p\_mask = False} $$ $$min\_p\_mask[b][v] = sampled\_res[b][v] < sampled\_res[b][0] * min\_p[b]$$ $$ -sampled\_res[b][v] = +sampled\_res[b][v] = \begin{cases} 0 & \text{min\_p\_mask = True} \\ sampled\_res[b][v] & \text{min\_p\_mask = False} diff --git a/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp index 18da6acb4..e543795f9 100644 --- a/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp +++ b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp @@ -40,10 +40,9 @@ HOST_API at::Tensor apply_top_k_top_p_min_p(const at::Tensor &probs, const at::T auto probsType = probs.scalar_type(); - at::Tensor minP = - min_p.has_value() - ? min_p.value() - : at::empty({1}, at::TensorOptions().dtype(probsType).device(probs.options().device())); + at::Tensor minP = min_p.has_value() + ? min_p.value() + : at::empty({1}, at::TensorOptions().dtype(probsType).device(probs.options().device())); ApplyTopKTopPMinPTilingInfo applyTopKTopPMinPInfo; applyTopKTopPMinPInfo.opParamInfo.probs.dtype = SCALAR_TYPE_TO_GE_DATATYPE(probsType); @@ -60,9 +59,8 @@ HOST_API at::Tensor apply_top_k_top_p_min_p(const at::Tensor &probs, const at::T applyTopKTopPMinPInfo.opParamInfo.sampledRes.shape = sampledRes.sizes(); ApplyTopKTopPMinPTiling applyTopKTopPMinPTiling(&applyTopKTopPMinPInfo); - TORCH_CHECK(applyTopKTopPMinPTiling.DoTiling() == ge::GRAPH_SUCCESS, - "apply_top_k_top_p_min_p DoTiling failed") - + TORCH_CHECK(applyTopKTopPMinPTiling.DoTiling() == ge::GRAPH_SUCCESS, "apply_top_k_top_p_min_p DoTiling failed"); + const auto &tilingData = applyTopKTopPMinPTiling.GetTilingData(); uint32_t tilingSize = (sizeof(ApplyTopKTopPMinPTiling) + PADDING_BYTE - 1) / PADDING_BYTE * PADDING_BYTE; @@ -71,9 +69,9 @@ HOST_API at::Tensor apply_top_k_top_p_min_p(const at::Tensor &probs, const at::T at::empty({tilingSize}, at::TensorOptions().dtype(at::kByte).device(probs.options().device())); aclrtMemcpy(tilingBuffer.data_ptr(), tilingSize, &tilingData, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE); at::Tensor tilingTensor = at::from_blob(tilingBuffer.data_ptr(), tilingSize, at::kByte); - + auto workspace = at::empty({applyTopKTopPMinPInfo.workspaceSize}, - at::TensorOptions().dtype(at::kByte).device(probs.options().device())); + at::TensorOptions().dtype(at::kByte).device(probs.options().device())); EXEC_KERNEL_CMD(apply_top_k_top_p_min_p, blockDim, probs, k, p, minP, sampledRes, workspace, tilingTensor); return sampledRes; } diff --git a/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h index 8b82c9de3..7fb248d48 100644 --- a/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h +++ b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h @@ -29,11 +29,7 @@ class ApplyTopKTopPMinP : public OpDef .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) .FormatList({ge::FORMAT_ND}) .AutoContiguous(); - this->Input("k") - .ParamType(REQUIRED) - .DataTypeList({ge::DT_INT32}) - .FormatList({ge::FORMAT_ND}) - .AutoContiguous(); + this->Input("k").ParamType(REQUIRED).DataTypeList({ge::DT_INT32}).FormatList({ge::FORMAT_ND}).AutoContiguous(); this->Input("p") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) diff --git a/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp index 769f54847..8f38d1c0a 100644 --- a/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp +++ b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp @@ -26,17 +26,16 @@ namespace sglang::ATKTPMPHost { ge::graphStatus ApplyTopKTopPMinPTiling::CheckDtype() { TORCH_CHECK((tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT16) || - (tilingInfo_->opParamInfo.probs.dtype == ge::DT_BF16) || - (tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT), + (tilingInfo_->opParamInfo.probs.dtype == ge::DT_BF16) || + (tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT), "The data types of probs, p and sampled_res must be float16, bfloat16 or float."); TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.p.dtype, "The data types of probs and p must be the same."); TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.sampledRes.dtype, "The data types of probs and sampled_res must be the same."); - - TORCH_CHECK(tilingInfo_->opParamInfo.k.dtype == ge::DT_INT32, - "The data types of the input k must be int32."); + + TORCH_CHECK(tilingInfo_->opParamInfo.k.dtype == ge::DT_INT32, "The data types of the input k must be int32."); return ge::GRAPH_SUCCESS; } @@ -49,19 +48,17 @@ ge::graphStatus ApplyTopKTopPMinPTiling::CheckShape() tilingData_.batchSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ZERO]; tilingData_.vocabSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ONE]; - TORCH_CHECK(tilingInfo_->opParamInfo.k.shape.size() == DIM_NUM_ONE, - "ApplyTopKTopPMinP: the dimNum of k should be ", DIM_NUM_ONE, ", but now is ", - tilingInfo_->opParamInfo.k.shape.size(), "."); + TORCH_CHECK(tilingInfo_->opParamInfo.k.shape.size() == DIM_NUM_ONE, "ApplyTopKTopPMinP: the dimNum of k should be ", + DIM_NUM_ONE, ", but now is ", tilingInfo_->opParamInfo.k.shape.size(), "."); int64_t kSize = tilingInfo_->opParamInfo.k.shape[DIM_IDX_ZERO]; - TORCH_CHECK(kSize == tilingData_.batchSize, - "ApplyTopKTopPMinP: the shape of k should be [", tilingData_.batchSize, "], but now is [", kSize, "]."); - - TORCH_CHECK(tilingInfo_->opParamInfo.p.shape.size() == DIM_NUM_ONE, - "ApplyTopKTopPMinP: the dimNum of p should be ", DIM_NUM_ONE, ", but now is ", - tilingInfo_->opParamInfo.p.shape.size(), "."); + TORCH_CHECK(kSize == tilingData_.batchSize, "ApplyTopKTopPMinP: the shape of k should be [", tilingData_.batchSize, + "], but now is [", kSize, "]."); + + TORCH_CHECK(tilingInfo_->opParamInfo.p.shape.size() == DIM_NUM_ONE, "ApplyTopKTopPMinP: the dimNum of p should be ", + DIM_NUM_ONE, ", but now is ", tilingInfo_->opParamInfo.p.shape.size(), "."); int64_t pSize = tilingInfo_->opParamInfo.p.shape[DIM_IDX_ZERO]; - TORCH_CHECK(pSize == tilingData_.batchSize, - "ApplyTopKTopPMinP: the shape of p should be [", tilingData_.batchSize, "], but now is [", pSize, "]."); + TORCH_CHECK(pSize == tilingData_.batchSize, "ApplyTopKTopPMinP: the shape of p should be [", tilingData_.batchSize, + "], but now is [", pSize, "]."); if (tilingInfo_->opParamInfo.minP.shape.size() != DIM_NUM_ZERO) { int64_t minPSize = tilingInfo_->opParamInfo.minP.shape[DIM_IDX_ZERO]; @@ -76,9 +73,8 @@ ge::graphStatus ApplyTopKTopPMinPTiling::CheckShape() int64_t sampledResSize0 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ZERO]; int64_t sampledResSize1 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ONE]; TORCH_CHECK(sampledResSize0 == tilingData_.batchSize && sampledResSize1 == tilingData_.vocabSize, - "ApplyTopKTopPMinP: the size of sampledRes should be [", - tilingData_.batchSize, ", ", tilingData_.vocabSize, - "], but now is [", sampledResSize0, ", ", sampledResSize1, "]."); + "ApplyTopKTopPMinP: the size of sampledRes should be [", tilingData_.batchSize, ", ", + tilingData_.vocabSize, "], but now is [", sampledResSize0, ", ", sampledResSize1, "]."); return ge::GRAPH_SUCCESS; } @@ -111,18 +107,18 @@ ge::graphStatus ApplyTopKTopPMinPTiling::DoTiling() auto socVersion = ascendcPlatform.GetSocVersion(); TORCH_CHECK(socVersion == platform_ascendc::SocVersion::ASCEND910B || - socVersion == platform_ascendc::SocVersion::ASCEND910_93, + socVersion == platform_ascendc::SocVersion::ASCEND910_93, "soc version does not support ", (int32_t)socVersion); - + SplitTask(); - + // -------------set workspacesize----------------- tilingInfo_->workspaceSize = static_cast(ascendcPlatform.GetLibApiWorkSpaceSize()) + tilingData_.batchSize * tilingData_.vocabSize * BYTES_B32; // -------------set tilingkey----------------- - tilingData_.tilingKey = G_DTYPE_MAP.at(tilingInfo_->opParamInfo.probs.dtype) * COEF_TEN + - tilingInfo_->needMinPSample; + tilingData_.tilingKey = + G_DTYPE_MAP.at(tilingInfo_->opParamInfo.probs.dtype) * COEF_TEN + tilingInfo_->needMinPSample; return ge::GRAPH_SUCCESS; } diff --git a/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h index 5856a7470..1b7cda228 100644 --- a/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h +++ b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h @@ -30,9 +30,7 @@ struct TensorParaInfo { c10::ArrayRef shape; }; -const std::map G_DTYPE_MAP = {{ge::DT_FLOAT, 1}, - {ge::DT_FLOAT16, 2}, - {ge::DT_BF16, 3}}; +const std::map G_DTYPE_MAP = {{ge::DT_FLOAT, 1}, {ge::DT_FLOAT16, 2}, {ge::DT_BF16, 3}}; // ------------------算子原型索引常量定义---------------- // Dim Index constexpr uint32_t DIM_IDX_ZERO = 0; diff --git a/csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp b/csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp index a5df21296..e44b615d2 100644 --- a/csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp +++ b/csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp @@ -37,26 +37,27 @@ __aicore__ inline void SetWaitFlag(HardEvent evt) } template -class ApplyTopKTopPMinP { +class ApplyTopKTopPMinP +{ public: __aicore__ inline ApplyTopKTopPMinP(){}; - __aicore__ inline void Init( - const __gm__ ApplyTopKTopPMinPTilingData *tilingData, GM_ADDR probs, GM_ADDR k, - GM_ADDR p, GM_ADDR min_p, GM_ADDR sampled_res, GM_ADDR workspace, TPipe *tPipe); + __aicore__ inline void Init(const __gm__ ApplyTopKTopPMinPTilingData *tilingData, GM_ADDR probs, GM_ADDR k, + GM_ADDR p, GM_ADDR min_p, GM_ADDR sampled_res, GM_ADDR workspace, TPipe *tPipe); __aicore__ inline void Process(); + private: - __aicore__ inline void CopyInWithCast( - LocalTensor& localTensor, GlobalTensor& globalTensor, int64_t gmOffset, int64_t copyNum); - __aicore__ inline void InitWorkspace(LocalTensor& workLocal); - __aicore__ inline void GetFloatValue(GlobalTensor& globalTensor, int64_t offset, float& value); - __aicore__ inline void CumSumImpl(LocalTensor& cumSumInput1Local, LocalTensor& cumSumInput2Local); - __aicore__ inline void CopyOutWithCast( - GlobalTensor& globalTensor, LocalTensor& localTensor, int64_t gmOffset, int64_t copyNum); - __aicore__ inline void TopPProcess( - LocalTensor& cumSumInput1Local, LocalTensor& cumSumInput2Local, LocalTensor& zeroLocal, - LocalTensor& maskLocal); - __aicore__ inline void MinPProcess( - LocalTensor& workLocal, LocalTensor& zeroLocal, LocalTensor& maskLocal); + __aicore__ inline void CopyInWithCast(LocalTensor &localTensor, GlobalTensor &globalTensor, + int64_t gmOffset, int64_t copyNum); + __aicore__ inline void InitWorkspace(LocalTensor &workLocal); + __aicore__ inline void GetFloatValue(GlobalTensor &globalTensor, int64_t offset, float &value); + __aicore__ inline void CumSumImpl(LocalTensor &cumSumInput1Local, LocalTensor &cumSumInput2Local); + __aicore__ inline void CopyOutWithCast(GlobalTensor &globalTensor, LocalTensor &localTensor, + int64_t gmOffset, int64_t copyNum); + __aicore__ inline void TopPProcess(LocalTensor &cumSumInput1Local, LocalTensor &cumSumInput2Local, + LocalTensor &zeroLocal, LocalTensor &maskLocal); + __aicore__ inline void MinPProcess(LocalTensor &workLocal, LocalTensor &zeroLocal, + LocalTensor &maskLocal); + private: TBuf calBuf_; @@ -93,9 +94,9 @@ class ApplyTopKTopPMinP { }; template -__aicore__ inline void ApplyTopKTopPMinP::Init( - const __gm__ ApplyTopKTopPMinPTilingData *tilingData, GM_ADDR probs, GM_ADDR k, - GM_ADDR p, GM_ADDR min_p, GM_ADDR sampled_res, GM_ADDR workspace, TPipe *tPipe) +__aicore__ inline void ApplyTopKTopPMinP::Init(const __gm__ ApplyTopKTopPMinPTilingData *tilingData, + GM_ADDR probs, GM_ADDR k, GM_ADDR p, GM_ADDR min_p, + GM_ADDR sampled_res, GM_ADDR workspace, TPipe *tPipe) { batchSize_ = tilingData->batchSize; vocabSize_ = tilingData->vocabSize; @@ -137,8 +138,8 @@ __aicore__ inline void ApplyTopKTopPMinP::Init( } template -__aicore__ inline void ApplyTopKTopPMinP::GetFloatValue( - GlobalTensor& globalTensor, int64_t offset, float& value) +__aicore__ inline void ApplyTopKTopPMinP::GetFloatValue(GlobalTensor &globalTensor, + int64_t offset, float &value) { if constexpr (IsSameType::value) { value = globalTensor.GetValue(offset); @@ -196,12 +197,13 @@ __aicore__ inline void ApplyTopKTopPMinP::Process() } template -__aicore__ inline void ApplyTopKTopPMinP::CopyInWithCast( - LocalTensor& localTensor, GlobalTensor& globalTensor, int64_t gmOffset, int64_t copyNum) +__aicore__ inline void ApplyTopKTopPMinP::CopyInWithCast(LocalTensor &localTensor, + GlobalTensor &globalTensor, + int64_t gmOffset, int64_t copyNum) { if constexpr (IsSameType::value) { - DataCopyPad(localTensor, globalTensor[gmOffset], - {1, static_cast(copyNum * sizeof(T)), 0, 0, 0}, {false, 0, 0, 0}); + DataCopyPad(localTensor, globalTensor[gmOffset], {1, static_cast(copyNum * sizeof(T)), 0, 0, 0}, + {false, 0, 0, 0}); } else { DataCopyPad(localTensor.ReinterpretCast()[loopDataNum_], globalTensor[gmOffset], {1, static_cast(copyNum * sizeof(T)), 0, 0, 0}, {false, 0, 0, 0}); @@ -211,7 +213,7 @@ __aicore__ inline void ApplyTopKTopPMinP::CopyInWithCast( } template -__aicore__ inline void ApplyTopKTopPMinP::InitWorkspace(LocalTensor& workLocal) +__aicore__ inline void ApplyTopKTopPMinP::InitWorkspace(LocalTensor &workLocal) { for (int64_t vocabLoop = 0; vocabLoop < kLoopNum_; vocabLoop++) { CopyInWithCast(workLocal, gmProbs_, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_); @@ -238,8 +240,8 @@ __aicore__ inline void ApplyTopKTopPMinP::InitWorkspace(Local } template -__aicore__ inline void ApplyTopKTopPMinP::CumSumImpl( - LocalTensor& cumSumInput1Local, LocalTensor& cumSumInput2Local) +__aicore__ inline void ApplyTopKTopPMinP::CumSumImpl(LocalTensor &cumSumInput1Local, + LocalTensor &cumSumInput2Local) { int64_t tmpValue = 1; iterateTimes_ = 0; @@ -284,8 +286,9 @@ __aicore__ inline void ApplyTopKTopPMinP::CumSumImpl( } template -__aicore__ inline void ApplyTopKTopPMinP::CopyOutWithCast( - GlobalTensor& globalTensor, LocalTensor& localTensor, int64_t gmOffset, int64_t copyNum) +__aicore__ inline void ApplyTopKTopPMinP::CopyOutWithCast(GlobalTensor &globalTensor, + LocalTensor &localTensor, + int64_t gmOffset, int64_t copyNum) { if constexpr (IsSameType::value) { DataCopyPad(globalTensor[gmOffset], localTensor, {1, static_cast(copyNum * sizeof(T)), 0, 0, 0}); @@ -298,9 +301,10 @@ __aicore__ inline void ApplyTopKTopPMinP::CopyOutWithCast( } template -__aicore__ inline void ApplyTopKTopPMinP::TopPProcess( - LocalTensor& cumSumInput1Local, LocalTensor& cumSumInput2Local, LocalTensor& zeroLocal, - LocalTensor& maskLocal) +__aicore__ inline void ApplyTopKTopPMinP::TopPProcess(LocalTensor &cumSumInput1Local, + LocalTensor &cumSumInput2Local, + LocalTensor &zeroLocal, + LocalTensor &maskLocal) { CumSumImpl(cumSumInput1Local, cumSumInput2Local); for (int64_t vocabLoop = 0; vocabLoop < kLoopNum_; vocabLoop++) { @@ -316,7 +320,8 @@ __aicore__ inline void ApplyTopKTopPMinP::TopPProcess( PipeBarrier(); CompareScalar(maskLocal, cumSumInput1Local, pValue_, CMPMODE::GT, loopDataNum_); PipeBarrier(); - Select(cumSumInput2Local, maskLocal, zeroLocal, cumSumInput2Local, SELMODE::VSEL_TENSOR_TENSOR_MODE, loopDataNum_); + Select(cumSumInput2Local, maskLocal, zeroLocal, cumSumInput2Local, SELMODE::VSEL_TENSOR_TENSOR_MODE, + loopDataNum_); if constexpr (IsSameType::value) { SetWaitFlag(HardEvent::V_MTE3); } else { @@ -338,7 +343,8 @@ __aicore__ inline void ApplyTopKTopPMinP::TopPProcess( PipeBarrier(); CompareScalar(maskLocal, cumSumInput1Local, pValue_, CMPMODE::GT, loopDataNum_); PipeBarrier(); - Select(cumSumInput2Local, maskLocal, zeroLocal, cumSumInput2Local, SELMODE::VSEL_TENSOR_TENSOR_MODE, loopDataNum_); + Select(cumSumInput2Local, maskLocal, zeroLocal, cumSumInput2Local, SELMODE::VSEL_TENSOR_TENSOR_MODE, + loopDataNum_); if constexpr (IsSameType::value) { SetWaitFlag(HardEvent::V_MTE3); } else { @@ -350,8 +356,9 @@ __aicore__ inline void ApplyTopKTopPMinP::TopPProcess( } template -__aicore__ inline void ApplyTopKTopPMinP::MinPProcess( - LocalTensor& workLocal, LocalTensor& zeroLocal, LocalTensor& maskLocal) +__aicore__ inline void ApplyTopKTopPMinP::MinPProcess(LocalTensor &workLocal, + LocalTensor &zeroLocal, + LocalTensor &maskLocal) { for (int64_t vocabLoop = 0; vocabLoop < kLoopNum_; vocabLoop++) { CopyInWithCast(workLocal, gmSampledRes_, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_); @@ -392,18 +399,16 @@ __aicore__ inline void ApplyTopKTopPMinP::MinPProcess( } } // namespace sglang::npu_kernel::ApplyTopKTopPMinPKernel -__global__ __aicore__ void apply_top_k_top_p_min_p( - GM_ADDR probs, GM_ADDR k, GM_ADDR p, GM_ADDR min_p, GM_ADDR sampled_res, - GM_ADDR workspace, GM_ADDR tiling) +__global__ __aicore__ void apply_top_k_top_p_min_p(GM_ADDR probs, GM_ADDR k, GM_ADDR p, GM_ADDR min_p, + GM_ADDR sampled_res, GM_ADDR workspace, GM_ADDR tiling) { -#define INIT_AND_PROCESS \ +#define INIT_AND_PROCESS \ op.Init(tilingData, probs, k, p, min_p, sampled_res, userWS, &tPipe); \ op.Process(); AscendC::TPipe tPipe; using namespace sglang::npu_kernel::ApplyTopKTopPMinPKernel; - // __gm__ uint8_t *userWorkspace = workspace; KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY); auto tilingData = reinterpret_cast<__gm__ sglang::ATKTPMPHost::ApplyTopKTopPMinPTilingData *>(tiling); @@ -414,22 +419,22 @@ __global__ __aicore__ void apply_top_k_top_p_min_p( } if (tilingKey == TTILING_FP32_WITHOUT_MIN_P) { - ApplyTopKTopPMinP op; - INIT_AND_PROCESS; + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; } else if (tilingKey == TTILING_FP16_WITHOUT_MIN_P) { - ApplyTopKTopPMinP op; - INIT_AND_PROCESS; + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; } else if (tilingKey == TTILING_BF16_WITHOUT_MIN_P) { - ApplyTopKTopPMinP op; - INIT_AND_PROCESS; + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; } else if (tilingKey == TTILING_FP32_MIN_P) { - ApplyTopKTopPMinP op; - INIT_AND_PROCESS; + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; } else if (tilingKey == TTILING_FP16_MIN_P) { - ApplyTopKTopPMinP op; - INIT_AND_PROCESS; + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; } else if (tilingKey == TTILING_BF16_MIN_P) { - ApplyTopKTopPMinP op; - INIT_AND_PROCESS; + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; } } diff --git a/csrc/pytorch_extensions.cpp b/csrc/pytorch_extensions.cpp index 6916c284c..5d88dc1df 100644 --- a/csrc/pytorch_extensions.cpp +++ b/csrc/pytorch_extensions.cpp @@ -98,7 +98,7 @@ TORCH_LIBRARY_FRAGMENT(npu, m) "int? sparse_count=None, int? sparse_mode=None) -> Tensor"); m.def("triangular_inverse(Tensor x) -> Tensor"); - + m.def("apply_top_k_top_p_min_p(Tensor logits, Tensor k, Tensor p, Tensor? min_p=None) -> Tensor"); } } // namespace diff --git a/include/sgl_kenel_npu_ops.h b/include/sgl_kenel_npu_ops.h index 22517df12..e3aeda613 100644 --- a/include/sgl_kenel_npu_ops.h +++ b/include/sgl_kenel_npu_ops.h @@ -124,9 +124,9 @@ at::Tensor lightning_indexer( */ at::Tensor tri_inv_col_sweep(const at::Tensor &tensor_in); -at::Tensor apply_top_k_top_p_min_p( - const at::Tensor &logits, const at::Tensor &k, const at::Tensor &p, - const c10::optional &min_p); +at::Tensor apply_top_k_top_p_min_p(const at::Tensor &logits, + const at::Tensor &k, const at::Tensor &p, + const c10::optional &min_p); } // namespace npu_kernel } // namespace sglang diff --git a/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py b/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py index 9b1dc1f62..191b15104 100644 --- a/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py +++ b/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py @@ -27,7 +27,9 @@ def _apply_top_k_top_p_min_p( min_p=None, ): probs_sort_out = probs_sort.clone().to(torch.float32) - top_k_mask = torch.arange(0, probs_sort.shape[-1], device=probs_sort.device).view(1, -1) >= k.view(-1, 1) + top_k_mask = torch.arange(0, probs_sort.shape[-1], device=probs_sort.device).view( + 1, -1 + ) >= k.view(-1, 1) probs_sort_out.masked_fill_(top_k_mask, 0.0) probs_sum = torch.cumsum(probs_sort_out, dim=-1) @@ -48,18 +50,17 @@ def test_apply_top_k_top_p_min_p_eager(self): for dtype in [torch.bfloat16, torch.float16, torch.float32]: np.random.seed(3) - logits = torch.tensor(np.random.uniform(-10, 10, (batch_size, vocab_size))).to(dtype) - k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to(torch.int32) + logits = torch.tensor( + np.random.uniform(-10, 10, (batch_size, vocab_size)) + ).to(dtype) + k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to( + torch.int32 + ) p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype) min_p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype) probs = torch.softmax(logits, dim=-1) probs_sort, probs_idx = probs.sort(dim=-1, descending=True, stable=True) - cpu_out = _apply_top_k_top_p_min_p( - probs_sort, - k, - p, - min_p - ) + cpu_out = _apply_top_k_top_p_min_p(probs_sort, k, p, min_p) torch_npu.npu.set_device(int(DEVICE_ID)) probs_sort = probs_sort.to("npu:%s" % DEVICE_ID) @@ -68,10 +69,7 @@ def test_apply_top_k_top_p_min_p_eager(self): min_p = min_p.to("npu:%s" % DEVICE_ID) npu_out = torch.ops.npu.apply_top_k_top_p_min_p( - probs_sort, - k, - p, - min_p=min_p + probs_sort, k, p, min_p=min_p ) # compare result @@ -79,11 +77,11 @@ def test_apply_top_k_top_p_min_p_eager(self): cpu_out = cpu_out.cpu() tol = DTYPE_ATOL[dtype] assert torch.allclose( - cpu_out.to(torch.float32), - npu_out.to(torch.float32), - atol=tol, - rtol=tol, - ) + cpu_out.to(torch.float32), + npu_out.to(torch.float32), + atol=tol, + rtol=tol, + ) def test_apply_top_k_top_p_min_p_eager_without_min_p(self): @@ -92,17 +90,16 @@ def test_apply_top_k_top_p_min_p_eager_without_min_p(self): for dtype in [torch.bfloat16, torch.float16, torch.float32]: np.random.seed(4) - logits = torch.tensor(np.random.uniform(-10, 10, (batch_size, vocab_size))).to(dtype) - k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to(torch.int32) + logits = torch.tensor( + np.random.uniform(-10, 10, (batch_size, vocab_size)) + ).to(dtype) + k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to( + torch.int32 + ) p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype) probs = torch.softmax(logits, dim=-1) probs_sort, probs_idx = probs.sort(dim=-1, descending=True, stable=True) - cpu_out = _apply_top_k_top_p_min_p( - probs_sort, - k, - p, - min_p=None - ) + cpu_out = _apply_top_k_top_p_min_p(probs_sort, k, p, min_p=None) torch_npu.npu.set_device(int(DEVICE_ID)) probs_sort = probs_sort.to("npu:%s" % DEVICE_ID) @@ -110,10 +107,7 @@ def test_apply_top_k_top_p_min_p_eager_without_min_p(self): p = p.to("npu:%s" % DEVICE_ID) npu_out = torch.ops.npu.apply_top_k_top_p_min_p( - probs_sort, - k, - p, - min_p=None + probs_sort, k, p, min_p=None ) # compare result @@ -121,11 +115,11 @@ def test_apply_top_k_top_p_min_p_eager_without_min_p(self): cpu_out = cpu_out.cpu() tol = DTYPE_ATOL[dtype] assert torch.allclose( - cpu_out.to(torch.float32), - npu_out.to(torch.float32), - atol=tol, - rtol=tol, - ) + cpu_out.to(torch.float32), + npu_out.to(torch.float32), + atol=tol, + rtol=tol, + ) if __name__ == "__main__": From f8e7926f484d9a6c63e15aa43465454caab00fe6 Mon Sep 17 00:00:00 2001 From: wangxinwei328 Date: Mon, 26 Jan 2026 15:18:24 +0800 Subject: [PATCH 3/3] fix lint --- tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py b/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py index 191b15104..ef1929ea8 100644 --- a/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py +++ b/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py @@ -83,7 +83,6 @@ def test_apply_top_k_top_p_min_p_eager(self): rtol=tol, ) - def test_apply_top_k_top_p_min_p_eager_without_min_p(self): batch_size = 4 vocab_size = 131072