Skip to content

Commit 6b9eebb

Browse files
author
Abhinay Kukkadapu
committed
Fix QNN runner KV cache bitwidth detection in Android JNI
Summary: The QNN runner in the Android JNI layer was hardcoded to use Runner<uint16_t>, but models can be exported with either 8-bit or 16-bit KV caches. This mismatch caused the KV cache data to be misinterpreted, resulting in gibberish output in the Android demo app while the same model worked correctly via the CLI runner. This change mirrors the dynamic KV bitwidth detection already present in qnn_llama_runner.cpp by querying the model's get_kv_io_bit_width method and instantiating the correct Runner<uint8_t> or Runner<uint16_t> accordingly. Also passes temperature_ to the Runner constructor which was previously omitted. Fixes #18571 Closes #17622 Test Plan: - Built Android AAR with QNN support (SDK 2.37) — jni_layer_llama.cpp compiles cleanly with both Runner<uint8_t> and Runner<uint16_t> template instantiations - Unit tests pass (gradlew testDebugUnitTest)
1 parent e0e10cc commit 6b9eebb

1 file changed

Lines changed: 35 additions & 7 deletions

File tree

extension/android/jni/jni_layer_llama.cpp

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,41 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
203203
data_files_vector,
204204
cpp_load_mode);
205205
std::string decoder_model = "llama3"; // use llama3 for now
206-
runner_ = std::make_unique<example::Runner<uint16_t>>( // QNN runner
207-
std::move(module),
208-
decoder_model.c_str(),
209-
model_path->toStdString().c_str(),
210-
tokenizer_path->toStdString().c_str(),
211-
"",
212-
"");
206+
// Using 8bit as default since this meta is introduced with 16bit kv io
207+
// support and older models only have 8bit kv io.
208+
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
209+
if (module->method_names()->count("get_kv_io_bit_width") > 0) {
210+
kv_bitwidth = static_cast<example::KvBitWidth>(
211+
module->get("get_kv_io_bit_width")
212+
.get()
213+
.toScalar()
214+
.to<int64_t>());
215+
}
216+
217+
if (kv_bitwidth == example::KvBitWidth::kWidth8) {
218+
runner_ = std::make_unique<example::Runner<uint8_t>>(
219+
std::move(module),
220+
decoder_model.c_str(),
221+
model_path->toStdString().c_str(),
222+
tokenizer_path->toStdString().c_str(),
223+
"",
224+
"",
225+
temperature_);
226+
} else if (kv_bitwidth == example::KvBitWidth::kWidth16) {
227+
runner_ = std::make_unique<example::Runner<uint16_t>>(
228+
std::move(module),
229+
decoder_model.c_str(),
230+
model_path->toStdString().c_str(),
231+
tokenizer_path->toStdString().c_str(),
232+
"",
233+
"",
234+
temperature_);
235+
} else {
236+
ET_CHECK_MSG(
237+
false,
238+
"Unsupported kv bitwidth: %ld",
239+
static_cast<int64_t>(kv_bitwidth));
240+
}
213241
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
214242
#endif
215243
#if defined(EXECUTORCH_BUILD_MEDIATEK)

0 commit comments

Comments
 (0)