@@ -194,83 +194,6 @@ struct BlockUniversalGemmAsBsCr
194194 {
195195 };
196196
197- template <typename GemmTraits>
198- struct BlockGemmImpl <GemmPipelineScheduler::Default, GemmTraits>
199- {
200- static constexpr auto ALdsTileDistr =
201- decltype (make_static_tile_distribution(MakeABlockDistributionEncode())){};
202- static constexpr auto BLdsTileDistr =
203- decltype (make_static_tile_distribution(MakeBBlockDistributionEncode())){};
204-
205- using ALdsTile = decltype (make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
206- using BLdsTile = decltype (make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
207-
208- ALdsTile a_warp_tile_;
209- BLdsTile b_warp_tile_;
210-
211- // C += A * B
212- template <typename CBlockTensor,
213- typename ASmemBlockWindow,
214- typename BSmemBlockWindow,
215- bool ALoadTranspose = false ,
216- bool BLoadTranspose = false >
217- CK_TILE_DEVICE void operator ()(CBlockTensor& c_block_tensor,
218- const ASmemBlockWindow& a_block_window,
219- const BSmemBlockWindow& b_block_window,
220- bool_constant<ALoadTranspose> = {},
221- bool_constant<BLoadTranspose> = {})
222- {
223- static_assert (std::is_same_v<CDataType, typename CBlockTensor::DataType>,
224- " The CDataType as defined in traits should be the same as correspoinding "
225- " C block tensor data type!" );
226- static_assert (std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
227- std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
228- " The ADataType and BDataType as defined in "
229- " traits should be the same as correspoinding block window data type!" );
230-
231- load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
232- a_block_window);
233- load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
234- b_block_window);
235- // hot loop:
236- static_for<0 , GemmTraits::KIterPerWarp, 1 >{}([&](auto kIter ) {
237- static_for<0 , MIterPerWarp, 1 >{}([&](auto mIter ) {
238- // read A warp tensor from A block tensor
239- AWarpTensor a_warp_tensor;
240-
241- a_warp_tensor.get_thread_buffer () = a_warp_tile_.get_y_sliced_thread_data (
242- merge_sequences (sequence<mIter , kIter >{}, a_warp_y_index_zeros),
243- merge_sequences (sequence<1 , 1 >{}, a_warp_y_lengths));
244-
245- static_for<0 , NIterPerWarp, 1 >{}([&](auto nIter) {
246- // read B warp tensor from B block tensor
247- BWarpTensor b_warp_tensor;
248-
249- b_warp_tensor.get_thread_buffer () = b_warp_tile_.get_y_sliced_thread_data (
250- merge_sequences (sequence<nIter, kIter >{}, b_warp_y_index_zeros),
251- merge_sequences (sequence<1 , 1 >{}, b_warp_y_lengths));
252-
253- // read C warp tensor from C block tensor-
254- CWarpTensor c_warp_tensor;
255-
256- c_warp_tensor.get_thread_buffer () = c_block_tensor.get_y_sliced_thread_data (
257- merge_sequences (sequence<mIter , nIter>{}, c_warp_y_index_zeros),
258- merge_sequences (sequence<1 , 1 >{}, c_warp_y_lengths));
259-
260- // warp GEMM
261- WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
262-
263- // write C warp tensor into C block tensor
264- c_block_tensor.set_y_sliced_thread_data (
265- merge_sequences (sequence<mIter , nIter>{}, c_warp_y_index_zeros),
266- merge_sequences (sequence<1 , 1 >{}, c_warp_y_lengths),
267- c_warp_tensor.get_thread_buffer ());
268- });
269- });
270- });
271- }
272- };
273-
274197 template <typename GemmTraits>
275198 struct BlockGemmImpl <GemmPipelineScheduler::Intrawave, GemmTraits>
276199 {
@@ -450,7 +373,9 @@ struct BlockUniversalGemmAsBsCr
450373 // hot loop:
451374 static_for<0 , KRepeat, 1 >{}([&](auto kIter ) {
452375 LocalPrefetch<kIter .value >(a_block_window, b_block_window, a_load_tr, b_load_tr);
453- __builtin_amdgcn_sched_barrier (0 );
376+ __builtin_amdgcn_sched_barrier (
377+ 0 ); // Complete scheduling all pending instruction groups before this point
378+
454379 // NOTE: Synchronize threads in a workgroup at the start of each MAC
455380 // cluster, but except the first, as we can shorten non-MAC cluster a bit
456381 // and there's no observable negative impact. The desired effect is waves in
@@ -460,8 +385,14 @@ struct BlockUniversalGemmAsBsCr
460385 // sync point.
461386 if constexpr (kIter .value != 0 || KRepeat == 1 )
462387 {
463- __builtin_amdgcn_s_barrier ();
464- __builtin_amdgcn_sched_barrier (0 );
388+ // This pattern ensures:
389+ // At runtime: All waves synchronize (hardware barrier)
390+ // At compile-time: Instructions after the barrier don't get moved before it
391+ // (scheduling barrier)
392+ __builtin_amdgcn_s_barrier (); // Blocks execution until all waves (threads) in
393+ // the workgroup reach this point
394+ __builtin_amdgcn_sched_barrier (
395+ 0 ); // Prevents instruction reordering across this boundary
465396 }
466397
467398 static_for<0 , KInnerLoopIter, 1 >{}([&](auto kInnerIter ) {
0 commit comments