Skip to content

Commit 4ce7d4c

Browse files
[ck_builder] add utility functions to convolution (#3459)
* reinstate conv_signature_utils.hpp * added tests for elementwise operation getters * add tests for getDataType functions * added test for no data type specified --------- Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
1 parent ead81d1 commit 4ce7d4c

3 files changed

Lines changed: 300 additions & 4 deletions

File tree

experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,24 @@ concept ConvOutputLayout3D =
8080
(L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) ||
8181
(L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided);
8282

83+
template <typename T>
84+
concept HasDataType = requires(T t) {
85+
{ t.data_type };
86+
};
87+
88+
// Note: for signature and TensorConfigDescriptor,
89+
// it is not required to provide a default data type, but if one is provided, check if well defined
90+
template <typename T>
91+
concept DataTypeWellDefinedIfProvided = requires(T t) {
92+
requires !HasDataType<T> || requires {
93+
{ t.data_type } -> std::convertible_to<DataType>;
94+
};
95+
};
96+
8397
template <typename T>
8498
concept TensorConfigDescriptor = requires(T t) {
8599
{ t.layout } -> std::convertible_to<TensorLayout>;
86-
// Only require that data type is defined. It might be set to undefined value, in which case the
87-
// signature's data type is used.
88-
{ t.data_type } -> std::convertible_to<DataType>;
100+
requires DataTypeWellDefinedIfProvided<T>;
89101
};
90102

91103
template <typename T>
@@ -164,11 +176,11 @@ concept HasElementwiseOpWithAuxiliaryOperands = requires(T t) {
164176
template <typename T>
165177
concept ConvSignatureDescriptor = requires(T t) {
166178
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
167-
{ t.data_type } -> std::convertible_to<DataType>;
168179
{ t.input } -> ConvTensorDescriptor;
169180
{ t.weight } -> ConvTensorDescriptor;
170181
{ t.output } -> ConvTensorDescriptor;
171182
requires ConvolutionDirectionWellDefinedIfProvided<T>;
183+
requires DataTypeWellDefinedIfProvided<T>;
172184
};
173185

174186
// Concept to validate a convolution signature's values.
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#pragma once
5+
6+
#include <concepts>
7+
#include <type_traits>
8+
9+
#include "ck_tile/builder/conv_signature_concepts.hpp"
10+
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
11+
#include "ck_tile/builder/types.hpp"
12+
13+
namespace ck_tile::builder {
14+
/**********************************************
15+
* constexpr helper functions for optional parameters
16+
**********************************************/
17+
18+
template <auto Sig>
19+
concept ProvidesElementwiseOperation = requires { Sig.elementwise_operation; };
20+
21+
template <auto Sig>
22+
concept ProvidesDataType = requires { Sig.data_type; };
23+
24+
template <auto ConvTensor>
25+
concept ConvTensorHasOp = requires { ConvTensor.operation; };
26+
27+
template <auto Sig>
28+
concept ProvidesConvolutionDirection = requires { Sig.direction; };
29+
30+
// returns elementwise operation for input tensor
31+
// will defalut to signature's generic type if provided
32+
// otherwise, default to PASS_THROUGH
33+
template <auto Sig>
34+
requires ValidConvSignature<Sig>
35+
constexpr auto getInputElementwiseOperation()
36+
{
37+
if constexpr(ConvTensorHasOp<Sig.input>)
38+
{
39+
return Sig.input.operation.elementwise_operation;
40+
}
41+
else if constexpr(ProvidesElementwiseOperation<Sig>)
42+
{
43+
return Sig.elementwise_operation;
44+
}
45+
else
46+
{
47+
return ElementwiseOperation::PASS_THROUGH;
48+
}
49+
}
50+
51+
// returns elementwise operation for weight tensor
52+
// will defalut to signature's generic type if provided
53+
// otherwise, default to PASS_THROUGH
54+
template <auto Sig>
55+
requires ValidConvSignature<Sig>
56+
constexpr auto getWeightElementwiseOperation()
57+
{
58+
if constexpr(ConvTensorHasOp<Sig.weight>)
59+
{
60+
return Sig.weight.operation.elementwise_operation;
61+
}
62+
else if constexpr(ProvidesElementwiseOperation<Sig>)
63+
{
64+
return Sig.elementwise_operation;
65+
}
66+
else
67+
{
68+
return ElementwiseOperation::PASS_THROUGH;
69+
}
70+
}
71+
72+
// returns elementwise operation for output tensor
73+
// will defalut to signature's generic type if provided
74+
// otherwise, default to PASS_THROUGH
75+
template <auto Sig>
76+
requires ValidConvSignature<Sig>
77+
constexpr auto getOutputElementwiseOperation()
78+
{
79+
if constexpr(ConvTensorHasOp<Sig.output>)
80+
{
81+
return Sig.output.operation.elementwise_operation;
82+
}
83+
else if constexpr(ProvidesElementwiseOperation<Sig>)
84+
{
85+
return Sig.elementwise_operation;
86+
}
87+
else
88+
{
89+
return ElementwiseOperation::PASS_THROUGH;
90+
}
91+
}
92+
93+
// returns convolution direction for signature. Will default to FORWARD if not provided by signature
94+
template <auto Sig>
95+
requires ValidConvSignature<Sig>
96+
constexpr auto getConvDirection()
97+
{
98+
if constexpr(ProvidesConvolutionDirection<Sig>)
99+
{
100+
return Sig.direction;
101+
}
102+
else
103+
{
104+
return ConvDirection::FORWARD;
105+
}
106+
}
107+
108+
// generic helper that returns data_type if provided and UNDEFINED otherwise
109+
// can be used on both signature and TensorConfigDescriptor objects
110+
template <auto TensorConfigOrSig>
111+
constexpr auto getDataType()
112+
{
113+
if constexpr(ProvidesDataType<TensorConfigOrSig>)
114+
{
115+
return TensorConfigOrSig.data_type;
116+
}
117+
else
118+
{
119+
return DataType::UNDEFINED_DATA_TYPE;
120+
}
121+
}
122+
123+
// return data type of input tensor
124+
template <auto Sig>
125+
requires ValidConvSignature<Sig>
126+
consteval auto getInputDataType()
127+
{
128+
constexpr auto tensorDataType = getDataType<Sig.input.config>();
129+
constexpr auto universalDataType = getDataType<Sig>();
130+
if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE)
131+
{
132+
return tensorDataType;
133+
}
134+
else
135+
{
136+
return universalDataType;
137+
}
138+
}
139+
140+
template <auto Sig>
141+
requires ValidConvSignature<Sig>
142+
consteval auto getWeightDataType()
143+
{
144+
constexpr auto tensorDataType = getDataType<Sig.weight.config>();
145+
constexpr auto universalDataType = getDataType<Sig>();
146+
if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE)
147+
{
148+
return tensorDataType;
149+
}
150+
else
151+
{
152+
return universalDataType;
153+
}
154+
}
155+
156+
template <auto Sig>
157+
requires ValidConvSignature<Sig>
158+
consteval auto getOutputDataType()
159+
{
160+
constexpr auto tensorDataType = getDataType<Sig.output.config>();
161+
constexpr auto universalDataType = getDataType<Sig>();
162+
if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE)
163+
{
164+
return tensorDataType;
165+
}
166+
else
167+
{
168+
return universalDataType;
169+
}
170+
}
171+
172+
// returns data type if and only if all tensors have the same type.
173+
// Otherwise, return DataType::UNDEFINED_DATA_TYPE
174+
template <auto Sig>
175+
requires ValidConvSignature<Sig>
176+
consteval auto getDataTypeIfCommon()
177+
{
178+
179+
auto inputDataType = getInputDataType<Sig>();
180+
auto weightDataType = getWeightDataType<Sig>();
181+
auto outputDataType = getOutputDataType<Sig>();
182+
183+
if(inputDataType == weightDataType && inputDataType == outputDataType)
184+
{
185+
return inputDataType;
186+
}
187+
else
188+
{
189+
return DataType::UNDEFINED_DATA_TYPE;
190+
}
191+
}
192+
} // namespace ck_tile::builder

