Skip to content

Commit 9bdf04e

Browse files
committed
Add per-layer hybrid sliding/full attention to CoreML static LLM export
Builds on the prior --sliding_window flag. Gemma 3, Gemma 4, and the Llama 4 Scout family interleave sliding and full attention layers rather than using one global setting: Gemma 4 E2B is '4 sliding + 1 full' repeated 7 times across 35 layers; Gemma 3 is '5 sliding + 1 full' repeated. HuggingFace expresses this as a single integer `sliding_window_pattern`, which is what the new `--sliding_window_pattern` flag mirrors. Implementation: - `_resolve_per_layer_cache_lens(...)` produces a per-layer cache_lens list using the HF rule (layer i is full iff (i+1) % P == 0); the IO manager and the model already accept per-layer cache_lens, so the attention mask dict and the per-layer KV cache shapes follow. - `_get_metadata` now reads each cache's cache_len from the example tensor's sequence dimension instead of receiving a single scalar, so the C++ runner metadata describes each layer correctly under hybrid attention. - Both single-method and multifunction export paths use the per-layer resolver. The previous PR's uniform-sliding behavior is preserved when `--sliding_window_pattern` is not set. Authored with Claude.
1 parent 8a2dfb5 commit 9bdf04e

3 files changed

Lines changed: 251 additions & 8 deletions

File tree

examples/apple/coreml/llama/export_static_llm_coreml.py

Lines changed: 101 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import argparse
2323
import json
24-
from typing import Optional
24+
from typing import List, Optional
2525

2626
import coremltools as ct
2727
import torch
@@ -193,6 +193,51 @@ def _resolve_cache_len(
193193
return cache_len
194194

195195

196+
def _resolve_per_layer_cache_lens(
197+
n_layers: int,
198+
max_context_len: int,
199+
input_len: int,
200+
sliding_window: Optional[int] = None,
201+
sliding_window_pattern: Optional[int] = None,
202+
) -> List[int]:
203+
"""Compute per-layer KV cache lengths for hybrid sliding/full attention.
204+
205+
Returns a list of length ``n_layers``. When ``sliding_window_pattern`` is
206+
``P``, every ``P``-th layer (0-indexed: layers ``P-1, 2P-1, ...``) uses
207+
the full ``max_context_len - input_len`` cache; the remaining layers use
208+
``sliding_window``. This matches HuggingFace's ``sliding_window_pattern``
209+
convention used by Gemma 3 (P=6: 5 sliding + 1 full) and Gemma 4 E2B
210+
(P=5: 4 sliding + 1 full).
211+
212+
When ``sliding_window_pattern`` is ``None``, every layer uses the same
213+
cache length resolved by :func:`_resolve_cache_len`.
214+
"""
215+
full_cache_len = max_context_len - input_len
216+
sliding_cache_len = _resolve_cache_len(
217+
max_context_len, input_len, sliding_window
218+
)
219+
220+
if sliding_window_pattern is None:
221+
return [sliding_cache_len] * n_layers
222+
223+
if sliding_window is None:
224+
raise ValueError(
225+
"sliding_window_pattern requires sliding_window to be set"
226+
)
227+
if sliding_window_pattern <= 1:
228+
raise ValueError(
229+
"sliding_window_pattern must be at least 2 (P=1 would make every "
230+
f"layer full attention); got {sliding_window_pattern}"
231+
)
232+
233+
return [
234+
full_cache_len
235+
if (i + 1) % sliding_window_pattern == 0
236+
else sliding_cache_len
237+
for i in range(n_layers)
238+
]
239+
240+
196241
def _create_example_inputs(
197242
model_args, input_len, max_context_len, float_dtype, cache_len=None
198243
):
@@ -296,6 +341,10 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype)
296341

297342
# Output indices are in the same order (after logits)
298343
# Logits is output 0, then k_caches, then v_caches
344+
# Read each cache's actual cache_len from the example tensor shape so
345+
# per-layer hybrid sliding/full attention (Gemma 3/4) reports the right
346+
# length per layer instead of a single uniform value.
347+
k_cache_tensors = example_inputs[1]["in_cache_state"][0]
299348
kv_cache_specs = []
300349
for i, cache_id in enumerate(sorted_k_cache_ids):
301350
k_in_idx = k_cache_in_indices[cache_id]
@@ -304,7 +353,10 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype)
304353
# v_caches come after k_caches (idx n_layers+1 to 2*n_layers)
305354
k_out_idx = 1 + i
306355
v_out_idx = 1 + len(sorted_k_cache_ids) + i
307-
kv_cache_specs.append([k_in_idx, k_out_idx, v_in_idx, v_out_idx, cache_len])
356+
per_cache_len = k_cache_tensors[cache_id].size(-2)
357+
kv_cache_specs.append(
358+
[k_in_idx, k_out_idx, v_in_idx, v_out_idx, per_cache_len]
359+
)
308360

