@@ -26,17 +26,16 @@ namespace sglang::ATKTPMPHost {
2626ge::graphStatus ApplyTopKTopPMinPTiling::CheckDtype ()
2727{
2828 TORCH_CHECK ((tilingInfo_->opParamInfo .probs .dtype == ge::DT_FLOAT16) ||
29- (tilingInfo_->opParamInfo .probs .dtype == ge::DT_BF16) ||
30- (tilingInfo_->opParamInfo .probs .dtype == ge::DT_FLOAT),
29+ (tilingInfo_->opParamInfo .probs .dtype == ge::DT_BF16) ||
30+ (tilingInfo_->opParamInfo .probs .dtype == ge::DT_FLOAT),
3131 " The data types of probs, p and sampled_res must be float16, bfloat16 or float." );
3232
3333 TORCH_CHECK (tilingInfo_->opParamInfo .probs .dtype == tilingInfo_->opParamInfo .p .dtype ,
3434 " The data types of probs and p must be the same." );
3535 TORCH_CHECK (tilingInfo_->opParamInfo .probs .dtype == tilingInfo_->opParamInfo .sampledRes .dtype ,
3636 " The data types of probs and sampled_res must be the same." );
37-
38- TORCH_CHECK (tilingInfo_->opParamInfo .k .dtype == ge::DT_INT32,
39- " The data types of the input k must be int32." );
37+
38+ TORCH_CHECK (tilingInfo_->opParamInfo .k .dtype == ge::DT_INT32, " The data types of the input k must be int32." );
4039
4140 return ge::GRAPH_SUCCESS;
4241}
@@ -49,19 +48,17 @@ ge::graphStatus ApplyTopKTopPMinPTiling::CheckShape()
4948 tilingData_.batchSize = tilingInfo_->opParamInfo .probs .shape [DIM_IDX_ZERO];
5049 tilingData_.vocabSize = tilingInfo_->opParamInfo .probs .shape [DIM_IDX_ONE];
5150
52- TORCH_CHECK (tilingInfo_->opParamInfo .k .shape .size () == DIM_NUM_ONE,
53- " ApplyTopKTopPMinP: the dimNum of k should be " , DIM_NUM_ONE, " , but now is " ,
54- tilingInfo_->opParamInfo .k .shape .size (), " ." );
51+ TORCH_CHECK (tilingInfo_->opParamInfo .k .shape .size () == DIM_NUM_ONE, " ApplyTopKTopPMinP: the dimNum of k should be " ,
52+ DIM_NUM_ONE, " , but now is " , tilingInfo_->opParamInfo .k .shape .size (), " ." );
5553 int64_t kSize = tilingInfo_->opParamInfo .k .shape [DIM_IDX_ZERO];
56- TORCH_CHECK (kSize == tilingData_.batchSize ,
57- " ApplyTopKTopPMinP: the shape of k should be [" , tilingData_.batchSize , " ], but now is [" , kSize , " ]." );
58-
59- TORCH_CHECK (tilingInfo_->opParamInfo .p .shape .size () == DIM_NUM_ONE,
60- " ApplyTopKTopPMinP: the dimNum of p should be " , DIM_NUM_ONE, " , but now is " ,
61- tilingInfo_->opParamInfo .p .shape .size (), " ." );
54+ TORCH_CHECK (kSize == tilingData_.batchSize , " ApplyTopKTopPMinP: the shape of k should be [" , tilingData_.batchSize ,
55+ " ], but now is [" , kSize , " ]." );
56+
57+ TORCH_CHECK (tilingInfo_->opParamInfo .p .shape .size () == DIM_NUM_ONE, " ApplyTopKTopPMinP: the dimNum of p should be " ,
58+ DIM_NUM_ONE, " , but now is " , tilingInfo_->opParamInfo .p .shape .size (), " ." );
6259 int64_t pSize = tilingInfo_->opParamInfo .p .shape [DIM_IDX_ZERO];
63- TORCH_CHECK (pSize == tilingData_.batchSize ,
64- " ApplyTopKTopPMinP: the shape of p should be [ " , tilingData_. batchSize , " ], but now is [" , pSize, " ]." );
60+ TORCH_CHECK (pSize == tilingData_.batchSize , " ApplyTopKTopPMinP: the shape of p should be [ " , tilingData_. batchSize ,
61+ " ], but now is [" , pSize, " ]." );
6562
6663 if (tilingInfo_->opParamInfo .minP .shape .size () != DIM_NUM_ZERO) {
6764 int64_t minPSize = tilingInfo_->opParamInfo .minP .shape [DIM_IDX_ZERO];
@@ -76,9 +73,8 @@ ge::graphStatus ApplyTopKTopPMinPTiling::CheckShape()
7673 int64_t sampledResSize0 = tilingInfo_->opParamInfo .sampledRes .shape [DIM_IDX_ZERO];
7774 int64_t sampledResSize1 = tilingInfo_->opParamInfo .sampledRes .shape [DIM_IDX_ONE];
7875 TORCH_CHECK (sampledResSize0 == tilingData_.batchSize && sampledResSize1 == tilingData_.vocabSize ,
79- " ApplyTopKTopPMinP: the size of sampledRes should be [" ,
80- tilingData_.batchSize , " , " , tilingData_.vocabSize ,
81- " ], but now is [" , sampledResSize0, " , " , sampledResSize1, " ]." );
76+ " ApplyTopKTopPMinP: the size of sampledRes should be [" , tilingData_.batchSize , " , " ,
77+ tilingData_.vocabSize , " ], but now is [" , sampledResSize0, " , " , sampledResSize1, " ]." );
8278 return ge::GRAPH_SUCCESS;
8379}
8480
@@ -111,18 +107,18 @@ ge::graphStatus ApplyTopKTopPMinPTiling::DoTiling()
111107
112108 auto socVersion = ascendcPlatform.GetSocVersion ();
113109 TORCH_CHECK (socVersion == platform_ascendc::SocVersion::ASCEND910B ||
114- socVersion == platform_ascendc::SocVersion::ASCEND910_93,
110+ socVersion == platform_ascendc::SocVersion::ASCEND910_93,
115111 " soc version does not support " , (int32_t )socVersion);
116-
112+
117113 SplitTask ();
118-
114+
119115 // -------------set workspacesize-----------------
120116 tilingInfo_->workspaceSize = static_cast <int64_t >(ascendcPlatform.GetLibApiWorkSpaceSize ()) +
121117 tilingData_.batchSize * tilingData_.vocabSize * BYTES_B32;
122118
123119 // -------------set tilingkey-----------------
124- tilingData_.tilingKey = G_DTYPE_MAP. at (tilingInfo_-> opParamInfo . probs . dtype ) * COEF_TEN +
125- tilingInfo_->needMinPSample ;
120+ tilingData_.tilingKey =
121+ G_DTYPE_MAP. at (tilingInfo_-> opParamInfo . probs . dtype ) * COEF_TEN + tilingInfo_->needMinPSample ;
126122
127123 return ge::GRAPH_SUCCESS;
128124}
0 commit comments