Skip to content

Commit 85ab161

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Integrate attention sink into ET LLM export and runner
Summary: Add custom op support, export pipeline integration, and C++ runner fixes for the attention sink ring buffer implementation. - CustomKVCacheWithAttentionSink: custom op variant using update_cache_with_indices for scatter-write performance. Replaces KVCacheWithAttentionSink during export. - CustomRingKVCache replacement: handle RingKVCache -> CustomRingKVCache in the replacement pass, and set SDPACustom.use_attention_mask=True for ring buffer models. - Export transform ordering: replace SDPA before KV cache so that _replace_kv_cache_with_custom_kv_cache can set use_attention_mask=True on the already-existing SDPACustom (previously the ordering was reversed, causing the mask flag to be overwritten by a new SDPACustom). - C++ runner: add max_seq_len prefill check; make context length check conditional for sliding window models (max_seq_len < max_context_len) since they handle position wrapping internally via ring buffer. Differential Revision: D100216686
1 parent 56d6e4d commit 85ab161

4 files changed

Lines changed: 183 additions & 13 deletions

File tree

examples/models/llama/export_llama_lib.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,16 +1781,17 @@ def _get_source_transforms( # noqa
17811781
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
17821782

17831783
if use_sdpa_with_kv_cache:
1784-
transforms.append(replace_kv_cache_with_custom_kv_cache)
1785-
# todo: do this optionally
1786-
# if use attention mask instead of causal attention
1787-
# then create partial function that sets use_attention_mask=True
1784+
# Replace SDPA first, then KV cache. Order matters: the KV cache
1785+
# replacement sets SDPACustom.use_attention_mask=True for ring buffer
1786+
# models (attention sink, sliding window). If SDPA is replaced after,
1787+
# a new SDPACustom(use_attention_mask=False) would overwrite it.
17881788
if use_attention_mask_for_custom_sdpa:
17891789
transforms.append(
17901790
partial(replace_sdpa_with_custom_op, use_attention_mask=True)
17911791
)
17921792
else:
17931793
transforms.append(replace_sdpa_with_custom_op)
1794+
transforms.append(replace_kv_cache_with_custom_kv_cache)
17941795

17951796
if quantize_kv_cache:
17961797
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"

examples/models/llama/source_transformation/custom_kv_cache.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,41 @@ def replace_kv_cache_with_custom_kv_cache(module):
371371

372372

373373
def _replace_kv_cache_with_custom_kv_cache(module):
374+
# Import here to avoid circular imports
375+
from executorch.examples.models.llama.source_transformation.attention_sink import (
376+
KVCacheWithAttentionSink,
377+
)
378+
374379
for name, child in module.named_children():
375-
if isinstance(child, KVCache):
380+
if isinstance(child, KVCacheWithAttentionSink):
381+
# Replace with custom op variant for performance
382+
setattr(
383+
module,
384+
name,
385+
CustomKVCacheWithAttentionSink.from_kv_cache_with_attention_sink(child),
386+
)
387+
# If parent has SDPACustom, enable explicit mask mode
388+
sdpa = getattr(module, "SDPA", None)
389+
if sdpa is not None and hasattr(sdpa, "use_attention_mask"):
390+
sdpa.use_attention_mask = True
391+
elif isinstance(child, RingKVCache):
392+
# RingKVCache (e.g., from attention sink with sink_size=0) needs
393+
# CustomRingKVCache, not plain CustomKVCache
394+
setattr(
395+
module,
396+
name,
397+
CustomRingKVCache(
398+
child.max_batch_size,
399+
child.window_size,
400+
child.n_heads,
401+
child.head_dim,
402+
dtype=child.k_cache.dtype,
403+
),
404+
)
405+
sdpa = getattr(module, "SDPA", None)
406+
if sdpa is not None and hasattr(sdpa, "use_attention_mask"):
407+
sdpa.use_attention_mask = True
408+
elif isinstance(child, KVCache):
376409
cache_shape = child.k_cache.shape
377410
cache_dtype = child.k_cache.dtype
378411
max_batch_size, n_heads, max_context_length, head_dim = cache_shape
@@ -466,6 +499,81 @@ def from_quantized_kv_cache(
466499
)
467500

468501

502+
class CustomKVCacheWithAttentionSink(CustomKVCache):
503+
"""
504+
CustomKVCache variant for attention sink models.
505+
506+
Uses the custom update_cache_with_indices op for performance while
507+
supporting sink tokens (fixed slots) + ring buffer (sliding window).
508+
Modeled after CustomRingKVCache but with CachePositionsManagerWithSink.
509+
"""
510+
511+
def __init__(
512+
self,
513+
max_batch_size,
514+
n_heads,
515+
head_dim,
516+
window_size,
517+
sink_size,
518+
dtype=torch.float32,
519+
):
520+
# Total cache size: sink slots + ring buffer (2x window for wrap safety)
521+
total_cache_size = sink_size + window_size * 2
522+
super().__init__(
523+
max_batch_size, total_cache_size, n_heads, head_dim, dtype
524+
)
525+
from executorch.examples.models.llama.source_transformation.attention_sink import (
526+
CachePositionsManagerWithSink,
527+
_create_causal_mask_for_attention_sink,
528+
)
529+
530+
self.cache_positions_manager = CachePositionsManagerWithSink(
531+
total_cache_size, sink_size
532+
)
533+
self.is_ring_buffer = True
534+
self.window_size = window_size
535+
self.sink_size = sink_size
536+
self._create_causal_mask_for_attention_sink = (
537+
_create_causal_mask_for_attention_sink
538+
)
539+
540+
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
541+
cache_positions = self.cache_positions_manager.cache_positions
542+
if self.sink_size > 0:
543+
return self._create_causal_mask_for_attention_sink(
544+
cache_positions, self.window_size, self.sink_size, start_pos, seq_len
545+
)
546+
else:
547+
return _create_causal_mask_for_ring_buffer(
548+
cache_positions, self.window_size, start_pos, seq_len
549+
)
550+
551+
def update(self, input_pos, k_val, v_val):
552+
seq_len = k_val.transpose(1, 2).size(1)
553+
assert seq_len <= self.k_cache.size(
554+
1
555+
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(1)})"
556+
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
557+
input_pos, seq_len
558+
)
559+
indices = indices.unsqueeze(0)
560+
561+
return super().update(input_pos, k_val, v_val, indices)
562+
563+
@classmethod
564+
def from_kv_cache_with_attention_sink(cls, kv_cache):
565+
"""Create from an existing KVCacheWithAttentionSink."""
566+
max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape
567+
return cls(
568+
max_batch_size,
569+
n_heads,
570+
head_dim,
571+
kv_cache.window_size,
572+
kv_cache.sink_size,
573+
dtype=kv_cache.k_cache.dtype,
574+
)
575+
576+
469577
class CustomRingKVCache(CustomKVCache):
470578
def __init__(
471579
self,

examples/models/llama/source_transformation/test_attention_sink.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,48 @@ def test_beyond_context_window_basic(self):
397397
self.assertTrue(
398398
torch.isfinite(out).all(), "Output contains non-finite values"
399399
)
400+
401+
def test_beyond_context_window_custom_sdpa(self):
402+
"""Generate tokens beyond context window with custom SDPA + custom KV cache."""
403+
sink_size = 4
404+
window_size = 16
405+
args = self._make_args(max_context_len=128)
406+
model = self._build_model(args, sink_size, window_size, use_custom_sdpa=True)
407+
408+
# Verify KV caches were replaced with CustomKVCacheWithAttentionSink
409+
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
410+
CustomKVCacheWithAttentionSink,
411+
)
412+
413+
found_custom_cache = False
414+
for m in model.modules():
415+
if isinstance(m, CustomKVCacheWithAttentionSink):
416+
found_custom_cache = True
417+
break
418+
self.assertTrue(
419+
found_custom_cache, "Expected CustomKVCacheWithAttentionSink in model"
420+
)
421+
422+
# Generate 80 tokens — well beyond KV cache size of 36
423+
outputs = self._run_generation(model, args, num_tokens=80)
424+
425+
self.assertEqual(len(outputs), 77)
426+
for out in outputs:
427+
self.assertTrue(
428+
torch.isfinite(out).all(), "Output contains non-finite values"
429+
)
430+
431+
def test_sink_zero_custom_sdpa(self):
432+
"""Degenerate case: sink_size=0 with custom SDPA (pure ring buffer)."""
433+
sink_size = 0
434+
window_size = 16
435+
args = self._make_args(max_context_len=128)
436+
model = self._build_model(args, sink_size, window_size, use_custom_sdpa=True)
437+
438+
outputs = self._run_generation(model, args, num_tokens=60)
439+
440+
self.assertEqual(len(outputs), 57)
441+
for out in outputs:
442+
self.assertTrue(
443+
torch.isfinite(out).all(), "Output contains non-finite values"
444+
)

