Skip to content

Commit a96bd6b

Browse files
committed
Android: Closeable lifecycle, error reporting, thread safety, and test coverage
Combines all previously reverted Android improvement PRs (#18669, #19012, #19028, #19092, #19099, #19124) plus new Module lifecycle tests into a single atomic change. Error reporting: all sync errors throw exceptions instead of returning status codes or logging silently. Module.loadMethod() throws ExecutorchRuntimeException on failure. LlmModule.generate() and load() throw on error. Native methods renamed with "Native" suffix; public wrappers check status and throw. Lifecycle: Module, LlmModule, and TrainingModule implement Closeable for try-with-resources. LlmModule adds ReentrantLock to serialize access to non-thread-safe native state. stop() remains lock-free (C++ atomic flag). TrainingModule replaces Log.e silent failures with IllegalStateException. Error consistency: ExecuTorchRuntime.validateFilePath throws IllegalArgumentException. SGD throws IllegalStateException. AsrModule throws ExecutorchRuntimeException. Cause-chaining constructor added to ExecutorchRuntimeException. JNI safety: Module and LlmModule constructors wrapped in try-catch to surface native initialization failures. LlmModule.load() uses throwExecutorchException with diagnostic detail. All @DoNotStrip annotations preserved on JNI-called methods; new @DoNotStrip added to renamed private native methods (generateNative, resetContextNative, loadNative). Tests: fix 4 @ignored Module tests by providing required input tensor. Add 13 new lifecycle/API coverage tests. Add LlmModule use-after-close and idempotent-close tests. Fix ModelRunner crash on loadMethod failure. This PR was authored with the help of Claude.
1 parent 524f037 commit a96bd6b

12 files changed

Lines changed: 816 additions & 371 deletions

File tree

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ class LlmModuleInstrumentationTest : LlmCallback {
5454
@Test
5555
@Throws(IOException::class, URISyntaxException::class)
5656
fun testGenerate() {
57-
val loadResult = llmModule.load()
58-
// Check that the model can be load successfully
59-
assertEquals(OK.toLong(), loadResult.toLong())
57+
llmModule.load()
6058

6159
llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
6260
assertEquals(results.size.toLong(), SEQ_LEN.toLong())
@@ -273,11 +271,26 @@ class LlmModuleInstrumentationTest : LlmCallback {
273271
}
274272
}
275273

274+
// --- Lifecycle tests ---
275+
276+
@Test
277+
fun testUseAfterCloseThrows() {
278+
llmModule.close()
279+
assertThrows(IllegalStateException::class.java) {
280+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
281+
}
282+
}
283+
284+
@Test
285+
fun testCloseIsIdempotent() {
286+
llmModule.close()
287+
llmModule.close()
288+
}
289+
276290
companion object {
277291
private const val TEST_FILE_NAME = "/stories.pte"
278292
private const val TOKENIZER_FILE_NAME = "/tokenizer.bin"
279293
private const val TEST_PROMPT = "Hello"
280-
private const val OK = 0x00
281294
private const val SEQ_LEN = 32
282295
}
283296
}

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt

Lines changed: 173 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import java.util.concurrent.atomic.AtomicInteger
1717
import org.apache.commons.io.FileUtils
1818
import org.junit.Assert
1919
import org.junit.Before
20-
import org.junit.Ignore
2120
import org.junit.Test
2221
import org.junit.runner.RunWith
2322
import org.pytorch.executorch.TestFileUtils.getTestFilePath
@@ -40,49 +39,49 @@ class ModuleInstrumentationTest {
4039
inputStream.close()
4140
}
4241

43-
@Ignore(
44-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
45-
)
4642
@Test
4743
@Throws(IOException::class, URISyntaxException::class)
4844
fun testModuleLoadAndForward() {
4945
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
50-
51-
val results = module.forward()
52-
Assert.assertTrue(results[0].isTensor)
46+
try {
47+
val results = module.forward(EValue.from(dummyInput()))
48+
Assert.assertTrue(results[0].isTensor)
49+
} finally {
50+
module.destroy()
51+
}
5352
}
5453

5554
@Test
5655
@Throws(IOException::class, URISyntaxException::class)
5756
fun testMethodMetadata() {
5857
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
58+
module.destroy()
5959
}
6060

61-
@Ignore(
62-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
63-
)
6461
@Test
6562
@Throws(IOException::class)
6663
fun testModuleLoadMethodAndForward() {
6764
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
65+
try {
66+
module.loadMethod(FORWARD_METHOD)
6867

69-
val loadMethod = module.loadMethod(FORWARD_METHOD)
70-
Assert.assertEquals(loadMethod.toLong(), OK.toLong())
71-
72-
val results = module.forward()
73-
Assert.assertTrue(results[0].isTensor)
68+
val results = module.forward(EValue.from(dummyInput()))
69+
Assert.assertTrue(results[0].isTensor)
70+
} finally {
71+
module.destroy()
72+
}
7473
}
7574

76-
@Ignore(
77-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
78-
)
7975
@Test
8076
@Throws(IOException::class)
8177
fun testModuleLoadForwardExplicit() {
8278
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
83-
84-
val results = module.execute(FORWARD_METHOD)
85-
Assert.assertTrue(results[0].isTensor)
79+
try {
80+
val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput()))
81+
Assert.assertTrue(results[0].isTensor)
82+
} finally {
83+
module.destroy()
84+
}
8685
}
8786

8887
@Test(expected = RuntimeException::class)
@@ -95,18 +94,26 @@ class ModuleInstrumentationTest {
9594
@Throws(IOException::class)
9695
fun testModuleLoadMethodNonExistantMethod() {
9796
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
98-
99-
val loadMethod = module.loadMethod(NONE_METHOD)
100-
Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong())
97+
try {
98+
val exception =
99+
Assert.assertThrows(ExecutorchRuntimeException::class.java) {
100+
module.loadMethod(NONE_METHOD)
101+
}
102+
Assert.assertEquals(
103+
ExecutorchRuntimeException.INVALID_ARGUMENT,
104+
exception.getErrorCode(),
105+
)
106+
} finally {
107+
module.destroy()
108+
}
101109
}
102110

103111
@Test(expected = RuntimeException::class)
104112
@Throws(IOException::class)
105113
fun testNonPteFile() {
106114
val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME))
107115

108-
val loadMethod = module.loadMethod(FORWARD_METHOD)
109-
Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong())
116+
module.loadMethod(FORWARD_METHOD)
110117
}
111118

112119
@Test
@@ -116,27 +123,21 @@ class ModuleInstrumentationTest {
116123

117124
module.destroy()
118125

119-
val loadMethod = module.loadMethod(FORWARD_METHOD)
120-
Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong())
126+
Assert.assertThrows(IllegalStateException::class.java) { module.loadMethod(FORWARD_METHOD) }
121127
}
122128

123129
@Test
124130
@Throws(IOException::class)
125131
fun testForwardOnDestroyedModule() {
126132
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
127133

128-
val loadMethod = module.loadMethod(FORWARD_METHOD)
129-
Assert.assertEquals(loadMethod.toLong(), OK.toLong())
134+
module.loadMethod(FORWARD_METHOD)
130135

131136
module.destroy()
132137

133-
val results = module.forward()
134-
Assert.assertEquals(0, results.size.toLong())
138+
Assert.assertThrows(IllegalStateException::class.java) { module.forward() }
135139
}
136140

137-
@Ignore(
138-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
139-
)
140141
@Test
141142
@Throws(InterruptedException::class, IOException::class)
142143
fun testForwardFromMultipleThreads() {
@@ -150,7 +151,7 @@ class ModuleInstrumentationTest {
150151
try {
151152
latch.countDown()
152153
latch.await(5000, TimeUnit.MILLISECONDS)
153-
val results = module.forward()
154+
val results = module.forward(EValue.from(dummyInput()))
154155
Assert.assertTrue(results[0].isTensor)
155156
completed.incrementAndGet()
156157
} catch (_: InterruptedException) {}
@@ -167,6 +168,139 @@ class ModuleInstrumentationTest {
167168
}
168169

169170
Assert.assertEquals(numThreads.toLong(), completed.get().toLong())
171+
module.destroy()
172+
}
173+
174+
// --- Load mode tests ---
175+
176+
@Test
177+
@Throws(IOException::class)
178+
fun testLoadWithMmapMode() {
179+
val module = Module.load(getTestFilePath(TEST_FILE_NAME), Module.LOAD_MODE_MMAP)
180+
try {
181+
val results = module.forward(EValue.from(dummyInput()))
182+
Assert.assertTrue(results[0].isTensor)
183+
} finally {
184+
module.destroy()
185+
}
186+
}
187+
188+
@Test
189+
@Throws(IOException::class)
190+
fun testLoadWithFileMode() {
191+
val module = Module.load(getTestFilePath(TEST_FILE_NAME), Module.LOAD_MODE_FILE)
192+
try {
193+
val results = module.forward(EValue.from(dummyInput()))
194+
Assert.assertTrue(results[0].isTensor)
195+
} finally {
196+
module.destroy()
197+
}
198+
}
199+
200+
// --- getMethods / getMethodMetadata tests ---
201+
202+
@Test
203+
@Throws(IOException::class)
204+
fun testGetMethods() {
205+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
206+
try {
207+
val methods = module.getMethods()
208+
Assert.assertNotNull(methods)
209+
Assert.assertTrue(methods.contains(FORWARD_METHOD))
210+
} finally {
211+
module.destroy()
212+
}
213+
}
214+
215+
@Test
216+
@Throws(IOException::class)
217+
fun testGetMethodMetadata() {
218+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
219+
try {
220+
val metadata = module.getMethodMetadata(FORWARD_METHOD)
221+
Assert.assertNotNull(metadata)
222+
Assert.assertEquals(FORWARD_METHOD, metadata.name)
223+
Assert.assertNotNull(metadata.backends)
224+
} finally {
225+
module.destroy()
226+
}
227+
}
228+
229+
// --- Log buffer tests ---
230+
231+
@Test
232+
@Throws(IOException::class)
233+
fun testReadLogBuffer() {
234+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
235+
try {
236+
val logs = module.readLogBuffer()
237+
Assert.assertNotNull(logs)
238+
} finally {
239+
module.destroy()
240+
}
241+
}
242+
243+
@Test
244+
fun testReadLogBufferStatic() {
245+
val logs = Module.readLogBufferStatic()
246+
Assert.assertNotNull(logs)
247+
}
248+
249+
// --- etdump test ---
250+
251+
@Test
252+
@Throws(IOException::class)
253+
fun testEtdump() {
254+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
255+
try {
256+
module.etdump()
257+
} finally {
258+
module.destroy()
259+
}
260+
}
261+
262+
// --- Destroyed-state tests for remaining methods ---
263+
264+
@Test
265+
@Throws(IOException::class)
266+
fun testGetMethodsOnDestroyedModule() {
267+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
268+
module.destroy()
269+
Assert.assertThrows(IllegalStateException::class.java) { module.getMethods() }
270+
}
271+
272+
@Test
273+
@Throws(IOException::class)
274+
fun testGetMethodMetadataOnDestroyedModule() {
275+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
276+
module.destroy()
277+
Assert.assertThrows(IllegalStateException::class.java) {
278+
module.getMethodMetadata(FORWARD_METHOD)
279+
}
280+
}
281+
282+
@Test
283+
@Throws(IOException::class)
284+
fun testReadLogBufferOnDestroyedModule() {
285+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
286+
module.destroy()
287+
Assert.assertThrows(IllegalStateException::class.java) { module.readLogBuffer() }
288+
}
289+
290+
@Test
291+
@Throws(IOException::class)
292+
fun testEtdumpOnDestroyedModule() {
293+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
294+
module.destroy()
295+
Assert.assertThrows(IllegalStateException::class.java) { module.etdump() }
296+
}
297+
298+
@Test
299+
@Throws(IOException::class)
300+
fun testDoubleDestroyIsSafe() {
301+
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
302+
module.destroy()
303+
module.destroy()
170304
}
171305

172306
companion object {
@@ -175,9 +309,8 @@ class ModuleInstrumentationTest {
175309
private const val NON_PTE_FILE_NAME = "/test.txt"
176310
private const val FORWARD_METHOD = "forward"
177311
private const val NONE_METHOD = "none"
178-
private const val OK = 0x00
179-
private const val INVALID_STATE = 0x2
180-
private const val INVALID_ARGUMENT = 0x12
181-
private const val ACCESS_FAILED = 0x22
312+
private val inputShape = longArrayOf(1, 3, 224, 224)
313+
314+
private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT)
182315
}
183316
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,22 @@ public static ExecuTorchRuntime getRuntime() {
3636
/**
3737
* Validates that the given path points to a readable file.
3838
*
39-
* @throws RuntimeException if the file does not exist or is not readable.
39+
* @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not
40+
* readable.
4041
*/
4142
public static void validateFilePath(String path, String description) {
43+
if (path == null) {
44+
throw new IllegalArgumentException("Cannot load " + description + ": path is null");
45+
}
4246
File file = new File(path);
43-
if (!file.canRead() || !file.isFile()) {
44-
throw new RuntimeException("Cannot load " + description + " " + path);
47+
if (!file.exists()) {
48+
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
49+
}
50+
if (!file.isFile()) {
51+
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
52+
}
53+
if (!file.canRead()) {
54+
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
4555
}
4656
}
4757

0 commit comments

Comments
 (0)