Skip to content

Commit 7cff409

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Attention sink support for LLM runner
Summary: Rewrite the Attention Sink KV cache implementation from eviction-based to ring buffer approach for torch.export compatibility. Key changes: - Ring buffer KV cache: Replace dynamic eviction (torch.cat, narrow, shift) with fixed-size ring buffer using index_copy_. Cache layout: [sink slots | ring buffer slots]. Sink tokens (e.g., BOS) occupy fixed positions; window tokens wrap around in the ring buffer region. - Remove eviction_batch_size: No longer needed -- ring buffer overwrites old entries automatically. Removed from all interfaces (attention_sink.py, model.py, llm_config.py, yaml config). - Remove attention_sink_forward: No more monkey-patching AttentionMHA.forward. Instead, KVCacheWithAttentionSink sets is_ring_buffer=True, and AttentionMHA.forward handles ring buffer models natively (skip start_pos bounds check, compute mask after KV update). - Remove rerotate_k / position shifting: Ring buffer uses original positions for RoPE -- no re-rotation needed. - Fix C++ runner: Remove TEMPORARY max_new_tokens hack. Add max_seq_len prefill check. Make context length check conditional for sliding window models. - Rewrite tests: Replace 16 eviction-based tests with 18 ring buffer tests covering sink preservation, ring wrapping, causal masking, and degenerate (sink_size=0) cases. - Add llama_attention_sink.yaml: Example config for attention sink export. Differential Revision: D99900289
1 parent 1eaafd9 commit 7cff409

9 files changed

Lines changed: 506 additions & 612 deletions

File tree

examples/models/llama/attention.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,14 @@ def forward(
550550

551551
if self.use_kv_cache:
552552
assert input_pos is not None
553-
if self.enable_dynamic_shape:
553+
is_ring_buffer = getattr(self.kv_cache, "is_ring_buffer", False)
554+
555+
if is_ring_buffer:
556+
# Ring buffer models compute their own mask after KV cache
557+
# update; skip start_pos bounds check since start_pos can
558+
# exceed max_context_len for sliding window / attention sink.
559+
attn_mask = None
560+
elif self.enable_dynamic_shape:
554561
start_pos = input_pos[-1].item()
555562
torch._check_is_size(start_pos)
556563
torch._check(start_pos < self.max_context_len)
@@ -569,7 +576,7 @@ def forward(
569576
)
570577
k, v = self.kv_cache.update(input_pos, k, v)
571578

572-
if getattr(self.kv_cache, "is_ring_buffer", False):
579+
if is_ring_buffer:
573580
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
574581
input_pos[0].item(), seqlen
575582
)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
base:
2+
metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
3+
4+
model:
5+
use_sdpa_with_kv_cache: False # attention_sink requires standard SDPA
6+
use_kv_cache: True
7+
dtype_override: fp32
8+
enable_dynamic_shape: True
9+
# Attention Sink: "sink_size,window_size"
10+
# sink_size=4: Keep first 4 tokens (e.g., BOS + system prompt)
11+
# window_size=124: sliding window size
12+
# KV cache size = sink_size + window_size * 2 = 4 + 124*2 = 252
13+
use_attention_sink: "4,124"
14+
15+
export:
16+
# max_context_length controls the RoPE frequency table size.
17+
# It must be >= sink_size + window_size (128), but larger values are
18+
# recommended to support generation beyond the sliding window.
19+
# The model default (e.g., 8192 or 131072) is typically used if not specified.
20+
# For testing, we use the model's default by not setting this explicitly.
21+
22+
quantization:
23+
qmode: 8da4w
24+
group_size: 128
25+
embedding_quantize: 4,32
26+
27+
backend:
28+
xnnpack:
29+
enabled: True
30+
extended_ops: True

examples/models/llama/config/test_llm_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
class TestValidation(unittest.TestCase):
2626
def test_invalid_attention_sink(self):
2727
with self.assertRaises(ValueError):
28-
ModelConfig(use_attention_sink="4,2048")
28+
ModelConfig(use_attention_sink="4")
2929

3030
def test_invalid_local_global_attention_format(self):
3131
with self.assertRaises(ValueError):
@@ -79,7 +79,7 @@ def test_valid_llm_config(self):
7979
),
8080
model=ModelConfig(
8181
dtype_override="fp32",
82-
use_attention_sink="4,2048,1024",
82+
use_attention_sink="4,2048",
8383
use_kv_cache=True,
8484
local_global_attention="[16, 32]",
8585
),

examples/models/llama/model.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,19 +203,29 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
203203
from .source_transformation.attention_sink import enable_attention_sink
204204

205205
attention_sink_params = self.llm_config.model.use_attention_sink.split(",")
206-
assert len(attention_sink_params) == 3
206+
assert len(attention_sink_params) == 2, (
207+
f"use_attention_sink expects exactly 2 comma-separated values "
208+
f"(sink_size,window_size), got {len(attention_sink_params)}"
209+
)
207210
sink_size = int(attention_sink_params[0])
208211
window_size = int(attention_sink_params[1])
209-
eviction_batch_size = int(attention_sink_params[2])
210212

211-
assert self.llm_config.export.max_context_length == sink_size + window_size
213+
# max_context_length must be >= sink_size + window_size to have enough RoPE frequencies
214+
# A larger max_context_length is allowed (and recommended) to support generation beyond
215+
# the sliding window size.
216+
assert self.llm_config.export.max_context_length >= sink_size + window_size, (
217+
f"max_context_length ({self.llm_config.export.max_context_length}) must be >= "
218+
f"sink_size + window_size ({sink_size + window_size})"
219+
)
220+
assert not self.llm_config.model.use_sdpa_with_kv_cache, (
221+
"Attention sink is not compatible with use_sdpa_with_kv_cache"
222+
)
212223

213224
self.model_ = enable_attention_sink(
214225
module=self.model_,
215226
params=model_args,
216227
sink_size=sink_size,
217228
window_size=window_size,
218-
eviction_batch_size=eviction_batch_size,
219229
)
220230

221231
missing, unexpected = None, None

0 commit comments

Comments
 (0)