Skip to content

Commit 0085452

Browse files
authored
Add SAM2 streaming video tracker (inference_models + workflow block) (#2245)
* Add SAM3ForStream and streaming-video tests in inference_models Introduces a SAM3 streaming tracker that mirrors the existing SAM2ForStream interface (prompt / track returning (masks, object_ids, state_dict)) so both can be used interchangeably by upstream code. - inference_models/models/sam3_rt/sam3_pytorch.py: SAM3ForStream backed by HuggingFace transformers' Sam3VideoModel / Sam3VideoProcessor. The native sam3 package's video predictor requires a full video resource upfront; the transformers port exposes init_video_session + per-frame model(frame=...), which is the shape we need for InferencePipeline-style streaming. - Accepts bbox and/or text prompts; state_dict is opaque (wraps the HF Sam3VideoInferenceSession) and must be kept in memory by the caller — it's not serializable across processes. - Register (segment-anything-3-rt, INSTANCE_SEGMENTATION_TASK, BackendType.HF) in models_registry alongside the existing SAM2-RT entry. Tests: - inference_models/tests/unit_tests/models/test_sam3_rt.py — 24 unit tests covering helpers (_normalise_bboxes, _unpack_processed_outputs, etc.) plus class behaviour using MagicMock model/processor. No weights required. - inference_models/tests/integration_tests/models/test_sam2_rt_predictions.py — new integration suite for the existing SAM2ForStream (prompt -> track on synthetic frames, centroid-moves assertion, track-without-prompt raises, torch.Tensor input). - inference_models/tests/integration_tests/models/test_sam3_rt_predictions.py — analogous suite for SAM3ForStream. - conftest.py: sam2_rt_package and sam3_rt_package fixtures download zips from rf-platform-models; docstrings list the expected file contents for upload. https://claude.ai/code/session_01T3k4sfbkaV3warwV7MHRZN * Add SAM2/SAM3 streaming video tracker workflow blocks Two new LOCAL-only workflow blocks that drive the inference_models streaming trackers (SAM2ForStream / SAM3ForStream) from workflows powered by InferencePipeline. Both blocks multiplex a single model instance across many videos by keying state_dicts on video_metadata.video_identifier, reset sessions when frame_number rolls back, and support three prompt modes: first_frame, every_n_frames, every_frame. - inference/core/workflows/core_steps/models/foundation/ _streaming_video_common.py: shared helpers (state bookkeeping, prompt-vs-track decision logic, sv.Detections assembly with SAM-assigned tracker_ids). - segment_anything2_video/v1.py: SAM2VideoTrackerBlockV1 (type: roboflow_core/segment_anything_2_video@v1, default model_id: segment-anything-2-rt). - segment_anything3_video/v1.py: SAM3VideoTrackerBlockV1 (type: roboflow_core/sam3_video@v1, default model_id: segment-anything-3-rt). Additionally accepts text prompts via class_names; boxes win when both are supplied. - Both raise NotImplementedError on REMOTE step execution — per-video session state cannot survive a remote boundary. - Models are loaded via inference_models.AutoModel.from_pretrained so backend negotiation / package download / caching flow through the standard inference_models pipeline. - Registered in core_steps/loader.py. Tests (30 total, all passing, no weights required): - test_segment_anything2_video.py — 10 tests covering manifest, REMOTE rejection, first_frame/every_n_frames/every_frame modes, state threading across track calls, multi-stream isolation, stream-restart detection. - test_segment_anything3_video.py — 9 tests with similar coverage plus text-vs-box prompt routing. - test_streaming_video_common.py — 11 tests for the shared helpers. https://claude.ai/code/session_01T3k4sfbkaV3warwV7MHRZN * Rename SAM video trackers to sam2video/sam3video; add SAM2Video (HF) Refactors the streaming trackers into a shared HuggingFace transformers base and adds a SAM2Video counterpart to the existing SAM3Video. The older sam2_rt (SAM2ForStream using Meta's sam2 camera predictor) is kept untouched — per the feedback it hasn't been exercised much in practice. Model classes ------------- - inference_models/models/common/hf_streaming_video.py: HFStreamingVideoBase containing all the HF streaming boilerplate — session init, prompt/track methods, mask/obj_id extraction, opaque state_dict contract. - inference_models/models/sam2_video/sam2_video_hf.py: SAM2Video (lazy-imports transformers.Sam2VideoModel / Sam2VideoProcessor; rejects text prompts). - inference_models/models/sam3_video/sam3_video_hf.py: SAM3Video, moved from the previous sam3_rt path; now a thin ~25-line subclass after the shared base absorbed the helpers (lazy-imports transformers.Sam3VideoModel / Sam3VideoProcessor; accepts both text and box prompts). Registry -------- - sam2video: (INSTANCE_SEGMENTATION_TASK, BackendType.HF) -> SAM2Video - sam3video: (INSTANCE_SEGMENTATION_TASK, BackendType.HF) -> SAM3Video - segment-anything-2-rt stays registered against SAM2ForStream. - segment-anything-3-rt entry dropped (never released). Workflow blocks now default to these ids: - roboflow_core/segment_anything_2_video@v1 -> "sam2video" - roboflow_core/sam3_video@v1 -> "sam3video" Tests ----- - Unit tests: added test_sam2_video.py (4 SAM2-specific), renamed test_sam3_rt.py -> test_sam3_video.py and updated imports (24 tests covering helpers on the shared base plus SAM3 class behaviour). - Integration tests: renamed SAM3 file, added SAM2 counterpart. New fixtures sam2_video_package / sam3_video_package (expected zips at rf-platform-models/sam2video.zip and rf-platform-models/sam3video.zip). - Workflow block tests updated to use sam2video / sam3video ids. - All 58 non-integration tests pass locally. https://claude.ai/code/session_01T3k4sfbkaV3warwV7MHRZN * Drop SAM2-RT tests/fixture so this PR doesn't touch the legacy model Removes the integration tests and fixture I'd added for the existing SAM2ForStream (sam2_rt) — keeping them would require uploading a segment-anything-2-rt.zip to the test assets bucket, but the goal for this PR is to leave that untested path alone. - Deleted tests/integration_tests/models/test_sam2_rt_predictions.py - Removed SAM2_RT_PACKAGE_URL + sam2_rt_package fixture from conftest.py - Fixed two docstring references that still said SAM2ForStream when they now point at the new SAM2Video. The SAM2ForStream registry entry itself stays — it's the legacy model that existed before this branch and we're not touching it. https://claude.ai/code/session_01T3k4sfbkaV3warwV7MHRZN * Add handoff doc for local testing + follow-up Captures the state of this branch while session context is fresh — what's where, how to test, what needs uploading, known gotchas, and a sketch of the follow-up "add a model" Claude skill. Doc is temporary and should be deleted before the PR merges. https://claude.ai/code/session_01T3k4sfbkaV3warwV7MHRZN * SAM2 video: variant model ids + input_boxes nesting fix - Default workflow block to sam2video/small, advertise all four Hiera backbones (tiny / small / base-plus / large) via examples and get_supported_model_variants. - Fix Sam2VideoProcessor.add_inputs_to_inference_session call: the processor expects input_boxes with 3 nesting levels ([image, boxes, coords]); we were passing 4, which raised ValueError on the first real-weights prompt. Unit tests missed it because they mock the processor — surfaced by end-to-end verify against the uploaded sam2video-small.zip. - Point the inference_models integration fixture URL at sam2video-small.zip (the variant that matches the new default). - Update sam2video workflow-block unit tests to pass the new sam2video/small default through the mocks. * Strip SAM3 video tracker from this PR Descope the SAM3 streaming-video work to a follow-up so this PR can ship SAM2 video alone. SAM3's HF port requires the gated facebook/sam3 checkpoints, which aren't available yet. Removed: - inference_models SAM3 model class, registry entry, unit + integration tests, and test fixture - sam3_video workflow block + loader registration + unit tests - SAM_VIDEO_HANDOFF.md (served its purpose during the branch work) Left in place: - HFStreamingVideoBase in inference_models/models/common — reusable, SAM2 uses it today and a future SAM3 port can inherit unchanged - _streaming_video_common workflow helpers — still used by the SAM2 video block - SAM2 video class, registry entry, workflow block, and all SAM2 tests Comments that referenced SAM3Video / test_sam3_video.py have been generalised or trimmed so nothing dangles.
1 parent 127dd3c commit 0085452

