Skip to content

Commit e809861

Browse files
refactor: remove Default scheduler implementation as it not used anymore (#3542)
* refactor: remove Default scheduler implementation as it not used anymore * refactor: remove dead code from gemm universal kernel * chore: add descriptive comments about amd intrinsic hardware sync instructions * fix: label existing memory pipeline for aquant as intrawave
1 parent 18c2ff6 commit e809861

4 files changed

Lines changed: 15 additions & 87 deletions

File tree

include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp

Lines changed: 11 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -194,83 +194,6 @@ struct BlockUniversalGemmAsBsCr
194194
{
195195
};
196196

197-
template <typename GemmTraits>
198-
struct BlockGemmImpl<GemmPipelineScheduler::Default, GemmTraits>
199-
{
200-
static constexpr auto ALdsTileDistr =
201-
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
202-
static constexpr auto BLdsTileDistr =
203-
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
204-
205-
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
206-
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
207-
208-
ALdsTile a_warp_tile_;
209-
BLdsTile b_warp_tile_;
210-
211-
// C += A * B
212-
template <typename CBlockTensor,
213-
typename ASmemBlockWindow,
214-
typename BSmemBlockWindow,
215-
bool ALoadTranspose = false,
216-
bool BLoadTranspose = false>
217-
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
218-
const ASmemBlockWindow& a_block_window,
219-
const BSmemBlockWindow& b_block_window,
220-
bool_constant<ALoadTranspose> = {},
221-
bool_constant<BLoadTranspose> = {})
222-
{
223-
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
224-
"The CDataType as defined in traits should be the same as correspoinding "
225-
"C block tensor data type!");
226-
static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
227-
std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
228-
"The ADataType and BDataType as defined in "
229-
"traits should be the same as correspoinding block window data type!");
230-
231-
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
232-
a_block_window);
233-
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
234-
b_block_window);
235-
// hot loop:
236-
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
237-
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
238-
// read A warp tensor from A block tensor
239-
AWarpTensor a_warp_tensor;
240-
241-
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
242-
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
243-
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
244-
245-
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
246-
// read B warp tensor from B block tensor
247-
BWarpTensor b_warp_tensor;
248-
249-
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
250-
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
251-
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
252-
253-
// read C warp tensor from C block tensor-
254-
CWarpTensor c_warp_tensor;
255-
256-
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
257-
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
258-
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
259-
260-
// warp GEMM
261-
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
262-
263-
// write C warp tensor into C block tensor
264-
c_block_tensor.set_y_sliced_thread_data(
265-
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
266-
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
267-
c_warp_tensor.get_thread_buffer());
268-
});
269-
});
270-
});
271-
}
272-
};
273-
274197
template <typename GemmTraits>
275198
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
276199
{
@@ -450,7 +373,9 @@ struct BlockUniversalGemmAsBsCr
450373
// hot loop:
451374
static_for<0, KRepeat, 1>{}([&](auto kIter) {
452375
LocalPrefetch<kIter.value>(a_block_window, b_block_window, a_load_tr, b_load_tr);
453-
__builtin_amdgcn_sched_barrier(0);
376+
__builtin_amdgcn_sched_barrier(
377+
0); // Complete scheduling all pending instruction groups before this point
378+
454379
// NOTE: Synchronize threads in a workgroup at the start of each MAC
455380
// cluster, but except the first, as we can shorten non-MAC cluster a bit
456381
// and there's no observable negative impact. The desired effect is waves in
@@ -460,8 +385,14 @@ struct BlockUniversalGemmAsBsCr
460385
// sync point.
461386
if constexpr(kIter.value != 0 || KRepeat == 1)
462387
{
463-
__builtin_amdgcn_s_barrier();
464-
__builtin_amdgcn_sched_barrier(0);
388+
// This pattern ensures:
389+
// At runtime: All waves synchronize (hardware barrier)
390+
// At compile-time: Instructions after the barrier don't get moved before it
391+
// (scheduling barrier)
392+
__builtin_amdgcn_s_barrier(); // Blocks execution until all waves (threads) in
393+
// the workgroup reach this point
394+
__builtin_amdgcn_sched_barrier(
395+
0); // Prevents instruction reordering across this boundary
465396
}
466397

467398
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {

include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,7 +1035,6 @@ struct UniversalGemmKernel
10351035
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
10361036
*
10371037
*/
1038-
template <bool UseDefaultScheduler = true>
10391038
CK_TILE_DEVICE static void RunGemm(const std::array<const ADataType*, NumATensor>& as_ptr,
10401039
const std::array<const BDataType*, NumBTensor>& bs_ptr,
10411040
const std::array<const void*, NumDTensor>& ds_ptr,
@@ -1161,9 +1160,7 @@ struct UniversalGemmKernel
11611160
// allocate LDS
11621161
__shared__ char smem_ptr[GetSmemSize()];
11631162

1164-
constexpr auto scheduler_type =
1165-
GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1);
1166-
RunGemm<scheduler_type>(
1163+
RunGemm(
11671164
as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
11681165
}
11691166

include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ struct GemmPipelineProblemBase
8080
static constexpr bool kPadK = Traits::kPadK;
8181

8282
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
83-
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
83+
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
8484
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
8585

8686
// In the base situation, the Preshuffle setting should be false.

include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
164164
};
165165

166166
template <>
167-
struct PipelineImpl<GemmPipelineScheduler::Interwave> : public PipelineImplBase
167+
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
168168
{
169169
using Base = PipelineImplBase;
170170

@@ -491,7 +491,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
491491
void* p_smem,
492492
index_t m = 0) const
493493
{
494-
return PipelineImpl<GemmPipelineScheduler::Interwave>{}
494+
return PipelineImpl<GemmPipelineScheduler::Intrawave>{}
495495
.template operator()<HasHotLoop, TailNum>(
496496
a_dram_block_window_tmp,
497497
[](const BDataType& a) { return a; },

0 commit comments

Comments
 (0)