diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index 276ff6d193a..3a70ad02d53 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -21,6 +21,7 @@ import argparse import json +from typing import Optional import coremltools as ct import torch @@ -170,6 +171,28 @@ def load_model( return model, args +def _resolve_cache_len( + max_context_len: int, input_len: int, sliding_window: Optional[int] = None +) -> int: + """Pick the per-layer KV cache length given context / input / window settings. + + Without sliding-window attention the cache must hold every token that can + attend to the current step, i.e. ``max_context_len - input_len``. When the + model is trained with sliding-window attention we instead cap the cache at + ``sliding_window`` so longer contexts do not enlarge per-layer attention + compute or KV cache memory. + """ + cache_len = max_context_len - input_len + if sliding_window is not None: + if sliding_window <= 0: + raise ValueError( + f"sliding_window must be positive, got {sliding_window}" + ) + if sliding_window < cache_len: + cache_len = sliding_window + return cache_len + + def _create_example_inputs( model_args, input_len, max_context_len, float_dtype, cache_len=None ): @@ -410,6 +433,18 @@ def main(): default=32, help="Input sequence length per forward pass", ) + parser.add_argument( + "--sliding_window", + type=int, + default=None, + help=( + "Sliding window attention size. When set, every layer uses a KV cache " + "of this many tokens instead of (max_context_len - input_len), which " + "lets the model serve longer contexts without growing per-layer attention " + "compute or KV cache memory. Required for Mistral / Gemma3 / Gemma4 / " + "Llama4-style models that train with sliding-window attention." + ), + ) parser.add_argument( "--dtype", type=str, @@ -481,11 +516,15 @@ def main(): print(f"\tLinear quantize: {args.linear_quantize}") print(f"\tDtype: {args.dtype}") - cache_len = args.max_context_len - args.input_len + cache_len = _resolve_cache_len( + args.max_context_len, args.input_len, args.sliding_window + ) print("\nGeneration configuration:") print(f"\tMax context length: {args.max_context_len}") print(f"\tInput length: {args.input_len}") print(f"\tCache length: {cache_len}") + if args.sliding_window is not None: + print(f"\tSliding window: {args.sliding_window}") print("\nLinear splitting:") print(f"\tTarget split size: {args.target_split_size}") @@ -513,9 +552,9 @@ def main(): # the same cache buffer at runtime without any copying. decode_input_len = 1 prefill_input_len = args.input_len # default 32 - shared_cache_len = ( - args.max_context_len - decode_input_len - ) # Use decode's cache size for both + shared_cache_len = _resolve_cache_len( + args.max_context_len, decode_input_len, args.sliding_window + ) print(f"\nShared cache length for prefill/decode: {shared_cache_len}") @@ -643,7 +682,11 @@ def main(): # Single method mode: fixed seqlen with generate_full_logits=True for lookahead print(f"\nCreating example inputs (seqlen={args.input_len})...") example_inputs, example_cache_len = _create_example_inputs( - model_args, args.input_len, args.max_context_len, float_dtype + model_args, + args.input_len, + args.max_context_len, + float_dtype, + cache_len=cache_len, ) # Test eager execution diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index ae3852a7828..5f8aa1bcd78 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -69,6 +69,7 @@ Key differences between the two modes: | `--max_context_len` | 1024 | Maximum context length | | `--input_len` | 32 | Input sequence length per forward pass. In multifunction mode, this is the prefill sequence length. | | `--dtype` | `fp16` | Model dtype (`fp16` or `fp32`). The ANE requires fp16. | +| `--sliding_window` | (off) | Sliding-window attention size. When set, every layer uses a KV cache of this many tokens instead of `max_context_len - input_len`. Required for Mistral / Gemma3 / Gemma4 / Llama4-style models trained with sliding-window attention; lets longer contexts run without growing per-layer attention compute or KV cache memory. | ### Quantization Options | Option | Default | Description | @@ -94,6 +95,7 @@ The static model has several ANE optimizations, including: * Splitting linear layers for improved performance (controlled by target_split_size and max_splits args) * Splitting the pte into multiple Core ML pieces for improved performance (can be disabled with no_graph_breaks) * Re-writing SDPA to avoid 5-D tensors to improve performance. This also fixes an accuracy bug that was introduced in iOS 26 (addresses this: https://github.com/pytorch/executorch/issues/15833) +* Sliding-window attention (`--sliding_window N`) caps each layer's KV cache at `N` tokens regardless of `max_context_len`. For models trained with sliding-window attention (Mistral 7B uses 4096; Gemma 3/Gemma 4 alternate sliding and full layers), this both matches training behavior and roughly halves KV-cache memory plus per-token attention FLOPs at long context. Per-layer mixed sliding/full attention is not yet exposed; today every layer shares the same window when the flag is set. We are working on adding a C++ runner as well. diff --git a/examples/apple/coreml/llama/test.py b/examples/apple/coreml/llama/test.py index 895cf2e1cce..73ca7712b1a 100644 --- a/examples/apple/coreml/llama/test.py +++ b/examples/apple/coreml/llama/test.py @@ -9,9 +9,13 @@ sys.path.insert(0, ".") import copy +import pytest import torch +from export_static_llm_coreml import _create_example_inputs, _resolve_cache_len from utils import replace_linear_with_split_linear +from executorch.examples.models.llama.model_args import ModelArgs + def get_split_model( model, @@ -44,5 +48,84 @@ def test_split_model(): assert torch.allclose(model(inputs), model3(inputs), atol=1e-5) +def test_resolve_cache_len_no_sliding_window(): + # Without --sliding_window the cache fills the rest of the context. + assert _resolve_cache_len(1024, 32) == 992 + assert _resolve_cache_len(1024, 1) == 1023 + + +def test_resolve_cache_len_with_sliding_window(): + # When the window is smaller than the remaining context the cache shrinks. + assert _resolve_cache_len(8192, 32, sliding_window=4096) == 4096 + assert _resolve_cache_len(8192, 1, sliding_window=4096) == 4096 + + +def test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op(): + # A user-provided window larger than the remaining context degenerates to + # the no-window case, so users can safely set --sliding_window to a value + # the model was trained with even when the export uses a shorter context. + assert _resolve_cache_len(1024, 32, sliding_window=4096) == 992 + + +def test_resolve_cache_len_rejects_non_positive_window(): + with pytest.raises(ValueError): + _resolve_cache_len(1024, 32, sliding_window=0) + with pytest.raises(ValueError): + _resolve_cache_len(1024, 32, sliding_window=-1) + + +def test_create_example_inputs_with_sliding_window_shrinks_kv_cache(): + # Build a tiny ModelArgs that does not need a checkpoint or torchao. + model_args = ModelArgs( + dim=32, + n_layers=2, + n_heads=4, + n_kv_heads=2, + head_dim=8, + vocab_size=128, + max_context_len=1024, + max_seq_len=1024, + ) + max_context_len = 1024 + input_len = 32 + sliding_window = 64 + + cache_len = _resolve_cache_len(max_context_len, input_len, sliding_window) + assert cache_len == sliding_window + + example_inputs, returned_cache_len = _create_example_inputs( + model_args, + input_len, + max_context_len, + float_dtype=torch.float32, + cache_len=cache_len, + ) + assert returned_cache_len == sliding_window + + # The KV cache tensors live inside the kwargs dict at index 1 under + # in_cache_state. Walking that structure should find caches whose + # sequence dimension equals the sliding window, not max_context_len. + kwargs = example_inputs[1] + in_cache_state = kwargs["in_cache_state"] + cache_seq_dims = set() + for per_kind in in_cache_state: # (k_caches, v_caches) + for cache_tensor in per_kind.values(): + cache_seq_dims.add(cache_tensor.size(-2)) + assert cache_seq_dims == {sliding_window}, ( + f"expected every KV cache to be sized to the sliding window {sliding_window}, " + f"got {cache_seq_dims}" + ) + + # The attention mask covers (input_len + cache_len) along the last dim. + masks = kwargs["masks"] + assert sliding_window in masks + assert masks[sliding_window].shape[-1] == input_len + sliding_window + + if __name__ == "__main__": test_split_model() + test_resolve_cache_len_no_sliding_window() + test_resolve_cache_len_with_sliding_window() + test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op() + test_resolve_cache_len_rejects_non_positive_window() + test_create_example_inputs_with_sliding_window_shrinks_kv_cache()