opencl: flash attention improvement#25069
Open
wanghqc wants to merge 10 commits into
Open
Conversation
- flash_attn_kv_pad_f16 pads the tail KV tile to a BLOCK_N multiple
- flash_attn_mask_pad_f16 pads the matching mask tile
- flash_attn_blk_f16 classifies each KV tile per query block as
fully masked / mixed / fully unmasked, so
the main kernel can skip fully-masked tiles
and the mask lookup for fully-unmasked ones
|
Hi @wanghqc, thanks for your contribution! Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:
Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Overview
Rework the FA for OpenCL backend to improve precision and performance, support quantized KV cache. Tested with gpt-oss-20b model. Works well with models with head_dim of 64. For larger head_dim, the main benefit is the data traffic savings.
Additional information
This targets the Adreno GPUs, tested with Adreno GPUs in flagship android devices, and Windows on Snapdragon (WoS) (X1,and X2 GPUs).
Requirements