Skip to content

Commit 45468e1

Browse files
committed
Support separate K/V seq dim in custom_sdpa op
Previously custom_sdpa used a single is_seq_at_dim_2 flag for all tensors. This meant v_only transpose required a runtime transpose copy for K (converting from [B,H,S,D] to [B,S,H,D]), which caused a 2.3x decode slowdown (15.35 vs 35.63 tok/s). Now the C++ op accepts separate is_seq_dim_2, is_k_seq_dim_2, is_v_seq_dim_2 flags so Q, K, V can each have independent layouts. The Python layer passes K and V in their native cache layout without any transpose, and the flash attention kernel handles the mixed strides directly. Changes: - op_sdpa_impl.h: cpu_flash_attention takes q_seq_dim, k_seq_dim, v_seq_dim instead of single seq_dim - op_sdpa.cpp/h: custom_sdpa_out takes 3 bool params - op_sdpa_aot.cpp: Updated schema strings and wrappers - sdpa.py: SDPACustom uses is_k_seq_at_dim_2 / is_v_seq_at_dim_2, Q always at dim 2, no input transposes - custom_kv_cache.py: update() returns native cache layout, added is_seq_at_dim_2 compat property - export_llama_lib.py: passes separate K/V flags Differential Revision: [D99677678](https://our.internmc.facebook.com/intern/diff/D99677678/) [ghstack-poisoned]
1 parent af088ae commit 45468e1

8 files changed

Lines changed: 158 additions & 142 deletions

File tree

examples/models/llama/export_llama_lib.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,20 +1763,25 @@ def _get_source_transforms( # noqa
17631763
transpose_k = use_transposed_cache
17641764
transpose_v = use_transposed_cache
17651765

1766-
# SDPA uses is_seq_at_dim_2=True when any cache is transposed,
1767-
# since KVCache always returns [B, H, S, D] for Attention.
1768-
sdpa_seq_at_dim_2 = transpose_k or transpose_v
1769-
17701766
transforms.append(
17711767
partial(replace_kv_cache_with_custom_kv_cache, transpose_k=transpose_k, transpose_v=transpose_v)
17721768
)
17731769
if use_attention_mask_for_custom_sdpa:
17741770
transforms.append(
1775-
partial(replace_sdpa_with_custom_op, use_attention_mask=True, is_seq_at_dim_2=sdpa_seq_at_dim_2)
1771+
partial(
1772+
replace_sdpa_with_custom_op,
1773+
use_attention_mask=True,
1774+
is_k_seq_at_dim_2=transpose_k,
1775+
is_v_seq_at_dim_2=transpose_v,
1776+
)
17761777
)
17771778
else:
17781779
transforms.append(
1779-
partial(replace_sdpa_with_custom_op, is_seq_at_dim_2=sdpa_seq_at_dim_2)
1780+
partial(
1781+
replace_sdpa_with_custom_op,
1782+
is_k_seq_at_dim_2=transpose_k,
1783+
is_v_seq_at_dim_2=transpose_v,
1784+
)
17801785
)
17811786

17821787
if quantize_kv_cache:

examples/models/llama/source_transformation/custom_kv_cache.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def update(
379379
# input_pos: [S], k_val/v_val: [B, H, S, D] from Attention
380380
start_pos = input_pos[0].item()
381381

382-
# Transpose k_val to match cache layout if needed
382+
# Transpose k_val/v_val to match cache layout if needed
383383
k_for_cache = k_val if self.transpose_k else k_val.transpose(1, 2)
384384
v_for_cache = v_val if self.transpose_v else v_val.transpose(1, 2)
385385

@@ -394,10 +394,15 @@ def update(
394394
_ = torch.ops.llama.update_cache(k_for_cache, self.k_cache, start_pos, self.transpose_k)
395395
_ = torch.ops.llama.update_cache(v_for_cache, self.v_cache, start_pos, self.transpose_v)
396396

397-
# Return both caches in [B, H, S, D] for Attention
398-
k_out = self.k_cache if self.transpose_k else self.k_cache.transpose(1, 2)
399-
v_out = self.v_cache if self.transpose_v else self.v_cache.transpose(1, 2)
400-
return (k_out, v_out)
397+
# Return caches in their native layout. The SDPA op handles
398+
# mixed K/V layouts via separate seq dim parameters, avoiding
399+
# expensive runtime transpose copies.
400+
return (self.k_cache, self.v_cache)
401+
402+
@property
403+
def is_seq_at_dim_2(self):
404+
"""Backward compat for quantized KV cache path."""
405+
return self.transpose_k and self.transpose_v
401406

402407

403408
def replace_kv_cache_with_custom_kv_cache(module, transpose_k=False, transpose_v=False):
@@ -519,7 +524,6 @@ def from_quantized_kv_cache(
519524
kv_cache.cache_type,
520525
kv_cache.use_custom_update_cache_op,
521526
kv_cache.return_float_values,
522-
kv_cache.is_seq_at_dim_2,
523527
is_seq_at_dim_2=kv_cache.is_seq_at_dim_2,
524528
)
525529

@@ -532,11 +536,13 @@ def __init__(
532536
n_heads,
533537
head_dim,
534538
dtype=torch.float32,
535-
is_seq_at_dim_2: bool = False,
539+
transpose_k: bool = False,
540+
transpose_v: bool = False,
536541
):
537542
# Look at attention.py for explanation on why max_context_length * 2
538543
super().__init__(
539-
max_batch_size, max_context_length * 2, n_heads, head_dim, dtype, is_seq_at_dim_2
544+
max_batch_size, max_context_length * 2, n_heads, head_dim, dtype,
545+
transpose_k=transpose_k, transpose_v=transpose_v,
540546
)
541547
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
542548
self.is_ring_buffer = True
@@ -551,18 +557,10 @@ def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
551557
def update(self, input_pos, k_val, v_val):
552558
"""
553559
k_val, v_val: [B, H, S, D]
554-
return: [B, H, S, D]
555-
However the storage is [B, S, H, D] so we incur transpose in, transpose out
556-
This shall be removed by subsequent post-export graph pass
560+
Returns K/V caches in their native storage layout.
557561
"""
558-
# Need to transpose for two reasons
559-
# 1. kv cache is stored as [B, S, H, D]
560-
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
561-
# away transpose at the output of k, v projection
562-
if not self.is_seq_at_dim_2:
563-
seq_len = k_val.transpose(1, 2).size(1)
564-
else:
565-
seq_len = k_val.size(2)
562+
# k_val is always [B, H, S, D] from Attention. Get seq_len from dim 2.
563+
seq_len = k_val.size(2)
566564
assert seq_len <= self.k_cache.size(
567565
1
568566
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
@@ -593,7 +591,8 @@ def from_custom_kv_cache(
593591
n_heads,
594592
head_dim,
595593
dtype=kv_cache.k_cache.dtype,
596-
is_seq_at_dim_2=kv_cache.is_seq_at_dim_2,
594+
transpose_k=kv_cache.transpose_k,
595+
transpose_v=kv_cache.transpose_v,
597596
)
598597

599598

examples/models/llama/source_transformation/sdpa.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,17 @@ def __init__(
2323
self,
2424
dim: int,
2525
use_attention_mask: bool = False,
26-
is_seq_at_dim_2: bool = False,
26+
is_k_seq_at_dim_2: bool = False,
27+
is_v_seq_at_dim_2: bool = False,
2728
):
2829
super().__init__()
2930
self.dim = dim
3031
self.use_attention_mask = use_attention_mask
31-
# When True, Q/K/V are in [B, H, S, D] and custom_sdpa uses seq_dim=2.
32-
# When False, they are transposed to [B, S, H, D] and custom_sdpa uses seq_dim=1.
33-
self.is_seq_at_dim_2 = is_seq_at_dim_2
32+
# Separate seq dim flags for K and V allow mixed cache layouts.
33+
# Q and output always use seq_dim=2 ([B, H, S, D]) since Q is
34+
# always small (current step) and the transpose is negligible.
35+
self.is_k_seq_at_dim_2 = is_k_seq_at_dim_2
36+
self.is_v_seq_at_dim_2 = is_v_seq_at_dim_2
3437

3538
def forward(
3639
self,
@@ -42,13 +45,8 @@ def forward(
4245
seqlen,
4346
mask,
4447
):
45-
# Q, K, V arrive in [B, H, S, D] from Attention.
46-
# If is_seq_at_dim_2=False, transpose to [B, S, H, D] for the op.
47-
if not self.is_seq_at_dim_2:
48-
q = q.transpose(1, 2)
49-
k = k.transpose(1, 2)
50-
v = v.transpose(1, 2)
51-
48+
# Q arrives in [B, H, S, D] from Attention - always passed with seq_dim=2.
49+
# K and V arrive in their native cache layout (may differ).
5250
input_dtype = q.dtype
5351
q = q.to(dtype=torch.float)
5452
k = k.to(dtype=torch.float)
@@ -64,7 +62,9 @@ def forward(
6462
0,
6563
False,
6664
scale=None,
67-
is_seq_dim_2=self.is_seq_at_dim_2,
65+
is_seq_dim_2=True,
66+
is_k_seq_dim_2=self.is_k_seq_at_dim_2,
67+
is_v_seq_dim_2=self.is_v_seq_at_dim_2,
6868
)
6969
else:
7070
output = torch.ops.llama.custom_sdpa(
@@ -76,15 +76,20 @@ def forward(
7676
0,
7777
True,
7878
scale=None,
79-
is_seq_dim_2=self.is_seq_at_dim_2,
79+
is_seq_dim_2=True,
80+
is_k_seq_dim_2=self.is_k_seq_at_dim_2,
81+
is_v_seq_dim_2=self.is_v_seq_at_dim_2,
8082
)
81-
if self.is_seq_at_dim_2:
82-
output = output.transpose(1, 2).contiguous()
83+
# Output is [B, H, S, D] (matches Q layout), transpose for reshape
84+
output = output.transpose(1, 2).contiguous()
8385
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
8486

8587

8688
def _replace_sdpa_with_custom_op(
87-
module: torch.nn.Module, use_attention_mask: bool = False, is_seq_at_dim_2: bool = True
89+
module: torch.nn.Module,
90+
use_attention_mask: bool = False,
91+
is_k_seq_at_dim_2: bool = False,
92+
is_v_seq_at_dim_2: bool = False,
8893
):
8994
for name, child in module.named_children():
9095
if isinstance(child, SDPA):
@@ -94,19 +99,33 @@ def _replace_sdpa_with_custom_op(
9499
SDPACustom(
95100
child.dim,
96101
use_attention_mask=use_attention_mask,
97-
is_seq_at_dim_2=is_seq_at_dim_2,
102+
is_k_seq_at_dim_2=is_k_seq_at_dim_2,
103+
is_v_seq_at_dim_2=is_v_seq_at_dim_2,
98104
),
99105
)
100106
else:
101-
_replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask, is_seq_at_dim_2=is_seq_at_dim_2)
107+
_replace_sdpa_with_custom_op(
108+
child,
109+
use_attention_mask=use_attention_mask,
110+
is_k_seq_at_dim_2=is_k_seq_at_dim_2,
111+
is_v_seq_at_dim_2=is_v_seq_at_dim_2,
112+
)
102113

103114

104115
def replace_sdpa_with_custom_op(
105-
module: torch.nn.Module, use_attention_mask: bool = False, is_seq_at_dim_2: bool = True
116+
module: torch.nn.Module,
117+
use_attention_mask: bool = False,
118+
is_k_seq_at_dim_2: bool = False,
119+
is_v_seq_at_dim_2: bool = False,
106120
) -> torch.nn.Module:
107121
from executorch.extension.llm.custom_ops import custom_ops # noqa
108122

109-
_replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask, is_seq_at_dim_2=is_seq_at_dim_2)
123+
_replace_sdpa_with_custom_op(
124+
module,
125+
use_attention_mask=use_attention_mask,
126+
is_k_seq_at_dim_2=is_k_seq_at_dim_2,
127+
is_v_seq_at_dim_2=is_v_seq_at_dim_2,
128+
)
110129
return module
111130

112131

@@ -138,6 +157,7 @@ def __init__(
138157
self.float_dtype = torch.float32
139158
self.kv_cache = kv_cache
140159
self.use_attention_mask = use_attention_mask
160+
# Quantized path uses a single flag for all tensors
141161
self.is_seq_at_dim_2 = is_seq_at_dim_2
142162

143163
def forward(
@@ -225,8 +245,10 @@ def _update_attention_module_with_quantized_sdpa(
225245
sdpa = getattr(module, "SDPA", None)
226246
assert sdpa is not None
227247
assert isinstance(sdpa, SDPACustom)
228-
# TODO: add support for SDPA with attention mask
229-
setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache, is_seq_at_dim_2=sdpa.is_seq_at_dim_2)) # noqa: B010
248+
# Quantized SDPA uses a single is_seq_at_dim_2 flag;
249+
# derive from K/V flags (both must match for quantized path).
250+
is_seq_at_dim_2 = sdpa.is_k_seq_at_dim_2 and sdpa.is_v_seq_at_dim_2
251+
setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache, is_seq_at_dim_2=is_seq_at_dim_2)) # noqa: B010
230252

231253

232254
def _replace_sdpa_with_quantized_sdpa(module: torch.nn.Module):

extension/llm/custom_ops/custom_ops.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -168,22 +168,11 @@ def custom_sdpa(
168168
is_causal=False,
169169
scale=None,
170170
is_seq_dim_2=False,
171+
is_k_seq_dim_2=False,
172+
is_v_seq_dim_2=False,
171173
):
172-
seq_len = query.size(2) if is_seq_dim_2 else query.size(1)
173-
_validate_params(
174-
query,
175-
key_cache,
176-
value_cache,
177-
key_cache,
178-
value_cache,
179-
start_pos,
180-
seq_len,
181-
attn_mask,
182-
drpout_p,
183-
is_causal,
184-
scale,
185-
)
186-
174+
# Skip _validate_params since it assumes K/V caches have the same layout.
175+
# With mixed transpose (e.g. v_only), K and V have different shapes.
187176
return torch.empty_like(query)
188177

189178

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,9 @@ Tensor& custom_sdpa_out_impl(
345345
const optional<Tensor>& k_scales = nullopt,
346346
const optional<Tensor>& v_zero_points = nullopt,
347347
const optional<Tensor>& v_scales = nullopt,
348-
bool is_seq_at_dim_2 = false) {
348+
bool is_seq_at_dim_2 = false,
349+
bool is_k_seq_at_dim_2 = false,
350+
bool is_v_seq_at_dim_2 = false) {
349351
ET_KERNEL_CHECK_MSG(
350352
ctx,
351353
!attn_mask.has_value() || !is_causal,
@@ -360,11 +362,10 @@ Tensor& custom_sdpa_out_impl(
360362
output,
361363
"Invalid arguments");
362364

363-
SeqDim seq_dim{SeqDim::TWO};
364-
if (!is_seq_at_dim_2) {
365-
seq_dim = SeqDim::ONE;
366-
}
367-
int64_t seq_len = q.size(static_cast<int64_t>(seq_dim));
365+
SeqDim q_seq_dim = is_seq_at_dim_2 ? SeqDim::TWO : SeqDim::ONE;
366+
SeqDim k_seq_dim = is_k_seq_at_dim_2 ? SeqDim::TWO : SeqDim::ONE;
367+
SeqDim v_seq_dim = is_v_seq_at_dim_2 ? SeqDim::TWO : SeqDim::ONE;
368+
int64_t seq_len = q.size(static_cast<int64_t>(q_seq_dim));
368369

369370
if (q.scalar_type() == ScalarType::Char) {
370371
ET_KERNEL_CHECK_MSG(
@@ -447,7 +448,9 @@ Tensor& custom_sdpa_out_impl(
447448
k_scales,
448449
v_zero_points,
449450
v_scales,
450-
seq_dim,
451+
q_seq_dim,
452+
k_seq_dim,
453+
v_seq_dim,
451454
start_pos,
452455
num_keys_for_causal_attention);
453456
} else if (seq_len >= 192) {
@@ -467,7 +470,9 @@ Tensor& custom_sdpa_out_impl(
467470
k_scales,
468471
v_zero_points,
469472
v_scales,
470-
seq_dim,
473+
q_seq_dim,
474+
k_seq_dim,
475+
v_seq_dim,
471476
start_pos,
472477
num_keys_for_causal_attention);
473478
} else {
@@ -487,7 +492,9 @@ Tensor& custom_sdpa_out_impl(
487492
k_scales,
488493
v_zero_points,
489494
v_scales,
490-
seq_dim,
495+
q_seq_dim,
496+
k_seq_dim,
497+
v_seq_dim,
491498
start_pos,
492499
num_keys_for_causal_attention);
493500
}
@@ -532,6 +539,8 @@ Tensor& custom_quantized_sdpa_out(
532539
k_scales,
533540
v_zero_points,
534541
v_scales,
542+
is_seq_at_dim_2,
543+
is_seq_at_dim_2,
535544
is_seq_at_dim_2);
536545
}
537546

@@ -562,6 +571,8 @@ Tensor& custom_sdpa_out(
562571
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
563572
const optional<double> scale,
564573
const bool is_seq_dim_2,
574+
const bool is_k_seq_dim_2,
575+
const bool is_v_seq_dim_2,
565576
Tensor& output) {
566577
return custom_sdpa_out_impl(
567578
ctx,
@@ -580,7 +591,9 @@ Tensor& custom_sdpa_out(
580591
nullopt,
581592
nullopt,
582593
nullopt,
583-
is_seq_dim_2);
594+
is_seq_dim_2,
595+
is_k_seq_dim_2,
596+
is_v_seq_dim_2);
584597
}
585598
/*
586599
Input params
@@ -635,7 +648,9 @@ Tensor& sdpa_with_kv_cache_out(
635648
dropout_p,
636649
is_causal,
637650
scale,
638-
false, // is_seq_dim_2 - default to false for backward compatibility
651+
false, // is_seq_dim_2
652+
false, // is_k_seq_dim_2
653+
false, // is_v_seq_dim_2
639654
output);
640655

641656
return output;

extension/llm/custom_ops/op_sdpa.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ Tensor& custom_sdpa_out(
4343
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
4444
const optional<double> scale,
4545
const bool is_seq_dim_2,
46+
const bool is_k_seq_dim_2,
47+
const bool is_v_seq_dim_2,
4648
Tensor& output);
4749

4850
Tensor& flash_attention_kernel_out(

0 commit comments

Comments
 (0)