@@ -34,13 +34,14 @@ class ShmemMoeCombineNormal
3434public:
3535 __aicore__ inline ShmemMoeCombineNormal (){};
3636 __aicore__ inline void Init (GM_ADDR recvX, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR topkIdx,
37- GM_ADDR sendTokenIdx, GM_ADDR balanceMatrix, GM_ADDR XOut, GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM,
38- TPipe *pipe, const ShmemMoeCombineNormalTilingData *tilingData);
37+ GM_ADDR sendTokenIdx, GM_ADDR balanceMatrix, GM_ADDR XOut, GM_ADDR sendCostStatsOut,
38+ GM_ADDR workspaceGM, TPipe *pipe, const ShmemMoeCombineNormalTilingData *tilingData);
3939 __aicore__ inline void Process ();
4040
4141private:
4242 __aicore__ inline void InitGlobalBuffer (GM_ADDR recvX, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR topkIdx,
43- GM_ADDR sendTokenIdx, GM_ADDR balanceMatrix, GM_ADDR XOut, GM_ADDR sendCostStatsOut);
43+ GM_ADDR sendTokenIdx, GM_ADDR balanceMatrix, GM_ADDR XOut,
44+ GM_ADDR sendCostStatsOut);
4445 __aicore__ inline void InitTilingData (const ShmemMoeCombineNormalTilingData *tilingData);
4546 __aicore__ inline void InitBuffLen ();
4647 __aicore__ inline void ResetMetaState ();
@@ -49,7 +50,8 @@ class ShmemMoeCombineNormal
4950 __aicore__ inline void WaitSyncFlag (int metaType);
5051 __aicore__ inline void GetShareAddr ();
5152 __aicore__ inline void HandleAllRankToken ();
52- __aicore__ inline void ReadAndWriteForTargetRank (uint32_t startId, uint32_t endId, uint32_t tokenCnt, uint32_t tarRankId);
53+ __aicore__ inline void ReadAndWriteForTargetRank (uint32_t startId, uint32_t endId, uint32_t tokenCnt,
54+ uint32_t tarRankId);
5355 __aicore__ inline void ReadTokenFromRemote ();
5456 __aicore__ inline void ReadTokenAndWeightedSum (uint32_t tokenIndex, uint32_t tarRankId);
5557
@@ -153,12 +155,12 @@ class ShmemMoeCombineNormal
153155 GM_ADDR epRecvCountGM_;
154156
155157 GM_ADDR gva_gm;
156- uint64_t shareRecvXAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (recvXGM_)
157- uint64_t shareTopkIdxAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (topkIdxGM_)
158- uint64_t shareTopkWeightsAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (topkWeightsGM_)
158+ uint64_t shareRecvXAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (recvXGM_)
159+ uint64_t shareTopkIdxAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (topkIdxGM_)
160+ uint64_t shareTopkWeightsAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (topkWeightsGM_)
159161 uint64_t shareSendTokenIdxAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (sendTokenIdxGM_)
160- uint64_t shareRecvCountAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (epRecvCountGM_)
161- uint64_t shareXOutAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (XOutGM_)
162+ uint64_t shareRecvCountAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (epRecvCountGM_)
163+ uint64_t shareXOutAddrs[CAM_MAX_RANK_SIZE]; // List of shmem asymmetric output addresses (XOutGM_)
162164 uint32_t shareAddrNum{6 };
163165
164166 LocalTensor<float > tokenFloatLocal;
@@ -173,13 +175,9 @@ class ShmemMoeCombineNormal
173175};
174176
175177template <TemplateMC2TypeClass>
176- __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::InitGlobalBuffer(GM_ADDR recvX, GM_ADDR epRecvCount,
177- GM_ADDR topkWeights,
178- GM_ADDR topkIdx,
179- GM_ADDR sendTokenIdx,
180- GM_ADDR balanceMatrix,
181- GM_ADDR XOut,
182- GM_ADDR sendCostStatsOut)
178+ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::InitGlobalBuffer(
179+ GM_ADDR recvX, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR topkIdx, GM_ADDR sendTokenIdx,
180+ GM_ADDR balanceMatrix, GM_ADDR XOut, GM_ADDR sendCostStatsOut)
183181{
184182 recvXGT_.SetGlobalBuffer ((__gm__ RecvXType *)recvX);
185183 epRecvCountGT_.SetGlobalBuffer ((__gm__ int32_t *)epRecvCount); // 放置allReccvCount信息,num_ranks * num_experts
@@ -233,8 +231,9 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::InitBuffLen()
233231
234232template <TemplateMC2TypeClass>
235233__aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::Init(
236- GM_ADDR recvX, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR topkIdx, GM_ADDR sendTokenIdx, GM_ADDR balanceMatrix, GM_ADDR XOut,
237- GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM, TPipe *pipe, const ShmemMoeCombineNormalTilingData *tilingData)
234+ GM_ADDR recvX, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR topkIdx, GM_ADDR sendTokenIdx,
235+ GM_ADDR balanceMatrix, GM_ADDR XOut, GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM, TPipe *pipe,
236+ const ShmemMoeCombineNormalTilingData *tilingData)
238237{
239238 workspaceGM_ = workspaceGM;
240239 recvXGM_ = recvX;
@@ -259,7 +258,8 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::Init(
259258 SplitCoreCal (epRankSize, rankNumPerBlock, curBlockStartRankId, curBlockEndRankId);
260259
261260 // if (blockIdx == 0) {
262- // printf("[Init] rank:%d, recvXGM_ %p topkIdxGM_ %p topkWeightsGM_ %p sendTokenIdxGM_ %p epRecvCountGM_ %p XOutGM_ %p\n", epRankId,
261+ // printf("[Init] rank:%d, recvXGM_ %p topkIdxGM_ %p topkWeightsGM_ %p sendTokenIdxGM_ %p epRecvCountGM_ %p
262+ // XOutGM_ %p\n", epRankId,
263263 // recvXGM_, topkIdxGM_, topkWeightsGM_, sendTokenIdxGM_, epRecvCountGM_, XOutGM_);
264264 // }
265265}
@@ -319,7 +319,8 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::PutShareAddr(
319319 metaDataGt.SetGlobalBuffer ((__gm__ uint64_t *)(remote_meta));
320320 DataCopyPad (metaDataGt, addrTensor_, copyParams);
321321
322- // printf("[putAddr] rank:%d, recvXAddr %p topkIdxAddr %p topkWeightsAddr %p sendTokenIdxAddr %p epRecvCountAddr %p XOutAddr %p\n", epRankId,
322+ // printf("[putAddr] rank:%d, recvXAddr %p topkIdxAddr %p topkWeightsAddr %p sendTokenIdxAddr %p epRecvCountAddr %p
323+ // XOutAddr %p\n", epRankId,
323324 // recvXAddr, topkIdxAddr, topkWeightsAddr, sendTokenIdxAddr,
324325 // epRecvCountAddr, XOutAddr);
325326
@@ -478,7 +479,8 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::ReadTokenAndW
478479 dstGT.SetGlobalBuffer ((__gm__ XType *)(ptr + hRecvXTypeLen_ * (remoteReadBase + remoteReadOffset)));
479480
480481 // if (tarRankId == 5) {
481- // printf("[WeightedSum] rank:%d, blockId:%d, tarRankId:%d, tokenIndex:%d, remoteReadBase:%d, remoteReadOffset:%d, expertId:%d, dstRankId:%d\n",
482+ // printf("[WeightedSum] rank:%d, blockId:%d, tarRankId:%d, tokenIndex:%d, remoteReadBase:%d,
483+ // remoteReadOffset:%d, expertId:%d, dstRankId:%d\n",
482484 // epRankId, blockIdx, tarRankId, tokenIndex, remoteReadBase, remoteReadOffset, expertId, dstRankId);
483485 // }
484486
@@ -542,7 +544,8 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::ReadTokenFrom
542544 xOutGlobal_.SetGlobalBuffer ((__gm__ XType *)XOutGM_);
543545
544546 // if (blockIdx == 0) {
545- // printf("[ReadTokenFromRemote] rank:%d, recvXGM_ %p topkIdxGM_ %p topkWeightsGM_ %p sendTokenIdxGM_ %p epRecvCountGM_ %p XOutGM_ %p\n", epRankId,
547+ // printf("[ReadTokenFromRemote] rank:%d, recvXGM_ %p topkIdxGM_ %p topkWeightsGM_ %p sendTokenIdxGM_ %p
548+ // epRecvCountGM_ %p XOutGM_ %p\n", epRankId,
546549 // recvXGM_, topkIdxGM_, topkWeightsGM_, sendTokenIdxGM_, epRecvCountGM_, XOutGM_);
547550 // AscendC::DumpTensor(topkIdxGT_, 547, 8);
548551 // }
@@ -572,7 +575,10 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::ReadTokenFrom
572575}
573576
574577template <TemplateMC2TypeClass>
575- __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::ReadAndWriteForTargetRank(uint32_t startId, uint32_t endId, uint32_t tokenCnt, uint32_t tarRankId)
578+ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::ReadAndWriteForTargetRank(uint32_t startId,
579+ uint32_t endId,
580+ uint32_t tokenCnt,
581+ uint32_t tarRankId)
576582{
577583 if (tokenCnt == 0U ) {
578584 return ;
@@ -586,7 +592,8 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::ReadAndWriteF
586592 endTokenIndex += startId;
587593
588594 // if (tarRankId == 5) {
589- // printf("[ReadAndWrite] rank:%d, blockId:%d, startTokenIndex:%d, endTokenIndex:%d, tokenPerBlock:%d, tarRankId:%d\n",
595+ // printf("[ReadAndWrite] rank:%d, blockId:%d, startTokenIndex:%d, endTokenIndex:%d, tokenPerBlock:%d,
596+ // tarRankId:%d\n",
590597 // epRankId, blockIdx, startTokenIndex, endTokenIndex, tokenPerBlock, tarRankId);
591598 // }
592599
@@ -672,8 +679,8 @@ __aicore__ inline void ShmemMoeCombineNormal<TemplateMC2TypeFunc>::HandleAllRank
672679 continue ;
673680 }
674681 if (blockIdx == 0 ) {
675- printf (" [HandleAllRank] rank:%d, blockId:%d, startId:%d, endId:%d, tokenCnt:%d, tarRankId:%d\n " ,
676- epRankId, blockIdx, startId, endId, tokenCnt, tarRankId);
682+ printf (" [HandleAllRank] rank:%d, blockId:%d, startId:%d, endId:%d, tokenCnt:%d, tarRankId:%d\n " , epRankId,
683+ blockIdx, startId, endId, tokenCnt, tarRankId);
677684 }
678685
679686 ReadAndWriteForTargetRank (startId, endId, tokenCnt, tarRankId);
0 commit comments