Skip to content

aiter-flash-attn: add AITER Flash Attention kernel for AMD ROCm#890

Open
Abdennacer-Badaoui wants to merge 3 commits into
huggingface:mainfrom
Abdennacer-Badaoui:aiter-flash-attn-add
Open

aiter-flash-attn: add AITER Flash Attention kernel for AMD ROCm#890
Abdennacer-Badaoui wants to merge 3 commits into
huggingface:mainfrom
Abdennacer-Badaoui:aiter-flash-attn-add

Conversation

@Abdennacer-Badaoui
Copy link
Copy Markdown
Member

This adds a Triton FlashAttention kernel for AMD ROCm, repackaged from the MHA implementation in AMD’s AITER project (https://github.com/ROCm/aiter).

The motivation is on the Transformers side: the ROCm FlashAttention fallback currently depends on installing the full aiter pip package, which reviewers (rightly) pushed back on. Having an equivalent kernel here means Transformers can route through get_kernel like every other FA backend instead of carrying a direct AITER dependency. Beyond just unblocking the dependency story, this is also an FA3-style kernel with native support for learnable attention sinks (sink=), which is required by models like gpt-oss on ROCm; something flash-attn2 does not provide.

It’s a slim copy; only the parts actually reachable from flash_attn_func and flash_attn_varlen_func, with the unused dao_ai implementation path removed and all absolute imports rewritten to be Hub-compliant. It has been tested locally on MI300X against an eager SDPA reference; numerics match within fp16 tolerance for dense, causal, and varlen cases.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we do not port the kv cache function as well? https://github.com/ROCm/aiter/blob/5de8c99ef2988e1700ae39dd6f41a195a7988906/aiter/ops/triton/attention/mha.py#L912

Could be useful for continuous batching which uses this and has quite nice perf boosts

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch; there’s no real reason beyond trimming for v1. AITER’s flash_attn_with_kvcache is the only entry point that calls into the flash_attn_triton_amd subpackage (via flash_attn_2.fwd_kvcache), and that same subpackage also backs the optional _MHA_IMPL = "dao_ai" forward/backward switch, which nothing in our code path actually selects by default.

To keep the initial PR minimal and avoid shipping the larger v2/v3 backend tree (~10k LOC) on top of the core MHA Triton kernels, I dropped that subpackage along with flash_attn_with_kvcache.

But if we need it for CB, i will add it now

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh that's a bit awkward but yea I think it would be a nice to have, thanks for adding

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants