Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion examples/models/gemma4/e2e_runner.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -43,6 +43,11 @@
DEFINE_int32(max_vision_tokens, 140, "Maximum soft tokens for vision encoder.");
DEFINE_double(temperature, 0.0, "Sampling temperature (0.0 = greedy).");
DEFINE_int32(cpu_threads, -1, "Number of CPU threads. -1 = auto-detect.");
DEFINE_bool(
enable_workspace_sharing,
true,
"Enable XNNPACK PerModel workspace sharing + weight cache. "
"Pass --noenable_workspace_sharing to disable for debugging.");

int32_t main(int32_t argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
Expand All @@ -62,7 +67,10 @@
Gemma4Stats stats;
stats.on_load_begin();

Gemma4Runner runner(FLAGS_model_path, FLAGS_tokenizer_path);
Gemma4Runner runner(
FLAGS_model_path,
FLAGS_tokenizer_path,
FLAGS_enable_workspace_sharing);
auto err = runner.load();
ET_CHECK_MSG(err == executorch::runtime::Error::Ok, "Failed to load model");

Expand Down
23 changes: 22 additions & 1 deletion examples/models/gemma4/export_gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def _export_text_decoder(
tied_embedding: bool = False,
variant: str = "e2b",
quantize_kv_cache: bool = False,
use_custom_sdpa: bool = True,
):
"""Export text decoder. Returns ExportedProgram."""
from executorch.examples.models.gemma4.quant_utils import (
Expand All @@ -467,6 +468,12 @@ def _export_text_decoder(
config.use_kv_cache = True
config.max_seq_len = max_seq_len
config.enable_dynamic_shape = True
config.use_custom_sdpa = use_custom_sdpa

if use_custom_sdpa:
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401

logger.info("Custom SDPA enabled (tiled flash attention)")

model_wrapper = Gemma4Model(
config=config, checkpoint_path=checkpoint_path, dtype=torch.float32
Expand Down Expand Up @@ -507,7 +514,9 @@ def _export_text_decoder(
)

logger.info("Replacing KV cache with INT8 quantized KV cache...")
model = replace_kv_cache_with_quantized_kv_cache(model)
model = replace_kv_cache_with_quantized_kv_cache(
model, use_custom_sdpa=use_custom_sdpa
)
model.eval()

if linear_quant:
Expand Down Expand Up @@ -549,6 +558,7 @@ def _export_components(
quantize_kv_cache: bool,
include_audio: bool,
include_vision: bool,
use_custom_sdpa: bool,
) -> dict:
"""Export each requested component to an ExportedProgram."""
components = []
Expand Down Expand Up @@ -594,6 +604,7 @@ def _export_components(
tied_embedding=tied_embedding,
variant=variant,
quantize_kv_cache=quantize_kv_cache,
use_custom_sdpa=use_custom_sdpa,
)

return programs
Expand Down Expand Up @@ -687,6 +698,7 @@ def export_single_pte(
quantize_kv_cache: bool = False,
include_audio: bool = True,
include_vision: bool = True,
use_custom_sdpa: bool = True,
) -> Path:
"""Export components into a single PTE.

Expand Down Expand Up @@ -714,6 +726,7 @@ def export_single_pte(
quantize_kv_cache=quantize_kv_cache,
include_audio=include_audio,
include_vision=include_vision,
use_custom_sdpa=use_custom_sdpa,
)

logger.info("Combining into single PTE...")
Expand Down Expand Up @@ -840,6 +853,13 @@ def main():
default=False,
help="Exclude vision_encoder method to reduce PTE size.",
)
parser.add_argument(
"--use_custom_sdpa",
action=argparse.BooleanOptionalAction,
default=True,
help="Route attention through llama::custom_sdpa (tiled flash attention). "
"Pass --no-use_custom_sdpa to fall back to matmul attention.",
)
args = parser.parse_args()

export_single_pte(
Expand All @@ -856,6 +876,7 @@ def main():
quantize_kv_cache=args.quantize_kv_cache,
include_audio=not args.no_audio,
include_vision=not args.no_vision,
use_custom_sdpa=args.use_custom_sdpa,
)


Expand Down
33 changes: 31 additions & 2 deletions examples/models/gemma4/runner/gemma4_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@

#include <executorch/examples/models/gemma4/runner/gemma4_runner.h>

#include <executorch/backends/xnnpack/runtime/XNNPACKBackend.h>
#include <executorch/extension/llm/runner/llm_runner_helper.h>
#include <executorch/extension/llm/sampler/sampler.h>
#include <executorch/extension/tensor/tensor_ptr_maker.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/platform/log.h>

Expand All @@ -28,10 +31,36 @@ using ::executorch::runtime::EValue;

Gemma4Runner::Gemma4Runner(
const std::string& model_path,
const std::string& tokenizer_path)
: model_path_(model_path), tokenizer_path_(tokenizer_path) {}
const std::string& tokenizer_path,
bool enable_workspace_sharing)
: model_path_(model_path),
tokenizer_path_(tokenizer_path),
enable_workspace_sharing_(enable_workspace_sharing) {}

Error Gemma4Runner::load() {
// Set XNNPACK workspace sharing explicitly. The compile-time default
// varies across build configurations, and set_option state is process-
// global, so always set it here to get the intended mode regardless of
// how the binary was built or what other code in the process did first.
{
auto mode = enable_workspace_sharing_
? ::executorch::backends::xnnpack::WorkspaceSharingMode::PerModel
: ::executorch::backends::xnnpack::WorkspaceSharingMode::Disabled;
::executorch::runtime::BackendOptions<2> xnnpack_opts;
xnnpack_opts.set_option(
::executorch::backends::xnnpack::weight_cache_option_key,
enable_workspace_sharing_);
xnnpack_opts.set_option(
::executorch::backends::xnnpack::workspace_sharing_mode_option_key,
static_cast<int>(mode));
auto opts_status = ::executorch::runtime::set_option(
::executorch::backends::xnnpack::xnnpack_backend_key,
xnnpack_opts.view());
if (opts_status != Error::Ok) {
ET_LOG(Error, "Failed to set XNNPACK options");
}
}

ET_LOG(Info, "Loading model: %s", model_path_.c_str());
module_ = std::make_unique<Module>(model_path_, Module::LoadMode::Mmap);

Expand Down
4 changes: 3 additions & 1 deletion examples/models/gemma4/runner/gemma4_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class Gemma4Runner {
public:
Gemma4Runner(
const std::string& model_path,
const std::string& tokenizer_path);
const std::string& tokenizer_path,
bool enable_workspace_sharing = true);

Error load();
bool is_loaded() const;
Expand Down Expand Up @@ -176,6 +177,7 @@ class Gemma4Runner {
std::unique_ptr<tokenizers::Tokenizer> tokenizer_;
std::string model_path_;
std::string tokenizer_path_;
bool enable_workspace_sharing_;
Error load_audio_methods();
Error load_vision_methods();

Expand Down
6 changes: 4 additions & 2 deletions examples/models/gemma4/runner/gemma4_stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <chrono>
#include <cinttypes>
#include <cstdint>
#include <cstdio>
#include <string>
Expand Down Expand Up @@ -198,7 +199,7 @@ struct Gemma4Stats {
"\"prefill_ms\":%.1f,\"generation_ms\":%.1f,"
"\"num_prompt_tokens\":%d,\"num_generated_tokens\":%d,"
"\"prefill_tok_per_s\":%.1f,\"gen_tok_per_s\":%.1f,"
"\"ttft_ms\":%.1f,\"total_ms\":%.1f,"
"\"ttft_ms\":%.1f,\"total_ms\":%.1f,\"rtf\":%.2f,"
"\"rss_after_load_mb\":%.0f,\"rss_peak_gen_mb\":%.0f}",
load_ms,
speech_transform_ms,
Expand All @@ -212,6 +213,7 @@ struct Gemma4Stats {
tokens_per_second(),
time_to_first_token_ms(),
total_inference_ms(),
rtf(),
rss_after_load_kb / 1024.0,
rss_peak_gen_kb / 1024.0);
return std::string(buf);
Expand All @@ -226,7 +228,7 @@ struct Gemma4Stats {
char line[256];
int64_t rss_kb = 0;
while (fgets(line, sizeof(line), f)) {
if (sscanf(line, "VmRSS: %ld kB", &rss_kb) == 1) {
if (sscanf(line, "VmRSS: %" SCNd64 " kB", &rss_kb) == 1) {
break;
}
}
Expand Down
7 changes: 3 additions & 4 deletions examples/models/gemma4/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@ def define_common_targets():
"runner/gemma4_runner.cpp",
],
visibility = ["PUBLIC"],
preprocessor_flags = [
"-DENABLE_XNNPACK_SHARED_WORKSPACE",
],
deps = _KERNEL_BACKEND_DEPS + [
"//executorch/backends/xnnpack:xnnpack_interface",
"//executorch/runtime/backend:interface",
"//executorch/extension/llm/sampler:sampler",
"//executorch/extension/module:module",
"//executorch/extension/tensor:tensor",
Expand All @@ -72,5 +71,5 @@ def define_common_targets():
],
visibility = ["PUBLIC"],
compiler_flags = ["-Wno-global-constructors"],
preprocessor_flags = ["-DET_USE_THREADPOOL", "-DENABLE_XNNPACK_SHARED_WORKSPACE"],
preprocessor_flags = ["-DET_USE_THREADPOOL"],
)
Loading
Loading