Skip to content

Commit 5aaa031

Browse files
AviralGoelAMDThomas Ning
andauthored
WIP: extract MakeALdsDescriptor() from child to parent class for code readability (#3392)
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
1 parent e809861 commit 5aaa031

2 files changed

Lines changed: 7 additions & 51 deletions

File tree

include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ struct UniversalGemmBasePolicy
112112
using ADataType = OverrideADataType;
113113
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
114114
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
115-
constexpr index_t KPack = GetSmemPackA<Problem>();
115+
constexpr index_t KPack = Derived::template GetSmemPackA<Problem>();
116116

117117
if constexpr(is_a_load_tr<Problem>)
118118
{

include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -14,56 +14,6 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
1414
{
1515
using BasePolicy = UniversalGemmBasePolicy<UniversalWeightPreshufflePipelineAgBgCrPolicy>;
1616

17-
// 3d + padding
18-
template <typename Problem>
19-
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
20-
{
21-
using namespace ck_tile;
22-
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
23-
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
24-
constexpr index_t kKPack = GetSmemPackA<Problem>();
25-
using ADataType = remove_cvref_t<typename Problem::ADataType>;
26-
27-
constexpr auto DataTypeSize = sizeof(ADataType);
28-
constexpr auto MLdsLayer =
29-
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
30-
31-
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
32-
make_tuple(number<kKPerBlock / kKPack * MLdsLayer>{},
33-
number<kMPerBlock / MLdsLayer>{},
34-
number<kKPack>{}),
35-
make_tuple(number<kKPack>{}, number<kKPerBlock * MLdsLayer>{}, number<1>{}),
36-
number<kKPack>{},
37-
number<1>{});
38-
39-
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
40-
a_lds_block_desc_0,
41-
make_tuple(make_xor_transform(make_tuple(number<kMPerBlock / MLdsLayer>{},
42-
number<kKPerBlock / kKPack * MLdsLayer>{})),
43-
make_pass_through_transform(number<kKPack>{})),
44-
make_tuple(sequence<1, 0>{}, sequence<2>{}),
45-
make_tuple(sequence<1, 0>{}, sequence<2>{}));
46-
47-
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
48-
a_lds_block_desc_permuted,
49-
make_tuple(make_unmerge_transform(
50-
make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
51-
make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
52-
make_pass_through_transform(number<kKPack>{})),
53-
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
54-
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
55-
56-
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
57-
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
58-
make_tuple(
59-
make_merge_transform(
60-
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
61-
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
62-
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
63-
make_tuple(sequence<0>{}, sequence<1>{}));
64-
return a_lds_block_desc;
65-
}
66-
6717
template <typename Problem>
6818
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
6919
{
@@ -291,6 +241,12 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
291241
}
292242
}
293243

244+
template <typename Problem>
245+
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
246+
{
247+
return GetBlockWeightPreshuffle<Problem>();
248+
}
249+
294250
template <typename Problem>
295251
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffle()
296252
{

0 commit comments

Comments
 (0)