13 files changed

Lines changed: 1746 additions & 0 deletions

File tree

inference/core/workflows/core_steps/loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,9 @@
277277
from inference.core.workflows.core_steps.models.foundation.segment_anything2.v1 import (
278278
SegmentAnything2BlockV1,
279279
)
280+
from inference.core.workflows.core_steps.models.foundation.segment_anything2_video.v1 import (
281+
SegmentAnything2VideoBlockV1,
282+
)
280283
from inference.core.workflows.core_steps.models.foundation.segment_anything3.v1 import (
281284
SegmentAnything3BlockV1,
282285
)
@@ -835,6 +838,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]:
835838
SIFTComparisonBlockV1,
836839
SIFTComparisonBlockV2,
837840
SegmentAnything2BlockV1,
841+
SegmentAnything2VideoBlockV1,
838842
SegmentAnything3BlockV1,
839843
SegmentAnything3BlockV2,
840844
SegmentAnything3BlockV3,
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
"""Shared helpers for SAM2/SAM3 streaming video tracker workflow blocks.
2+
3+
Both blocks multiplex a single ``inference_models``-backed streaming
4+
model across many videos by keying ``state_dict``s on
5+
``video_identifier``. They follow the same decision logic on every
6+
frame: reset a session if the source stream restarted, and re-prompt
7+
only on the frames requested by ``prompt_mode``. Everything that is
8+
independent of "SAM2 vs SAM3" lives here so each concrete block is just
9+
a thin wrapper around ``inference_models.AutoModel``.
10+
"""
11+
12+
from dataclasses import dataclass, field
13+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
14+
from uuid import uuid4
15+
16+
import numpy as np
17+
import supervision as sv
18+
19+
from inference.core.workflows.execution_engine.constants import (
20+
DETECTION_ID_KEY,
21+
IMAGE_DIMENSIONS_KEY,
22+
PARENT_ID_KEY,
23+
)
24+
from inference.core.workflows.execution_engine.entities.base import WorkflowImageData
25+
26+
DETECTIONS_CLASS_NAME_FIELD = "class_name"
27+
28+
PromptMode = Literal["first_frame", "every_n_frames", "every_frame"]
29+
30+
31+
@dataclass
32+
class VideoSessionBookkeeping:
33+
"""Per-video bookkeeping that lives alongside the model's opaque
34+
``state_dict``.
35+
36+
We store the last state returned from the model so the next call
37+
can continue the same session; ``obj_id_metadata`` holds the
38+
detector-provided class name / id / parent detection id for each
39+
prompted track so the emitted masks inherit them.
40+
"""
41+
42+
state_dict: Optional[dict] = None
43+
last_frame_number: int = -1
44+
frames_since_prompt: int = 0
45+
obj_id_metadata: Dict[int, Dict[str, Any]] = field(default_factory=dict)
46+
47+
48+
@dataclass
49+
class BoxPromptMetadata:
50+
"""Class info carried from an upstream detector to the emitted mask."""
51+
52+
class_id: int
53+
class_name: str
54+
confidence: float
55+
parent_id: Optional[str]
56+
57+
58+
def extract_box_prompts(
59+
boxes_for_image: Optional[sv.Detections],
60+
) -> Tuple[List[Tuple[float, float, float, float]], List[BoxPromptMetadata]]:
61+
"""Flatten an ``sv.Detections`` into xyxy tuples + per-box metadata.
62+
63+
Empty / missing input returns two empty lists; class_name defaults
64+
to "foreground" when the detection doesn't carry one.
65+
"""
66+
if boxes_for_image is None or len(boxes_for_image) == 0:
67+
return [], []
68+
69+
boxes_xyxy: List[Tuple[float, float, float, float]] = []
70+
metas: List[BoxPromptMetadata] = []
71+
for xyxy, _mask, confidence, class_id, _tracker_id, data in boxes_for_image:
72+
x1, y1, x2, y2 = xyxy
73+
boxes_xyxy.append((float(x1), float(y1), float(x2), float(y2)))
74+
class_name = (
75+
data.get(DETECTIONS_CLASS_NAME_FIELD, "foreground")
76+
if isinstance(data, dict)
77+
else "foreground"
78+
)
79+
parent_id = data.get("detection_id") if isinstance(data, dict) else None
80+
metas.append(
81+
BoxPromptMetadata(
82+
class_id=int(class_id) if class_id is not None else 0,
83+
class_name=str(class_name),
84+
confidence=float(confidence) if confidence is not None else 1.0,
85+
parent_id=str(parent_id) if parent_id is not None else None,
86+
)
87+
)
88+
return boxes_xyxy, metas
89+
90+
91+
def decide_prompt_vs_track(
92+
session: VideoSessionBookkeeping,
93+
frame_number: int,
94+
prompt_mode: PromptMode,
95+
prompt_interval: int,
96+
has_prompts: bool,
97+
) -> Tuple[bool, bool]:
98+
"""Return ``(should_reset, should_prompt)`` for a single frame.
99+
100+
- A reset fires when the source stream's ``frame_number`` rolls
101+
back (or this is the first frame we've seen for this video).
102+
- ``should_prompt`` is gated on prompt availability: there's no
103+
point issuing a prompt call with nothing to prompt on.
104+
"""
105+
fresh_session = session.last_frame_number < 0 or session.state_dict is None
106+
reset = fresh_session or frame_number < session.last_frame_number
107+
108+
if prompt_mode == "every_frame":
109+
return reset, has_prompts
110+
if prompt_mode == "every_n_frames":
111+
due = reset or session.frames_since_prompt >= max(1, prompt_interval)
112+
return reset, due and has_prompts
113+
# first_frame
114+
return reset, reset and has_prompts
115+
116+
117+
def masks_to_sv_detections(
118+
masks: np.ndarray,
119+
obj_ids: np.ndarray,
120+
image: WorkflowImageData,
121+
obj_id_metadata: Dict[int, BoxPromptMetadata],
122+
threshold: float,
123+
) -> sv.Detections:
124+
"""Assemble one ``sv.Detections`` of instance-seg predictions.
125+
126+
Emits one detection per SAM-assigned object (preserving the
127+
one-to-one mapping with ``tracker_id``). Masks without any positive
128+
pixels are dropped.
129+
"""
130+
h, w = image.numpy_image.shape[:2]
131+
if masks.shape[0] == 0:
132+
return _empty_detections(h, w)
133+
134+
xyxy: List[List[float]] = []
135+
confidences: List[float] = []
136+
class_ids: List[int] = []
137+
class_names: List[str] = []
138+
tracker_ids: List[int] = []
139+
detection_ids: List[str] = []
140+
parent_ids: List[str] = []
141+
kept_masks: List[np.ndarray] = []
142+
143+
for mask, obj_id in zip(masks, obj_ids.tolist()):
144+
meta = obj_id_metadata.get(int(obj_id))
145+
confidence = meta.confidence if meta is not None else 1.0
146+
if confidence < threshold:
147+
continue
148+
ys, xs = np.where(mask)
149+
if xs.size == 0:
150+
continue
151+
xyxy.append(
152+
[
153+
float(xs.min()),
154+
float(ys.min()),
155+
float(xs.max()),
156+
float(ys.max()),
157+
]
158+
)
159+
confidences.append(float(confidence))
160+
class_ids.append(meta.class_id if meta is not None else 0)
161+
class_names.append(meta.class_name if meta is not None else "foreground")
162+
tracker_ids.append(int(obj_id))
163+
parent = meta.parent_id if meta is not None else None
164+
parent_ids.append(str(parent) if parent is not None else "")
165+
detection_ids.append(str(uuid4()))
166+
kept_masks.append(mask.astype(bool))
167+
168+
if not kept_masks:
169+
return _empty_detections(h, w)
170+
171+
detections = sv.Detections(
172+
xyxy=np.asarray(xyxy, dtype=np.float32),
173+
mask=np.stack(kept_masks, axis=0),
174+
confidence=np.asarray(confidences, dtype=np.float32),
175+
class_id=np.asarray(class_ids, dtype=int),
176+
tracker_id=np.asarray(tracker_ids, dtype=int),
177+
)
178+
detections.data[DETECTIONS_CLASS_NAME_FIELD] = np.asarray(class_names, dtype=object)
179+
detections[DETECTION_ID_KEY] = np.asarray(detection_ids, dtype=object)
180+
detections[PARENT_ID_KEY] = np.asarray(parent_ids, dtype=object)
181+
detections[IMAGE_DIMENSIONS_KEY] = np.asarray([[h, w]] * len(detections), dtype=int)
182+
return detections
183+
184+
185+
def _empty_detections(h: int, w: int) -> sv.Detections:
186+
empty = sv.Detections.empty()
187+
empty[DETECTION_ID_KEY] = np.array([], dtype=object)
188+
empty[PARENT_ID_KEY] = np.array([], dtype=object)
189+
empty[IMAGE_DIMENSIONS_KEY] = np.zeros((0, 2), dtype=int)
190+
empty.data[DETECTIONS_CLASS_NAME_FIELD] = np.array([], dtype=object)
191+
return empty
192+
193+
194+
def build_obj_id_metadata_from_boxes(
195+
obj_ids: np.ndarray,
196+
box_metas: List[BoxPromptMetadata],
197+
) -> Dict[int, BoxPromptMetadata]:
198+
"""Align SAM-assigned object ids with the detector-provided metadata.
199+
200+
The model hands us object ids in the same order as the prompts we
201+
issued; we zip them together so later frames (which only have
202+
``obj_ids``) can still be labelled.
203+
"""
204+
return dict(zip([int(i) for i in obj_ids.tolist()], box_metas))
205+
206+
207+
def build_obj_id_metadata_from_text(
208+
obj_ids: np.ndarray,
209+
class_names: List[str],
210+
) -> Dict[int, BoxPromptMetadata]:
211+
"""For text-prompt sessions where we don't have per-object class
212+
info, fall back to a single class name (if only one was supplied)
213+
or "foreground" (if multiple or none).
214+
"""
215+
label = class_names[0] if len(class_names) == 1 and class_names[0] else "foreground"
216+
return {
217+
int(oid): BoxPromptMetadata(
218+
class_id=0, class_name=label, confidence=1.0, parent_id=None
219+
)
220+
for oid in obj_ids.tolist()
221+
}
222+
223+
224+
def normalise_class_names(
225+
class_names: Optional[Any],
226+
) -> List[str]:
227+
"""Accept a list, comma-separated string, or None and return a list."""
228+
if class_names is None:
229+
return []
230+
if isinstance(class_names, str):
231+
return [c.strip() for c in class_names.split(",") if c.strip()]
232+
return [c for c in class_names if c]

inference/core/workflows/core_steps/models/foundation/segment_anything2_video/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)