diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 6b3236f8b..68826e8d2 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -22,6 +22,8 @@ FILE(GLOB OP_SRCS ${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/lightning_indexer.cpp ${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/tiling/lightning_indexer_tiling.cpp ${PROJECT_OP_SRC_BASE}/tri_inv/op_host/tri_inv.cpp + ${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp + ${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp ) if(BUILD_CATLASS_MODULE) list(APPEND OP_SRCS @@ -53,6 +55,7 @@ set(WORKSPACE_KERNEL_SRCS ${PROJECT_OP_SRC_BASE}/alloc_extend/op_kernel/alloc_extend_kernel.cpp ${PROJECT_OP_SRC_BASE}/build_tree/op_kernel/build_tree_kernel.cpp ${PROJECT_OP_SRC_BASE}/lightning_indexer/op_kernel/lightning_indexer_kernel.cpp + ${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp ) if(BUILD_CATLASS_MODULE) list(APPEND WORKSPACE_KERNEL_SRCS diff --git a/csrc/apply_top_k_top_p_min_p/README.md b/csrc/apply_top_k_top_p_min_p/README.md new file mode 100644 index 000000000..f452ea4a2 --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/README.md @@ -0,0 +1,65 @@ +## Introduction +A top-k, top-p and min-p sampling implementation for ascend. + +## Sheet 1: Parameters +| Parameter | Dimension | Data Type | Format | Description | +|--------------|--------------------------|----------------------|--------|--------------------------------------------------| +| probs | [batch_size, vocab_size] | float32/float16/bf16 | ND | Probabilities for sampling.
The probabilities should be sorted in descending order. | +| k | [batch_size] | int32 | ND | Representing the threshold for top-k sampling. | +| p | [batch_size] | float32/float16/bf16 | ND | Representing the threshold for top-p sampling. | +| min_p | [batch_size] | float32/float16/bf16 | ND | Representing the threshold for min-p sampling.
When min_p is nullptr, the min-p sampling will be skipped. | +| sampled_res | [batch_size, vocab_size] | float32/float16/bf16 | ND | The result after sampling.
The DataType of sampled_res should be same as probs. | + +## Calculation Formula +$$ +sampled\_res[b][v] = +\begin{cases} +0 & \text{v >= k[b]} \\ +probs[b][v] & \text{v < k[b]} +\end{cases} +$$ +$$probs\_sum = cumsum(sampled\_res, dim=-1)$$ +$$top\_p\_mask[b][v] = probs\_sum[b][v] - sampled\_res[b][v] > p[b]$$ +$$ +sampled\_res[b][v] = +\begin{cases} +0 & \text{top\_p\_mask = True} \\ +sampled\_res[b][v] & \text{top\_p\_mask = False} +\end{cases} +$$ +$$min\_p\_mask[b][v] = sampled\_res[b][v] < sampled\_res[b][0] * min\_p[b]$$ +$$ +sampled\_res[b][v] = +\begin{cases} +0 & \text{min\_p\_mask = True} \\ +sampled\_res[b][v] & \text{min\_p\_mask = False} +\end{cases} +$$ +Where $0 \le b \lt batch\_size$, and $0 \le v \lt vocab\_size$. + +## Restrictions +1. Only support Ascend A2/A3. +2. $0 \lt k[b] \le vocab\_size$, where $0 \le b \lt batch\_size$, if $k[b] \lt 0$ or $k[b] \gt vocab\_size$, the $k[b]$ will regarded as vocab\_size. +2. $0 \le p[b] \le 1$, where $0 \le b \lt batch\_size$. + +## Sample Code +```python +import numpy as np +import torch +import torch_npu +import sgl_kernel_npu + +dtype = torch.float16 +batch_size = 4 +vocab_size = 128 + +logits = torch.tensor(np.random.uniform(-10, 10, (batch_size, vocab_size))).to(dtype).npu() +k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to(torch.int32).npu() +p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype).npu() +min_p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype).npu() + +probs = torch.softmax(logits, dim=-1) +probs_sort, probs_idx = probs.sort(dim=-1, descending=True, stable=True) + +torch.ops.npu.apply_top_k_top_p_min_p(probs_sort, k, p, min_p=min_p) +``` diff --git a/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp new file mode 100644 index 000000000..e543795f9 --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp @@ -0,0 +1,79 @@ +#include +#include +#include "acl/acl.h" +#include "kernel_tiling/kernel_tiling.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/apply_top_k_top_p_min_p_tiling.h" +#include "defines.h" +#include "torch_helper.h" +#include "ge_helper.h" +#include "common_tiling.h" +#include "apply_top_k_top_p_min_p_def.h" +#include "common.h" +#include "aclrtlaunch_apply_top_k_top_p_min_p.h" + +namespace sglang::ATKTPMPHost { + +using namespace ge_helper; +constexpr uint32_t PADDING_BYTE = 32U; + +inline at::Tensor ConstructApplyTopKTopPMinPOutputTensor(const at::Tensor &probs) +{ + for (size_t i = 0; i < probs.sizes().size(); i++) { + TORCH_CHECK(probs.size(i) > 0, + "All values within probs's shape should be greater " + "than 0, but shape[", + i, "] is ", probs.size(i)); + } + at::Tensor output = at::empty_like(probs); + return output; +} +} // namespace sglang::ATKTPMPHost + +namespace sglang { +namespace npu_kernel { +HOST_API at::Tensor apply_top_k_top_p_min_p(const at::Tensor &probs, const at::Tensor &k, const at::Tensor &p, + const c10::optional &min_p) +{ + using namespace ATKTPMPHost; + at::Tensor sampledRes = ConstructApplyTopKTopPMinPOutputTensor(probs); + + auto probsType = probs.scalar_type(); + + at::Tensor minP = min_p.has_value() + ? min_p.value() + : at::empty({1}, at::TensorOptions().dtype(probsType).device(probs.options().device())); + + ApplyTopKTopPMinPTilingInfo applyTopKTopPMinPInfo; + applyTopKTopPMinPInfo.opParamInfo.probs.dtype = SCALAR_TYPE_TO_GE_DATATYPE(probsType); + applyTopKTopPMinPInfo.opParamInfo.probs.shape = probs.sizes(); + applyTopKTopPMinPInfo.opParamInfo.k.dtype = SCALAR_TYPE_TO_GE_DATATYPE(k.scalar_type()); + applyTopKTopPMinPInfo.opParamInfo.k.shape = k.sizes(); + applyTopKTopPMinPInfo.opParamInfo.p.dtype = SCALAR_TYPE_TO_GE_DATATYPE(p.scalar_type()); + applyTopKTopPMinPInfo.opParamInfo.p.shape = p.sizes(); + if (min_p.has_value()) { + applyTopKTopPMinPInfo.opParamInfo.minP.dtype = SCALAR_TYPE_TO_GE_DATATYPE(minP.scalar_type()); + applyTopKTopPMinPInfo.opParamInfo.minP.shape = minP.sizes(); + } + applyTopKTopPMinPInfo.opParamInfo.sampledRes.dtype = SCALAR_TYPE_TO_GE_DATATYPE(sampledRes.scalar_type()); + applyTopKTopPMinPInfo.opParamInfo.sampledRes.shape = sampledRes.sizes(); + + ApplyTopKTopPMinPTiling applyTopKTopPMinPTiling(&applyTopKTopPMinPInfo); + TORCH_CHECK(applyTopKTopPMinPTiling.DoTiling() == ge::GRAPH_SUCCESS, "apply_top_k_top_p_min_p DoTiling failed"); + + const auto &tilingData = applyTopKTopPMinPTiling.GetTilingData(); + + uint32_t tilingSize = (sizeof(ApplyTopKTopPMinPTiling) + PADDING_BYTE - 1) / PADDING_BYTE * PADDING_BYTE; + auto blockDim = tilingData.coreNum; + static auto tilingBuffer = + at::empty({tilingSize}, at::TensorOptions().dtype(at::kByte).device(probs.options().device())); + aclrtMemcpy(tilingBuffer.data_ptr(), tilingSize, &tilingData, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE); + at::Tensor tilingTensor = at::from_blob(tilingBuffer.data_ptr(), tilingSize, at::kByte); + + auto workspace = at::empty({applyTopKTopPMinPInfo.workspaceSize}, + at::TensorOptions().dtype(at::kByte).device(probs.options().device())); + EXEC_KERNEL_CMD(apply_top_k_top_p_min_p, blockDim, probs, k, p, minP, sampledRes, workspace, tilingTensor); + return sampledRes; +} +} // namespace npu_kernel +} // namespace sglang diff --git a/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h new file mode 100644 index 000000000..7fb248d48 --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h @@ -0,0 +1,50 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of + * the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_min_p_def.cpp + * \brief + */ +#include +#include "ge_helper.h" + +namespace sglang { +namespace ATKTPMPHost { +using namespace ge_helper; +class ApplyTopKTopPMinP : public OpDef +{ +public: + explicit ApplyTopKTopPMinP(const char *name) : OpDef(name) + { + this->Input("probs") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("k").ParamType(REQUIRED).DataTypeList({ge::DT_INT32}).FormatList({ge::FORMAT_ND}).AutoContiguous(); + this->Input("p") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("min_p") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Output("sampled_res") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .FormatList({ge::FORMAT_ND}); + } +}; +} // namespace ATKTPMPHost +} // namespace sglang diff --git a/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp new file mode 100644 index 000000000..8f38d1c0a --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp @@ -0,0 +1,130 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of + * the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_min_p_tiling.cpp + * \brief + */ + +#include "apply_top_k_top_p_min_p_tiling.h" + +using namespace ge; +using namespace AscendC; +using std::map; +using std::string; +namespace sglang::ATKTPMPHost { + +// --------------------------ApplyTopKTopPMinPTiling类成员函数定义----------------------- +ge::graphStatus ApplyTopKTopPMinPTiling::CheckDtype() +{ + TORCH_CHECK((tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT16) || + (tilingInfo_->opParamInfo.probs.dtype == ge::DT_BF16) || + (tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT), + "The data types of probs, p and sampled_res must be float16, bfloat16 or float."); + + TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.p.dtype, + "The data types of probs and p must be the same."); + TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.sampledRes.dtype, + "The data types of probs and sampled_res must be the same."); + + TORCH_CHECK(tilingInfo_->opParamInfo.k.dtype == ge::DT_INT32, "The data types of the input k must be int32."); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus ApplyTopKTopPMinPTiling::CheckShape() +{ + TORCH_CHECK(tilingInfo_->opParamInfo.probs.shape.size() == DIM_NUM_TWO, + "ApplyTopKTopPMinP: the dimNum of probs should be ", DIM_NUM_TWO, ", but now is ", + tilingInfo_->opParamInfo.probs.shape.size(), "."); + tilingData_.batchSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ZERO]; + tilingData_.vocabSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ONE]; + + TORCH_CHECK(tilingInfo_->opParamInfo.k.shape.size() == DIM_NUM_ONE, "ApplyTopKTopPMinP: the dimNum of k should be ", + DIM_NUM_ONE, ", but now is ", tilingInfo_->opParamInfo.k.shape.size(), "."); + int64_t kSize = tilingInfo_->opParamInfo.k.shape[DIM_IDX_ZERO]; + TORCH_CHECK(kSize == tilingData_.batchSize, "ApplyTopKTopPMinP: the shape of k should be [", tilingData_.batchSize, + "], but now is [", kSize, "]."); + + TORCH_CHECK(tilingInfo_->opParamInfo.p.shape.size() == DIM_NUM_ONE, "ApplyTopKTopPMinP: the dimNum of p should be ", + DIM_NUM_ONE, ", but now is ", tilingInfo_->opParamInfo.p.shape.size(), "."); + int64_t pSize = tilingInfo_->opParamInfo.p.shape[DIM_IDX_ZERO]; + TORCH_CHECK(pSize == tilingData_.batchSize, "ApplyTopKTopPMinP: the shape of p should be [", tilingData_.batchSize, + "], but now is [", pSize, "]."); + + if (tilingInfo_->opParamInfo.minP.shape.size() != DIM_NUM_ZERO) { + int64_t minPSize = tilingInfo_->opParamInfo.minP.shape[DIM_IDX_ZERO]; + TORCH_CHECK(minPSize == tilingData_.batchSize, ": the shape of p should be [", tilingData_.batchSize, + "], but now is [", minPSize, "]."); + tilingInfo_->needMinPSample = 1; + } + + TORCH_CHECK(tilingInfo_->opParamInfo.sampledRes.shape.size() == DIM_NUM_TWO, + "ApplyTopKTopPMinP: the dimNum of sampled_res should be ", DIM_NUM_TWO, ", but now is ", + tilingInfo_->opParamInfo.sampledRes.shape.size(), "."); + int64_t sampledResSize0 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ZERO]; + int64_t sampledResSize1 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ONE]; + TORCH_CHECK(sampledResSize0 == tilingData_.batchSize && sampledResSize1 == tilingData_.vocabSize, + "ApplyTopKTopPMinP: the size of sampledRes should be [", tilingData_.batchSize, ", ", + tilingData_.vocabSize, "], but now is [", sampledResSize0, ", ", sampledResSize1, "]."); + return ge::GRAPH_SUCCESS; +} + +void ApplyTopKTopPMinPTiling::SplitTask() +{ + tilingData_.loopDataNum = tilingData_.ubSize / BYTES_B32 / LOCAL_TENSOR_NUM / BYTES_PER_REPEAT * BYTES_PER_REPEAT; + tilingData_.coreNum = tilingData_.batchSize > tilingData_.coreNum ? tilingData_.coreNum : tilingData_.batchSize; + tilingData_.batchPerCore = tilingData_.batchSize / std::max(tilingData_.coreNum, static_cast(1)); + tilingData_.batchTailCore = tilingData_.batchSize - tilingData_.batchPerCore * tilingData_.coreNum; +} + +ge::graphStatus ApplyTopKTopPMinPTiling::DoTiling() +{ + if (CheckDtype() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + if (CheckShape() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + auto ascendcPlatform = *platform_ascendc::PlatformAscendCManager::GetInstance(); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); + TORCH_CHECK(aivNum != 0 && aivNum != 0, "num of core obtained is 0"); + tilingData_.coreNum = static_cast(aivNum); + + uint64_t ubSize = 0; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + tilingData_.ubSize = static_cast(ubSize) - SELECT_MODE_BYTES; + + auto socVersion = ascendcPlatform.GetSocVersion(); + TORCH_CHECK(socVersion == platform_ascendc::SocVersion::ASCEND910B || + socVersion == platform_ascendc::SocVersion::ASCEND910_93, + "soc version does not support ", (int32_t)socVersion); + + SplitTask(); + + // -------------set workspacesize----------------- + tilingInfo_->workspaceSize = static_cast(ascendcPlatform.GetLibApiWorkSpaceSize()) + + tilingData_.batchSize * tilingData_.vocabSize * BYTES_B32; + + // -------------set tilingkey----------------- + tilingData_.tilingKey = + G_DTYPE_MAP.at(tilingInfo_->opParamInfo.probs.dtype) * COEF_TEN + tilingInfo_->needMinPSample; + + return ge::GRAPH_SUCCESS; +} + +const ApplyTopKTopPMinPTilingData &ApplyTopKTopPMinPTiling::GetTilingData() const +{ + return tilingData_; +} +} // namespace sglang::ATKTPMPHost diff --git a/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h new file mode 100644 index 000000000..1b7cda228 --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h @@ -0,0 +1,84 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of + * the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_min_p_tiling.h + * \brief + */ + +#ifndef APPLY_TOP_K_TOP_P_MIN_P_TILING_H_ +#define APPLY_TOP_K_TOP_P_MIN_P_TILING_H_ + +#include "register/op_def_registry.h" +#include "register/tilingdata_base.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/tiling_api.h" +#include "apply_top_k_top_p_min_p_tiling_data.h" +#include "ge_helper.h" + +namespace sglang::ATKTPMPHost { +struct TensorParaInfo { + ge::DataType dtype; + c10::ArrayRef shape; +}; + +const std::map G_DTYPE_MAP = {{ge::DT_FLOAT, 1}, {ge::DT_FLOAT16, 2}, {ge::DT_BF16, 3}}; +// ------------------算子原型索引常量定义---------------- +// Dim Index +constexpr uint32_t DIM_IDX_ZERO = 0; +constexpr uint32_t DIM_IDX_ONE = 1; +// Dim Num +constexpr uint32_t DIM_NUM_ZERO = 0; +constexpr uint32_t DIM_NUM_ONE = 1; +constexpr uint32_t DIM_NUM_TWO = 2; + +constexpr int64_t COEF_TEN = 10; +constexpr int64_t BYTES_B32 = 4; +constexpr int64_t LOCAL_TENSOR_NUM = 4; +constexpr int64_t SELECT_MODE_BYTES = 8192; +constexpr int64_t BYTES_PER_REPEAT = 256; + +// -----------算子Tiling入参结构体定义--------------- +struct ApplyTopKTopPMinPParaInfo { + TensorParaInfo probs = {ge::DT_FLOAT, c10::ArrayRef{}}; + TensorParaInfo k = {ge::DT_INT32, c10::ArrayRef{}}; + TensorParaInfo p = {ge::DT_FLOAT, c10::ArrayRef{}}; + TensorParaInfo minP = {ge::DT_FLOAT, c10::ArrayRef{}}; + TensorParaInfo sampledRes = {ge::DT_FLOAT, c10::ArrayRef{}}; +}; + +// -----------算子Tiling入参信息类--------------- +class ApplyTopKTopPMinPTilingInfo +{ +public: + ApplyTopKTopPMinPParaInfo opParamInfo; + int64_t workspaceSize = 0; + int64_t needMinPSample = 0; +}; + +// ---------------算子Tiling类--------------- +class ApplyTopKTopPMinPTiling +{ +public: + explicit ApplyTopKTopPMinPTiling(ApplyTopKTopPMinPTilingInfo *tilingInfo) : tilingInfo_(tilingInfo) {}; + ge::graphStatus CheckDtype(); + ge::graphStatus CheckShape(); + void SplitTask(); + ge::graphStatus DoTiling(); + const ApplyTopKTopPMinPTilingData &GetTilingData() const; + +private: + ApplyTopKTopPMinPTilingInfo *tilingInfo_ = nullptr; + ApplyTopKTopPMinPTilingData tilingData_; +}; + +} // namespace sglang::ATKTPMPHost +#endif // APPLY_TOP_K_TOP_P_MIN_P_TILING_H_ diff --git a/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling_data.h b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling_data.h new file mode 100644 index 000000000..4d0e15817 --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling_data.h @@ -0,0 +1,23 @@ +#ifndef APPLY_TOP_K_TOP_P_MIN_P_TILING_DATA_H +#define APPLY_TOP_K_TOP_P_MIN_P_TILING_DATA_H +#include + +namespace sglang { +namespace ATKTPMPHost { + +// -----------算子TilingData定义--------------- +#pragma pack(push, 1) +struct ApplyTopKTopPMinPTilingData { + int64_t batchSize = 0; + int64_t vocabSize = 0; + int64_t batchPerCore = 0; + int64_t batchTailCore = 0; + int64_t ubSize = 0; + int64_t coreNum = 0; + int64_t loopDataNum = 0; + int64_t tilingKey = 0; +}; +#pragma pack(pop) +} // namespace ATKTPMPHost +} // namespace sglang +#endif diff --git a/csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp b/csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp new file mode 100644 index 000000000..e44b615d2 --- /dev/null +++ b/csrc/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp @@ -0,0 +1,440 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of + * the software repository for the full text of the License. + */ + +/*! + * \file apply_top_k_top_p_min_p_kernel.cpp + * \brief + */ + +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "../op_host/tiling/apply_top_k_top_p_min_p_tiling_data.h" + +using namespace AscendC; +using namespace sglang::ATKTPMPHost; +#define TTILING_FP32_WITHOUT_MIN_P 10 +#define TTILING_FP16_WITHOUT_MIN_P 20 +#define TTILING_BF16_WITHOUT_MIN_P 30 +#define TTILING_FP32_MIN_P 11 +#define TTILING_FP16_MIN_P 21 +#define TTILING_BF16_MIN_P 31 + +namespace sglang::npu_kernel::ApplyTopKTopPMinPKernel { +template +__aicore__ inline void SetWaitFlag(HardEvent evt) +{ + event_t eventId = static_cast(GetTPipePtr()->FetchEventID(evt)); + SetFlag(eventId); + WaitFlag(eventId); +} + +template +class ApplyTopKTopPMinP +{ +public: + __aicore__ inline ApplyTopKTopPMinP(){}; + __aicore__ inline void Init(const __gm__ ApplyTopKTopPMinPTilingData *tilingData, GM_ADDR probs, GM_ADDR k, + GM_ADDR p, GM_ADDR min_p, GM_ADDR sampled_res, GM_ADDR workspace, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyInWithCast(LocalTensor &localTensor, GlobalTensor &globalTensor, + int64_t gmOffset, int64_t copyNum); + __aicore__ inline void InitWorkspace(LocalTensor &workLocal); + __aicore__ inline void GetFloatValue(GlobalTensor &globalTensor, int64_t offset, float &value); + __aicore__ inline void CumSumImpl(LocalTensor &cumSumInput1Local, LocalTensor &cumSumInput2Local); + __aicore__ inline void CopyOutWithCast(GlobalTensor &globalTensor, LocalTensor &localTensor, + int64_t gmOffset, int64_t copyNum); + __aicore__ inline void TopPProcess(LocalTensor &cumSumInput1Local, LocalTensor &cumSumInput2Local, + LocalTensor &zeroLocal, LocalTensor &maskLocal); + __aicore__ inline void MinPProcess(LocalTensor &workLocal, LocalTensor &zeroLocal, + LocalTensor &maskLocal); + +private: + TBuf calBuf_; + + // tilingData + int64_t batchSize_ = 0; + int64_t vocabSize_ = 0; + int64_t batchPerCore_ = 0; + int64_t batchTailCore_ = 0; + uint32_t blockIdx_ = 0; + int64_t coreBatch_ = 0; + int64_t batchOffset_ = 0; + int64_t coreNum_ = 0; + int64_t iterateTimes_ = 0; + int64_t baseGmOffset_ = 0; + int64_t probsGmOffset_ = 0; + + int32_t kValue_ = 0; + float pValue_ = 0; + float minPValue_ = 0; + float maxValue_ = 0; + float minPThresholds_ = 0; + int64_t lastIndex_ = 0; + int64_t kLoopNum_ = 0; + int64_t kTailNum_ = 0; + int64_t loopDataNum_ = 0; + + GlobalTensor gmProbs_; + GlobalTensor gmK_; + GlobalTensor gmP_; + GlobalTensor gmMinP_; + GlobalTensor gmSampledRes_; + + GlobalTensor gmWk_; +}; + +template +__aicore__ inline void ApplyTopKTopPMinP::Init(const __gm__ ApplyTopKTopPMinPTilingData *tilingData, + GM_ADDR probs, GM_ADDR k, GM_ADDR p, GM_ADDR min_p, + GM_ADDR sampled_res, GM_ADDR workspace, TPipe *tPipe) +{ + batchSize_ = tilingData->batchSize; + vocabSize_ = tilingData->vocabSize; + batchPerCore_ = tilingData->batchPerCore; + batchTailCore_ = tilingData->batchTailCore; + coreNum_ = tilingData->coreNum; + uint32_t ubSize = static_cast(tilingData->ubSize); + loopDataNum_ = tilingData->loopDataNum; + + blockIdx_ = GetBlockIdx(); + if (blockIdx_ >= coreNum_) { + return; + } + + if (blockIdx_ < batchTailCore_) { + coreBatch_ = batchPerCore_ + 1; + batchOffset_ = blockIdx_ * coreBatch_; + } else { + coreBatch_ = batchPerCore_; + batchOffset_ = blockIdx_ * batchPerCore_ + batchTailCore_; + } + baseGmOffset_ = batchOffset_ * vocabSize_; + + if (coreBatch_ == 0) { + return; + } + + gmProbs_.SetGlobalBuffer((__gm__ T *)probs + baseGmOffset_); + gmK_.SetGlobalBuffer((__gm__ int32_t *)k + batchOffset_); + gmP_.SetGlobalBuffer((__gm__ T *)p + batchOffset_); + if constexpr (IsMinPSampling == 1) { + gmMinP_.SetGlobalBuffer((__gm__ T *)min_p + batchOffset_); + } + gmSampledRes_.SetGlobalBuffer((__gm__ T *)sampled_res + baseGmOffset_); + gmWk_.SetGlobalBuffer((__gm__ float *)workspace + baseGmOffset_); + InitGlobalMemory(gmSampledRes_, coreBatch_ * vocabSize_, T(0)); + InitGlobalMemory(gmWk_, coreBatch_ * vocabSize_, float(0)); + tPipe->InitBuffer(calBuf_, ubSize); +} + +template +__aicore__ inline void ApplyTopKTopPMinP::GetFloatValue(GlobalTensor &globalTensor, + int64_t offset, float &value) +{ + if constexpr (IsSameType::value) { + value = globalTensor.GetValue(offset); + } else if constexpr (IsSameType::value) { + value = static_cast(globalTensor.GetValue(offset)); + } else { + value = ToFloat(globalTensor.GetValue(offset)); + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::Process() +{ + if (blockIdx_ >= coreNum_) { + return; + } + uint32_t bufOffset = 0; + LocalTensor maskLocal = calBuf_.GetWithOffset(loopDataNum_ / 8, bufOffset); + bufOffset += loopDataNum_ / 8 * sizeof(uint8_t); + LocalTensor zeroLocal = calBuf_.GetWithOffset(loopDataNum_, bufOffset); + bufOffset += loopDataNum_ * sizeof(float); + LocalTensor cumSumInput1Local = calBuf_.GetWithOffset(loopDataNum_, bufOffset); + bufOffset += loopDataNum_ * sizeof(float); + LocalTensor cumSumInput2Local = calBuf_.GetWithOffset(loopDataNum_, bufOffset); + bufOffset += loopDataNum_ * sizeof(float); + + Duplicate(zeroLocal, float(0), loopDataNum_); + SetWaitFlag(HardEvent::MTE3_MTE2); + for (int64_t batchLoop = 0; batchLoop < coreBatch_; batchLoop++) { + probsGmOffset_ = batchLoop * vocabSize_; + + kValue_ = gmK_.GetValue(batchLoop); + GetFloatValue(gmP_, batchLoop, pValue_); + if constexpr (IsMinPSampling == 1) { + GetFloatValue(gmMinP_, batchLoop, minPValue_); + GetFloatValue(gmProbs_, probsGmOffset_, maxValue_); + minPThresholds_ = maxValue_ * minPValue_; + } + if (kValue_ > vocabSize_ || kValue_ < 0) { + lastIndex_ = vocabSize_; + } else { + lastIndex_ = kValue_; + } + kLoopNum_ = lastIndex_ / loopDataNum_; + kTailNum_ = lastIndex_ - kLoopNum_ * loopDataNum_; + + InitWorkspace(cumSumInput1Local); + + TopPProcess(cumSumInput1Local, cumSumInput2Local, zeroLocal, maskLocal); + + if constexpr (IsMinPSampling == 1) { + MinPProcess(cumSumInput1Local, zeroLocal, maskLocal); + } + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::CopyInWithCast(LocalTensor &localTensor, + GlobalTensor &globalTensor, + int64_t gmOffset, int64_t copyNum) +{ + if constexpr (IsSameType::value) { + DataCopyPad(localTensor, globalTensor[gmOffset], {1, static_cast(copyNum * sizeof(T)), 0, 0, 0}, + {false, 0, 0, 0}); + } else { + DataCopyPad(localTensor.ReinterpretCast()[loopDataNum_], globalTensor[gmOffset], + {1, static_cast(copyNum * sizeof(T)), 0, 0, 0}, {false, 0, 0, 0}); + SetWaitFlag(HardEvent::MTE2_V); + Cast(localTensor, localTensor.ReinterpretCast()[loopDataNum_], RoundMode::CAST_NONE, copyNum); + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::InitWorkspace(LocalTensor &workLocal) +{ + for (int64_t vocabLoop = 0; vocabLoop < kLoopNum_; vocabLoop++) { + CopyInWithCast(workLocal, gmProbs_, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::MTE2_MTE3); + } else { + SetWaitFlag(HardEvent::V_MTE3); + } + DataCopyPad(gmWk_[probsGmOffset_ + vocabLoop * loopDataNum_], workLocal, + {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0}); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + if (kTailNum_ > 0) { + CopyInWithCast(workLocal, gmProbs_, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::MTE2_MTE3); + } else { + SetWaitFlag(HardEvent::V_MTE3); + } + DataCopyPad(gmWk_[probsGmOffset_ + kLoopNum_ * loopDataNum_], workLocal, + {1, static_cast(kTailNum_ * sizeof(float)), 0, 0, 0}); + SetWaitFlag(HardEvent::MTE3_MTE2); + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::CumSumImpl(LocalTensor &cumSumInput1Local, + LocalTensor &cumSumInput2Local) +{ + int64_t tmpValue = 1; + iterateTimes_ = 0; + while (tmpValue < lastIndex_) { + tmpValue <<= 1; + iterateTimes_++; + } + for (int64_t iterateTime = 0; iterateTime < iterateTimes_; iterateTime++) { + int64_t iteratOffset = 1; + for (int64_t powerIdx = 0; powerIdx < iterateTime; powerIdx++) { + iteratOffset = iteratOffset * 2; + } + int64_t addLength = lastIndex_ - iteratOffset; + int64_t innerLoopNum = addLength / loopDataNum_; + int64_t dataTail = addLength - innerLoopNum * loopDataNum_; + for (int64_t innerLoop = 0; innerLoop < innerLoopNum; innerLoop++) { + int64_t innerLoopOffset = dataTail + (innerLoopNum - 1 - innerLoop) * loopDataNum_; + DataCopyPad(cumSumInput1Local, gmWk_[probsGmOffset_ + innerLoopOffset], + {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + DataCopyPad(cumSumInput2Local, gmWk_[probsGmOffset_ + innerLoopOffset + iteratOffset], + {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + SetWaitFlag(HardEvent::MTE2_V); + Add(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum_); + SetWaitFlag(HardEvent::V_MTE3); + DataCopyPad(gmWk_[probsGmOffset_ + innerLoopOffset + iteratOffset], cumSumInput1Local, + {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0}); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + if (dataTail > 0) { + DataCopyPad(cumSumInput1Local, gmWk_[probsGmOffset_], + {1, static_cast(dataTail * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + DataCopyPad(cumSumInput2Local, gmWk_[probsGmOffset_ + iteratOffset], + {1, static_cast(dataTail * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + SetWaitFlag(HardEvent::MTE2_V); + Add(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, dataTail); + SetWaitFlag(HardEvent::V_MTE3); + DataCopyPad(gmWk_[probsGmOffset_ + iteratOffset], cumSumInput1Local, + {1, static_cast(dataTail * sizeof(float)), 0, 0, 0}); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::CopyOutWithCast(GlobalTensor &globalTensor, + LocalTensor &localTensor, + int64_t gmOffset, int64_t copyNum) +{ + if constexpr (IsSameType::value) { + DataCopyPad(globalTensor[gmOffset], localTensor, {1, static_cast(copyNum * sizeof(T)), 0, 0, 0}); + } else { + Cast(localTensor.ReinterpretCast(), localTensor, RoundMode::CAST_ROUND, copyNum); + SetWaitFlag(HardEvent::V_MTE3); + DataCopyPad(globalTensor[gmOffset], localTensor.ReinterpretCast(), + {1, static_cast(copyNum * sizeof(T)), 0, 0, 0}); + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::TopPProcess(LocalTensor &cumSumInput1Local, + LocalTensor &cumSumInput2Local, + LocalTensor &zeroLocal, + LocalTensor &maskLocal) +{ + CumSumImpl(cumSumInput1Local, cumSumInput2Local); + for (int64_t vocabLoop = 0; vocabLoop < kLoopNum_; vocabLoop++) { + DataCopyPad(cumSumInput1Local, gmWk_[probsGmOffset_ + vocabLoop * loopDataNum_], + {1, static_cast(loopDataNum_ * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + CopyInWithCast(cumSumInput2Local, gmProbs_, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::MTE2_V); + } else { + PipeBarrier(); + } + Sub(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum_); + PipeBarrier(); + CompareScalar(maskLocal, cumSumInput1Local, pValue_, CMPMODE::GT, loopDataNum_); + PipeBarrier(); + Select(cumSumInput2Local, maskLocal, zeroLocal, cumSumInput2Local, SELMODE::VSEL_TENSOR_TENSOR_MODE, + loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::V_MTE3); + } else { + PipeBarrier(); + } + CopyOutWithCast(gmSampledRes_, cumSumInput2Local, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + if (kTailNum_ > 0) { + DataCopyPad(cumSumInput1Local, gmWk_[probsGmOffset_ + kLoopNum_ * loopDataNum_], + {1, static_cast(kTailNum_ * sizeof(float)), 0, 0, 0}, {false, 0, 0, 0}); + CopyInWithCast(cumSumInput2Local, gmProbs_, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::MTE2_V); + } else { + PipeBarrier(); + } + Sub(cumSumInput1Local, cumSumInput1Local, cumSumInput2Local, loopDataNum_); + PipeBarrier(); + CompareScalar(maskLocal, cumSumInput1Local, pValue_, CMPMODE::GT, loopDataNum_); + PipeBarrier(); + Select(cumSumInput2Local, maskLocal, zeroLocal, cumSumInput2Local, SELMODE::VSEL_TENSOR_TENSOR_MODE, + loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::V_MTE3); + } else { + PipeBarrier(); + } + CopyOutWithCast(gmSampledRes_, cumSumInput2Local, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_); + SetWaitFlag(HardEvent::MTE3_MTE2); + } +} + +template +__aicore__ inline void ApplyTopKTopPMinP::MinPProcess(LocalTensor &workLocal, + LocalTensor &zeroLocal, + LocalTensor &maskLocal) +{ + for (int64_t vocabLoop = 0; vocabLoop < kLoopNum_; vocabLoop++) { + CopyInWithCast(workLocal, gmSampledRes_, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::MTE2_V); + } else { + PipeBarrier(); + } + CompareScalar(maskLocal, workLocal, minPThresholds_, CMPMODE::LT, loopDataNum_); + PipeBarrier(); + Select(workLocal, maskLocal, zeroLocal, workLocal, SELMODE::VSEL_TENSOR_TENSOR_MODE, loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::V_MTE3); + } else { + PipeBarrier(); + } + CopyOutWithCast(gmSampledRes_, workLocal, probsGmOffset_ + vocabLoop * loopDataNum_, loopDataNum_); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + if (kTailNum_ > 0) { + CopyInWithCast(workLocal, gmSampledRes_, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::MTE2_V); + } else { + PipeBarrier(); + } + CompareScalar(maskLocal, workLocal, minPThresholds_, CMPMODE::LT, loopDataNum_); + PipeBarrier(); + Select(workLocal, maskLocal, zeroLocal, workLocal, SELMODE::VSEL_TENSOR_TENSOR_MODE, loopDataNum_); + if constexpr (IsSameType::value) { + SetWaitFlag(HardEvent::V_MTE3); + } else { + PipeBarrier(); + } + CopyOutWithCast(gmSampledRes_, workLocal, probsGmOffset_ + kLoopNum_ * loopDataNum_, kTailNum_); + SetWaitFlag(HardEvent::MTE3_MTE2); + } +} +} // namespace sglang::npu_kernel::ApplyTopKTopPMinPKernel + +__global__ __aicore__ void apply_top_k_top_p_min_p(GM_ADDR probs, GM_ADDR k, GM_ADDR p, GM_ADDR min_p, + GM_ADDR sampled_res, GM_ADDR workspace, GM_ADDR tiling) +{ +#define INIT_AND_PROCESS \ + op.Init(tilingData, probs, k, p, min_p, sampled_res, userWS, &tPipe); \ + op.Process(); + + AscendC::TPipe tPipe; + using namespace sglang::npu_kernel::ApplyTopKTopPMinPKernel; + + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY); + + auto tilingData = reinterpret_cast<__gm__ sglang::ATKTPMPHost::ApplyTopKTopPMinPTilingData *>(tiling); + auto tilingKey = tilingData->tilingKey; + GM_ADDR userWS = GetUserWorkspace(workspace); + if (userWS == nullptr) { + return; + } + + if (tilingKey == TTILING_FP32_WITHOUT_MIN_P) { + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; + } else if (tilingKey == TTILING_FP16_WITHOUT_MIN_P) { + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; + } else if (tilingKey == TTILING_BF16_WITHOUT_MIN_P) { + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; + } else if (tilingKey == TTILING_FP32_MIN_P) { + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; + } else if (tilingKey == TTILING_FP16_MIN_P) { + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; + } else if (tilingKey == TTILING_BF16_MIN_P) { + ApplyTopKTopPMinP op; + INIT_AND_PROCESS; + } +} diff --git a/csrc/pytorch_extensions.cpp b/csrc/pytorch_extensions.cpp index a7d093730..5d88dc1df 100644 --- a/csrc/pytorch_extensions.cpp +++ b/csrc/pytorch_extensions.cpp @@ -98,6 +98,8 @@ TORCH_LIBRARY_FRAGMENT(npu, m) "int? sparse_count=None, int? sparse_mode=None) -> Tensor"); m.def("triangular_inverse(Tensor x) -> Tensor"); + + m.def("apply_top_k_top_p_min_p(Tensor logits, Tensor k, Tensor p, Tensor? min_p=None) -> Tensor"); } } // namespace @@ -141,5 +143,7 @@ TORCH_LIBRARY_IMPL(npu, PrivateUse1, m) m.impl("lightning_indexer", TORCH_FN(sglang::npu_kernel::lightning_indexer)); m.impl("triangular_inverse", TORCH_FN(sglang::npu_kernel::tri_inv_col_sweep)); + + m.impl("apply_top_k_top_p_min_p", TORCH_FN(sglang::npu_kernel::apply_top_k_top_p_min_p)); } } // namespace diff --git a/include/sgl_kenel_npu_ops.h b/include/sgl_kenel_npu_ops.h index e2f5a6804..e3aeda613 100644 --- a/include/sgl_kenel_npu_ops.h +++ b/include/sgl_kenel_npu_ops.h @@ -123,6 +123,10 @@ at::Tensor lightning_indexer( * is inversed. */ at::Tensor tri_inv_col_sweep(const at::Tensor &tensor_in); + +at::Tensor apply_top_k_top_p_min_p(const at::Tensor &logits, + const at::Tensor &k, const at::Tensor &p, + const c10::optional &min_p); } // namespace npu_kernel } // namespace sglang diff --git a/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py b/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py new file mode 100644 index 000000000..ef1929ea8 --- /dev/null +++ b/tests/python/sgl_kernel_npu/test_apply_top_k_top_p_min_p.py @@ -0,0 +1,125 @@ +# This program is free software, you can redistribute it and/or modify it. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import unittest + +import numpy as np +import sgl_kernel_npu +import torch +import torch.nn as nn +import torch_npu +import torchair + +DEVICE_ID = 0 +torch_npu.npu.set_device(int(DEVICE_ID)) +DTYPE_ATOL = {torch.float32: 0.0001, torch.float16: 0.001, torch.bfloat16: 0.004} + + +def _apply_top_k_top_p_min_p( + probs_sort, + k, + p, + min_p=None, +): + probs_sort_out = probs_sort.clone().to(torch.float32) + top_k_mask = torch.arange(0, probs_sort.shape[-1], device=probs_sort.device).view( + 1, -1 + ) >= k.view(-1, 1) + probs_sort_out.masked_fill_(top_k_mask, 0.0) + + probs_sum = torch.cumsum(probs_sort_out, dim=-1) + top_p_mask = probs_sum - probs_sort_out > p.view(-1, 1) + probs_sort_out.masked_fill_(top_p_mask, 0.0) + + if min_p is not None: + min_p_thresholds = probs_sort_out[:, 0] * min_p + min_p_mask = probs_sort_out < min_p_thresholds.view(-1, 1) + probs_sort_out.masked_fill_(min_p_mask, 0.0) + return probs_sort_out.to(probs_sort.dtype) + + +class TestCustomApplyTopKTopPMinP(unittest.TestCase): + def test_apply_top_k_top_p_min_p_eager(self): + batch_size = 4 + vocab_size = 131072 + + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + np.random.seed(3) + logits = torch.tensor( + np.random.uniform(-10, 10, (batch_size, vocab_size)) + ).to(dtype) + k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to( + torch.int32 + ) + p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype) + min_p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype) + probs = torch.softmax(logits, dim=-1) + probs_sort, probs_idx = probs.sort(dim=-1, descending=True, stable=True) + cpu_out = _apply_top_k_top_p_min_p(probs_sort, k, p, min_p) + + torch_npu.npu.set_device(int(DEVICE_ID)) + probs_sort = probs_sort.to("npu:%s" % DEVICE_ID) + k = k.to("npu:%s" % DEVICE_ID) + p = p.to("npu:%s" % DEVICE_ID) + min_p = min_p.to("npu:%s" % DEVICE_ID) + + npu_out = torch.ops.npu.apply_top_k_top_p_min_p( + probs_sort, k, p, min_p=min_p + ) + + # compare result + npu_out = npu_out.cpu() + cpu_out = cpu_out.cpu() + tol = DTYPE_ATOL[dtype] + assert torch.allclose( + cpu_out.to(torch.float32), + npu_out.to(torch.float32), + atol=tol, + rtol=tol, + ) + + def test_apply_top_k_top_p_min_p_eager_without_min_p(self): + batch_size = 4 + vocab_size = 131072 + + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + np.random.seed(4) + logits = torch.tensor( + np.random.uniform(-10, 10, (batch_size, vocab_size)) + ).to(dtype) + k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to( + torch.int32 + ) + p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype) + probs = torch.softmax(logits, dim=-1) + probs_sort, probs_idx = probs.sort(dim=-1, descending=True, stable=True) + cpu_out = _apply_top_k_top_p_min_p(probs_sort, k, p, min_p=None) + + torch_npu.npu.set_device(int(DEVICE_ID)) + probs_sort = probs_sort.to("npu:%s" % DEVICE_ID) + k = k.to("npu:%s" % DEVICE_ID) + p = p.to("npu:%s" % DEVICE_ID) + + npu_out = torch.ops.npu.apply_top_k_top_p_min_p( + probs_sort, k, p, min_p=None + ) + + # compare result + npu_out = npu_out.cpu() + cpu_out = cpu_out.cpu() + tol = DTYPE_ATOL[dtype] + assert torch.allclose( + cpu_out.to(torch.float32), + npu_out.to(torch.float32), + atol=tol, + rtol=tol, + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2)