Skip to content

Commit 3ec63f4

Browse files
psiddhCopilot
andauthored
Ignored Module tests: provide required input tensor (#19028)
All 4 tests failed because they called forward() with zero arguments on mobilenet_v2 which expects a [1,3,224,224] float input. This was a test bug, not a runtime bug. Add a dummyInput() helper that creates a Tensor.ones with the correct shape, and remove all @ignore annotations. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 8a77f9b commit 3ec63f4

1 file changed

Lines changed: 37 additions & 33 deletions

File tree

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import java.util.concurrent.atomic.AtomicInteger
1717
import org.apache.commons.io.FileUtils
1818
import org.junit.Assert
1919
import org.junit.Before
20-
import org.junit.Ignore
2120
import org.junit.Test
2221
import org.junit.runner.RunWith
2322
import org.pytorch.executorch.TestFileUtils.getTestFilePath
@@ -40,48 +39,49 @@ class ModuleInstrumentationTest {
4039
inputStream.close()
4140
}
4241

43-
@Ignore(
44-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
45-
)
4642
@Test
4743
@Throws(IOException::class, URISyntaxException::class)
4844
fun testModuleLoadAndForward() {
4945
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
50-
51-
val results = module.forward()
52-
Assert.assertTrue(results[0].isTensor)
46+
try {
47+
val results = module.forward(EValue.from(dummyInput()))
48+
Assert.assertTrue(results[0].isTensor)
49+
} finally {
50+
module.destroy()
51+
}
5352
}
5453

5554
@Test
5655
@Throws(IOException::class, URISyntaxException::class)
5756
fun testMethodMetadata() {
5857
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
58+
module.destroy()
5959
}
6060

61-
@Ignore(
62-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
63-
)
6461
@Test
6562
@Throws(IOException::class)
6663
fun testModuleLoadMethodAndForward() {
6764
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
65+
try {
66+
module.loadMethod(FORWARD_METHOD)
6867

69-
module.loadMethod(FORWARD_METHOD)
70-
71-
val results = module.forward()
72-
Assert.assertTrue(results[0].isTensor)
68+
val results = module.forward(EValue.from(dummyInput()))
69+
Assert.assertTrue(results[0].isTensor)
70+
} finally {
71+
module.destroy()
72+
}
7373
}
7474

75-
@Ignore(
76-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
77-
)
7875
@Test
7976
@Throws(IOException::class)
8077
fun testModuleLoadForwardExplicit() {
8178
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
82-
83-
val results = module.execute(FORWARD_METHOD)
84-
Assert.assertTrue(results[0].isTensor)
79+
try {
80+
val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput()))
81+
Assert.assertTrue(results[0].isTensor)
82+
} finally {
83+
module.destroy()
84+
}
8585
}
8686

8787
@Test(expected = RuntimeException::class)
@@ -94,15 +94,18 @@ class ModuleInstrumentationTest {
9494
@Throws(IOException::class)
9595
fun testModuleLoadMethodNonExistantMethod() {
9696
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
97-
98-
val exception =
99-
Assert.assertThrows(ExecutorchRuntimeException::class.java) {
100-
module.loadMethod(NONE_METHOD)
101-
}
102-
Assert.assertEquals(
103-
ExecutorchRuntimeException.INVALID_ARGUMENT,
104-
exception.getErrorCode(),
105-
)
97+
try {
98+
val exception =
99+
Assert.assertThrows(ExecutorchRuntimeException::class.java) {
100+
module.loadMethod(NONE_METHOD)
101+
}
102+
Assert.assertEquals(
103+
ExecutorchRuntimeException.INVALID_ARGUMENT,
104+
exception.getErrorCode(),
105+
)
106+
} finally {
107+
module.destroy()
108+
}
106109
}
107110

108111
@Test(expected = RuntimeException::class)
@@ -135,9 +138,6 @@ class ModuleInstrumentationTest {
135138
Assert.assertThrows(IllegalStateException::class.java) { module.forward() }
136139
}
137140

138-
@Ignore(
139-
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
140-
)
141141
@Test
142142
@Throws(InterruptedException::class, IOException::class)
143143
fun testForwardFromMultipleThreads() {
@@ -151,7 +151,7 @@ class ModuleInstrumentationTest {
151151
try {
152152
latch.countDown()
153153
latch.await(5000, TimeUnit.MILLISECONDS)
154-
val results = module.forward()
154+
val results = module.forward(EValue.from(dummyInput()))
155155
Assert.assertTrue(results[0].isTensor)
156156
completed.incrementAndGet()
157157
} catch (_: InterruptedException) {}
@@ -168,6 +168,7 @@ class ModuleInstrumentationTest {
168168
}
169169

170170
Assert.assertEquals(numThreads.toLong(), completed.get().toLong())
171+
module.destroy()
171172
}
172173

173174
companion object {
@@ -176,5 +177,8 @@ class ModuleInstrumentationTest {
176177
private const val NON_PTE_FILE_NAME = "/test.txt"
177178
private const val FORWARD_METHOD = "forward"
178179
private const val NONE_METHOD = "none"
180+
private val inputShape = longArrayOf(1, 3, 224, 224)
181+
182+
private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT)
179183
}
180184
}

0 commit comments

Comments
 (0)