Add --sliding_window flag to CoreML static LLM export#19250
Open
john-rocky wants to merge 1 commit intopytorch:mainfrom
Open
Add --sliding_window flag to CoreML static LLM export#19250john-rocky wants to merge 1 commit intopytorch:mainfrom
john-rocky wants to merge 1 commit intopytorch:mainfrom
Conversation
Models trained with sliding-window attention (Mistral 7B, Gemma 3, Gemma 4, Llama 4 Scout, …) only need each layer to attend to the last `W` tokens, but `export_static_llm_coreml.py` was always sizing the per-layer KV cache to `max_context_len - input_len`. That made longer contexts proportionally more expensive in both KV cache memory and per-token attention compute, even though the model was trained to ignore everything outside the window. Add a `--sliding_window` flag that caps the cache at the trained window. The downstream pieces — `StaticAttentionMask` invariants under cache eviction and the `StaticAttentionIOManager`'s per-layer `cache_lens` plumbing — already support this; the export script just needed to expose it. Per-layer mixed sliding/full attention (Gemma 3/4) is left for a follow-up; this PR uses one window for every layer. The cache_len computation is factored into `_resolve_cache_len` so it is unit-testable, and the README's ANE Optimizations section documents the new option. ### Memory savings example For a 32-layer / n_kv_heads=8 / head_dim=128 model exported with `max_context_len=8192` in fp16, dropping the cache from 8160 to 4096 cuts the per-method KV cache from ~1.07 GB to ~0.54 GB.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19250
Note: Links to docs will display an error until the docs builds have been completed.
|
This PR needs a
|
This was referenced May 1, 2026
Open
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Models trained with sliding-window attention — Mistral 7B, Gemma 3, Gemma 4,
Llama 4 Scout, etc. — only need each layer to attend to the last
Wtokens.export_static_llm_coreml.pywas always sizing the per-layer KV cache tomax_context_len - input_len, so longer contexts were proportionally moreexpensive in both KV cache memory and per-token attention compute even
though the model was trained to ignore everything outside the window.
Add a
--sliding_windowflag that caps the cache at the trained window.The downstream pieces —
StaticAttentionMaskinvariants under cacheeviction (validated by
test_sliding_window_cache_and_mask) andStaticAttentionIOManager's per-layercache_lensplumbing — alreadysupport this; the export script just needed to expose it.
The cache_len computation is factored into
_resolve_cache_lenso it isunit-testable. Per-layer mixed sliding/full attention (Gemma 3 / Gemma 4
alternate sliding and full layers) is intentionally left for a follow-up;
this PR uses one window for every layer. Documented in the ANE
Optimizations section of
readme.md.Memory savings example
For a 32-layer /
n_kv_heads=8/head_dim=128model exported withmax_context_len=8192in fp16:cache_len = 8160)--sliding_window 4096Test plan
Added unit tests in
examples/apple/coreml/llama/test.py:test_resolve_cache_len_no_sliding_window— default path is unchanged.test_resolve_cache_len_with_sliding_window— cache shrinks to the window.test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op— auser-provided window larger than the remaining context degenerates to the
no-window case (so users can pass the model's training window verbatim).
test_resolve_cache_len_rejects_non_positive_window— input validation.test_create_example_inputs_with_sliding_window_shrinks_kv_cache— fullpath: every cache tensor in the example inputs has its sequence dimension
equal to the sliding window, and the attention mask covers
input_len + sliding_window.I also confirmed the existing
examples/models/llama/tests/test_static_attention.py::test_sliding_window_cache_and_maskalready covers the cache + mask invariants under bothshift_pointerandsmart_maskeviction styles whencache_len < total_tokens, so this PR does not need to re-test that.Authored with Claude.