Skip to content

Commit acf3d65

Browse files
qianfengzassistant-librarian[bot]
authored andcommitted
[rocm-libraries] ROCm/rocm-libraries#7256 (commit 1fc20eb)
=?UTF-8?q?Skip=20numeric=20drop-out=20when=20PComputeWind?= =?UTF-8?q?ow=20is=20a=20null=5Ftile=5Fwindow=20in=20Bl=E2=80=A6=20(#7256)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The BlockDropout implementation already provides very complete logic for generating random numbers and executing dropout for the P tensor after first attention Gemm with capability to support both Warp-Gemm 32x32 and 16x16 as well as to run on both wave32 and wave64 arch. But in some situation, we only need the block-layer process to generate random numbers, rather than simultaneously execute dropout in real-time on the vgpr tile. For example, xformers' `test_mem_eff_attention.py::test_dropout_ck` requires the host reference implementation of `attention forward with dropout` to use the same random numbers to compare & verify the device side implementation of `attention forward with dropout`, so a standalone kernel to generate random numbers only is required. This PR will enable xformers's random_val generating kernel (in file `ck_tiled_rand_uniform_kernel.h`) to depend on BlockDropout's `Run()` operator completely to generate random numbers for a `[MPerBlock, NPerBlock]` tile during the tile iteration, no need to replicate the logic of BlockDropout in the xformers kernel
1 parent 5c7b7ec commit acf3d65

1 file changed

Lines changed: 21 additions & 17 deletions

File tree

include/ck_tile/ops/fmha/block/block_dropout.hpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -381,24 +381,28 @@ struct BlockDropout
381381
store_tile(randval_dram_window, randval_store);
382382
}
383383
move_tile_window(randval_dram_window, {0, kNPerStep});
384-
// Drop values of P based on the generated probabilities
385-
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
386-
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
387-
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
388-
constexpr auto p_idx0 =
389-
tile_distributed_index<i_m0 * MIterPerWarp +
390-
idx0.impl_.template at<0>()>{};
391-
constexpr auto p_idx1 =
392-
tile_distributed_index<i_n0,
393-
idx1.impl_.template at<1>(),
394-
idx1.impl_.template at<2>()>{};
395-
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
396-
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
397-
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
398-
? p_compute[p_idx] * rp_undrop
399-
: PComputeDataType(0);
384+
385+
if constexpr(!is_null_tile_window_v<PComputeWindow>)
386+
{
387+
// Drop values of P based on the generated probabilities
388+
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
389+
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
390+
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
391+
constexpr auto p_idx0 =
392+
tile_distributed_index<i_m0 * MIterPerWarp +
393+
idx0.impl_.template at<0>()>{};
394+
constexpr auto p_idx1 =
395+
tile_distributed_index<i_n0,
396+
idx1.impl_.template at<1>(),
397+
idx1.impl_.template at<2>()>{};
398+
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
399+
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
400+
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
401+
? p_compute[p_idx] * rp_undrop
402+
: PComputeDataType(0);
403+
});
400404
});
401-
});
405+
}
402406
});
403407
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
404408
});

0 commit comments

Comments
 (0)