3333import androidx .core .content .ContextCompat ;
3434
3535import java .io .File ;
36+ import java .io .FileOutputStream ;
3637import java .io .IOException ;
38+ import java .io .InputStream ;
39+ import java .net .HttpURLConnection ;
40+ import java .net .URL ;
3741import java .util .ArrayList ;
3842import org .pytorch .executorch .EValue ;
3943import org .pytorch .executorch .Module ;
4246public class MainActivity extends Activity implements Runnable {
4347 private ImageView mImageView ;
4448 private Button mButtonXnnpack ;
49+ private Button mDownloadModelButton ;
4550 private ProgressBar mProgressBar ;
51+ private android .widget .TextView mInferenceTimeText ;
52+ private android .widget .TextView mModelStatusText ;
4653 private Bitmap mBitmap = null ;
4754 private Module mModule = null ;
4855 private String mImagename = "corgi.jpeg" ;
56+ private long mInferenceTime = 0 ;
57+
58+ // Model download configuration
59+ private static final String MODEL_URL = "https://example.com/dl3_xnnpack_fp32.pte" ; // TODO: Replace with actual URL
60+ private static final String MODEL_PATH = "/data/local/tmp/dl3_xnnpack_fp32.pte" ;
4961
5062 private final ArrayList <String > mImageFiles = new ArrayList <>();
5163
@@ -57,9 +69,31 @@ public class MainActivity extends Activity implements Runnable {
5769 // see http://host.robots.ox.ac.uk:8080/pascal/VOC/voc2007/segexamples/index.html for the list of
5870 // classes with indexes
5971 private static final int CLASSNUM = 21 ;
60- private static final int DOG = 12 ;
61- private static final int PERSON = 15 ;
62- private static final int SHEEP = 17 ;
72+
73+ // Colors for all 21 PASCAL VOC classes
74+ private static final int [] CLASS_COLORS = {
75+ 0x00000000 , // 0: Background (transparent)
76+ 0xFFE6194B , // 1: Aeroplane (red)
77+ 0xFF3CB44B , // 2: Bicycle (green)
78+ 0xFFFFE119 , // 3: Bird (yellow)
79+ 0xFF4363D8 , // 4: Boat (blue)
80+ 0xFFF58231 , // 5: Bottle (orange)
81+ 0xFF911EB4 , // 6: Bus (purple)
82+ 0xFF46F0F0 , // 7: Car (cyan)
83+ 0xFFF032E6 , // 8: Cat (magenta)
84+ 0xFFBCF60C , // 9: Chair (lime)
85+ 0xFFFABEBE , // 10: Cow (pink)
86+ 0xFF008080 , // 11: Dining Table (teal)
87+ 0xFF00FF00 , // 12: Dog (bright green)
88+ 0xFF9A6324 , // 13: Horse (brown)
89+ 0xFFFFD8B1 , // 14: Motorbike (peach)
90+ 0xFFFF0000 , // 15: Person (red)
91+ 0xFF800000 , // 16: Potted Plant (maroon)
92+ 0xFF0000FF , // 17: Sheep (blue)
93+ 0xFF808000 , // 18: Sofa (olive)
94+ 0xFFE6BEFF , // 19: Train (lavender)
95+ 0xFFAA6E28 , // 20: TV/Monitor (tan)
96+ };
6397
6498 private void checkAndRequestStoragePermission () {
6599 if (ContextCompat .checkSelfPermission (this , Manifest .permission .READ_EXTERNAL_STORAGE )
@@ -197,14 +231,21 @@ protected void onCreate(Bundle savedInstanceState) {
197231 // Initialize all views first!
198232 mImageView = findViewById (R .id .imageView );
199233 mButtonXnnpack = findViewById (R .id .xnnpackButton );
234+ mDownloadModelButton = findViewById (R .id .downloadModelButton );
200235 mProgressBar = findViewById (R .id .progressBar );
236+ mInferenceTimeText = findViewById (R .id .inferenceTimeText );
237+ mModelStatusText = findViewById (R .id .modelStatusText );
201238
202239 populateImagePathFromAssets ();
203240 showImage ();
204241
205- mModule = Module .load ("/data/local/tmp/dl3_xnnpack_fp32.pte" );
242+ // Check if model exists and load it, otherwise show download button
243+ loadModelOrShowDownloadButton ();
206244 mImageView .setImageBitmap (mBitmap );
207245
246+ // Download button click handler
247+ mDownloadModelButton .setOnClickListener (v -> downloadModel ());
248+
208249 final Button buttonNext = findViewById (R .id .nextButton );
209250 buttonNext .setOnClickListener (
210251 new View .OnClickListener () {
@@ -223,10 +264,9 @@ public void onClick(View v) {
223264 mButtonXnnpack .setOnClickListener (
224265 new View .OnClickListener () {
225266 public void onClick (View v ) {
226- mModule .destroy ();
227- mModule = Module .load ("/data/local/tmp/dl3_xnnpack_fp32.pte" );
228267 mButtonXnnpack .setEnabled (false );
229268 mProgressBar .setVisibility (ProgressBar .VISIBLE );
269+ mInferenceTimeText .setVisibility (View .INVISIBLE );
230270 mButtonXnnpack .setText (getString (R .string .run_model ));
231271
232272 Thread thread = new Thread (MainActivity .this );
@@ -262,32 +302,38 @@ public void run() {
262302 boolean imageSegementationSuccess = false ;
263303 final long startTime = SystemClock .elapsedRealtime ();
264304 Tensor outputTensor = mModule .forward (EValue .from (inputTensor ))[0 ].toTensor ();
265- final long inferenceTime = SystemClock .elapsedRealtime () - startTime ;
266- Log .d ("ImageSegmentation" , "inference time (ms): " + inferenceTime );
305+ mInferenceTime = SystemClock .elapsedRealtime () - startTime ;
306+ Log .d ("ImageSegmentation" , "inference time (ms): " + mInferenceTime );
267307
268308 final float [] scores = outputTensor .getDataAsFloatArray ();
269309 int width = mBitmap .getWidth ();
270310 int height = mBitmap .getHeight ();
271311
312+ // Get original pixels for blending
313+ int [] originalPixels = new int [width * height ];
314+ mBitmap .getPixels (originalPixels , 0 , width , 0 , 0 , width , height );
315+
272316 int [] intValues = new int [width * height ];
273317 for (int j = 0 ; j < height ; j ++) {
274318 for (int k = 0 ; k < width ; k ++) {
275- int maxi = 0 , maxj = 0 , maxk = 0 ;
319+ int maxi = 0 ;
276320 double maxnum = -Double .MAX_VALUE ;
277321 for (int i = 0 ; i < CLASSNUM ; i ++) {
278322 float score = scores [i * (width * height ) + j * width + k ];
279323 if (score > maxnum ) {
280324 maxnum = score ;
281325 maxi = i ;
282- maxj = j ;
283- maxk = k ;
284326 }
285327 }
286- if (maxi == PERSON ) intValues [maxj * width + maxk ] = 0xFFFF0000 ; // R
287- else if (maxi == DOG ) intValues [maxj * width + maxk ] = 0xFF00FF00 ; // G
288- else if (maxi == SHEEP ) intValues [maxj * width + maxk ] = 0xFF0000FF ; // B
289- else intValues [maxj * width + maxk ] = 0xFF000000 ;
290- if (maxi == PERSON || maxi == DOG || maxi == SHEEP ) {
328+ int pixelIndex = j * width + k ;
329+ int classColor = CLASS_COLORS [maxi ];
330+
331+ if (maxi == 0 ) {
332+ // Background: show original image
333+ intValues [pixelIndex ] = originalPixels [pixelIndex ];
334+ } else {
335+ // Blend segmentation color with original at 50% opacity
336+ intValues [pixelIndex ] = blendColors (originalPixels [pixelIndex ], classColor , 0.5f );
291337 imageSegementationSuccess = true ;
292338 }
293339 }
@@ -310,12 +356,14 @@ public void run() {
310356 runOnUiThread (
311357 () -> {
312358 if (showUserIndicationOnImgSegFail ) {
313- Toast .makeText (this , "ImageSegmentation Failed " , Toast .LENGTH_SHORT ).show ();
359+ Toast .makeText (this , "No objects detected " , Toast .LENGTH_SHORT ).show ();
314360 }
315361 mImageView .setImageBitmap (transferredBitmap );
316362 mButtonXnnpack .setEnabled (true );
317363 mButtonXnnpack .setText (R .string .run_xnnpack );
318364 mProgressBar .setVisibility (ProgressBar .INVISIBLE );
365+ mInferenceTimeText .setText ("Inference: " + mInferenceTime + " ms" );
366+ mInferenceTimeText .setVisibility (View .VISIBLE );
319367 });
320368 }
321369
@@ -326,4 +374,89 @@ public void run() {
326374 }
327375 });
328376 }
377+
378+ // Blend two colors with given alpha for the overlay
379+ private int blendColors (int background , int foreground , float alpha ) {
380+ int bgR = (background >> 16 ) & 0xFF ;
381+ int bgG = (background >> 8 ) & 0xFF ;
382+ int bgB = background & 0xFF ;
383+ int fgR = (foreground >> 16 ) & 0xFF ;
384+ int fgG = (foreground >> 8 ) & 0xFF ;
385+ int fgB = foreground & 0xFF ;
386+ int r = (int ) (bgR * (1 - alpha ) + fgR * alpha );
387+ int g = (int ) (bgG * (1 - alpha ) + fgG * alpha );
388+ int b = (int ) (bgB * (1 - alpha ) + fgB * alpha );
389+ return 0xFF000000 | (r << 16 ) | (g << 8 ) | b ;
390+ }
391+
392+ private void loadModelOrShowDownloadButton () {
393+ File modelFile = new File (MODEL_PATH );
394+ if (modelFile .exists ()) {
395+ try {
396+ mModule = Module .load (MODEL_PATH );
397+ mButtonXnnpack .setEnabled (true );
398+ mDownloadModelButton .setVisibility (View .GONE );
399+ mModelStatusText .setText ("Model loaded" );
400+ mModelStatusText .setVisibility (View .VISIBLE );
401+ } catch (Exception e ) {
402+ Log .e ("MainActivity" , "Failed to load model" , e );
403+ showUIMessage (this , "Failed to load model: " + e .getMessage ());
404+ mButtonXnnpack .setEnabled (false );
405+ mDownloadModelButton .setVisibility (View .VISIBLE );
406+ mModelStatusText .setText ("Model load failed" );
407+ mModelStatusText .setVisibility (View .VISIBLE );
408+ }
409+ } else {
410+ mButtonXnnpack .setEnabled (false );
411+ mDownloadModelButton .setVisibility (View .VISIBLE );
412+ mModelStatusText .setText ("Model not found" );
413+ mModelStatusText .setVisibility (View .VISIBLE );
414+ }
415+ }
416+
417+ private void downloadModel () {
418+ mDownloadModelButton .setEnabled (false );
419+ mDownloadModelButton .setText (R .string .downloading );
420+ mProgressBar .setVisibility (View .VISIBLE );
421+ mModelStatusText .setText ("Downloading..." );
422+
423+ new Thread (() -> {
424+ try {
425+ URL url = new URL (MODEL_URL );
426+ HttpURLConnection connection = (HttpURLConnection ) url .openConnection ();
427+ connection .setRequestMethod ("GET" );
428+ connection .connect ();
429+
430+ if (connection .getResponseCode () != HttpURLConnection .HTTP_OK ) {
431+ throw new IOException ("Server returned HTTP " + connection .getResponseCode ());
432+ }
433+
434+ File outputFile = new File (MODEL_PATH );
435+ try (InputStream input = connection .getInputStream ();
436+ FileOutputStream output = new FileOutputStream (outputFile )) {
437+ byte [] buffer = new byte [4096 ];
438+ int bytesRead ;
439+ while ((bytesRead = input .read (buffer )) != -1 ) {
440+ output .write (buffer , 0 , bytesRead );
441+ }
442+ }
443+
444+ runOnUiThread (() -> {
445+ mDownloadModelButton .setText (R .string .download_model );
446+ mProgressBar .setVisibility (View .INVISIBLE );
447+ loadModelOrShowDownloadButton ();
448+ showUIMessage (this , "Model downloaded successfully!" );
449+ });
450+ } catch (Exception e ) {
451+ Log .e ("MainActivity" , "Failed to download model" , e );
452+ runOnUiThread (() -> {
453+ mDownloadModelButton .setEnabled (true );
454+ mDownloadModelButton .setText (R .string .download_model );
455+ mProgressBar .setVisibility (View .INVISIBLE );
456+ mModelStatusText .setText ("Download failed" );
457+ showUIMessage (this , "Download failed: " + e .getMessage ());
458+ });
459+ }
460+ }).start ();
461+ }
329462}
0 commit comments