Commit acf3d65
[rocm-libraries] ROCm/rocm-libraries#7256 (commit 1fc20eb)
=?UTF-8?q?Skip=20numeric=20drop-out=20when=20PComputeWind?=
=?UTF-8?q?ow=20is=20a=20null=5Ftile=5Fwindow=20in=20Bl=E2=80=A6=20(#7256)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
The BlockDropout implementation already provides very complete logic for
generating random numbers and executing dropout for the P tensor after
first attention Gemm with capability to support both Warp-Gemm 32x32 and
16x16 as well as to run on both wave32 and wave64 arch.
But in some situation, we only need the block-layer process to generate
random numbers, rather than simultaneously execute dropout in real-time
on the vgpr tile. For example, xformers'
`test_mem_eff_attention.py::test_dropout_ck` requires the host reference
implementation of `attention forward with dropout` to use the same
random numbers to compare & verify the device side implementation of
`attention forward with dropout`, so a standalone kernel to generate
random numbers only is required.
This PR will enable xformers's random_val generating kernel (in file
`ck_tiled_rand_uniform_kernel.h`) to depend on BlockDropout's `Run()`
operator completely to generate random numbers for a `[MPerBlock,
NPerBlock]` tile during the tile iteration, no need to replicate the
logic of BlockDropout in the xformers kernel1 parent 5c7b7ec commit acf3d65
1 file changed
Lines changed: 21 additions & 17 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
381 | 381 | | |
382 | 382 | | |
383 | 383 | | |
384 | | - | |
385 | | - | |
386 | | - | |
387 | | - | |
388 | | - | |
389 | | - | |
390 | | - | |
391 | | - | |
392 | | - | |
393 | | - | |
394 | | - | |
395 | | - | |
396 | | - | |
397 | | - | |
398 | | - | |
399 | | - | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
400 | 404 | | |
401 | | - | |
| 405 | + | |
402 | 406 | | |
403 | 407 | | |
404 | 408 | | |
| |||
0 commit comments