Skip to content

Commit ef4ff46

Browse files
bartekxkassistant-librarian[bot]
authored andcommitted
[rocm-libraries] ROCm/rocm-libraries#5842 (commit 04c5690)
[CK][CK Tile] Force padding for atomic_add bf16 C tensor (#5842) ## Motivation Force padding for atomic_add bf16 C tensor to avoid memfaults. ## Technical Details - add global atomic add for bf16 and enable them - add padding for atomic add bf16 due to the lack of oob - remove padding for not continous dims in conv for other cases - minor bwd data conv fixes ## Test Plan test_grouped_conv_*_tile ## Test Result pending ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent 66dc81d commit ef4ff46

7 files changed

Lines changed: 174 additions & 171 deletions

include/ck_tile/core/arch/amd_buffer_addressing.hpp

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
#include "ck_tile/core/utility/ignore.hpp"
1919
#include "ck_tile/core/arch/amd_buffer_coherence.hpp"
2020

21+
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
22+
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
23+
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
24+
2125
// This attribute gives a hint to the compiler that a branch is likely to be taken.
2226
// Then, the compiler should remove if possible the associated s_cbranch_execz branch that would
2327
// have been generated.
@@ -2317,6 +2321,34 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
23172321
}
23182322
}
23192323

2324+
template <typename T, index_t N>
2325+
CK_TILE_DEVICE void
2326+
amd_global_atomic_add_impl([[maybe_unused]] const thread_buffer<T, N>& src_thread_data,
2327+
[[maybe_unused]] T* addr)
2328+
{
2329+
static_assert((std::is_same<T, ck_tile::bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
2330+
(std::is_same<T, ck_tile::fp16_t>::value && (N == 2 || N == 4 || N == 8)),
2331+
"wrong! not implemented");
2332+
2333+
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
2334+
if constexpr(__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16) &&
2335+
std::is_same<T, ck_tile::bf16_t>::value)
2336+
{
2337+
static_for<0, N / 2, 1>{}([&](auto i) {
2338+
__builtin_amdgcn_global_atomic_fadd_v2bf16(
2339+
bit_cast<ck_tile::bf16x2_t*>(addr) + i,
2340+
src_thread_data.template get_as<ck_tile::bf16x2_t>()[i]);
2341+
});
2342+
}
2343+
else
2344+
{
2345+
static_assert(false, "Not supported!");
2346+
}
2347+
#else
2348+
static_assert(false, "Not supported!");
2349+
#endif
2350+
}
2351+
23202352
template <typename T, index_t N>
23212353
CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_thread_data,
23222354
int32x4_t dst_wave_buffer_resource,
@@ -2325,8 +2357,11 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
23252357
{
23262358
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
23272359
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
2328-
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
2329-
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
2360+
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4))
2361+
#if defined(__gfx950__)
2362+
|| (std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8))
2363+
#endif
2364+
,
23302365
"wrong! not implemented");
23312366

23322367
if constexpr(std::is_same<T, float>::value)
@@ -2931,23 +2966,37 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
29312966
const bool dst_thread_element_valid,
29322967
const index_t dst_element_space_size)
29332968
{
2934-
const int32x4_t dst_wave_buffer_resource =
2935-
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
2936-
2937-
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
2969+
#if defined(__gfx942__)
2970+
if constexpr(std::is_same<T, bf16_t>::value)
2971+
{
2972+
if(dst_thread_element_valid)
2973+
{
2974+
amd_global_atomic_add_impl<T, N>(src_thread_data,
2975+
p_dst_wave + dst_thread_element_offset);
2976+
}
2977+
}
2978+
else
2979+
{
2980+
#endif
2981+
const int32x4_t dst_wave_buffer_resource =
2982+
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
29382983

2984+
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
29392985
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
2940-
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
2986+
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
29412987

2942-
amd_buffer_atomic_add_impl<T, N>(
2943-
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
2988+
amd_buffer_atomic_add_impl<T, N>(
2989+
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
29442990
#else
29452991
if(dst_thread_element_valid)
29462992
{
29472993
amd_buffer_atomic_add_impl<T, N>(
29482994
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
29492995
}
29502996
#endif
2997+
#if defined(__gfx942__)
2998+
}
2999+
#endif
29513000
}
29523001

29533002
template <typename T,

include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
#include "ck_tile/core/utility/ignore.hpp"
1919
#include "ck_tile/core/arch/amd_buffer_coherence.hpp"
2020

21+
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
22+
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
23+
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
24+
2125
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
2226

2327
namespace ck_tile {
@@ -2143,6 +2147,33 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
21432147
}
21442148
}
21452149

2150+
template <typename T, index_t N>
2151+
CK_TILE_DEVICE void
2152+
amd_global_atomic_add_impl([[maybe_unused]] const thread_buffer<T, N>& src_thread_data,
2153+
[[maybe_unused]] T* addr)
2154+
{
2155+
static_assert((std::is_same<T, ck_tile::bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
2156+
(std::is_same<T, ck_tile::fp16_t>::value && (N == 2 || N == 4 || N == 8)),
2157+
"wrong! not implemented");
2158+
2159+
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
2160+
if constexpr(std::is_same<T, ck_tile::bf16_t>::value)
2161+
{
2162+
static_for<0, N / 2, 1>{}([&](auto i) {
2163+
__builtin_amdgcn_global_atomic_fadd_v2bf16(
2164+
bit_cast<ck_tile::bf16x2_t*>(addr) + i,
2165+
src_thread_data.template get_as<ck_tile::bf16x2_t>()[i]);
2166+
});
2167+
}
2168+
else
2169+
{
2170+
static_assert(false, "Not supported!");
2171+
}
2172+
#else
2173+
static_assert(false, "Not supported!");
2174+
#endif
2175+
}
2176+
21462177
template <typename T, index_t N>
21472178
CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_thread_data,
21482179
int32x4_t dst_wave_buffer_resource,
@@ -2151,8 +2182,11 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
21512182
{
21522183
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
21532184
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
2154-
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
2155-
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
2185+
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4))
2186+
#if defined(__gfx950__)
2187+
|| (std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8))
2188+
#endif
2189+
,
21562190
"wrong! not implemented");
21572191

