Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
Expand Down Expand Up @@ -334,6 +334,21 @@
action="store_true",
help="Whether to use sdpa_with_kv_cache update op when using kv cache",
)
parser.add_argument(
"--no_transposed_cache",
dest="use_transposed_cache",
default=True,
action="store_false",
help="Disable transposed KV cache layout [B, H, S, D]. By default transposed cache is used for better SDPA performance.",
)
parser.add_argument(
"--cache_transpose",
type=str,
default=None,
choices=["none", "all", "v_only", "k_only"],
help="Per-cache transpose control. Overrides --no_transposed_cache. "
"'v_only' transposes only the V cache for SDPA locality benefits.",
)
parser.add_argument(
"--disable_dynamic_shape",
dest="enable_dynamic_shape",
Expand Down Expand Up @@ -766,6 +781,7 @@
),
use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache,
use_transposed_cache=llm_config.model.use_transposed_cache,
cache_transpose=llm_config.model.cache_transpose,
quantize_kv_cache=llm_config.model.quantize_kv_cache,
use_kv_cache=llm_config.model.use_kv_cache,
qnn=llm_config.backend.qnn.enabled,
Expand Down Expand Up @@ -1605,6 +1621,7 @@
use_custom_sdpa_with_attention_mask: bool = False,
use_sdpa_with_kv_cache: bool = False,
use_transposed_cache: bool = True,
cache_transpose: Optional[str] = None,
quantize_kv_cache: bool = False,
use_kv_cache: bool = False,
qnn: bool = False,
Expand Down Expand Up @@ -1642,6 +1659,7 @@
use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
use_transposed_cache: Whether to store KV cache in transposed layout [B, H, S, D].
cache_transpose: Per-cache transpose control ('none','all','v_only','k_only'). Overrides use_transposed_cache.
quantize_kv_cache: Whether to quantize KV cache.
use_kv_cache: Whether to use KV cache.
qnn: Whether to use QNN.
Expand Down Expand Up @@ -1737,19 +1755,28 @@
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask

if use_sdpa_with_kv_cache:
# Resolve per-cache transpose flags
if cache_transpose is not None:
transpose_k = cache_transpose in ("all", "k_only")
transpose_v = cache_transpose in ("all", "v_only")
else:
transpose_k = use_transposed_cache
transpose_v = use_transposed_cache

# SDPA uses is_seq_at_dim_2=True when any cache is transposed,
# since KVCache always returns [B, H, S, D] for Attention.
sdpa_seq_at_dim_2 = transpose_k or transpose_v

transforms.append(
partial(replace_kv_cache_with_custom_kv_cache, is_seq_at_dim_2=use_transposed_cache)
partial(replace_kv_cache_with_custom_kv_cache, transpose_k=transpose_k, transpose_v=transpose_v)
)
# todo: do this optionally
# if use attention mask instead of causal attention
# then create partial function that sets use_attention_mask=True
if use_attention_mask_for_custom_sdpa:
transforms.append(
partial(replace_sdpa_with_custom_op, use_attention_mask=True, is_seq_at_dim_2=use_transposed_cache)
partial(replace_sdpa_with_custom_op, use_attention_mask=True, is_seq_at_dim_2=sdpa_seq_at_dim_2)
)
else:
transforms.append(
partial(replace_sdpa_with_custom_op, is_seq_at_dim_2=use_transposed_cache)
partial(replace_sdpa_with_custom_op, is_seq_at_dim_2=sdpa_seq_at_dim_2)
)

if quantize_kv_cache:
Expand Down
69 changes: 40 additions & 29 deletions examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -334,31 +334,39 @@


