Skip to content

Commit 9211cf8

Browse files
committed
fix lint
1 parent 8ff3939 commit 9211cf8

9 files changed

Lines changed: 124 additions & 137 deletions

File tree

csrc/apply_top_k_top_p_min_p/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ A top-k, top-p and min-p sampling implementation for ascend.
1212

1313
## Calculation Formula
1414
$$
15-
sampled\_res[b][v] =
15+
sampled\_res[b][v] =
1616
\begin{cases}
1717
0 & \text{v >= k[b]} \\
1818
probs[b][v] & \text{v < k[b]}
@@ -21,15 +21,15 @@ $$
2121
$$probs\_sum = cumsum(sampled\_res, dim=-1)$$
2222
$$top\_p\_mask[b][v] = probs\_sum[b][v] - sampled\_res[b][v] > p[b]$$
2323
$$
24-
sampled\_res[b][v] =
24+
sampled\_res[b][v] =
2525
\begin{cases}
2626
0 & \text{top\_p\_mask = True} \\
2727
sampled\_res[b][v] & \text{top\_p\_mask = False}
2828
\end{cases}
2929
$$
3030
$$min\_p\_mask[b][v] = sampled\_res[b][v] < sampled\_res[b][0] * min\_p[b]$$
3131
$$
32-
sampled\_res[b][v] =
32+
sampled\_res[b][v] =
3333
\begin{cases}
3434
0 & \text{min\_p\_mask = True} \\
3535
sampled\_res[b][v] & \text{min\_p\_mask = False}

csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,9 @@ HOST_API at::Tensor apply_top_k_top_p_min_p(const at::Tensor &probs, const at::T
4040

4141
auto probsType = probs.scalar_type();
4242

43-
at::Tensor minP =
44-
min_p.has_value()
45-
? min_p.value()
46-
: at::empty({1}, at::TensorOptions().dtype(probsType).device(probs.options().device()));
43+
at::Tensor minP = min_p.has_value()
44+
? min_p.value()
45+
: at::empty({1}, at::TensorOptions().dtype(probsType).device(probs.options().device()));
4746

4847
ApplyTopKTopPMinPTilingInfo applyTopKTopPMinPInfo;
4948
applyTopKTopPMinPInfo.opParamInfo.probs.dtype = SCALAR_TYPE_TO_GE_DATATYPE(probsType);
@@ -60,9 +59,8 @@ HOST_API at::Tensor apply_top_k_top_p_min_p(const at::Tensor &probs, const at::T
6059
applyTopKTopPMinPInfo.opParamInfo.sampledRes.shape = sampledRes.sizes();
6160

6261
ApplyTopKTopPMinPTiling applyTopKTopPMinPTiling(&applyTopKTopPMinPInfo);
63-
TORCH_CHECK(applyTopKTopPMinPTiling.DoTiling() == ge::GRAPH_SUCCESS,
64-
"apply_top_k_top_p_min_p DoTiling failed")
65-
62+
TORCH_CHECK(applyTopKTopPMinPTiling.DoTiling() == ge::GRAPH_SUCCESS, "apply_top_k_top_p_min_p DoTiling failed");
63+
6664
const auto &tilingData = applyTopKTopPMinPTiling.GetTilingData();
6765

6866
uint32_t tilingSize = (sizeof(ApplyTopKTopPMinPTiling) + PADDING_BYTE - 1) / PADDING_BYTE * PADDING_BYTE;
@@ -71,9 +69,9 @@ HOST_API at::Tensor apply_top_k_top_p_min_p(const at::Tensor &probs, const at::T
7169
at::empty({tilingSize}, at::TensorOptions().dtype(at::kByte).device(probs.options().device()));
7270
aclrtMemcpy(tilingBuffer.data_ptr<uint8_t>(), tilingSize, &tilingData, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
7371
at::Tensor tilingTensor = at::from_blob(tilingBuffer.data_ptr<uint8_t>(), tilingSize, at::kByte);
74-
72+
7573
auto workspace = at::empty({applyTopKTopPMinPInfo.workspaceSize},
76-
at::TensorOptions().dtype(at::kByte).device(probs.options().device()));
74+
at::TensorOptions().dtype(at::kByte).device(probs.options().device()));
7775
EXEC_KERNEL_CMD(apply_top_k_top_p_min_p, blockDim, probs, k, p, minP, sampledRes, workspace, tilingTensor);
7876
return sampledRes;
7977
}

csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,7 @@ class ApplyTopKTopPMinP : public OpDef
2929
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
3030
.FormatList({ge::FORMAT_ND})
3131
.AutoContiguous();
32-
this->Input("k")
33-
.ParamType(REQUIRED)
34-
.DataTypeList({ge::DT_INT32})
35-
.FormatList({ge::FORMAT_ND})
36-
.AutoContiguous();
32+
this->Input("k").ParamType(REQUIRED).DataTypeList({ge::DT_INT32}).FormatList({ge::FORMAT_ND}).AutoContiguous();
3733
this->Input("p")
3834
.ParamType(REQUIRED)
3935
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})

csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,16 @@ namespace sglang::ATKTPMPHost {
2626
ge::graphStatus ApplyTopKTopPMinPTiling::CheckDtype()
2727
{
2828
TORCH_CHECK((tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT16) ||
29-
(tilingInfo_->opParamInfo.probs.dtype == ge::DT_BF16) ||
30-
(tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT),
29+
(tilingInfo_->opParamInfo.probs.dtype == ge::DT_BF16) ||
30+
(tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT),
3131
"The data types of probs, p and sampled_res must be float16, bfloat16 or float.");
3232

3333
TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.p.dtype,
3434
"The data types of probs and p must be the same.");
3535
TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.sampledRes.dtype,
3636
"The data types of probs and sampled_res must be the same.");
37-
38-
TORCH_CHECK(tilingInfo_->opParamInfo.k.dtype == ge::DT_INT32,
39-
"The data types of the input k must be int32.");
37+
38+
TORCH_CHECK(tilingInfo_->opParamInfo.k.dtype == ge::DT_INT32, "The data types of the input k must be int32.");
4039

4140
return ge::GRAPH_SUCCESS;
4241
}
@@ -49,19 +48,17 @@ ge::graphStatus ApplyTopKTopPMinPTiling::CheckShape()
4948
tilingData_.batchSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ZERO];
5049
tilingData_.vocabSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ONE];
5150

52-
TORCH_CHECK(tilingInfo_->opParamInfo.k.shape.size() == DIM_NUM_ONE,
53-
"ApplyTopKTopPMinP: the dimNum of k should be ", DIM_NUM_ONE, ", but now is ",
54-
tilingInfo_->opParamInfo.k.shape.size(), ".");
51+
TORCH_CHECK(tilingInfo_->opParamInfo.k.shape.size() == DIM_NUM_ONE, "ApplyTopKTopPMinP: the dimNum of k should be ",
52+
DIM_NUM_ONE, ", but now is ", tilingInfo_->opParamInfo.k.shape.size(), ".");
5553
int64_t kSize = tilingInfo_->opParamInfo.k.shape[DIM_IDX_ZERO];
56-
TORCH_CHECK(kSize == tilingData_.batchSize,
57-
"ApplyTopKTopPMinP: the shape of k should be [", tilingData_.batchSize, "], but now is [", kSize, "].");
58-
59-
TORCH_CHECK(tilingInfo_->opParamInfo.p.shape.size() == DIM_NUM_ONE,
60-
"ApplyTopKTopPMinP: the dimNum of p should be ", DIM_NUM_ONE, ", but now is ",
61-
tilingInfo_->opParamInfo.p.shape.size(), ".");
54+
TORCH_CHECK(kSize == tilingData_.batchSize, "ApplyTopKTopPMinP: the shape of k should be [", tilingData_.batchSize,
55+
"], but now is [", kSize, "].");
56+
57+
TORCH_CHECK(tilingInfo_->opParamInfo.p.shape.size() == DIM_NUM_ONE, "ApplyTopKTopPMinP: the dimNum of p should be ",
58+
DIM_NUM_ONE, ", but now is ", tilingInfo_->opParamInfo.p.shape.size(), ".");
6259
int64_t pSize = tilingInfo_->opParamInfo.p.shape[DIM_IDX_ZERO];
63-
TORCH_CHECK(pSize == tilingData_.batchSize,
64-
"ApplyTopKTopPMinP: the shape of p should be [", tilingData_.batchSize, "], but now is [", pSize, "].");
60+
TORCH_CHECK(pSize == tilingData_.batchSize, "ApplyTopKTopPMinP: the shape of p should be [", tilingData_.batchSize,
61+
"], but now is [", pSize, "].");
6562

