Skip to content

Commit 239f7cc

Browse files
authored
Enable infinite generation with RoPE position remapping for attention sink (#19011)
Differential Revision: D100728748 Pull Request resolved: #19011
1 parent 2d53535 commit 239f7cc

2 files changed

Lines changed: 64 additions & 14 deletions

File tree

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,16 @@ class RopeWithAttentionSink(Rope):
2929
"""
3030
Rope subclass for Attention Sink models.
3131
32-
For torch.export compatibility, this passes through the original position
33-
unchanged - the sliding window is handled by the cache index management
34-
(ring buffer), not by position shifting.
32+
Remaps input positions using modular arithmetic so RoPE frequencies stay
33+
within the cache size bounds, enabling generation beyond max_context_len.
3534
36-
Note: This class uses the model's max_context_len (params.max_context_len) for
37-
RoPE frequency table size, which should be large enough to support generation
38-
beyond the sliding window. The actual KV cache size is sink_size + window_size * 2.
35+
Position mapping:
36+
- Sink tokens (pos < sink_size): position preserved as-is
37+
- Window tokens (pos >= sink_size): wrapped into ring buffer range
38+
[sink_size, sink_size + ring_size) via modulo
39+
40+
The ring buffer is 2x window_size, so the live window (window_size tokens)
41+
never spans a wrap boundary, preserving correct relative distances in RoPE.
3942
"""
4043

4144
def __init__(
@@ -47,19 +50,48 @@ def __init__(
4750
super().__init__(params)
4851
self.window_size = window_size
4952
self.sink_size = sink_size
50-
# max_context_len from params is used for RoPE frequencies (should be large)
51-
self.max_context_length = self.params.max_context_len
53+
self.ring_size = window_size * 2
54+
55+
def _remap_input_pos(self, input_pos: torch.Tensor) -> torch.Tensor:
56+
"""Remap positions: sink tokens stay, window tokens wrap in ring buffer."""
57+
return torch.where(
58+
input_pos < self.sink_size,
59+
input_pos,
60+
self.sink_size + (input_pos - self.sink_size) % self.ring_size,
61+
)
5262

5363
def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
5464
"""
55-
Get rotary embedding frequencies.
56-
For attention sink, we use the original position - the sliding window
57-
is handled by the cache index management, not by position shifting.
65+
Get rotary embedding frequencies with position remapping.
66+
67+
For dynamic shape mode (input_pos is a single start position), we remap
68+
the start and use narrow. For static shape mode (input_pos is the full
69+
position tensor), we remap all positions and index directly.
5870
"""
5971
assert input_pos is not None
60-
# Use torch._check for export compatibility (data-dependent guard)
61-
torch._check(input_pos[0].item() + seq_len <= self.max_context_length)
62-
return super().get_freqs(input_pos, seq_len)
72+
if not self.params.use_kv_cache:
73+
return self.freqs_cos[:seq_len], self.freqs_sin[:seq_len]
74+
75+
if self.params.enable_dynamic_shape:
76+
# Dynamic shape: input_pos is [start_pos], remap and narrow
77+
input_pos_item = input_pos[-1].item()
78+
if input_pos_item < self.sink_size:
79+
remapped_item = input_pos_item
80+
else:
81+
remapped_item = (
82+
self.sink_size + (input_pos_item - self.sink_size) % self.ring_size
83+
)
84+
torch._check_is_size(remapped_item)
85+
torch._check(remapped_item + seq_len <= self.sink_size + self.ring_size)
86+
freqs_cos = self.freqs_cos.narrow(0, remapped_item, seq_len)
87+
freqs_sin = self.freqs_sin.narrow(0, remapped_item, seq_len)
88+
else:
89+
# Static shape: remap full position tensor and index
90+
remapped = self._remap_input_pos(input_pos)
91+
freqs_cos = self.freqs_cos[remapped]
92+
freqs_sin = self.freqs_sin[remapped]
93+
94+
return freqs_cos, freqs_sin
6395

6496

6597
def _create_causal_mask_for_attention_sink(

examples/models/llama/source_transformation/test_attention_sink.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,24 @@ def test_beyond_context_window_basic(self):
398398
torch.isfinite(out).all(), "Output contains non-finite values"
399399
)
400400

401+
def test_beyond_max_context_len(self):
402+
"""Generate tokens beyond max_context_len with RoPE position remapping."""
403+
sink_size = 4
404+
window_size = 16
405+
# KV cache size = 36, max_context_len = 64
406+
# Generate 100 tokens — well beyond max_context_len
407+
args = self._make_args(max_context_len=64)
408+
model = self._build_model(args, sink_size, window_size, use_custom_sdpa=False)
409+
410+
outputs = self._run_generation(model, args, num_tokens=100)
411+
412+
self.assertEqual(len(outputs), 97) # 1 prefill + 96 decode steps
413+
for out in outputs:
414+
self.assertTrue(
415+
torch.isfinite(out).all(),
416+
"Output contains non-finite values beyond max_context_len",
417+
)
418+
401419
def test_beyond_context_window_custom_sdpa(self):
402420
"""Generate tokens beyond context window with custom SDPA + custom KV cache."""
403421
sink_size = 4

0 commit comments

Comments
 (0)