Skip to content

Commit a7ed94f

Browse files
illsilinex-rzr
andauthored
[CK_TILE] FMHA Reduce register spilling in fwd with dropout (workaround for CI failures with clang-22) (#3221) (#3372)
* Use vectorized stores for dropout randvals With no kPadSeqLenK the kernel uses 2 buffer_store_dwordx2 instead of 16 buffer_store_byte. This requires less registers and reduces spilling. * Calculate dropout randvals for storing and applying only once Even though it may add a small overhead when storing is not required, it uses significantly less registers and hence no spilling. Co-authored-by: Anton Gorenko <anton@streamhpc.com>
1 parent 6e9f0e9 commit a7ed94f

11 files changed

Lines changed: 37 additions & 16 deletions

include/ck_tile/ops/fmha/block/block_dropout.hpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -333,23 +333,15 @@ struct BlockDropout
333333
return randval;
334334
};
335335

336-
if(is_store_randval)
337-
{
338-
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
339-
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
340-
const auto randval = generate_randval(i_m0, i_n0);
341-
// save to Global
342-
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
343-
store_tile(randval_dram_window, randval_store);
344-
move_tile_window(randval_dram_window, {0, kNPerStep});
345-
});
346-
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
347-
});
348-
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
349-
}
350336
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
351337
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
352338
const auto randval = generate_randval(i_m0, i_n0);
339+
if(is_store_randval)
340+
{
341+
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
342+
store_tile(randval_dram_window, randval_store);
343+
}
344+
move_tile_window(randval_dram_window, {0, kNPerStep});
353345
// Drop values of P based on the generated probabilities
354346
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
355347
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
@@ -369,7 +361,9 @@ struct BlockDropout
369361
});
370362
});
371363
});
364+
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
372365
});
366+
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
373367
}
374368

375369
const unsigned long long ph_seed;

include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1005,7 +1005,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
10051005
rand_val_ptr,
10061006
make_tuple(kargs.seqlen_q, kargs.seqlen_k),
10071007
make_tuple(kargs.stride_randval, 1),
1008-
number<1>{},
1008+
number<FmhaPipeline::kAlignmentRandVal>{},
10091009
number<1>{});
10101010

10111011
return pad_tensor_view(randval_dram_naive,

include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1450,7 +1450,7 @@ struct FmhaFwdKernel
14501450
rand_val_ptr,
14511451
make_tuple(kargs.seqlen_q, kargs.seqlen_k),
14521452
make_tuple(kargs.stride_randval, 1),
1453-
number<1>{},
1453+
number<FmhaPipeline::kAlignmentRandVal>{},
14541454
number<1>{});
14551455

14561456
return pad_tensor_view(randval_dram_naive,

include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
8080
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
8181
static constexpr index_t kAlignmentBias =
8282
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
83+
static constexpr index_t kAlignmentRandVal =
84+
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
8385

8486
static constexpr index_t kBlockPerCu = []() {
8587
if constexpr(Problem::kBlockPerCu != -1)

include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ struct BlockFmhaPipelineQRKSVS
8383
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
8484
static constexpr index_t kAlignmentBias =
8585
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
86+
static constexpr index_t kAlignmentRandVal =
87+
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
8688

8789
static constexpr index_t kBlockPerCu = []() {
8890
if constexpr(Problem::kBlockPerCu != -1)

include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ struct BlockFmhaPipelineQRKSVSAsync
8181
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
8282
static constexpr index_t kAlignmentBias =
8383
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
84+
static constexpr index_t kAlignmentRandVal =
85+
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
8486

8587
#if CK_TILE_FMHA_FWD_FAST_EXP2
8688
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;

include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
9090

9191
static constexpr index_t kAlignmentBias =
9292
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
93+
static constexpr index_t kAlignmentRandVal =
94+
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
9395

9496
static constexpr index_t kBlockPerCu = []() {
9597
if constexpr(Problem::kBlockPerCu != -1)

include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
6969
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
7070
static constexpr index_t kAlignmentBias =
7171
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
72+
static constexpr index_t kAlignmentRandVal =
73+
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
7274

7375
static constexpr index_t kBlockPerCu = []() {
7476
if constexpr(Problem::kBlockPerCu != -1)

include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
7474
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
7575
static constexpr index_t kAlignmentBias =
7676
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
77+
static constexpr index_t kAlignmentRandVal =
78+
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
7779

7880
static constexpr index_t kBlockPerCu = []() {
7981
if constexpr(Problem::kBlockPerCu != -1)

include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ struct BlockFmhaPipelineQSKSVS
7878
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
7979
static constexpr index_t kAlignmentBias =
8080
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
81+
static constexpr index_t kAlignmentRandVal =
82+
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
8183

8284
static constexpr index_t kBlockPerCu = []() {
8385
if constexpr(Problem::kBlockPerCu != -1)

0 commit comments

Comments
 (0)