Skip to content

Commit d07b043

Browse files
committed
Fix issues
1 parent 6277aef commit d07b043

2 files changed

Lines changed: 91 additions & 49 deletions

File tree

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package org.pytorch.executorch.extension.asr
1010

1111
import java.io.Closeable
1212
import java.io.File
13+
import java.util.concurrent.atomic.AtomicLong
1314
import org.pytorch.executorch.annotations.Experimental
1415

1516
/**
@@ -28,7 +29,7 @@ import org.pytorch.executorch.annotations.Experimental
2829
@Experimental
2930
class AsrModule(modelPath: String, tokenizerPath: String, dataPath: String? = null) : Closeable {
3031

31-
private var nativeHandle: Long
32+
private val nativeHandle = AtomicLong(0L)
3233

3334
init {
3435
val modelFile = File(modelPath)
@@ -38,10 +39,11 @@ class AsrModule(modelPath: String, tokenizerPath: String, dataPath: String? = nu
3839
"Cannot load tokenizer path $tokenizerPath"
3940
}
4041

41-
nativeHandle = nativeCreate(modelPath, tokenizerPath, dataPath)
42-
if (nativeHandle == 0L) {
42+
val handle = nativeCreate(modelPath, tokenizerPath, dataPath)
43+
if (handle == 0L) {
4344
throw RuntimeException("Failed to create native AsrModule")
4445
}
46+
nativeHandle.set(handle)
4547
}
4648

4749
companion object {
@@ -78,17 +80,20 @@ class AsrModule(modelPath: String, tokenizerPath: String, dataPath: String? = nu
7880

7981
/** Check if the native handle is valid. */
8082
val isValid: Boolean
81-
get() = nativeHandle != 0L
83+
get() = nativeHandle.get() != 0L
8284

8385
/** Check if the module is loaded and ready for inference. */
8486
val isLoaded: Boolean
85-
get() = nativeHandle != 0L && nativeIsLoaded(nativeHandle)
87+
get() {
88+
val handle = nativeHandle.get()
89+
return handle != 0L && nativeIsLoaded(handle)
90+
}
8691

