@@ -8,10 +8,11 @@ namespace ck_tile {
88template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
99StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::StreamKTilePartitionerBase(
1010 index_t m, index_t n, index_t k, index_t max_active_wgs)
11- : max_active_wgs_{max_active_wgs}, n_{n}
11+ : max_active_wgs_{max_active_wgs}, n_{n}, k_{k}
1212{
13- iters_per_tile_ = integer_divide_ceil (k, KPerBlock);
14- num_tiles_ = integer_divide_ceil (m, MPerBlock) * integer_divide_ceil (n_, NPerBlock);
13+ iters_per_tile_ = integer_divide_ceil (k, KPerBlock);
14+ num_tiles_ = integer_divide_ceil (m, MPerBlock) * integer_divide_ceil (n_, NPerBlock);
15+ remainder_along_k_ = k % KPerBlock;
1516
1617 bool big_enough = num_tiles_ > max_active_wgs_;
1718 index_t remainder_tiles = num_tiles_ % max_active_wgs_;
@@ -250,6 +251,21 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_n() c
250251 return n_;
251252}
252253
254+ template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
255+ CK_TILE_HOST_DEVICE index_t
256+ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_k() const noexcept
257+ {
258+ return k_;
259+ }
260+
261+ template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
262+ CK_TILE_HOST_DEVICE index_t
263+ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_remainder_along_k()
264+ const noexcept
265+ {
266+ return remainder_along_k_;
267+ }
268+
253269template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
254270CK_TILE_HOST index_t
255271StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::estimate_num_wgs_per_tile()
@@ -334,6 +350,29 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::remap_xcd
334350 return block_1d_id;
335351}
336352
353+ template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
354+ CK_TILE_DEVICE index_t
355+ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_k_size(
356+ index_t num_macro_tiles, index_t local_iter_end) const noexcept
357+ {
358+ // Determine if this workgroup is responsible for the last macro tile in the K dimension
359+ bool last_tile = get_iters_per_tile () == local_iter_end;
360+ index_t k_size;
361+ // If there is no remainder or if the workgroup was not assigned the last macro tile along K,
362+ // then their k_size will be a multiple of KPerBlock.
363+ if (!remainder_along_k_ || !last_tile)
364+ {
365+ k_size = num_macro_tiles * KPerBlock;
366+ }
367+ // Otherwise, there's a remainder. So, k_size is not a multiple of KPerBlock.
368+ else
369+ {
370+ k_size = (num_macro_tiles - 1 ) * KPerBlock + remainder_along_k_;
371+ }
372+
373+ return k_size;
374+ }
375+
337376template <typename BlockGemmShapeType,
338377 StreamKReductionStrategy ReductionStrategyType,
339378 bool Persistent>
0 commit comments