1414#include " text_decoder_runner.h"
1515
1616namespace executorch ::extension::llm {
17+ // Supports two PTE contracts, selected automatically at load time from
18+ // `token_embedding`'s output arity:
19+ //
20+ // * Legacy (default):
21+ // token_embedding(ids) -> inputs_embeds
22+ // text_decoder(inputs_embeds, input_pos)
23+ //
24+ // * Gemma-style PLE (when token_embedding emits 2 outputs):
25+ // token_embedding(ids) -> (inputs_embeds, ple_tok)
26+ // text_decoder(inputs_embeds, ple_tok, input_pos)
27+ // ple_tok carries Gemma4's per-layer PLE signal keyed on input_ids. It's
28+ // computed once in token_embedding and threaded through every decoder call
29+ // so PLE fires at every position (including multimodal placeholder slots).
1730class MultimodalDecoderRunner : public TextDecoderRunner {
1831public:
1932 explicit MultimodalDecoderRunner (Module &module , IOManager *io_manager,
2033 const GenerationConfig &config)
2134 : TextDecoderRunner(module , io_manager, config) {}
2235
36+ // True iff the loaded PTE uses the Gemma-style PLE contract above.
37+ // Meaningful only after load() has been called.
38+ bool uses_ple () const { return uses_ple_; }
39+
2340 inline ::executorch::runtime::Result<::executorch::aten::Tensor>
2441 step (TensorPtr &tokens, int64_t start_pos) override {
2542 auto embed_result = module_->execute (kTokenEmbeddingMethod , tokens);
43+
2644 if (!embed_result.ok ()) {
2745 return embed_result.error ();
2846 }
29- return decode ((*embed_result)[0 ], start_pos);
47+ auto &embed_outputs = *embed_result;
48+ if (uses_ple_) {
49+ ET_CHECK_MSG (embed_outputs.size () == 2 ,
50+ " Expected 2 outputs (inputs_embeds, ple_tok) from "
51+ " token_embedding, got %zu" ,
52+ embed_outputs.size ());
53+ return decode (embed_outputs[0 ], embed_outputs[1 ], start_pos);
54+ }
55+ return decode (embed_outputs[0 ], start_pos);
3056 }
3157
58+ // Legacy 2-input text_decoder(inputs_embeds, input_pos).
3259 inline ::executorch::runtime::Result<::executorch::aten::Tensor>
3360 decode (const ::executorch::runtime::EValue &embeddings, int64_t start_pos) {
3461 auto start_pos_tensor = ::executorch::extension::from_blob (
@@ -46,19 +73,45 @@ class MultimodalDecoderRunner : public TextDecoderRunner {
4673 return outputs[0 ].toTensor ();
4774 }
4875
76+ // PLE 3-input text_decoder(inputs_embeds, ple_tok, input_pos).
77+ inline ::executorch::runtime::Result<::executorch::aten::Tensor>
78+ decode (const ::executorch::runtime::EValue &embeddings,
79+ const ::executorch::runtime::EValue &ple_tok, int64_t start_pos) {
80+ auto start_pos_tensor = ::executorch::extension::from_blob (
81+ &start_pos, {1 }, ::executorch::aten::ScalarType::Long);
82+ auto outputs_result = module_->execute (
83+ kTextModelMethod , {embeddings, ple_tok, start_pos_tensor});
84+ if (!outputs_result.ok ()) {
85+ return outputs_result.error ();
86+ }
87+ auto &outputs = *outputs_result;
88+ ET_CHECK_MSG (outputs.size () == 1 ,
89+ " Expected 1 output from text_decoder, got %zu" ,
90+ outputs.size ());
91+ ET_CHECK_MSG (outputs[0 ].isTensor (), " text_decoder output is not a tensor" );
92+ return outputs[0 ].toTensor ();
93+ }
94+
4995 inline ::executorch::runtime::Error load () override {
5096 if (is_method_loaded ()) {
5197 return ::executorch::runtime::Error::Ok;
5298 }
5399 ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (kTokenEmbeddingMethod ));
54100 ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (kTextModelMethod ));
101+
102+ auto meta = module_->method_meta (kTokenEmbeddingMethod );
103+ ET_CHECK_OK_OR_RETURN_ERROR (meta.error ());
104+ uses_ple_ = (meta->num_outputs () == 2 );
55105 return ::executorch::runtime::Error::Ok;
56106 }
57107
58108 inline bool is_method_loaded () override {
59109 return module_->is_method_loaded (kTokenEmbeddingMethod ) &&
60110 module_->is_method_loaded (kTextModelMethod );
61111 }
112+
113+ private:
114+ bool uses_ple_ = true ;
62115};
63116
64117} // namespace executorch::extension::llm
0 commit comments