Skip to content

Commit 19090d0

Browse files
Port SAM3 from inference/models to inference-models (#1946)
1 parent a8f1126 commit 19090d0

19 files changed

Lines changed: 3271 additions & 55 deletions

inference/core/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.2.9"
1+
__version__ = "1.2.10"
22

33

44
if __name__ == "__main__":
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
from time import perf_counter
2+
from typing import Any, Dict, List, Optional, Tuple
3+
4+
import numpy as np
5+
import torch
6+
from pycocotools import mask as mask_utils
7+
8+
from inference.core.entities.requests.inference import InferenceRequestImage
9+
from inference.core.entities.requests.sam3 import (
10+
Sam3InferenceRequest,
11+
Sam3Prompt,
12+
Sam3SegmentationRequest,
13+
)
14+
from inference.core.entities.responses.sam3 import (
15+
Sam3PromptEcho,
16+
Sam3PromptResult,
17+
Sam3SegmentationPrediction,
18+
Sam3SegmentationResponse,
19+
)
20+
from inference.core.env import (
21+
ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
22+
ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
23+
API_KEY,
24+
DEVICE,
25+
DISABLED_INFERENCE_MODELS_BACKENDS,
26+
VALID_INFERENCE_MODELS_BACKENDS,
27+
)
28+
from inference.core.models.base import Model
29+
from inference.core.roboflow_api import get_extra_weights_provider_headers
30+
from inference.core.utils.image_utils import load_image_rgb
31+
from inference.core.utils.postprocess import masks2multipoly
32+
from inference.usage_tracking.collector import usage_collector
33+
from inference_models import AutoModel
34+
from inference_models.models.sam3.sam3_torch import SAM3Torch
35+
36+
if DEVICE is None:
37+
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
38+
39+
40+
class InferenceModelsSAM3Adapter(Model):
41+
"""Adapter wrapping inference_models SAM3Torch for open-vocabulary segmentation.
42+
43+
Replaces inference.models.sam3.segment_anything3.SegmentAnything3.
44+
Handles Sam3SegmentationRequest with text and/or visual (box) prompts via
45+
SAM3Torch.segment_with_text_prompts.
46+
"""
47+
48+
def __init__(
49+
self,
50+
*args,
51+
model_id: str = "sam3/sam3_final",
52+
api_key: Optional[str] = None,
53+
**kwargs,
54+
):
55+
super().__init__()
56+
self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0}
57+
self.api_key = api_key if api_key else API_KEY
58+
self.task_type = "unsupervised-segmentation"
59+
60+
extra_weights_provider_headers = get_extra_weights_provider_headers(
61+
countinference=kwargs.get("countinference"),
62+
service_secret=kwargs.get("service_secret"),
63+
)
64+
backend = list(
65+
VALID_INFERENCE_MODELS_BACKENDS.difference(
66+
DISABLED_INFERENCE_MODELS_BACKENDS
67+
)
68+
)
69+
self._model: SAM3Torch = AutoModel.from_pretrained(
70+
model_id_or_path=model_id,
71+
api_key=self.api_key,
72+
allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
73+
allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
74+
weights_provider_extra_headers=extra_weights_provider_headers,
75+
backend=backend,
76+
**kwargs,
77+
)
78+
79+
@usage_collector("model")
80+
def infer_from_request(self, request: Sam3InferenceRequest):
81+
t1 = perf_counter()
82+
if isinstance(request, Sam3SegmentationRequest):
83+
return self.segment_image(
84+
image=request.image,
85+
prompts=request.prompts,
86+
output_prob_thresh=request.output_prob_thresh or 0.5,
87+
format=request.format or "polygon",
88+
nms_iou_threshold=request.nms_iou_threshold,
89+
inference_start_timestamp=t1,
90+
)
91+
raise ValueError(f"Invalid request type {type(request)}")
92+
93+
def segment_image(
94+
self,
95+
image: InferenceRequestImage,
96+
prompts: List[Sam3Prompt],
97+
output_prob_thresh: float = 0.5,
98+
format: str = "polygon",
99+
nms_iou_threshold: Optional[float] = None,
100+
inference_start_timestamp: Optional[float] = None,
101+
) -> Sam3SegmentationResponse:
102+
if inference_start_timestamp is None:
103+
inference_start_timestamp = perf_counter()
104+
np_image = load_image_rgb(image)
105+
106+
# The backend applies a single threshold floor; use the min so per-prompt
107+
# thresholds applied below can still refine higher values.
108+
min_threshold = output_prob_thresh
109+
for p in prompts:
110+
prompt_thresh = getattr(p, "output_prob_thresh", None)
111+
if prompt_thresh is not None:
112+
min_threshold = min(min_threshold, prompt_thresh)
113+
114+
prompt_dicts = [_sam3_prompt_to_dict(p) for p in prompts]
115+
116+
# segment_with_text_prompts returns List[per-image] of List[per-prompt] dicts
117+
# with keys: prompt_index, masks (N,H,W ndarray), scores (list).
118+
per_image_results = self._model.segment_with_text_prompts(
119+
images=[np_image],
120+
prompts=prompt_dicts,
121+
output_prob_thresh=float(min_threshold),
122+
)
123+
per_prompt = per_image_results[0]
124+
125+
# processed: prompt_idx -> {"masks": ndarray, "scores": list}
126+
processed: Dict[int, Dict[str, Any]] = {}
127+
for idx, r in enumerate(per_prompt):
128+
processed[idx] = {
129+
"masks": r.get("masks"),
130+
"scores": list(r.get("scores", [])),
131+
}
132+
133+
if nms_iou_threshold is not None and len(prompts) > 0:
134+
all_masks = _collect_masks_with_per_prompt_threshold(
135+
processed=processed,
136+
prompts=prompts,
137+
default_threshold=output_prob_thresh,
138+
)
139+
if len(all_masks) > 0:
140+
all_masks = _apply_nms_cross_prompt(all_masks, nms_iou_threshold)
141+
regrouped = _regroup_masks_by_prompt(all_masks, len(prompts))
142+
143+
prompt_results: List[Sam3PromptResult] = []
144+
for idx, p in enumerate(prompts):
145+
echo = _build_echo(idx, p)
146+
bucket = regrouped.get(idx, [])
147+
if bucket:
148+
masks_np = np.stack([m for m, _ in bucket], axis=0)
149+
scores = [s for _, s in bucket]
150+
else:
151+
masks_np = np.zeros((0, 0, 0), dtype=np.uint8)
152+
scores = []
153+
preds = _masks_to_predictions(masks_np, scores, format)
154+
prompt_results.append(
155+
Sam3PromptResult(prompt_index=idx, echo=echo, predictions=preds)
156+
)
157+
else:
158+
prompt_results = []
159+
for idx, p in enumerate(prompts):
160+
masks_np = _to_numpy_masks(processed[idx]["masks"])
161+
scores = processed[idx]["scores"]
162+
prompt_thresh = getattr(p, "output_prob_thresh", None)
163+
if prompt_thresh is not None:
164+
masks_np, scores = _filter_by_threshold(
165+
masks_np, scores, prompt_thresh
166+
)
167+
preds = _masks_to_predictions(masks_np, scores, format)
168+
prompt_results.append(
169+
Sam3PromptResult(
170+
prompt_index=idx,
171+
echo=_build_echo(idx, p),
172+
predictions=preds,
173+
)
174+
)
175+
176+
return Sam3SegmentationResponse(
177+
time=perf_counter() - inference_start_timestamp,
178+
prompt_results=prompt_results,
179+
)
180+
181+
182+
def _sam3_prompt_to_dict(p: Sam3Prompt) -> Dict[str, Any]:
183+
d: Dict[str, Any] = {"text": p.text}
184+
if p.boxes:
185+
d["boxes"] = (
186+
p.boxes
187+
) # backend's _build_visual_query handles pydantic Box/BoxXYXY
188+
d["box_labels"] = p.box_labels or []
189+
return d
190+
191+
192+
def _build_echo(prompt_index: int, p: Sam3Prompt) -> Sam3PromptEcho:
193+
has_visual = bool(p.boxes)
194+
return Sam3PromptEcho(
195+
prompt_index=prompt_index,
196+
type="visual" if has_visual else "text",
197+
text=p.text,
198+
num_boxes=len(p.boxes) if has_visual else 0,
199+
)
200+
201+
202+
def _to_numpy_masks(masks_any) -> np.ndarray:
203+
if masks_any is None:
204+
return np.zeros((0, 0, 0), dtype=np.uint8)
205+
if hasattr(masks_any, "detach"):
206+
masks_np = masks_any.detach().cpu().numpy().astype(np.uint8)
207+
else:
208+
arrs = []
209+
for m in masks_any:
210+
if hasattr(m, "detach"):
211+
arrs.append(m.detach().cpu().numpy().astype(np.uint8))
212+
else:
213+
arrs.append(np.asarray(m, dtype=np.uint8))
214+
if not arrs:
215+
return np.zeros((0, 0, 0), dtype=np.uint8)
216+
masks_np = np.stack(arrs, axis=0)
217+
if masks_np.ndim == 4 and masks_np.shape[1] == 1:
218+
masks_np = masks_np[:, 0, ...]
219+
elif masks_np.ndim == 2:
220+
masks_np = masks_np[None, ...]
221+
return masks_np
222+
223+
224+
def _filter_by_threshold(
225+
masks_np: np.ndarray,
226+
scores: List[float],
227+
threshold: float,
228+
) -> Tuple[np.ndarray, List[float]]:
229+
if masks_np.ndim != 3 or masks_np.shape[0] == 0:
230+
return masks_np, scores
231+
keep = [i for i, s in enumerate(scores) if s >= threshold]
232+
if not keep:
233+
return np.zeros((0, 0, 0), dtype=np.uint8), []
234+
return masks_np[keep], [scores[i] for i in keep]
235+
236+
237+
def _masks_to_predictions(
238+
masks_np: np.ndarray, scores: List[float], fmt: str
239+
) -> List[Sam3SegmentationPrediction]:
240+
preds: List[Sam3SegmentationPrediction] = []
241+
if masks_np.ndim != 3 or 0 in masks_np.shape:
242+
return preds
243+
if fmt in ("polygon", "json"):
244+
polygons = masks2multipoly((masks_np > 0).astype(np.uint8))
245+
for poly, score in zip(polygons, scores[: len(polygons)]):
246+
preds.append(
247+
Sam3SegmentationPrediction(
248+
masks=[p.tolist() for p in poly],
249+
confidence=float(score),
250+
format="polygon",
251+
)
252+
)
253+
elif fmt == "rle":
254+
for m, score in zip(masks_np, scores[: masks_np.shape[0]]):
255+
mb = (m > 0).astype(np.uint8)
256+
rle = mask_utils.encode(np.asfortranarray(mb))
257+
rle["counts"] = rle["counts"].decode("utf-8")
258+
preds.append(
259+
Sam3SegmentationPrediction(
260+
masks=rle, confidence=float(score), format="rle"
261+
)
262+
)
263+
return preds
264+
265+
266+
def _collect_masks_with_per_prompt_threshold(
267+
processed: Dict[int, Dict[str, Any]],
268+
prompts: List[Sam3Prompt],
269+
default_threshold: float,
270+
) -> List[Tuple[int, np.ndarray, float]]:
271+
all_masks: List[Tuple[int, np.ndarray, float]] = []
272+
for idx, p in enumerate(prompts):
273+
prompt_thresh = getattr(p, "output_prob_thresh", None)
274+
if prompt_thresh is None:
275+
prompt_thresh = default_threshold
276+
masks_np = _to_numpy_masks(processed[idx]["masks"])
277+
scores = processed[idx]["scores"]
278+
if masks_np.ndim != 3 or 0 in masks_np.shape:
279+
continue
280+
for mask, score in zip(masks_np, scores):
281+
if score >= prompt_thresh:
282+
all_masks.append((idx, mask, float(score)))
283+
return all_masks
284+
285+
286+
def _nms_greedy_pycocotools(
287+
rles: List[Dict],
288+
confidences: np.ndarray,
289+
iou_threshold: float = 0.5,
290+
) -> np.ndarray:
291+
num_detections = len(rles)
292+
if num_detections == 0:
293+
return np.array([], dtype=bool)
294+
sort_index = np.argsort(confidences)[::-1]
295+
sorted_rles = [rles[i] for i in sort_index]
296+
ious = mask_utils.iou(sorted_rles, sorted_rles, [0] * num_detections)
297+
keep = np.ones(num_detections, dtype=bool)
298+
for i in range(num_detections):
299+
if keep[i]:
300+
condition = ious[i, :] > iou_threshold
301+
keep[i + 1 :] = np.where(condition[i + 1 :], False, keep[i + 1 :])
302+
return keep[np.argsort(sort_index)]
303+
304+
305+
def _apply_nms_cross_prompt(
306+
all_masks: List[Tuple[int, np.ndarray, float]],
307+
iou_threshold: float,
308+
) -> List[Tuple[int, np.ndarray, float]]:
309+
if not all_masks:
310+
return all_masks
311+
rles = []
312+
for _, mask_np, _ in all_masks:
313+
mb = (mask_np > 0).astype(np.uint8)
314+
rle = mask_utils.encode(np.asfortranarray(mb))
315+
rles.append(rle)
316+
confidences = np.array([score for _, _, score in all_masks])
317+
keep = _nms_greedy_pycocotools(rles, confidences, iou_threshold)
318+
return [all_masks[i] for i in range(len(all_masks)) if keep[i]]
319+
320+
321+
def _regroup_masks_by_prompt(
322+
filtered_masks: List[Tuple[int, np.ndarray, float]],
323+
num_prompts: int,
324+
) -> Dict[int, List[Tuple[np.ndarray, float]]]:
325+
result: Dict[int, List[Tuple[np.ndarray, float]]] = {
326+
i: [] for i in range(num_prompts)
327+
}
328+
for prompt_idx, mask_np, score in filtered_masks:
329+
result[prompt_idx].append((mask_np, score))
330+
return result

inference/models/sam3/visual_segmentation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,11 +446,14 @@ def find_prior_prompt_in_cache(
446446
"""
447447
Performs search over the cache to see if prior used prompts are subset of this one.
448448
"""
449+
num_points = initial_prompt_set.num_points()
450+
if num_points <= 1:
451+
return None # there is only 1 point, hence no prior prompt can be found
449452

450453
logits_for_image = [cache[k] for k in cache if k[0] == image_id]
451454
maxed_size = 0
452455
best_match: Optional[np.ndarray] = None
453-
desired_size = initial_prompt_set.num_points() - 1
456+
desired_size = num_points - 1
454457
for cached_dict in logits_for_image[::-1]:
455458
logits = cached_dict["logits"]
456459
prompt_set: Sam2PromptSet = cached_dict["prompt_set"]

0 commit comments

Comments
 (0)