@@ -149,103 +149,117 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
149149 jint num_bos = 0 ,
150150 jint num_eos = 0 ,
151151 jint load_mode = 1 ) {
152- temperature_ = temperature;
153- num_bos_ = num_bos;
154- num_eos_ = num_eos;
152+ try {
153+ temperature_ = temperature;
154+ num_bos_ = num_bos;
155+ num_eos_ = num_eos;
155156#if defined(ET_USE_THREADPOOL)
156- // Reserve 1 thread for the main thread.
157- int32_t num_performant_cores =
158- ::executorch::extension::cpuinfo::get_num_performant_cores () - 1;
159- if (num_performant_cores > 0 ) {
160- ET_LOG (Info, " Resetting threadpool to %d threads" , num_performant_cores);
161- ::executorch::extension::threadpool::get_threadpool ()
162- ->_unsafe_reset_threadpool(num_performant_cores);
163- }
157+ // Reserve 1 thread for the main thread.
158+ int32_t num_performant_cores =
159+ ::executorch::extension::cpuinfo::get_num_performant_cores () - 1;
160+ if (num_performant_cores > 0 ) {
161+ ET_LOG (
162+ Info, " Resetting threadpool to %d threads" , num_performant_cores);
163+ ::executorch::extension::threadpool::get_threadpool ()
164+ ->_unsafe_reset_threadpool(num_performant_cores);
165+ }
164166#endif
165167
166- model_type_category_ = model_type_category;
167- auto cpp_load_mode = load_mode_from_int (load_mode);
168- std::vector<std::string> data_files_vector;
169- if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
170- runner_ = llm::create_multimodal_runner (
171- model_path->toStdString ().c_str (),
172- llm::load_tokenizer (tokenizer_path->toStdString ()),
173- std::nullopt ,
174- cpp_load_mode);
175- } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
176- if (data_files != nullptr ) {
177- // Convert Java List<String> to C++ std::vector<string>
178- auto list_class = facebook::jni::findClassStatic (" java/util/List" );
179- auto size_method = list_class->getMethod <jint ()>(" size" );
180- auto get_method =
181- list_class->getMethod <facebook::jni::local_ref<jobject>(jint)>(
182- " get" );
183-
184- jint size = size_method (data_files);
185- for (jint i = 0 ; i < size; ++i) {
186- auto str_obj = get_method (data_files, i);
187- auto jstr = facebook::jni::static_ref_cast<jstring>(str_obj);
188- data_files_vector.push_back (jstr->toStdString ());
168+ model_type_category_ = model_type_category;
169+ auto cpp_load_mode = load_mode_from_int (load_mode);
170+ std::vector<std::string> data_files_vector;
171+ if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
172+ runner_ = llm::create_multimodal_runner (
173+ model_path->toStdString ().c_str (),
174+ llm::load_tokenizer (tokenizer_path->toStdString ()),
175+ std::nullopt ,
176+ cpp_load_mode);
177+ } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
178+ if (data_files != nullptr ) {
179+ // Convert Java List<String> to C++ std::vector<string>
180+ auto list_class = facebook::jni::findClassStatic (" java/util/List" );
181+ auto size_method = list_class->getMethod <jint ()>(" size" );
182+ auto get_method =
183+ list_class->getMethod <facebook::jni::local_ref<jobject>(jint)>(
184+ " get" );
185+
186+ jint size = size_method (data_files);
187+ for (jint i = 0 ; i < size; ++i) {
188+ auto str_obj = get_method (data_files, i);
189+ auto jstr = facebook::jni::static_ref_cast<jstring>(str_obj);
190+ data_files_vector.push_back (jstr->toStdString ());
191+ }
189192 }
190- }
191- runner_ = executorch::extension::llm::create_text_llm_runner (
192- model_path->toStdString (),
193- llm::load_tokenizer (tokenizer_path->toStdString ()),
194- data_files_vector,
195- /* temperature=*/ -1 .0f ,
196- /* event_tracer=*/ nullptr ,
197- /* method_name=*/ " forward" ,
198- cpp_load_mode);
193+ runner_ = executorch::extension::llm::create_text_llm_runner (
194+ model_path->toStdString (),
195+ llm::load_tokenizer (tokenizer_path->toStdString ()),
196+ data_files_vector,
197+ /* temperature=*/ -1 .0f ,
198+ /* event_tracer=*/ nullptr ,
199+ /* method_name=*/ " forward" ,
200+ cpp_load_mode);
199201#if defined(EXECUTORCH_BUILD_QNN)
200- } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
201- std::unique_ptr<executorch::extension::Module> module =
202- std::make_unique<executorch::extension::Module>(
203- model_path->toStdString ().c_str (),
204- data_files_vector,
205- cpp_load_mode);
206- std::string decoder_model = " llama3" ; // use llama3 for now
207- // Using 8bit as default since this meta is introduced with 16bit kv io
208- // support and older models only have 8bit kv io.
209- example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8 ;
210- if (module ->method_names ()->count (" get_kv_io_bit_width" ) > 0 ) {
211- kv_bitwidth = static_cast <example::KvBitWidth>(
212- module ->get (" get_kv_io_bit_width" ).get ().toScalar ().to <int64_t >());
213- }
202+ } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
203+ std::unique_ptr<executorch::extension::Module> module =
204+ std::make_unique<executorch::extension::Module>(
205+ model_path->toStdString ().c_str (),
206+ data_files_vector,
207+ cpp_load_mode);
208+ std::string decoder_model = " llama3" ; // use llama3 for now
209+ // Using 8bit as default since this meta is introduced with 16bit kv io
210+ // support and older models only have 8bit kv io.
211+ example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8 ;
212+ if (module ->method_names ()->count (" get_kv_io_bit_width" ) > 0 ) {
213+ kv_bitwidth = static_cast <example::KvBitWidth>(
214+ module ->get (" get_kv_io_bit_width" )
215+ .get ()
216+ .toScalar ()
217+ .to <int64_t >());
218+ }
214219
215- if (kv_bitwidth == example::KvBitWidth::kWidth8 ) {
216- runner_ = std::make_unique<example::Runner<uint8_t >>(
217- std::move (module ),
218- decoder_model.c_str (),
219- model_path->toStdString ().c_str (),
220- tokenizer_path->toStdString ().c_str (),
221- " " ,
222- " " ,
223- temperature_);
224- } else if (kv_bitwidth == example::KvBitWidth::kWidth16 ) {
225- runner_ = std::make_unique<example::Runner<uint16_t >>(
226- std::move (module ),
227- decoder_model.c_str (),
228- model_path->toStdString ().c_str (),
229- tokenizer_path->toStdString ().c_str (),
230- " " ,
231- " " ,
232- temperature_);
233- } else {
234- ET_CHECK_MSG (
235- false ,
236- " Unsupported kv bitwidth: %ld" ,
237- static_cast <int64_t >(kv_bitwidth));
238- }
239- model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
220+ if (kv_bitwidth == example::KvBitWidth::kWidth8 ) {
221+ runner_ = std::make_unique<example::Runner<uint8_t >>(
222+ std::move (module ),
223+ decoder_model.c_str (),
224+ model_path->toStdString ().c_str (),
225+ tokenizer_path->toStdString ().c_str (),
226+ " " ,
227+ " " ,
228+ temperature_);
229+ } else if (kv_bitwidth == example::KvBitWidth::kWidth16 ) {
230+ runner_ = std::make_unique<example::Runner<uint16_t >>(
231+ std::move (module ),
232+ decoder_model.c_str (),
233+ model_path->toStdString ().c_str (),
234+ tokenizer_path->toStdString ().c_str (),
235+ " " ,
236+ " " ,
237+ temperature_);
238+ } else {
239+ ET_CHECK_MSG (
240+ false ,
241+ " Unsupported kv bitwidth: %ld" ,
242+ static_cast <int64_t >(kv_bitwidth));
243+ }
244+ model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
240245#endif
241246#if defined(EXECUTORCH_BUILD_MEDIATEK)
242- } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
243- runner_ = std::make_unique<MTKLlamaRunner>(
244- model_path->toStdString ().c_str (),
245- tokenizer_path->toStdString ().c_str ());
246- // Interpret the model type as LLM
247- model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
247+ } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
248+ runner_ = std::make_unique<MTKLlamaRunner>(
249+ model_path->toStdString ().c_str (),
250+ tokenizer_path->toStdString ().c_str ());
251+ // Interpret the model type as LLM
252+ model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
248253#endif
254+ }
255+ } catch (const std::exception& e) {
256+ executorch::jni_helper::throwExecutorchException (
257+ static_cast <uint32_t >(Error::Internal),
258+ std::string (" Failed to create LlmModule: " ) + e.what ());
259+ } catch (...) {
260+ executorch::jni_helper::throwExecutorchException (
261+ static_cast <uint32_t >(Error::Internal),
262+ " Failed to create LlmModule: unknown native error" );
249263 }
250264 }
251265
0 commit comments