Skip to content

Commit 4c0e73a

Browse files
hyoon1assistant-librarian[bot]
authored andcommitted
[rocm-libraries] ROCm/rocm-libraries#6156 (commit 367565a)
[CK_TILE] Optimize FMHA head-dim padded path on gfx11/gfx12 (#6156) ## Motivation On gfx11/gfx12, FMHA forward kernels that require head-dim padding show a large performance drop compared to the exact-head-dim path. In practice, padded cases such as `HDIM=72` and `HDIM=80` were falling too far off the fast path. This PR improves padded-head-dim FMHA performance on gfx11/gfx12 while keeping the behavior for other GPUs unchanged. ## Technical Details - Add/scope a dedicated padded-head-dim (`qr_hpad`) FMHA forward path for gfx11/gfx12. - For `receipt=0`, keep support conservative and only enable the padded fast path for vector-safe cases (`head_dim % 8 == 0`), matching the existing assumption used on other GPUs. - Move `v_prefetch` later only for the head-dim-padded path on gfx11/gfx12. This reduces live ranges and removes the register-spill behavior seen in the earlier scheduling. - Enable the buffer-load OOB check offset trick for the padded path on gfx11/gfx12. ## Test Plan ./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16 -d={72/80} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1} ## Test Result Observed padded-head-dim performance improvements for HDIM=72/80: - gfx11: about ~3.5x - gfx1151: about ~2.0x - gfx12: about ~1.3x ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent 7d6c8e5 commit 4c0e73a

4 files changed

Lines changed: 144 additions & 26 deletions

File tree

example/ck_tile/01_fmha/codegen/cpp_symbol_map.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def get_mask_cpp_check_expr(mask: str) -> str:
139139

140140
PIPELINE_MAP = {
141141
"qr": "ck_tile::BlockFmhaPipelineQRKSVS",
142+
"qr_hpad": "ck_tile::BlockFmhaPipelineQRKSVSHpad",
142143
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
143144
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
144145
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
@@ -147,6 +148,7 @@ def get_mask_cpp_check_expr(mask: str) -> str:
147148

148149
PIPELINE_ENUM_MAP = {
149150
"qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
151+
"qr_hpad": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_HPAD",
150152
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
151153
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
152154
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",

example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,22 @@
6060
#include "fmha_fwd.hpp"
6161
"""
6262

63+
FMHA_FWD_KERNEL_HEADER_QR_HPAD = """// SPDX-License-Identifier: MIT
64+
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
65+
// auto generated by generate.py
66+
#if defined(__HIP_DEVICE_COMPILE__) && \
67+
(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
68+
defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
69+
defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) || \
70+
defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__))
71+
#if !defined(CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK)
72+
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
73+
#endif
74+
#endif
75+
#include "ck_tile/ops/fmha/block/variants.hpp"
76+
#include "fmha_fwd.hpp"
77+
"""
78+
6379
FMHA_FWD_KERNEL_BODY_TEMPLATE = """
6480
#include <iostream>
6581
@@ -300,7 +316,7 @@ def scheck(self) -> str:
300316
return "true" # always support
301317
else:
302318
return "true"
303-
elif self.pipeline_tag in ["qr", "qs"]:
319+
elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
304320
if self.spad == "t":
305321
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
306322
else:
@@ -323,7 +339,7 @@ def skcheck(self) -> str:
323339
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
324340
else:
325341
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
326-
elif self.pipeline_tag in ["qr", "qs"]:
342+
elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
327343
if self.skpad == "t":
328344
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
329345
else:
@@ -344,6 +360,11 @@ def dcheck(self) -> str:
344360
return f"a.hdim_q % {vec} == 0"
345361
else:
346362
assert False
363+
elif self.pipeline_tag == "qr_hpad":
364+
if self.dpad == "t":
365+
return "a.hdim_q % 8 == 0"
366+
else:
367+
assert False
347368
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
348369
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
349370
if self.dpad == "t":
@@ -361,6 +382,11 @@ def dvcheck(self) -> str:
361382
return f"a.hdim_v % {vec} == 0"
362383
else:
363384
assert False
385+
elif self.pipeline_tag == "qr_hpad":
386+
if self.dvpad == "t":
387+
return "a.hdim_v % 8 == 0"
388+
else:
389+
assert False
364390
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
365391
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
366392
if self.dvpad == "t":
@@ -634,6 +660,7 @@ class FmhaFwdKernel:
634660
F_pipeline: FmhaFwdPipeline
635661

636662
_KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER
663+
_KERNEL_HEADER_QR_HPAD: ClassVar[str] = FMHA_FWD_KERNEL_HEADER_QR_HPAD
637664
_KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE
638665

639666
@classmethod
@@ -643,6 +670,12 @@ def _get_cpp_kernel_class_name(cls, pipeline_tag):
643670
else:
644671
return "ck_tile::FmhaFwdKernel"
645672

673+
@classmethod
674+
def _get_kernel_header(cls, pipeline_tag):
675+
if pipeline_tag == "qr_hpad":
676+
return cls._KERNEL_HEADER_QR_HPAD
677+
return cls._KERNEL_HEADER
678+
646679
@classmethod
647680
def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
648681
if pipeline_tag == "qr_async_trload_v3":
@@ -651,7 +684,9 @@ def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
651684
return "fmha_fwd_create_kargs_and_grids"
652685

653686
def render(self) -> str:
654-
return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format(
687+
return type(self)._get_kernel_header(self.F_pipeline.tag) + type(
688+
self
689+
)._KERNEL_BODY_TEMPLATE.format(
655690
F_kname=self.name,
656691
F_arch=self.F_arch,
657692
F_hdim=self.F_hdim,
@@ -1144,6 +1179,37 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
11441179
def supported_dtypes(cls) -> Tuple[str]:
11451180
return cls._DT_FP16_BF16
11461181

1182+
@classmethod
1183+
def get_rules(cls) -> List[CompatibilityRule]:
1184+
rules = super().get_rules()
1185+
1186+
# For gfx11 fp16/bf16 d128, use dpad=dvpad=t for the 64x32 tile:
1187+
# the exact-hdim variant (dpad=dvpad=f) is much slower here.
1188+
def check_d128_tile_pipeline(
1189+
problem_ctx: ProblemContext, kernel_ctx: KernelContext
1190+
) -> bool:
1191+
if problem_ctx.dtype not in cls._DT_FP16_BF16:
1192+
return True
1193+
1194+
if (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128):
1195+
return True
1196+
1197+
is_64x32_tile = kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 32
1198+
pads_hdim = (
1199+
kernel_ctx.pipeline.F_dpad == "t" and kernel_ctx.pipeline.F_dvpad == "t"
1200+
)
1201+
exact_hdim = (
1202+
kernel_ctx.pipeline.F_dpad == "f" and kernel_ctx.pipeline.F_dvpad == "f"
1203+
)
1204+
1205+
if is_64x32_tile:
1206+
return pads_hdim
1207+
1208+
return exact_hdim
1209+
1210+
rules.append(check_d128_tile_pipeline)
1211+
return rules
1212+
11471213
@classmethod
11481214
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
11491215
if dtype in cls._DT_FP16_BF16:
@@ -1152,7 +1218,8 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
11521218
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
11531219
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
11541220
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
1155-
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
1221+
(128, 128) : [FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, 6, CppConstraint("a.hdim_q != 128 || a.hdim_v != 128")),
1222+
FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
11561223
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
11571224
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
11581225
(256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)]
@@ -1179,7 +1246,9 @@ def get_pipelines(
11791246
# Keep only ttff/tttt for gfx11: ffff path is often similar or worse
11801247
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
11811248
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
1182-
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
1249+
pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
1250+
if receipt == 1:
1251+
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
11831252
return pipelines
11841253

11851254

@@ -1251,7 +1320,9 @@ def get_pipelines(
12511320
):
12521321
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
12531322
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
1254-
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
1323+
pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
1324+
if receipt == 1:
1325+
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
12551326
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
12561327
# no need lse/dropout kernels
12571328
for logits, qscale, mask, bias in itertools.product(

include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ enum class BlockFmhaPipelineEnum
1313
QSKSVS,
1414
QRKSVS_ASYNC_TRLOAD,
1515
QRKSVS_ASYNC_TRLOAD_V3,
16+
QRKSVS_HPAD,
1617
};
1718

1819
template <BlockFmhaPipelineEnum>
@@ -40,4 +41,10 @@ struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD>
4041
static constexpr const char* name = "qr_async_trload";
4142
};
4243

44+
template <>
45+
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_HPAD>
46+
{
47+
static constexpr const char* name = "qr_hpad";
48+
};
49+
4350
} // namespace ck_tile

include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
namespace ck_tile {
1515

1616
// This pipeline is qkv all located in LDS
17-
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
17+
template <typename Problem_,
18+
typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy,
19+
bool PaddedVecLoadStore_ = false>
1820
struct BlockFmhaPipelineQRKSVS
1921
{
2022
using Problem = remove_cvref_t<Problem_>;
@@ -54,17 +56,18 @@ struct BlockFmhaPipelineQRKSVS
5456

5557
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
5658

57-
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
58-
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
59-
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
60-
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
61-
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
62-
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
63-
static constexpr auto BiasEnum = Problem::BiasEnum;
64-
static constexpr bool kStoreLSE = Problem::kStoreLSE;
65-
static constexpr bool kHasDropout = Problem::kHasDropout;
66-
static constexpr auto QScaleEnum = Problem::QScaleEnum;
67-
static constexpr bool kHasSink = Problem::kHasSink;
59+
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
60+
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
61+
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
62+
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
63+
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
64+
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
65+
static constexpr auto BiasEnum = Problem::BiasEnum;
66+
static constexpr bool kStoreLSE = Problem::kStoreLSE;
67+
static constexpr bool kHasDropout = Problem::kHasDropout;
68+
static constexpr auto QScaleEnum = Problem::QScaleEnum;
69+
static constexpr bool kHasSink = Problem::kHasSink;
70+
static constexpr bool kPaddedVecLoadStore = PaddedVecLoadStore_;
6871

6972
static constexpr ck_tile::index_t kQKScaleGranularity = Problem::kQKScaleGranularity;
7073
static constexpr ck_tile::index_t kVScaleGranularity = Problem::kVScaleGranularity;
@@ -80,23 +83,29 @@ struct BlockFmhaPipelineQRKSVS
8083
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
8184
!kHasLogitsSoftCap)) ||
8285
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
86+
static_assert(!kPaddedVecLoadStore || (kPadHeadDimQ && kPadHeadDimV),
87+
"padded vector load/store fast path only applies to padded head-dim kernels");
8388

8489
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
8590
// ... together with tensor distribution. tensor dist should able to overwrite this
86-
static constexpr index_t kAlignmentQ = kPadHeadDimQ ? numeric_traits<QDataType>::PackedSize
87-
: Policy::template GetAlignmentQ<Problem>();
88-
static constexpr index_t kAlignmentK = kPadHeadDimQ ? numeric_traits<KDataType>::PackedSize
89-
: Policy::template GetAlignmentK<Problem>();
91+
static constexpr index_t kAlignmentQ = (kPadHeadDimQ && !kPaddedVecLoadStore)
92+
? numeric_traits<QDataType>::PackedSize
93+
: Policy::template GetAlignmentQ<Problem>();
94+
static constexpr index_t kAlignmentK = (kPadHeadDimQ && !kPaddedVecLoadStore)
95+
? numeric_traits<KDataType>::PackedSize
96+
: Policy::template GetAlignmentK<Problem>();
9097
static constexpr index_t kAlignmentV = []() {
9198
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
92-
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
99+
return (kPadHeadDimV && !kPaddedVecLoadStore)
100+
? 1
101+
: Policy::template GetAlignmentV<Problem>();
93102
else
94103
return kPadSeqLenK ? numeric_traits<VDataType>::PackedSize
95104
: Policy::template GetAlignmentV<Problem>();
96105
}();
97106

98107
static constexpr index_t kAlignmentO =
99-
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
108+
(kPadHeadDimV && !kPaddedVecLoadStore) ? 1 : Policy::template GetAlignmentO<Problem>();
100109
static constexpr index_t kAlignmentBias =
101110
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
102111
static constexpr index_t kAlignmentRandVal =
@@ -548,8 +557,25 @@ struct BlockFmhaPipelineQRKSVS
548557
});
549558
}
550559