extension/llm/runner/text_llm_runner.cpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ Error TextLLMRunner::generate(
110110
stats_->inference_start_ms = time_in_ms();
111111
shouldStop_ = false;
112112

113+
// Get max_seq_len for single prefill chunk limit
114+
int64_t max_seq_len = metadata_.at(kMaxSeqLen);
113115
int64_t max_context_len = metadata_.at(kMaxContextLen);
114116

115117
uint64_t cur_token = 0;
@@ -138,13 +140,26 @@ Error TextLLMRunner::generate(
138140
InvalidArgument,
139141
"Expected at least 1 prompt token");
140142
ET_CHECK_OR_RETURN_ERROR(
141-
pos_ + num_prompt_tokens < max_context_len,
143+
num_prompt_tokens <= max_seq_len,
142144
InvalidArgument,
143-
"pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64
144-
", Max seq length exceeded - please increase max seq len value in your export script",
145-
pos_,
145+
"num_prompt_tokens %d > max_seq_len %" PRId64
146+
", Single prefill chunk too large - please reduce prompt size or increase max_seq_len",
146147
num_prompt_tokens,
147-
max_context_len);
148+
max_seq_len);
149+
// For non-sliding-window models, also check that we won't exceed
150+
// KV cache capacity. Sliding window models (where max_seq_len <
151+
// max_context_len) handle position wrapping internally.
152+
if (max_seq_len >= max_context_len) {
153+
ET_CHECK_OR_RETURN_ERROR(
154+
pos_ + num_prompt_tokens < max_context_len,
155+
InvalidArgument,
156+
"pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64
157+
", Max seq length exceeded - please increase max seq len value in "
158+
"your export script",
159+
pos_,
160+
num_prompt_tokens,
161+
max_context_len);
162+
}
148163

149164
// print prompts
150165
if (config.echo) {
@@ -168,9 +183,10 @@ Error TextLLMRunner::generate(
168183
prefill_next_token_.reset();
169184
}
170185

171-
// Resolve max_new_tokens. pos_ now reflects all occupied positions
172-
// (including prompt tokens just prefilled).
173-
int max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_);
186+
// For sliding window models, the ring buffer recycles space — pos_ doesn't
187+
// represent consumed capacity, so pass 0 to get the full budget.
188+
int64_t effective_pos = (max_seq_len < max_context_len) ? 0 : pos_;
189+
int max_new_tokens = config.resolve_max_new_tokens(max_context_len, effective_pos);
174190

175191
ET_LOG(
176192
Info,

0 commit comments

Comments
 (0)