class CustomKVCache(nn.Module):
"""Custom KV cache with independent K/V transpose control.

Args:
transpose_k: If True, K cache is stored as [B, H, S, D] (transposed).
If False, stored as [B, S, H, D] (standard).
transpose_v: If True, V cache is stored as [B, H, S, D] (transposed).
If False, stored as [B, S, H, D] (standard).
"""
def __init__(
self,
max_batch_size: int,
max_context_length: int,
n_heads: int,
head_dim: int,
dtype=torch.float32,
is_seq_at_dim_2: bool = False,
transpose_k: bool = False,
transpose_v: bool = False,
):
self.is_seq_at_dim_2 = is_seq_at_dim_2
super().__init__()
self.transpose_k = transpose_k
self.transpose_v = transpose_v
self.max_context_length = max_context_length
if self.is_seq_at_dim_2:
cache_shape = (max_batch_size, n_heads, max_context_length, head_dim)
else:
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)

self.max_batch_size = max_batch_size
self.n_heads = n_heads
self.head_dim = head_dim

transposed_shape = (max_batch_size, n_heads, max_context_length, head_dim)
standard_shape = (max_batch_size, max_context_length, n_heads, head_dim)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
"k_cache", torch.zeros(transposed_shape if transpose_k else standard_shape, dtype=dtype, device="cpu")
)
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
"v_cache", torch.zeros(transposed_shape if transpose_v else standard_shape, dtype=dtype, device="cpu")
)

def update(
Expand All @@ -368,43 +376,45 @@
v_val: torch.Tensor,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D]
if not self.is_seq_at_dim_2:
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
# input_pos: [S], k_val/v_val: [B, H, S, D] from Attention
start_pos = input_pos[0].item()

# Transpose k_val to match cache layout if needed
k_for_cache = k_val if self.transpose_k else k_val.transpose(1, 2)
v_for_cache = v_val if self.transpose_v else v_val.transpose(1, 2)

if indices is not None:
_ = torch.ops.llama.update_cache_with_indices(
k_val, self.k_cache, start_pos, indices, self.is_seq_at_dim_2
k_for_cache, self.k_cache, start_pos, indices, self.transpose_k
)
_ = torch.ops.llama.update_cache_with_indices(
v_val, self.v_cache, start_pos, indices, self.is_seq_at_dim_2
v_for_cache, self.v_cache, start_pos, indices, self.transpose_v
)
else:
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos, self.is_seq_at_dim_2)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos, self.is_seq_at_dim_2)
_ = torch.ops.llama.update_cache(k_for_cache, self.k_cache, start_pos, self.transpose_k)
_ = torch.ops.llama.update_cache(v_for_cache, self.v_cache, start_pos, self.transpose_v)

if not self.is_seq_at_dim_2:
return (k_val.transpose(1, 2), v_val.transpose(1, 2))
else:
return (self.k_cache, self.v_cache)
# Return both caches in [B, H, S, D] for Attention
k_out = self.k_cache if self.transpose_k else self.k_cache.transpose(1, 2)
v_out = self.v_cache if self.transpose_v else self.v_cache.transpose(1, 2)
return (k_out, v_out)


def replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=True):
def replace_kv_cache_with_custom_kv_cache(module, transpose_k=False, transpose_v=False):
"""
Replace KVCache with CustomKVCache. This modifies the model in place.
When is_seq_at_dim_2=True, cache is stored as [B, H, S, D] (transposed),
which improves SDPA GEMM performance via better memory locality.
When is_seq_at_dim_2=False, cache is stored as [B, S, H, D] (standard).
K and V caches can be independently transposed:
- transpose_k=True: K cache stored as [B, H, S, D] (transposed)
- transpose_v=True: V cache stored as [B, H, S, D] (transposed)
- When False, cache is stored as [B, S, H, D] (standard)
"""
logging.info(
"Replacing KVCache with CustomKVCache. This modifies the model in place."
)
return _replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=is_seq_at_dim_2)
return _replace_kv_cache_with_custom_kv_cache(module, transpose_k=transpose_k, transpose_v=transpose_v)


