Skip to content

Track FlashAttention-3 landing in PyTorch SDPA #41

@mmshad

Description

@mmshad

Context

FlashAttention-3 (Dao et al., 2024) offers 1.5–2.0x speedup over FA2 on Hopper GPUs (H100/H200) for bf16/fp16, plus FP8 attention support.

KempnerForge routes attention through PyTorch SDPA (F.scaled_dot_product_attention) in kempnerforge/model/attention.py, with an sdpa_backend config option ("auto", "flash", "efficient", "cudnn", "math"). SDPA currently dispatches to FA2, not FA3.

Why not add the standalone flash-attn package

FA3 ships outside PyTorch in the flash-attn Hopper build. Pulling it in conflicts with KempnerForge's PyTorch-native design:

  • Source-build step (the default pip install flash-attn wheel ships FA2 kernels, not the FA3 Hopper build)
  • Hopper-only — adds a backend branch in Attention.forward() that only activates on H100/H200
  • Non-PyTorch dep to track and version
  • Duplicates work PyTorch will land upstream

What to track

  • PyTorch adding FA3 to SDPA (either as a new SDPBackend enum value or by upgrading the existing FLASH_ATTENTION backend to FA3 on Hopper)
  • Once landed, sdpa_backend="auto" picks it up automatically on Hopper — no code change needed in KempnerForge

Action

Watch PyTorch release notes. Revisit when SDPA exposes FA3.

Priority

Blocked on upstream.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions