Skip to content

Commit b83a7b7

Browse files
committed
fix: fix sam3 infer + add convert/infer scripts
1 parent ff0701d commit b83a7b7

5 files changed

Lines changed: 124 additions & 12 deletions

File tree

convert_sam3.sh

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/bin/bash
2+
# Export SAM3 ViT-H to ONNX.
3+
#
4+
# Requirements:
5+
# - sam3 submodule initialised: git submodule update --init sam3
6+
# - osam installed for CLIP tokenisation: pip install osam
7+
#
8+
# Optional: pass --simplify to run onnxsim after export (reduces some
9+
# redundant ops; vision_pos_enc_0/1 may be removed from the decoder).
10+
11+
set -euo pipefail
12+
13+
OUTPUT_DIR="${1:-output_models/sam3}"
14+
SIMPLIFY="${SIMPLIFY:-}"
15+
16+
echo "Exporting SAM3 ViT-H to ONNX → $OUTPUT_DIR"
17+
18+
if [ -n "$SIMPLIFY" ]; then
19+
python -m samexporter.export_sam3 \
20+
--output_dir "$OUTPUT_DIR" \
21+
--opset 18 \
22+
--simplify
23+
else
24+
python -m samexporter.export_sam3 \
25+
--output_dir "$OUTPUT_DIR" \
26+
--opset 18
27+
fi
28+
29+
echo "Done – models written to $OUTPUT_DIR/"
30+
echo " sam3_image_encoder.onnx"
31+
echo " sam3_language_encoder.onnx"
32+
echo " sam3_decoder.onnx"

infer_sam3.sh

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/bin/bash
2+
# SAM3 inference examples.
3+
#
4+
# SAM3 supports three prompt modes:
5+
# 1. Text only – open-vocabulary detection (no geometric hint needed)
6+
# 2. Text + point – refine detection around a clicked pixel
7+
# 3. Text + rectangle – constrain detection to a bounding box
8+
#
9+
# The --text_prompt flag drives the language encoder; always supply it for
10+
# best results. If omitted the model falls back to "visual" (no language
11+
# guidance) which may return fewer or no detections.
12+
13+
set -euo pipefail
14+
15+
ENC="output_models/sam3/sam3_image_encoder.onnx"
16+
DEC="output_models/sam3/sam3_decoder.onnx"
17+
LANG="output_models/sam3/sam3_language_encoder.onnx"
18+
IMG="images/truck.jpg"
19+
20+
echo "--- SAM3: text-only prompt ('truck') ---"
21+
python -m samexporter.inference \
22+
--encoder_model "$ENC" \
23+
--decoder_model "$DEC" \
24+
--language_encoder_model "$LANG" \
25+
--image "$IMG" \
26+
--prompt images/truck_sam3.json \
27+
--text_prompt "truck" \
28+
--output output_images/sam3_truck_text.png \
29+
--sam_variant sam3
30+
31+
echo "--- SAM3: text + rectangle prompt ---"
32+
python -m samexporter.inference \
33+
--encoder_model "$ENC" \
34+
--decoder_model "$DEC" \
35+
--language_encoder_model "$LANG" \
36+
--image "$IMG" \
37+
--prompt images/truck_sam3_box.json \
38+
--text_prompt "truck" \
39+
--output output_images/sam3_truck_box.png \
40+
--sam_variant sam3
41+
42+
echo "--- SAM3: text + point prompt ---"
43+
python -m samexporter.inference \
44+
--encoder_model "$ENC" \
45+
--decoder_model "$DEC" \
46+
--language_encoder_model "$LANG" \
47+
--image "$IMG" \
48+
--prompt images/truck_sam3_point.json \
49+
--text_prompt "truck" \
50+
--output output_images/sam3_truck_point.png \
51+
--sam_variant sam3
52+
53+
echo "Done – outputs saved to output_images/"

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"onnxsim==0.5.0",
3232
"numpy==1.26.4",
3333
"onnxscript==0.6.2",
34+
"osam", # CLIP tokeniser for SAM3 text prompts
3435
]
3536

3637
[project.urls]

