Support separate K/V seq dim in custom_sdpa op#18714
Support separate K/V seq dim in custom_sdpa op#18714kimishpatel wants to merge 1 commit intogh/kimishpatel/234/basefrom
Conversation
Previously custom_sdpa used a single is_seq_at_dim_2 flag for all tensors. This meant v_only transpose required a runtime transpose copy for K (converting from [B,H,S,D] to [B,S,H,D]), which caused a 2.3x decode slowdown (15.35 vs 35.63 tok/s). Now the C++ op accepts separate is_seq_dim_2, is_k_seq_dim_2, is_v_seq_dim_2 flags so Q, K, V can each have independent layouts. The Python layer passes K and V in their native cache layout without any transpose, and the flash attention kernel handles the mixed strides directly. Changes: - op_sdpa_impl.h: cpu_flash_attention takes q_seq_dim, k_seq_dim, v_seq_dim instead of single seq_dim - op_sdpa.cpp/h: custom_sdpa_out takes 3 bool params - op_sdpa_aot.cpp: Updated schema strings and wrappers - sdpa.py: SDPACustom uses is_k_seq_at_dim_2 / is_v_seq_at_dim_2, Q always at dim 2, no input transposes - custom_kv_cache.py: update() returns native cache layout, added is_seq_at_dim_2 compat property - export_llama_lib.py: passes separate K/V flags Differential Revision: [D99677678](https://our.internmc.facebook.com/intern/diff/D99677678/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18714
Note: Links to docs will display an error until the docs builds have been completed. ❌ 126 New Failures, 2 Cancelled JobsAs of commit 45468e1 with merge base fb1618e ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
submitted by accident, not meant to land immedidately |
Stack from ghstack (oldest at bottom):
Previously custom_sdpa used a single is_seq_at_dim_2 flag for all
tensors. This meant v_only transpose required a runtime transpose
copy for K (converting from [B,H,S,D] to [B,S,H,D]), which caused
a 2.3x decode slowdown (15.35 vs 35.63 tok/s).
Now the C++ op accepts separate is_seq_dim_2, is_k_seq_dim_2,
is_v_seq_dim_2 flags so Q, K, V can each have independent layouts.
The Python layer passes K and V in their native cache layout
without any transpose, and the flash attention kernel handles the
mixed strides directly.
Changes:
v_seq_dim instead of single seq_dim
Q always at dim 2, no input transposes
added is_seq_at_dim_2 compat property
Differential Revision: D99677678