309361
print(f"KV cache specs (k_in, k_out, v_in, v_out, cache_len): {kv_cache_specs}")
310362

@@ -445,6 +497,19 @@ def main():
445497
"Llama4-style models that train with sliding-window attention."
446498
),
447499
)
500+
parser.add_argument(
501+
"--sliding_window_pattern",
502+
type=int,
503+
default=None,
504+
help=(
505+
"Period of the sliding/full attention pattern (HuggingFace's "
506+
"sliding_window_pattern). When set together with --sliding_window, "
507+
"every P-th layer (1-indexed) uses full attention while the rest use "
508+
"the sliding window. Use P=5 for Gemma 4 E2B (4 sliding + 1 full) "
509+
"and P=6 for Gemma 3 (5 sliding + 1 full). Without this flag every "
510+
"layer uses the sliding window."
511+
),
512+
)
448513
parser.add_argument(
449514
"--dtype",
450515
type=str,
@@ -516,6 +581,9 @@ def main():
516581
print(f"\tLinear quantize: {args.linear_quantize}")
517582
print(f"\tDtype: {args.dtype}")
518583

584+
if args.sliding_window_pattern is not None and args.sliding_window is None:
585+
parser.error("--sliding_window_pattern requires --sliding_window to be set")
586+
519587
cache_len = _resolve_cache_len(
520588
args.max_context_len, args.input_len, args.sliding_window
521589
)
@@ -525,6 +593,8 @@ def main():
525593
print(f"\tCache length: {cache_len}")
526594
if args.sliding_window is not None:
527595
print(f"\tSliding window: {args.sliding_window}")
596+
if args.sliding_window_pattern is not None:
597+
print(f"\tSliding window pattern: every {args.sliding_window_pattern}-th layer is full")
528598

529599
print("\nLinear splitting:")
530600
print(f"\tTarget split size: {args.target_split_size}")
@@ -552,11 +622,22 @@ def main():
552622
# the same cache buffer at runtime without any copying.
553623
decode_input_len = 1
554624
prefill_input_len = args.input_len # default 32
555-
shared_cache_len = _resolve_cache_len(
556-
args.max_context_len, decode_input_len, args.sliding_window
625+
shared_cache_len = _resolve_per_layer_cache_lens(
626+
n_layers=model_args.n_layers,
627+
max_context_len=args.max_context_len,
628+
input_len=decode_input_len,
629+
sliding_window=args.sliding_window,
630+
sliding_window_pattern=args.sliding_window_pattern,
557631
)
558632

559-
print(f"\nShared cache length for prefill/decode: {shared_cache_len}")
633+
if args.sliding_window_pattern is not None:
634+
n_full = sum(1 for cl in shared_cache_len if cl == args.max_context_len - decode_input_len)
635+
print(
636+
f"\nShared cache lengths for prefill/decode: {n_full} full + "
637+
f"{model_args.n_layers - n_full} sliding ({args.sliding_window} tokens)"
638+
)
639+
else:
640+
print(f"\nShared cache length for prefill/decode: {shared_cache_len[0]}")
560641

561642
print(f"\nCreating example inputs for decode (seqlen={decode_input_len})...")
562643
decode_inputs, decode_cache_len = _create_example_inputs(
@@ -680,13 +761,27 @@ def main():
680761
)
681762
else:
682763
# Single method mode: fixed seqlen with generate_full_logits=True for lookahead
764+
per_layer_cache_lens = _resolve_per_layer_cache_lens(
765+
n_layers=model_args.n_layers,
766+
max_context_len=args.max_context_len,
767+
input_len=args.input_len,
768+
sliding_window=args.sliding_window,
769+
sliding_window_pattern=args.sliding_window_pattern,
770+
)
771+
if args.sliding_window_pattern is not None:
772+
full_cache_len = args.max_context_len - args.input_len
773+
n_full = sum(1 for cl in per_layer_cache_lens if cl == full_cache_len)
774+
print(
775+
f"\nCache length per layer: {n_full} full ({full_cache_len} tokens) + "
776+
f"{model_args.n_layers - n_full} sliding ({args.sliding_window} tokens)"
777+
)
683778
print(f"\nCreating example inputs (seqlen={args.input_len})...")
684779
example_inputs, example_cache_len = _create_example_inputs(
685780
model_args,
686781
args.input_len,
687782
args.max_context_len,
688783
float_dtype,
689-
cache_len=cache_len,
784+
cache_len=per_layer_cache_lens,
690785
)
691786

692787
# Test eager execution

examples/apple/coreml/llama/readme.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ Key differences between the two modes:
7070
| `--input_len` | 32 | Input sequence length per forward pass. In multifunction mode, this is the prefill sequence length. |
7171
| `--dtype` | `fp16` | Model dtype (`fp16` or `fp32`). The ANE requires fp16. |
7272
| `--sliding_window` | (off) | Sliding-window attention size. When set, every layer uses a KV cache of this many tokens instead of `max_context_len - input_len`. Required for Mistral / Gemma3 / Gemma4 / Llama4-style models trained with sliding-window attention; lets longer contexts run without growing per-layer attention compute or KV cache memory. |
73+
| `--sliding_window_pattern` | (off) | Period of the sliding/full attention pattern (HuggingFace's `sliding_window_pattern`). When set together with `--sliding_window`, every P-th layer (1-indexed) uses full attention while the rest use the sliding window. Use `P=5` for Gemma 4 E2B (4 sliding + 1 full) and `P=6` for Gemma 3 (5 sliding + 1 full). |
7374

7475
### Quantization Options
7576
| Option | Default | Description |
@@ -95,7 +96,7 @@ The static model has several ANE optimizations, including:
9596
* Splitting linear layers for improved performance (controlled by target_split_size and max_splits args)
9697
* Splitting the pte into multiple Core ML pieces for improved performance (can be disabled with no_graph_breaks)
9798
* Re-writing SDPA to avoid 5-D tensors to improve performance. This also fixes an accuracy bug that was introduced in iOS 26 (addresses this: https://github.com/pytorch/executorch/issues/15833)
98-
* Sliding-window attention (`--sliding_window N`) caps each layer's KV cache at `N` tokens regardless of `max_context_len`. For models trained with sliding-window attention (Mistral 7B uses 4096; Gemma 3/Gemma 4 alternate sliding and full layers), this both matches training behavior and roughly halves KV-cache memory plus per-token attention FLOPs at long context. Per-layer mixed sliding/full attention is not yet exposed; today every layer shares the same window when the flag is set.
99+
* Sliding-window attention (`--sliding_window N`) caps each layer's KV cache at `N` tokens regardless of `max_context_len`. For models trained with sliding-window attention (Mistral 7B uses 4096; Gemma 3 / Gemma 4 alternate sliding and full layers), this both matches training behavior and roughly halves KV-cache memory plus per-token attention FLOPs at long context. Pair with `--sliding_window_pattern P` to mix sliding and full layers in the HuggingFace pattern (every P-th layer is full attention): P=5 for Gemma 4 E2B (4 sliding + 1 full) and P=6 for Gemma 3 (5 sliding + 1 full).
99100

100101
We are working on adding a C++ runner as well.
101102

examples/apple/coreml/llama/test.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111

1212
import pytest
1313
import torch
14-
from export_static_llm_coreml import _create_example_inputs, _resolve_cache_len
14+
from export_static_llm_coreml import (
15+
_create_example_inputs,
16+
_resolve_cache_len,
17+
_resolve_per_layer_cache_lens,
18+
)
1519
from utils import replace_linear_with_split_linear
1620

1721
from executorch.examples.models.llama.model_args import ModelArgs
@@ -122,10 +126,153 @@ def test_create_example_inputs_with_sliding_window_shrinks_kv_cache():
122126
assert masks[sliding_window].shape[-1] == input_len + sliding_window
123127

124128

129+
def test_per_layer_cache_lens_uniform_when_no_pattern():
130+
# Without a pattern every layer gets the same cache length.
131+
out = _resolve_per_layer_cache_lens(
132+
n_layers=4, max_context_len=1024, input_len=32, sliding_window=64
133+
)
134+
assert out == [64, 64, 64, 64]
135+
136+
137+
def test_per_layer_cache_lens_uniform_full_when_no_window():
138+
# No window at all is just `max_context_len - input_len` everywhere.
139+
out = _resolve_per_layer_cache_lens(
140+
n_layers=4, max_context_len=1024, input_len=32
141+
)
142+
assert out == [992, 992, 992, 992]
143+
144+
145+
def test_per_layer_cache_lens_gemma4_e2b_pattern():
146+
# Gemma 4 E2B: 35 layers, P=5 → 4 sliding + 1 full repeated 7 times.
147+
out = _resolve_per_layer_cache_lens(
148+
n_layers=35,
149+
max_context_len=8192,
150+
input_len=32,
151+
sliding_window=4096,
152+
sliding_window_pattern=5,
153+
)
154+
full = 8192 - 32
155+
sliding = 4096
156+
assert len(out) == 35
157+
# Layers at 1-indexed positions 5, 10, 15, …, 35 are full.
158+
assert [out[i] for i in range(35)] == [
159+
full if (i + 1) % 5 == 0 else sliding for i in range(35)
160+
]
161+
assert sum(1 for cl in out if cl == full) == 7
162+
assert sum(1 for cl in out if cl == sliding) == 28
163+
164+
165+
def test_per_layer_cache_lens_gemma3_pattern():
166+
# Gemma 3 uses P=6 (5 sliding + 1 full).
167+
out = _resolve_per_layer_cache_lens(
168+
n_layers=12,
169+
max_context_len=2048,
170+
input_len=32,
171+
sliding_window=512,
172+
sliding_window_pattern=6,
173+
)
174+
full = 2048 - 32
175+
sliding = 512
176+
assert out == [
177+
sliding,
178+
sliding,
179+
sliding,
180+
sliding,
181+
sliding,
182+
full,
183+
sliding,
184+
sliding,
185+
sliding,
186+
sliding,
187+
sliding,
188+
full,
189+
]
190+
191+
192+
def test_per_layer_cache_lens_pattern_requires_sliding_window():
193+
with pytest.raises(ValueError):
194+
_resolve_per_layer_cache_lens(
195+
n_layers=8,
196+
max_context_len=1024,
197+
input_len=32,
198+
sliding_window=None,
199+
sliding_window_pattern=5,
200+
)
201+
202+
203+
def test_per_layer_cache_lens_rejects_pattern_le_one():
204+
# P=1 would make every layer full and is almost certainly a typo, so
205+
# surface it rather than silently doing the no-pattern thing.
206+
with pytest.raises(ValueError):
207+
_resolve_per_layer_cache_lens(
208+
n_layers=8,
209+
max_context_len=1024,
210+
input_len=32,
211+
sliding_window=64,
212+
sliding_window_pattern=1,
213+
)
214+
215+
216+
def test_create_example_inputs_with_per_layer_pattern_yields_two_cache_sizes():
217+
model_args = ModelArgs(
218+
dim=32,
219+
n_layers=10,
220+
n_heads=4,
221+
n_kv_heads=2,
222+
head_dim=8,
223+
vocab_size=128,
224+
max_context_len=1024,
225+
max_seq_len=1024,
226+
)
227+
max_context_len = 1024
228+
input_len = 32
229+
sliding_window = 64
230+
pattern = 5
231+
232+
cache_lens = _resolve_per_layer_cache_lens(
233+
n_layers=model_args.n_layers,
234+
max_context_len=max_context_len,
235+
input_len=input_len,
236+
sliding_window=sliding_window,
237+
sliding_window_pattern=pattern,
238+
)
239+
240+
example_inputs, _ = _create_example_inputs(
241+
model_args,
242+
input_len,
243+
max_context_len,
244+
float_dtype=torch.float32,
245+
cache_len=cache_lens,
246+
)
247+
248+
in_cache_state = example_inputs[1]["in_cache_state"]
249+
seen = set()
250+
for per_kind in in_cache_state:
251+
for tensor in per_kind.values():
252+
seen.add(tensor.size(-2))
253+
full = max_context_len - input_len
254+
assert seen == {sliding_window, full}, (
255+
f"expected both {sliding_window} (sliding) and {full} (full) cache sizes, got {seen}"
256+
)
257+
258+
# Both cache_len values get their own mask; (input_len + cache_len) per mask.
259+
masks = example_inputs[1]["masks"]
260+
assert set(masks.keys()) == {sliding_window, full}
261+
assert masks[sliding_window].shape[-1] == input_len + sliding_window
262+
assert masks[full].shape[-1] == input_len + full
263+
264+
125265
if __name__ == "__main__":
126266
test_split_model()
127267
test_resolve_cache_len_no_sliding_window()
128268
test_resolve_cache_len_with_sliding_window()
129269
test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op()
130270
test_resolve_cache_len_rejects_non_positive_window()
131271
test_create_example_inputs_with_sliding_window_shrinks_kv_cache()
272+
test_per_layer_cache_lens_uniform_when_no_pattern()
273+
test_per_layer_cache_lens_uniform_full_when_no_window()
274+
test_per_layer_cache_lens_gemma4_e2b_pattern()
275+
test_per_layer_cache_lens_gemma3_pattern()
276+
test_per_layer_cache_lens_pattern_requires_sliding_window()
277+
test_per_layer_cache_lens_rejects_pattern_le_one()
278+
test_create_example_inputs_with_per_layer_pattern_yields_two_cache_sizes()

0 commit comments

Comments
 (0)