Skip to content

Commit 2955d77

Browse files
authored
Fix grouped conv fwd wmma porting (#3479)
* Fix grouped conv fwd wmma porting * add more limitations
1 parent a8aebb7 commit 2955d77

3 files changed

Lines changed: 30 additions & 6 deletions

File tree

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
327327
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
328328
GET_MXDL_PER_WAVE_IMPL
329329
// Force usage of 16x16 instruction for WMMA
330-
static constexpr index_t Wave32MaxMNPerXDL = 16;
331-
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
330+
static constexpr bool Wave32Force16MNPerXDL =
331+
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
332+
sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 &&
333+
is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
334+
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 ||
335+
ConvForwardSpecialization == ConvolutionForwardSpecialization::Default);
336+
static constexpr index_t Wave32MaxMNPerXDL =
337+
Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL);
338+
339+
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
332340
static constexpr auto MXdlPerWave32 =
333341
GetMXdlPerWave<false,
334342
Wave32MaxMNPerXDL,

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
402402
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
403403
GET_MXDL_PER_WAVE_IMPL
404404
// Force usage of 16x16 instruction for WMMA
405-
static constexpr index_t Wave32MaxMNPerXDL = 16;
406-
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
405+
static constexpr bool Wave32Force16MNPerXDL =
406+
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
407+
sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 &&
408+
is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
409+
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 ||
410+
ConvForwardSpecialization == ConvolutionForwardSpecialization::Default);
411+
static constexpr index_t Wave32MaxMNPerXDL =
412+
Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL);
413+
414+
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
407415
static constexpr auto MXdlPerWave32 =
408416
GetMXdlPerWave<false,
409417
Wave32MaxMNPerXDL,

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,16 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
208208
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
209209
GET_MXDL_PER_WAVE_IMPL
210210
// Force usage of 16x16 instruction for WMMA
211-
static constexpr index_t Wave32MaxMNPerXDL = 16;
212-
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
211+
static constexpr bool Wave32Force16MNPerXDL =
212+
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
213+
sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 &&
214+
is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
215+
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 ||
216+
ConvForwardSpecialization == ConvolutionForwardSpecialization::Default);
217+
static constexpr index_t Wave32MaxMNPerXDL =
218+
Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL);
219+
220+
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
213221
static constexpr auto MXdlPerWave32 =
214222
GetMXdlPerWave<false,
215223
Wave32MaxMNPerXDL,

0 commit comments

Comments
 (0)