@@ -91,7 +91,46 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_ins
9191#ifdef CK_USE_WMMA
9292#ifdef CK_ENABLE_BF16
9393// grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK
94- void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances (
94+ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1 (
95+ std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3 ,
96+ NDHWGC ,
97+ GKZYXC ,
98+ ck::Tuple<>,
99+ NDHWGK ,
100+ ck::Tuple<BF16 , BF16 >,
101+ ck::Tuple<BF16 , BF16 >,
102+ ck::Tuple<>,
103+ BF16 ,
104+ ScaleAdd,
105+ ScaleAdd,
106+ PassThrough>>>& instances);
107+ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2 (
108+ std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3 ,
109+ NDHWGC ,
110+ GKZYXC ,
111+ ck::Tuple<>,
112+ NDHWGK ,
113+ ck::Tuple<BF16 , BF16 >,
114+ ck::Tuple<BF16 , BF16 >,
115+ ck::Tuple<>,
116+ BF16 ,
117+ ScaleAdd,
118+ ScaleAdd,
119+ PassThrough>>>& instances);
120+ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3 (
121+ std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3 ,
122+ NDHWGC ,
123+ GKZYXC ,
124+ ck::Tuple<>,
125+ NDHWGK ,
126+ ck::Tuple<BF16 , BF16 >,
127+ ck::Tuple<BF16 , BF16 >,
128+ ck::Tuple<>,
129+ BF16 ,
130+ ScaleAdd,
131+ ScaleAdd,
132+ PassThrough>>>& instances);
133+ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4 (
95134 std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3 ,
96135 NDHWGC ,
97136 GKZYXC ,
@@ -107,7 +146,46 @@ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndh
107146#endif
108147
109148#ifdef CK_ENABLE_FP16
110- void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances (
149+ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1 (
150+ std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3 ,
151+ NDHWGC ,
152+ GKZYXC ,
153+ ck::Tuple<>,
154+ NDHWGK ,
155+ ck::Tuple<F16 , F16 >,
156+ ck::Tuple<F16 , F16 >,
157+ ck::Tuple<>,
158+ F16 ,
159+ ScaleAdd,
160+ ScaleAdd,
161+ PassThrough>>>& instances);
162+ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2 (
163+ std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3 ,
164+ NDHWGC ,
165+ GKZYXC ,
166+ ck::Tuple<>,
167+ NDHWGK ,
168+ ck::Tuple<F16 , F16 >,
169+ ck::Tuple<F16 , F16 >,
170+ ck::Tuple<>,
171+ F16 ,
172+ ScaleAdd,
173+ ScaleAdd,
174+ PassThrough>>>& instances);
175+ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3 (
176+ std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3 ,
177+ NDHWGC ,
178+ GKZYXC ,
179+ ck::Tuple<>,
180+ NDHWGK ,
181+ ck::Tuple<F16 , F16 >,
182+ ck::Tuple<F16 , F16 >,
183+ ck::Tuple<>,
184+ F16 ,
185+ ScaleAdd,
186+ ScaleAdd,
187+ PassThrough>>>& instances);
188+ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4 (
111189 std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3 ,
112190 NDHWGC ,
113191 GKZYXC ,
@@ -218,7 +296,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
218296 is_same_v<WeiDataType, ck::Tuple<half_t , half_t >> &&
219297 is_same_v<OutDataType, half_t > && is_same_v<ComputeType, half_t >)
220298 {
221- add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances (
299+ add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1 (
300+ op_ptrs);
301+ add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2 (
302+ op_ptrs);
303+ add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3 (
304+ op_ptrs);
305+ add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4 (
222306 op_ptrs);
223307 }
224308#endif
@@ -227,7 +311,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
227311 is_same_v<WeiDataType, ck::Tuple<ck::bhalf_t , ck::bhalf_t >> &&
228312 is_same_v<OutDataType, ck::bhalf_t > && is_same_v<ComputeType, ck::bhalf_t >)
229313 {
230- add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances (
314+ add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1 (
315+ op_ptrs);
316+ add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2 (
317+ op_ptrs);
318+ add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3 (
319+ op_ptrs);
320+ add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4 (
231321 op_ptrs);
232322 }
233323#endif
0 commit comments