Skip to content

Commit 1e95fc2

Browse files
leixinfacebook-github-bot
authored andcommitted
On-device perf + memory optimizations: custom SDPA, on-the-fly RoPE, KV cache fix, XNNPACK workspace sharing
Summary: Six changes for the Gemma 4 text decoder + runner, enabled by default. Custom SDPA can be opted out via `--no-use_custom_sdpa` for eager mode or non-XNNPACK backends. Workspace sharing can be opted out via `--noenable_workspace_sharing` for debugging. 1. Custom SDPA — attention now runs through `torch.ops.llama.custom_sdpa` (tiled flash attention from the Llama runner). Skips the 8x KV expansion that GQA/MQA otherwise requires, and never materializes the full `[seq, seq]` attention matrix — the matmul fallback's `[bs, heads, seq, seq]` tensor exceeds S25's 8 MB L2 cache at `seq=2048` and causes severe regression. Adds an inline INT8 dequant path for `Gemma4QuantizedKVCache(return_float_values=False)` that stays inside the XNNPACK partition. 2. On-the-fly RoPE — the attention module stores only the `inv_freq` vector (~128-256 floats) and computes cos/sin per forward, instead of registering precomputed `[max_seq_len, head_dim]` cos/sin buffers. Reduces PTE size 3-7%. 3. KV cache allocation is skipped for `is_kv_shared_layer=True`. In YOCO, 20 of 35 layers consume the donor's KV via `shared_kv` and never write to their own cache, so the allocation was dead. Saves ~40 MB at `seq=1024`, ~80 MB at `seq=2048`. 4. XNNPACK workspace sharing in runner. `Gemma4Runner::load()` now calls `set_option(workspace_sharing_mode_option_key=PerModel, weight_cache_option_key=true)` on the XNNPACK backend before module load. Default-on with `enable_workspace_sharing` constructor flag for opt-out. Without this, real Android/iOS app builds (which don't pass the bench's compile-time `--config xnnpack_workspace_sharing=1`) end up with `Disabled` mode and OOM crash silently on E4B (>2 GB peak memory regression reported by app teams). Compile-time flag in xplat/.../gemma4/targets.bzl (`-DENABLE_XNNPACK_SHARED_WORKSPACE`) is also removed since it was dead — Buck preprocessor flags don't reach `XNNWorkspaceManager.cpp` (which lives in the `xnnpack_backend` compile unit). 5. Correctness fix for KV cache quant + custom SDPA + YOCO. When `Gemma4QuantizedKVCache(return_float_values=False)` is in use, the donor layer now dequants K/V before storing in `kv_to_share` so cross-decoder layers (which lack access to the donor's scales) don't pass raw int8 to `custom_sdpa`. Dormant bug: only triggers with `--quantize_kv_cache --use_custom_sdpa`; previously crashed export with `AssertionError: Expected key to be float32`. 6. iOS VmRSS sscanf fix (consolidates D103030061). `Gemma4Stats::read_rss_kb()` uses `SCNd64` from `<cinttypes>` instead of `%ld` so the format matches `int64_t` on both LP64 (Linux/Android) and LLP64-ish (iOS arm64) platforms. Unblocks iOS sample app builds with `-Werror,-Wformat`. Mask construction is factored into `_build_attn_mask` / `_slice_mask` helpers shared between the custom-SDPA and matmul branches. Differential Revision: D102710062
1 parent d767516 commit 1e95fc2

8 files changed

Lines changed: 242 additions & 164 deletions

File tree

examples/models/gemma4/e2e_runner.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ DEFINE_int32(max_new_tokens, 100, "Maximum tokens to generate.");
4343
DEFINE_int32(max_vision_tokens, 140, "Maximum soft tokens for vision encoder.");
4444
DEFINE_double(temperature, 0.0, "Sampling temperature (0.0 = greedy).");
4545
DEFINE_int32(cpu_threads, -1, "Number of CPU threads. -1 = auto-detect.");
46+
DEFINE_bool(
47+
enable_workspace_sharing,
48+
true,
49+
"Enable XNNPACK PerModel workspace sharing + weight cache. "
50+
"Pass --noenable_workspace_sharing to disable for debugging.");
4651

4752
int32_t main(int32_t argc, char** argv) {
4853
gflags::ParseCommandLineFlags(&argc, &argv, true);
@@ -62,7 +67,10 @@ int32_t main(int32_t argc, char** argv) {
6267
Gemma4Stats stats;
6368
stats.on_load_begin();
6469

65-
Gemma4Runner runner(FLAGS_model_path, FLAGS_tokenizer_path);
70+
Gemma4Runner runner(
71+
FLAGS_model_path,
72+
FLAGS_tokenizer_path,
73+
FLAGS_enable_workspace_sharing);
6674
auto err = runner.load();
6775
ET_CHECK_MSG(err == executorch::runtime::Error::Ok, "Failed to load model");
6876

examples/models/gemma4/export_gemma4.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ def _export_text_decoder(
451451
tied_embedding: bool = False,
452452
variant: str = "e2b",
453453
quantize_kv_cache: bool = False,
454+
use_custom_sdpa: bool = True,
454455
):
455456
"""Export text decoder. Returns ExportedProgram."""
456457
from executorch.examples.models.gemma4.quant_utils import (
@@ -467,6 +468,12 @@ def _export_text_decoder(
467468
config.use_kv_cache = True
468469
config.max_seq_len = max_seq_len
469470
config.enable_dynamic_shape = True
471+
config.use_custom_sdpa = use_custom_sdpa
472+
473+
if use_custom_sdpa:
474+
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
475+
476+
logger.info("Custom SDPA enabled (tiled flash attention)")
470477

471478
model_wrapper = Gemma4Model(
472479
config=config, checkpoint_path=checkpoint_path, dtype=torch.float32
@@ -507,7 +514,9 @@ def _export_text_decoder(
507514
)
508515

509516
logger.info("Replacing KV cache with INT8 quantized KV cache...")
510-
model = replace_kv_cache_with_quantized_kv_cache(model)
517+
model = replace_kv_cache_with_quantized_kv_cache(
518+
model, use_custom_sdpa=use_custom_sdpa
519+
)
511520
model.eval()
512521

513522
if linear_quant:
@@ -549,6 +558,7 @@ def _export_components(
549558
quantize_kv_cache: bool,
550559
include_audio: bool,
551560
include_vision: bool,
561+
use_custom_sdpa: bool,
552562
) -> dict:
553563
"""Export each requested component to an ExportedProgram."""
554564
components = []
@@ -594,6 +604,7 @@ def _export_components(
594604
tied_embedding=tied_embedding,
595605
variant=variant,
596606
quantize_kv_cache=quantize_kv_cache,
607+
use_custom_sdpa=use_custom_sdpa,
597608
)
598609

599610
return programs
@@ -687,6 +698,7 @@ def export_single_pte(
687698
quantize_kv_cache: bool = False,
688699
include_audio: bool = True,
689700
include_vision: bool = True,
701+
use_custom_sdpa: bool = True,
690702
) -> Path:
691703
"""Export components into a single PTE.
692704
@@ -714,6 +726,7 @@ def export_single_pte(
714726
quantize_kv_cache=quantize_kv_cache,
715727
include_audio=include_audio,
716728
include_vision=include_vision,
729+
use_custom_sdpa=use_custom_sdpa,
717730
)
718731

719732
logger.info("Combining into single PTE...")
@@ -840,6 +853,13 @@ def main():
840853
default=False,
841854
help="Exclude vision_encoder method to reduce PTE size.",
842855
)
856+
parser.add_argument(
857+
"--use_custom_sdpa",
858+
action=argparse.BooleanOptionalAction,
859+
default=True,
860+
help="Route attention through llama::custom_sdpa (tiled flash attention). "
861+
"Pass --no-use_custom_sdpa to fall back to matmul attention.",
862+
)
843863
args = parser.parse_args()
844864

845865
export_single_pte(
@@ -856,6 +876,7 @@ def main():
856876
quantize_kv_cache=args.quantize_kv_cache,
857877
include_audio=not args.no_audio,
858878
include_vision=not args.no_vision,
879+
use_custom_sdpa=args.use_custom_sdpa,
859880
)
860881

861882

examples/models/gemma4/runner/gemma4_runner.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010

1111
#include <executorch/examples/models/gemma4/runner/gemma4_runner.h>
1212

13+
#include <executorch/backends/xnnpack/runtime/XNNPACKBackend.h>
1314
#include <executorch/extension/llm/runner/llm_runner_helper.h>
1415
#include <executorch/extension/llm/sampler/sampler.h>
1516
#include <executorch/extension/tensor/tensor_ptr_maker.h>
17+
#include <executorch/runtime/backend/interface.h>
18+
#include <executorch/runtime/backend/options.h>
1619
#include <executorch/runtime/core/evalue.h>
1720
#include <executorch/runtime/platform/log.h>
1821

@@ -28,10 +31,36 @@ using ::executorch::runtime::EValue;
2831

2932
Gemma4Runner::Gemma4Runner(
3033
const std::string& model_path,
31-
const std::string& tokenizer_path)
32-
: model_path_(model_path), tokenizer_path_(tokenizer_path) {}
34+
const std::string& tokenizer_path,
35+
bool enable_workspace_sharing)
36+
: model_path_(model_path),
37+
tokenizer_path_(tokenizer_path),
38+
enable_workspace_sharing_(enable_workspace_sharing) {}
3339

3440
Error Gemma4Runner::load() {
41+
// Set XNNPACK workspace sharing explicitly. The compile-time default
42+
// varies across build configurations, and set_option state is process-
43+
// global, so always set it here to get the intended mode regardless of
44+
// how the binary was built or what other code in the process did first.
45+
{
46+
auto mode = enable_workspace_sharing_
47+
? ::executorch::backends::xnnpack::WorkspaceSharingMode::PerModel
48+
: ::executorch::backends::xnnpack::WorkspaceSharingMode::Disabled;
49+
::executorch::runtime::BackendOptions<2> xnnpack_opts;
50+
xnnpack_opts.set_option(
51+
::executorch::backends::xnnpack::weight_cache_option_key,
52+
enable_workspace_sharing_);
53+
xnnpack_opts.set_option(
54+
::executorch::backends::xnnpack::workspace_sharing_mode_option_key,
55+
static_cast<int>(mode));
56+
auto opts_status = ::executorch::runtime::set_option(
57+
::executorch::backends::xnnpack::xnnpack_backend_key,
58+
xnnpack_opts.view());
59+
if (opts_status != Error::Ok) {
60+
ET_LOG(Error, "Failed to set XNNPACK options");
61+
}
62+
}
63+
3564
ET_LOG(Info, "Loading model: %s", model_path_.c_str());
3665
module_ = std::make_unique<Module>(model_path_, Module::LoadMode::Mmap);
3766

examples/models/gemma4/runner/gemma4_runner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ class Gemma4Runner {
3939
public:
4040
Gemma4Runner(
4141
const std::string& model_path,
42-
const std::string& tokenizer_path);
42+
const std::string& tokenizer_path,
43+
bool enable_workspace_sharing = true);
4344

4445
Error load();
4546
bool is_loaded() const;
@@ -176,6 +177,7 @@ class Gemma4Runner {
176177
std::unique_ptr<tokenizers::Tokenizer> tokenizer_;
177178
std::string model_path_;
178179
std::string tokenizer_path_;
180+
bool enable_workspace_sharing_;
179181
Error load_audio_methods();
180182
Error load_vision_methods();
181183

examples/models/gemma4/runner/gemma4_stats.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <chrono>
12+
#include <cinttypes>
1213
#include <cstdint>
1314
#include <cstdio>
1415
#include <string>
@@ -198,7 +199,7 @@ struct Gemma4Stats {
198199
"\"prefill_ms\":%.1f,\"generation_ms\":%.1f,"
199200
"\"num_prompt_tokens\":%d,\"num_generated_tokens\":%d,"
200201
"\"prefill_tok_per_s\":%.1f,\"gen_tok_per_s\":%.1f,"
201-
"\"ttft_ms\":%.1f,\"total_ms\":%.1f,"
202+
"\"ttft_ms\":%.1f,\"total_ms\":%.1f,\"rtf\":%.2f,"
202203
"\"rss_after_load_mb\":%.0f,\"rss_peak_gen_mb\":%.0f}",
203204
load_ms,
204205
speech_transform_ms,
@@ -212,6 +213,7 @@ struct Gemma4Stats {
212213
tokens_per_second(),
213214
time_to_first_token_ms(),
214215
total_inference_ms(),
216+
rtf(),
215217
rss_after_load_kb / 1024.0,
216218
rss_peak_gen_kb / 1024.0);
217219
return std::string(buf);
@@ -226,7 +228,7 @@ struct Gemma4Stats {
226228
char line[256];
227229
int64_t rss_kb = 0;
228230
while (fgets(line, sizeof(line), f)) {
229-
if (sscanf(line, "VmRSS: %ld kB", &rss_kb) == 1) {
231+
if (sscanf(line, "VmRSS: %" SCNd64 " kB", &rss_kb) == 1) {
230232
break;
231233
}
232234
}

examples/models/gemma4/targets.bzl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,9 @@ def define_common_targets():
4343
"runner/gemma4_runner.cpp",
4444
],
4545
visibility = ["PUBLIC"],
46-
preprocessor_flags = [
47-
"-DENABLE_XNNPACK_SHARED_WORKSPACE",
48-
],
4946
deps = _KERNEL_BACKEND_DEPS + [
47+
"//executorch/backends/xnnpack:xnnpack_interface",
48+
"//executorch/runtime/backend:interface",
5049
"//executorch/extension/llm/sampler:sampler",
5150
"//executorch/extension/module:module",
5251
"//executorch/extension/tensor:tensor",
@@ -72,5 +71,5 @@ def define_common_targets():
7271
],
7372
visibility = ["PUBLIC"],
7473
compiler_flags = ["-Wno-global-constructors"],
75-
preprocessor_flags = ["-DET_USE_THREADPOOL", "-DENABLE_XNNPACK_SHARED_WORKSPACE"],
74+
preprocessor_flags = ["-DET_USE_THREADPOOL"],
7675
)

0 commit comments

Comments
 (0)