|
| 1 | +from time import perf_counter |
| 2 | +from typing import Any, Dict, List, Optional, Tuple |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +from pycocotools import mask as mask_utils |
| 7 | + |
| 8 | +from inference.core.entities.requests.inference import InferenceRequestImage |
| 9 | +from inference.core.entities.requests.sam3 import ( |
| 10 | + Sam3InferenceRequest, |
| 11 | + Sam3Prompt, |
| 12 | + Sam3SegmentationRequest, |
| 13 | +) |
| 14 | +from inference.core.entities.responses.sam3 import ( |
| 15 | + Sam3PromptEcho, |
| 16 | + Sam3PromptResult, |
| 17 | + Sam3SegmentationPrediction, |
| 18 | + Sam3SegmentationResponse, |
| 19 | +) |
| 20 | +from inference.core.env import ( |
| 21 | + ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES, |
| 22 | + ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES, |
| 23 | + API_KEY, |
| 24 | + DEVICE, |
| 25 | + DISABLED_INFERENCE_MODELS_BACKENDS, |
| 26 | + VALID_INFERENCE_MODELS_BACKENDS, |
| 27 | +) |
| 28 | +from inference.core.models.base import Model |
| 29 | +from inference.core.roboflow_api import get_extra_weights_provider_headers |
| 30 | +from inference.core.utils.image_utils import load_image_rgb |
| 31 | +from inference.core.utils.postprocess import masks2multipoly |
| 32 | +from inference.usage_tracking.collector import usage_collector |
| 33 | +from inference_models import AutoModel |
| 34 | +from inference_models.models.sam3.sam3_torch import SAM3Torch |
| 35 | + |
| 36 | +if DEVICE is None: |
| 37 | + DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" |
| 38 | + |
| 39 | + |
| 40 | +class InferenceModelsSAM3Adapter(Model): |
| 41 | + """Adapter wrapping inference_models SAM3Torch for open-vocabulary segmentation. |
| 42 | +
|
| 43 | + Replaces inference.models.sam3.segment_anything3.SegmentAnything3. |
| 44 | + Handles Sam3SegmentationRequest with text and/or visual (box) prompts via |
| 45 | + SAM3Torch.segment_with_text_prompts. |
| 46 | + """ |
| 47 | + |
| 48 | + def __init__( |
| 49 | + self, |
| 50 | + *args, |
| 51 | + model_id: str = "sam3/sam3_final", |
| 52 | + api_key: Optional[str] = None, |
| 53 | + **kwargs, |
| 54 | + ): |
| 55 | + super().__init__() |
| 56 | + self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0} |
| 57 | + self.api_key = api_key if api_key else API_KEY |
| 58 | + self.task_type = "unsupervised-segmentation" |
| 59 | + |
| 60 | + extra_weights_provider_headers = get_extra_weights_provider_headers( |
| 61 | + countinference=kwargs.get("countinference"), |
| 62 | + service_secret=kwargs.get("service_secret"), |
| 63 | + ) |
| 64 | + backend = list( |
| 65 | + VALID_INFERENCE_MODELS_BACKENDS.difference( |
| 66 | + DISABLED_INFERENCE_MODELS_BACKENDS |
| 67 | + ) |
| 68 | + ) |
| 69 | + self._model: SAM3Torch = AutoModel.from_pretrained( |
| 70 | + model_id_or_path=model_id, |
| 71 | + api_key=self.api_key, |
| 72 | + allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES, |
| 73 | + allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES, |
| 74 | + weights_provider_extra_headers=extra_weights_provider_headers, |
| 75 | + backend=backend, |
| 76 | + **kwargs, |
| 77 | + ) |
| 78 | + |
| 79 | + @usage_collector("model") |
| 80 | + def infer_from_request(self, request: Sam3InferenceRequest): |
| 81 | + t1 = perf_counter() |
| 82 | + if isinstance(request, Sam3SegmentationRequest): |
| 83 | + return self.segment_image( |
| 84 | + image=request.image, |
| 85 | + prompts=request.prompts, |
| 86 | + output_prob_thresh=request.output_prob_thresh or 0.5, |
| 87 | + format=request.format or "polygon", |
| 88 | + nms_iou_threshold=request.nms_iou_threshold, |
| 89 | + inference_start_timestamp=t1, |
| 90 | + ) |
| 91 | + raise ValueError(f"Invalid request type {type(request)}") |
| 92 | + |
| 93 | + def segment_image( |
| 94 | + self, |
| 95 | + image: InferenceRequestImage, |
| 96 | + prompts: List[Sam3Prompt], |
| 97 | + output_prob_thresh: float = 0.5, |
| 98 | + format: str = "polygon", |
| 99 | + nms_iou_threshold: Optional[float] = None, |
| 100 | + inference_start_timestamp: Optional[float] = None, |
| 101 | + ) -> Sam3SegmentationResponse: |
| 102 | + if inference_start_timestamp is None: |
| 103 | + inference_start_timestamp = perf_counter() |
| 104 | + np_image = load_image_rgb(image) |
| 105 | + |
| 106 | + # The backend applies a single threshold floor; use the min so per-prompt |
| 107 | + # thresholds applied below can still refine higher values. |
| 108 | + min_threshold = output_prob_thresh |
| 109 | + for p in prompts: |
| 110 | + prompt_thresh = getattr(p, "output_prob_thresh", None) |
| 111 | + if prompt_thresh is not None: |
| 112 | + min_threshold = min(min_threshold, prompt_thresh) |
| 113 | + |
| 114 | + prompt_dicts = [_sam3_prompt_to_dict(p) for p in prompts] |
| 115 | + |
| 116 | + # segment_with_text_prompts returns List[per-image] of List[per-prompt] dicts |
| 117 | + # with keys: prompt_index, masks (N,H,W ndarray), scores (list). |
| 118 | + per_image_results = self._model.segment_with_text_prompts( |
| 119 | + images=[np_image], |
| 120 | + prompts=prompt_dicts, |
| 121 | + output_prob_thresh=float(min_threshold), |
| 122 | + ) |
| 123 | + per_prompt = per_image_results[0] |
| 124 | + |
| 125 | + # processed: prompt_idx -> {"masks": ndarray, "scores": list} |
| 126 | + processed: Dict[int, Dict[str, Any]] = {} |
| 127 | + for idx, r in enumerate(per_prompt): |
| 128 | + processed[idx] = { |
| 129 | + "masks": r.get("masks"), |
| 130 | + "scores": list(r.get("scores", [])), |
| 131 | + } |
| 132 | + |
| 133 | + if nms_iou_threshold is not None and len(prompts) > 0: |
| 134 | + all_masks = _collect_masks_with_per_prompt_threshold( |
| 135 | + processed=processed, |
| 136 | + prompts=prompts, |
| 137 | + default_threshold=output_prob_thresh, |
| 138 | + ) |
| 139 | + if len(all_masks) > 0: |
| 140 | + all_masks = _apply_nms_cross_prompt(all_masks, nms_iou_threshold) |
| 141 | + regrouped = _regroup_masks_by_prompt(all_masks, len(prompts)) |
| 142 | + |
| 143 | + prompt_results: List[Sam3PromptResult] = [] |
| 144 | + for idx, p in enumerate(prompts): |
| 145 | + echo = _build_echo(idx, p) |
| 146 | + bucket = regrouped.get(idx, []) |
| 147 | + if bucket: |
| 148 | + masks_np = np.stack([m for m, _ in bucket], axis=0) |
| 149 | + scores = [s for _, s in bucket] |
| 150 | + else: |
| 151 | + masks_np = np.zeros((0, 0, 0), dtype=np.uint8) |
| 152 | + scores = [] |
| 153 | + preds = _masks_to_predictions(masks_np, scores, format) |
| 154 | + prompt_results.append( |
| 155 | + Sam3PromptResult(prompt_index=idx, echo=echo, predictions=preds) |
| 156 | + ) |
| 157 | + else: |
| 158 | + prompt_results = [] |
| 159 | + for idx, p in enumerate(prompts): |
| 160 | + masks_np = _to_numpy_masks(processed[idx]["masks"]) |
| 161 | + scores = processed[idx]["scores"] |
| 162 | + prompt_thresh = getattr(p, "output_prob_thresh", None) |
| 163 | + if prompt_thresh is not None: |
| 164 | + masks_np, scores = _filter_by_threshold( |
| 165 | + masks_np, scores, prompt_thresh |
| 166 | + ) |
| 167 | + preds = _masks_to_predictions(masks_np, scores, format) |
| 168 | + prompt_results.append( |
| 169 | + Sam3PromptResult( |
| 170 | + prompt_index=idx, |
| 171 | + echo=_build_echo(idx, p), |
| 172 | + predictions=preds, |
| 173 | + ) |
| 174 | + ) |
| 175 | + |
| 176 | + return Sam3SegmentationResponse( |
| 177 | + time=perf_counter() - inference_start_timestamp, |
| 178 | + prompt_results=prompt_results, |
| 179 | + ) |
| 180 | + |
| 181 | + |
| 182 | +def _sam3_prompt_to_dict(p: Sam3Prompt) -> Dict[str, Any]: |
| 183 | + d: Dict[str, Any] = {"text": p.text} |
| 184 | + if p.boxes: |
| 185 | + d["boxes"] = ( |
| 186 | + p.boxes |
| 187 | + ) # backend's _build_visual_query handles pydantic Box/BoxXYXY |
| 188 | + d["box_labels"] = p.box_labels or [] |
| 189 | + return d |
| 190 | + |
| 191 | + |
| 192 | +def _build_echo(prompt_index: int, p: Sam3Prompt) -> Sam3PromptEcho: |
| 193 | + has_visual = bool(p.boxes) |
| 194 | + return Sam3PromptEcho( |
| 195 | + prompt_index=prompt_index, |
| 196 | + type="visual" if has_visual else "text", |
| 197 | + text=p.text, |
| 198 | + num_boxes=len(p.boxes) if has_visual else 0, |
| 199 | + ) |
| 200 | + |
| 201 | + |
| 202 | +def _to_numpy_masks(masks_any) -> np.ndarray: |
| 203 | + if masks_any is None: |
| 204 | + return np.zeros((0, 0, 0), dtype=np.uint8) |
| 205 | + if hasattr(masks_any, "detach"): |
| 206 | + masks_np = masks_any.detach().cpu().numpy().astype(np.uint8) |
| 207 | + else: |
| 208 | + arrs = [] |
| 209 | + for m in masks_any: |
| 210 | + if hasattr(m, "detach"): |
| 211 | + arrs.append(m.detach().cpu().numpy().astype(np.uint8)) |
| 212 | + else: |
| 213 | + arrs.append(np.asarray(m, dtype=np.uint8)) |
| 214 | + if not arrs: |
| 215 | + return np.zeros((0, 0, 0), dtype=np.uint8) |
| 216 | + masks_np = np.stack(arrs, axis=0) |
| 217 | + if masks_np.ndim == 4 and masks_np.shape[1] == 1: |
| 218 | + masks_np = masks_np[:, 0, ...] |
| 219 | + elif masks_np.ndim == 2: |
| 220 | + masks_np = masks_np[None, ...] |
| 221 | + return masks_np |
| 222 | + |
| 223 | + |
| 224 | +def _filter_by_threshold( |
| 225 | + masks_np: np.ndarray, |
| 226 | + scores: List[float], |
| 227 | + threshold: float, |
| 228 | +) -> Tuple[np.ndarray, List[float]]: |
| 229 | + if masks_np.ndim != 3 or masks_np.shape[0] == 0: |
| 230 | + return masks_np, scores |
| 231 | + keep = [i for i, s in enumerate(scores) if s >= threshold] |
| 232 | + if not keep: |
| 233 | + return np.zeros((0, 0, 0), dtype=np.uint8), [] |
| 234 | + return masks_np[keep], [scores[i] for i in keep] |
| 235 | + |
| 236 | + |
| 237 | +def _masks_to_predictions( |
| 238 | + masks_np: np.ndarray, scores: List[float], fmt: str |
| 239 | +) -> List[Sam3SegmentationPrediction]: |
| 240 | + preds: List[Sam3SegmentationPrediction] = [] |
| 241 | + if masks_np.ndim != 3 or 0 in masks_np.shape: |
| 242 | + return preds |
| 243 | + if fmt in ("polygon", "json"): |
| 244 | + polygons = masks2multipoly((masks_np > 0).astype(np.uint8)) |
| 245 | + for poly, score in zip(polygons, scores[: len(polygons)]): |
| 246 | + preds.append( |
| 247 | + Sam3SegmentationPrediction( |
| 248 | + masks=[p.tolist() for p in poly], |
| 249 | + confidence=float(score), |
| 250 | + format="polygon", |
| 251 | + ) |
| 252 | + ) |
| 253 | + elif fmt == "rle": |
| 254 | + for m, score in zip(masks_np, scores[: masks_np.shape[0]]): |
| 255 | + mb = (m > 0).astype(np.uint8) |
| 256 | + rle = mask_utils.encode(np.asfortranarray(mb)) |
| 257 | + rle["counts"] = rle["counts"].decode("utf-8") |
| 258 | + preds.append( |
| 259 | + Sam3SegmentationPrediction( |
| 260 | + masks=rle, confidence=float(score), format="rle" |
| 261 | + ) |
| 262 | + ) |
| 263 | + return preds |
| 264 | + |
| 265 | + |
| 266 | +def _collect_masks_with_per_prompt_threshold( |
| 267 | + processed: Dict[int, Dict[str, Any]], |
| 268 | + prompts: List[Sam3Prompt], |
| 269 | + default_threshold: float, |
| 270 | +) -> List[Tuple[int, np.ndarray, float]]: |
| 271 | + all_masks: List[Tuple[int, np.ndarray, float]] = [] |
| 272 | + for idx, p in enumerate(prompts): |
| 273 | + prompt_thresh = getattr(p, "output_prob_thresh", None) |
| 274 | + if prompt_thresh is None: |
| 275 | + prompt_thresh = default_threshold |
| 276 | + masks_np = _to_numpy_masks(processed[idx]["masks"]) |
| 277 | + scores = processed[idx]["scores"] |
| 278 | + if masks_np.ndim != 3 or 0 in masks_np.shape: |
| 279 | + continue |
| 280 | + for mask, score in zip(masks_np, scores): |
| 281 | + if score >= prompt_thresh: |
| 282 | + all_masks.append((idx, mask, float(score))) |
| 283 | + return all_masks |
| 284 | + |
| 285 | + |
| 286 | +def _nms_greedy_pycocotools( |
| 287 | + rles: List[Dict], |
| 288 | + confidences: np.ndarray, |
| 289 | + iou_threshold: float = 0.5, |
| 290 | +) -> np.ndarray: |
| 291 | + num_detections = len(rles) |
| 292 | + if num_detections == 0: |
| 293 | + return np.array([], dtype=bool) |
| 294 | + sort_index = np.argsort(confidences)[::-1] |
| 295 | + sorted_rles = [rles[i] for i in sort_index] |
| 296 | + ious = mask_utils.iou(sorted_rles, sorted_rles, [0] * num_detections) |
| 297 | + keep = np.ones(num_detections, dtype=bool) |
| 298 | + for i in range(num_detections): |
| 299 | + if keep[i]: |
| 300 | + condition = ious[i, :] > iou_threshold |
| 301 | + keep[i + 1 :] = np.where(condition[i + 1 :], False, keep[i + 1 :]) |
| 302 | + return keep[np.argsort(sort_index)] |
| 303 | + |
| 304 | + |
| 305 | +def _apply_nms_cross_prompt( |
| 306 | + all_masks: List[Tuple[int, np.ndarray, float]], |
| 307 | + iou_threshold: float, |
| 308 | +) -> List[Tuple[int, np.ndarray, float]]: |
| 309 | + if not all_masks: |
| 310 | + return all_masks |
| 311 | + rles = [] |
| 312 | + for _, mask_np, _ in all_masks: |
| 313 | + mb = (mask_np > 0).astype(np.uint8) |
| 314 | + rle = mask_utils.encode(np.asfortranarray(mb)) |
| 315 | + rles.append(rle) |
| 316 | + confidences = np.array([score for _, _, score in all_masks]) |
| 317 | + keep = _nms_greedy_pycocotools(rles, confidences, iou_threshold) |
| 318 | + return [all_masks[i] for i in range(len(all_masks)) if keep[i]] |
| 319 | + |
| 320 | + |
| 321 | +def _regroup_masks_by_prompt( |
| 322 | + filtered_masks: List[Tuple[int, np.ndarray, float]], |
| 323 | + num_prompts: int, |
| 324 | +) -> Dict[int, List[Tuple[np.ndarray, float]]]: |
| 325 | + result: Dict[int, List[Tuple[np.ndarray, float]]] = { |
| 326 | + i: [] for i in range(num_prompts) |
| 327 | + } |
| 328 | + for prompt_idx, mask_np, score in filtered_masks: |
| 329 | + result[prompt_idx].append((mask_np, score)) |
| 330 | + return result |
0 commit comments