Skip to content

Commit fb1e212

Browse files
fix: Add Android model E2E test (pytorch#19927)
1 parent 1925a86 commit fb1e212

3 files changed

Lines changed: 45 additions & 5 deletions

File tree

extension/android/executorch_android/android_test_setup.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ prepare_golden() {
3939
done
4040
}
4141

42+
prepare_add() {
43+
pushd "${BASEDIR}/../../.."
44+
"${PYTHON_EXECUTABLE}" -m test.models.export_program --modules "ModuleAdd" --outdir "${BASEDIR}/src/androidTest/resources/"
45+
popd
46+
}
47+
4248
prepare_xor
4349
prepare_tinyllama
4450
prepare_golden
51+
prepare_add

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,28 @@ class ModuleE2ETest {
7878
fun testVitB16() {
7979
testGoldenModel("vit_b_16", longArrayOf(1, 3, 224, 224))
8080
}
81+
82+
@Test
83+
fun testAdd() {
84+
val x = Tensor.fromBlob(floatArrayOf(1f, 2f, 3f, 4f), longArrayOf(2, 2))
85+
val y = Tensor.fromBlob(floatArrayOf(5f, 6f, 7f, 8f), longArrayOf(2, 2))
86+
87+
val pteFile = File(getTestFilePath("/ModuleAdd.pte"))
88+
javaClass.getResourceAsStream("/ModuleAdd.pte")!!.use {
89+
FileUtils.copyInputStreamToFile(it, pteFile)
90+
}
91+
92+
val module = Module.load(pteFile.absolutePath)
93+
try {
94+
// ModuleAdd computes torch.add(x, y, alpha=alpha). The alpha scalar is
95+
// passed as a Double because EValue only exposes a Double scalar factory
96+
// (TYPE_CODE_DOUBLE); the float32 output dtype is determined by x and y.
97+
val results = module.forward(EValue.from(x), EValue.from(y), EValue.from(1.0))
98+
assertTrue(results[0].isTensor)
99+
val actualOutput = results[0].toTensor().dataAsFloatArray
100+
assertOutputsClose(actualOutput, floatArrayOf(6f, 8f, 10f, 12f))
101+
} finally {
102+
module.destroy()
103+
}
104+
}
81105
}

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class ModuleInstrumentationTest {
4545
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
4646
try {
4747
val results = module.forward(EValue.from(dummyInput()))
48-
Assert.assertTrue(results[0].isTensor)
48+
assertSingleTensorResultWithShape(results, expectedOutputShape)
4949
} finally {
5050
module.destroy()
5151
}
@@ -59,7 +59,7 @@ class ModuleInstrumentationTest {
5959
module.loadMethod(FORWARD_METHOD)
6060

6161
val results = module.forward(EValue.from(dummyInput()))
62-
Assert.assertTrue(results[0].isTensor)
62+
assertSingleTensorResultWithShape(results, expectedOutputShape)
6363
} finally {
6464
module.destroy()
6565
}
@@ -71,7 +71,7 @@ class ModuleInstrumentationTest {
7171
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
7272
try {
7373
val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput()))
74-
Assert.assertTrue(results[0].isTensor)
74+
assertSingleTensorResultWithShape(results, expectedOutputShape)
7575
} finally {
7676
module.destroy()
7777
}
@@ -177,7 +177,7 @@ class ModuleInstrumentationTest {
177177
val module = Module.load(getTestFilePath(TEST_FILE_NAME), Module.LOAD_MODE_MMAP)
178178
try {
179179
val results = module.forward(EValue.from(dummyInput()))
180-
Assert.assertTrue(results[0].isTensor)
180+
assertSingleTensorResultWithShape(results, expectedOutputShape)
181181
} finally {
182182
module.destroy()
183183
}
@@ -189,7 +189,7 @@ class ModuleInstrumentationTest {
189189
val module = Module.load(getTestFilePath(TEST_FILE_NAME), Module.LOAD_MODE_FILE)
190190
try {
191191
val results = module.forward(EValue.from(dummyInput()))
192-
Assert.assertTrue(results[0].isTensor)
192+
assertSingleTensorResultWithShape(results, expectedOutputShape)
193193
} finally {
194194
module.destroy()
195195
}
@@ -308,7 +308,16 @@ class ModuleInstrumentationTest {
308308
private const val FORWARD_METHOD = "forward"
309309
private const val NONE_METHOD = "none"
310310
private val inputShape = longArrayOf(1, 3, 224, 224)
311+
private val expectedOutputShape = longArrayOf(1, 1000)
311312

312313
private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT)
314+
315+
private fun assertSingleTensorResultWithShape(
316+
results: Array<EValue>,
317+
expectedShape: LongArray,
318+
) {
319+
Assert.assertTrue(results[0].isTensor)
320+
Assert.assertArrayEquals(expectedShape, results[0].toTensor().shape())
321+
}
313322
}
314323
}

0 commit comments

Comments
 (0)