You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (… (#22286)
* ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (GQA=32)
Adds MMA-f16 and tile kernel configs, dispatch logic, template instances,
and tile .cu file for Mistral Small 4 (head sizes 320/256), restricting to
ncols2=32 to support GQA ratio 32 only.
* Adding check to return BEST_FATTN_KERNEL_NONE in case GQA!=32
* Apply suggestions from code review
Address review comments
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
* Address review comments and making kernel config default to DQK=512, DV=512 instead of DQK=256,DV=256
* Fixed bug with sinks=1, with ncols=32, there are two warp-groups created but sinks index is same(0,...,15) for both the groups hence with sinks=1, output is not matching with CPU output. Added sink_base which will be base index for each warp_group (threadIdx.y / np)
* Apply suggestions from code review
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
* Update ggml/src/ggml-cuda/template-instances/generate_cu_files.py
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
0 commit comments