Skip to content

Commit 2203b0d

Browse files
EnricoDegbartekxk
andauthored
Add padding to 1x1Stride1Pad0 conv specialization (grouped conv bwd weight) (#2610)
* Add padding 1x1Stride1Pad0 conv specialization * Add gridwise checks for conv cshufflev3 * Merge padding with previous transforms * Apply transform changes for padding to default specialization as well --------- Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
1 parent cbfecf8 commit 2203b0d

5 files changed

Lines changed: 290 additions & 168 deletions

File tree

include/ck/ck.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,6 @@
222222
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
223223
#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
224224

225-
// workaround: conv crash when K, C is even
226-
#define CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN 1
227-
228225
// workaround: compiler crash when compiling recursive lambda
229226
#define CK_WORKAROUND_SWDEV_275126 1
230227

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
331331
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
332332

333333
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3<
334-
tensor_layout::gemm::RowMajor,
335334
tensor_layout::gemm::ColumnMajor,
336335
tensor_layout::gemm::RowMajor,
336+
tensor_layout::gemm::RowMajor,
337337
ADataType,
338338
BDataType,
339339
AccDataType,
@@ -1299,13 +1299,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
12991299
if constexpr(ConvBackwardWeightSpecialization ==
13001300
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
13011301
{
1302-
// workaround: disable when K, C is even
1303-
#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN
1304-
if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0)
1305-
{
1306-
return false;
1307-
}
1308-
#endif
13091302
// check if it's 1x1, stride=1 pad = 0 conv
13101303
for(int i = 0; i < NDimSpatial; i++)
13111304
{
@@ -1330,7 +1323,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
13301323
}
13311324

13321325
// Gridwise GEMM size
1333-
return true;
1326+
return GridwiseGemm::CheckValidity(gemm_arg);
13341327
}
13351328

13361329
bool IsSupportedArgument(const BaseArgument* p_arg) override

include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#pragma once
55

66
#include "ck/utility/common_header.hpp"
7+
#include "ck/utility/env.hpp"
78
#include "ck/tensor_description/multi_index_transform_helper.hpp"
89
#include "ck/tensor_description/tensor_descriptor.hpp"
910
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
@@ -606,6 +607,203 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
606607
c_block_size * sizeof(CShuffleDataType));
607608
}
608609

610+
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
611+
__host__ static constexpr bool CheckValidity(const Argument& karg)
612+
{
613+
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
614+
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
615+
"Invalid tuning param!");
616+
617+
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
618+
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
619+
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
620+
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
621+
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
622+
{
623+
if(!(karg.M % MPerBlock == 0))
624+
{
625+
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
626+
{
627+
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
628+
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
629+
<< std::endl;
630+
}
631+
return false;
632+
}
633+
}
634+
635+
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
636+
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
637+
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
638+
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
639+
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
640+
{
641+
if(!(karg.N % NPerBlock == 0))
642+
{
643+
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
644+
{
645+
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
646+
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
647+
<< std::endl;
648+
}
649+
return false;
650+
}
651+
}
652+
653+
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
654+
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
655+
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
656+
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
657+
{
658+
659+
auto K_t = karg.KBatch * KPerBlock;
660+
if(!(karg.K % K_t == 0))
661+
{
662+
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
663+
{
664+
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
665+
<< karg.K << " " << __FILE__ << ":" << __LINE__
666+
<< ", in function: " << __func__ << std::endl;
667+
}
668+
return false;
669+
}
670+
}
671+
else
672+
{
673+
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
674+
auto K_t = karg.KBatch * KReadVec;
675+
auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
676+
if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
677+
{
678+
return false;
679+
}
680+
}
681+
682+
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
683+
{
684+
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
685+
{
686+
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
687+
{
688+
std::cout << "Arg K (" << karg.K
689+
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
690+
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
691+
<< __LINE__ << ", in function: " << __func__ << std::endl;
692+
}
693+
return false;
694+
}
695+
}
696+
else
697+
{
698+
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
699+
{
700+
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
701+
{
702+
std::cout << "Arg M (" << karg.M
703+
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
704+
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
705+
<< __LINE__ << ", in function: " << __func__ << std::endl;
706+
}
707+
return false;
708+
}
709+
}
710+
711+
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
712+
{
713+
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
714+
{
715+
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
716+
{
717+
std::cout << "Arg N (" << karg.N
718+
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
719+
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
720+
<< __LINE__ << ", in function: " << __func__ << std::endl;
721+
}
722+
return false;
723+
}
724+
}
725+
else
726+
{
727+
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
728+
{
729+
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
730+
{
731+
std::cout << "Arg K (" << karg.K
732+
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
733+
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
734+
<< __LINE__ << ", in function: " << __func__ << std::endl;
735+
}
736+
return false;
737+
}
738+
}
739+
740+
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
741+
{
742+
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
743+
{
744+
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
745+
{
746+
std::cout << "Arg N (" << karg.N
747+
<< ") value is not a multiple of "
748+
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
749+
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
750+
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
751+
<< std::endl;
752+
}
753+
return false;
754+
}
755+
}
756+
else
757+
{
758+
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
759+
{
760+
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
761+
{
762+
std::cout << "Arg M (" << karg.M
763+
<< ") value is not a multiple of "
764+
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
765+
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
766+
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
767+
<< std::endl;
768+
}
769+
return false;
770+
}
771+
}
772+
773+
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
774+
is_same<remove_cvref_t<CDataType>, float>::value ||
775+
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
776+
is_same<remove_cvref_t<CDataType>, int32_t>::value))
777+
{
778+
if(!karg.IsReduceAdd())
779+
{
780+
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
781+
{
782+
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
783+
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
784+
}
785+
if(karg.KBatch > 1)
786+
{
787+
return false;
788+
}
789+
}
790+
}
791+
792+
// check gridwise gemm pipeline
793+
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
794+
795+
if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
796+
{
797+
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
798+
{
799+
return false;
800+
}
801+
}
802+
803+
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
804+
return true;
805+
}
806+
609807
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
610808
{
611809
const index_t num_loop = K / KPerBlock;

0 commit comments

Comments
 (0)