Skip to content

xlite-dev/ffpa-attn

Repository files navigation

🤖FFPA: Yet another Faster Flash Prefill Attention
with O(1)⚡️GPU SRAM complexity for large headdim🐑


FFPA(Split-D): Yet another Faster Flash Prefill Attention with Split-D strategy, achieve O(1) SRAM complexity and O(d/4) register complexity for large headdim (> 256), 1.5~3x 🎉 faster than SDPA. 📚👇The Core features:

Self Attn GQA/MQA Cross Attn Causal/Mask Dropout Headdim Fwd/Bwd
✔️(Nq=Nkv) ✔️(Hq!=Hkv) ✔️(Nq!=Nkv) ✔️(attn_mask) ✔️(p>0) 320~1024 1.5~3x↑

📖 Quick Start

First, install the prebuilt package from PyPI or build ffpa-attn from source:

# Fisrt, install the prebuilt package from PyPI
pip3 install -U ffpa-attn # (support: sm_{80,...,120})
# Or, build ffpa-attn from source, just follow the cmds
git clone https://github.com/xlite-dev/ffpa-attn.git
# Then, build the wheel package (Triton + CuTeDSL backends)
cd ffpa-attn && pip3 install -e . --no-build-isolation
# Optional: install ffpa-attn with CUDA backend support
ENABLE_FFPA_CUDA_IMPL=1 MAX_JOBS=32 pip3 install -e .

Then, try to accelerate the attention for large headdim with just one-line of code:

>>> import torch.nn.functional as F
>>> from ffpa_attn import ffpa_attn_func
>>> # Monkey-patch SDPA to point to FFPA. Every thing that FFPA
>>> # does not support will auto fallback to SDPA: D <= 256, etc.
>>> F.scaled_dot_product_attention = ffpa_attn_func # one-line code

For more advanced features, please refer to our online docs at 📘ffpa-attn.io.

📖 Split-D

We extend FlashAttention to support large headdim ($D&gt;256$) via fine-grained tiling at the MMA level for $QK^\top$ and $PV$ matrix multiplication, referred to as Split-D. This design keeps SRAM usage fixed at $B_r \times 16$ (with $B_r=B_c$) for Q, K and V, yielding constant SRAM complexity $O(B_r \times 16) \approx O(1)$ and register complexity $O(d/4)$.

FFPA enables headdim > 256, and outperforms standard SDPA by 1.5~3x🎉.

Note

FFPA has been tested on Ampere, Ada, Hopper, and Blackwell architectures (e.g., A30, L20, 4090, H200, 5090), achieves 1.5~3×↑🎉 speedup over SDPA. FFPA is mainly design for prefill and large headdim, and may not be faster than SDPA for 😈 small sequence length (N<512) or small headdim (D<=256).

🎉 Benchmark

Runnable examples are provided under examples. The performance benchmarks for the NVIDIA L20 (Ada), NVIDIA Geforce RTX 5090 (Blackwell), NVIDIA H800 PCIE (Hopper), NVIDIA H200 SXM (Hopper, CuTeDSL backend, up to 427 TFLOPS!🎉) with large headdim are shown below:




🤖 Backends

FFPA supports multiple backends for the forward and backward pass, including: SDPA (baseline), CUDA (forward only), Triton, and CuTeDSL. The CuTeDSL backend is currently in early stage and has some constraints (e.g., D <= 512), but it can achieve up to 427🎉 TFLOPS on H200! Stay tuned for future updates.

Backend Arch Fwd Bwd Headdim Autotune Speedup Recommend
SDPA sm>=75 All 1.0x🤗 sm>=75
CUDA sm>=80 320~1024 1.5x~3x🎉 sm80~89,120
Triton sm>=80 320~1024 1.5x~3x🎉 sm>=80
CuTeDSL sm80~89 320~1024 1.5x~2x🎉 sm80~89
CuTeDSL sm90 320~512 3x~6x🎉 sm90

Special thanks to Butterfingrz for contributing to the CuTeDSL backend! Awesome work!🎉

How to use different backends for your own scenario? Users can simply pass the Backend configs (SDPABackend, CUDABackend, TritonBackend or CuTeDSLBackend) to ffpa_attn_func, for example:

>>> from ffpa_attn import ffpa_attn_func, CuTeDSLBackend
>>> # CuTeDSL backend, D=512 scenario, fastest on H200!🎉
>>> o = ffpa_attn_func(q, k, v, backend=CuTeDSLBackend())

©️License

Apache License 2.0

©️Citations

@misc{ffpa-attn@2025,
  title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},
  url={https://github.com/xlite-dev/ffpa-attn.git},
  note={Open-source software available at https://github.com/xlite-dev/ffpa-attn.git},
  author={DefTruth, Butterfingrz},
  year={2025}
}

📖 References

About

🤖FFPA: Extends FlashAttention-2 via Split-D for large headdims, 1.5x~3×↑🎉 vs SDPA, up to 430T🎉 on H200.

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages