|
| 1 | +/** |
| 2 | + * This program is free software, you can redistribute it and/or modify it. |
| 3 | + * Copyright (c) 2026 Huawei Technologies Co., Ltd. |
| 4 | + * This file is a part of the CANN Open Software. |
| 5 | + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). |
| 6 | + * Please refer to the License for details. You may not use this file except in compliance with the License. |
| 7 | + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING |
| 8 | + * BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of |
| 9 | + * the software repository for the full text of the License. |
| 10 | + */ |
| 11 | + |
| 12 | +/*! |
| 13 | + * \file apply_top_k_top_p_min_p_tiling.cpp |
| 14 | + * \brief |
| 15 | + */ |
| 16 | + |
| 17 | +#include "apply_top_k_top_p_min_p_tiling.h" |
| 18 | + |
| 19 | +using namespace ge; |
| 20 | +using namespace AscendC; |
| 21 | +using std::map; |
| 22 | +using std::string; |
| 23 | +namespace sglang::ATKTPMPHost { |
| 24 | + |
| 25 | +// --------------------------ApplyTopKTopPMinPTiling类成员函数定义----------------------- |
| 26 | +ge::graphStatus ApplyTopKTopPMinPTiling::CheckDtype() |
| 27 | +{ |
| 28 | + TORCH_CHECK((tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT16) || |
| 29 | + (tilingInfo_->opParamInfo.probs.dtype == ge::DT_BF16) || |
| 30 | + (tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT), |
| 31 | + "The data types of probs, p and sampled_res must be float16, bfloat16 or float."); |
| 32 | + |
| 33 | + TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.p.dtype, |
| 34 | + "The data types of probs and p must be the same."); |
| 35 | + TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.sampledRes.dtype, |
| 36 | + "The data types of probs and sampled_res must be the same."); |
| 37 | + |
| 38 | + TORCH_CHECK(tilingInfo_->opParamInfo.k.dtype == ge::DT_INT32, |
| 39 | + "The data types of the input k must be int32."); |
| 40 | + |
| 41 | + return ge::GRAPH_SUCCESS; |
| 42 | +} |
| 43 | + |
| 44 | +ge::graphStatus ApplyTopKTopPMinPTiling::CheckShape() |
| 45 | +{ |
| 46 | + TORCH_CHECK(tilingInfo_->opParamInfo.probs.shape.size() == DIM_NUM_TWO, |
| 47 | + "ApplyTopKTopPMinP: the dimNum of probs should be ", DIM_NUM_TWO, ", but now is ", |
| 48 | + tilingInfo_->opParamInfo.probs.shape.size(), "."); |
| 49 | + tilingData_.batchSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ZERO]; |
| 50 | + tilingData_.vocabSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ONE]; |
| 51 | + |
| 52 | + TORCH_CHECK(tilingInfo_->opParamInfo.k.shape.size() == DIM_NUM_ONE, |
| 53 | + "ApplyTopKTopPMinP: the dimNum of k should be ", DIM_NUM_ONE, ", but now is ", |
| 54 | + tilingInfo_->opParamInfo.k.shape.size(), "."); |
| 55 | + int64_t kSize = tilingInfo_->opParamInfo.k.shape[DIM_IDX_ZERO]; |
| 56 | + TORCH_CHECK(kSize == tilingData_.batchSize, |
| 57 | + "ApplyTopKTopPMinP: the shape of k should be [", tilingData_.batchSize, "], but now is [", kSize, "]."); |
| 58 | + |
| 59 | + TORCH_CHECK(tilingInfo_->opParamInfo.p.shape.size() == DIM_NUM_ONE, |
| 60 | + "ApplyTopKTopPMinP: the dimNum of p should be ", DIM_NUM_ONE, ", but now is ", |
| 61 | + tilingInfo_->opParamInfo.p.shape.size(), "."); |
| 62 | + int64_t pSize = tilingInfo_->opParamInfo.p.shape[DIM_IDX_ZERO]; |
| 63 | + TORCH_CHECK(pSize == tilingData_.batchSize, |
| 64 | + "ApplyTopKTopPMinP: the shape of p should be [", tilingData_.batchSize, "], but now is [", pSize, "]."); |
| 65 | + |
| 66 | + if (tilingInfo_->opParamInfo.minP.shape.size() != DIM_NUM_ZERO) { |
| 67 | + int64_t minPSize = tilingInfo_->opParamInfo.minP.shape[DIM_IDX_ZERO]; |
| 68 | + TORCH_CHECK(minPSize == tilingData_.batchSize, ": the shape of p should be [", tilingData_.batchSize, |
| 69 | + "], but now is [", minPSize, "]."); |
| 70 | + tilingInfo_->needMinPSample = 1; |
| 71 | + } |
| 72 | + |
| 73 | + TORCH_CHECK(tilingInfo_->opParamInfo.sampledRes.shape.size() == DIM_NUM_TWO, |
| 74 | + "ApplyTopKTopPMinP: the dimNum of sampled_res should be ", DIM_NUM_TWO, ", but now is ", |
| 75 | + tilingInfo_->opParamInfo.sampledRes.shape.size(), "."); |
| 76 | + int64_t sampledResSize0 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ZERO]; |
| 77 | + int64_t sampledResSize1 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ONE]; |
| 78 | + TORCH_CHECK(sampledResSize0 == tilingData_.batchSize && sampledResSize1 == tilingData_.vocabSize, |
| 79 | + "ApplyTopKTopPMinP: the size of sampledRes should be [", |
| 80 | + tilingData_.batchSize, ", ", tilingData_.vocabSize, |
| 81 | + "], but now is [", sampledResSize0, ", ", sampledResSize1, "]."); |
| 82 | + return ge::GRAPH_SUCCESS; |
| 83 | +} |
| 84 | + |
| 85 | +void ApplyTopKTopPMinPTiling::SplitTask() |
| 86 | +{ |
| 87 | + tilingData_.loopDataNum = tilingData_.ubSize / BYTES_B32 / LOCAL_TENSOR_NUM / BYTES_PER_REPEAT * BYTES_PER_REPEAT; |
| 88 | + tilingData_.coreNum = tilingData_.batchSize > tilingData_.coreNum ? tilingData_.coreNum : tilingData_.batchSize; |
| 89 | + tilingData_.batchPerCore = tilingData_.batchSize / std::max(tilingData_.coreNum, static_cast<int64_t>(1)); |
| 90 | + tilingData_.batchTailCore = tilingData_.batchSize - tilingData_.batchPerCore * tilingData_.coreNum; |
| 91 | +} |
| 92 | + |
| 93 | +ge::graphStatus ApplyTopKTopPMinPTiling::DoTiling() |
| 94 | +{ |
| 95 | + if (CheckDtype() != ge::GRAPH_SUCCESS) { |
| 96 | + return ge::GRAPH_FAILED; |
| 97 | + } |
| 98 | + if (CheckShape() != ge::GRAPH_SUCCESS) { |
| 99 | + return ge::GRAPH_FAILED; |
| 100 | + } |
| 101 | + |
| 102 | + auto ascendcPlatform = *platform_ascendc::PlatformAscendCManager::GetInstance(); |
| 103 | + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); |
| 104 | + uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); |
| 105 | + TORCH_CHECK(aivNum != 0 && aivNum != 0, "num of core obtained is 0"); |
| 106 | + tilingData_.coreNum = static_cast<int64_t>(aivNum); |
| 107 | + |
| 108 | + uint64_t ubSize = 0; |
| 109 | + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); |
| 110 | + tilingData_.ubSize = static_cast<int64_t>(ubSize) - SELECT_MODE_BYTES; |
| 111 | + |
| 112 | + auto socVersion = ascendcPlatform.GetSocVersion(); |
| 113 | + TORCH_CHECK(socVersion == platform_ascendc::SocVersion::ASCEND910B || |
| 114 | + socVersion == platform_ascendc::SocVersion::ASCEND910_93, |
| 115 | + "soc version does not support ", (int32_t)socVersion); |
| 116 | + |
| 117 | + SplitTask(); |
| 118 | + |
| 119 | + // -------------set workspacesize----------------- |
| 120 | + tilingInfo_->workspaceSize = static_cast<int64_t>(ascendcPlatform.GetLibApiWorkSpaceSize()) + |
| 121 | + tilingData_.batchSize * tilingData_.vocabSize * BYTES_B32; |
| 122 | + |
| 123 | + // -------------set tilingkey----------------- |
| 124 | + tilingData_.tilingKey = G_DTYPE_MAP.at(tilingInfo_->opParamInfo.probs.dtype) * COEF_TEN + |
| 125 | + tilingInfo_->needMinPSample; |
| 126 | + |
| 127 | + return ge::GRAPH_SUCCESS; |
| 128 | +} |
| 129 | + |
| 130 | +const ApplyTopKTopPMinPTilingData &ApplyTopKTopPMinPTiling::GetTilingData() const |
| 131 | +{ |
| 132 | + return tilingData_; |
| 133 | +} |
| 134 | +} // namespace sglang::ATKTPMPHost |
0 commit comments