Skip to content

Commit a8aebb7

Browse files
authored
Post-merge cleanup for WMMA grouped conv fwd (#3468)
* remove duplicate aliases * Split scaleadd_ab instances for WMMA grouped conv fwd * removed big shape from the test
1 parent 44f1b5c commit a8aebb7

13 files changed

Lines changed: 570 additions & 62 deletions

library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp

Lines changed: 124 additions & 16 deletions
Large diffs are not rendered by default.

library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,46 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_ins
9191
#ifdef CK_USE_WMMA
9292
#ifdef CK_ENABLE_BF16
9393
// grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK
94-
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
94+
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
95+
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
96+
NDHWGC,
97+
GKZYXC,
98+
ck::Tuple<>,
99+
NDHWGK,
100+
ck::Tuple<BF16, BF16>,
101+
ck::Tuple<BF16, BF16>,
102+
ck::Tuple<>,
103+
BF16,
104+
ScaleAdd,
105+
ScaleAdd,
106+
PassThrough>>>& instances);
107+
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
108+
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
109+
NDHWGC,
110+
GKZYXC,
111+
ck::Tuple<>,
112+
NDHWGK,
113+
ck::Tuple<BF16, BF16>,
114+
ck::Tuple<BF16, BF16>,
115+
ck::Tuple<>,
116+
BF16,
117+
ScaleAdd,
118+
ScaleAdd,
119+
PassThrough>>>& instances);
120+
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3(
121+
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
122+
NDHWGC,
123+
GKZYXC,
124+
ck::Tuple<>,
125+
NDHWGK,
126+
ck::Tuple<BF16, BF16>,
127+
ck::Tuple<BF16, BF16>,
128+
ck::Tuple<>,
129+
BF16,
130+
ScaleAdd,
131+
ScaleAdd,
132+
PassThrough>>>& instances);
133+
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
95134
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
96135
NDHWGC,
97136
GKZYXC,
@@ -107,7 +146,46 @@ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndh
107146
#endif
108147

109148
#ifdef CK_ENABLE_FP16
110-
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances(
149+
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
150+
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
151+
NDHWGC,
152+
GKZYXC,
153+
ck::Tuple<>,
154+
NDHWGK,
155+
ck::Tuple<F16, F16>,
156+
ck::Tuple<F16, F16>,
157+
ck::Tuple<>,
158+
F16,
159+
ScaleAdd,
160+
ScaleAdd,
161+
PassThrough>>>& instances);
162+
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2(
163+
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
164+
NDHWGC,
165+
GKZYXC,
166+
ck::Tuple<>,
167+
NDHWGK,
168+
ck::Tuple<F16, F16>,
169+
ck::Tuple<F16, F16>,
170+
ck::Tuple<>,
171+
F16,
172+
ScaleAdd,
173+
ScaleAdd,
174+
PassThrough>>>& instances);
175+
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3(
176+
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
177+
NDHWGC,
178+
GKZYXC,
179+
ck::Tuple<>,
180+
NDHWGK,
181+
ck::Tuple<F16, F16>,
182+
ck::Tuple<F16, F16>,
183+
ck::Tuple<>,
184+
F16,
185+
ScaleAdd,
186+
ScaleAdd,
187+
PassThrough>>>& instances);
188+
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
111189
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
112190
NDHWGC,
113191
GKZYXC,
@@ -218,7 +296,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
218296
is_same_v<WeiDataType, ck::Tuple<half_t, half_t>> &&
219297
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
220298
{
221-
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances(
299+
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
300+
op_ptrs);
301+
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2(
302+
op_ptrs);
303+
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3(
304+
op_ptrs);
305+
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
222306
op_ptrs);
223307
}
224308
#endif
@@ -227,7 +311,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
227311
is_same_v<WeiDataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>> &&
228312
is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, ck::bhalf_t>)
229313
{
230-
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
314+
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
315+
op_ptrs);
316+
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
317+
op_ptrs);
318+
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3(
319+
op_ptrs);
320+
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
231321
op_ptrs);
232322
}
233323
#endif

library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,14 @@ set(GROUPED_CONV3D_FWD_SCALEADD_AB
99
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
1010

1111
# WMMA CSHUFFLE V3
12-
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
13-
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
12+
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp
13+
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp
14+
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp
15+
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp
16+
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp
17+
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp
18+
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp
19+
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp
1420
)
1521

1622
add_instance_library(device_grouped_conv3d_fwd_scaleadd_ab_instance ${GROUPED_CONV3D_FWD_SCALEADD_AB})

library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp renamed to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace tensor_operation {
99
namespace device {
1010
namespace instance {
1111

12-
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
12+
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
1313
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
1414
NDHWGC,
1515
GKZYXC,
@@ -25,25 +25,25 @@ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndh
2525
{
2626
add_device_operation_instances(
2727
instances,
28-
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances<3,
29-
NDHWGC,
30-
GKZYXC,
31-
NDHWGK,
32-
ConvFwdDefault>{});
28+
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part1<3,
29+
NDHWGC,
30+
GKZYXC,
31+
NDHWGK,
32+
ConvFwdDefault>{});
3333
add_device_operation_instances(
3434
instances,
35-
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances<3,
36-
NDHWGC,
37-
GKZYXC,
38-
NDHWGK,
39-
ConvFwd1x1P0>{});
35+
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part1<3,
36+
NDHWGC,
37+
GKZYXC,
38+
NDHWGK,
39+
ConvFwd1x1P0>{});
4040
add_device_operation_instances(
4141
instances,
42-
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances<3,
43-
NDHWGC,
44-
GKZYXC,
45-
NDHWGK,
46-
ConvFwd1x1S1P0>{});
42+
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part1<3,
43+
NDHWGC,
44+
GKZYXC,
45+
NDHWGK,
46+
ConvFwd1x1S1P0>{});
4747
}
4848

4949
} // namespace instance
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
5+
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp"
6+
7+
namespace ck {
8+
namespace tensor_operation {
9+
namespace device {
10+
namespace instance {
11+
12+
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
13+
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
14+
NDHWGC,
15+
GKZYXC,
16+
ck::Tuple<>,
17+
NDHWGK,
18+
ck::Tuple<BF16, BF16>,
19+
ck::Tuple<BF16, BF16>,
20+
ck::Tuple<>,
21+
BF16,
22+
ScaleAdd,
23+
ScaleAdd,
24+
PassThrough>>>& instances)
25+
{
26+
add_device_operation_instances(
27+
instances,
28+
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part2<3,
29+
NDHWGC,
30+
GKZYXC,
31+
NDHWGK,
32+
ConvFwdDefault>{});
33+
add_device_operation_instances(
34+
instances,
35+
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part2<3,
36+
NDHWGC,
37+
GKZYXC,
38+
NDHWGK,
39+
ConvFwd1x1P0>{});
40+
add_device_operation_instances(
41+
instances,
42+
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part2<3,
43+
NDHWGC,
44+
GKZYXC,
45+
NDHWGK,
46+
ConvFwd1x1S1P0>{});
47+
}
48+
49+
} // namespace instance
50+
} // namespace device
51+
} // namespace tensor_operation
52+
} // namespace ck

0 commit comments

Comments
 (0)