Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
Expand Down Expand Up @@ -765,6 +765,7 @@
llm_config.model, "use_custom_sdpa_with_attention_mask", False
),
use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache,
use_transposed_cache=llm_config.model.use_transposed_cache,
quantize_kv_cache=llm_config.model.quantize_kv_cache,
use_kv_cache=llm_config.model.use_kv_cache,
qnn=llm_config.backend.qnn.enabled,
Expand Down Expand Up @@ -1603,6 +1604,7 @@
expand_rope_table: bool = False,
use_custom_sdpa_with_attention_mask: bool = False,
use_sdpa_with_kv_cache: bool = False,
use_transposed_cache: bool = True,
quantize_kv_cache: bool = False,
use_kv_cache: bool = False,
qnn: bool = False,
Expand Down Expand Up @@ -1639,6 +1641,7 @@
expand_rope_table: Whether to expand rope table.
use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
use_transposed_cache: Whether to store KV cache in transposed layout [B, H, S, D].
quantize_kv_cache: Whether to quantize KV cache.
use_kv_cache: Whether to use KV cache.
qnn: Whether to use QNN.
Expand Down Expand Up @@ -1734,16 +1737,20 @@
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask

if use_sdpa_with_kv_cache:
transforms.append(replace_kv_cache_with_custom_kv_cache)
transforms.append(
partial(replace_kv_cache_with_custom_kv_cache, is_seq_at_dim_2=use_transposed_cache)
)
# todo: do this optionally
# if use attention mask instead of causal attention
# then create partial function that sets use_attention_mask=True
if use_attention_mask_for_custom_sdpa:
transforms.append(
partial(replace_sdpa_with_custom_op, use_attention_mask=True)
partial(replace_sdpa_with_custom_op, use_attention_mask=True, is_seq_at_dim_2=use_transposed_cache)
)
else:
transforms.append(replace_sdpa_with_custom_op)
transforms.append(
partial(replace_sdpa_with_custom_op, is_seq_at_dim_2=use_transposed_cache)
)

if quantize_kv_cache:
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
Expand Down
17 changes: 8 additions & 9 deletions examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -391,21 +391,20 @@
return (self.k_cache, self.v_cache)


def replace_kv_cache_with_custom_kv_cache(module):
def replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=True):
"""
Replace KVCache with CustomKVCache. This modifies the model in place.
At the moment custom kv cache only supports cache with shape
[B, S, H, D] as opposed to [B, H, S, D]
This is because the custom op treats second dim as sequence dim.
Future work: support [B, H, S, D]
When is_seq_at_dim_2=True, cache is stored as [B, H, S, D] (transposed),
which improves SDPA GEMM performance via better memory locality.
When is_seq_at_dim_2=False, cache is stored as [B, S, H, D] (standard).
"""
logging.info(
"Replacing KVCache with CustomKVCache. This modifies the model in place."
)
return _replace_kv_cache_with_custom_kv_cache(module)
return _replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=is_seq_at_dim_2)


def _replace_kv_cache_with_custom_kv_cache(module):
def _replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=True):
for name, child in module.named_children():
if isinstance(child, KVCache):
cache_dtype = child.k_cache.dtype
Expand All @@ -422,11 +421,11 @@
n_heads,
head_dim,
dtype=cache_dtype,
is_seq_at_dim_2=True, # hacking temporarily
is_seq_at_dim_2=is_seq_at_dim_2,
),
)
else:
_replace_kv_cache_with_custom_kv_cache(child)
_replace_kv_cache_with_custom_kv_cache(child, is_seq_at_dim_2=is_seq_at_dim_2)
return module


Expand Down
10 changes: 5 additions & 5 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -82,7 +82,7 @@


def _replace_sdpa_with_custom_op(
module: torch.nn.Module, use_attention_mask: bool = False
module: torch.nn.Module, use_attention_mask: bool = False, is_seq_at_dim_2: bool = True
):
for name, child in module.named_children():
if isinstance(child, SDPA):
Expand All @@ -92,19 +92,19 @@
SDPACustom(
child.dim,
use_attention_mask=use_attention_mask,
is_seq_at_dim_2=True, # hacking temporarily
is_seq_at_dim_2=is_seq_at_dim_2,
),
)
else:
_replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask)
_replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask, is_seq_at_dim_2=is_seq_at_dim_2)


def replace_sdpa_with_custom_op(
module: torch.nn.Module, use_attention_mask: bool = False
module: torch.nn.Module, use_attention_mask: bool = False, is_seq_at_dim_2: bool = True
) -> torch.nn.Module:
from executorch.extension.llm.custom_ops import custom_ops # noqa

_replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask)
_replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask, is_seq_at_dim_2=is_seq_at_dim_2)
return module


Expand Down
8 changes: 4 additions & 4 deletions extension/llm/custom_ops/op_update_cache.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -119,17 +119,17 @@

ET_CHECK_MSG(
value_batch_size == cache_batch_size,
"projected_value batch size (%zd) should be equal to the cache batch size (%zd).",
"projected_value batch size (%" PRId64 ") should be equal to the cache batch size (%" PRId64 ").",
value_batch_size,
cache_batch_size);
ET_CHECK_MSG(
value_num_heads == cache_num_heads,
"projected_value number of heads (%zd) should be equal to the cache number of heads (%zd).",
"projected_value number of heads (%" PRId64 ") should be equal to the cache number of heads (%" PRId64 ").",
value_num_heads,
cache_num_heads);
ET_CHECK_MSG(
value_head_dim == cache_head_dim,
"projected_value embedding dimension (%zd) should be equal to the cache embedding dimension (%zd).",
"projected_value embedding dimension (%" PRId64 ") should be equal to the cache embedding dimension (%" PRId64 ").",
value_head_dim,
cache_head_dim);
ET_CHECK_MSG(
Expand Down Expand Up @@ -210,7 +210,7 @@
// Ensure the target position is valid
ET_CHECK_MSG(
target_pos >= 0 && target_pos < cache_seq_len,
"Index out of bounds: %" PRId64 " not in [0, %zd)",
"Index out of bounds: %" PRId64 " not in [0, %" PRId64 ")",
target_pos,
cache_seq_len);

Expand Down
5 changes: 5 additions & 0 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ class ModelConfig:
use_sdpa_with_kv_cache: Whether to use flash attention by substituting
for our custom SDPA op. Note that the naming is poor and this
doesn't actually have anything to do with the kv_cache at the moment.
use_transposed_cache: Whether to store KV cache in transposed layout
[B, H, S, D] instead of standard [B, S, H, D]. Transposed layout
improves SDPA performance by enabling contiguous memory access in
the attn_score @ V GEMM (stride D instead of H*D along seq dim).
expand_rope_table: Temporary workaround to expand sin/cos table in head
dim to take vectorized path in optimized kernels.
use_attention_sink: Whether to use attention sink to support multi-round
Expand All @@ -194,6 +198,7 @@ class ModelConfig:
enable_dynamic_shape: bool = True
use_shared_embedding: bool = False
use_sdpa_with_kv_cache: bool = False
use_transposed_cache: bool = True
expand_rope_table: bool = False
use_attention_sink: Optional[str] = None
output_prune_map: Optional[str] = None
Expand Down
Loading