diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index ed0cacf3dbc..ed144acb14b 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -203,13 +203,38 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { data_files_vector, cpp_load_mode); std::string decoder_model = "llama3"; // use llama3 for now - runner_ = std::make_unique>( // QNN runner - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - ""); + // Using 8bit as default since this meta is introduced with 16bit kv io + // support and older models only have 8bit kv io. + example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; + if (module->method_names()->count("get_kv_io_bit_width") > 0) { + kv_bitwidth = static_cast( + module->get("get_kv_io_bit_width").get().toScalar().to()); + } + + if (kv_bitwidth == example::KvBitWidth::kWidth8) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else { + ET_CHECK_MSG( + false, + "Unsupported kv bitwidth: %ld", + static_cast(kv_bitwidth)); + } model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif #if defined(EXECUTORCH_BUILD_MEDIATEK)