@@ -29,13 +29,16 @@ class RopeWithAttentionSink(Rope):
2929 """
3030 Rope subclass for Attention Sink models.
3131
32- For torch.export compatibility, this passes through the original position
33- unchanged - the sliding window is handled by the cache index management
34- (ring buffer), not by position shifting.
32+ Remaps input positions using modular arithmetic so RoPE frequencies stay
33+ within the cache size bounds, enabling generation beyond max_context_len.
3534
36- Note: This class uses the model's max_context_len (params.max_context_len) for
37- RoPE frequency table size, which should be large enough to support generation
38- beyond the sliding window. The actual KV cache size is sink_size + window_size * 2.
35+ Position mapping:
36+ - Sink tokens (pos < sink_size): position preserved as-is
37+ - Window tokens (pos >= sink_size): wrapped into ring buffer range
38+ [sink_size, sink_size + ring_size) via modulo
39+
40+ The ring buffer is 2x window_size, so the live window (window_size tokens)
41+ never spans a wrap boundary, preserving correct relative distances in RoPE.
3942 """
4043
4144 def __init__ (
@@ -47,19 +50,48 @@ def __init__(
4750 super ().__init__ (params )
4851 self .window_size = window_size
4952 self .sink_size = sink_size
50- # max_context_len from params is used for RoPE frequencies (should be large)
51- self .max_context_length = self .params .max_context_len
53+ self .ring_size = window_size * 2
54+
55+ def _remap_input_pos (self , input_pos : torch .Tensor ) -> torch .Tensor :
56+ """Remap positions: sink tokens stay, window tokens wrap in ring buffer."""
57+ return torch .where (
58+ input_pos < self .sink_size ,
59+ input_pos ,
60+ self .sink_size + (input_pos - self .sink_size ) % self .ring_size ,
61+ )
5262
5363 def get_freqs (self , input_pos : Optional [torch .Tensor ], seq_len : int ):
5464 """
55- Get rotary embedding frequencies.
56- For attention sink, we use the original position - the sliding window
57- is handled by the cache index management, not by position shifting.
65+ Get rotary embedding frequencies with position remapping.
66+
67+ For dynamic shape mode (input_pos is a single start position), we remap
68+ the start and use narrow. For static shape mode (input_pos is the full
69+ position tensor), we remap all positions and index directly.
5870 """
5971 assert input_pos is not None
60- # Use torch._check for export compatibility (data-dependent guard)
61- torch ._check (input_pos [0 ].item () + seq_len <= self .max_context_length )
62- return super ().get_freqs (input_pos , seq_len )
72+ if not self .params .use_kv_cache :
73+ return self .freqs_cos [:seq_len ], self .freqs_sin [:seq_len ]
74+
75+ if self .params .enable_dynamic_shape :
76+ # Dynamic shape: input_pos is [start_pos], remap and narrow
77+ input_pos_item = input_pos [- 1 ].item ()
78+ if input_pos_item < self .sink_size :
79+ remapped_item = input_pos_item
80+ else :
81+ remapped_item = (
82+ self .sink_size + (input_pos_item - self .sink_size ) % self .ring_size
83+ )
84+ torch ._check_is_size (remapped_item )
85+ torch ._check (remapped_item + seq_len <= self .sink_size + self .ring_size )
86+ freqs_cos = self .freqs_cos .narrow (0 , remapped_item , seq_len )
87+ freqs_sin = self .freqs_sin .narrow (0 , remapped_item , seq_len )
88+ else :
89+ # Static shape: remap full position tensor and index
90+ remapped = self ._remap_input_pos (input_pos )
91+ freqs_cos = self .freqs_cos [remapped ]
92+ freqs_sin = self .freqs_sin [remapped ]
93+
94+ return freqs_cos , freqs_sin
6395
6496
6597def _create_causal_mask_for_attention_sink (
0 commit comments