@@ -47,27 +47,19 @@ static constexpr auto kUseKVCache = "use_kv_cache";
4747static constexpr auto kUseSDPAWithKVCache = " use_sdpa_with_kv_cache" ;
4848} // namespace
4949
50- Runner::Runner (const std::string &model_path, const std::string &tokenizer_path,
51- const float temperature,
50+ Runner::Runner (Module *module , const std::string &model_path,
51+ const std::string &tokenizer_path,
52+ const bool extended_input_mode, const float temperature,
5253 std::optional<const std::string> data_path)
53- // NOTE: we observed ~2x loading performance increase on iPhone 15
54- // and a ~5% improvement on Galaxy S22 by switching to
55- // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
56- : temperature_(temperature), tokenizer_path_(tokenizer_path),
57- metadata_ ({
58- {kEnableDynamicShape , false },
59- {kMaxSeqLen , 128 },
60- {kMaxContextLen , 128 },
61- {kUseKVCache , true },
62- {kUseSDPAWithKVCache , false },
63- }) {
64- if (data_path.has_value ()) {
65- module_ = std::make_unique<Module>(model_path, data_path.value (),
66- Module::LoadMode::File);
67- } else {
68- module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
69- }
70- ET_LOG (Info, " Creating LLaMa runner: model_path=%s, tokenizer_path=%s" ,
54+ : module_(module ), temperature_(temperature),
55+ tokenizer_path_ (tokenizer_path), metadata_({
56+ {kEnableDynamicShape , false },
57+ {kMaxSeqLen , 128 },
58+ {kMaxContextLen , 128 },
59+ {kUseKVCache , true },
60+ {kUseSDPAWithKVCache , false },
61+ }) {
62+ ET_LOG (Info, " Creating LLM runner: model_path=%s, tokenizer_path=%s" ,
7163 model_path.c_str (), tokenizer_path.c_str ());
7264}
7365
@@ -116,7 +108,7 @@ Error Runner::load() {
116108 }
117109 }
118110 text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
119- module_. get () , metadata_.at (kUseKVCache ), metadata_.at (kVocabSize ),
111+ module_, metadata_.at (kUseKVCache ), metadata_.at (kVocabSize ),
120112 temperature_);
121113 text_prefiller_ = std::make_unique<llm::TextPrefiller>(
122114 text_decoder_runner_.get (), metadata_.at (kUseKVCache ),
@@ -206,7 +198,8 @@ Error Runner::generate(const std::string &prompt,
206198 wrapped_callback (prompt);
207199 }
208200 int64_t pos = 0 ;
209- auto prefill_res = text_prefiller_->prefill (prompt_tokens_uint64, pos);
201+ auto prefill_res = text_prefiller_->prefill (prompt_tokens_uint64, pos,
202+ extend_position_input_);
210203 stats_.first_token_ms = llm::time_in_ms ();
211204 stats_.prompt_eval_end_ms = llm::time_in_ms ();
212205 ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
@@ -269,6 +262,10 @@ void Runner::stop() {
269262 }
270263}
271264
265+ void Runner::set_extended_input_mode (bool extend_position_input) {
266+ extend_position_input_ = extend_position_input;
267+ }
268+
272269void Runner::set_count_interval (size_t count_interval) {
273270 text_token_generator_->set_count_interval (count_interval);
274271}
0 commit comments