Skip to content

Commit 865ab2b

Browse files
qianfengzassistant-librarian[bot]
authored andcommitted
[rocm-libraries] ROCm/rocm-libraries#6209 (commit 89c9f3e)
Improve the performance of qr_ks_vs_whole_k_prefetch pipeline (#6209) ## About qr_ks_vs_whole_k_prefetch pipeline This PR updates and enhances the qr_ks_vs_whole_k_prefetch pipeline to improve performance on both MI350 GPUs through better MFMA instruction usage, transposed V-loading support, and N0-loop implementation. The pipeline targets scenarios where the number of workgroups is low, enabling better CU occupancy by using smaller MTile sizes (kM0=64 vs 128) while prefetching entire K tiles. ## Changes: - Adds transposed V-loading support (qr_ks_vs_whole_k_prefetch_trload) to avoid using shuffle instructions on MI350 - Implements N0-loop based Gemm0 to reduce tile window movement overhead and eliminate `clear_tile` calls - Adds full support for hdim96/hdim160 without padding requirements - Updates MFMA instruction selection to ensure optimal choices for MI350 ## Performance results 1. For attention shapes which leads to kM0=64, `qr_ks_vs_async_whole_k_prefetch_trload` shows much better performance than `qr_ks_vs_async_trload` on the same case (execution time `41.02ms` by whole_k_prefetch_trload & `58.50ms` by async_load), and `qr_ks_vs_async_whole_k_prefetch_trload` also shows obviously better performance than the recently tuned `qr_ks_vs_async` on the same case (execution time `41.02ms` by whole_k_prefetch_trload 7 `47.60ms` by qr_ks_vs_async) 2. Also on MI300, for attention shapes which leads to kM0=64, `qr_ks_vs_async_whole_k_prefetch` shows much better performance than the `qr_ks_vs_async` (which is supposed to be very high-efficient) on the same case (execution time `64.50ms` by whole_k_prefetch & `80.20ms` by qr_ks_vs_async) 3. For attention shapes which leads to kM0=128, `qr_ks_vs_async_whole_k_prefetch_trload` show a little bit better performance than `qr_ks_vs_async` on mi350 (execution time `104.50ms` by whole_k_prefetch_trload & `106.50ms` by qr_ks_vs_async). And they shows completely on-par performance on MI300 ## Test/Verify 1. Use the ROCM xformers branch `test_whole_k_prefetch_n0loop` to test/verify qr_ks_vs_whole_k_prefetch pipeline since this pipeline can not be used by ck_tile fmha example so far 2. Use the following command-line for building/testing xformers >```bash > #> git clone -b test_whole_k_prefetch_n0loop https://github.com/ROCm/xformers > #> git submodule update --init --recursive > #> pip install --no-build-isolation -e ./ > #> pytest tests/test_mem_eff_attention.py::test_forward >``` 4. Any scripts which can run on xformers can be used to evaluate qr_ks_vs_whole_k_prefetch pipeline. Using the two environ variable to switch from using different pipelines > ```bash > #> export FMHA_DISABLE_SPECIAL_TREATMENT=1 #> to disable using FAV3 and qr_ks_vs_async_trload pipeline > #> export FMHA_ENABLE_ASYNC_PIPELINE=1 #> to disable using qr_ks_vs_async pipeline for comparing > ``` ## Discussion
1 parent b2ea5fd commit 865ab2b

12 files changed

Lines changed: 2875 additions & 799 deletions

include/ck_tile/host/rotating_buffers.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "ck_tile/core/config.hpp"
77
#include "ck_tile/host/hip_check_error.hpp"
88
#include <hip/hip_runtime.h>
9+
#include <iostream>
910

1011
namespace ck_tile {
1112

include/ck_tile/ops/fmha.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
5757
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
5858
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp"
59+
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp"
5960
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
6061
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
6162
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"

include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp

Lines changed: 159 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,83 @@
3232

3333
namespace ck_tile {
3434

35+
namespace detail {
36+
37+
// A helper struct for detecting n0loop
38+
template <typename T, typename = void>
39+
struct has_n0loop_flag : std::false_type
40+
{
41+
};
42+
43+
template <typename T>
44+
struct has_n0loop_flag<
45+
T,
46+
std::enable_if_t<std::is_convertible_v<decltype(T::kUseN0Loop), bool> && T::kUseN0Loop>>
47+
: std::true_type
48+
{
49+
};
50+
51+
template <typename T>
52+
static inline constexpr bool is_n0loop_pipeline_v = has_n0loop_flag<T>::value;
53+
54+
// A helper struct for detecting ignore_fast_exp2 flag
55+
template <typename T, typename = void>
56+
struct has_ignore_fast_exp2_flag : std::false_type
57+
{
58+
};
59+
60+
// IgnoreFastExp2 is used by some pipeline which explicitly chooses not to use FAST_EXP2;
61+
// By detecting the kIgnoreFastExp2 from the pipeline, the kernel's MakeKargsImpl() interface
62+
// is able to avoid passing an in-correct scale_s parameter to the kernel layer
63+
template <typename T>
64+
struct has_ignore_fast_exp2_flag<
65+
T,
66+
std::enable_if_t<std::is_convertible_v<decltype(T::kIgnoreFastExp2), bool> &&
67+
T::kIgnoreFastExp2>> : std::true_type
68+
{
69+
};
70+
71+
template <typename T>
72+
static inline constexpr bool ignore_fast_exp2_v = has_ignore_fast_exp2_flag<T>::value;
73+
74+
// A helper struct for detecting naive_hdim_load, naive_hdim_load means load tiles of
75+
// hdim96/hdim160/hdim192 without padding the tensor_view/tile_window to hdim128/hdim256
76+
// naive_hdim_load is current supported by the qr_ks_vs_whole_k_prefetch_pipeline
77+
template <typename T, typename = void>
78+
struct has_naive_hdim_load_flag : std::false_type
79+
{
80+
};
81+
82+
template <typename T>
83+
struct has_naive_hdim_load_flag<
84+
T,
85+
std::enable_if_t<std::is_convertible_v<decltype(T::kIsNaiveHDimLoad), bool> &&
86+
T::kIsNaiveHDimLoad>> : std::true_type
87+
{
88+
};
89+
90+
template <typename T>
91+
static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag<T>::value;
92+
93+
// A helper struct for detecting kUseTrLoad
94+
template <typename T, typename = void>
95+
struct has_use_trload_flag : std::false_type
96+
{
97+
};
98+
99+
template <typename T>
100+
struct has_use_trload_flag<
101+
T,
102+
std::enable_if_t<std::is_convertible_v<decltype(T::kUseTrLoad), bool> && T::kUseTrLoad>>
103+
: std::true_type
104+
{
105+
};
106+
107+
template <typename T>
108+
static inline constexpr bool is_using_trload_v = has_use_trload_flag<T>::value;
109+
110+
} // namespace detail
111+
35112
template <typename FmhaPipeline_, typename EpiloguePipeline_>
36113
struct FmhaFwdKernel
37114
{
@@ -74,13 +151,14 @@ struct FmhaFwdKernel
74151
static constexpr bool kHasMask = FmhaMask::IsMasking;
75152

76153
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
154+
static constexpr bool kUseTrLoad = detail::is_using_trload_v<FmhaPipeline>;
77155

78-
static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
79156
#if defined(__gfx950__)
80157
static constexpr bool kIsAvailable = true;
81158
#else
82159
static constexpr bool kIsAvailable = !kUseTrLoad;
83160
#endif
161+
84162
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
85163

86164
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
@@ -441,7 +519,9 @@ struct FmhaFwdKernel
441519
num_head_q,
442520
nhead_ratio_qk,
443521
#if CK_TILE_FMHA_FWD_FAST_EXP2
444-
static_cast<float>(scale_s * ck_tile::log2e_v<>),
522+
detail::ignore_fast_exp2_v<FmhaPipeline>
523+
? scale_s
524+
: static_cast<float>(scale_s * ck_tile::log2e_v<>),
445525
#else
446526
scale_s,
447527
#endif
@@ -894,7 +974,9 @@ struct FmhaFwdKernel
894974
num_head_q,
895975
nhead_ratio_qk,
896976
#if CK_TILE_FMHA_FWD_FAST_EXP2
897-
static_cast<float>(scale_s * ck_tile::log2e_v<>),
977+
detail::ignore_fast_exp2_v<FmhaPipeline>
978+
? scale_s
979+
: static_cast<float>(scale_s * ck_tile::log2e_v<>),
898980
#else
899981
scale_s,
900982
#endif
@@ -1036,6 +1118,7 @@ struct FmhaFwdKernel
10361118
const void* seqlen_k_ptr,
10371119
const void* block_scale_seqstart_q_ptr,
10381120
const void* block_scale_seqstart_k_ptr,
1121+
const void* seqstart_v_scale_ptr,
10391122
ck_tile::index_t hdim_q,
10401123
ck_tile::index_t hdim_v,
10411124
ck_tile::index_t num_head_q,
@@ -1094,6 +1177,7 @@ struct FmhaFwdKernel
10941177
seqlen_k_ptr,
10951178
block_scale_seqstart_q_ptr,
10961179
block_scale_seqstart_k_ptr,
1180+
seqstart_v_scale_ptr,
10971181
hdim_q,
10981182
hdim_v,
10991183
num_head_q,
@@ -1155,6 +1239,7 @@ struct FmhaFwdKernel
11551239
const void* seqlen_k_ptr,
11561240
const void* block_scale_seqstart_q_ptr,
11571241
const void* block_scale_seqstart_k_ptr,
1242+
const void* seqstart_v_scale_ptr,
11581243
ck_tile::index_t hdim_q,
11591244
ck_tile::index_t hdim_v,
11601245
ck_tile::index_t num_head_q,
@@ -1213,6 +1298,7 @@ struct FmhaFwdKernel
12131298
seqlen_k_ptr,
12141299
block_scale_seqstart_q_ptr,
12151300
block_scale_seqstart_k_ptr,
1301+
seqstart_v_scale_ptr,
12161302
hdim_q,
12171303
hdim_v,
12181304
num_head_q,
@@ -1599,6 +1685,10 @@ struct FmhaFwdKernel
15991685
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
16001686
batch_offset_o;
16011687

1688+
constexpr index_t kQKHeaddimToUse = detail::is_naive_hdim_load_v<FmhaPipeline>
1689+
? FmhaPipeline::kQKHeaddim
1690+
: FmhaPipeline::kSubQKHeaddim;
1691+
16021692
// Q/K/V DRAM and DRAM window
16031693
const auto q_dram = [&]() {
16041694
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -1609,10 +1699,10 @@ struct FmhaFwdKernel
16091699
number<1>{});
16101700
if constexpr(FmhaPipeline::kQLoadOnce)
16111701
{
1612-
return pad_tensor_view(q_dram_naive,
1613-
make_tuple(number<FmhaPipeline::kM0>{},
1614-
number<FmhaPipeline::kSubQKHeaddim>{}),
1615-
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
1702+
return pad_tensor_view(
1703+
q_dram_naive,
1704+
make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{}),
1705+
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
16161706
}
16171707
else
16181708
{
@@ -1631,10 +1721,21 @@ struct FmhaFwdKernel
16311721
number<1>{});
16321722

16331723
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1634-
return pad_tensor_view(
1635-
k_dram_naive,
1636-
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
1637-
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
1724+
1725+
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
1726+
{
1727+
return pad_tensor_view(
1728+
k_dram_naive,
1729+
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
1730+
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
1731+
}
1732+
else
1733+
{
1734+
return pad_tensor_view(
1735+
k_dram_naive,
1736+
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
1737+
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
1738+
}
16381739
}();
16391740
const auto v_dram = [&]() {
16401741
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
@@ -1646,18 +1747,29 @@ struct FmhaFwdKernel
16461747
number<FmhaPipeline::kAlignmentV>{},
16471748
number<1>{});
16481749

1649-
const auto v_dram_transposed = transform_tensor_view(
1650-
v_dram_naive,
1651-
make_tuple(make_pass_through_transform(kargs.hdim_v),
1652-
make_pass_through_transform(kargs.seqlen_k)),
1653-
make_tuple(sequence<1>{}, sequence<0>{}),
1654-
make_tuple(sequence<0>{}, sequence<1>{}));
1750+
if constexpr(!kUseTrLoad)
1751+
{
1752+
const auto v_dram_transposed = transform_tensor_view(
1753+
v_dram_naive,
1754+
make_tuple(make_pass_through_transform(kargs.hdim_v),
1755+
make_pass_through_transform(kargs.seqlen_k)),
1756+
make_tuple(sequence<1>{}, sequence<0>{}),
1757+
make_tuple(sequence<0>{}, sequence<1>{}));
16551758

1656-
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1657-
return pad_tensor_view(
1658-
v_dram_transposed,
1659-
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
1660-
sequence<kPadHeadDimV, kPadSeqLenK_>{});
1759+
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1760+
1761+
return pad_tensor_view(
1762+
v_dram_transposed,
1763+
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
1764+
sequence<kPadHeadDimV, kPadSeqLenK_>{});
1765+
}
1766+
else
1767+
{
1768+
return pad_tensor_view(
1769+
v_dram_naive,
1770+
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
1771+
sequence<false, kPadHeadDimV>{});
1772+
};
16611773
}
16621774
else
16631775
{
@@ -1680,17 +1792,28 @@ struct FmhaFwdKernel
16801792
q_dram,
16811793
[&]() {
16821794
if constexpr(FmhaPipeline::kQLoadOnce)
1683-
return make_tuple(number<FmhaPipeline::kM0>{},
1684-
number<FmhaPipeline::kSubQKHeaddim>{});
1795+
return make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{});
16851796
else
16861797
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
16871798
}(),
16881799
{i_m0, 0});
16891800

1690-
auto k_dram_window = make_tile_window(
1691-
k_dram,
1692-
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
1693-
{0, 0});
1801+
auto k_dram_window = [&]() {
1802+
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
1803+
{
1804+
return make_tile_window(
1805+
k_dram,
1806+
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
1807+
{0, 0});
1808+
}
1809+
else
1810+
{
1811+
return make_tile_window(
1812+
k_dram,
1813+
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
1814+
{0, 0});
1815+
}
1816+
}();
16941817

16951818
auto v_dram_window = make_tile_window(
16961819
v_dram,
@@ -1840,7 +1963,10 @@ struct FmhaFwdKernel
18401963
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
18411964
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
18421965
#if CK_TILE_FMHA_FWD_FAST_EXP2
1843-
slope *= ck_tile::log2e_v<>;
1966+
if constexpr(!detail::ignore_fast_exp2_v<FmhaPipeline>)
1967+
{
1968+
slope *= ck_tile::log2e_v<>;
1969+
}
18441970
#endif
18451971
if constexpr(kHasMask)
18461972
{
@@ -2798,7 +2924,10 @@ struct FmhaFwdKernel
27982924
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
27992925
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
28002926
#if CK_TILE_FMHA_FWD_FAST_EXP2
2801-
slope *= ck_tile::log2e_v<>;
2927+
if constexpr(!detail::ignore_fast_exp2_v<FmhaPipeline>)
2928+
{
2929+
slope *= ck_tile::log2e_v<>;
2930+
}
28022931
#endif
28032932
if constexpr(kHasMask)
28042933
{

include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,52 @@
99

1010
namespace ck_tile {
1111

12+
namespace detail {
13+
14+
template <typename DataType, index_t ElemPerThread>
15+
CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize()
16+
{
17+
if constexpr(std::is_same_v<DataType, half_t> || std::is_same_v<DataType, bf16_t>)
18+
{
19+
// ToDo: need support in ck_tile for using buffer_load_dwordx3
20+
// if constexpr(ElemPerThread % 6 == 0)
21+
// return 6;
22+
if constexpr(ElemPerThread % 8 == 0)
23+
return 8;
24+
else if constexpr(ElemPerThread % 4 == 0)
25+
return 4;
26+
else if constexpr(ElemPerThread % 2 == 0)
27+
return 2;
28+
return 1;
29+
}
30+
else if constexpr(std::is_same_v<DataType, float>)
31+
{
32+
// ToDo: need support in ck_tile for using buffer_load_dwordx3
33+
// if constexpr(ElemPerThread % 3 == 0)
34+
// return 3;
35+
if constexpr(ElemPerThread % 4 == 0)
36+
return 4;
37+
else if constexpr(ElemPerThread % 2 == 0)
38+
return 2;
39+
return 1;
40+
}
41+
else
42+
return 1;
43+
};
44+
45+
template <typename DataType,
46+
index_t kThreadBlockSize,
47+
index_t kHigherDimSize,
48+
index_t kLowerDimSize>
49+
CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize()
50+
{
51+
constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize;
52+
53+
return GetMaxVectorSize<DataType, ElemPerThread>();
54+
}
55+
56+
} // namespace detail
57+
1258
template <typename QDataType_,
1359
typename KDataType_,
1460
typename VDataType_,

0 commit comments

Comments
 (0)