Skip to content

Commit 2207d16

Browse files
committed
feat: Add GitHub Actions workflow for model export and enhance DeepLabV3 Android demo with improved segmentation visualization and inference time display.
1 parent d0d845a commit 2207d16

5 files changed

Lines changed: 182 additions & 23 deletions

File tree

.github/workflows/export-models.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
name: Export Models
88

99
on:
10+
pull_request:
1011
schedule:
1112
# Run nightly at midnight UTC
1213
- cron: '0 0 * * *'
@@ -41,6 +42,9 @@ jobs:
4142
4243
name: Export ${{ matrix.name }}
4344

45+
export-dl3:
46+
runs-on: ubuntu-latest
47+
name: Export DeepLabV3 Model
4448
steps:
4549
- name: Checkout repository
4650
uses: actions/checkout@v4
@@ -65,3 +69,14 @@ jobs:
6569
name: ${{ matrix.artifact }}
6670
path: ${{ matrix.output }}
6771
if-no-files-found: error
72+
run: pip install executorch torchvision
73+
74+
- name: Export DL3 model
75+
working-directory: dl3/python
76+
run: python export.py
77+
78+
- name: Upload PTE model
79+
uses: actions/upload-artifact@v4
80+
with:
81+
name: dl3_xnnpack_fp32.pte
82+
path: dl3/python/dl3_xnnpack_fp32.pte

dl3/android/DeepLabV3Demo/app/src/main/java/org/pytorch/executorchexamples/dl3/MainActivity.java

Lines changed: 150 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
import androidx.core.content.ContextCompat;
3434