551-
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
552-
{ // tail
560+
auto v_prefetch = decltype(load_tile(v_dram_window)){};
561+
enum class VPrefetchPoint
562+
{
563+
BeforeGemm0Tail,
564+
AfterGemm0Tail,
565+
AfterSoftmax
566+
};
567+
568+
#if defined(__gfx11__) || defined(__gfx12__)
569+
constexpr auto kVPrefetch =
570+
kPadHeadDimV ? VPrefetchPoint::AfterSoftmax : VPrefetchPoint::AfterGemm0Tail;
571+
#else
572+
constexpr auto kVPrefetch = VPrefetchPoint::BeforeGemm0Tail;
573+
#endif
574+
if constexpr(kVPrefetch == VPrefetchPoint::BeforeGemm0Tail)
575+
{
576+
load_tile(v_prefetch, v_dram_window); // prefetch load v tile
577+
}
578+
{ // tail
553579
block_sync_lds();
554580
run_gemm_0(number<k0_loops - 2>{});
555581
block_sync_lds();
@@ -562,6 +588,10 @@ struct BlockFmhaPipelineQRKSVS
562588

563589
run_gemm_0(number<k0_loops - 1>{});
564590
}
591+
if constexpr(kVPrefetch == VPrefetchPoint::AfterGemm0Tail)
592+
{
593+
load_tile(v_prefetch, v_dram_window);
594+
}
565595
// dequant
566596
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
567597
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
@@ -819,6 +849,11 @@ struct BlockFmhaPipelineQRKSVS
819849
randval_ptr, seq_offset, p_compute, randval_dram_window);
820850
}
821851

852+
if constexpr(kVPrefetch == VPrefetchPoint::AfterSoftmax)
853+
{
854+
load_tile(v_prefetch, v_dram_window);
855+
}
856+
822857
block_sync_lds();
823858
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
824859
{
@@ -1098,4 +1133,7 @@ struct BlockFmhaPipelineQRKSVS
10981133
}
10991134
};
11001135

1136+
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
1137+
using BlockFmhaPipelineQRKSVSHpad = BlockFmhaPipelineQRKSVS<Problem_, Policy_, true>;
1138+
11011139
} // namespace ck_tile

0 commit comments

Comments
 (0)