1010
1111#include < cstdint>
1212#include < memory>
13+ #include < mutex>
1314#include < string>
1415#include < unordered_set>
1516#include < vector>
@@ -68,32 +69,42 @@ bool utf8_check_validity(const char* str, size_t length) {
6869// Thread-local token buffer for UTF-8 accumulation
6970thread_local std::string asr_token_buffer;
7071
71- // Cached JNI references for callback
72+ // Global cached JNI references for callback (shared across threads)
7273struct AsrCallbackCache {
7374 jclass callbackClass = nullptr ;
7475 jmethodID onTokenMethod = nullptr ;
7576 jmethodID onCompleteMethod = nullptr ;
76- bool initialized = false ;
77+ } ;
7778
78- void init (JNIEnv* env) {
79- if (initialized) {
80- return ;
81- }
79+ AsrCallbackCache callbackCache;
80+ std::once_flag callbackCacheInitFlag;
81+
82+ void initCallbackCache (JNIEnv* env) {
83+ std::call_once (callbackCacheInitFlag, [env]() {
8284 jclass localClass =
8385 env->FindClass (" org/pytorch/executorch/extension/asr/AsrCallback" );
8486 if (localClass != nullptr ) {
85- callbackClass = (jclass)env->NewGlobalRef (localClass);
86- onTokenMethod =
87- env-> GetMethodID ( callbackClass, " onToken" , " (Ljava/lang/String;)V" );
88- onCompleteMethod = env->GetMethodID (
89- callbackClass, " onComplete" , " (Ljava/lang/String;)V" );
87+ callbackCache. callbackClass = (jclass)env->NewGlobalRef (localClass);
88+ callbackCache. onTokenMethod = env-> GetMethodID (
89+ callbackCache. callbackClass , " onToken" , " (Ljava/lang/String;)V" );
90+ callbackCache. onCompleteMethod = env->GetMethodID (
91+ callbackCache. callbackClass , " onComplete" , " (Ljava/lang/String;)V" );
9092 env->DeleteLocalRef (localClass);
91- initialized = true ;
9293 }
93- }
94- };
94+ });
95+ }
9596
96- thread_local AsrCallbackCache callbackCache;
97+ // Helper to create a unique_ptr for JNI global references
98+ auto make_scoped_global_ref (JNIEnv* env, jobject obj) {
99+ auto deleter = [env](jobject ref) {
100+ if (ref != nullptr ) {
101+ env->DeleteGlobalRef (ref);
102+ }
103+ };
104+ jobject globalRef = obj ? env->NewGlobalRef (obj) : nullptr ;
105+ return std::unique_ptr<std::remove_pointer_t <jobject>, decltype (deleter)>(
106+ globalRef, deleter);
107+ }
97108
98109} // namespace
99110
@@ -222,6 +233,17 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
222233 return -1 ;
223234 }
224235
236+ // Validate dimension parameters are positive
237+ if (batchSize <= 0 || timeSteps <= 0 || featureDim <= 0 ) {
238+ env->ThrowNew (
239+ env->FindClass (" java/lang/IllegalArgumentException" ),
240+ (" Dimensions must be positive: batchSize=" + std::to_string (batchSize) +
241+ " , timeSteps=" + std::to_string (timeSteps) +
242+ " , featureDim=" + std::to_string (featureDim))
243+ .c_str ());
244+ return -1 ;
245+ }
246+
225247 auto * runner = reinterpret_cast <asr::AsrRunner*>(nativeHandle);
226248
227249 // Get features from Java array
@@ -233,11 +255,27 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
233255 return -1 ;
234256 }
235257
236- // Copy feature data
258+ // Validate that dimensions match the array length
259+ jsize expectedLen =
260+ static_cast <jsize>(batchSize) * timeSteps * featureDim;
261+ if (featuresLen != expectedLen) {
262+ env->ThrowNew (
263+ env->FindClass (" java/lang/IllegalArgumentException" ),
264+ (" Features array length (" + std::to_string (featuresLen) +
265+ " ) does not match dimensions: " + std::to_string (batchSize) + " x " +
266+ std::to_string (timeSteps) + " x " + std::to_string (featureDim) +
267+ " = " + std::to_string (expectedLen))
268+ .c_str ());
269+ return -1 ;
270+ }
271+
272+ // Copy feature data - this vector must remain in scope for the duration of
273+ // transcribe() since from_blob creates a non-owning view over the data.
237274 std::vector<float > featuresData (featuresLen);
238275 env->GetFloatArrayRegion (features, 0 , featuresLen, featuresData.data ());
239276
240- // Create tensor from features
277+ // Create tensor from features. Note: from_blob does NOT copy the data,
278+ // it creates a view. featuresData must outlive the tensor and transcribe().
241279 auto featuresTensor = ::executorch::extension::from_blob (
242280 featuresData.data (),
243281 {static_cast <::executorch::aten::SizesType>(batchSize),
@@ -254,17 +292,16 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
254292 // Set up callback
255293 std::function<void (const std::string&)> tokenCallback = nullptr ;
256294
257- // We need to keep a global ref to the callback for the duration of
258- // transcription
259- jobject globalCallback = nullptr ;
260- if (callback != nullptr ) {
261- globalCallback = env->NewGlobalRef (callback);
262- callbackCache.init (env);
295+ // Use unique_ptr with custom deleter to ensure global ref is released
296+ auto scopedCallback = make_scoped_global_ref (env, callback);
297+ if (scopedCallback) {
298+ initCallbackCache (env);
263299
264300 // Reset token buffer
265301 asr_token_buffer.clear ();
266302
267- tokenCallback = [env, globalCallback](const std::string& token) {
303+ jobject callbackRef = scopedCallback.get ();
304+ tokenCallback = [env, callbackRef](const std::string& token) {
268305 asr_token_buffer += token;
269306 if (!utf8_check_validity (
270307 asr_token_buffer.c_str (), asr_token_buffer.size ())) {
@@ -277,7 +314,11 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
277314 asr_token_buffer.clear ();
278315
279316 jstring jToken = env->NewStringUTF (completeToken.c_str ());
280- env->CallVoidMethod (globalCallback, callbackCache.onTokenMethod , jToken);
317+ env->CallVoidMethod (callbackRef, callbackCache.onTokenMethod , jToken);
318+ if (env->ExceptionCheck ()) {
319+ ET_LOG (Error, " Exception occurred in AsrCallback.onToken" );
320+ env->ExceptionClear ();
321+ }
281322 env->DeleteLocalRef (jToken);
282323 };
283324 }
@@ -286,12 +327,15 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
286327 auto result = runner->transcribe (featuresTensor, config, tokenCallback);
287328
288329 // Call onComplete if callback provided
289- if (globalCallback != nullptr ) {
330+ if (scopedCallback ) {
290331 jstring emptyStr = env->NewStringUTF (" " );
291332 env->CallVoidMethod (
292- globalCallback, callbackCache.onCompleteMethod , emptyStr);
333+ scopedCallback.get (), callbackCache.onCompleteMethod , emptyStr);
334+ if (env->ExceptionCheck ()) {
335+ ET_LOG (Error, " Exception occurred in AsrCallback.onComplete" );
336+ env->ExceptionClear ();
337+ }
293338 env->DeleteLocalRef (emptyStr);
294- env->DeleteGlobalRef (globalCallback);
295339 }
296340
297341 if (!result.ok ()) {
0 commit comments