Skip to content

On-device perf + memory optimizations: custom SDPA, on-the-fly RoPE, KV cache fix, XNNPACK workspace sharing (#19214)#19214

Open
leixin wants to merge 1 commit intopytorch:mainfrom
leixin:export-D102710062
Open

On-device perf + memory optimizations: custom SDPA, on-the-fly RoPE, KV cache fix, XNNPACK workspace sharing (#19214)#19214
leixin wants to merge 1 commit intopytorch:mainfrom
leixin:export-D102710062

Conversation

@leixin
Copy link
Copy Markdown
Contributor

@leixin leixin commented Apr 29, 2026

Summary:

Six changes for the Gemma 4 text decoder + runner, enabled by default. Custom SDPA can be opted out via --no-use_custom_sdpa for eager mode or non-XNNPACK backends. Workspace sharing can be opted out via --noenable_workspace_sharing for debugging.

  1. Custom SDPA — attention now runs through torch.ops.llama.custom_sdpa (tiled flash attention from the Llama runner). Skips the 8x KV expansion that GQA/MQA otherwise requires, and never materializes the full [seq, seq] attention matrix — the matmul fallback's [bs, heads, seq, seq] tensor exceeds S25's 8 MB L2 cache at seq=2048 and causes severe regression. Adds an inline INT8 dequant path for Gemma4QuantizedKVCache(return_float_values=False) that stays inside the XNNPACK partition.

  2. On-the-fly RoPE — the attention module stores only the inv_freq vector (~128-256 floats) and computes cos/sin per forward, instead of registering precomputed [max_seq_len, head_dim] cos/sin buffers. Reduces PTE size 3-7%.

  3. KV cache allocation is skipped for is_kv_shared_layer=True. In YOCO, 20 of 35 layers consume the donor's KV via shared_kv and never write to their own cache, so the allocation was dead. Saves ~40 MB at seq=1024, ~80 MB at seq=2048.

  4. XNNPACK workspace sharing in runner. Gemma4Runner::load() now calls set_option(workspace_sharing_mode_option_key=PerModel, weight_cache_option_key=true) on the XNNPACK backend before module load. Default-on with enable_workspace_sharing constructor flag for opt-out. Without this, real Android/iOS app builds (which don't pass the bench's compile-time --config xnnpack_workspace_sharing=1) end up with Disabled mode and OOM crash silently on E4B (>2 GB peak memory regression reported by app teams). Compile-time flag in xplat/.../gemma4/targets.bzl (-DENABLE_XNNPACK_SHARED_WORKSPACE) is also removed since it was dead — Buck preprocessor flags don't reach XNNWorkspaceManager.cpp (which lives in the xnnpack_backend compile unit).

  5. Correctness fix for KV cache quant + custom SDPA + YOCO. When Gemma4QuantizedKVCache(return_float_values=False) is in use, the donor layer now dequants K/V before storing in kv_to_share so cross-decoder layers (which lack access to the donor's scales) don't pass raw int8 to custom_sdpa. Dormant bug: only triggers with --quantize_kv_cache --use_custom_sdpa; previously crashed export with AssertionError: Expected key to be float32.

  6. iOS VmRSS sscanf fix (consolidates D103030061). Gemma4Stats::read_rss_kb() uses SCNd64 from <cinttypes> instead of %ld so the format matches int64_t on both LP64 (Linux/Android) and LLP64-ish (iOS arm64) platforms. Unblocks iOS sample app builds with -Werror,-Wformat.

Mask construction is factored into _build_attn_mask / _slice_mask helpers shared between the custom-SDPA and matmul branches.

Differential Revision: D102710062

@leixin leixin requested a review from lucylq as a code owner April 29, 2026 22:42
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 29, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19214

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures, 1 Unrelated Failure

As of commit 3b254e4 with merge base d767516 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 29, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Apr 29, 2026

@leixin has exported this pull request. If you are a Meta employee, you can view the originating Diff in D102710062.

@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

…KV cache fix, XNNPACK workspace sharing (pytorch#19214)

Summary:

Six changes for the Gemma 4 text decoder + runner, enabled by default. Custom SDPA can be opted out via `--no-use_custom_sdpa` for eager mode or non-XNNPACK backends. Workspace sharing can be opted out via `--noenable_workspace_sharing` for debugging.

1. Custom SDPA — attention now runs through `torch.ops.llama.custom_sdpa` (tiled flash attention from the Llama runner). Skips the 8x KV expansion that GQA/MQA otherwise requires, and never materializes the full `[seq, seq]` attention matrix — the matmul fallback's `[bs, heads, seq, seq]` tensor exceeds S25's 8 MB L2 cache at `seq=2048` and causes severe regression. Adds an inline INT8 dequant path for `Gemma4QuantizedKVCache(return_float_values=False)` that stays inside the XNNPACK partition.

2. On-the-fly RoPE — the attention module stores only the `inv_freq` vector (~128-256 floats) and computes cos/sin per forward, instead of registering precomputed `[max_seq_len, head_dim]` cos/sin buffers. Reduces PTE size 3-7%.

3. KV cache allocation is skipped for `is_kv_shared_layer=True`. In YOCO, 20 of 35 layers consume the donor's KV via `shared_kv` and never write to their own cache, so the allocation was dead. Saves ~40 MB at `seq=1024`, ~80 MB at `seq=2048`.

4. XNNPACK workspace sharing in runner. `Gemma4Runner::load()` now calls `set_option(workspace_sharing_mode_option_key=PerModel, weight_cache_option_key=true)` on the XNNPACK backend before module load. Default-on with `enable_workspace_sharing` constructor flag for opt-out. Without this, real Android/iOS app builds (which don't pass the bench's compile-time `--config xnnpack_workspace_sharing=1`) end up with `Disabled` mode and OOM crash silently on E4B (>2 GB peak memory regression reported by app teams). Compile-time flag in xplat/.../gemma4/targets.bzl (`-DENABLE_XNNPACK_SHARED_WORKSPACE`) is also removed since it was dead — Buck preprocessor flags don't reach `XNNWorkspaceManager.cpp` (which lives in the `xnnpack_backend` compile unit).

5. Correctness fix for KV cache quant + custom SDPA + YOCO. When `Gemma4QuantizedKVCache(return_float_values=False)` is in use, the donor layer now dequants K/V before storing in `kv_to_share` so cross-decoder layers (which lack access to the donor's scales) don't pass raw int8 to `custom_sdpa`. Dormant bug: only triggers with `--quantize_kv_cache --use_custom_sdpa`; previously crashed export with `AssertionError: Expected key to be float32`.

6. iOS VmRSS sscanf fix (consolidates D103030061). `Gemma4Stats::read_rss_kb()` uses `SCNd64` from `<cinttypes>` instead of `%ld` so the format matches `int64_t` on both LP64 (Linux/Android) and LLP64-ish (iOS arm64) platforms. Unblocks iOS sample app builds with `-Werror,-Wformat`.

Mask construction is factored into `_build_attn_mask` / `_slice_mask` helpers shared between the custom-SDPA and matmul branches.

Differential Revision: D102710062
@meta-codesync meta-codesync Bot changed the title On-device perf + memory optimizations: custom SDPA, on-the-fly RoPE, KV cache fix, XNNPACK workspace sharing On-device perf + memory optimizations: custom SDPA, on-the-fly RoPE, KV cache fix, XNNPACK workspace sharing (#19214) Apr 29, 2026
@leixin leixin force-pushed the export-D102710062 branch from 1e95fc2 to 3b254e4 Compare April 29, 2026 22:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant