Skip to content

Commit 8a2dfb5

Browse files
committed
Add --sliding_window flag to CoreML static LLM export
Models trained with sliding-window attention (Mistral 7B, Gemma 3, Gemma 4, Llama 4 Scout, …) only need each layer to attend to the last `W` tokens, but `export_static_llm_coreml.py` was always sizing the per-layer KV cache to `max_context_len - input_len`. That made longer contexts proportionally more expensive in both KV cache memory and per-token attention compute, even though the model was trained to ignore everything outside the window. Add a `--sliding_window` flag that caps the cache at the trained window. The downstream pieces — `StaticAttentionMask` invariants under cache eviction and the `StaticAttentionIOManager`'s per-layer `cache_lens` plumbing — already support this; the export script just needed to expose it. Per-layer mixed sliding/full attention (Gemma 3/4) is left for a follow-up; this PR uses one window for every layer. The cache_len computation is factored into `_resolve_cache_len` so it is unit-testable, and the README's ANE Optimizations section documents the new option. ### Memory savings example For a 32-layer / n_kv_heads=8 / head_dim=128 model exported with `max_context_len=8192` in fp16, dropping the cache from 8160 to 4096 cuts the per-method KV cache from ~1.07 GB to ~0.54 GB.
1 parent 94d2881 commit 8a2dfb5

3 files changed

Lines changed: 133 additions & 5 deletions

File tree

examples/apple/coreml/llama/export_static_llm_coreml.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import argparse
2323
import json
24+
from typing import Optional
2425

