Skip to content

Commit 7e1c887

Browse files
committed
gemma4 support
1 parent 9f752b6 commit 7e1c887

11 files changed

Lines changed: 343 additions & 87 deletions

File tree

packages/react-native-executorch/common/runner/base_llm_runner.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ BaseLLMRunner::BaseLLMRunner(std::unique_ptr<Module> module,
1818
tokenizer_(std::make_unique<tokenizers::HFTokenizer>()),
1919
metadata_({
2020
{kEnableDynamicShape, false},
21-
{kMaxSeqLen, 128},
22-
{kMaxContextLen, 128},
21+
{kMaxSeqLen, 2048},
22+
{kMaxContextLen, 2048},
2323
{kUseKVCache, true},
2424
}) {}
2525

@@ -69,6 +69,7 @@ Error BaseLLMRunner::load() {
6969
eos_ids_->emplace(static_cast<uint64_t>(eos_id.toScalar().to<int64_t>()));
7070
}
7171
}
72+
eos_ids_->emplace(static_cast<uint64_t>(1));
7273
if (eos_ids_->empty()) {
7374
throw rnexecutorch::RnExecutorchError(
7475
rnexecutorch::RnExecutorchErrorCode::InvalidModelOutput,

packages/react-native-executorch/common/runner/constants.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ inline constexpr auto kMaxSeqLen = "get_max_seq_len";
1717
inline constexpr auto kMaxContextLen = "get_max_context_len";
1818
inline constexpr auto kVocabSize = "get_vocab_size";
1919
inline constexpr auto kUseKVCache = "use_kv_cache";
20+
// PLE models only: token id that marks image placeholder slots in input_ids.
21+
// token_embedding run on this id produces the per-layer PLE signal for image
22+
// positions; the inputs_embeds output for those positions is discarded (the
23+
// vision encoder output replaces it).
24+
inline constexpr auto kImagePlaceholderId = "image_placeholder_id";
2025

2126
// Multimodal method name conventions
2227
inline constexpr auto kVisionEncoderMethod = "vision_encoder";

packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ bool VisionEncoder::is_loaded() const noexcept {
4141
}
4242

4343
int32_t VisionEncoder::encoderTokenCount() const {
44+
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info,
45+
"VisionEncoder::encoderTokenCount");
4446
if (!is_loaded()) {
4547
return 0;
4648
}
@@ -102,6 +104,8 @@ VisionEncoder::preprocessImage(const std::string &path,
102104
}
103105

104106
Result<EValue> VisionEncoder::encode(const MultimodalInput &input) {
107+
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info,
108+
"VisionEncoder::encode start");
105109
if (!is_loaded()) {
106110
return Error::InvalidState;
107111
}
@@ -128,9 +132,14 @@ Result<EValue> VisionEncoder::encode(const MultimodalInput &input) {
128132
auto image_tensor = ::executorch::extension::from_blob(
129133
chw.data(), sizes, ::executorch::aten::ScalarType::Float);
130134

135+
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info,
136+
"VisionEncoder::encode start1");
131137
auto result = ET_UNWRAP(module_->execute(kVisionEncoderMethod, image_tensor));
138+
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info,
139+
"VisionEncoder::encode end1");
132140
auto embedding = result[0];
133141
embedding_cache_.emplace(path, embedding);
142+
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, "VisionEncoder::encode end");
134143
return embedding;
135144
}
136145

packages/react-native-executorch/common/runner/multimodal_decoder_runner.h

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,48 @@
1414
#include "text_decoder_runner.h"
1515

