|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": null, |
| 6 | + "id": "547a25de", |
| 7 | + "metadata": {}, |
| 8 | + "outputs": [], |
| 9 | + "source": [ |
| 10 | + "encoder_path = \"./model/encoder/model.onnx\"\n", |
| 11 | + "decoder_path = \"./model/decoder/model.onnx\"\n", |
| 12 | + "\n", |
| 13 | + "ExecutionProvider=\"QNNExecutionProvider\"" |
| 14 | + ] |
| 15 | + }, |
| 16 | + { |
| 17 | + "cell_type": "code", |
| 18 | + "execution_count": null, |
| 19 | + "id": "eed9c231", |
| 20 | + "metadata": {}, |
| 21 | + "outputs": [], |
| 22 | + "source": [ |
| 23 | + "# reference: https://learn.microsoft.com/en-us/windows/ai/new-windows-ml/tutorial?tabs=python#acquiring-the-model-and-preprocessing\n", |
| 24 | + "import subprocess\n", |
| 25 | + "import json\n", |
| 26 | + "import sys\n", |
| 27 | + "import os\n", |
| 28 | + "import onnxruntime as ort\n", |
| 29 | + "\n", |
| 30 | + "def register_execution_providers():\n", |
| 31 | + " worker_script = os.path.abspath('winml.py')\n", |
| 32 | + " print(worker_script)\n", |
| 33 | + " result = subprocess.check_output([sys.executable, worker_script], text=True)\n", |
| 34 | + " paths = json.loads(result)\n", |
| 35 | + " for item in paths.items():\n", |
| 36 | + " try:\n", |
| 37 | + " ort.register_execution_provider_library(item[0], item[1])\n", |
| 38 | + " except Exception as e:\n", |
| 39 | + " print(f\"Failed to register execution provider {item[0]}: {e}\")\n", |
| 40 | + "\n", |
| 41 | + "register_execution_providers()" |
| 42 | + ] |
| 43 | + }, |
| 44 | + { |
| 45 | + "cell_type": "code", |
| 46 | + "execution_count": null, |
| 47 | + "id": "9e9d8984", |
| 48 | + "metadata": {}, |
| 49 | + "outputs": [], |
| 50 | + "source": [ |
| 51 | + "from urllib import request\n", |
| 52 | + "\n", |
| 53 | + "test_image_url = \"https://github.com/facebookresearch/segment-anything/blob/main/notebooks/images/truck.jpg?raw=true\"\n", |
| 54 | + "test_image_name = \"truck.jpg\"\n", |
| 55 | + "\n", |
| 56 | + "request.urlretrieve(test_image_url, test_image_name)\n", |
| 57 | + "\n", |
| 58 | + "from IPython.display import Image, display\n", |
| 59 | + "\n", |
| 60 | + "display(Image(filename=test_image_name))" |
| 61 | + ] |
| 62 | + }, |
| 63 | + { |
| 64 | + "cell_type": "code", |
| 65 | + "execution_count": null, |
| 66 | + "id": "ffcd22ad", |
| 67 | + "metadata": {}, |
| 68 | + "outputs": [], |
| 69 | + "source": [ |
| 70 | + "import sys\n", |
| 71 | + "from pathlib import Path\n", |
| 72 | + "\n", |
| 73 | + "NOTEBOOK_DIR = Path(__file__).parent if \"__file__\" in globals() else Path.cwd()\n", |
| 74 | + "PROJECT_ROOT = NOTEBOOK_DIR.parents[1]\n", |
| 75 | + "sys.path.insert(0, str(PROJECT_ROOT))\n", |
| 76 | + "\n", |
| 77 | + "import numpy as np\n", |
| 78 | + "from PIL import Image\n", |
| 79 | + "from sam_mask_generator import get_mask_ort\n", |
| 80 | + "\n", |
| 81 | + "def add_ep_for_device(session_options, ep_name, device_type, ep_options=None):\n", |
| 82 | + " ep_devices = ort.get_ep_devices()\n", |
| 83 | + " for ep_device in ep_devices:\n", |
| 84 | + " if ep_device.ep_name == ep_name and ep_device.device.type == device_type:\n", |
| 85 | + " print(f\"Adding {ep_name} for {device_type}\")\n", |
| 86 | + " session_options.add_provider_for_devices([ep_device], {} if ep_options is None else ep_options)\n", |
| 87 | + " break\n", |
| 88 | + "\n", |
| 89 | + "\n", |
| 90 | + "sess_options = ort.SessionOptions()\n", |
| 91 | + "\n", |
| 92 | + "add_ep_for_device(sess_options, ExecutionProvider, ort.OrtHardwareDeviceType.CPU)\n", |
| 93 | + "\n", |
| 94 | + "# Load image\n", |
| 95 | + "raw_image = Image.open(test_image_name).convert(\"RGB\")\n", |
| 96 | + "input_box = [[[100, 300], [1750, 900]]]\n", |
| 97 | + "\n", |
| 98 | + "# Load models\n", |
| 99 | + "sess_ve = ort.InferenceSession(encoder_path, sess_options=sess_options)\n", |
| 100 | + "sess_md = ort.InferenceSession(decoder_path, sess_options=sess_options)\n", |
| 101 | + "\n", |
| 102 | + "sess_ve_inputs = sess_ve.get_inputs()\n", |
| 103 | + "sess_md_inputs = sess_md.get_inputs()\n", |
| 104 | + "\n", |
| 105 | + "ve_dtype = np.float32 if sess_ve_inputs[0].type == 'tensor(float)' else np.float16\n", |
| 106 | + "md_dtype = np.float32 if sess_md_inputs[0].type == 'tensor(float)' else np.float16\n", |
| 107 | + "\n", |
| 108 | + "# Get mask\n", |
| 109 | + "mask = get_mask_ort(sess_ve, sess_md, raw_image, input_box, ve_dtype, md_dtype, sess_ve_inputs, sess_md_inputs)\n", |
| 110 | + "\n", |
| 111 | + "# Save mask using PIL\n", |
| 112 | + "mask_img = Image.fromarray(mask * 255) # Convert binary mask to 0-255\n", |
| 113 | + "\n", |
| 114 | + "from IPython.display import display\n", |
| 115 | + "display(mask_img)" |
| 116 | + ] |
| 117 | + } |
| 118 | + ], |
| 119 | + "metadata": { |
| 120 | + "language_info": { |
| 121 | + "name": "python" |
| 122 | + } |
| 123 | + }, |
| 124 | + "nbformat": 4, |
| 125 | + "nbformat_minor": 5 |
| 126 | +} |
0 commit comments