Skip to content

Commit 04727f0

Browse files
committed
Android: Closeable lifecycle, error reporting, thread safety, and test coverage
Combines all previously reverted Android improvement PRs (#18669, #19012, #19028, #19092, #19099, #19124) plus new Module lifecycle tests into a single atomic change. Error reporting: all sync errors throw exceptions instead of returning status codes or logging silently. Module.loadMethod() throws ExecutorchRuntimeException on failure. LlmModule.generate() and load() throw on error. Native methods renamed with "Native" suffix; public wrappers check status and throw. Lifecycle: Module, LlmModule, and TrainingModule implement Closeable for try-with-resources. LlmModule adds ReentrantLock to serialize access to non-thread-safe native state. stop() remains lock-free (C++ atomic flag). TrainingModule replaces Log.e silent failures with IllegalStateException. Error consistency: ExecuTorchRuntime.validateFilePath throws IllegalArgumentException. SGD throws IllegalStateException. AsrModule throws ExecutorchRuntimeException. Cause-chaining constructor added to ExecutorchRuntimeException. JNI safety: Module and LlmModule constructors wrapped in try-catch to surface native initialization failures. LlmModule.load() uses throwExecutorchException with diagnostic detail. All @DoNotStrip annotations preserved on JNI-called methods; new @DoNotStrip added to renamed private native methods (generateNative, resetContextNative, loadNative). Tests: fix 4 @ignored Module tests by providing required input tensor. Add 13 new lifecycle/API coverage tests. Add LlmModule use-after-close and idempotent-close tests. Fix ModelRunner crash on loadMethod failure. This PR was authored with the help of Claude.
1 parent e84a418 commit 04727f0

13 files changed

Lines changed: 837 additions & 380 deletions

File tree

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ class LlmModuleInstrumentationTest : LlmCallback {
5454
@Test
5555
@Throws(IOException::class, URISyntaxException::class)
5656
fun testGenerate() {
57-
val loadResult = llmModule.load()
58-
// Check that the model can be load successfully
59-
assertEquals(OK.toLong(), loadResult.toLong())
57+
llmModule.load()
6058

6159
llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
6260
assertEquals(results.size.toLong(), SEQ_LEN.toLong())
@@ -273,11 +271,26 @@ class LlmModuleInstrumentationTest : LlmCallback {
273271
}
274272
}
275273

274+
// --- Lifecycle tests ---
275+
276+
@Test
277+
fun testUseAfterCloseThrows() {
278+
llmModule.close()
279+
assertThrows(IllegalStateException::class.java) {
280+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
281+
}
282+
}
283+
284+
@Test
285+
fun testCloseIsIdempotent() {
286+
llmModule.close()
287+
llmModule.close()
288+
}
289+
276290
companion object {
277291
private const val TEST_FILE_NAME = "/stories.pte"
278292
private const val TOKENIZER_FILE_NAME = "/tokenizer.bin"
279293
private const val TEST_PROMPT = "Hello"
280-
private const val OK = 0x00
281294
private const val SEQ_LEN = 32
282295
}
283296
}

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt

Lines changed: 172 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import java.util.concurrent.atomic.AtomicInteger
1717
import org.apache.commons.io.FileUtils
1818
import org.junit.Assert
1919
import org.junit.Before
20-
import org.junit.Ignore
2120
import org.junit.Test
2221
import org.junit.runner.RunWith
2322
import org.pytorch.executorch.TestFileUtils.getTestFilePath
@@ -40,49 +39,42 @@ class ModuleInstrumentationTest {
4039
inputStream.close()
4140
}
4241

43-
@Ignore(
44-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
45-
)
4642
@Test
4743
@Throws(IOException::class, URISyntaxException::class)
4844
fun testModuleLoadAndForward() {
4945
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
50-
51-
val results = module.forward()
52-
Assert.assertTrue(results[0].isTensor)
53-
}
54-
55-
@Test
56-
@Throws(IOException::class, URISyntaxException::class)
57-
fun testMethodMetadata() {
58-
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
46+
try {
47+
val results = module.forward(EValue.from(dummyInput()))
48+
Assert.assertTrue(results[0].isTensor)
49+
} finally {
50+
module.destroy()
51+
}
5952
}
6053

61-
@Ignore(
62-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
63-
)
6454
@Test
6555
@Throws(IOException::class)
6656
fun testModuleLoadMethodAndForward() {
6757
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
58+
try {
59+
module.loadMethod(FORWARD_METHOD)
6860

69-
val loadMethod = module.loadMethod(FORWARD_METHOD)
70-
Assert.assertEquals(loadMethod.toLong(), OK.toLong())
71-
72-
val results = module.forward()
73-
Assert.assertTrue(results[0].isTensor)
61+
val results = module.forward(EValue.from(dummyInput()))
62+
Assert.assertTrue(results[0].isTensor)
63+
} finally {
64+
module.destroy()
65+
}
7466
}
7567

76-
@Ignore(
77-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
78-
)
7968
@Test
8069
@Throws(IOException::class)
8170
fun testModuleLoadForwardExplicit() {
8271
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
83-
84-
val results = module.execute(FORWARD_METHOD)
85-
Assert.assertTrue(results[0].isTensor)
72+
try {
73+
val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput()))
74+
Assert.assertTrue(results[0].isTensor)
75+
} finally {
76+
module.destroy()
77+
}
8678
}
8779

8880
@Test(expected = RuntimeException::class)
@@ -95,18 +87,26 @@ class ModuleInstrumentationTest {
9587
@Throws(IOException::class)
9688
fun testModuleLoadMethodNonExistantMethod() {
9789
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
98-
99-
val loadMethod = module.loadMethod(NONE_METHOD)
100-
Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong())
90+
try {
91+
val exception =
92+
Assert.assertThrows(ExecutorchRuntimeException::class.java) {
93+
module.loadMethod(NONE_METHOD)
94+
}
95+
Assert.assertEquals(
96+
ExecutorchRuntimeException.INVALID_ARGUMENT,
97+
exception.getErrorCode(),
98+
)
99+
} finally {
100+
module.destroy()
101+
}
101102
}
102103

103104
@Test(expected = RuntimeException::class)
104105
@Throws(IOException::class)
105106
fun testNonPteFile() {
106107
val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME))
107108

108-
val loadMethod = module.loadMethod(FORWARD_METHOD)
109-
Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong())
109+
module.loadMethod(FORWARD_METHOD)
110110
}
111111

112112
@Test
@@ -116,27 +116,21 @@ class ModuleInstrumentationTest {
116116

117117
module.destroy()
118118

119-
val loadMethod = module.loadMethod(FORWARD_METHOD)
120-
Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong())
119+
Assert.assertThrows(IllegalStateException::class.java) { module.loadMethod(FORWARD_METHOD) }
121120
}
122121

123122
@Test
124123
@Throws(IOException::class)
125124
fun testForwardOnDestroyedModule() {
126125
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
127126

128-
val loadMethod = module.loadMethod(FORWARD_METHOD)
129-
Assert.assertEquals(loadMethod.toLong(), OK.toLong())
127+
module.loadMethod(FORWARD_METHOD)
130128

131129
module.destroy()
132130

133-
val results = module.forward()
134-
Assert.assertEquals(0, results.size.toLong())
131+
Assert.assertThrows(IllegalStateException::class.java) { module.forward() }
135132
}
136133

137-
@Ignore(
138-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
139-
)
140134
@Test
141135
@Throws(InterruptedException::class, IOException::class)
142136
fun testForwardFromMultipleThreads() {
@@ -150,7 +144,7 @@ class ModuleInstrumentationTest {
150144
try {
151145
latch.countDown()
152146
latch.await(5000, TimeUnit.MILLISECONDS)
153-
val results = module.forward()
147+
val results = module.forward(EValue.from(dummyInput()))
154148
Assert.assertTrue(results[0].isTensor)
155149
completed.incrementAndGet()
156150
} catch (_: InterruptedException) {}
@@ -167,6 +161,139 @@ class ModuleInstrumentationTest {
167161
}
168162

169163
Assert.assertEquals(numThreads.toLong(), completed.get().toLong())
164+
module.destroy()
165+
}
166+
167+
// --- Load mode tests ---
168+
169+
@Test
170+
@Throws(IOException::class)
171+
fun testLoadWithMmapMode() {
172+
val module = Module.load(getTestFilePath(TEST_FILE_NAME), Module.LOAD_MODE_MMAP)
173+
try {
174+
val results = module.forward(EValue.from(dummyInput()))
175+
Assert.assertTrue(results[0].isTensor)
176+
} finally {
177+
module.destroy()
178+
}
179+
}
180+
181+
@Test
182+
@Throws(IOException::class)
183+
fun testLoadWithFileMode() {
184+
val module = Module.load(getTestFilePath(TEST_FILE_NAME), Module.LOAD_MODE_FILE)
185+
try {
186+
val results = module.forward(EValue.from(dummyInput()))
187+
Assert.assertTrue(results[0].isTensor)
188+
} finally {
189+
module.destroy()
190+
}
191+
}
192+
193+
// --- getMethods / getMethodMetadata tests ---
194+
195+
@Test
196+
@Throws(IOException::class)
197+
fun testGetMethods() {
198+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
199+
try {
200+
val methods = module.getMethods()
201+
Assert.assertNotNull(methods)
202+
Assert.assertTrue(methods.contains(FORWARD_METHOD))
203+
} finally {
204+
module.destroy()
205+
}
206+
}
207+
208+
@Test
209+
@Throws(IOException::class)
210+
fun testGetMethodMetadata() {
211+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
212+
try {
213+
val metadata = module.getMethodMetadata(FORWARD_METHOD)
214+
Assert.assertNotNull(metadata)
215+
Assert.assertEquals(FORWARD_METHOD, metadata.name)
216+
Assert.assertNotNull(metadata.backends)
217+
} finally {
218+
module.destroy()
219+
}
220+
}
221+
222+
// --- Log buffer tests ---
223+
224+
@Test
225+
@Throws(IOException::class)
226+
fun testReadLogBuffer() {
227+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
228+
try {
229+
val logs = module.readLogBuffer()
230+
Assert.assertNotNull(logs)
231+
} finally {
232+
module.destroy()
233+
}
234+
}
235+
236+
@Test
237+
fun testReadLogBufferStatic() {
238+
val logs = Module.readLogBufferStatic()
239+
Assert.assertNotNull(logs)
240+
}
241+
242+
// --- etdump test ---
243+
244+
@Test
245+
@Throws(IOException::class)
246+
fun testEtdump() {
247+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
248+
try {
249+
module.etdump()
250+
} finally {
251+
module.destroy()
252+
}
253+
}
254+
255+
// --- Destroyed-state tests for remaining methods ---
256+
257+
@Test
258+
@Throws(IOException::class)
259+
fun testGetMethodsOnDestroyedModule() {
260+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
261+
module.destroy()
262+
Assert.assertThrows(IllegalStateException::class.java) { module.getMethods() }
263+
}
264+
265+
@Test
266+
@Throws(IOException::class)
267+
fun testGetMethodMetadataOnDestroyedModule() {
268+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
269+
module.destroy()
270+
Assert.assertThrows(IllegalStateException::class.java) {
271+
module.getMethodMetadata(FORWARD_METHOD)
272+
}
273+
}
274+
275+
@Test
276+
@Throws(IOException::class)
277+
fun testReadLogBufferOnDestroyedModule() {
278+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
279+
module.destroy()
280+
Assert.assertThrows(IllegalStateException::class.java) { module.readLogBuffer() }
281+
}
282+
283+
@Test
284+
@Throws(IOException::class)
285+
fun testEtdumpOnDestroyedModule() {
286+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
287+
module.destroy()
288+
Assert.assertThrows(IllegalStateException::class.java) { module.etdump() }
289+
}
290+
291+
@Test
292+
@Throws(IOException::class)
293+
fun testDoubleDestroyIsSafe() {
294+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
295+
module.destroy()
296+
module.destroy()
170297
}
171298

172299
companion object {
@@ -175,9 +302,8 @@ class ModuleInstrumentationTest {
175302
private const val NON_PTE_FILE_NAME = "/test.txt"
176303
private const val FORWARD_METHOD = "forward"
177304
private const val NONE_METHOD = "none"
178-
private const val OK = 0x00
179-
private const val INVALID_STATE = 0x2
180-
private const val INVALID_ARGUMENT = 0x12
181-
private const val ACCESS_FAILED = 0x22
305+
private val inputShape = longArrayOf(1, 3, 224, 224)
306+
307+
private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT)
182308
}
183309
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,25 @@ public static ExecuTorchRuntime getRuntime() {
3636
/**
3737
* Validates that the given path points to a readable file.
3838
*
39-
* @throws RuntimeException if the file does not exist or is not readable.
39+
* @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not
40+
* readable.
4041
*/
4142
public static void validateFilePath(String path, String description) {
43+
if (path == null) {
44+
throw new IllegalArgumentException("Cannot load " + description + ": path is null");
45+
}
4246
File file = new File(path);
43-
if (!file.canRead() || !file.isFile()) {
44-
throw new RuntimeException("Cannot load " + description + " " + path);
47+
if (!file.exists()) {
48+
throw new IllegalArgumentException(
49+
"Cannot load " + description + ": path does not exist: " + path);
50+
}
51+
if (!file.isFile()) {
52+
throw new IllegalArgumentException(
53+
"Cannot load " + description + ": path is not a file: " + path);
54+
}
55+
if (!file.canRead()) {
56+
throw new IllegalArgumentException(
57+
"Cannot load " + description + ": path is not readable: " + path);
4558
}
4659
}
4760

0 commit comments

Comments
 (0)