Skip to content

Commit 9491563

Browse files
yrapartiassistant-librarian[bot]
authored andcommitted
[rocm-libraries] ROCm/rocm-libraries#6399 (commit 13bf528)
[CK][CK TILE] Modify elementwise kernel template signature to accept independent type arguments (#6399) ## Motivation modify elementwise kernel template signature to fix cshuffle epilogue build error ## Technical Details Encountered a build error while building conv fallback kernel with dispatcher. Error: Type mismatch in `ElementWiseKernel::operator()` where the template required all three parameters (lens, input_strides, output_strides) to be the same type, but the CShuffle epilogue was passing them with different tuple element types. Solution: Modified the template signature in elementwise_kernel.hpp to accept three independent type parameters: Changed from single typename `Dims` to typename `DimsLens`, typename `DimsInStrides`, typename `DimsOutStrides` Updated references to `Dims::size()` to use the appropriate specific type ## Test Plan - Test with dispatcher conv unit tests - Relying on CI tests ## Test Result - Dispatcher unit tests passed - Relying on CI tests ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent 918e8a1 commit 9491563

1 file changed

Lines changed: 20 additions & 15 deletions

File tree

include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,13 @@ struct ElementWiseKernel
2727
return is_wave32() ? kBlockSize / 2 : kBlockSize;
2828
}
2929

30-
template <typename... XDataType, typename Dims>
31-
CK_TILE_DEVICE void operator()(const Dims lens,
32-
const Dims input_strides,
33-
const Dims output_strides,
30+
template <typename... XDataType,
31+
typename DimsLens,
32+
typename DimsInStrides,
33+
typename DimsOutStrides>
34+
CK_TILE_DEVICE void operator()(const DimsLens lens,
35+
const DimsInStrides input_strides,
36+
const DimsOutStrides output_strides,
3437
const tuple<XDataType...>& input_tensors,
3538
YDataType* p_y) const
3639
{
@@ -49,10 +52,11 @@ struct ElementWiseKernel
4952
input_tensors.get(i), lens, input_strides, number<S::kVectorM>{}, number<1>{});
5053

5154
const auto transformed_tensor = pad_tensor_view(
52-
transform_tensor_view(tensor_view,
53-
ck_tile::make_tuple(merge_transform),
54-
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
55-
ck_tile::make_tuple(sequence<0>{})),
55+
transform_tensor_view(
56+
tensor_view,
57+
ck_tile::make_tuple(merge_transform),
58+
ck_tile::make_tuple(make_index_sequence<DimsLens::size()>{}),
59+
ck_tile::make_tuple(sequence<0>{})),
5660
ck_tile::make_tuple(number<S::kBlockM>{}),
5761
sequence<Problem::kPad>{});
5862

@@ -86,13 +90,14 @@ struct ElementWiseKernel
8690
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
8791
p_y, lens, output_strides, number<S::kVectorM>{});
8892

89-
const auto transformed_y_m_n = pad_tensor_view(
90-
transform_tensor_view(y_m_n,
91-
ck_tile::make_tuple(merge_transform),
92-
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
93-
ck_tile::make_tuple(sequence<0>{})),
94-
ck_tile::make_tuple(number<S::kBlockM>{}),
95-
sequence<Problem::kPad>{});
93+
const auto transformed_y_m_n =
94+
pad_tensor_view(transform_tensor_view(
95+
y_m_n,
96+
ck_tile::make_tuple(merge_transform),
97+
ck_tile::make_tuple(make_index_sequence<DimsOutStrides::size()>{}),
98+
ck_tile::make_tuple(sequence<0>{})),
99+
ck_tile::make_tuple(number<S::kBlockM>{}),
100+
sequence<Problem::kPad>{});
96101

97102
auto y_window = make_tile_window(transformed_y_m_n,
98103
make_tuple(number<S::kBlockM>{}),

0 commit comments

Comments
 (0)