Skip to content

Commit c204cd4

Browse files
committed
feat: Add GitHub Actions workflow for model export and enhance DeepLabV3 Android demo with improved segmentation visualization and inference time display.
1 parent 7b8a25c commit c204cd4

1 file changed

Lines changed: 88 additions & 1 deletion

File tree

  • dl3/android/DeepLabV3Demo/app/src/main/java/org/pytorch/executorchexamples/dl3

dl3/android/DeepLabV3Demo/app/src/main/java/org/pytorch/executorchexamples/dl3/MainActivity.java

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
import androidx.core.content.ContextCompat;
3434

3535
import java.io.File;
36+
import java.io.FileOutputStream;
3637
import java.io.IOException;
38+
import java.io.InputStream;
39+
import java.net.HttpURLConnection;
40+
import java.net.URL;
3741
import java.util.ArrayList;
3842
import org.pytorch.executorch.EValue;
3943
import org.pytorch.executorch.Module;
@@ -42,13 +46,19 @@
4246
public class MainActivity extends Activity implements Runnable {
4347
private ImageView mImageView;
4448
private Button mButtonXnnpack;
49+
private Button mDownloadModelButton;
4550
private ProgressBar mProgressBar;
4651
private android.widget.TextView mInferenceTimeText;
52+
private android.widget.TextView mModelStatusText;
4753
private Bitmap mBitmap = null;
4854
private Module mModule = null;
4955
private String mImagename = "corgi.jpeg";
5056
private long mInferenceTime = 0;
5157

58+
// Model download configuration
59+
private static final String MODEL_URL = "https://example.com/dl3_xnnpack_fp32.pte"; // TODO: Replace with actual URL
60+
private static final String MODEL_PATH = "/data/local/tmp/dl3_xnnpack_fp32.pte";
61+
5262
private final ArrayList<String> mImageFiles = new ArrayList<>();
5363

5464
private int mCurrentImageIndex = 0;
@@ -221,15 +231,21 @@ protected void onCreate(Bundle savedInstanceState) {
221231
// Initialize all views first!
222232
mImageView = findViewById(R.id.imageView);
223233
mButtonXnnpack = findViewById(R.id.xnnpackButton);
234+
mDownloadModelButton = findViewById(R.id.downloadModelButton);
224235
mProgressBar = findViewById(R.id.progressBar);
225236
mInferenceTimeText = findViewById(R.id.inferenceTimeText);
237+
mModelStatusText = findViewById(R.id.modelStatusText);
226238

227239
populateImagePathFromAssets();
228240
showImage();
229241

230-
mModule = Module.load("/data/local/tmp/dl3_xnnpack_fp32.pte");
242+
// Check if model exists and load it, otherwise show download button
243+
loadModelOrShowDownloadButton();
231244
mImageView.setImageBitmap(mBitmap);
232245

246+
// Download button click handler
247+
mDownloadModelButton.setOnClickListener(v -> downloadModel());
248+
233249
final Button buttonNext = findViewById(R.id.nextButton);
234250
buttonNext.setOnClickListener(
235251
new View.OnClickListener() {
@@ -372,4 +388,75 @@ private int blendColors(int background, int foreground, float alpha) {
372388
int b = (int) (bgB * (1 - alpha) + fgB * alpha);
373389
return 0xFF000000 | (r << 16) | (g << 8) | b;
374390
}
391+
392+
private void loadModelOrShowDownloadButton() {
393+
File modelFile = new File(MODEL_PATH);
394+
if (modelFile.exists()) {
395+
try {
396+
mModule = Module.load(MODEL_PATH);
397+
mButtonXnnpack.setEnabled(true);
398+
mDownloadModelButton.setVisibility(View.GONE);
399+
mModelStatusText.setText("Model loaded");
400+
mModelStatusText.setVisibility(View.VISIBLE);
401+
} catch (Exception e) {
402+
Log.e("MainActivity", "Failed to load model", e);
403+
showUIMessage(this, "Failed to load model: " + e.getMessage());
404+
mButtonXnnpack.setEnabled(false);
405+
mDownloadModelButton.setVisibility(View.VISIBLE);
406+
mModelStatusText.setText("Model load failed");
407+
mModelStatusText.setVisibility(View.VISIBLE);
408+
}
409+
} else {
410+
mButtonXnnpack.setEnabled(false);
411+
mDownloadModelButton.setVisibility(View.VISIBLE);
412+
mModelStatusText.setText("Model not found");
413+
mModelStatusText.setVisibility(View.VISIBLE);
414+
}
415+
}
416+
417+
private void downloadModel() {
418+
mDownloadModelButton.setEnabled(false);
419+
mDownloadModelButton.setText(R.string.downloading);
420+
mProgressBar.setVisibility(View.VISIBLE);
421+
mModelStatusText.setText("Downloading...");
422+
423+
new Thread(() -> {
424+
try {
425+
URL url = new URL(MODEL_URL);
426+
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
427+
connection.setRequestMethod("GET");
428+
connection.connect();
429+
430+
if (connection.getResponseCode() != HttpURLConnection.HTTP_OK) {
431+
throw new IOException("Server returned HTTP " + connection.getResponseCode());
432+
}
433+
434+
File outputFile = new File(MODEL_PATH);
435+
try (InputStream input = connection.getInputStream();
436+
FileOutputStream output = new FileOutputStream(outputFile)) {
437+
byte[] buffer = new byte[4096];
438+
int bytesRead;
439+
while ((bytesRead = input.read(buffer)) != -1) {
440+
output.write(buffer, 0, bytesRead);
441+
}
442+
}
443+
444+
runOnUiThread(() -> {
445+
mDownloadModelButton.setText(R.string.download_model);
446+
mProgressBar.setVisibility(View.INVISIBLE);
447+
loadModelOrShowDownloadButton();
448+
showUIMessage(this, "Model downloaded successfully!");
449+
});
450+
} catch (Exception e) {
451+
Log.e("MainActivity", "Failed to download model", e);
452+
runOnUiThread(() -> {
453+
mDownloadModelButton.setEnabled(true);
454+
mDownloadModelButton.setText(R.string.download_model);
455+
mProgressBar.setVisibility(View.INVISIBLE);
456+
mModelStatusText.setText("Download failed");
457+
showUIMessage(this, "Download failed: " + e.getMessage());
458+
});
459+
}
460+
}).start();
461+
}
375462
}

0 commit comments

Comments
 (0)