2526
import coremltools as ct
2627
import torch
@@ -170,6 +171,28 @@ def load_model(
170171
return model, args
171172

172173

174+
def _resolve_cache_len(
175+
max_context_len: int, input_len: int, sliding_window: Optional[int] = None
176+
) -> int:
177+
"""Pick the per-layer KV cache length given context / input / window settings.
178+
179+
Without sliding-window attention the cache must hold every token that can
180+
attend to the current step, i.e. ``max_context_len - input_len``. When the
181+
model is trained with sliding-window attention we instead cap the cache at
182+
``sliding_window`` so longer contexts do not enlarge per-layer attention
183+
compute or KV cache memory.
184+
"""
185+
cache_len = max_context_len - input_len
186+
if sliding_window is not None:
187+
if sliding_window <= 0:
188+
raise ValueError(
189+
f"sliding_window must be positive, got {sliding_window}"
190+
)
191+
if sliding_window < cache_len:
192+
cache_len = sliding_window
193+
return cache_len
194+
195+
173196
def _create_example_inputs(
174197
model_args, input_len, max_context_len, float_dtype, cache_len=None
175198
):
@@ -410,6 +433,18 @@ def main():
410433
default=32,
411434
help="Input sequence length per forward pass",
412435
)
436+
parser.add_argument(
437+
"--sliding_window",
438+
type=int,
439+
default=None,
440+
help=(
441+
"Sliding window attention size. When set, every layer uses a KV cache "
442+
"of this many tokens instead of (max_context_len - input_len), which "
443+
"lets the model serve longer contexts without growing per-layer attention "
444+
"compute or KV cache memory. Required for Mistral / Gemma3 / Gemma4 / "
445+
"Llama4-style models that train with sliding-window attention."
446+
),
447+
)
413448
parser.add_argument(
414449
"--dtype",
415450
type=str,
@@ -481,11 +516,15 @@ def main():
481516
print(f"\tLinear quantize: {args.linear_quantize}")
482517
print(f"\tDtype: {args.dtype}")
483518

484-
cache_len = args.max_context_len - args.input_len
519+
cache_len = _resolve_cache_len(
520+
args.max_context_len, args.input_len, args.sliding_window
521+
)
485522
print("\nGeneration configuration:")
486523
print(f"\tMax context length: {args.max_context_len}")
487524
print(f"\tInput length: {args.input_len}")
488525
print(f"\tCache length: {cache_len}")
526+
if args.sliding_window is not None:
527+
print(f"\tSliding window: {args.sliding_window}")
489528

490529
print("\nLinear splitting:")
491530
print(f"\tTarget split size: {args.target_split_size}")
@@ -513,9 +552,9 @@ def main():
513552
# the same cache buffer at runtime without any copying.
514553
decode_input_len = 1
515554
prefill_input_len = args.input_len # default 32
516-
shared_cache_len = (
517-
args.max_context_len - decode_input_len
518-
) # Use decode's cache size for both
555+
shared_cache_len = _resolve_cache_len(
556+
args.max_context_len, decode_input_len, args.sliding_window
557+
)
519558

520559
print(f"\nShared cache length for prefill/decode: {shared_cache_len}")
521560

@@ -643,7 +682,11 @@ def main():
643682
# Single method mode: fixed seqlen with generate_full_logits=True for lookahead
644683
print(f"\nCreating example inputs (seqlen={args.input_len})...")
645684
example_inputs, example_cache_len = _create_example_inputs(
646-
model_args, args.input_len, args.max_context_len, float_dtype
685+
model_args,
686+
args.input_len,
687+
args.max_context_len,
688+
float_dtype,
689+
cache_len=cache_len,
647690
)
648691

649692
# Test eager execution

examples/apple/coreml/llama/readme.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Key differences between the two modes:
6969
| `--max_context_len` | 1024 | Maximum context length |
7070
| `--input_len` | 32 | Input sequence length per forward pass. In multifunction mode, this is the prefill sequence length. |
7171
| `--dtype` | `fp16` | Model dtype (`fp16` or `fp32`). The ANE requires fp16. |
72+
| `--sliding_window` | (off) | Sliding-window attention size. When set, every layer uses a KV cache of this many tokens instead of `max_context_len - input_len`. Required for Mistral / Gemma3 / Gemma4 / Llama4-style models trained with sliding-window attention; lets longer contexts run without growing per-layer attention compute or KV cache memory. |
7273

7374
### Quantization Options
7475
| Option | Default | Description |
@@ -94,6 +95,7 @@ The static model has several ANE optimizations, including:
9495
* Splitting linear layers for improved performance (controlled by target_split_size and max_splits args)
9596
* Splitting the pte into multiple Core ML pieces for improved performance (can be disabled with no_graph_breaks)
9697
* Re-writing SDPA to avoid 5-D tensors to improve performance. This also fixes an accuracy bug that was introduced in iOS 26 (addresses this: https://github.com/pytorch/executorch/issues/15833)
98+
* Sliding-window attention (`--sliding_window N`) caps each layer's KV cache at `N` tokens regardless of `max_context_len`. For models trained with sliding-window attention (Mistral 7B uses 4096; Gemma 3/Gemma 4 alternate sliding and full layers), this both matches training behavior and roughly halves KV-cache memory plus per-token attention FLOPs at long context. Per-layer mixed sliding/full attention is not yet exposed; today every layer shares the same window when the flag is set.
9799

98100
We are working on adding a C++ runner as well.
99101

examples/apple/coreml/llama/test.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@
99
sys.path.insert(0, ".")
1010
import copy
1111

12+
import pytest
1213
import torch
14+
from export_static_llm_coreml import _create_example_inputs, _resolve_cache_len
1315
from utils import replace_linear_with_split_linear
1416

17+
from executorch.examples.models.llama.model_args import ModelArgs
18+
1519

1620
def get_split_model(
1721
model,
@@ -44,5 +48,84 @@ def test_split_model():
4448
assert torch.allclose(model(inputs), model3(inputs), atol=1e-5)
4549

4650

51+
def test_resolve_cache_len_no_sliding_window():
52+
# Without --sliding_window the cache fills the rest of the context.
53+
assert _resolve_cache_len(1024, 32) == 992
54+
assert _resolve_cache_len(1024, 1) == 1023
55+
56+
57+
def test_resolve_cache_len_with_sliding_window():
58+
# When the window is smaller than the remaining context the cache shrinks.
59+
assert _resolve_cache_len(8192, 32, sliding_window=4096) == 4096
60+
assert _resolve_cache_len(8192, 1, sliding_window=4096) == 4096
61+
62+
63+
def test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op():
64+
# A user-provided window larger than the remaining context degenerates to
65+
# the no-window case, so users can safely set --sliding_window to a value
66+
# the model was trained with even when the export uses a shorter context.
67+
assert _resolve_cache_len(1024, 32, sliding_window=4096) == 992
68+
69+
70+
def test_resolve_cache_len_rejects_non_positive_window():
71+
with pytest.raises(ValueError):
72+
_resolve_cache_len(1024, 32, sliding_window=0)
73+
with pytest.raises(ValueError):
74+
_resolve_cache_len(1024, 32, sliding_window=-1)
75+
76+
77+
def test_create_example_inputs_with_sliding_window_shrinks_kv_cache():
78+
# Build a tiny ModelArgs that does not need a checkpoint or torchao.
79+
model_args = ModelArgs(
80+
dim=32,
81+
n_layers=2,
82+
n_heads=4,
83+
n_kv_heads=2,
84+
head_dim=8,
85+
vocab_size=128,
86+
max_context_len=1024,
87+
max_seq_len=1024,
88+
)
89+
max_context_len = 1024
90+
input_len = 32
91+
sliding_window = 64
92+
93+
cache_len = _resolve_cache_len(max_context_len, input_len, sliding_window)
94+
assert cache_len == sliding_window
95+
96+
example_inputs, returned_cache_len = _create_example_inputs(
97+
model_args,
98+
input_len,
99+
max_context_len,
100+
float_dtype=torch.float32,
101+
cache_len=cache_len,
102+
)
103+
assert returned_cache_len == sliding_window
104+
105+
# The KV cache tensors live inside the kwargs dict at index 1 under
106+
# in_cache_state. Walking that structure should find caches whose
107+
# sequence dimension equals the sliding window, not max_context_len.
108+
kwargs = example_inputs[1]
109+
in_cache_state = kwargs["in_cache_state"]
110+
cache_seq_dims = set()
111+
for per_kind in in_cache_state: # (k_caches, v_caches)
112+
for cache_tensor in per_kind.values():
113+
cache_seq_dims.add(cache_tensor.size(-2))
114+
assert cache_seq_dims == {sliding_window}, (
115+
f"expected every KV cache to be sized to the sliding window {sliding_window}, "
116+
f"got {cache_seq_dims}"
117+
)
118+
119+
# The attention mask covers (input_len + cache_len) along the last dim.
120+
masks = kwargs["masks"]
121+
assert sliding_window in masks
122+
assert masks[sliding_window].shape[-1] == input_len + sliding_window
123+
124+
47125
if __name__ == "__main__":
48126
test_split_model()
127+
test_resolve_cache_len_no_sliding_window()
128+
test_resolve_cache_len_with_sliding_window()
129+
test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op()
130+
test_resolve_cache_len_rejects_non_positive_window()
131+
test_create_example_inputs_with_sliding_window_shrinks_kv_cache()

0 commit comments

Comments
 (0)