Integrate attention sink into ET LLM export and runner (#18860)#18860
Integrate attention sink into ET LLM export and runner (#18860)#18860meta-codesync[bot] merged 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18860
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 3 PendingAs of commit df7ede5 with merge base 3180927 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D100216686. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Integrates the “attention sink” ring-buffer KV-cache path into the Llama export pipeline and runtime runner, including custom-op KV cache replacements and updated runner-side constraints for sliding-window models.
Changes:
- Add
CustomKVCacheWithAttentionSinkand extend KV-cache replacement to handleKVCacheWithAttentionSinkandRingKVCache, including enabling attention-mask mode onSDPACustom. - Reorder export transforms so SDPA replacement happens before KV-cache replacement (to preserve
use_attention_mask=Truefor ring-buffer models). - Update C++ text runner to (a) enforce a single-prefill-chunk
max_seq_lenlimit and (b) relax/adjust context-limit logic for sliding-window models; add new end-to-end tests for custom SDPA + custom KV cache.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| extension/llm/runner/text_llm_runner.cpp | Adds max-seq-len prefill guard and modifies max-new-tokens budgeting for sliding-window models. |
| examples/models/llama/source_transformation/test_attention_sink.py | Adds end-to-end generation tests for custom SDPA + custom KV cache and sink_size=0 edge case. |
| examples/models/llama/source_transformation/custom_kv_cache.py | Introduces CustomKVCacheWithAttentionSink and expands KV-cache replacement logic for attention-sink/ring KV caches. |
| examples/models/llama/export_llama_lib.py | Reorders source transforms so SDPA replacement precedes KV-cache replacement to preserve mask settings. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // For sliding window models, the ring buffer recycles space — pos_ doesn't | ||
| // represent consumed capacity, so pass 0 to get the full budget. | ||
| int64_t effective_pos = (max_seq_len < max_context_len) ? 0 : pos_; | ||
| int max_new_tokens = config.resolve_max_new_tokens(max_context_len, effective_pos); |
There was a problem hiding this comment.
effective_pos is set to 0 for sliding-window models, but resolve_max_new_tokens(max_context_len, effective_pos) is intended to ensure start_pos + max_new_tokens <= max_context_len. With effective_pos=0, a nonzero pos_ can lead to generating past max_context_len (e.g., RoPE tables typically guard input_pos + seq_len <= max_context_len), causing runtime failures. Consider keeping pos_ (or a wrapped position if the model truly wraps positions) when resolving max_new_tokens, and only relax KV-cache-capacity checks separately from the RoPE/context-length limit.
| // For sliding window models, the ring buffer recycles space — pos_ doesn't | |
| // represent consumed capacity, so pass 0 to get the full budget. | |
| int64_t effective_pos = (max_seq_len < max_context_len) ? 0 : pos_; | |
| int max_new_tokens = config.resolve_max_new_tokens(max_context_len, effective_pos); | |
| // Resolve generation budget from the actual starting position so | |
| // start_pos + max_new_tokens stays within max_context_len. Sliding-window | |
| // KV-cache reuse must not bypass positional/context-length limits. | |
| int64_t effective_pos = pos_; | |
| int max_new_tokens = | |
| config.resolve_max_new_tokens(max_context_len, effective_pos); |
| def update(self, input_pos, k_val, v_val): | ||
| seq_len = k_val.transpose(1, 2).size(1) | ||
| assert seq_len <= self.k_cache.size( | ||
| 1 | ||
| ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(1)})" | ||
| indices = self.cache_positions_manager.calculate_positions_and_update_indices( | ||
| input_pos, seq_len | ||
| ) | ||
| indices = indices.unsqueeze(0) | ||
|
|
||
| return super().update(input_pos, k_val, v_val, indices) |
There was a problem hiding this comment.
CustomKVCacheWithAttentionSink.update() computes ring-buffer write indices but does not validate that the number of window tokens written in a single update fits within the ring portion (i.e., avoids duplicate indices due to modulo wrap). KVCacheWithAttentionSink.update() has an explicit num_window_tokens <= ring_size guard to prevent non-deterministic index_copy_ behavior; the same kind of guard is needed here as well (especially since update_cache_with_indices is also a scatter-style update).
|
@claude review the epatch |
|
Claude finished @kirklandsign's task in 1m 51s —— View job Review of PR #18860: Integrate attention sink into ET LLM export and runner
Overall this is a well-structured PR that correctly integrates the attention sink ring buffer into the export pipeline and C++ runner. The transform ordering fix is sound, the new Issue 1: Missing
|
Issue 1 and 4 will be addressed in a follow-up patch. 2 and 3 is not a blocker |
Summary: Pull Request resolved: #18860 Add custom op support, export pipeline integration, and C++ runner fixes for the attention sink ring buffer implementation. - CustomKVCacheWithAttentionSink: custom op variant using update_cache_with_indices for scatter-write performance. Replaces KVCacheWithAttentionSink during export. - CustomRingKVCache replacement: handle RingKVCache -> CustomRingKVCache in the replacement pass, and set SDPACustom.use_attention_mask=True for ring buffer models. - Export transform ordering: replace SDPA before KV cache so that _replace_kv_cache_with_custom_kv_cache can set use_attention_mask=True on the already-existing SDPACustom (previously the ordering was reversed, causing the mask flag to be overwritten by a new SDPACustom). - C++ runner: add max_seq_len prefill check; make context length check conditional for sliding window models (max_seq_len < max_context_len) since they handle position wrapping internally via ring buffer. Differential Revision: D100216686
85ab161 to
e9b14af
Compare
Summary: Pull Request resolved: #18860 Add custom op support, export pipeline integration, and C++ runner fixes for the attention sink ring buffer implementation. - CustomKVCacheWithAttentionSink: custom op variant using update_cache_with_indices for scatter-write performance. Replaces KVCacheWithAttentionSink during export. - CustomRingKVCache replacement: handle RingKVCache -> CustomRingKVCache in the replacement pass, and set SDPACustom.use_attention_mask=True for ring buffer models. - Export transform ordering: replace SDPA before KV cache so that _replace_kv_cache_with_custom_kv_cache can set use_attention_mask=True on the already-existing SDPACustom (previously the ordering was reversed, causing the mask flag to be overwritten by a new SDPACustom). - C++ runner: add max_seq_len prefill check; make context length check conditional for sliding window models (max_seq_len < max_context_len) since they handle position wrapping internally via ring buffer. Reviewed By: lucylq Differential Revision: D100216686
e9b14af to
ec586db
Compare
Summary: Pull Request resolved: #18860 Add custom op support, export pipeline integration, and C++ runner fixes for the attention sink ring buffer implementation. - CustomKVCacheWithAttentionSink: custom op variant using update_cache_with_indices for scatter-write performance. Replaces KVCacheWithAttentionSink during export. - CustomRingKVCache replacement: handle RingKVCache -> CustomRingKVCache in the replacement pass, and set SDPACustom.use_attention_mask=True for ring buffer models. - Export transform ordering: replace SDPA before KV cache so that _replace_kv_cache_with_custom_kv_cache can set use_attention_mask=True on the already-existing SDPACustom (previously the ordering was reversed, causing the mask flag to be overwritten by a new SDPACustom). - C++ runner: add max_seq_len prefill check; make context length check conditional for sliding window models (max_seq_len < max_context_len) since they handle position wrapping internally via ring buffer. Reviewed By: lucylq Differential Revision: D100216686
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ET_CHECK_OR_RETURN_ERROR( | ||
| pos_ + num_prompt_tokens < max_context_len, | ||
| num_prompt_tokens <= max_seq_len, | ||
| InvalidArgument, | ||
| "pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64 | ||
| ", Max seq length exceeded - please increase max seq len value in your export script", | ||
| pos_, | ||
| "num_prompt_tokens %d > max_seq_len %" PRId64 | ||
| ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", | ||
| num_prompt_tokens, | ||
| max_context_len); | ||
| max_seq_len); |
There was a problem hiding this comment.
The new num_prompt_tokens <= max_seq_len check rejects prompts longer than max_seq_len, but TextPrefiller::prefill() already chunks prompts larger than max_seq_len_ into multiple prefill_chunk() calls. This is a behavior regression (previously long prompts could succeed via chunking) and isn’t necessary to enforce the per-chunk limit. Consider removing this guard and relying on TextPrefiller chunking, or (if needed) enforce the max chunk size inside the prefiller rather than rejecting the whole prompt here.
ec586db to
df7ede5
Compare
Summary:
Add custom op support, export pipeline integration, and C++ runner fixes
for the attention sink ring buffer implementation.
for scatter-write performance. Replaces KVCacheWithAttentionSink during export.
replacement pass, and set SDPACustom.use_attention_mask=True for ring buffer models.
_replace_kv_cache_with_custom_kv_cache can set use_attention_mask=True on the
already-existing SDPACustom (previously the ordering was reversed, causing the
mask flag to be overwritten by a new SDPACustom).
for sliding window models (max_seq_len < max_context_len) since they handle
position wrapping internally via ring buffer.
Reviewed By: lucylq
Differential Revision: D100216686