samexporter/inference.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ def str2bool(v):
3636
default=None,
3737
help="Path to the ONNX language encoder model (for SAM3)",
3838
)
39+
argparser.add_argument(
40+
"--text_prompt",
41+
type=str,
42+
default=None,
43+
help="Text prompt for SAM3 (e.g. 'truck'). Overrides any text entry in the prompt JSON.",
44+
)
3945
argparser.add_argument(
4046
"--image",
4147
type=str,
@@ -90,11 +96,14 @@ def str2bool(v):
9096

9197
text_prompt = None
9298
if args.sam_variant == "sam3":
93-
# Extract text prompt from JSON if available, otherwise default to "visual"
94-
for p in prompt:
95-
if p["type"] == "text":
96-
text_prompt = p["data"]
97-
break
99+
# --text_prompt takes priority; fall back to any text entry in the JSON.
100+
if args.text_prompt:
101+
text_prompt = args.text_prompt
102+
else:
103+
for p in prompt:
104+
if p["type"] == "text":
105+
text_prompt = p["data"]
106+
break
98107
if text_prompt is None:
99108
text_prompt = "visual"
100109

@@ -109,14 +118,14 @@ def str2bool(v):
109118
# Merge masks
110119
mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8)
111120
if args.sam_variant == "sam3":
112-
# SAM3 returns (N, 1, H, W) – render all N detected instances.
121+
# SAM3 returns bool (N, 1, H, W) – render all N detected instances.
113122
for i in range(masks.shape[0]):
114-
m = masks[i, 0] # (H, W)
115-
mask[m > 0.5] = [255, 0, 0]
123+
m = masks[i, 0] # (H, W) bool
124+
mask[m] = [255, 0, 0]
116125
else:
117-
# SAM1/SAM2 returns (1, 3, H, W) – merge all quality levels.
126+
# SAM1/SAM2 return raw logits (1, 3, H, W) – threshold at 0 (= sigmoid 0.5).
118127
for m in masks[0, :, :, :]:
119-
mask[m > 0.5] = [255, 0, 0]
128+
mask[m > 0.0] = [255, 0, 0]
120129

121130
# Binding image and mask
122131
visualized = cv2.addWeighted(image, 0.5, mask, 0.5, 0)

samexporter/sam3_onnx.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@ def encode(self, cv_image: np.ndarray, text_prompt=None) -> dict[str, Any]:
6868

6969
return embedding
7070

71-
def predict_masks(self, embedding: dict[str, Any], prompt) -> np.ndarray:
71+
def predict_masks(
72+
self,
73+
embedding: dict[str, Any],
74+
prompt,
75+
confidence_threshold: float = 0.5,
76+
) -> np.ndarray:
7277
"""Run the decoder for the given geometric prompt.
7378
7479
Parameters
@@ -78,6 +83,9 @@ def predict_masks(self, embedding: dict[str, Any], prompt) -> np.ndarray:
7883
prompt:
7984
List of mark dicts, each with keys ``"type"`` (``"rectangle"``
8085
or ``"point"``) and ``"data"``.
86+
confidence_threshold:
87+
Minimum score to keep a detection. Detections with score below
88+
this value are discarded. Defaults to ``0.5``.
8189
8290
Returns
8391
-------
@@ -114,7 +122,7 @@ def predict_masks(self, embedding: dict[str, Any], prompt) -> np.ndarray:
114122
box_labels_np = np.array([box_labels], dtype=np.int64)
115123
box_masks_np = np.array([box_masks], dtype=np.bool_)
116124

117-
masks, _scores, _boxes = self.decoder(
125+
masks, scores, _boxes = self.decoder(
118126
original_size,
119127
embedding["vision_pos_enc_0"],
120128
embedding["vision_pos_enc_1"],
@@ -130,6 +138,15 @@ def predict_masks(self, embedding: dict[str, Any], prompt) -> np.ndarray:
130138
box_masks_np,
131139
)
132140

141+
# Filter detections by confidence score.
142+
if len(scores) > 0:
143+
keep = np.where(scores > confidence_threshold)[0]
144+
masks = (
145+
masks[keep]
146+
if len(keep) > 0
147+
else np.zeros((0,) + masks.shape[1:], dtype=masks.dtype)
148+
)
149+
133150
return masks
134151

135152
def transform_masks(self, masks, original_size, transform_matrix):

0 commit comments

Comments
 (0)