Plumb transposed cache config through export pipeline#18712
Plumb transposed cache config through export pipeline#18712kimishpatel wants to merge 1 commit intogh/kimishpatel/232/basefrom
Conversation
Benchmarking shows transposed KV cache [B, H, S, D] significantly outperforms standard layout [B, S, H, D] in custom_sdpa, especially at longer cache fills: 1.64x at start_pos=1024, 1.14x at start_pos=512, 1.13x for prefill seq_len=512 (Llama 3 8B config, Apple M-series). The improvement comes from better memory locality in the attn_score @ V GEMM where V stride along S_kv changes from H*D to D. This commit replaces the hardcoded `is_seq_at_dim_2=True # hacking temporarily` values in sdpa.py and custom_kv_cache.py with a proper configurable parameter threaded through the export pipeline: - Add `use_transposed_cache: bool = True` to ModelConfig in llm_config.py - Thread it through _get_source_transforms in export_llama_lib.py - Add `is_seq_at_dim_2` parameter to replace_kv_cache_with_custom_kv_cache and replace_sdpa_with_custom_op (defaulting to True for backward compat) Also fixes: - torchao aarch64:matmul BUCK: deps -> exported_deps for :macro, fixing transitive header visibility on arm64 - op_update_cache.cpp: %zd -> PRId64 for int64_t format strings Authored with Claude. Differential Revision: [D99677679](https://our.internmc.facebook.com/intern/diff/D99677679/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18712
Note: Links to docs will display an error until the docs builds have been completed. ❌ 9 New Failures, 2 Cancelled JobsAs of commit 6f31ee3 with merge base fb1618e ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
submitted by accident, not meant to land immedidately |
Stack from ghstack (oldest at bottom):
Benchmarking shows transposed KV cache [B, H, S, D] significantly outperforms
standard layout [B, S, H, D] in custom_sdpa, especially at longer cache fills:
1.64x at start_pos=1024, 1.14x at start_pos=512, 1.13x for prefill seq_len=512
(Llama 3 8B config, Apple M-series). The improvement comes from better memory
locality in the attn_score @ V GEMM where V stride along S_kv changes from
H*D to D.
This commit replaces the hardcoded
is_seq_at_dim_2=True # hacking temporarilyvalues in sdpa.py and custom_kv_cache.py with a proper configurable parameter
threaded through the export pipeline:
use_transposed_cache: bool = Trueto ModelConfig in llm_config.pyis_seq_at_dim_2parameter to replace_kv_cache_with_custom_kv_cacheand replace_sdpa_with_custom_op (defaulting to True for backward compat)
Also fixes:
transitive header visibility on arm64
Authored with Claude.
Differential Revision: D99677679