Skip to content

Commit af088ae

Browse files
committed
Add --no_transposed_cache CLI flag for export pipeline
Add CLI argument to control transposed KV cache layout during export. By default transposed cache is used (is_seq_at_dim_2=True). Pass --no_transposed_cache to disable it for baseline comparison. Differential Revision: [D99677684](https://our.internmc.facebook.com/intern/diff/D99677684/) [ghstack-poisoned]
1 parent 6f31ee3 commit af088ae

4 files changed

Lines changed: 95 additions & 44 deletions

File tree

examples/models/llama/export_llama_lib.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,21 @@ def build_args_parser() -> argparse.ArgumentParser:
334334
action="store_true",
335335
help="Whether to use sdpa_with_kv_cache update op when using kv cache",
336336
)
337+
parser.add_argument(
338+
"--no_transposed_cache",
339+
dest="use_transposed_cache",
340+
default=True,
341+
action="store_false",
342+
help="Disable transposed KV cache layout [B, H, S, D]. By default transposed cache is used for better SDPA performance.",
343+
)
344+
parser.add_argument(
345+
"--cache_transpose",
346+
type=str,
347+
default=None,
348+
choices=["none", "all", "v_only", "k_only"],
349+
help="Per-cache transpose control. Overrides --no_transposed_cache. "
350+
"'v_only' transposes only the V cache for SDPA locality benefits.",
351+
)
337352
parser.add_argument(
338353
"--disable_dynamic_shape",
339354
dest="enable_dynamic_shape",
@@ -766,6 +781,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
766781
),
767782
use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache,
768783
use_transposed_cache=llm_config.model.use_transposed_cache,
784+
cache_transpose=llm_config.model.cache_transpose,
769785
quantize_kv_cache=llm_config.model.quantize_kv_cache,
770786
use_kv_cache=llm_config.model.use_kv_cache,
771787
qnn=llm_config.backend.qnn.enabled,
@@ -1605,6 +1621,7 @@ def _get_source_transforms( # noqa
16051621
use_custom_sdpa_with_attention_mask: bool = False,
16061622
use_sdpa_with_kv_cache: bool = False,
16071623
use_transposed_cache: bool = True,
1624+
cache_transpose: Optional[str] = None,
16081625
quantize_kv_cache: bool = False,
16091626
use_kv_cache: bool = False,
16101627
qnn: bool = False,
@@ -1642,6 +1659,7 @@ def _get_source_transforms( # noqa
16421659
use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
16431660
use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
16441661
use_transposed_cache: Whether to store KV cache in transposed layout [B, H, S, D].
1662+
cache_transpose: Per-cache transpose control ('none','all','v_only','k_only'). Overrides use_transposed_cache.
16451663
quantize_kv_cache: Whether to quantize KV cache.
16461664
use_kv_cache: Whether to use KV cache.
16471665
qnn: Whether to use QNN.
@@ -1737,19 +1755,28 @@ def _get_source_transforms( # noqa
17371755
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
17381756

17391757
if use_sdpa_with_kv_cache:
1758+
# Resolve per-cache transpose flags
1759+
if cache_transpose is not None:
1760+
transpose_k = cache_transpose in ("all", "k_only")
1761+
transpose_v = cache_transpose in ("all", "v_only")
1762+
else:
1763+
transpose_k = use_transposed_cache
1764+
transpose_v = use_transposed_cache
1765+
1766+
# SDPA uses is_seq_at_dim_2=True when any cache is transposed,
1767+
# since KVCache always returns [B, H, S, D] for Attention.
1768+
sdpa_seq_at_dim_2 = transpose_k or transpose_v
1769+
17401770
transforms.append(
1741-
partial(replace_kv_cache_with_custom_kv_cache, is_seq_at_dim_2=use_transposed_cache)
1771+
partial(replace_kv_cache_with_custom_kv_cache, transpose_k=transpose_k, transpose_v=transpose_v)
17421772
)
1743-
# todo: do this optionally
1744-
# if use attention mask instead of causal attention
1745-
# then create partial function that sets use_attention_mask=True
17461773
if use_attention_mask_for_custom_sdpa:
17471774
transforms.append(
1748-
partial(replace_sdpa_with_custom_op, use_attention_mask=True, is_seq_at_dim_2=use_transposed_cache)
1775+
partial(replace_sdpa_with_custom_op, use_attention_mask=True, is_seq_at_dim_2=sdpa_seq_at_dim_2)
17491776
)
17501777
else:
17511778
transforms.append(
1752-
partial(replace_sdpa_with_custom_op, is_seq_at_dim_2=use_transposed_cache)
1779+
partial(replace_sdpa_with_custom_op, is_seq_at_dim_2=sdpa_seq_at_dim_2)
17531780
)
17541781

17551782
if quantize_kv_cache:

examples/models/llama/source_transformation/custom_kv_cache.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -334,31 +334,39 @@ def _replace_kv_cache_with_quantized_kv_cache(module):
334334

335335

336336
class CustomKVCache(nn.Module):
337+
"""Custom KV cache with independent K/V transpose control.
338+
339+
Args:
340+
transpose_k: If True, K cache is stored as [B, H, S, D] (transposed).
341+
If False, stored as [B, S, H, D] (standard).
342+
transpose_v: If True, V cache is stored as [B, H, S, D] (transposed).
343+
If False, stored as [B, S, H, D] (standard).
344+
"""
337345
def __init__(
338346
self,
339347
max_batch_size: int,
340348
max_context_length: int,
341349
n_heads: int,
342350
head_dim: int,
343351
dtype=torch.float32,
344-
is_seq_at_dim_2: bool = False,
352+
transpose_k: bool = False,
353+
transpose_v: bool = False,
345354
):
346-
self.is_seq_at_dim_2 = is_seq_at_dim_2
347355
super().__init__()
356+
self.transpose_k = transpose_k
357+
self.transpose_v = transpose_v
348358
self.max_context_length = max_context_length
349-
if self.is_seq_at_dim_2:
350-
cache_shape = (max_batch_size, n_heads, max_context_length, head_dim)
351-
else:
352-
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
353-
354359
self.max_batch_size = max_batch_size
355360
self.n_heads = n_heads
356361
self.head_dim = head_dim
362+
363+
transposed_shape = (max_batch_size, n_heads, max_context_length, head_dim)
364+
standard_shape = (max_batch_size, max_context_length, n_heads, head_dim)
357365
self.register_buffer(
358-
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
366+
"k_cache", torch.zeros(transposed_shape if transpose_k else standard_shape, dtype=dtype, device="cpu")
359367
)
360368
self.register_buffer(
361-
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
369+
"v_cache", torch.zeros(transposed_shape if transpose_v else standard_shape, dtype=dtype, device="cpu")
362370
)
363371

364372
def update(
@@ -368,43 +376,45 @@ def update(
368376
v_val: torch.Tensor,
369377
indices: Optional[torch.Tensor] = None,
370378
) -> Tuple[torch.Tensor, torch.Tensor]:
371-
# input_pos: [S], k_val: [B, H, S, D]
372-
if not self.is_seq_at_dim_2:
373-
k_val = k_val.transpose(1, 2)
374-
v_val = v_val.transpose(1, 2)
379+
# input_pos: [S], k_val/v_val: [B, H, S, D] from Attention
375380
start_pos = input_pos[0].item()
376381

382+
# Transpose k_val to match cache layout if needed
383+
k_for_cache = k_val if self.transpose_k else k_val.transpose(1, 2)
384+
v_for_cache = v_val if self.transpose_v else v_val.transpose(1, 2)
385+
377386
if indices is not None:
378387
_ = torch.ops.llama.update_cache_with_indices(
379-
k_val, self.k_cache, start_pos, indices, self.is_seq_at_dim_2
388+
k_for_cache, self.k_cache, start_pos, indices, self.transpose_k
380389
)
381390
_ = torch.ops.llama.update_cache_with_indices(
382-
v_val, self.v_cache, start_pos, indices, self.is_seq_at_dim_2
391+
v_for_cache, self.v_cache, start_pos, indices, self.transpose_v
383392
)
384393
else:
385-
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos, self.is_seq_at_dim_2)
386-
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos, self.is_seq_at_dim_2)
394+
_ = torch.ops.llama.update_cache(k_for_cache, self.k_cache, start_pos, self.transpose_k)
395+
_ = torch.ops.llama.update_cache(v_for_cache, self.v_cache, start_pos, self.transpose_v)
387396

388-
if not self.is_seq_at_dim_2:
389-
return (k_val.transpose(1, 2), v_val.transpose(1, 2))
390-
else:
391-
return (self.k_cache, self.v_cache)
397+
# Return both caches in [B, H, S, D] for Attention
398+
k_out = self.k_cache if self.transpose_k else self.k_cache.transpose(1, 2)
399+
v_out = self.v_cache if self.transpose_v else self.v_cache.transpose(1, 2)
400+
return (k_out, v_out)
392401

393402

394-
def replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=True):
403+
def replace_kv_cache_with_custom_kv_cache(module, transpose_k=False, transpose_v=False):
395404
"""
396405
Replace KVCache with CustomKVCache. This modifies the model in place.
397-
When is_seq_at_dim_2=True, cache is stored as [B, H, S, D] (transposed),
398-
which improves SDPA GEMM performance via better memory locality.
399-
When is_seq_at_dim_2=False, cache is stored as [B, S, H, D] (standard).
406+
K and V caches can be independently transposed:
407+
- transpose_k=True: K cache stored as [B, H, S, D] (transposed)
408+
- transpose_v=True: V cache stored as [B, H, S, D] (transposed)
409+
- When False, cache is stored as [B, S, H, D] (standard)
400410
"""
401411
logging.info(
402412
"Replacing KVCache with CustomKVCache. This modifies the model in place."
403413
)
404-
return _replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=is_seq_at_dim_2)
414+
return _replace_kv_cache_with_custom_kv_cache(module, transpose_k=transpose_k, transpose_v=transpose_v)
405415

