Skip to content

Commit e72eecd

Browse files
committed
Android: Closeable lifecycle, error reporting, thread safety, and test coverage (pytorch#19229)
Summary: Combines all previously reverted Android improvement PRs (pytorch#18669, pytorch#19012, pytorch#19028, pytorch#19092, pytorch#19099, pytorch#19124) plus new Module lifecycle tests into a single atomic change. Error reporting: all sync errors throw exceptions instead of returning status codes or logging silently. Module.loadMethod() throws ExecutorchRuntimeException on failure. LlmModule.generate() and load() throw on error. Native methods renamed with "Native" suffix; public wrappers check status and throw. Lifecycle: Module, LlmModule, and TrainingModule implement Closeable for try-with-resources. LlmModule adds ReentrantLock to serialize access to non-thread-safe native state. stop() remains lock-free (C++ atomic flag) with volatile mDestroyed guard to prevent use-after-close. TrainingModule replaces Log.e silent failures with IllegalStateException. Thread safety: LlmModule adds checkNotReentrant() guard to all lock-acquiring public methods, preventing callbacks from re-entering module methods and corrupting native state. TrainingModule.mDestroyed made volatile for cross-thread visibility. Error consistency: ExecuTorchRuntime.validateFilePath throws IllegalArgumentException. SGD throws IllegalStateException. AsrModule throws ExecutorchRuntimeException. Cause-chaining constructor added to ExecutorchRuntimeException. JNI safety: Module and LlmModule constructors wrapped in try-catch to surface native initialization failures. LlmModule.load() uses throwExecutorchException with diagnostic detail. All DoNotStrip annotations preserved on JNI-called methods. std::move added for etdump_gen unique_ptr in profiling build path. LlmModelRunner generate path wrapped in try-catch to prevent HandlerThread death. Test Plan: fix 4 Ignored Module tests by providing required input tensor. Add 13 new lifecycle/API coverage tests. Add LlmModule use-after-close and idempotent close tests. Add @after tearDown to LlmModuleInstrumentationTest to prevent native resource leaks. Remove dead testMethodMetadata test. Fix testNonPteFile cleanup with assertThrows + finally. This diff updates callers of Module.loadMethod() to use try/catch instead of checking the return code, calling destroy() on failure. Co-authored-by: Claude <noreply@anthropic.com> Reviewed By: kirklandsign Differential Revision: D103233465 Pulled By: psiddh
1 parent fe98297 commit e72eecd

13 files changed

Lines changed: 876 additions & 386 deletions

File tree

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import java.nio.ByteOrder
1616
import org.apache.commons.io.FileUtils
1717
import org.json.JSONException
1818
import org.json.JSONObject
19+
import org.junit.After
1920
import org.junit.Assert.assertEquals
2021
import org.junit.Assert.assertThrows
2122
import org.junit.Assert.assertTrue
@@ -51,12 +52,17 @@ class LlmModuleInstrumentationTest : LlmCallback {
5152
LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f)
5253
}
5354

55+
@After
56+
fun tearDown() {
57+
if (::llmModule.isInitialized) {
58+
llmModule.close()
59+
}
60+
}
61+
5462
@Test
5563
@Throws(IOException::class, URISyntaxException::class)
5664
fun testGenerate() {
57-
val loadResult = llmModule.load()
58-
// Check that the model can be load successfully
59-
assertEquals(OK.toLong(), loadResult.toLong())
65+
llmModule.load()
6066

6167
llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
6268
assertEquals(results.size.toLong(), SEQ_LEN.toLong())
@@ -273,11 +279,26 @@ class LlmModuleInstrumentationTest : LlmCallback {
273279
}
274280
}
275281

282+
// --- Lifecycle tests ---
283+
284+
@Test
285+
fun testUseAfterCloseThrows() {
286+
llmModule.close()
287+
assertThrows(IllegalStateException::class.java) {
288+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
289+
}
290+
}
291+
292+
@Test
293+
fun testCloseIsIdempotent() {
294+
llmModule.close()
295+
llmModule.close()
296+
}
297+
276298
companion object {
277299
private const val TEST_FILE_NAME = "/stories.pte"
278300
private const val TOKENIZER_FILE_NAME = "/tokenizer.bin"
279301
private const val TEST_PROMPT = "Hello"
280-
private const val OK = 0x00
281302
private const val SEQ_LEN = 32
282303
}
283304
}

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt

Lines changed: 180 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import java.util.concurrent.atomic.AtomicInteger
1717
import org.apache.commons.io.FileUtils
1818
import org.junit.Assert
1919
import org.junit.Before
20-
import org.junit.Ignore
2120
import org.junit.Test
2221
import org.junit.runner.RunWith
2322
import org.pytorch.executorch.TestFileUtils.getTestFilePath
@@ -40,49 +39,42 @@ class ModuleInstrumentationTest {
4039
inputStream.close()
4140
}
4241

