Skip to content

Commit 2199f31

Browse files
author
zhengfaan
committed
fix codelint
1 parent a8e9ba3 commit 2199f31

11 files changed

Lines changed: 82 additions & 70 deletions

csrc/deepep/deep_ep.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
405405
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> Buffer::intranode_combine(
406406
const torch::Tensor &x, const torch::Tensor &topk_idx, const std::optional<torch::Tensor> &topk_weights,
407407
const torch::Tensor &src_idx, const torch::Tensor &send_head, const torch::Tensor &put_offset,
408-
const torch::Tensor &balance_matrix,
409-
const std::optional<at::Tensor> &combine_send_cost_stats)
408+
const torch::Tensor &balance_matrix, const std::optional<at::Tensor> &combine_send_cost_stats)
410409
{
411410
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
412411
at::Tensor recv_x = x;
@@ -473,13 +472,14 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
473472
moe_expert_number = put_offset.size(0);
474473
ext_info = reinterpret_cast<uint64_t>(shmem_ptr);
475474

476-
// printf("[deepep] rank:%d, is_padding:%d, recv_x %p expand_ids %p expert_scales %p send_token_idx %p ep_send_counts %p combined_x %p\n",
477-
// rank, is_padding, recv_x.data_ptr(), expand_ids.data_ptr(), expert_scales.data_ptr(), send_token_idx_small.data_ptr(),
478-
// ep_send_counts.data_ptr(), combined_x.data_ptr());
475+
// printf("[deepep] rank:%d, is_padding:%d, recv_x %p expand_ids %p expert_scales %p send_token_idx %p
476+
// ep_send_counts %p combined_x %p\n",
477+
// rank, is_padding, recv_x.data_ptr(), expand_ids.data_ptr(), expert_scales.data_ptr(),
478+
// send_token_idx_small.data_ptr(), ep_send_counts.data_ptr(), combined_x.data_ptr());
479479

480480
EXEC_NPU_CMD(aclnnShmemMoeCombineNormal, recv_x, ep_send_counts, expert_scales, expand_ids,
481-
this->send_token_idx_small, balance_matrix, ext_info, num_ranks, rank, tp_world_size, tp_rankId, moe_expert_number,
482-
global_bs, combined_x, combine_send_cost_stats_out);
481+
this->send_token_idx_small, balance_matrix, ext_info, num_ranks, rank, tp_world_size, tp_rankId,
482+
moe_expert_number, global_bs, combined_x, combine_send_cost_stats_out);
483483
} else {
484484
ep_send_counts = send_head;
485485
moe_expert_number = send_head.size(0);

csrc/deepep/deep_ep.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ struct Buffer {
8383
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> intranode_combine(
8484
const torch::Tensor &x, const torch::Tensor &topk_idx, const std::optional<torch::Tensor> &topk_weights,
8585
const torch::Tensor &src_idx, const torch::Tensor &send_head, const torch::Tensor &put_offset,
86-
const torch::Tensor &balance_matrix,
87-
const std::optional<at::Tensor> &combine_send_cost_stats);
86+
const torch::Tensor &balance_matrix, const std::optional<at::Tensor> &combine_send_cost_stats);
8887

8988
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>,
9089
std::vector<int>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,

csrc/deepep/ops/op_host/op_api/aclnn_shmem_moe_combine_normal.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,15 @@ extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor,
1414
extern "C" {
1515
#endif
1616

17-
aclnnStatus aclnnShmemMoeCombineNormalGetWorkspaceSize(const aclTensor *recvX, const aclTensor *epRecvCounts,
18-
const aclTensor *recvTopkWeights, const aclTensor *topkIdx,
19-
const aclTensor *sendTokenIdx, const aclTensor *balanceMatrix, uint64_t meta_data_ptr,
20-
int64_t epWorldSize, int64_t epRankId, int64_t tpWorldSize,
21-
int64_t tpRankId, int64_t moeExpertNum, int64_t globalBs,
22-
const aclTensor *out, const aclTensor *sendCostStats,
23-
uint64_t *workspaceSize, aclOpExecutor **executor)
17+
aclnnStatus aclnnShmemMoeCombineNormalGetWorkspaceSize(
18+
const aclTensor *recvX, const aclTensor *epRecvCounts, const aclTensor *recvTopkWeights, const aclTensor *topkIdx,
19+
const aclTensor *sendTokenIdx, const aclTensor *balanceMatrix, uint64_t meta_data_ptr, int64_t epWorldSize,
20+
int64_t epRankId, int64_t tpWorldSize, int64_t tpRankId, int64_t moeExpertNum, int64_t globalBs,
21+
const aclTensor *out, const aclTensor *sendCostStats, uint64_t *workspaceSize, aclOpExecutor **executor)
2422
{
2523
return aclnnInnerShmemMoeCombineNormalGetWorkspaceSize(
26-
recvX, epRecvCounts, recvTopkWeights, topkIdx, sendTokenIdx, balanceMatrix, meta_data_ptr, epWorldSize, epRankId, tpWorldSize,
27-
tpRankId, moeExpertNum, globalBs, out, sendCostStats, workspaceSize, executor);
24+
recvX, epRecvCounts, recvTopkWeights, topkIdx, sendTokenIdx, balanceMatrix, meta_data_ptr, epWorldSize,
25+
epRankId, tpWorldSize, tpRankId, moeExpertNum, globalBs, out, sendCostStats, workspaceSize, executor);
2826
}
2927

3028
aclnnStatus aclnnShmemMoeCombineNormal(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,

csrc/deepep/ops/op_host/op_api/aclnn_shmem_moe_combine_normal.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ extern "C" {
99

1010
__attribute__((visibility("default"))) aclnnStatus aclnnShmemMoeCombineNormalGetWorkspaceSize(
1111
const aclTensor *recvX, const aclTensor *epRecvCounts, const aclTensor *recvTopkWeights, const aclTensor *topkIdx,
12-
const aclTensor *sendTokenIdx, const aclTensor *balanceMatrix, uint64_t meta_data_ptr, int64_t epWorldSize, int64_t epRankId, int64_t tpWorldSize,
13-
int64_t tpRankId, int64_t moeExpertNum, int64_t globalBs, const aclTensor *out, const aclTensor *sendCostStats,
14-
uint64_t *workspaceSize, aclOpExecutor **executor);
12+
const aclTensor *sendTokenIdx, const aclTensor *balanceMatrix, uint64_t meta_data_ptr, int64_t epWorldSize,
13+
int64_t epRankId, int64_t tpWorldSize, int64_t tpRankId, int64_t moeExpertNum, int64_t globalBs,
14+
const aclTensor *out, const aclTensor *sendCostStats, uint64_t *workspaceSize, aclOpExecutor **executor);
1515

1616
__attribute__((visibility("default"))) aclnnStatus aclnnShmemMoeCombineNormal(void *workspace, uint64_t workspaceSize,
1717
aclOpExecutor *executor,

csrc/deepep/ops/op_host/shmem_moe_combine_normal_tiling.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ using CommQuantModeType = std::underlying_type<CommQuantMode>;
7373

7474
namespace optiling {
7575

76-
static int GetFactorEnv(const char* name, int value = 0)
76+
static int GetFactorEnv(const char *name, int value = 0)
7777
{
7878
int defaultValue = value;
7979
if (getenv(name) == nullptr) {

csrc/deepep/ops/op_host/shmem_notify_dispatch_tiling.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ constexpr static int TILING_KEY_A2_TYPE = 100;
5858
} // namespace
5959

6060
namespace optiling {
61-
static float GetFactorEnv(const char* name, float value = 0.0f)
61+
static float GetFactorEnv(const char *name, float value = 0.0f)
6262
{
6363
float defaultValue = value;
6464
if (getenv(name) == nullptr) {

csrc/deepep/ops/op_kernel/shmem_moe_combine_normal.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ using namespace AscendC;
66
using namespace ShmemMoeCombineNormalImpl;
77

88
extern "C" __global__ __aicore__ void shmem_moe_combine_normal(GM_ADDR recvX, GM_ADDR epRecvCount, GM_ADDR topkWeights,
9-
GM_ADDR topkIdx, GM_ADDR sendTokenIdx, GM_ADDR balanceMatrix, GM_ADDR XOut,
9+
GM_ADDR topkIdx, GM_ADDR sendTokenIdx,
10+
GM_ADDR balanceMatrix, GM_ADDR XOut,
1011
GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM,
1112
GM_ADDR tilingGM)
1213

@@ -17,8 +18,8 @@ extern "C" __global__ __aicore__ void shmem_moe_combine_normal(GM_ADDR recvX, GM
1718
#if (ORIG_DTYPE_RECV_X == DT_BF16 || ORIG_DTYPE_RECV_X == DT_FLOAT16)
1819
GET_TILING_DATA_WITH_STRUCT(ShmemMoeCombineNormalTilingData, tilingData, tilingGM);
1920
ShmemMoeCombineNormal<DTYPE_RECV_X, DTYPE_X, int32_t> op;
20-
op.Init(recvX, epRecvCount, topkWeights, topkIdx, sendTokenIdx, balanceMatrix, XOut, sendCostStatsOut, workspaceGM, &pipe,
21-
&tilingData);
21+
op.Init(recvX, epRecvCount, topkWeights, topkIdx, sendTokenIdx, balanceMatrix, XOut, sendCostStatsOut, workspaceGM,
22+
&pipe, &tilingData);
2223
op.Process();
2324
#endif
2425
}

csrc/deepep/ops/op_kernel/shmem_moe_combine_normal.h

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,14 @@ class ShmemMoeCombineNormal
3434
public:
3535
__aicore__ inline ShmemMoeCombineNormal(){};
3636
__aicore__ inline void Init(GM_ADDR recvX, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR topkIdx,
37-
GM_ADDR sendTokenIdx, GM_ADDR balanceMatrix, GM_ADDR XOut, GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM,
38-
TPipe *pipe, const ShmemMoeCombineNormalTilingData *tilingData);
37+
GM_ADDR sendTokenIdx, GM_ADDR balanceMatrix, GM_ADDR XOut, GM_ADDR sendCostStatsOut,
38+
GM_ADDR workspaceGM, TPipe *pipe, const ShmemMoeCombineNormalTilingData *tilingData);
3939
__aicore__ inline void Process();
4040

4141
private:
4242
__aicore__ inline void InitGlobalBuffer(GM_ADDR recvX, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR topkIdx,
43-
GM_ADDR sendTokenIdx, GM_ADDR balanceMatrix, GM_ADDR XOut, GM_ADDR sendCostStatsOut);
43+
GM_ADDR sendTokenIdx, GM_ADDR balanceMatrix, GM_ADDR XOut,
44+
GM_ADDR sendCostStatsOut);
4445
__aicore__ inline void InitTilingData(const ShmemMoeCombineNormalTilingData *tilingData);
4546
__aicore__ inline void InitBuffLen();
4647
__aicore__ inline void ResetMetaState();
@@ -49,7 +50,8 @@ class ShmemMoeCombineNormal
4950
__aicore__ inline void WaitSyncFlag(int metaType);
5051
__aicore__ inline void GetShareAddr();
5152
__aicore__ inline void HandleAllRankToken();
52-
__aicore__ inline void ReadAndWriteForTargetRank(uint32_t startId, uint32_t endId, uint32_t tokenCnt, uint32_t tarRankId);
53+
__aicore__ inline void ReadAndWriteForTargetRank(uint32_t startId, uint32_t endId, uint32_t tokenCnt,
54+
uint32_t tarRankId);
5355
__aicore__ inline void ReadTokenFromRemote();
5456
__aicore__ inline void ReadTokenAndWeightedSum(uint32_t tokenIndex, uint32_t tarRankId);
5557

@@ -153,12 +155,12 @@ class ShmemMoeCombineNormal
153155
GM_ADDR epRecvCountGM_;
154156

155157
GM_ADDR gva_gm;
156-
uint64_t shareRecvXAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (recvXGM_)
157-
uint64_t shareTopkIdxAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (topkIdxGM_)
158-
uint64_t shareTopkWeightsAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (topkWeightsGM_)
158+
uint64_t shareRecvXAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (recvXGM_)
159+
uint64_t shareTopkIdxAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (topkIdxGM_)
160+
uint64_t shareTopkWeightsAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (topkWeightsGM_)
159161
uint64_t shareSendTokenIdxAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (sendTokenIdxGM_)
160-
uint64_t shareRecvCountAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (epRecvCountGM_)
161-
uint64_t shareXOutAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (XOutGM_)
162+
uint64_t shareRecvCountAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (epRecvCountGM_)
163+
uint64_t shareXOutAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (XOutGM_)
162164
uint32_t shareAddrNum{6};
163165

164166
LocalTensor<float> tokenFloatLocal;
@@ -173,13 +175,9 @@ class ShmemMoeCombineNormal
173175
};
174176

175177
template <TemplateMC2TypeClass>
176-
__aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::InitGlobalBuffer(GM_ADDR recvX, GM_ADDR epRecvCount,
177-
GM_ADDR topkWeights,
178-
GM_ADDR topkIdx,
179-
GM_ADDR sendTokenIdx,
180-
GM_ADDR balanceMatrix,
181-
GM_ADDR XOut,
182-
GM_ADDR sendCostStatsOut)
178+
__aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::InitGlobalBuffer(
179+
GM_ADDR recvX, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR topkIdx, GM_ADDR sendTokenIdx,
180+
GM_ADDR balanceMatrix, GM_ADDR XOut, GM_ADDR sendCostStatsOut)
183181
{
184182
recvXGT_.SetGlobalBuffer((__gm__ RecvXType *)recvX);
185183
epRecvCountGT_.SetGlobalBuffer((__gm__ int32_t *)epRecvCount); // 放置allReccvCount信息,num_ranks * num_experts
@@ -233,8 +231,9 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::InitBuffLen()
233231

234232
template <TemplateMC2TypeClass>
235233
__aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::Init(
236-
GM_ADDR recvX, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR topkIdx, GM_ADDR sendTokenIdx, GM_ADDR balanceMatrix, GM_ADDR XOut,
237-
GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM, TPipe *pipe, const ShmemMoeCombineNormalTilingData *tilingData)
234+
GM_ADDR recvX, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR topkIdx, GM_ADDR sendTokenIdx,
235+
GM_ADDR balanceMatrix, GM_ADDR XOut, GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM, TPipe *pipe,
236+
const ShmemMoeCombineNormalTilingData *tilingData)
238237
{
239238
workspaceGM_ = workspaceGM;
240239
recvXGM_ = recvX;
@@ -259,7 +258,8 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::Init(
259258
SplitCoreCal(epRankSize, rankNumPerBlock, curBlockStartRankId, curBlockEndRankId);
260259

261260
// if (blockIdx == 0) {
262-
// printf("[Init] rank:%d, recvXGM_ %p topkIdxGM_ %p topkWeightsGM_ %p sendTokenIdxGM_ %p epRecvCountGM_ %p XOutGM_ %p\n", epRankId,
261+
// printf("[Init] rank:%d, recvXGM_ %p topkIdxGM_ %p topkWeightsGM_ %p sendTokenIdxGM_ %p epRecvCountGM_ %p
262+
// XOutGM_ %p\n", epRankId,
263263
// recvXGM_, topkIdxGM_, topkWeightsGM_, sendTokenIdxGM_, epRecvCountGM_, XOutGM_);
264264
// }
265265
}
@@ -319,7 +319,8 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::PutShareAddr(
319319
metaDataGt.SetGlobalBuffer((__gm__ uint64_t *)(remote_meta));
320320
DataCopyPad(metaDataGt, addrTensor_, copyParams);
321321

322-
// printf("[putAddr] rank:%d, recvXAddr %p topkIdxAddr %p topkWeightsAddr %p sendTokenIdxAddr %p epRecvCountAddr %p XOutAddr %p\n", epRankId,
322+
// printf("[putAddr] rank:%d, recvXAddr %p topkIdxAddr %p topkWeightsAddr %p sendTokenIdxAddr %p epRecvCountAddr %p
323+
// XOutAddr %p\n", epRankId,
323324
// recvXAddr, topkIdxAddr, topkWeightsAddr, sendTokenIdxAddr,
324325
// epRecvCountAddr, XOutAddr);
325326

@@ -478,7 +479,8 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::ReadTokenAndW
478479
dstGT.SetGlobalBuffer((__gm__ XType *)(ptr + hRecvXTypeLen_ * (remoteReadBase + remoteReadOffset)));
479480

480481
// if (tarRankId == 5) {
481-
// printf("[WeightedSum] rank:%d, blockId:%d, tarRankId:%d, tokenIndex:%d, remoteReadBase:%d, remoteReadOffset:%d, expertId:%d, dstRankId:%d\n",
482+
// printf("[WeightedSum] rank:%d, blockId:%d, tarRankId:%d, tokenIndex:%d, remoteReadBase:%d,
483+
// remoteReadOffset:%d, expertId:%d, dstRankId:%d\n",
482484
// epRankId, blockIdx, tarRankId, tokenIndex, remoteReadBase, remoteReadOffset, expertId, dstRankId);
483485
// }
484486

@@ -542,7 +544,8 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::ReadTokenFrom
542544
xOutGlobal_.SetGlobalBuffer((__gm__ XType *)XOutGM_);
543545

544546
// if (blockIdx == 0) {
545-
// printf("[ReadTokenFromRemote] rank:%d, recvXGM_ %p topkIdxGM_ %p topkWeightsGM_ %p sendTokenIdxGM_ %p epRecvCountGM_ %p XOutGM_ %p\n", epRankId,
547+
// printf("[ReadTokenFromRemote] rank:%d, recvXGM_ %p topkIdxGM_ %p topkWeightsGM_ %p sendTokenIdxGM_ %p
548+
// epRecvCountGM_ %p XOutGM_ %p\n", epRankId,
546549
// recvXGM_, topkIdxGM_, topkWeightsGM_, sendTokenIdxGM_, epRecvCountGM_, XOutGM_);
547550
// AscendC::DumpTensor(topkIdxGT_, 547, 8);
548551
// }
@@ -572,7 +575,10 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::ReadTokenFrom
572575
}
573576

574577
template <TemplateMC2TypeClass>
575-
__aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::ReadAndWriteForTargetRank(uint32_t startId, uint32_t endId, uint32_t tokenCnt, uint32_t tarRankId)
578+
__aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::ReadAndWriteForTargetRank(uint32_t startId,
579+
uint32_t endId,
580+
uint32_t tokenCnt,
581+
uint32_t tarRankId)
576582
{
577583
if (tokenCnt == 0U) {
578584
return;
@@ -586,7 +592,8 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::ReadAndWriteF
586592
endTokenIndex += startId;
587593

588594
// if (tarRankId == 5) {
589-
// printf("[ReadAndWrite] rank:%d, blockId:%d, startTokenIndex:%d, endTokenIndex:%d, tokenPerBlock:%d, tarRankId:%d\n",
595+
// printf("[ReadAndWrite] rank:%d, blockId:%d, startTokenIndex:%d, endTokenIndex:%d, tokenPerBlock:%d,
596+
// tarRankId:%d\n",
590597
// epRankId, blockIdx, startTokenIndex, endTokenIndex, tokenPerBlock, tarRankId);
591598
// }
592599

@@ -672,8 +679,8 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::HandleAllRank
672679
continue;
673680
}
674681
if (blockIdx == 0) {
675-
printf("[HandleAllRank] rank:%d, blockId:%d, startId:%d, endId:%d, tokenCnt:%d, tarRankId:%d\n",
676-
epRankId, blockIdx, startId, endId, tokenCnt, tarRankId);
682+
printf("[HandleAllRank] rank:%d, blockId:%d, startId:%d, endId:%d, tokenCnt:%d, tarRankId:%d\n", epRankId,
683+
blockIdx, startId, endId, tokenCnt, tarRankId);
677684
}
678685

679686
ReadAndWriteForTargetRank(startId, endId, tokenCnt, tarRankId);

csrc/deepep/ops/op_kernel/shmem_moe_dispatch_normal.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,8 +477,9 @@ __aicore__ inline void ShmemMoeDispatchNormal<CamTypeFunc>::InputToDstOutput()
477477

478478
DataCopyExtParams xCopyParams = {1U, static_cast<uint32_t>(h * sizeof(XType)), 0U, 0U, 0U};
479479
DataCopyPadExtParams<XType> tokenCopyPadExtParams{false, 0U, 0U, 0U};
480-
DataCopyExtParams xOutCopyParams = {1U, static_cast<uint32_t>(h * sizeof(ExpandXOutType)), 0U, 0U, 0U}; // 只拷贝hidden_size
481-
DataCopyExtParams scaleCopyParams = {1U, sizeof(float), 0U, 0U, 0U}; // 拷贝dynamicScales
480+
DataCopyExtParams xOutCopyParams = {1U, static_cast<uint32_t>(h * sizeof(ExpandXOutType)), 0U, 0U,
481+
0U}; // 只拷贝hidden_size
482+
DataCopyExtParams scaleCopyParams = {1U, sizeof(float), 0U, 0U, 0U}; // 拷贝dynamicScales
482483

483484
for (int32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) {
484485
uint32_t dstExpertId = expertIdsTensor(tokenIndex - startTokenId);
@@ -509,7 +510,8 @@ __aicore__ inline void ShmemMoeDispatchNormal<CamTypeFunc>::InputToDstOutput()
509510
DataCopyPad(dstGT, xOutTensor, xOutCopyParams); // 拷贝token
510511

511512
LocalTensor<float> xOutFp32Tensor = xOutTensor.template ReinterpretCast<float>();
512-
DataCopyPad(dstScaleOutGT[dstExpertOffset + curExpertIdx], xOutFp32Tensor[hUBAlignSize / sizeof(float)], scaleCopyParams);
513+
DataCopyPad(dstScaleOutGT[dstExpertOffset + curExpertIdx], xOutFp32Tensor[hUBAlignSize / sizeof(float)],
514+
scaleCopyParams);
513515

514516
xOutQueue.FreeTensor(xOutTensor);
515517
} else {

0 commit comments

Comments
 (0)