1616
namespace executorch::extension::llm {
17+
// Supports two PTE contracts, selected automatically at load time from
18+
// `token_embedding`'s output arity:
19+
//
20+
// * Legacy (default):
21+
// token_embedding(ids) -> inputs_embeds
22+
// text_decoder(inputs_embeds, input_pos)
23+
//
24+
// * Gemma-style PLE (when token_embedding emits 2 outputs):
25+
// token_embedding(ids) -> (inputs_embeds, ple_tok)
26+
// text_decoder(inputs_embeds, ple_tok, input_pos)
27+
// ple_tok carries Gemma4's per-layer PLE signal keyed on input_ids. It's
28+
// computed once in token_embedding and threaded through every decoder call
29+
// so PLE fires at every position (including multimodal placeholder slots).
1730
class MultimodalDecoderRunner : public TextDecoderRunner {
1831
public:
1932
explicit MultimodalDecoderRunner(Module &module, IOManager *io_manager,
2033
const GenerationConfig &config)
2134
: TextDecoderRunner(module, io_manager, config) {}
2235

36+
// True iff the loaded PTE uses the Gemma-style PLE contract above.
37+
// Meaningful only after load() has been called.
38+
bool uses_ple() const { return uses_ple_; }
39+
2340
inline ::executorch::runtime::Result<::executorch::aten::Tensor>
2441
step(TensorPtr &tokens, int64_t start_pos) override {
2542
auto embed_result = module_->execute(kTokenEmbeddingMethod, tokens);
43+
2644
if (!embed_result.ok()) {
2745
return embed_result.error();
2846
}
29-
return decode((*embed_result)[0], start_pos);
47+
auto &embed_outputs = *embed_result;
48+
if (uses_ple_) {
49+
ET_CHECK_MSG(embed_outputs.size() == 2,
50+
"Expected 2 outputs (inputs_embeds, ple_tok) from "
51+
"token_embedding, got %zu",
52+
embed_outputs.size());
53+
return decode(embed_outputs[0], embed_outputs[1], start_pos);
54+
}
55+
return decode(embed_outputs[0], start_pos);
3056
}
3157

58+
// Legacy 2-input text_decoder(inputs_embeds, input_pos).
3259
inline ::executorch::runtime::Result<::executorch::aten::Tensor>
3360
decode(const ::executorch::runtime::EValue &embeddings, int64_t start_pos) {
3461
auto start_pos_tensor = ::executorch::extension::from_blob(
@@ -46,19 +73,45 @@ class MultimodalDecoderRunner : public TextDecoderRunner {
4673
return outputs[0].toTensor();
4774
}
4875

76+
// PLE 3-input text_decoder(inputs_embeds, ple_tok, input_pos).
77+
inline ::executorch::runtime::Result<::executorch::aten::Tensor>
78+
decode(const ::executorch::runtime::EValue &embeddings,
79+
const ::executorch::runtime::EValue &ple_tok, int64_t start_pos) {
80+
auto start_pos_tensor = ::executorch::extension::from_blob(
81+
&start_pos, {1}, ::executorch::aten::ScalarType::Long);
82+
auto outputs_result = module_->execute(
83+
kTextModelMethod, {embeddings, ple_tok, start_pos_tensor});
84+
if (!outputs_result.ok()) {
85+
return outputs_result.error();
86+
}
87+
auto &outputs = *outputs_result;
88+
ET_CHECK_MSG(outputs.size() == 1,
89+
"Expected 1 output from text_decoder, got %zu",
90+
outputs.size());
91+
ET_CHECK_MSG(outputs[0].isTensor(), "text_decoder output is not a tensor");
92+
return outputs[0].toTensor();
93+
}
94+
4995
inline ::executorch::runtime::Error load() override {
5096
if (is_method_loaded()) {
5197
return ::executorch::runtime::Error::Ok;
5298
}
5399
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTokenEmbeddingMethod));
54100
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod));
101+
102+
auto meta = module_->method_meta(kTokenEmbeddingMethod);
103+
ET_CHECK_OK_OR_RETURN_ERROR(meta.error());
104+
uses_ple_ = (meta->num_outputs() == 2);
55105
return ::executorch::runtime::Error::Ok;
56106
}
57107

58108
inline bool is_method_loaded() override {
59109
return module_->is_method_loaded(kTokenEmbeddingMethod) &&
60110
module_->is_method_loaded(kTextModelMethod);
61111
}
112+
113+
private:
114+
bool uses_ple_ = true;
62115
};
63116

64117
} // namespace executorch::extension::llm

0 commit comments

Comments
 (0)