21582192
if constexpr(std::is_same<T, float>::value)
@@ -2759,23 +2793,38 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
27592793
const bool dst_thread_element_valid,
27602794
const index_t dst_element_space_size)
27612795
{
2762-
const int32x4_t dst_wave_buffer_resource =
2763-
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
2796+
#if defined(__gfx942__)
2797+
if constexpr(std::is_same<T, bf16_t>::value)
2798+
{
2799+
if(dst_thread_element_valid)
2800+
{
2801+
amd_global_atomic_add_impl<T, N>(src_thread_data,
2802+
p_dst_wave + dst_thread_element_offset);
2803+
}
2804+
}
2805+
else
2806+
{
2807+
#endif
2808+
const int32x4_t dst_wave_buffer_resource =
2809+
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
27642810

2765-
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
2811+
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
27662812

27672813
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
2768-
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
2814+
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
27692815

2770-
amd_buffer_atomic_add_impl<T, N>(
2771-
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
2816+
amd_buffer_atomic_add_impl<T, N>(
2817+
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
27722818
#else
27732819
if(dst_thread_element_valid)
27742820
{
27752821
amd_buffer_atomic_add_impl<T, N>(
27762822
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
27772823
}
27782824
#endif
2825+
#if defined(__gfx942__)
2826+
}
2827+
#endif
27792828
}
27802829

27812830
template <typename T,

include/ck_tile/core/tensor/buffer_view.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ struct buffer_view<address_space_enum::global,
630630
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
631631
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
632632
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
633-
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
633+
#if defined(__gfx942__) || defined(__gfx950__) // only gfx942 and gfx950 support atomic_pk_add_bf16
634634
||
635635
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
636636
#endif
@@ -642,7 +642,7 @@ struct buffer_view<address_space_enum::global,
642642
bool constexpr use_amd_buffer_addressing =
643643
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
644644
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
645-
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
645+
#if defined(__gfx942__) || defined(__gfx950__) // only gfx942 and gfx950 support atomic_pk_add_bf16
646646
||
647647
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
648648
#endif

include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,21 +1021,26 @@ struct UniversalGemmKernel
10211021
const auto& e_tensor_view =
10221022
make_tensor_view<address_space_enum::global, DstInMemOp>(e_ptr, e_desc);
10231023

1024+
// For bf16_t and atomic_add global_atomic_add is used instead of buffer_atomic_add
1025+
// Add padding for not contiguous dim due to the lack of OOB check
1026+
constexpr bool pad_not_contiguous_dim =
1027+
std::is_same_v<EDataType, bf16_t> && DstInMemOp == memory_operation_enum::atomic_add;
1028+
10241029
// Step 2: Create padded view
10251030
const auto& e_pad_view = [&]() {
10261031
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
10271032
{
10281033
return pad_tensor_view(e_tensor_view,
10291034
make_tuple(number<TilePartitioner::MPerBlock>{},
10301035
number<TilePartitioner::NPerBlock>{}),
1031-
sequence<false, GemmPipeline::kPadN>{});
1036+
sequence<pad_not_contiguous_dim, GemmPipeline::kPadN>{});
10321037
}
10331038
else
10341039
{
10351040
return pad_tensor_view(e_tensor_view,
10361041
make_tuple(number<TilePartitioner::MPerBlock>{},
10371042
number<TilePartitioner::NPerBlock>{}),
1038-
sequence<GemmPipeline::kPadM, false>{});
1043+
sequence<GemmPipeline::kPadM, pad_not_contiguous_dim>{});
10391044
}
10401045
}();
10411046

0 commit comments

Comments
 (0)