6663
if (tilingInfo_->opParamInfo.minP.shape.size() != DIM_NUM_ZERO) {
6764
int64_t minPSize = tilingInfo_->opParamInfo.minP.shape[DIM_IDX_ZERO];
@@ -76,9 +73,8 @@ ge::graphStatus ApplyTopKTopPMinPTiling::CheckShape()
7673
int64_t sampledResSize0 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ZERO];
7774
int64_t sampledResSize1 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ONE];
7875
TORCH_CHECK(sampledResSize0 == tilingData_.batchSize && sampledResSize1 == tilingData_.vocabSize,
79-
"ApplyTopKTopPMinP: the size of sampledRes should be [",
80-
tilingData_.batchSize, ", ", tilingData_.vocabSize,
81-
"], but now is [", sampledResSize0, ", ", sampledResSize1, "].");
76+
"ApplyTopKTopPMinP: the size of sampledRes should be [", tilingData_.batchSize, ", ",
77+
tilingData_.vocabSize, "], but now is [", sampledResSize0, ", ", sampledResSize1, "].");
8278
return ge::GRAPH_SUCCESS;
8379
}
8480

@@ -111,18 +107,18 @@ ge::graphStatus ApplyTopKTopPMinPTiling::DoTiling()
111107

112108
auto socVersion = ascendcPlatform.GetSocVersion();
113109
TORCH_CHECK(socVersion == platform_ascendc::SocVersion::ASCEND910B ||
114-
socVersion == platform_ascendc::SocVersion::ASCEND910_93,
110+
socVersion == platform_ascendc::SocVersion::ASCEND910_93,
115111
"soc version does not support ", (int32_t)socVersion);
116-
112+
117113
SplitTask();
118-
114+
119115
// -------------set workspacesize-----------------
120116
tilingInfo_->workspaceSize = static_cast<int64_t>(ascendcPlatform.GetLibApiWorkSpaceSize()) +
121117
tilingData_.batchSize * tilingData_.vocabSize * BYTES_B32;
122118

123119
// -------------set tilingkey-----------------
124-
tilingData_.tilingKey = G_DTYPE_MAP.at(tilingInfo_->opParamInfo.probs.dtype) * COEF_TEN +
125-
tilingInfo_->needMinPSample;
120+
tilingData_.tilingKey =
121+
G_DTYPE_MAP.at(tilingInfo_->opParamInfo.probs.dtype) * COEF_TEN + tilingInfo_->needMinPSample;
126122

127123
return ge::GRAPH_SUCCESS;
128124
}

csrc/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ struct TensorParaInfo {
3030
c10::ArrayRef<int64_t> shape;
3131
};
3232

33-
const std::map<ge::DataType, int64_t> G_DTYPE_MAP = {{ge::DT_FLOAT, 1},
34-
{ge::DT_FLOAT16, 2},
35-
{ge::DT_BF16, 3}};
33+
const std::map<ge::DataType, int64_t> G_DTYPE_MAP = {{ge::DT_FLOAT, 1}, {ge::DT_FLOAT16, 2}, {ge::DT_BF16, 3}};
3634
// ------------------算子原型索引常量定义----------------
3735
// Dim Index
3836
constexpr uint32_t DIM_IDX_ZERO = 0;

0 commit comments

Comments
 (0)