|
33 | 33 | import androidx.core.content.ContextCompat; |
34 | 34 |
|
35 | 35 | import java.io.File; |
| 36 | +import java.io.FileOutputStream; |
36 | 37 | import java.io.IOException; |
| 38 | +import java.io.InputStream; |
| 39 | +import java.net.HttpURLConnection; |
| 40 | +import java.net.URL; |
37 | 41 | import java.util.ArrayList; |
38 | 42 | import org.pytorch.executorch.EValue; |
39 | 43 | import org.pytorch.executorch.Module; |
|
42 | 46 | public class MainActivity extends Activity implements Runnable { |
43 | 47 | private ImageView mImageView; |
44 | 48 | private Button mButtonXnnpack; |
| 49 | + private Button mDownloadModelButton; |
45 | 50 | private ProgressBar mProgressBar; |
46 | 51 | private android.widget.TextView mInferenceTimeText; |
| 52 | + private android.widget.TextView mModelStatusText; |
47 | 53 | private Bitmap mBitmap = null; |
48 | 54 | private Module mModule = null; |
49 | 55 | private String mImagename = "corgi.jpeg"; |
50 | 56 | private long mInferenceTime = 0; |
51 | 57 |
|
| 58 | + // Model download configuration |
| 59 | + private static final String MODEL_URL = "https://example.com/dl3_xnnpack_fp32.pte"; // TODO: Replace with actual URL |
| 60 | + private static final String MODEL_PATH = "/data/local/tmp/dl3_xnnpack_fp32.pte"; |
| 61 | + |
52 | 62 | private final ArrayList<String> mImageFiles = new ArrayList<>(); |
53 | 63 |
|
54 | 64 | private int mCurrentImageIndex = 0; |
@@ -221,15 +231,21 @@ protected void onCreate(Bundle savedInstanceState) { |
221 | 231 | // Initialize all views first! |
222 | 232 | mImageView = findViewById(R.id.imageView); |
223 | 233 | mButtonXnnpack = findViewById(R.id.xnnpackButton); |
| 234 | + mDownloadModelButton = findViewById(R.id.downloadModelButton); |
224 | 235 | mProgressBar = findViewById(R.id.progressBar); |
225 | 236 | mInferenceTimeText = findViewById(R.id.inferenceTimeText); |
| 237 | + mModelStatusText = findViewById(R.id.modelStatusText); |
226 | 238 |
|
227 | 239 | populateImagePathFromAssets(); |
228 | 240 | showImage(); |
229 | 241 |
|
230 | | - mModule = Module.load("/data/local/tmp/dl3_xnnpack_fp32.pte"); |
| 242 | + // Check if model exists and load it, otherwise show download button |
| 243 | + loadModelOrShowDownloadButton(); |
231 | 244 | mImageView.setImageBitmap(mBitmap); |
232 | 245 |
|
| 246 | + // Download button click handler |
| 247 | + mDownloadModelButton.setOnClickListener(v -> downloadModel()); |
| 248 | + |
233 | 249 | final Button buttonNext = findViewById(R.id.nextButton); |
234 | 250 | buttonNext.setOnClickListener( |
235 | 251 | new View.OnClickListener() { |
@@ -372,4 +388,75 @@ private int blendColors(int background, int foreground, float alpha) { |
372 | 388 | int b = (int) (bgB * (1 - alpha) + fgB * alpha); |
373 | 389 | return 0xFF000000 | (r << 16) | (g << 8) | b; |
374 | 390 | } |
| 391 | + |
| 392 | + private void loadModelOrShowDownloadButton() { |
| 393 | + File modelFile = new File(MODEL_PATH); |
| 394 | + if (modelFile.exists()) { |
| 395 | + try { |
| 396 | + mModule = Module.load(MODEL_PATH); |
| 397 | + mButtonXnnpack.setEnabled(true); |
| 398 | + mDownloadModelButton.setVisibility(View.GONE); |
| 399 | + mModelStatusText.setText("Model loaded"); |
| 400 | + mModelStatusText.setVisibility(View.VISIBLE); |
| 401 | + } catch (Exception e) { |
| 402 | + Log.e("MainActivity", "Failed to load model", e); |
| 403 | + showUIMessage(this, "Failed to load model: " + e.getMessage()); |
| 404 | + mButtonXnnpack.setEnabled(false); |
| 405 | + mDownloadModelButton.setVisibility(View.VISIBLE); |
| 406 | + mModelStatusText.setText("Model load failed"); |
| 407 | + mModelStatusText.setVisibility(View.VISIBLE); |
| 408 | + } |
| 409 | + } else { |
| 410 | + mButtonXnnpack.setEnabled(false); |
| 411 | + mDownloadModelButton.setVisibility(View.VISIBLE); |
| 412 | + mModelStatusText.setText("Model not found"); |
| 413 | + mModelStatusText.setVisibility(View.VISIBLE); |
| 414 | + } |
| 415 | + } |
| 416 | + |
| 417 | + private void downloadModel() { |
| 418 | + mDownloadModelButton.setEnabled(false); |
| 419 | + mDownloadModelButton.setText(R.string.downloading); |
| 420 | + mProgressBar.setVisibility(View.VISIBLE); |
| 421 | + mModelStatusText.setText("Downloading..."); |
| 422 | + |
| 423 | + new Thread(() -> { |
| 424 | + try { |
| 425 | + URL url = new URL(MODEL_URL); |
| 426 | + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); |
| 427 | + connection.setRequestMethod("GET"); |
| 428 | + connection.connect(); |
| 429 | + |
| 430 | + if (connection.getResponseCode() != HttpURLConnection.HTTP_OK) { |
| 431 | + throw new IOException("Server returned HTTP " + connection.getResponseCode()); |
| 432 | + } |
| 433 | + |
| 434 | + File outputFile = new File(MODEL_PATH); |
| 435 | + try (InputStream input = connection.getInputStream(); |
| 436 | + FileOutputStream output = new FileOutputStream(outputFile)) { |
| 437 | + byte[] buffer = new byte[4096]; |
| 438 | + int bytesRead; |
| 439 | + while ((bytesRead = input.read(buffer)) != -1) { |
| 440 | + output.write(buffer, 0, bytesRead); |
| 441 | + } |
| 442 | + } |
| 443 | + |
| 444 | + runOnUiThread(() -> { |
| 445 | + mDownloadModelButton.setText(R.string.download_model); |
| 446 | + mProgressBar.setVisibility(View.INVISIBLE); |
| 447 | + loadModelOrShowDownloadButton(); |
| 448 | + showUIMessage(this, "Model downloaded successfully!"); |
| 449 | + }); |
| 450 | + } catch (Exception e) { |
| 451 | + Log.e("MainActivity", "Failed to download model", e); |
| 452 | + runOnUiThread(() -> { |
| 453 | + mDownloadModelButton.setEnabled(true); |
| 454 | + mDownloadModelButton.setText(R.string.download_model); |
| 455 | + mProgressBar.setVisibility(View.INVISIBLE); |
| 456 | + mModelStatusText.setText("Download failed"); |
| 457 | + showUIMessage(this, "Download failed: " + e.getMessage()); |
| 458 | + }); |
| 459 | + } |
| 460 | + }).start(); |
| 461 | + } |
375 | 462 | } |
0 commit comments