Skip to content

Commit 90a02ef

Browse files
feat: support sam/sam2 with qnn (microsoft#315)
Co-authored-by: ziyuanguo <ziyuanguo@microsoft.com>
1 parent ad62f41 commit 90a02ef

40 files changed

Lines changed: 3206 additions & 11 deletions

.aitk/configs/checks.json

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
{
2-
"configCheck": 165,
3-
"copyCheck": 180,
2+
"configCheck": 167,
3+
"copyCheck": 182,
44
"extensionCheck": 2,
5-
"gitignoreCheck": 42,
5+
"gitignoreCheck": 44,
66
"inferenceModelCheck": 25,
7-
"ipynbCheck": 42,
8-
"licenseCheck": 39,
9-
"modelProjectCheck": 44,
7+
"ipynbCheck": 44,
8+
"licenseCheck": 41,
9+
"modelProjectCheck": 46,
1010
"oliveCheck": 60,
11-
"oliveJsonCheck": 165,
12-
"pathCheck": 1397,
11+
"oliveJsonCheck": 167,
12+
"pathCheck": 1423,
1313
"requirementsCheck": 37,
1414
"templateCheck": 3,
15-
"venvRequirementsCheck": 16
15+
"venvRequirementsCheck": 17
1616
}

.aitk/configs/model_list.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,38 @@
471471
"text-generation"
472472
]
473473
},
474+
{
475+
"displayName": "facebook/sam-vit-base",
476+
"icon": "meta",
477+
"modelLink": "https://huggingface.co/facebook/sam-vit-base",
478+
"id": "huggingface/facebook/sam-vit-base",
479+
"runtimes": [
480+
"QNN"
481+
],
482+
"architecture": "Transformer",
483+
"status": "Hide",
484+
"relativePath": "sam-vit-base/aitk",
485+
"version": 1,
486+
"pipeline_tags": [
487+
"fill-mask"
488+
]
489+
},
490+
{
491+
"displayName": "facebook/sam2.1-hiera-small",
492+
"icon": "meta",
493+
"modelLink": "https://huggingface.co/facebook/sam2.1-hiera-small",
494+
"id": "huggingface/facebook/sam2.1-hiera-small",
495+
"runtimes": [
496+
"QNN"
497+
],
498+
"architecture": "Transformer",
499+
"status": "Hide",
500+
"relativePath": "sam2.1-hiera-small/aitk",
501+
"version": 1,
502+
"pipeline_tags": [
503+
"fill-mask"
504+
]
505+
},
474506
{
475507
"displayName": "meta-llama/Llama-3.1-8B-Instruct",
476508
"icon": "meta",
@@ -925,6 +957,7 @@
925957
"AIMClab-RUC/COCO-CN": "https://huggingface.co/datasets/AIMClab-RUC/COCO-CN",
926958
"librispeech_asr": "https://huggingface.co/datasets/openslr/librispeech_asr",
927959
"phiyodr/coco2017": "https://huggingface.co/datasets/phiyodr/coco2017",
960+
"nielsr/coco-panoptic-val2017": "https://huggingface.co/datasets/nielsr/coco-panoptic-val2017",
928961
"pileval_for_awq_benchmark": "https://huggingface.co/datasets/mit-han-lab/pile-val-backup"
929962
},
930963
"LoginRequiredDatasets": [
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
onnxsim==0.6.2
2+
sam2==1.1.0
3+
transformers==4.56.2

.aitk/scripts/project_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def fetch_pipeline_tags(model_link: str) -> Optional[List[str]]:
4949
"google": IconEnum.Gemini,
5050
"deepseek-ai": IconEnum.DeepSeek,
5151
"Qwen": IconEnum.qwen,
52+
"facebook": IconEnum.Meta,
5253
"meta-llama": IconEnum.Meta,
5354
"mistralai": IconEnum.mistralai,
5455
# TODO add

sam-vit-base/QNN/sam_mask_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def main():
5050
parser.add_argument("--image_path", required=True, help="Path to input image")
5151
parser.add_argument("--output_path", default="mask_output.png", help="Path to save the output mask image")
5252
parser.add_argument("--box_x", type=int, default=40, help="Top-Left X coordinate of input box")
53-
parser.add_argument("--box_y", type=int, default=235, help="To-Left Y coordinate of input box")
53+
parser.add_argument("--box_y", type=int, default=235, help="Top-Left Y coordinate of input box")
5454
parser.add_argument("--box_w", type=int, default=940, help="Width of input box")
5555
parser.add_argument("--box_h", type=int, default=490, help="Height of input box")
5656
args = parser.parse_args()

sam-vit-base/aitk/.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
__pycache__
2+
/cache
3+
/history/*/*
4+
!/history/*/history.config
5+
!/history/*/olive_config.json
6+
/quantization_dataset

sam-vit-base/aitk/README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SAM Model Conversion
2+
3+
This repository demonstrates the optimization of the [facebook/sam-vit-base](https://huggingface.co/facebook/sam-vit-base) model using **post-training quantization (PTQ)** techniques.
4+
5+
6+
### Run the Quantization + Compilation Config
7+
Activate the **Quantization Python Environment** and run the workflow:
8+
9+
For Encoder Model:
10+
```bash
11+
olive run --config sam_vision_encoder_qnn.json
12+
```
13+
14+
For Point and Box based Decoder Model:
15+
```bash
16+
olive run --config sam_mask_decoder_qnn_fp16_ctx.json
17+
```
18+
19+
### Model ORT Execution
20+
21+
Execute SAM model in **AOT Compilation Python Environment** using following command:
22+
23+
```bash
24+
python sam_mask_generator.py --model_ve path/to/encoder_model.onnx --model_md path/to/decoder_model.onnx --image_path car.png --box_x 40 --box_y 235 --box_w 940 --box_h 490 --output_path car_mask.png
25+
```
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"copies": [
3+
{
4+
"src": "../../intel-bert-base-uncased-mrpc/aitk/winml.py",
5+
"dst": "winml.py"
6+
}
7+
]
8+
}

sam-vit-base/aitk/config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
6+
class ModelConfig:
7+
model_name = "facebook/sam-vit-base"
8+
data_dir = "quantization_dataset"
9+
image_dataset = "nielsr/coco-panoptic-val2017"
10+
image_dataset_split = "train"
11+
ve_input_name = "pixel_values"
12+
ve_sample_size = 1024
13+
ve_channel_size = 3
14+
mask_point_input_names = ("input_points", "image_embeddings")
15+
mask_point_input_shapes = ((1, 1, 2), (256, 64, 64))
16+
mask_box_input_names = ("input_boxes", "image_embeddings")
17+
mask_box_input_shapes = ((1, 4), (256, 64, 64))
18+
mask_input_names = ("input_points", "input_labels", "image_embeddings")
19+
mask_input_shapes = ((1, 2, 2), (1, 2), (256, 64, 64))
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "547a25de",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"encoder_path = \"./model/encoder/model.onnx\"\n",
11+
"decoder_path = \"./model/decoder/model.onnx\"\n",
12+
"\n",
13+
"ExecutionProvider=\"QNNExecutionProvider\""
14+
]
15+
},
16+
{
17+
"cell_type": "code",
18+
"execution_count": null,
19+
"id": "eed9c231",
20+
"metadata": {},
21+
"outputs": [],
22+
"source": [
23+
"# reference: https://learn.microsoft.com/en-us/windows/ai/new-windows-ml/tutorial?tabs=python#acquiring-the-model-and-preprocessing\n",
24+
"import subprocess\n",
25+
"import json\n",
26+
"import sys\n",
27+
"import os\n",
28+
"import onnxruntime as ort\n",
29+
"\n",
30+
"def register_execution_providers():\n",
31+
" worker_script = os.path.abspath('winml.py')\n",
32+
" print(worker_script)\n",
33+
" result = subprocess.check_output([sys.executable, worker_script], text=True)\n",
34+
" paths = json.loads(result)\n",
35+
" for item in paths.items():\n",
36+
" try:\n",
37+
" ort.register_execution_provider_library(item[0], item[1])\n",
38+
" except Exception as e:\n",
39+
" print(f\"Failed to register execution provider {item[0]}: {e}\")\n",
40+
"\n",
41+
"register_execution_providers()"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": null,
47+
"id": "9e9d8984",
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"from urllib import request\n",
52+
"\n",
53+
"test_image_url = \"https://github.com/facebookresearch/segment-anything/blob/main/notebooks/images/truck.jpg?raw=true\"\n",
54+
"test_image_name = \"truck.jpg\"\n",
55+
"\n",
56+
"request.urlretrieve(test_image_url, test_image_name)\n",
57+
"\n",
58+
"from IPython.display import Image, display\n",
59+
"\n",
60+
"display(Image(filename=test_image_name))"
61+
]
62+
},
63+
{
64+
"cell_type": "code",
65+
"execution_count": null,
66+
"id": "ffcd22ad",
67+
"metadata": {},
68+
"outputs": [],
69+
"source": [
70+
"import sys\n",
71+
"from pathlib import Path\n",
72+
"\n",
73+
"NOTEBOOK_DIR = Path(__file__).parent if \"__file__\" in globals() else Path.cwd()\n",
74+
"PROJECT_ROOT = NOTEBOOK_DIR.parents[1]\n",
75+
"sys.path.insert(0, str(PROJECT_ROOT))\n",
76+
"\n",
77+
"import numpy as np\n",
78+
"from PIL import Image\n",
79+
"from sam_mask_generator import get_mask_ort\n",
80+
"\n",
81+
"def add_ep_for_device(session_options, ep_name, device_type, ep_options=None):\n",
82+
" ep_devices = ort.get_ep_devices()\n",
83+
" for ep_device in ep_devices:\n",
84+
" if ep_device.ep_name == ep_name and ep_device.device.type == device_type:\n",
85+
" print(f\"Adding {ep_name} for {device_type}\")\n",
86+
" session_options.add_provider_for_devices([ep_device], {} if ep_options is None else ep_options)\n",
87+
" break\n",
88+
"\n",
89+
"\n",
90+
"sess_options = ort.SessionOptions()\n",
91+
"\n",
92+
"add_ep_for_device(sess_options, ExecutionProvider, ort.OrtHardwareDeviceType.CPU)\n",
93+
"\n",
94+
"# Load image\n",
95+
"raw_image = Image.open(test_image_name).convert(\"RGB\")\n",
96+
"input_box = [[[100, 300], [1750, 900]]]\n",
97+
"\n",
98+
"# Load models\n",
99+
"sess_ve = ort.InferenceSession(encoder_path, sess_options=sess_options)\n",
100+
"sess_md = ort.InferenceSession(decoder_path, sess_options=sess_options)\n",
101+
"\n",
102+
"sess_ve_inputs = sess_ve.get_inputs()\n",
103+
"sess_md_inputs = sess_md.get_inputs()\n",
104+
"\n",
105+
"ve_dtype = np.float32 if sess_ve_inputs[0].type == 'tensor(float)' else np.float16\n",
106+
"md_dtype = np.float32 if sess_md_inputs[0].type == 'tensor(float)' else np.float16\n",
107+
"\n",
108+
"# Get mask\n",
109+
"mask = get_mask_ort(sess_ve, sess_md, raw_image, input_box, ve_dtype, md_dtype, sess_ve_inputs, sess_md_inputs)\n",
110+
"\n",
111+
"# Save mask using PIL\n",
112+
"mask_img = Image.fromarray(mask * 255) # Convert binary mask to 0-255\n",
113+
"\n",
114+
"from IPython.display import display\n",
115+
"display(mask_img)"
116+
]
117+
}
118+
],
119+
"metadata": {
120+
"language_info": {
121+
"name": "python"
122+
}
123+
},
124+
"nbformat": 4,
125+
"nbformat_minor": 5
126+
}

0 commit comments

Comments
 (0)