406416

407-
def _replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=True):
417+
def _replace_kv_cache_with_custom_kv_cache(module, transpose_k=False, transpose_v=False):
408418
for name, child in module.named_children():
409419
if isinstance(child, KVCache):
410420
cache_dtype = child.k_cache.dtype
@@ -421,11 +431,12 @@ def _replace_kv_cache_with_custom_kv_cache(module, is_seq_at_dim_2=True):
421431
n_heads,
422432
head_dim,
423433
dtype=cache_dtype,
424-
is_seq_at_dim_2=is_seq_at_dim_2,
434+
transpose_k=transpose_k,
435+
transpose_v=transpose_v,
425436
),
426437
)
427438
else:
428-
_replace_kv_cache_with_custom_kv_cache(child, is_seq_at_dim_2=is_seq_at_dim_2)
439+
_replace_kv_cache_with_custom_kv_cache(child, transpose_k=transpose_k, transpose_v=transpose_v)
429440
return module
430441

431442

examples/models/llama/source_transformation/sdpa.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def __init__(
2828
super().__init__()
2929
self.dim = dim
3030
self.use_attention_mask = use_attention_mask
31+
# When True, Q/K/V are in [B, H, S, D] and custom_sdpa uses seq_dim=2.
32+
# When False, they are transposed to [B, S, H, D] and custom_sdpa uses seq_dim=1.
3133
self.is_seq_at_dim_2 = is_seq_at_dim_2
3234

3335
def forward(
@@ -40,13 +42,13 @@ def forward(
4042
seqlen,
4143
mask,
4244
):
45+
# Q, K, V arrive in [B, H, S, D] from Attention.
46+
# If is_seq_at_dim_2=False, transpose to [B, S, H, D] for the op.
4347
if not self.is_seq_at_dim_2:
44-
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
48+
q = q.transpose(1, 2)
4549
k = k.transpose(1, 2)
4650
v = v.transpose(1, 2)
4751

48-
# Custom op only supports float32 currently. Converting to/from float32 is
49-
# faster than not having the op.
5052
input_dtype = q.dtype
5153
q = q.to(dtype=torch.float)
5254
k = k.to(dtype=torch.float)
@@ -58,9 +60,9 @@ def forward(
5860
k,
5961
v,
6062
input_pos[0].item(),
61-
mask, # Attention mask
62-
0, # dropout probability. Ignored by the code
63-
False, # is_causal
63+
mask,
64+
0,
65+
False,
6466
scale=None,
6567
is_seq_dim_2=self.is_seq_at_dim_2,
6668
)
@@ -70,9 +72,9 @@ def forward(
7072
k,
7173
v,
7274
input_pos[0].item(),
73-
None, # Attention mask
74-
0, # dropout probability. Ignored by the code
75-
True, # is_causal
75+
None,
76+
0,
77+
True,
7678
scale=None,
7779
is_seq_dim_2=self.is_seq_at_dim_2,
7880
)

extension/llm/export/config/llm_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ class ModelConfig:
178178
[B, H, S, D] instead of standard [B, S, H, D]. Transposed layout
179179
improves SDPA performance by enabling contiguous memory access in
180180
the attn_score @ V GEMM (stride D instead of H*D along seq dim).
181+
Controls both K and V caches together. For per-cache control, use
182+
cache_transpose instead.
183+
cache_transpose: Per-cache transpose control. One of 'none', 'all',
184+
'v_only', 'k_only'. Overrides use_transposed_cache when set.
185+
'v_only' transposes only the V cache, which may give SDPA locality
186+
benefits for the attn @ V GEMM without the overhead of transposing K.
181187
expand_rope_table: Temporary workaround to expand sin/cos table in head
182188
dim to take vectorized path in optimized kernels.
183189
use_attention_sink: Whether to use attention sink to support multi-round
@@ -199,6 +205,7 @@ class ModelConfig:
199205
use_shared_embedding: bool = False
200206
use_sdpa_with_kv_cache: bool = False
201207
use_transposed_cache: bool = True
208+
cache_transpose: Optional[str] = None
202209
expand_rope_table: bool = False
203210
use_attention_sink: Optional[str] = None
204211
output_prune_map: Optional[str] = None
@@ -686,6 +693,10 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
686693
llm_config.model.use_shared_embedding = args.use_shared_embedding
687694
if hasattr(args, "use_sdpa_with_kv_cache"):
688695
llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
696+
if hasattr(args, "use_transposed_cache"):
697+
llm_config.model.use_transposed_cache = args.use_transposed_cache
698+
if hasattr(args, "cache_transpose") and args.cache_transpose is not None:
699+
llm_config.model.cache_transpose = args.cache_transpose
689700
if hasattr(args, "expand_rope_table"):
690701
llm_config.model.expand_rope_table = args.expand_rope_table
691702
if hasattr(args, "use_attention_sink"):

0 commit comments

Comments
 (0)