Skip to content

Commit f875ab0

Browse files
authored
Add validity checks for MoE FlatMM scatter and enable bf16 hardware atomic-add (#3236)
* Add validity checks for MoE FlatMM scatter and enable bf16 hardware atomic * correct clang-format * removed unused rtol_atol variable from example code * clang format correction * remove unused varable max_accumulated_value from example
1 parent 30727c4 commit f875ab0

2 files changed

Lines changed: 10 additions & 2 deletions

File tree

include/ck_tile/core/arch/generic_memory_space_atomic.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
102102
template <>
103103
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
104104
{
105+
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
106+
__builtin_amdgcn_global_atomic_fadd_v2bf16(c_style_pointer_cast<bf16x2_t*>(p_dst), x);
107+
#else
105108
union U32BF162_ADDR
106109
{
107110
uint32_t* u32_a;
@@ -128,6 +131,7 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
128131
new_v = new_.u32;
129132
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
130133
} while(cur_v.u32 != old_v);
134+
#endif
131135
}
132136

133137
template <>

include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ struct MoeFlatmmKernel
623623
{
624624
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
625625
e_ptr,
626-
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken,
626+
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens,
627627
IsGateUp ? kargs.N / 2 : kargs.N),
628628
make_tuple(1, kargs.stride_C),
629629
number<1>{},
@@ -1250,6 +1250,8 @@ struct MoeFlatmmKernel
12501250
constexpr int MPerThread = TileEncodingPattern::Y2;
12511251
statically_indexed_array<statically_indexed_array<index_t, MPerThread>, NumMEpiTile>
12521252
c_scatter_offsets;
1253+
statically_indexed_array<statically_indexed_array<bool, MPerThread>, NumMEpiTile>
1254+
c_scatter_valids;
12531255
auto c_coord = dram_tile_distribution.calculate_index();
12541256
static_for<0, NumMEpiTile, 1>{}([&](auto mIter) {
12551257
static_for<0, MPerThread, 1>{}([&](auto m0) {
@@ -1262,6 +1264,7 @@ struct MoeFlatmmKernel
12621264
scatter_token_id =
12631265
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
12641266
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
1267+
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
12651268
});
12661269
});
12671270

@@ -1302,7 +1305,8 @@ struct MoeFlatmmKernel
13021305
c_block_window.get_window_lengths(),
13031306
c_block_window.get_window_origin(),
13041307
dram_tile_distribution,
1305-
c_scatter_offsets[mIter]);
1308+
c_scatter_offsets[mIter],
1309+
c_scatter_valids[mIter]);
13061310

13071311
if constexpr(!IsInputGemm ||
13081312
EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)

0 commit comments

Comments
 (0)