You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Previously custom_sdpa used a single is_seq_at_dim_2 flag for all
tensors. This meant v_only transpose required a runtime transpose
copy for K (converting from [B,H,S,D] to [B,S,H,D]), which caused
a 2.3x decode slowdown (15.35 vs 35.63 tok/s).
Now the C++ op accepts separate is_seq_dim_2, is_k_seq_dim_2,
is_v_seq_dim_2 flags so Q, K, V can each have independent layouts.
The Python layer passes K and V in their native cache layout
without any transpose, and the flash attention kernel handles the
mixed strides directly.
Changes:
- op_sdpa_impl.h: cpu_flash_attention takes q_seq_dim, k_seq_dim,
v_seq_dim instead of single seq_dim
- op_sdpa.cpp/h: custom_sdpa_out takes 3 bool params
- op_sdpa_aot.cpp: Updated schema strings and wrappers
- sdpa.py: SDPACustom uses is_k_seq_at_dim_2 / is_v_seq_at_dim_2,
Q always at dim 2, no input transposes
- custom_kv_cache.py: update() returns native cache layout,
added is_seq_at_dim_2 compat property
- export_llama_lib.py: passes separate K/V flags
Differential Revision: [D99677678](https://our.internmc.facebook.com/intern/diff/D99677678/)
[ghstack-poisoned]
0 commit comments