Skip to content

Commit bf39024

Browse files
committed
MV3 instrumentation test and build fix
1 parent 32ba8af commit bf39024

4 files changed

Lines changed: 331 additions & 6 deletions

File tree

mv3/android/MV3Demo/README.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,46 @@ This is a sample Android application demonstrating MobileNet v3 (MV3) image clas
4646
- **Inference**: ExecuTorch Android API
4747
- **Image Processing**: `TensorImageUtils` for bitmap-to-tensor conversion
4848
49+
## Testing
50+
51+
The app includes an instrumentation test that validates the complete image classification workflow.
52+
53+
### What the test does
54+
55+
1. Launches the app
56+
2. Downloads the MV3 model if not already present
57+
3. Downloads a cat image from HuggingFace
58+
4. Runs inference on the image
59+
5. Validates that the model correctly classifies it as a cat
60+
61+
### Running the test
62+
63+
1. **Connect a device or start an emulator**
64+
65+
2. **Build and install the test APKs**:
66+
```bash
67+
./gradlew installDebug installDebugAndroidTest
68+
```
69+
70+
3. **Run the test**:
71+
```bash
72+
adb shell am instrument -w -r \
73+
-e class 'org.pytorch.executorchexamples.mv3.UIWorkflowTest#testCatImageClassification' \
74+
org.pytorch.executorchexamples.mv3.test/androidx.test.runner.AndroidJUnitRunner
75+
```
76+
77+
Or run all tests via Gradle:
78+
```bash
79+
./gradlew connectedDebugAndroidTest
80+
```
81+
82+
### Test output
83+
84+
The test logs classification results to logcat with the tag `MV3_RESULT`:
85+
```bash
86+
adb logcat -s MV3_RESULT
87+
```
88+
4989
## License
5090
5191
BSD-3-Clause
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorchexamples.mv3
10+
11+
import android.content.Context
12+
import android.graphics.Bitmap
13+
import android.graphics.BitmapFactory
14+
import android.util.Log
15+
import androidx.compose.ui.test.junit4.createAndroidComposeRule
16+
import androidx.compose.ui.test.onAllNodesWithText
17+
import androidx.compose.ui.test.onNodeWithText
18+
import androidx.compose.ui.test.performClick
19+
import androidx.test.core.app.ApplicationProvider
20+
import androidx.test.ext.junit.runners.AndroidJUnit4
21+
import androidx.test.filters.LargeTest
22+
import org.junit.Assert.assertNotNull
23+
import org.junit.Assert.assertTrue
24+
import org.junit.Before
25+
import org.junit.Rule
26+
import org.junit.Test
27+
import org.junit.runner.RunWith
28+
import org.pytorch.executorch.EValue
29+
import org.pytorch.executorch.Module
30+
import java.io.File
31+
import java.net.HttpURLConnection
32+
import java.net.URL
33+
import java.util.concurrent.CountDownLatch
34+
import java.util.concurrent.TimeUnit
35+
import kotlin.math.exp
36+
37+
/**
38+
* Instrumentation test for MobileNetV3 image classification demo.
39+
*
40+
* This test validates the complete end-to-end workflow:
41+
* 1. App launches successfully
42+
* 2. Model downloads if needed
43+
* 3. Downloads a cat image from HuggingFace
44+
* 4. Runs inference and validates the image is classified as a cat
45+
*/
46+
@RunWith(AndroidJUnit4::class)
47+
@LargeTest
48+
class UIWorkflowTest {
49+
50+
companion object {
51+
private const val TAG = "MV3UIWorkflowTest"
52+
private const val RESULT_TAG = "MV3_RESULT"
53+
54+
// Cat test image from HuggingFace
55+
private const val CAT_IMAGE_URL =
56+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/cat.jpg"
57+
58+
// Model filename (same as MainActivity)
59+
private const val MODEL_FILENAME = "mv3.pte"
60+
61+
// Cat-related ImageNet classes that we expect for a cat image
62+
private val CAT_CLASSES = setOf(
63+
"tabby", "tiger cat", "Persian cat", "Siamese cat", "Egyptian cat"
64+
)
65+
}
66+
67+
@get:Rule
68+
val composeTestRule = createAndroidComposeRule<MainActivity>()
69+
70+
private lateinit var context: Context
71+
72+
@Before
73+
fun setUp() {
74+
context = ApplicationProvider.getApplicationContext()
75+
}
76+
77+
/**
78+
* Downloads an image from URL and returns it as a Bitmap.
79+
*/
80+
private fun downloadImageFromUrl(imageUrl: String): Bitmap? {
81+
var bitmap: Bitmap? = null
82+
val latch = CountDownLatch(1)
83+
84+
Thread {
85+
try {
86+
Log.i(TAG, "Downloading image from: $imageUrl")
87+
val url = URL(imageUrl)
88+
val connection = url.openConnection() as HttpURLConnection
89+
connection.requestMethod = "GET"
90+
connection.connectTimeout = 30000
91+
connection.readTimeout = 30000
92+
connection.instanceFollowRedirects = true
93+
connection.connect()
94+
95+
if (connection.responseCode == HttpURLConnection.HTTP_OK) {
96+
connection.inputStream.use { inputStream ->
97+
bitmap = BitmapFactory.decodeStream(inputStream)
98+
}
99+
Log.i(TAG, "Image downloaded successfully")
100+
} else {
101+
Log.e(TAG, "Failed to download image: HTTP ${connection.responseCode}")
102+
}
103+
connection.disconnect()
104+
} catch (e: Exception) {
105+
Log.e(TAG, "Error downloading image", e)
106+
} finally {
107+
latch.countDown()
108+
}
109+
}.start()
110+
111+
latch.await(60, TimeUnit.SECONDS)
112+
return bitmap
113+
}
114+
115+
/**
116+
* Waits for the model to be ready.
117+
*/
118+
private fun waitForModelReady(timeoutMs: Long = 120000): Boolean {
119+
return try {
120+
composeTestRule.waitUntil(timeoutMillis = timeoutMs) {
121+
composeTestRule
122+
.onAllNodesWithText("Pick an image to start or use Live Camera", substring = true)
123+
.fetchSemanticsNodes()
124+
.isNotEmpty()
125+
}
126+
Log.i(TAG, "Model is ready")
127+
true
128+
} catch (e: Exception) {
129+
Log.i(TAG, "Model not ready after ${timeoutMs}ms: ${e.message}")
130+
false
131+
}
132+
}
133+
134+
/**
135+
* Ensures model is ready, downloading if necessary.
136+
*/
137+
private fun ensureModelReady(): Boolean {
138+
composeTestRule.waitForIdle()
139+
140+
// Check if model is already ready
141+
val readyNodes = composeTestRule
142+
.onAllNodesWithText("Pick an image to start or use Live Camera", substring = true)
143+
.fetchSemanticsNodes()
144+
if (readyNodes.isNotEmpty()) {
145+
Log.i(TAG, "Model is already ready")
146+
return true
147+
}
148+
149+
// Check if we need to download
150+
val downloadNodes = composeTestRule
151+
.onAllNodesWithText("Download Model", substring = true)
152+
.fetchSemanticsNodes()
153+
154+
if (downloadNodes.isNotEmpty()) {
155+
Log.i(TAG, "Downloading model...")
156+
composeTestRule.onNodeWithText("Download Model").performClick()
157+
158+
// Wait for download to complete (up to 5 minutes)
159+
composeTestRule.waitUntil(timeoutMillis = 300000) {
160+
val downloading = composeTestRule
161+
.onAllNodesWithText("Downloading...", substring = true)
162+
.fetchSemanticsNodes()
163+
val ready = composeTestRule
164+
.onAllNodesWithText("Pick an image to start or use Live Camera", substring = true)
165+
.fetchSemanticsNodes()
166+
downloading.isEmpty() && ready.isNotEmpty()
167+
}
168+
Log.i(TAG, "Model download complete")
169+
return true
170+
}
171+
172+
// Wait for UI to settle
173+
return waitForModelReady(10000)
174+
}
175+
176+
/**
177+
* Applies softmax to convert logits to probabilities.
178+
*/
179+
private fun softmax(scores: FloatArray): FloatArray {
180+
val max = scores.maxOrNull() ?: 0f
181+
val expScores = scores.map { exp((it - max).toDouble()) }
182+
val sumExp = expScores.sum()
183+
return expScores.map { (it / sumExp).toFloat() }.toFloatArray()
184+
}
185+
186+
/**
187+
* Gets top-K predictions from scores.
188+
*/
189+
private fun getTopK(scores: FloatArray, k: Int): List<Pair<Int, Float>> {
190+
val probabilities = softmax(scores)
191+
return probabilities.withIndex()
192+
.sortedByDescending { it.value }
193+
.take(k)
194+
.map { it.index to it.value }
195+
}
196+
197+
/**
198+
* Runs inference on the given bitmap using the model.
199+
*/
200+
private fun runInferenceOnBitmap(bitmap: Bitmap, module: Module): List<Pair<String, Float>> {
201+
val scaledBitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true)
202+
203+
val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
204+
scaledBitmap,
205+
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
206+
TensorImageUtils.TORCHVISION_NORM_STD_RGB
207+
)
208+
209+
val outputTensor = module.forward(EValue.from(inputTensor))[0].toTensor()
210+
val scores = outputTensor.dataAsFloatArray
211+
val top3 = getTopK(scores, 3)
212+
213+
return top3.map { (index, score) ->
214+
val label = if (index in ImageNetClasses.IMAGENET_CLASSES.indices) {
215+
ImageNetClasses.IMAGENET_CLASSES[index]
216+
} else {
217+
"Unknown($index)"
218+
}
219+
label to score
220+
}
221+
}
222+
223+
/**
224+
* Tests the full end-to-end classification workflow:
225+
* 1. App launches
226+
* 2. Download model if needed
227+
* 3. Download cat image from HuggingFace
228+
* 4. Run inference
229+
* 5. Validate that the result is a cat class
230+
*/
231+
@Test
232+
fun testCatImageClassification() {
233+
composeTestRule.waitForIdle()
234+
235+
// Step 1: Ensure model is ready
236+
val modelReady = ensureModelReady()
237+
assertTrue("Model should be ready or download should start", modelReady)
238+
239+
val finalReady = waitForModelReady(300000)
240+
assertTrue("Model should be ready", finalReady)
241+
242+
// Step 2: Download cat image
243+
val bitmap = downloadImageFromUrl(CAT_IMAGE_URL)
244+
assertNotNull("Cat image should be downloaded", bitmap)
245+
246+
// Step 3: Load the model
247+
val modelPath = context.filesDir.absolutePath + "/" + MODEL_FILENAME
248+
val modelFile = File(modelPath)
249+
assertTrue("Model file should exist at $modelPath", modelFile.exists())
250+
251+
val module = Module.load(modelPath)
252+
assertNotNull("Module should be loaded", module)
253+
254+
// Step 4: Run inference
255+
val results = runInferenceOnBitmap(bitmap!!, module)
256+
assertTrue("Should have classification results", results.isNotEmpty())
257+
258+
Log.i(RESULT_TAG, "Classification results:")
259+
results.forEach { (label, prob) ->
260+
Log.i(RESULT_TAG, " $label: ${String.format("%.4f", prob)}")
261+
}
262+
263+
// Step 5: Validate that top prediction is a cat
264+
val topLabel = results.first().first
265+
val isCat = CAT_CLASSES.any { catClass ->
266+
topLabel.contains(catClass, ignoreCase = true)
267+
}
268+
269+
assertTrue(
270+
"Top prediction should be a cat class, but got: $topLabel. Expected one of: $CAT_CLASSES",
271+
isCat
272+
)
273+
274+
Log.i(TAG, "Cat image correctly classified as: $topLabel")
275+
}
276+
}