43-
@Ignore(
44-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
45-
)
4642
@Test
4743
@Throws(IOException::class, URISyntaxException::class)
4844
fun testModuleLoadAndForward() {
4945
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
50-
51-
val results = module.forward()
52-
Assert.assertTrue(results[0].isTensor)
53-
}
54-
55-
@Test
56-
@Throws(IOException::class, URISyntaxException::class)
57-
fun testMethodMetadata() {
58-
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
46+
try {
47+
val results = module.forward(EValue.from(dummyInput()))
48+
Assert.assertTrue(results[0].isTensor)
49+
} finally {
50+
module.destroy()
51+
}
5952
}
6053

61-
@Ignore(
62-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
63-
)
6454
@Test
6555
@Throws(IOException::class)
6656
fun testModuleLoadMethodAndForward() {
6757
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
58+
try {
59+
module.loadMethod(FORWARD_METHOD)
6860

69-
val loadMethod = module.loadMethod(FORWARD_METHOD)
70-
Assert.assertEquals(loadMethod.toLong(), OK.toLong())
71-
72-
val results = module.forward()
73-
Assert.assertTrue(results[0].isTensor)
61+
val results = module.forward(EValue.from(dummyInput()))
62+
Assert.assertTrue(results[0].isTensor)
63+
} finally {
64+
module.destroy()
65+
}
7466
}
7567

76-
@Ignore(
77-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
78-
)
7968
@Test
8069
@Throws(IOException::class)
8170
fun testModuleLoadForwardExplicit() {
8271
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
83-
84-
val results = module.execute(FORWARD_METHOD)
85-
Assert.assertTrue(results[0].isTensor)
72+
try {
73+
val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput()))
74+
Assert.assertTrue(results[0].isTensor)
75+
} finally {
76+
module.destroy()
77+
}
8678
}
8779

8880
@Test(expected = RuntimeException::class)
@@ -95,18 +87,31 @@ class ModuleInstrumentationTest {
9587
@Throws(IOException::class)
9688
fun testModuleLoadMethodNonExistantMethod() {
9789
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
98-
99-
val loadMethod = module.loadMethod(NONE_METHOD)
100-
Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong())
90+
try {
91+
val exception =
92+
Assert.assertThrows(ExecutorchRuntimeException::class.java) {
93+
module.loadMethod(NONE_METHOD)
94+
}
95+
Assert.assertEquals(
96+
ExecutorchRuntimeException.INVALID_ARGUMENT,
97+
exception.getErrorCode(),
98+
)
99+
} finally {
100+
module.destroy()
101+
}
101102
}
102103

103-
@Test(expected = RuntimeException::class)
104+
@Test
104105
@Throws(IOException::class)
105106
fun testNonPteFile() {
106-
val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME))
107-
108-
val loadMethod = module.loadMethod(FORWARD_METHOD)
109-
Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong())
107+
Assert.assertThrows(RuntimeException::class.java) {
108+
val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME))
109+
try {
110+
module.loadMethod(FORWARD_METHOD)
111+
} finally {
112+
module.destroy()
113+
}
114+
}
110115
}
111116

112117
@Test
@@ -116,27 +121,21 @@ class ModuleInstrumentationTest {
116121

117122
module.destroy()
118123

119-
val loadMethod = module.loadMethod(FORWARD_METHOD)
120-
Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong())
124+
Assert.assertThrows(IllegalStateException::class.java) { module.loadMethod(FORWARD_METHOD) }
121125
}
122126

123127
@Test
124128
@Throws(IOException::class)
125129
fun testForwardOnDestroyedModule() {
126130
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
127131

128-
val loadMethod = module.loadMethod(FORWARD_METHOD)
129-
Assert.assertEquals(loadMethod.toLong(), OK.toLong())
132+
module.loadMethod(FORWARD_METHOD)
130133

131134
module.destroy()
132135

133-
val results = module.forward()
134-
Assert.assertEquals(0, results.size.toLong())
136+
Assert.assertThrows(IllegalStateException::class.java) { module.forward() }
135137
}
136138