8792
/** Releases native resources. Call this when done with the module. */
8893
fun destroy() {
89-
if (nativeHandle != 0L) {
90-
nativeDestroy(nativeHandle)
91-
nativeHandle = 0L
94+
val handle = nativeHandle.getAndSet(0L)
95+
if (handle != 0L) {
96+
nativeDestroy(handle)
9297
}
9398
}
9499

@@ -97,20 +102,16 @@ class AsrModule(modelPath: String, tokenizerPath: String, dataPath: String? = nu
97102
destroy()
98103
}
99104

100-
@Throws(Throwable::class)
101-
protected fun finalize() {
102-
destroy()
103-
}
104-
105105
/**
106106
* Force loading the module. Otherwise the model is loaded during first transcribe() call.
107107
*
108108
* @return 0 on success, error code otherwise
109109
* @throws IllegalStateException if the module has been destroyed
110110
*/
111111
fun load(): Int {
112-
checkNotDestroyed()
113-
return nativeLoad(nativeHandle)
112+
val handle = nativeHandle.get()
113+
check(handle != 0L) { "AsrModule has been destroyed" }
114+
return nativeLoad(handle)
114115
}
115116

116117
/**
@@ -152,9 +153,10 @@ class AsrModule(modelPath: String, tokenizerPath: String, dataPath: String? = nu
152153
config: AsrTranscribeConfig,
153154
callback: AsrCallback? = null,
154155
): Int {
155-
checkNotDestroyed()
156+
val handle = nativeHandle.get()
157+
check(handle != 0L) { "AsrModule has been destroyed" }
156158
return nativeTranscribe(
157-
nativeHandle,
159+
handle,
158160
features,
159161
batchSize,
160162
timeSteps,
@@ -212,8 +214,4 @@ class AsrModule(modelPath: String, tokenizerPath: String, dataPath: String? = nu
212214

213215
return result.toString()
214216
}
215-
216-
private fun checkNotDestroyed() {
217-
check(nativeHandle != 0L) { "AsrModule has been destroyed" }
218-
}
219217
}

extension/android/jni/jni_layer_asr.cpp

Lines changed: 72 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <cstdint>
1212
#include <memory>
13+
#include <mutex>
1314
#include <string>
1415
#include <unordered_set>
1516
#include <vector>
@@ -68,32 +69,42 @@ bool utf8_check_validity(const char* str, size_t length) {
6869
// Thread-local token buffer for UTF-8 accumulation
6970
thread_local std::string asr_token_buffer;
7071

71-
// Cached JNI references for callback
72+
// Global cached JNI references for callback (shared across threads)
7273
struct AsrCallbackCache {
7374
jclass callbackClass = nullptr;
7475
jmethodID onTokenMethod = nullptr;
7576
jmethodID onCompleteMethod = nullptr;
76-
bool initialized = false;
77+
};
7778

78-
void init(JNIEnv* env) {
79-
if (initialized) {
80-
return;
81-
}
79+
AsrCallbackCache callbackCache;
80+
std::once_flag callbackCacheInitFlag;
81+
82+
void initCallbackCache(JNIEnv* env) {
83+
std::call_once(callbackCacheInitFlag, [env]() {
8284
jclass localClass =
8385
env->FindClass("org/pytorch/executorch/extension/asr/AsrCallback");
8486
if (localClass != nullptr) {
85-
callbackClass = (jclass)env->NewGlobalRef(localClass);
86-
onTokenMethod =
87-
env->GetMethodID(callbackClass, "onToken", "(Ljava/lang/String;)V");
88-
onCompleteMethod = env->GetMethodID(
89-
callbackClass, "onComplete", "(Ljava/lang/String;)V");
87+
callbackCache.callbackClass = (jclass)env->NewGlobalRef(localClass);
88+
callbackCache.onTokenMethod = env->GetMethodID(
89+
callbackCache.callbackClass, "onToken", "(Ljava/lang/String;)V");
90+
callbackCache.onCompleteMethod = env->GetMethodID(
91+
callbackCache.callbackClass, "onComplete", "(Ljava/lang/String;)V");
9092
env->DeleteLocalRef(localClass);
91-
initialized = true;
9293
}
93-
}
94-
};
94+
});
95+
}
9596

96-
thread_local AsrCallbackCache callbackCache;
97+
// Helper to create a unique_ptr for JNI global references
98+
auto make_scoped_global_ref(JNIEnv* env, jobject obj) {
99+
auto deleter = [env](jobject ref) {
100+
if (ref != nullptr) {
101+
env->DeleteGlobalRef(ref);
102+
}
103+
};
104+
jobject globalRef = obj ? env->NewGlobalRef(obj) : nullptr;
105+
return std::unique_ptr<std::remove_pointer_t<jobject>, decltype(deleter)>(
106+
globalRef, deleter);
107+
}
97108

98109
} // namespace
99110

@@ -222,6 +233,17 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
222233
return -1;
223234
}
224235

236+
// Validate dimension parameters are positive
237+
if (batchSize <= 0 || timeSteps <= 0 || featureDim <= 0) {
238+
env->ThrowNew(
239+
env->FindClass("java/lang/IllegalArgumentException"),
240+
("Dimensions must be positive: batchSize=" + std::to_string(batchSize) +
241+
", timeSteps=" + std::to_string(timeSteps) +
242+
", featureDim=" + std::to_string(featureDim))
243+
.c_str());
244+
return -1;
245+
}
246+
225247
auto* runner = reinterpret_cast<asr::AsrRunner*>(nativeHandle);
226248

227249
// Get features from Java array
@@ -233,11 +255,27 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
233255
return -1;
234256
}
235257

236-
// Copy feature data
258+
// Validate that dimensions match the array length
259+
jsize expectedLen =
260+
static_cast<jsize>(batchSize) * timeSteps * featureDim;
261+
if (featuresLen != expectedLen) {
262+
env->ThrowNew(
263+
env->FindClass("java/lang/IllegalArgumentException"),
264+
("Features array length (" + std::to_string(featuresLen) +
265+
") does not match dimensions: " + std::to_string(batchSize) + " x " +
266+
std::to_string(timeSteps) + " x " + std::to_string(featureDim) +
267+
" = " + std::to_string(expectedLen))
268+
.c_str());
269+
return -1;
270+
}
271+
272+
// Copy feature data - this vector must remain in scope for the duration of
273+
// transcribe() since from_blob creates a non-owning view over the data.
237274
std::vector<float> featuresData(featuresLen);
238275
env->GetFloatArrayRegion(features, 0, featuresLen, featuresData.data());
239276

240-
// Create tensor from features
277+
// Create tensor from features. Note: from_blob does NOT copy the data,
278+
// it creates a view. featuresData must outlive the tensor and transcribe().
241279
auto featuresTensor = ::executorch::extension::from_blob(
242280
featuresData.data(),
243281
{static_cast<::executorch::aten::SizesType>(batchSize),
@@ -254,17 +292,16 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
254292
// Set up callback
255293
std::function<void(const std::string&)> tokenCallback = nullptr;
256294

257-
// We need to keep a global ref to the callback for the duration of
258-
// transcription
259-
jobject globalCallback = nullptr;
260-
if (callback != nullptr) {
261-
globalCallback = env->NewGlobalRef(callback);
262-
callbackCache.init(env);
295+
// Use unique_ptr with custom deleter to ensure global ref is released
296+
auto scopedCallback = make_scoped_global_ref(env, callback);
297+
if (scopedCallback) {
298+
initCallbackCache(env);
263299

264300
// Reset token buffer
265301
asr_token_buffer.clear();
266302

267-
tokenCallback = [env, globalCallback](const std::string& token) {
303+
jobject callbackRef = scopedCallback.get();
304+
tokenCallback = [env, callbackRef](const std::string& token) {
268305
asr_token_buffer += token;
269306
if (!utf8_check_validity(
270307
asr_token_buffer.c_str(), asr_token_buffer.size())) {
@@ -277,7 +314,11 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
277314
asr_token_buffer.clear();
278315

279316
jstring jToken = env->NewStringUTF(completeToken.c_str());
280-
env->CallVoidMethod(globalCallback, callbackCache.onTokenMethod, jToken);
317+
env->CallVoidMethod(callbackRef, callbackCache.onTokenMethod, jToken);
318+
if (env->ExceptionCheck()) {
319+
ET_LOG(Error, "Exception occurred in AsrCallback.onToken");
320+
env->ExceptionClear();
321+
}
281322
env->DeleteLocalRef(jToken);
282323
};
283324
}
@@ -286,12 +327,15 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
286327
auto result = runner->transcribe(featuresTensor, config, tokenCallback);
287328

288329
// Call onComplete if callback provided
289-
if (globalCallback != nullptr) {
330+
if (scopedCallback) {
290331
jstring emptyStr = env->NewStringUTF("");
291332
env->CallVoidMethod(
292-
globalCallback, callbackCache.onCompleteMethod, emptyStr);
333+
scopedCallback.get(), callbackCache.onCompleteMethod, emptyStr);
334+
if (env->ExceptionCheck()) {
335+
ET_LOG(Error, "Exception occurred in AsrCallback.onComplete");
336+
env->ExceptionClear();
337+
}
293338
env->DeleteLocalRef(emptyStr);
294-
env->DeleteGlobalRef(globalCallback);
295339
}
296340

297341
if (!result.ok()) {

0 commit comments

Comments
 (0)