Skip to content

Commit 4558d37

Browse files
psiddhclaude
andcommitted
Android: Fix thread safety, use-after-close, and test resource leaks
- LlmModule: make mDestroyed volatile for cross-thread stop() visibility - LlmModule: add use-after-close guard to stop() via stopNative() rename - LlmModule: add checkNotReentrant() to prevent callback re-entrancy - jni_layer.cpp: add std::move(etdump_gen) to fix profiling build - jni_layer_llama.cpp: update JNI registration for stopNative - LlmModelRunner: wrap generate() in try-catch to prevent HandlerThread death - LlmModuleInstrumentationTest: add @after tearDown to prevent native leaks - ModuleInstrumentationTest: remove dead testMethodMetadata, fix testNonPteFile cleanup - TrainingModule: make mDestroyed volatile for thread safety Co-authored-by: Claude <noreply@anthropic.com>
1 parent 920259a commit 4558d37

7 files changed

Lines changed: 51 additions & 17 deletions

File tree

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import java.nio.ByteOrder
1616
import org.apache.commons.io.FileUtils
1717
import org.json.JSONException
1818
import org.json.JSONObject
19+
import org.junit.After
1920
import org.junit.Assert.assertEquals
2021
import org.junit.Assert.assertThrows
2122
import org.junit.Assert.assertTrue
@@ -51,6 +52,13 @@ class LlmModuleInstrumentationTest : LlmCallback {
5152
LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f)
5253
}
5354

55+
@After
56+
fun tearDown() {
57+
if (::llmModule.isInitialized) {
58+
llmModule.close()
59+
}
60+
}
61+
5462
@Test
5563
@Throws(IOException::class, URISyntaxException::class)
5664
fun testGenerate() {

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,6 @@ class ModuleInstrumentationTest {
5151
}
5252
}
5353

54-
@Test
55-
@Throws(IOException::class, URISyntaxException::class)
56-
fun testMethodMetadata() {
57-
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
58-
module.destroy()
59-
}
60-
6154
@Test
6255
@Throws(IOException::class)
6356
fun testModuleLoadMethodAndForward() {
@@ -108,12 +101,17 @@ class ModuleInstrumentationTest {
108101
}
109102
}
110103

111-
@Test(expected = RuntimeException::class)
104+
@Test
112105
@Throws(IOException::class)
113106
fun testNonPteFile() {
114-
val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME))
115-
116-
module.loadMethod(FORWARD_METHOD)
107+
Assert.assertThrows(RuntimeException::class.java) {
108+
val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME))
109+
try {
110+
module.loadMethod(FORWARD_METHOD)
111+
} finally {
112+
module.destroy()
113+
}
114+
}
117115
}
118116