experimental/builder/test/test_conv_description.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "testing_utils.hpp"
1111
#include "impl/conv_signature_types.hpp"
1212
#include "impl/conv_algorithm_types.hpp"
13+
#include "ck_tile/builder/conv_signature_utils.hpp"
1314

1415
namespace {
1516

@@ -35,6 +36,18 @@ struct TensorConfig
3536
ckb::DataType compute_type{ckb::DataType::UNDEFINED_DATA_TYPE};
3637
};
3738

39+
struct TensorConfigNoDataType
40+
{
41+
ckb::TensorLayout layout;
42+
ckb::DataType compute_type{ckb::DataType::UNDEFINED_DATA_TYPE};
43+
};
44+
45+
struct ConvTensorNoDataType
46+
{
47+
TensorConfigNoDataType config;
48+
TensorOp operation{};
49+
};
50+
3851
struct ConvTensorSimple
3952
{
4053
TensorConfig config;
@@ -155,6 +168,85 @@ struct DefaultAlgorithm
155168
};
156169
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
157170

171+
struct ConvSignatureUtilsTest1
172+
{
173+
using enum ckb::DataType;
174+
using enum ckb::TensorLayout;
175+
using enum ckb::ConvDirection;
176+
using enum ckb::ElementwiseOperation;
177+
178+
int spatial_dim = 2;
179+
ckb::DataType data_type = FP16;
180+
ckb::DataType accumulation_data_type = FP32;
181+
ckb::ConvDirection direction = FORWARD;
182+
ConvTensorWithOp input = {
183+
.config = {GNHWC, FP16},
184+
};
185+
ConvTensorWithOp weight = {.config = {GKYXC, FP16}};
186+
ConvTensorWithOp output = {.config = {GNHWK, UNDEFINED_DATA_TYPE}, .operation = {SCALE}};
187+
};
188+
189+
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureUtilsTest1>);
190+
191+
struct ConvSignatureUtilsTest2
192+
{
193+
using enum ckb::DataType;
194+
using enum ckb::TensorLayout;
195+
using enum ckb::ConvDirection;
196+
using enum ckb::ElementwiseOperation;
197+
198+
int spatial_dim = 2;
199+
ckb::DataType data_type = FP16;
200+
ckb::ElementwiseOperation elementwise_operation = CONV_INVSCALE;
201+
ckb::DataType accumulation_data_type = FP32;
202+
ckb::ConvDirection direction = FORWARD;
203+
ConvTensorSimple input = {
204+
.config = {GNHWC, FP16},
205+
};
206+
ConvTensorNoDataType weight = {.config = {GKYXC}, .operation = {POWER}};
207+
ConvTensorWithOp output = {.config = {GNHWK, BF16}, .operation = {GELU}};
208+
};
209+
210+
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureUtilsTest2>);
211+
212+
TEST(ConvUtilsTest, getDataType1)
213+
{
214+
using enum ckb::DataType;
215+
static constexpr const ConvSignatureUtilsTest1 SIGNATURE;
216+
EXPECT_THAT(ckb::getInputDataType<SIGNATURE>(), FP16);
217+
EXPECT_THAT(ckb::getWeightDataType<SIGNATURE>(), FP16);
218+
EXPECT_THAT(ckb::getOutputDataType<SIGNATURE>(), FP16);
219+
EXPECT_THAT(ckb::getDataTypeIfCommon<SIGNATURE>(), FP16);
220+
}
221+
222+
TEST(ConvUtilsTest, getDataType2)
223+
{
224+
using enum ckb::DataType;
225+
static constexpr const ConvSignatureUtilsTest2 SIGNATURE;
226+
EXPECT_THAT(ckb::getInputDataType<SIGNATURE>(), FP16);
227+
EXPECT_THAT(ckb::getWeightDataType<SIGNATURE>(), FP16);
228+
EXPECT_THAT(ckb::getOutputDataType<SIGNATURE>(), BF16);
229+
EXPECT_THAT(ckb::getDataTypeIfCommon<SIGNATURE>(), UNDEFINED_DATA_TYPE);
230+
}
231+
232+
TEST(ConvUtilsTest, getElementwiseOperation1)
233+
{
234+
using enum ckb::ElementwiseOperation;
235+
static constexpr const ConvSignatureUtilsTest1 SIGNATURE;
236+
EXPECT_THAT(ckb::getInputElementwiseOperation<SIGNATURE>(), PASS_THROUGH);
237+
EXPECT_THAT(ckb::getWeightElementwiseOperation<SIGNATURE>(), PASS_THROUGH);
238+
EXPECT_THAT(ckb::getOutputElementwiseOperation<SIGNATURE>(), SCALE);
239+
}
240+
241+
TEST(ConvUtilsTest, getElementwiseOperation2)
242+
{
243+
using enum ckb::ElementwiseOperation;
244+
static constexpr const ConvSignatureUtilsTest2 SIGNATURE;
245+
EXPECT_THAT(ckb::getInputElementwiseOperation<SIGNATURE>(), CONV_INVSCALE);
246+
EXPECT_THAT(ckb::getWeightElementwiseOperation<SIGNATURE>(), POWER);
247+
EXPECT_THAT(ckb::getOutputElementwiseOperation<SIGNATURE>(), GELU);
248+
}
249+
158250
TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription)
159251
{
160252
static constexpr const ConvSignature SIGNATURE;

0 commit comments

Comments
 (0)