Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions examples/apple/coreml/llama/export_static_llm_coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import argparse
import json
from typing import Optional

import coremltools as ct
import torch
Expand Down Expand Up @@ -170,6 +171,28 @@ def load_model(
return model, args


def _resolve_cache_len(
max_context_len: int, input_len: int, sliding_window: Optional[int] = None
) -> int:
"""Pick the per-layer KV cache length given context / input / window settings.

Without sliding-window attention the cache must hold every token that can
attend to the current step, i.e. ``max_context_len - input_len``. When the
model is trained with sliding-window attention we instead cap the cache at
``sliding_window`` so longer contexts do not enlarge per-layer attention
compute or KV cache memory.
"""
cache_len = max_context_len - input_len
if sliding_window is not None:
if sliding_window <= 0:
raise ValueError(
f"sliding_window must be positive, got {sliding_window}"
)
if sliding_window < cache_len:
cache_len = sliding_window
return cache_len


def _create_example_inputs(
model_args, input_len, max_context_len, float_dtype, cache_len=None
):
Expand Down Expand Up @@ -410,6 +433,18 @@ def main():
default=32,
help="Input sequence length per forward pass",
)
parser.add_argument(
"--sliding_window",
type=int,
default=None,
help=(
"Sliding window attention size. When set, every layer uses a KV cache "
"of this many tokens instead of (max_context_len - input_len), which "
"lets the model serve longer contexts without growing per-layer attention "
"compute or KV cache memory. Required for Mistral / Gemma3 / Gemma4 / "
"Llama4-style models that train with sliding-window attention."
),
)
parser.add_argument(
"--dtype",
type=str,
Expand Down Expand Up @@ -481,11 +516,15 @@ def main():
print(f"\tLinear quantize: {args.linear_quantize}")
print(f"\tDtype: {args.dtype}")

cache_len = args.max_context_len - args.input_len
cache_len = _resolve_cache_len(
args.max_context_len, args.input_len, args.sliding_window
)
print("\nGeneration configuration:")
print(f"\tMax context length: {args.max_context_len}")
print(f"\tInput length: {args.input_len}")
print(f"\tCache length: {cache_len}")
if args.sliding_window is not None:
print(f"\tSliding window: {args.sliding_window}")

print("\nLinear splitting:")
print(f"\tTarget split size: {args.target_split_size}")
Expand Down Expand Up @@ -513,9 +552,9 @@ def main():
# the same cache buffer at runtime without any copying.
decode_input_len = 1
prefill_input_len = args.input_len # default 32
shared_cache_len = (
args.max_context_len - decode_input_len
) # Use decode's cache size for both
shared_cache_len = _resolve_cache_len(
args.max_context_len, decode_input_len, args.sliding_window
)

print(f"\nShared cache length for prefill/decode: {shared_cache_len}")

Expand Down Expand Up @@ -643,7 +682,11 @@ def main():
# Single method mode: fixed seqlen with generate_full_logits=True for lookahead
print(f"\nCreating example inputs (seqlen={args.input_len})...")
example_inputs, example_cache_len = _create_example_inputs(
model_args, args.input_len, args.max_context_len, float_dtype
model_args,
args.input_len,
args.max_context_len,
float_dtype,
cache_len=cache_len,
)

# Test eager execution
Expand Down
2 changes: 2 additions & 0 deletions examples/apple/coreml/llama/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Key differences between the two modes:
| `--max_context_len` | 1024 | Maximum context length |
| `--input_len` | 32 | Input sequence length per forward pass. In multifunction mode, this is the prefill sequence length. |
| `--dtype` | `fp16` | Model dtype (`fp16` or `fp32`). The ANE requires fp16. |
| `--sliding_window` | (off) | Sliding-window attention size. When set, every layer uses a KV cache of this many tokens instead of `max_context_len - input_len`. Required for Mistral / Gemma3 / Gemma4 / Llama4-style models trained with sliding-window attention; lets longer contexts run without growing per-layer attention compute or KV cache memory. |

### Quantization Options
| Option | Default | Description |
Expand All @@ -94,6 +95,7 @@ The static model has several ANE optimizations, including:
* Splitting linear layers for improved performance (controlled by target_split_size and max_splits args)
* Splitting the pte into multiple Core ML pieces for improved performance (can be disabled with no_graph_breaks)
* Re-writing SDPA to avoid 5-D tensors to improve performance. This also fixes an accuracy bug that was introduced in iOS 26 (addresses this: https://github.com/pytorch/executorch/issues/15833)
* Sliding-window attention (`--sliding_window N`) caps each layer's KV cache at `N` tokens regardless of `max_context_len`. For models trained with sliding-window attention (Mistral 7B uses 4096; Gemma 3/Gemma 4 alternate sliding and full layers), this both matches training behavior and roughly halves KV-cache memory plus per-token attention FLOPs at long context. Per-layer mixed sliding/full attention is not yet exposed; today every layer shares the same window when the flag is set.

We are working on adding a C++ runner as well.

Expand Down
83 changes: 83 additions & 0 deletions examples/apple/coreml/llama/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
sys.path.insert(0, ".")
import copy

import pytest
import torch
from export_static_llm_coreml import _create_example_inputs, _resolve_cache_len
from utils import replace_linear_with_split_linear

from executorch.examples.models.llama.model_args import ModelArgs


def get_split_model(
model,
Expand Down Expand Up @@ -44,5 +48,84 @@ def test_split_model():
assert torch.allclose(model(inputs), model3(inputs), atol=1e-5)


def test_resolve_cache_len_no_sliding_window():
# Without --sliding_window the cache fills the rest of the context.
assert _resolve_cache_len(1024, 32) == 992
assert _resolve_cache_len(1024, 1) == 1023


def test_resolve_cache_len_with_sliding_window():
# When the window is smaller than the remaining context the cache shrinks.
assert _resolve_cache_len(8192, 32, sliding_window=4096) == 4096
assert _resolve_cache_len(8192, 1, sliding_window=4096) == 4096


def test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op():
# A user-provided window larger than the remaining context degenerates to
# the no-window case, so users can safely set --sliding_window to a value
# the model was trained with even when the export uses a shorter context.
assert _resolve_cache_len(1024, 32, sliding_window=4096) == 992


def test_resolve_cache_len_rejects_non_positive_window():
with pytest.raises(ValueError):
_resolve_cache_len(1024, 32, sliding_window=0)
with pytest.raises(ValueError):
_resolve_cache_len(1024, 32, sliding_window=-1)


def test_create_example_inputs_with_sliding_window_shrinks_kv_cache():
# Build a tiny ModelArgs that does not need a checkpoint or torchao.
model_args = ModelArgs(
dim=32,
n_layers=2,
n_heads=4,
n_kv_heads=2,
head_dim=8,
vocab_size=128,
max_context_len=1024,
max_seq_len=1024,
)
max_context_len = 1024
input_len = 32
sliding_window = 64

cache_len = _resolve_cache_len(max_context_len, input_len, sliding_window)
assert cache_len == sliding_window

example_inputs, returned_cache_len = _create_example_inputs(
model_args,
input_len,
max_context_len,
float_dtype=torch.float32,
cache_len=cache_len,
)
assert returned_cache_len == sliding_window

# The KV cache tensors live inside the kwargs dict at index 1 under
# in_cache_state. Walking that structure should find caches whose
# sequence dimension equals the sliding window, not max_context_len.
kwargs = example_inputs[1]
in_cache_state = kwargs["in_cache_state"]
cache_seq_dims = set()
for per_kind in in_cache_state: # (k_caches, v_caches)
for cache_tensor in per_kind.values():
cache_seq_dims.add(cache_tensor.size(-2))
assert cache_seq_dims == {sliding_window}, (
f"expected every KV cache to be sized to the sliding window {sliding_window}, "
f"got {cache_seq_dims}"
)

# The attention mask covers (input_len + cache_len) along the last dim.
masks = kwargs["masks"]
assert sliding_window in masks
assert masks[sliding_window].shape[-1] == input_len + sliding_window


if __name__ == "__main__":
test_split_model()
test_resolve_cache_len_no_sliding_window()
test_resolve_cache_len_with_sliding_window()
test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op()
test_resolve_cache_len_rejects_non_positive_window()
test_create_example_inputs_with_sliding_window_shrinks_kv_cache()
Loading