Skip to content

Commit b252227

Browse files
feat/core: BoT-SORT block (#2349)
* initial commit for BoTSORT block * added botsort to overview doc --------- Co-authored-by: Paweł Pęczek <146137186+PawelPeczek-Roboflow@users.noreply.github.com>
1 parent 7074d8e commit b252227

8 files changed

Lines changed: 501 additions & 8 deletions

File tree

docs/workflows/video_processing/overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Video Processing with Workflows
22

33
We've begun our journey into video processing using Workflows. Over time, we've expanded the number of
4-
video-specific blocks — including object tracker blocks for **ByteTrack**, **SORT**, and **OC-SORT** — and
4+
video-specific blocks — including object tracker blocks for **ByteTrack**, **SORT**, **OC-SORT** and **BoT-SORT** — and
55
continue to dedicate efforts toward improving their performance and robustness. The current state of this
66
work is as follows:
77

inference/core/workflows/core_steps/loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,9 @@
453453
TwilioSMSNotificationBlockV2,
454454
)
455455
from inference.core.workflows.core_steps.sinks.webhook.v1 import WebhookSinkBlockV1
456+
from inference.core.workflows.core_steps.trackers.botsort.v1 import (
457+
BoTSORTBlockV1 as TrackerBoTSORTBlockV1,
458+
)
456459
from inference.core.workflows.core_steps.trackers.bytetrack.v1 import (
457460
ByteTrackBlockV1 as TrackerByteTrackBlockV1,
458461
)
@@ -919,6 +922,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]:
919922
ReferencePathVisualizationBlockV1,
920923
ByteTrackerBlockV3,
921924
TrackerByteTrackBlockV1,
925+
TrackerBoTSORTBlockV1,
922926
TrackerSORTBlockV1,
923927
TrackerOCSORTBlockV1,
924928
WebhookSinkBlockV1,

inference/core/workflows/core_steps/trackers/_base.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Shared base classes for tracker workflow blocks.
22
3-
Each concrete tracker block (ByteTrack, SORT, OC-SORT) inherits from
4-
``TrackerBlockBase`` and only needs to implement ``_create_tracker`` and
5-
``get_manifest``.
3+
Each concrete tracker block (ByteTrack, BoT-SORT, SORT, OC-SORT) inherits from
4+
``TrackerBlockBase`` and implements ``_create_tracker`` and ``get_manifest``.
5+
Sub-classes may override ``_tracker_update`` when the underlying tracker needs
6+
extra per-frame context (e.g. a video frame for camera motion compensation).
67
"""
78

89
from abc import abstractmethod
910
from collections import deque
10-
from typing import Any, Dict, List, Optional, Type
11+
from typing import Any, Dict, List, Type
1112

1213
import supervision as sv
1314

@@ -76,7 +77,9 @@ def _cache_new_tracker_id(self, tracker_id: int) -> None:
7677
class TrackerBlockBase(WorkflowBlock):
7778
"""Common run-loop shared by every tracker block.
7879
79-
Sub-classes only need to override ``_create_tracker`` and ``get_manifest``.
80+
Sub-classes implement ``_create_tracker`` and ``get_manifest``. Override
81+
``_tracker_update`` only when the tracker API requires additional context
82+
beyond ``sv.Detections`` (e.g. BoT-SORT with camera motion compensation).
8083
"""
8184

8285
def __init__(self) -> None:
@@ -92,6 +95,20 @@ def _create_tracker(self, fps: int, **kwargs: Any) -> Any:
9295
"""Instantiate the concrete tracker with algorithm-specific params."""
9396
...
9497

98+
def _tracker_update(
99+
self,
100+
tracker: Any,
101+
detections: sv.Detections,
102+
image: WorkflowImageData,
103+
) -> sv.Detections:
104+
"""Invoke the tracker for one frame.
105+
106+
Must call ``tracker.update`` only with arguments that library trackers
107+
define for the per-frame step (typically detections, optionally a frame
108+
tensor). Do **not** pass workflow/block kwargs used in ``_create_tracker``.
109+
"""
110+
return tracker.update(detections)
111+
95112
def _run_tracker(
96113
self,
97114
image: WorkflowImageData,
@@ -120,7 +137,7 @@ def _run_tracker(
120137
self._trackers[video_id] = self._create_tracker(fps=fps, **tracker_kwargs)
121138

122139
tracker = self._trackers[video_id]
123-
tracked_detections = tracker.update(detections)
140+
tracked_detections = self._tracker_update(tracker, detections, image)
124141

125142
# Filter out immature / unmatched tracks (tracker_id == -1)
126143
if tracked_detections.tracker_id is not None and len(tracked_detections) > 0:

0 commit comments

Comments
 (0)