137-
@Ignore(
138-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
139-
)
140139
@Test
141140
@Throws(InterruptedException::class, IOException::class)
142141
fun testForwardFromMultipleThreads() {
@@ -150,7 +149,7 @@ class ModuleInstrumentationTest {
150149
try {
151150
latch.countDown()
152151
latch.await(5000, TimeUnit.MILLISECONDS)
153-
val results = module.forward()
152+
val results = module.forward(EValue.from(dummyInput()))
154153
Assert.assertTrue(results[0].isTensor)
155154
completed.incrementAndGet()
156155
} catch (_: InterruptedException) {}
@@ -167,6 +166,139 @@ class ModuleInstrumentationTest {
167166
}
168167

169168
Assert.assertEquals(numThreads.toLong(), completed.get().toLong())
169+
module.destroy()
170+
}
171+
172+
// --- Load mode tests ---
173+
174+
@Test
175+
@Throws(IOException::class)
176+
fun testLoadWithMmapMode() {
177+
val module = Module.load(getTestFilePath(TEST_FILE_NAME), Module.LOAD_MODE_MMAP)
178+
try {
179+
val results = module.forward(EValue.from(dummyInput()))
180+
Assert.assertTrue(results[0].isTensor)
181+
} finally {
182+
module.destroy()
183+
}
184+
}
185+
186+
@Test
187+
@Throws(IOException::class)
188+
fun testLoadWithFileMode() {
189+
val module = Module.load(getTestFilePath(TEST_FILE_NAME), Module.LOAD_MODE_FILE)
190+
try {
191+
val results = module.forward(EValue.from(dummyInput()))
192+
Assert.assertTrue(results[0].isTensor)
193+
} finally {
194+
module.destroy()
195+
}
196+
}
197+
198+
// --- getMethods / getMethodMetadata tests ---
199+
200+
@Test
201+
@Throws(IOException::class)
202+
fun testGetMethods() {
203+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
204+
try {
205+
val methods = module.getMethods()
206+
Assert.assertNotNull(methods)
207+
Assert.assertTrue(methods.contains(FORWARD_METHOD))
208+
} finally {
209+
module.destroy()
210+
}
211+
}
212+
213+
@Test
214+
@Throws(IOException::class)
215+
fun testGetMethodMetadata() {
216+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
217+
try {
218+
val metadata = module.getMethodMetadata(FORWARD_METHOD)
219+
Assert.assertNotNull(metadata)
220+
Assert.assertEquals(FORWARD_METHOD, metadata.name)
221+
Assert.assertNotNull(metadata.backends)
222+
} finally {
223+
module.destroy()
224+
}
225+
}
226+
227+
// --- Log buffer tests ---
228+
229+
@Test
230+
@Throws(IOException::class)
231+
fun testReadLogBuffer() {
232+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
233+
try {
234+
val logs = module.readLogBuffer()
235+
Assert.assertNotNull(logs)
236+
} finally {
237+
module.destroy()
238+
}
239+
}
240+
241+
@Test
242+
fun testReadLogBufferStatic() {
243+
val logs = Module.readLogBufferStatic()
244+
Assert.assertNotNull(logs)
245+
}
246+
247+
// --- etdump test ---
248+
249+
@Test
250+
@Throws(IOException::class)
251+
fun testEtdump() {
252+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
253+
try {
254+
module.etdump()
255+
} finally {
256+
module.destroy()
257+
}
258+
}
259+
260+
// --- Destroyed-state tests for remaining methods ---
261+
262+
@Test
263+
@Throws(IOException::class)
264+
fun testGetMethodsOnDestroyedModule() {
265+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
266+
module.destroy()
267+
Assert.assertThrows(IllegalStateException::class.java) { module.getMethods() }
268+
}
269+
270+
@Test
271+
@Throws(IOException::class)
272+
fun testGetMethodMetadataOnDestroyedModule() {
273+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
274+
module.destroy()
275+
Assert.assertThrows(IllegalStateException::class.java) {
276+
module.getMethodMetadata(FORWARD_METHOD)
277+
}
278+
}
279+
280+
@Test
281+
@Throws(IOException::class)
282+
fun testReadLogBufferOnDestroyedModule() {
283+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
284+
module.destroy()
285+
Assert.assertThrows(IllegalStateException::class.java) { module.readLogBuffer() }
286+
}
287+
288+
@Test
289+
@Throws(IOException::class)
290+
fun testEtdumpOnDestroyedModule() {
291+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
292+
module.destroy()
293+
Assert.assertThrows(IllegalStateException::class.java) { module.etdump() }
294+
}
295+
296+
@Test
297+
@Throws(IOException::class)
298+
fun testDoubleDestroyIsSafe() {
299+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
300+
module.destroy()
301+
module.destroy()
170302
}
171303

172304
companion object {
@@ -175,9 +307,8 @@ class ModuleInstrumentationTest {
175307
private const val NON_PTE_FILE_NAME = "/test.txt"
176308
private const val FORWARD_METHOD = "forward"
177309
private const val NONE_METHOD = "none"
178-
private const val OK = 0x00
179-
private const val INVALID_STATE = 0x2
180-
private const val INVALID_ARGUMENT = 0x12
181-
private const val ACCESS_FAILED = 0x22
310+
private val inputShape = longArrayOf(1, 3, 224, 224)
311+
312+
private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT)
182313
}
183314
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,25 @@ public static ExecuTorchRuntime getRuntime() {
3636
/**
3737
* Validates that the given path points to a readable file.
3838
*
39-
* @throws RuntimeException if the file does not exist or is not readable.
39+
* @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not
40+
* readable.
4041
*/
4142
public static void validateFilePath(String path, String description) {
43+
if (path == null) {
44+
throw new IllegalArgumentException("Cannot load " + description + ": path is null");
45+
}
4246
File file = new File(path);
43-
if (!file.canRead() || !file.isFile()) {
44-
throw new RuntimeException("Cannot load " + description + " " + path);
47+
if (!file.exists()) {
48+
throw new IllegalArgumentException(
49+
"Cannot load " + description + ": path does not exist: " + path);
50+
}
51+
if (!file.isFile()) {
52+
throw new IllegalArgumentException(
53+
"Cannot load " + description + ": path is not a file: " + path);
54+
}
55+
if (!file.canRead()) {
56+
throw new IllegalArgumentException(
57+
"Cannot load " + description + ": path is not readable: " + path);
4558
}
4659
}
4760

0 commit comments

Comments
 (0)