aiter-flash-attn: add AITER Flash Attention kernel for AMD ROCm#890
aiter-flash-attn: add AITER Flash Attention kernel for AMD ROCm#890Abdennacer-Badaoui wants to merge 3 commits into
Conversation
There was a problem hiding this comment.
Any reason we do not port the kv cache function as well? https://github.com/ROCm/aiter/blob/5de8c99ef2988e1700ae39dd6f41a195a7988906/aiter/ops/triton/attention/mha.py#L912
Could be useful for continuous batching which uses this and has quite nice perf boosts
There was a problem hiding this comment.
Good catch; there’s no real reason beyond trimming for v1. AITER’s flash_attn_with_kvcache is the only entry point that calls into the flash_attn_triton_amd subpackage (via flash_attn_2.fwd_kvcache), and that same subpackage also backs the optional _MHA_IMPL = "dao_ai" forward/backward switch, which nothing in our code path actually selects by default.
To keep the initial PR minimal and avoid shipping the larger v2/v3 backend tree (~10k LOC) on top of the core MHA Triton kernels, I dropped that subpackage along with flash_attn_with_kvcache.
But if we need it for CB, i will add it now
There was a problem hiding this comment.
Argh that's a bit awkward but yea I think it would be a nice to have, thanks for adding
This adds a Triton FlashAttention kernel for AMD ROCm, repackaged from the MHA implementation in AMD’s AITER project (https://github.com/ROCm/aiter).
The motivation is on the Transformers side: the ROCm FlashAttention fallback currently depends on installing the full
aiterpip package, which reviewers (rightly) pushed back on. Having an equivalent kernel here means Transformers can route throughget_kernellike every other FA backend instead of carrying a direct AITER dependency. Beyond just unblocking the dependency story, this is also an FA3-style kernel with native support for learnable attention sinks (sink=), which is required by models like gpt-oss on ROCm; somethingflash-attn2does not provide.It’s a slim copy; only the parts actually reachable from
flash_attn_funcandflash_attn_varlen_func, with the unuseddao_aiimplementation path removed and all absolute imports rewritten to be Hub-compliant. It has been tested locally on MI300X against an eager SDPA reference; numerics match within fp16 tolerance for dense, causal, and varlen cases.