@@ -100,18 +100,18 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
100100 tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB,
101101 gemm::EltwiseActType eltwiseActType, bool enablesEarlyExit, bool enablesDelayedEarlyExit,
102102 bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, int epilogueTileN,
103- bool fuseUtccpWithUtcmma, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB ,
104- bool gridWaitForPrimaryEarlyExit , bool gridWaitForPrimaryA , bool gridWaitForPrimaryB, bool hoistLoadTaskInit ,
105- bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA ,
106- gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind , int mmaM , int mmaN, bool mockAllReduce, int n ,
107- int numEpilogueWarps , int numRegsCastAWarps, int numRegsCopySfLdsSttm , int numRegsCopySparsityInfo ,
108- int numRegsPerThreadEpilogueWarp , int numRegsPerThreadNonEpilogueWarp , int numSlicesForSplitK ,
109- int numSlicesForSliceK , int numStages , int numStagesMma , int numStagesMmaWithinWorkTile ,
110- int numStagesMmaAcrossWorkTile , int numStagesWorkId, bool outputDebugTensors, bool patchF2fp ,
111- int32_t sfBlockSizeA, int32_t sfBlockSizeB, int32_t sfBlockSizeC, tg::SfLayout sfLayoutA ,
112- tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, tg::Sparsity sparsityA ,
113- gemm::SplitK splitK, int tileK, int tileM, int tileN, gemm::TileScheduler tileScheduler,
114- bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8,
103+ int fallbackClusterDimX, int fallbackClusterDimY, int fallbackClusterDimZ, bool fuseUtccpWithUtcmma ,
104+ bool gridTriggerSecondaryA , bool gridTriggerSecondaryB , bool gridWaitForPrimaryEarlyExit ,
105+ bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k ,
106+ gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB , int m , int mmaK ,
107+ tg::MmaKind mmaKind, int mmaM , int mmaN, bool mockAllReduce, int n , int numEpilogueWarps, int numRegsCastAWarps ,
108+ int numRegsCopySfLdsSttm , int numRegsCopySparsityInfo , int numRegsPerThreadEpilogueWarp ,
109+ int numRegsPerThreadNonEpilogueWarp , int numSlicesForSplitK , int numSlicesForSliceK , int numStages ,
110+ int numStagesMma , int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId ,
111+ bool outputDebugTensors, bool patchF2fp, int32_t sfBlockSizeA, int32_t sfBlockSizeB, int32_t sfBlockSizeC,
112+ tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK,
113+ tg::Sparsity sparsityA, gemm::SplitK splitK, int tileK, int tileM, int tileN, gemm::TileScheduler tileScheduler,
114+ bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useFlexibleClusterDims,
115115 bool useHoistTryWaitForCustomMmaSchedule, bool useMaxTmemOverlap, bool usePerTokenSfA, bool usePerTokenSfB,
116116 bool useShuffledMatrix, bool useTmaStore, bool useTwoTmaLoadWarps, bool useTwoMmaWarps,
117117 bool useUnrollLoop2xForMma, int validM, int validN, int validK, int worldSize,
@@ -127,17 +127,18 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
127127 gemm::GemmOptions (allReduceAlgo, biasType, blockK, clcFastDrain, clusterDimX, clusterDimY, clusterDimZ,
128128 ctaSwizzleType, dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, eltwiseActType,
129129 enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, epilogueLdtmBits,
130- epilogueTileM, epilogueTileN, fuseUtccpWithUtcmma, gridTriggerSecondaryA, gridTriggerSecondaryB,
131- gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, gridWaitForPrimaryB, hoistLoadTaskInit,
132- hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n,
133- numEpilogueWarps, numRegsCastAWarps, numRegsCopySfLdsSttm, numRegsCopySparsityInfo,
134- numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp, numSlicesForSplitK, numSlicesForSliceK,
135- numStages, numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId,
136- outputDebugTensors, patchF2fp, sfBlockSizeA, sfBlockSizeB, sfBlockSizeC, sfLayoutA, sfLayoutB,
137- sfLayoutC, sfReshapeFactor, sliceK, sparsityA, splitK, tileK, tileM, tileN, tileScheduler,
138- transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8, useHoistTryWaitForCustomMmaSchedule,
139- useMaxTmemOverlap, usePerTokenSfA, usePerTokenSfB, useShuffledMatrix, useTmaStore, useTwoTmaLoadWarps,
140- useTwoMmaWarps, useUnrollLoop2xForMma, validM, validN, validK, worldSize),
130+ epilogueTileM, epilogueTileN, fallbackClusterDimX, fallbackClusterDimY, fallbackClusterDimZ,
131+ fuseUtccpWithUtcmma, gridTriggerSecondaryA, gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit,
132+ gridWaitForPrimaryA, gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits,
133+ layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, numEpilogueWarps, numRegsCastAWarps,
134+ numRegsCopySfLdsSttm, numRegsCopySparsityInfo, numRegsPerThreadEpilogueWarp,
135+ numRegsPerThreadNonEpilogueWarp, numSlicesForSplitK, numSlicesForSliceK, numStages, numStagesMma,
136+ numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId, outputDebugTensors, patchF2fp,
137+ sfBlockSizeA, sfBlockSizeB, sfBlockSizeC, sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, sliceK,
138+ sparsityA, splitK, tileK, tileM, tileN, tileScheduler, transposeMmaOutput, useCustomMmaSchedule,
139+ useDeepSeekFp8, useFlexibleClusterDims, useHoistTryWaitForCustomMmaSchedule, useMaxTmemOverlap,
140+ usePerTokenSfA, usePerTokenSfB, useShuffledMatrix, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps,
141+ useUnrollLoop2xForMma, validM, validN, validK, worldSize),
141142 actType, clampBeforeAct)
142143 , mBatchedM(batchedM)
143144 , mBatchedN(batchedN)
@@ -310,7 +311,7 @@ inline bool checkAndUpdateBatchedGemmOptions(
310311 TLLM_CHECK_ERROR ((options.mRouteSfsImpl .value () == RouteImpl::Ldgsts
311312 || options.mRouteSfsImpl .value () == RouteImpl::LdgPlusSts)
312313 && options.mRouteImpl == RouteImpl::Tma,
313- " RouteSfsImpl must be equal to RouteImpl, or Ldgsts/LdgPlusSts, when RouteImpl is Tma" );
314+ " RouteSfsImpl must be equal to RouteImpl, or Ldgsts/LdgPlusSts when RouteImpl is Tma" );
314315 }
315316 else if (!options.mRouteSfsImpl .has_value ())
316317 {
@@ -379,8 +380,6 @@ inline bool checkAndUpdateBatchedGemmOptions(
379380
380381 if (doesRouteImplUseTma (options.mRouteSfsImpl .value ()))
381382 {
382- TLLM_CHECK_ERROR (!batchM, " UTMALDG.GATHER4 only supported for batch N." );
383-
384383 if (tg::mmaKindIsBlockFmt (options.mMmaKind ))
385384 {
386385 int const numEltsPerSfRoute = batchM ? options.mSfBlockSizeA : options.mSfBlockSizeB ;
@@ -392,8 +391,9 @@ inline bool checkAndUpdateBatchedGemmOptions(
392391
393392 if (!batchM || doesRouteImplUseNoRoute (options.mRouteImpl ))
394393 {
395- TLLM_CHECK_ERROR (options.mSfLayoutA == tg::SfLayout::R128c4,
396- " options.mSfLayoutA has to be tg::SfLayout::R128c4 when not being routed" );
394+ bool isSupportedSfLayoutA = options.mSfLayoutA == tg::SfLayout::R128c4;
395+ TLLM_CHECK_ERROR (isSupportedSfLayoutA, " options.mSfLayoutA has to be R128cX when not batch M or not routed" ,
396+ tg::sfLayoutToString (options.mSfLayoutA ));
397397 }
398398 }
399399
@@ -422,12 +422,6 @@ inline bool checkAndUpdateBatchedGemmOptions(
422422 options.mK % options.mTileK == 0 , " K must be a multiple of tileK when using Ldg based SF routing" );
423423 }
424424
425- if (options.mClusterDimX > 1 && batchM && options.mRouteSfsImpl .has_value ())
426- {
427- TLLM_CHECK_ERROR (options.mRouteSfsImpl .value () != RouteImpl::Tma,
428- " 2CTA BatchedGemm does not support routing Sf along M dimension with TMA." );
429- }
430-
431425 // Check if all elements in mBatchedM or mBatchedN are the same (uniform tokens per batch) and
432426 // set mIsUniformNumTokensPerBatch and mBatchStride.
433427 if (options.mIsUniformNumTokensPerBatch )
0 commit comments