@@ -14,56 +14,6 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
1414{
1515 using BasePolicy = UniversalGemmBasePolicy<UniversalWeightPreshufflePipelineAgBgCrPolicy>;
1616
17- // 3d + padding
18- template <typename Problem>
19- CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor ()
20- {
21- using namespace ck_tile ;
22- constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM ;
23- constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK ;
24- constexpr index_t kKPack = GetSmemPackA<Problem>();
25- using ADataType = remove_cvref_t <typename Problem::ADataType>;
26-
27- constexpr auto DataTypeSize = sizeof (ADataType);
28- constexpr auto MLdsLayer =
29- (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
30-
31- constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor (
32- make_tuple (number<kKPerBlock / kKPack * MLdsLayer>{},
33- number<kMPerBlock / MLdsLayer>{},
34- number<kKPack >{}),
35- make_tuple (number<kKPack >{}, number<kKPerBlock * MLdsLayer>{}, number<1 >{}),
36- number<kKPack >{},
37- number<1 >{});
38-
39- constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor (
40- a_lds_block_desc_0,
41- make_tuple (make_xor_transform (make_tuple (number<kMPerBlock / MLdsLayer>{},
42- number<kKPerBlock / kKPack * MLdsLayer>{})),
43- make_pass_through_transform (number<kKPack >{})),
44- make_tuple (sequence<1 , 0 >{}, sequence<2 >{}),
45- make_tuple (sequence<1 , 0 >{}, sequence<2 >{}));
46-
47- constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor (
48- a_lds_block_desc_permuted,
49- make_tuple (make_unmerge_transform (
50- make_tuple (number<MLdsLayer>{}, number<kKPerBlock / kKPack >{})),
51- make_pass_through_transform (number<kMPerBlock / MLdsLayer>{}),
52- make_pass_through_transform (number<kKPack >{})),
53- make_tuple (sequence<0 >{}, sequence<1 >{}, sequence<2 >{}),
54- make_tuple (sequence<0 , 2 >{}, sequence<1 >{}, sequence<3 >{}));
55-
56- constexpr auto a_lds_block_desc = transform_tensor_descriptor (
57- a_lds_block_desc_xk0_mnldslayer_mn_xk1,
58- make_tuple (
59- make_merge_transform (
60- make_tuple (number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
61- make_merge_transform (make_tuple (number<kKPerBlock / kKPack >{}, number<kKPack >{}))),
62- make_tuple (sequence<1 , 0 >{}, sequence<2 , 3 >{}),
63- make_tuple (sequence<0 >{}, sequence<1 >{}));
64- return a_lds_block_desc;
65- }
66-
6717 template <typename Problem>
6818 CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA ()
6919 {
@@ -291,6 +241,12 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
291241 }
292242 }
293243
244+ template <typename Problem>
245+ CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm ()
246+ {
247+ return GetBlockWeightPreshuffle<Problem>();
248+ }
249+
294250 template <typename Problem>
295251 CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffle ()
296252 {
0 commit comments