119117
@Test

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class LlmModule implements Closeable {
3333

3434
private final HybridData mHybridData;
3535
private final ReentrantLock mLock = new ReentrantLock();
36-
private boolean mDestroyed = false;
36+
private volatile boolean mDestroyed = false;
3737
private static final int DEFAULT_SEQ_LEN = 128;
3838
private static final boolean DEFAULT_ECHO = true;
3939
private static final float DEFAULT_TEMPERATURE = -1.0f;
@@ -194,6 +194,12 @@ private void checkNotDestroyed() {
194194
if (mDestroyed) throw new IllegalStateException("LlmModule has been destroyed");
195195
}
196196

197+
private void checkNotReentrant() {
198+
if (mLock.getHoldCount() > 1) {
199+
throw new IllegalStateException("Cannot call LlmModule methods from within a callback");
200+
}
201+
}
202+
197203
/**
198204
* Releases native resources. Callers must ensure no other methods are in-flight. Call {@link
199205
* #stop()} and wait for {@link #generate(String, LlmCallback)} to return before calling this
@@ -321,6 +327,7 @@ public void generate(
321327
int numEos) {
322328
mLock.lock();
323329
try {
330+
checkNotReentrant();
324331
checkNotDestroyed();
325332
int err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos);
326333
if (err != 0) {
@@ -458,6 +465,7 @@ public void generate(
458465
int numEos) {
459466
mLock.lock();
460467
try {
468+
checkNotReentrant();
461469
checkNotDestroyed();
462470
if (image != null) {
463471
int nativeResult = prefillImagesInput(image, width, height, channels);
@@ -487,6 +495,7 @@ public void generate(
487495
public void prefillImages(int[] image, int width, int height, int channels) {
488496
mLock.lock();
489497
try {
498+
checkNotReentrant();
490499
checkNotDestroyed();
491500
int nativeResult = prefillImagesInput(image, width, height, channels);
492501
if (nativeResult != 0) {
@@ -516,6 +525,7 @@ public void prefillImages(int[] image, int width, int height, int channels) {
516525
public void prefillImages(ByteBuffer image, int width, int height, int channels) {
517526
mLock.lock();
518527
try {
528+
checkNotReentrant();
519529
checkNotDestroyed();
520530
if (!image.isDirect()) {
521531
throw new IllegalArgumentException("Input ByteBuffer must be direct.");
@@ -573,6 +583,7 @@ public void prefillImages(ByteBuffer image, int width, int height, int channels)
573583
public void prefillNormalizedImage(ByteBuffer image, int width, int height, int channels) {
574584
mLock.lock();
575585
try {
586+
checkNotReentrant();
576587
checkNotDestroyed();
577588
if (!image.isDirect()) {
578589
throw new IllegalArgumentException("Input ByteBuffer must be direct.");
@@ -646,6 +657,7 @@ private native int prefillNormalizedImagesInputBuffer(
646657
public void prefillImages(float[] image, int width, int height, int channels) {
647658
mLock.lock();
648659
try {
660+
checkNotReentrant();
649661
checkNotDestroyed();
650662
int nativeResult = prefillNormalizedImagesInput(image, width, height, channels);
651663
if (nativeResult != 0) {
@@ -672,6 +684,7 @@ private native int prefillNormalizedImagesInput(
672684
public void prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) {
673685
mLock.lock();
674686
try {
687+
checkNotReentrant();
675688
checkNotDestroyed();
676689
int nativeResult = prefillAudioInput(audio, batch_size, n_bins, n_frames);
677690
if (nativeResult != 0) {
@@ -697,6 +710,7 @@ public void prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames)
697710
public void prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) {
698711
mLock.lock();
699712
try {
713+
checkNotReentrant();
700714
checkNotDestroyed();
701715
int nativeResult = prefillAudioInputFloat(audio, batch_size, n_bins, n_frames);
702716
if (nativeResult != 0) {
@@ -723,6 +737,7 @@ private native int prefillAudioInputFloat(
723737
public void prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) {
724738
mLock.lock();
725739
try {
740+
checkNotReentrant();
726741
checkNotDestroyed();
727742
int nativeResult = prefillRawAudioInput(audio, batch_size, n_channels, n_samples);
728743
if (nativeResult != 0) {
@@ -746,6 +761,7 @@ private native int prefillRawAudioInput(
746761
public void prefillPrompt(String prompt) {
747762
mLock.lock();
748763
try {
764+
checkNotReentrant();
749765
checkNotDestroyed();
750766
int nativeResult = prefillTextInput(prompt);
751767
if (nativeResult != 0) {
@@ -767,6 +783,7 @@ public void prefillPrompt(String prompt) {
767783
public void resetContext() {
768784
mLock.lock();
769785
try {
786+
checkNotReentrant();
770787
checkNotDestroyed();
771788
resetContextNative();
772789
} finally {
@@ -778,14 +795,20 @@ public void resetContext() {
778795
private native void resetContextNative();
779796

780797
/** Stop current generate() before it finishes. */
798+
public void stop() {
799+
if (mDestroyed) return;
800+
stopNative();
801+
}
802+
781803
@DoNotStrip
782-
public native void stop();
804+
private native void stopNative();
783805

784806
/** Force loading the module. Otherwise the model is loaded during first generate(). */
785807
@DoNotStrip
786808
public void load() {
787809
mLock.lock();
788810
try {
811+
checkNotReentrant();
789812
checkNotDestroyed();
790813
int err = loadNative();
791814
if (err != 0) {

extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public class TrainingModule implements Closeable {
3636
}
3737

3838
private final HybridData mHybridData;
39-
private boolean mDestroyed = false;
39+
private volatile boolean mDestroyed = false;
4040

4141
@DoNotStrip
4242
private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath);

extension/android/jni/jni_layer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
286286
#endif
287287
try {
288288
module_ = std::make_unique<Module>(
289-
modelPath->toStdString(), load_mode, etdump_gen);
289+
modelPath->toStdString(), load_mode, std::move(etdump_gen));
290290
} catch (const std::exception& e) {
291291
executorch::jni_helper::throwExecutorchException(
292292
static_cast<uint32_t>(Error::Internal),

extension/android/jni/jni_layer_llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
630630
registerHybrid({
631631
makeNativeMethod("initHybrid", ExecuTorchLlmJni::initHybrid),
632632
makeNativeMethod("generateNative", ExecuTorchLlmJni::generate),
633-
makeNativeMethod("stop", ExecuTorchLlmJni::stop),
633+
makeNativeMethod("stopNative", ExecuTorchLlmJni::stop),
634634
makeNativeMethod("loadNative", ExecuTorchLlmJni::load),
635635
makeNativeMethod(
636636
"prefillImagesInput", ExecuTorchLlmJni::prefill_images_input),

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import android.os.HandlerThread;
1313
import android.os.Looper;
1414
import android.os.Message;
15+
import android.util.Log;
1516
import org.pytorch.executorch.extension.llm.LlmCallback;
1617
import org.pytorch.executorch.extension.llm.LlmModule;
1718

@@ -98,7 +99,11 @@ public void handleMessage(android.os.Message msg) {
9899
}
99100
mLlmModelRunner.mCallback.onModelLoaded(status);
100101
} else if (msg.what == MESSAGE_GENERATE) {
101-
mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner);
102+
try {
103+
mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner);
104+
} catch (Exception e) {
105+
Log.e("LlmModelRunner", "generate() failed", e);
106+
}
102107
mLlmModelRunner.mCallback.onGenerationStopped();
103108
}
104109
}

0 commit comments

Comments
 (0)