diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index be29dd448e5..5835296b183 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -334,6 +334,21 @@ def build_args_parser() -> argparse.ArgumentParser: action="store_true", help="Whether to use sdpa_with_kv_cache update op when using kv cache", ) + parser.add_argument( + "--no_transposed_cache", + dest="use_transposed_cache", + default=True, + action="store_false", + help="Disable transposed KV cache layout [B, H, S, D]. By default transposed cache is used for better SDPA performance.", + ) + parser.add_argument( + "--cache_transpose", + type=str, + default=None, + choices=["none", "all", "v_only", "k_only"], + help="Per-cache transpose control. Overrides --no_transposed_cache. " + "'v_only' transposes only the V cache for SDPA locality benefits.", + ) parser.add_argument( "--disable_dynamic_shape", dest="enable_dynamic_shape", @@ -766,6 +781,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: ), use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache, use_transposed_cache=llm_config.model.use_transposed_cache, + cache_transpose=llm_config.model.cache_transpose, quantize_kv_cache=llm_config.model.quantize_kv_cache, use_kv_cache=llm_config.model.use_kv_cache, qnn=llm_config.backend.qnn.enabled, @@ -1605,6 +1621,7 @@ def _get_source_transforms( # noqa use_custom_sdpa_with_attention_mask: bool = False, use_sdpa_with_kv_cache: bool = False, use_transposed_cache: bool = True, + cache_transpose: Optional[str] = None, quantize_kv_cache: bool = False, use_kv_cache: bool = False, qnn: bool = False, @@ -1642,6 +1659,7 @@ def _get_source_transforms( # noqa use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask. use_sdpa_with_kv_cache: Whether to use SDPA with KV cache. use_transposed_cache: Whether to store KV cache in transposed layout [B, H, S, D]. + cache_transpose: Per-cache transpose control ('none','all','v_only','k_only'). Overrides use_transposed_cache. quantize_kv_cache: Whether to quantize KV cache. use_kv_cache: Whether to use KV cache. qnn: Whether to use QNN. @@ -1737,19 +1755,28 @@ def _get_source_transforms( # noqa use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask if use_sdpa_with_kv_cache: + # Resolve per-cache transpose flags + if cache_transpose is not None: + transpose_k = cache_transpose in ("all", "k_only") + transpose_v = cache_transpose in ("all", "v_only") + else: + transpose_k = use_transposed_cache + transpose_v = use_transposed_cache + + # SDPA uses is_seq_at_dim_2=True when any cache is transposed, + # since KVCache always returns [B, H, S, D] for Attention. + sdpa_seq_at_dim_2 = transpose_k or transpose_v + transforms.append( - partial(replace_kv_cache_with_custom_kv_cache, is_seq_at_dim_2=use_transposed_cache) + partial(replace_kv_cache_with_custom_kv_cache, transpose_k=transpose_k, transpose_v=transpose_v) ) - # todo: do this optionally - # if use attention mask instead of causal attention - # then create partial function that sets use_attention_mask=True if use_attention_mask_for_custom_sdpa: transforms.append( - partial(replace_sdpa_with_custom_op, use_attention_mask=True, is_seq_at_dim_2=use_transposed_cache) + partial(replace_sdpa_with_custom_op, use_attention_mask=True, is_seq_at_dim_2=sdpa_seq_at_dim_2) ) else: transforms.append( - partial(replace_sdpa_with_custom_op, is_seq_at_dim_2=use_transposed_cache) + partial(replace_sdpa_with_custom_op, is_seq_at_dim_2=sdpa_seq_at_dim_2) ) if quantize_kv_cache: diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index c5a056ff2e9..8aefc2c5973 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -334,6 +334,14 @@ def _replace_kv_cache_with_quantized_kv_cache(module): class CustomKVCache(nn.Module): + """Custom KV cache with independent K/V transpose control. + + Args: + transpose_k: If True, K cache is stored as [B, H, S, D] (transposed). + If False, stored as [B, S, H, D] (standard). + transpose_v: If True, V cache is stored as [B, H, S, D] (transposed). + If False, stored as [B, S, H, D] (standard). + """ def __init__( self, max_batch_size: int, @@ -341,24 +349,24 @@ def __init__( n_heads: int, head_dim: int, dtype=torch.float32, - is_seq_at_dim_2: bool = False, + transpose_k: bool = False, + transpose_v: bool = False, ): - self.is_seq_at_dim_2 = is_seq_at_dim_2 super().__init__() + self.transpose_k = transpose_k + self.transpose_v = transpose_v self.max_context_length = max_context_length - if self.is_seq_at_dim_2: - cache_shape = (max_batch_size, n_heads, max_context_length, head_dim) - else: - cache_shape = (max_batch_size, max_context_length, n_heads, head_dim) - self.max_batch_size = max_batch_size self.n_heads = n_heads self.head_dim = head_dim + + transposed_shape = (max_batch_size, n_heads, max_context_length, head_dim) + standard_shape = (max_batch_size, max_context_length, n_heads, head_dim) self.register_buffer( - "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + "k_cache", torch.zeros(transposed_shape if transpose_k else standard_shape, dtype=dtype, device="cpu") ) self.register_buffer( - "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + "v_cache", torch.zeros(transposed_shape if transpose_v else standard_shape, dtype=dtype, device="cpu") ) def update( @@ -368,43 +376,45 @@ def update( v_val: torch.Tensor, indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [S], k_val: [B, H, S, D] - if not self.is_seq_at_dim_2: - k_val = k_val.transpose(1, 2) - v_val = v_val.transpose(1, 2) + # input_pos: [S], k_val/v_val: [B, H, S, D] from Attention start_pos = input_pos[0].item() + # Transpose k_val to match cache layout if needed + k_for_cache = k_val if self.transpose_k else k_val.transpose(1, 2) + v_for_cache = v_val if self.transpose_v else v_val.transpose(1, 2) + if indices is not None: _ = torch.ops.llama.update_cache_with_indices( - k_val, self.k_cache, start_pos, indices, self.is_seq_at_dim_2 + k_for_cache, self.k_cache, start_pos, indices, self.transpose_k ) _ = torch.ops.llama.update_cache_with_indices( - v_val, self.v_cache, start_pos, indices, self.is_seq_at_dim_2 + v_for_cache, self.v_cache, start_pos, indices, self.transpose_v ) else: - _ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos, self.is_seq_at_dim_2) - _ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos, self.is_seq_at_dim_2) + _ = torch.ops.llama.update_cache(k_for_cache, self.k_cache, start_pos, self.transpose_k) + _ = torch.ops.llama.update_cache(v_for_cache, self.v_cache, start_pos, self.transpose_v) - if not self.is_seq_at_dim_2: - return (k_val.transpose(1, 2), v_val.transpose(1, 2)) - else: - return (self.k_cache, self.v_cache) + # Return both caches in [B, H, S, D] for Attention + k_out = self.k_cache if self.transpose_k else self.k_cache.transpose(1, 2) + v_out = self.v_cache if self.transpose_v else self.v_cache.transpose(1, 2) + return (k_out, v_out) -def replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=True): +def replace_kv_cache_with_custom_kv_cache(module, transpose_k=False, transpose_v=False): """ Replace KVCache with CustomKVCache. This modifies the model in place. - When is_seq_at_dim_2=True, cache is stored as [B, H, S, D] (transposed), - which improves SDPA GEMM performance via better memory locality. - When is_seq_at_dim_2=False, cache is stored as [B, S, H, D] (standard). + K and V caches can be independently transposed: + - transpose_k=True: K cache stored as [B, H, S, D] (transposed) + - transpose_v=True: V cache stored as [B, H, S, D] (transposed) + - When False, cache is stored as [B, S, H, D] (standard) """ logging.info( "Replacing KVCache with CustomKVCache. This modifies the model in place." ) - return _replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=is_seq_at_dim_2) + return _replace_kv_cache_with_custom_kv_cache(module, transpose_k=transpose_k, transpose_v=transpose_v) -def _replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=True): +def _replace_kv_cache_with_custom_kv_cache(module, transpose_k=False, transpose_v=False): for name, child in module.named_children(): if isinstance(child, KVCache): cache_dtype = child.k_cache.dtype @@ -421,11 +431,12 @@ def _replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=True): n_heads, head_dim, dtype=cache_dtype, - is_seq_at_dim_2=is_seq_at_dim_2, + transpose_k=transpose_k, + transpose_v=transpose_v, ), ) else: - _replace_kv_cache_with_custom_kv_cache(child, is_seq_at_dim_2=is_seq_at_dim_2) + _replace_kv_cache_with_custom_kv_cache(child, transpose_k=transpose_k, transpose_v=transpose_v) return module diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index abdd039c10f..e6de718f38e 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -28,6 +28,8 @@ def __init__( super().__init__() self.dim = dim self.use_attention_mask = use_attention_mask + # When True, Q/K/V are in [B, H, S, D] and custom_sdpa uses seq_dim=2. + # When False, they are transposed to [B, S, H, D] and custom_sdpa uses seq_dim=1. self.is_seq_at_dim_2 = is_seq_at_dim_2 def forward( @@ -40,13 +42,13 @@ def forward( seqlen, mask, ): + # Q, K, V arrive in [B, H, S, D] from Attention. + # If is_seq_at_dim_2=False, transpose to [B, S, H, D] for the op. if not self.is_seq_at_dim_2: - q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) + q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) - # Custom op only supports float32 currently. Converting to/from float32 is - # faster than not having the op. input_dtype = q.dtype q = q.to(dtype=torch.float) k = k.to(dtype=torch.float) @@ -58,9 +60,9 @@ def forward( k, v, input_pos[0].item(), - mask, # Attention mask - 0, # dropout probability. Ignored by the code - False, # is_causal + mask, + 0, + False, scale=None, is_seq_dim_2=self.is_seq_at_dim_2, ) @@ -70,9 +72,9 @@ def forward( k, v, input_pos[0].item(), - None, # Attention mask - 0, # dropout probability. Ignored by the code - True, # is_causal + None, + 0, + True, scale=None, is_seq_dim_2=self.is_seq_at_dim_2, ) diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index a9c1ffce349..1595c78d4f3 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -178,6 +178,12 @@ class ModelConfig: [B, H, S, D] instead of standard [B, S, H, D]. Transposed layout improves SDPA performance by enabling contiguous memory access in the attn_score @ V GEMM (stride D instead of H*D along seq dim). + Controls both K and V caches together. For per-cache control, use + cache_transpose instead. + cache_transpose: Per-cache transpose control. One of 'none', 'all', + 'v_only', 'k_only'. Overrides use_transposed_cache when set. + 'v_only' transposes only the V cache, which may give SDPA locality + benefits for the attn @ V GEMM without the overhead of transposing K. expand_rope_table: Temporary workaround to expand sin/cos table in head dim to take vectorized path in optimized kernels. use_attention_sink: Whether to use attention sink to support multi-round @@ -199,6 +205,7 @@ class ModelConfig: use_shared_embedding: bool = False use_sdpa_with_kv_cache: bool = False use_transposed_cache: bool = True + cache_transpose: Optional[str] = None expand_rope_table: bool = False use_attention_sink: Optional[str] = None output_prune_map: Optional[str] = None @@ -686,6 +693,10 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 llm_config.model.use_shared_embedding = args.use_shared_embedding if hasattr(args, "use_sdpa_with_kv_cache"): llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache + if hasattr(args, "use_transposed_cache"): + llm_config.model.use_transposed_cache = args.use_transposed_cache + if hasattr(args, "cache_transpose") and args.cache_transpose is not None: + llm_config.model.cache_transpose = args.cache_transpose if hasattr(args, "expand_rope_table"): llm_config.model.expand_rope_table = args.expand_rope_table if hasattr(args, "use_attention_sink"):