Skip to content

Commit 945849b

Browse files
bartekxkassistant-librarian[bot]
authored andcommitted
[rocm-libraries] ROCm/rocm-libraries#6838 (commit ff7a665)
[CK_TILE] Add depthwise conv2d forward kernel (FP16/FP32) (#6838) ## Motivation CK currently has no kernel optimized for depthwise convolution (G=C_in=C_out, C=K=1 per group) and existing generic paths perform poorly for this workload. This PR adds a dedicated depthwise conv forward kernel in CK Tile. ## Technical Details Adds a dedicated depthwise conv2d forward op to CK Tile that performs direct convolution rather than falling back to the generic GEMM path. The kernel is templatized by filter size, stride, and data type, and compiled into ~60 instances covering common configurations (kernel 3/5/7/9, stride 1/2, FP16/FP32). Supports both CDNA (gfx942/gfx950) and RDNA (gfx1100/gfx1200) architectures. ## Test Plan - [x] Correctness and performance validated on gfx942, gfx950, and gfx1100, with ckProfiler `grouped_conv_fwd` as baseline. - [ ] MI300A (gfx942) and gfx1200 validation. ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-1137
1 parent fe2e29f commit 945849b

22 files changed

Lines changed: 1997 additions & 316 deletions

experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,27 @@ concept TileOptimizationsDescriptor = requires(T t) {
158158
{ t.two_stage } -> std::convertible_to<bool>;
159159
};
160160

161+
// Concept to check if struct specifies depthwise convolution tile parameters.
162+
template <typename T>
163+
concept DepthwiseConvParamsDescriptor = requires(T t) {
164+
{ t.block_size } -> std::convertible_to<int>;
165+
{ t.tile_h } -> std::convertible_to<int>;
166+
{ t.tile_w } -> std::convertible_to<int>;
167+
{ t.filter_h } -> std::convertible_to<int>;
168+
{ t.filter_w } -> std::convertible_to<int>;
169+
{ t.stride_h } -> std::convertible_to<int>;
170+
{ t.stride_w } -> std::convertible_to<int>;
171+
{ t.dilation_h } -> std::convertible_to<int>;
172+
{ t.dilation_w } -> std::convertible_to<int>;
173+
{ t.pad_h } -> std::convertible_to<int>;
174+
{ t.pad_w } -> std::convertible_to<int>;
175+
{ t.nbatch } -> std::convertible_to<int>;
176+
{ t.subtile_h } -> std::convertible_to<int>;
177+
{ t.subtile_w } -> std::convertible_to<int>;
178+
{ t.in_vec } -> std::convertible_to<int>;
179+
{ t.out_vec } -> std::convertible_to<int>;
180+
};
181+
161182
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
162183
// concept.
163184
template <typename T>
@@ -299,6 +320,27 @@ concept SpecifiesTileOptimizations = requires {
299320
{ T::optimizations.two_stage } -> std::convertible_to<bool>;
300321
};
301322

323+
// Concept to check if struct specifies depthwise convolution tile parameters.
324+
template <typename T>
325+
concept SpecifiesDepthwiseConvParams = requires {
326+
{ T::depthwise_params.block_size } -> std::convertible_to<int>;
327+
{ T::depthwise_params.tile_h } -> std::convertible_to<int>;
328+
{ T::depthwise_params.tile_w } -> std::convertible_to<int>;
329+
{ T::depthwise_params.filter_h } -> std::convertible_to<int>;
330+
{ T::depthwise_params.filter_w } -> std::convertible_to<int>;
331+
{ T::depthwise_params.stride_h } -> std::convertible_to<int>;
332+
{ T::depthwise_params.stride_w } -> std::convertible_to<int>;
333+
{ T::depthwise_params.dilation_h } -> std::convertible_to<int>;
334+
{ T::depthwise_params.dilation_w } -> std::convertible_to<int>;
335+
{ T::depthwise_params.pad_h } -> std::convertible_to<int>;
336+
{ T::depthwise_params.pad_w } -> std::convertible_to<int>;
337+
{ T::depthwise_params.nbatch } -> std::convertible_to<int>;
338+
{ T::depthwise_params.subtile_h } -> std::convertible_to<int>;
339+
{ T::depthwise_params.subtile_w } -> std::convertible_to<int>;
340+
{ T::depthwise_params.in_vec } -> std::convertible_to<int>;
341+
{ T::depthwise_params.out_vec } -> std::convertible_to<int>;
342+
};
343+
302344
template <typename T>
303345
concept SpecifiesTileConvSpecialization = requires {
304346
{ T::specialization } -> std::convertible_to<TileConvSpecialization>;

experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ concept TileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T
6363
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
6464
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
6565

66+
// Depthwise tile-based algorithm concept (no GEMM — direct spatial pipeline)
67+
template <typename T>
68+
concept DepthwiseAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesDepthwiseConvParams<T>;
69+
6670
// FWD XDL algorithm concepts
6771
template <typename T>
6872
concept FwdXdlAlgorithm = FwdXdlAlgorithmBase<T> && SpecifiesGenericInstance<T>;
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#pragma once
5+
6+
#include "ck_tile/core.hpp"
7+
#include "ck_tile/host/kernel_launch.hpp"
8+
#include "ck_tile/ops/epilogue.hpp"
9+
#include "ck_tile/ops/grouped_convolution.hpp"
10+
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
11+
#include "ck_tile/builder/conv_signature_concepts.hpp"
12+
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
13+
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp"
14+
15+
namespace ck_tile::builder::factory {
16+
17+
// Factory for CK Tile depthwise grouped convolution kernels.
18+
// Instantiates GroupedConvolutionForwardKernel with DepthwiseConvFwdPipeline.
19+
template <ConvSignatureDescriptor auto SIGNATURE,
20+
ConvAlgorithmDescriptor auto ALGORITHM,
21+
StringLiteral VERSION>
22+
struct ConvDepthwiseTileFactory
23+
{
24+
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
25+
using Types = internal::TileConvTensorTypes<SIGNATURE.data_type>;
26+
27+
static constexpr auto DW = ALGORITHM.depthwise_params;
28+
29+
using InDataType = typename Types::ADataType;
30+
using WeiDataType = typename Types::BDataType;
31+
using AccDataType = typename Types::AccDataType;
32+
using OutDataType = typename Types::EDataType;
33+
34+
using DwTraits = ck_tile::DepthwiseConvFwdTraits<InDataType,
35+
WeiDataType,
36+
AccDataType,
37+
OutDataType,
38+
DW.block_size,
39+
DW.tile_h,
40+
DW.tile_w,
41+
DW.filter_h,
42+
DW.filter_w,
43+
DW.stride_h,
44+
DW.stride_w,
45+
DW.dilation_h,
46+
DW.dilation_w,
47+
DW.pad_h,
48+
DW.pad_w,
49+
DW.nbatch,
50+
DW.subtile_h,
51+
DW.subtile_w,
52+
DW.in_vec,
53+
DW.out_vec>;
54+
55+
using DwPipeline = ck_tile::DepthwiseConvFwdPipeline<DwTraits>;
56+
57+
using ConvTraitsType = ck_tile::GroupedConvTraits<SPATIAL_DIM,
58+
ck_tile::ConvolutionSpecialization::Default,
59+
void,
60+
void,
61+
ck_tile::tuple<>,
62+
void,
63+
DW.in_vec,
64+
DW.in_vec,
65+
DW.out_vec,
66+
1,
67+
false,
68+
false,
69+
DwTraits>;
70+
71+
struct DepthwiseNullEpilogue
72+
{
73+
using DsLayout = ck_tile::tuple<>;
74+
using DsDataType = ck_tile::tuple<>;
75+
using ODataType = OutDataType;
76+
using AccDataType = float;
77+
using CDElementwise = ck_tile::element_wise::PassThrough;
78+
};
79+
80+
using Instance = ck_tile::
81+
GroupedConvolutionForwardKernel<ConvTraitsType, void, DwPipeline, DepthwiseNullEpilogue>;
82+
};
83+
84+
} // namespace ck_tile::builder::factory

experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
7070
#include "ck_tile/builder/factory/reference_factory.hpp"
7171
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
72+
#include "ck_tile/builder/factory/conv_depthwise_tile_factory.hpp"
7273
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp"
7374
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp"
7475
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp"
@@ -115,6 +116,11 @@ constexpr auto make_conv_instance()
115116
{
116117
return typename ReferenceFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
117118
}
119+
// Depthwise tile algorithm — direct spatial pipeline, no GEMM
120+
else if constexpr(DepthwiseAlgorithm<AlgoType>)
121+
{
122+
return typename ConvDepthwiseTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
123+
}
118124
// CK Tile supports common factory for each direction
119125
else if constexpr(TileAlgorithm<AlgoType>)
120126
{

0 commit comments

Comments
 (0)