Context
FlashAttention-3 (Dao et al., 2024) offers 1.5–2.0x speedup over FA2 on Hopper GPUs (H100/H200) for bf16/fp16, plus FP8 attention support.
KempnerForge routes attention through PyTorch SDPA (F.scaled_dot_product_attention) in kempnerforge/model/attention.py, with an sdpa_backend config option ("auto", "flash", "efficient", "cudnn", "math"). SDPA currently dispatches to FA2, not FA3.
Why not add the standalone flash-attn package
FA3 ships outside PyTorch in the flash-attn Hopper build. Pulling it in conflicts with KempnerForge's PyTorch-native design:
- Source-build step (the default
pip install flash-attn wheel ships FA2 kernels, not the FA3 Hopper build)
- Hopper-only — adds a backend branch in
Attention.forward() that only activates on H100/H200
- Non-PyTorch dep to track and version
- Duplicates work PyTorch will land upstream
What to track
- PyTorch adding FA3 to SDPA (either as a new
SDPBackend enum value or by upgrading the existing FLASH_ATTENTION backend to FA3 on Hopper)
- Once landed,
sdpa_backend="auto" picks it up automatically on Hopper — no code change needed in KempnerForge
Action
Watch PyTorch release notes. Revisit when SDPA exposes FA3.
Priority
Blocked on upstream.
Context
FlashAttention-3 (Dao et al., 2024) offers 1.5–2.0x speedup over FA2 on Hopper GPUs (H100/H200) for bf16/fp16, plus FP8 attention support.
KempnerForge routes attention through PyTorch SDPA (
F.scaled_dot_product_attention) inkempnerforge/model/attention.py, with ansdpa_backendconfig option ("auto","flash","efficient","cudnn","math"). SDPA currently dispatches to FA2, not FA3.Why not add the standalone
flash-attnpackageFA3 ships outside PyTorch in the
flash-attnHopper build. Pulling it in conflicts with KempnerForge's PyTorch-native design:pip install flash-attnwheel ships FA2 kernels, not the FA3 Hopper build)Attention.forward()that only activates on H100/H200What to track
SDPBackendenum value or by upgrading the existingFLASH_ATTENTIONbackend to FA3 on Hopper)sdpa_backend="auto"picks it up automatically on Hopper — no code change needed in KempnerForgeAction
Watch PyTorch release notes. Revisit when SDPA exposes FA3.
Priority
Blocked on upstream.