From 6f31ee3c3a2c6f6493c93c07f7b1a70642ca1be8 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 6 Apr 2026 08:47:45 -0700 Subject: [PATCH] Plumb transposed cache config through export pipeline Benchmarking shows transposed KV cache [B, H, S, D] significantly outperforms standard layout [B, S, H, D] in custom_sdpa, especially at longer cache fills: 1.64x at start_pos=1024, 1.14x at start_pos=512, 1.13x for prefill seq_len=512 (Llama 3 8B config, Apple M-series). The improvement comes from better memory locality in the attn_score @ V GEMM where V stride along S_kv changes from H*D to D. This commit replaces the hardcoded `is_seq_at_dim_2=True # hacking temporarily` values in sdpa.py and custom_kv_cache.py with a proper configurable parameter threaded through the export pipeline: - Add `use_transposed_cache: bool = True` to ModelConfig in llm_config.py - Thread it through _get_source_transforms in export_llama_lib.py - Add `is_seq_at_dim_2` parameter to replace_kv_cache_with_custom_kv_cache and replace_sdpa_with_custom_op (defaulting to True for backward compat) Also fixes: - torchao aarch64:matmul BUCK: deps -> exported_deps for :macro, fixing transitive header visibility on arm64 - op_update_cache.cpp: %zd -> PRId64 for int64_t format strings Authored with Claude. Differential Revision: [D99677679](https://our.internmc.facebook.com/intern/diff/D99677679/) [ghstack-poisoned] --- examples/models/llama/export_llama_lib.py | 13 ++++++++++--- .../source_transformation/custom_kv_cache.py | 17 ++++++++--------- .../models/llama/source_transformation/sdpa.py | 10 +++++----- extension/llm/custom_ops/op_update_cache.cpp | 8 ++++---- extension/llm/export/config/llm_config.py | 5 +++++ 5 files changed, 32 insertions(+), 21 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 7d6371add44..be29dd448e5 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -765,6 +765,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: llm_config.model, "use_custom_sdpa_with_attention_mask", False ), use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache, + use_transposed_cache=llm_config.model.use_transposed_cache, quantize_kv_cache=llm_config.model.quantize_kv_cache, use_kv_cache=llm_config.model.use_kv_cache, qnn=llm_config.backend.qnn.enabled, @@ -1603,6 +1604,7 @@ def _get_source_transforms( # noqa expand_rope_table: bool = False, use_custom_sdpa_with_attention_mask: bool = False, use_sdpa_with_kv_cache: bool = False, + use_transposed_cache: bool = True, quantize_kv_cache: bool = False, use_kv_cache: bool = False, qnn: bool = False, @@ -1639,6 +1641,7 @@ def _get_source_transforms( # noqa expand_rope_table: Whether to expand rope table. use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask. use_sdpa_with_kv_cache: Whether to use SDPA with KV cache. + use_transposed_cache: Whether to store KV cache in transposed layout [B, H, S, D]. quantize_kv_cache: Whether to quantize KV cache. use_kv_cache: Whether to use KV cache. qnn: Whether to use QNN. @@ -1734,16 +1737,20 @@ def _get_source_transforms( # noqa use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask if use_sdpa_with_kv_cache: - transforms.append(replace_kv_cache_with_custom_kv_cache) + transforms.append( + partial(replace_kv_cache_with_custom_kv_cache, is_seq_at_dim_2=use_transposed_cache) + ) # todo: do this optionally # if use attention mask instead of causal attention # then create partial function that sets use_attention_mask=True if use_attention_mask_for_custom_sdpa: transforms.append( - partial(replace_sdpa_with_custom_op, use_attention_mask=True) + partial(replace_sdpa_with_custom_op, use_attention_mask=True, is_seq_at_dim_2=use_transposed_cache) ) else: - transforms.append(replace_sdpa_with_custom_op) + transforms.append( + partial(replace_sdpa_with_custom_op, is_seq_at_dim_2=use_transposed_cache) + ) if quantize_kv_cache: assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 64822fd0c42..c5a056ff2e9 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -391,21 +391,20 @@ def update( return (self.k_cache, self.v_cache) -def replace_kv_cache_with_custom_kv_cache(module): +def replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=True): """ Replace KVCache with CustomKVCache. This modifies the model in place. - At the moment custom kv cache only supports cache with shape - [B, S, H, D] as opposed to [B, H, S, D] - This is because the custom op treats second dim as sequence dim. - Future work: support [B, H, S, D] + When is_seq_at_dim_2=True, cache is stored as [B, H, S, D] (transposed), + which improves SDPA GEMM performance via better memory locality. + When is_seq_at_dim_2=False, cache is stored as [B, S, H, D] (standard). """ logging.info( "Replacing KVCache with CustomKVCache. This modifies the model in place." ) - return _replace_kv_cache_with_custom_kv_cache(module) + return _replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=is_seq_at_dim_2) -def _replace_kv_cache_with_custom_kv_cache(module): +def _replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=True): for name, child in module.named_children(): if isinstance(child, KVCache): cache_dtype = child.k_cache.dtype @@ -422,11 +421,11 @@ def _replace_kv_cache_with_custom_kv_cache(module): n_heads, head_dim, dtype=cache_dtype, - is_seq_at_dim_2=True, # hacking temporarily + is_seq_at_dim_2=is_seq_at_dim_2, ), ) else: - _replace_kv_cache_with_custom_kv_cache(child) + _replace_kv_cache_with_custom_kv_cache(child, is_seq_at_dim_2=is_seq_at_dim_2) return module diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index a3544d5704e..abdd039c10f 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -82,7 +82,7 @@ def forward( def _replace_sdpa_with_custom_op( - module: torch.nn.Module, use_attention_mask: bool = False + module: torch.nn.Module, use_attention_mask: bool = False, is_seq_at_dim_2: bool = True ): for name, child in module.named_children(): if isinstance(child, SDPA): @@ -92,19 +92,19 @@ def _replace_sdpa_with_custom_op( SDPACustom( child.dim, use_attention_mask=use_attention_mask, - is_seq_at_dim_2=True, # hacking temporarily + is_seq_at_dim_2=is_seq_at_dim_2, ), ) else: - _replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask) + _replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask, is_seq_at_dim_2=is_seq_at_dim_2) def replace_sdpa_with_custom_op( - module: torch.nn.Module, use_attention_mask: bool = False + module: torch.nn.Module, use_attention_mask: bool = False, is_seq_at_dim_2: bool = True ) -> torch.nn.Module: from executorch.extension.llm.custom_ops import custom_ops # noqa - _replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask) + _replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask, is_seq_at_dim_2=is_seq_at_dim_2) return module diff --git a/extension/llm/custom_ops/op_update_cache.cpp b/extension/llm/custom_ops/op_update_cache.cpp index 5f918bd90bb..913714590fa 100644 --- a/extension/llm/custom_ops/op_update_cache.cpp +++ b/extension/llm/custom_ops/op_update_cache.cpp @@ -119,17 +119,17 @@ Tensor& update_cache_impl( ET_CHECK_MSG( value_batch_size == cache_batch_size, - "projected_value batch size (%zd) should be equal to the cache batch size (%zd).", + "projected_value batch size (%" PRId64 ") should be equal to the cache batch size (%" PRId64 ").", value_batch_size, cache_batch_size); ET_CHECK_MSG( value_num_heads == cache_num_heads, - "projected_value number of heads (%zd) should be equal to the cache number of heads (%zd).", + "projected_value number of heads (%" PRId64 ") should be equal to the cache number of heads (%" PRId64 ").", value_num_heads, cache_num_heads); ET_CHECK_MSG( value_head_dim == cache_head_dim, - "projected_value embedding dimension (%zd) should be equal to the cache embedding dimension (%zd).", + "projected_value embedding dimension (%" PRId64 ") should be equal to the cache embedding dimension (%" PRId64 ").", value_head_dim, cache_head_dim); ET_CHECK_MSG( @@ -210,7 +210,7 @@ Tensor& update_cache_impl( // Ensure the target position is valid ET_CHECK_MSG( target_pos >= 0 && target_pos < cache_seq_len, - "Index out of bounds: %" PRId64 " not in [0, %zd)", + "Index out of bounds: %" PRId64 " not in [0, %" PRId64 ")", target_pos, cache_seq_len); diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index e126ef54456..a9c1ffce349 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -174,6 +174,10 @@ class ModelConfig: use_sdpa_with_kv_cache: Whether to use flash attention by substituting for our custom SDPA op. Note that the naming is poor and this doesn't actually have anything to do with the kv_cache at the moment. + use_transposed_cache: Whether to store KV cache in transposed layout + [B, H, S, D] instead of standard [B, S, H, D]. Transposed layout + improves SDPA performance by enabling contiguous memory access in + the attn_score @ V GEMM (stride D instead of H*D along seq dim). expand_rope_table: Temporary workaround to expand sin/cos table in head dim to take vectorized path in optimized kernels. use_attention_sink: Whether to use attention sink to support multi-round @@ -194,6 +198,7 @@ class ModelConfig: enable_dynamic_shape: bool = True use_shared_embedding: bool = False use_sdpa_with_kv_cache: bool = False + use_transposed_cache: bool = True expand_rope_table: bool = False use_attention_sink: Optional[str] = None output_prune_map: Optional[str] = None