From 407676e952fcefd46786012100a16e22097f47d2 Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Mon, 6 Apr 2026 17:29:40 -0700 Subject: [PATCH] Fix QNN runner KV cache bitwidth detection in Android JNI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The QNN runner in the Android JNI layer was hardcoded to use Runner, but models can be exported with either 8-bit or 16-bit KV caches. This mismatch caused the KV cache data to be misinterpreted, resulting in gibberish output in the Android demo app while the same model worked correctly via the CLI runner. This change mirrors the dynamic KV bitwidth detection already present in qnn_llama_runner.cpp by querying the model's get_kv_io_bit_width method and instantiating the correct Runner or Runner accordingly. Also passes temperature_ to the Runner constructor which was previously omitted. Fixes #18571 Closes #17622 Test Plan: - Built Android AAR with QNN support (SDK 2.37) — jni_layer_llama.cpp compiles cleanly with both Runner and Runner template instantiations - Unit tests pass (gradlew testDebugUnitTest) --- extension/android/jni/jni_layer_llama.cpp | 39 +++++++++++++++++++---- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index ed0cacf3dbc..ed144acb14b 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -203,13 +203,38 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { data_files_vector, cpp_load_mode); std::string decoder_model = "llama3"; // use llama3 for now - runner_ = std::make_unique>( // QNN runner - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - ""); + // Using 8bit as default since this meta is introduced with 16bit kv io + // support and older models only have 8bit kv io. + example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; + if (module->method_names()->count("get_kv_io_bit_width") > 0) { + kv_bitwidth = static_cast( + module->get("get_kv_io_bit_width").get().toScalar().to()); + } + + if (kv_bitwidth == example::KvBitWidth::kWidth8) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else { + ET_CHECK_MSG( + false, + "Unsupported kv bitwidth: %ld", + static_cast(kv_bitwidth)); + } model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif #if defined(EXECUTORCH_BUILD_MEDIATEK)