mv3/android/MV3Demo/app/src/main/AndroidManifest.xml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
<uses-permission android:name="android.permission.INTERNET" />
88
<uses-permission android:name="android.permission.CAMERA" />
99

10+
<uses-feature android:name="android.hardware.camera" android:required="false" />
11+
<uses-feature android:name="android.hardware.camera.autofocus" android:required="false" />
12+
1013
<application
1114
android:allowBackup="false"
1215
android:label="MV3 Demo"

mv3/android/MV3Demo/app/src/main/java/org/pytorch/executorchexamples/mv3/MainActivity.kt

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ import androidx.compose.ui.Modifier
4646
import androidx.compose.ui.graphics.asImageBitmap
4747
import androidx.compose.ui.platform.LocalContext
4848
import androidx.compose.ui.platform.LocalLifecycleOwner
49+
import androidx.compose.ui.platform.testTag
50+
import androidx.compose.ui.semantics.semantics
51+
import androidx.compose.ui.semantics.testTag
4952
import androidx.compose.ui.unit.dp
5053
import androidx.compose.ui.viewinterop.AndroidView
5154
import androidx.core.content.ContextCompat
@@ -163,7 +166,7 @@ class MainActivity : ComponentActivity() {
163166
}
164167
},
165168
enabled = !isProcessing && modelReady && !isLiveCameraMode && bitmap != null,
166-
modifier = Modifier.weight(1f)
169+
modifier = Modifier.weight(1f).testTag("run_inference_button")
167170
) {
168171
Icon(
169172
imageVector = Icons.Filled.PlayArrow,
@@ -177,7 +180,7 @@ class MainActivity : ComponentActivity() {
177180
imagePickerLauncher.launch("image/*")
178181
},
179182
enabled = !isLiveCameraMode && modelReady,
180-
modifier = Modifier.weight(1f)
183+
modifier = Modifier.weight(1f).testTag("pick_image_button")
181184
) {
182185
Icon(
183186
imageVector = Icons.Filled.PhotoLibrary,
@@ -195,7 +198,7 @@ class MainActivity : ComponentActivity() {
195198
}
196199
},
197200
enabled = modelReady,
198-
modifier = Modifier.weight(1f),
201+
modifier = Modifier.weight(1f).testTag("live_camera_button"),
199202
colors = IconButtonDefaults.iconButtonColors(
200203
containerColor = if (isLiveCameraMode) MaterialTheme.colorScheme.secondaryContainer else androidx.compose.ui.graphics.Color.Transparent
201204
)
@@ -274,12 +277,15 @@ class MainActivity : ComponentActivity() {
274277
LazyColumn(
275278
modifier = Modifier
276279
.fillMaxWidth()
277-
.height(120.dp),
280+
.height(120.dp)
281+
.testTag("results_list"),
278282
verticalArrangement = Arrangement.spacedBy(4.dp)
279283
) {
280284
items(classificationResults) { result ->
281285
Row(
282-
modifier = Modifier.fillMaxWidth(),
286+
modifier = Modifier
287+
.fillMaxWidth()
288+
.testTag("result_item_${result.first}"),
283289
horizontalArrangement = Arrangement.SpaceBetween
284290
) {
285291
Text(text = result.first, style = MaterialTheme.typography.bodyMedium)
@@ -308,7 +314,7 @@ class MainActivity : ComponentActivity() {
308314
}
309315
},
310316
enabled = !isDownloading,
311-
modifier = Modifier.fillMaxWidth()
317+
modifier = Modifier.fillMaxWidth().testTag("download_button")
312318
) {
313319
Text(if (isDownloading) "Downloading..." else "Download Model")
314320
}

0 commit comments

Comments
 (0)