def _replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=True):
def _replace_kv_cache_with_custom_kv_cache(module, transpose_k=False, transpose_v=False):
for name, child in module.named_children():
if isinstance(child, KVCache):
cache_dtype = child.k_cache.dtype
Expand All @@ -421,11 +431,12 @@
n_heads,
head_dim,
dtype=cache_dtype,
is_seq_at_dim_2=is_seq_at_dim_2,
transpose_k=transpose_k,
transpose_v=transpose_v,
),
)
else:
_replace_kv_cache_with_custom_kv_cache(child, is_seq_at_dim_2=is_seq_at_dim_2)
_replace_kv_cache_with_custom_kv_cache(child, transpose_k=transpose_k, transpose_v=transpose_v)
return module


Expand Down
20 changes: 11 additions & 9 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -28,6 +28,8 @@
super().__init__()
self.dim = dim
self.use_attention_mask = use_attention_mask
# When True, Q/K/V are in [B, H, S, D] and custom_sdpa uses seq_dim=2.
# When False, they are transposed to [B, S, H, D] and custom_sdpa uses seq_dim=1.
self.is_seq_at_dim_2 = is_seq_at_dim_2

def forward(
Expand All @@ -40,13 +42,13 @@
seqlen,
mask,
):
# Q, K, V arrive in [B, H, S, D] from Attention.
# If is_seq_at_dim_2=False, transpose to [B, S, H, D] for the op.
if not self.is_seq_at_dim_2:
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Custom op only supports float32 currently. Converting to/from float32 is
# faster than not having the op.
input_dtype = q.dtype
q = q.to(dtype=torch.float)
k = k.to(dtype=torch.float)
Expand All @@ -58,9 +60,9 @@
k,
v,
input_pos[0].item(),
mask, # Attention mask
0, # dropout probability. Ignored by the code
False, # is_causal
mask,
0,
False,
scale=None,
is_seq_dim_2=self.is_seq_at_dim_2,
)
Expand All @@ -70,9 +72,9 @@
k,
v,
input_pos[0].item(),
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
None,
0,
True,
scale=None,
is_seq_dim_2=self.is_seq_at_dim_2,
)
Expand Down
11 changes: 11 additions & 0 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ class ModelConfig:
[B, H, S, D] instead of standard [B, S, H, D]. Transposed layout
improves SDPA performance by enabling contiguous memory access in
the attn_score @ V GEMM (stride D instead of H*D along seq dim).
Controls both K and V caches together. For per-cache control, use
cache_transpose instead.
cache_transpose: Per-cache transpose control. One of 'none', 'all',
'v_only', 'k_only'. Overrides use_transposed_cache when set.
'v_only' transposes only the V cache, which may give SDPA locality
benefits for the attn @ V GEMM without the overhead of transposing K.
expand_rope_table: Temporary workaround to expand sin/cos table in head
dim to take vectorized path in optimized kernels.
use_attention_sink: Whether to use attention sink to support multi-round
Expand All @@ -199,6 +205,7 @@ class ModelConfig:
use_shared_embedding: bool = False
use_sdpa_with_kv_cache: bool = False
use_transposed_cache: bool = True
cache_transpose: Optional[str] = None
expand_rope_table: bool = False
use_attention_sink: Optional[str] = None
output_prune_map: Optional[str] = None
Expand Down Expand Up @@ -686,6 +693,10 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
llm_config.model.use_shared_embedding = args.use_shared_embedding
if hasattr(args, "use_sdpa_with_kv_cache"):
llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
if hasattr(args, "use_transposed_cache"):
llm_config.model.use_transposed_cache = args.use_transposed_cache
if hasattr(args, "cache_transpose") and args.cache_transpose is not None:
llm_config.model.cache_transpose = args.cache_transpose
if hasattr(args, "expand_rope_table"):
llm_config.model.expand_rope_table = args.expand_rope_table
if hasattr(args, "use_attention_sink"):
Expand Down
Loading