3535
import java.io.File;
36+
import java.io.FileOutputStream;
3637
import java.io.IOException;
38+
import java.io.InputStream;
39+
import java.net.HttpURLConnection;
40+
import java.net.URL;
3741
import java.util.ArrayList;
3842
import org.pytorch.executorch.EValue;
3943
import org.pytorch.executorch.Module;
@@ -42,10 +46,18 @@
4246
public class MainActivity extends Activity implements Runnable {
4347
private ImageView mImageView;
4448
private Button mButtonXnnpack;
49+
private Button mDownloadModelButton;
4550
private ProgressBar mProgressBar;
51+
private android.widget.TextView mInferenceTimeText;
52+
private android.widget.TextView mModelStatusText;
4653
private Bitmap mBitmap = null;
4754
private Module mModule = null;
4855
private String mImagename = "corgi.jpeg";
56+
private long mInferenceTime = 0;
57+
58+
// Model download configuration
59+
private static final String MODEL_URL = "https://example.com/dl3_xnnpack_fp32.pte"; // TODO: Replace with actual URL
60+
private static final String MODEL_PATH = "/data/local/tmp/dl3_xnnpack_fp32.pte";
4961

5062
private final ArrayList<String> mImageFiles = new ArrayList<>();
5163

@@ -57,9 +69,31 @@ public class MainActivity extends Activity implements Runnable {
5769
// see http://host.robots.ox.ac.uk:8080/pascal/VOC/voc2007/segexamples/index.html for the list of
5870
// classes with indexes
5971
private static final int CLASSNUM = 21;
60-
private static final int DOG = 12;
61-
private static final int PERSON = 15;
62-
private static final int SHEEP = 17;
72+
73+
// Colors for all 21 PASCAL VOC classes
74+
private static final int[] CLASS_COLORS = {
75+
0x00000000, // 0: Background (transparent)
76+
0xFFE6194B, // 1: Aeroplane (red)
77+
0xFF3CB44B, // 2: Bicycle (green)
78+
0xFFFFE119, // 3: Bird (yellow)
79+
0xFF4363D8, // 4: Boat (blue)
80+
0xFFF58231, // 5: Bottle (orange)
81+
0xFF911EB4, // 6: Bus (purple)
82+
0xFF46F0F0, // 7: Car (cyan)
83+
0xFFF032E6, // 8: Cat (magenta)
84+
0xFFBCF60C, // 9: Chair (lime)
85+
0xFFFABEBE, // 10: Cow (pink)
86+
0xFF008080, // 11: Dining Table (teal)
87+
0xFF00FF00, // 12: Dog (bright green)
88+
0xFF9A6324, // 13: Horse (brown)
89+
0xFFFFD8B1, // 14: Motorbike (peach)
90+
0xFFFF0000, // 15: Person (red)
91+
0xFF800000, // 16: Potted Plant (maroon)
92+
0xFF0000FF, // 17: Sheep (blue)
93+
0xFF808000, // 18: Sofa (olive)
94+
0xFFE6BEFF, // 19: Train (lavender)
95+
0xFFAA6E28, // 20: TV/Monitor (tan)
96+
};
6397

6498
private void checkAndRequestStoragePermission() {
6599
if (ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE)
@@ -197,14 +231,21 @@ protected void onCreate(Bundle savedInstanceState) {
197231
// Initialize all views first!
198232
mImageView = findViewById(R.id.imageView);
199233
mButtonXnnpack = findViewById(R.id.xnnpackButton);
234+
mDownloadModelButton = findViewById(R.id.downloadModelButton);
200235
mProgressBar = findViewById(R.id.progressBar);
236+
mInferenceTimeText = findViewById(R.id.inferenceTimeText);
237+
mModelStatusText = findViewById(R.id.modelStatusText);
201238

202239
populateImagePathFromAssets();
203240
showImage();
204241

205-
mModule = Module.load("/data/local/tmp/dl3_xnnpack_fp32.pte");
242+
// Check if model exists and load it, otherwise show download button
243+
loadModelOrShowDownloadButton();
206244
mImageView.setImageBitmap(mBitmap);
207245

246+
// Download button click handler
247+
mDownloadModelButton.setOnClickListener(v -> downloadModel());
248+
208249
final Button buttonNext = findViewById(R.id.nextButton);
209250
buttonNext.setOnClickListener(
210251
new View.OnClickListener() {
@@ -223,10 +264,9 @@ public void onClick(View v) {
223264
mButtonXnnpack.setOnClickListener(
224265
new View.OnClickListener() {
225266
public void onClick(View v) {
226-
mModule.destroy();
227-
mModule = Module.load("/data/local/tmp/dl3_xnnpack_fp32.pte");
228267
mButtonXnnpack.setEnabled(false);
229268
mProgressBar.setVisibility(ProgressBar.VISIBLE);
269+
mInferenceTimeText.setVisibility(View.INVISIBLE);
230270
mButtonXnnpack.setText(getString(R.string.run_model));
231271

232272
Thread thread = new Thread(MainActivity.this);
@@ -262,32 +302,38 @@ public void run() {
262302
boolean imageSegementationSuccess = false;
263303
final long startTime = SystemClock.elapsedRealtime();
264304
Tensor outputTensor = mModule.forward(EValue.from(inputTensor))[0].toTensor();
265-
final long inferenceTime = SystemClock.elapsedRealtime() - startTime;
266-
Log.d("ImageSegmentation", "inference time (ms): " + inferenceTime);
305+
mInferenceTime = SystemClock.elapsedRealtime() - startTime;
306+
Log.d("ImageSegmentation", "inference time (ms): " + mInferenceTime);
267307

268308
final float[] scores = outputTensor.getDataAsFloatArray();
269309
int width = mBitmap.getWidth();
270310
int height = mBitmap.getHeight();
271311

312+
// Get original pixels for blending
313+
int[] originalPixels = new int[width * height];
314+
mBitmap.getPixels(originalPixels, 0, width, 0, 0, width, height);
315+
272316
int[] intValues = new int[width * height];
273317
for (int j = 0; j < height; j++) {
274318
for (int k = 0; k < width; k++) {
275-
int maxi = 0, maxj = 0, maxk = 0;
319+
int maxi = 0;
276320
double maxnum = -Double.MAX_VALUE;
277321
for (int i = 0; i < CLASSNUM; i++) {
278322
float score = scores[i * (width * height) + j * width + k];
279323
if (score > maxnum) {
280324
maxnum = score;
281325
maxi = i;
282-
maxj = j;
283-
maxk = k;
284326
}
285327
}
286-
if (maxi == PERSON) intValues[maxj * width + maxk] = 0xFFFF0000; // R
287-
else if (maxi == DOG) intValues[maxj * width + maxk] = 0xFF00FF00; // G
288-
else if (maxi == SHEEP) intValues[maxj * width + maxk] = 0xFF0000FF; // B
289-
else intValues[maxj * width + maxk] = 0xFF000000;
290-
if (maxi == PERSON || maxi == DOG || maxi == SHEEP) {
328+
int pixelIndex = j * width + k;
329+
int classColor = CLASS_COLORS[maxi];
330+
331+
if (maxi == 0) {
332+
// Background: show original image
333+
intValues[pixelIndex] = originalPixels[pixelIndex];
334+
} else {
335+
// Blend segmentation color with original at 50% opacity
336+
intValues[pixelIndex] = blendColors(originalPixels[pixelIndex], classColor, 0.5f);
291337
imageSegementationSuccess = true;
292338
}
293339
}
@@ -310,12 +356,14 @@ public void run() {
310356
runOnUiThread(
311357
() -> {
312358
if (showUserIndicationOnImgSegFail) {
313-
Toast.makeText(this, "ImageSegmentation Failed", Toast.LENGTH_SHORT).show();
359+
Toast.makeText(this, "No objects detected", Toast.LENGTH_SHORT).show();
314360
}
315361
mImageView.setImageBitmap(transferredBitmap);
316362
mButtonXnnpack.setEnabled(true);
317363
mButtonXnnpack.setText(R.string.run_xnnpack);
318364
mProgressBar.setVisibility(ProgressBar.INVISIBLE);
365+
mInferenceTimeText.setText("Inference: " + mInferenceTime + " ms");
366+
mInferenceTimeText.setVisibility(View.VISIBLE);
319367
});
320368
}
321369

@@ -326,4 +374,89 @@ public void run() {
326374
}
327375
});
328376
}
377+
378+
// Blend two colors with given alpha for the overlay
379+
private int blendColors(int background, int foreground, float alpha) {
380+
int bgR = (background >> 16) & 0xFF;
381+
int bgG = (background >> 8) & 0xFF;
382+
int bgB = background & 0xFF;
383+
int fgR = (foreground >> 16) & 0xFF;
384+
int fgG = (foreground >> 8) & 0xFF;
385+
int fgB = foreground & 0xFF;
386+
int r = (int) (bgR * (1 - alpha) + fgR * alpha);
387+
int g = (int) (bgG * (1 - alpha) + fgG * alpha);
388+
int b = (int) (bgB * (1 - alpha) + fgB * alpha);
389+
return 0xFF000000 | (r << 16) | (g << 8) | b;
390+
}
391+
392+
private void loadModelOrShowDownloadButton() {
393+
File modelFile = new File(MODEL_PATH);
394+
if (modelFile.exists()) {
395+
try {
396+
mModule = Module.load(MODEL_PATH);
397+
mButtonXnnpack.setEnabled(true);
398+
mDownloadModelButton.setVisibility(View.GONE);
399+
mModelStatusText.setText("Model loaded");
400+
mModelStatusText.setVisibility(View.VISIBLE);
401+
} catch (Exception e) {
402+
Log.e("MainActivity", "Failed to load model", e);
403+
showUIMessage(this, "Failed to load model: " + e.getMessage());
404+
mButtonXnnpack.setEnabled(false);
405+
mDownloadModelButton.setVisibility(View.VISIBLE);
406+
mModelStatusText.setText("Model load failed");
407+
mModelStatusText.setVisibility(View.VISIBLE);
408+
}
409+
} else {
410+
mButtonXnnpack.setEnabled(false);
411+
mDownloadModelButton.setVisibility(View.VISIBLE);
412+
mModelStatusText.setText("Model not found");
413+
mModelStatusText.setVisibility(View.VISIBLE);
414+
}
415+
}
416+
417+
private void downloadModel() {
418+
mDownloadModelButton.setEnabled(false);
419+
mDownloadModelButton.setText(R.string.downloading);
420+
mProgressBar.setVisibility(View.VISIBLE);
421+
mModelStatusText.setText("Downloading...");
422+
423+
new Thread(() -> {
424+
try {
425+
URL url = new URL(MODEL_URL);
426+
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
427+
connection.setRequestMethod("GET");
428+
connection.connect();
429+
430+
if (connection.getResponseCode() != HttpURLConnection.HTTP_OK) {
431+
throw new IOException("Server returned HTTP " + connection.getResponseCode());
432+
}
433+
434+
File outputFile = new File(MODEL_PATH);
435+
try (InputStream input = connection.getInputStream();
436+
FileOutputStream output = new FileOutputStream(outputFile)) {
437+
byte[] buffer = new byte[4096];
438+
int bytesRead;
439+
while ((bytesRead = input.read(buffer)) != -1) {
440+
output.write(buffer, 0, bytesRead);
441+
}
442+
}
443+
444+
runOnUiThread(() -> {
445+
mDownloadModelButton.setText(R.string.download_model);
446+
mProgressBar.setVisibility(View.INVISIBLE);
447+
loadModelOrShowDownloadButton();
448+
showUIMessage(this, "Model downloaded successfully!");
449+
});
450+
} catch (Exception e) {
451+
Log.e("MainActivity", "Failed to download model", e);
452+
runOnUiThread(() -> {
453+
mDownloadModelButton.setEnabled(true);
454+
mDownloadModelButton.setText(R.string.download_model);
455+
mProgressBar.setVisibility(View.INVISIBLE);
456+
mModelStatusText.setText("Download failed");
457+
showUIMessage(this, "Download failed: " + e.getMessage());
458+
});
459+
}
460+
}).start();
461+
}
329462
}

dl3/android/DeepLabV3Demo/app/src/main/java/org/pytorch/executorchexamples/dl3/TensorImageUtils.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,6 @@ public static void bitmapToFloatBuffer(
7070
bitmap.getPixels(pixels, 0, width, x, y, width, height);
7171
final int offset_g = pixelsCount;
7272
final int offset_b = 2 * pixelsCount;
73-
for (int i = 0; i < 100; i++) {
74-
final int c = pixels[i];
75-
Log.i("Image", ": " + i + " " + ((c >> 16) & 0xff));
76-
}
7773
for (int i = 0; i < pixelsCount; i++) {
7874
final int c = pixels[i];
7975
float r = ((c >> 16) & 0xff) / 255.0f;

dl3/android/DeepLabV3Demo/app/src/main/res/layout/activity_main.xml

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,28 @@
3030
app:layout_constraintStart_toStartOf="parent"
3131
app:layout_constraintEnd_toEndOf="parent" />
3232

33+
<!-- Inference Time Display -->
34+
<TextView
35+
android:id="@+id/inferenceTimeText"
36+
android:layout_width="wrap_content"
37+
android:layout_height="wrap_content"
38+
android:text="@string/inference_time_placeholder"
39+
android:textSize="16sp"
40+
android:textStyle="bold"
41+
android:textColor="#2196F3"
42+
android:visibility="invisible"
43+
app:layout_constraintTop_toBottomOf="@+id/progressBar"
44+
app:layout_constraintStart_toStartOf="parent"
45+
app:layout_constraintEnd_toEndOf="parent" />
46+
3347
<!-- Row 1: Next and Reset buttons side by side -->
3448
<Button
3549
android:id="@+id/nextButton"
3650
android:layout_width="0dp"
3751
android:layout_height="wrap_content"
3852
android:text="@string/next"
3953
android:textAllCaps="false"
40-
app:layout_constraintTop_toBottomOf="@+id/progressBar"
54+
app:layout_constraintTop_toBottomOf="@+id/inferenceTimeText"
4155
app:layout_constraintStart_toStartOf="parent"
4256
app:layout_constraintEnd_toStartOf="@+id/resetImage"
4357
app:layout_constraintHorizontal_weight="1" />
@@ -48,7 +62,7 @@
4862
android:layout_height="wrap_content"
4963
android:text="@string/reset"
5064
android:textAllCaps="false"
51-
app:layout_constraintTop_toBottomOf="@+id/progressBar"
65+
app:layout_constraintTop_toBottomOf="@+id/inferenceTimeText"
5266
app:layout_constraintStart_toEndOf="@+id/nextButton"
5367
app:layout_constraintEnd_toEndOf="parent"
5468
app:layout_constraintHorizontal_weight="1" />

dl3/android/DeepLabV3Demo/app/src/main/res/values/strings.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
<string name="reset">Reset</string>
1212
<string name="load_and_refresh">Load And Refresh</string>
1313
<string name="run">Run</string>
14+
<string name="inference_time_placeholder">Inference: -- ms</string>
1415
</resources>

0 commit comments

Comments
 (0)