11"""Shared base classes for tracker workflow blocks.
22
3- Each concrete tracker block (ByteTrack, SORT, OC-SORT) inherits from
4- ``TrackerBlockBase`` and only needs to implement ``_create_tracker`` and
5- ``get_manifest``.
3+ Each concrete tracker block (ByteTrack, BoT-SORT, SORT, OC-SORT) inherits from
4+ ``TrackerBlockBase`` and implements ``_create_tracker`` and ``get_manifest``.
5+ Sub-classes may override ``_tracker_update`` when the underlying tracker needs
6+ extra per-frame context (e.g. a video frame for camera motion compensation).
67"""
78
89from abc import abstractmethod
910from collections import deque
10- from typing import Any , Dict , List , Optional , Type
11+ from typing import Any , Dict , List , Type
1112
1213import supervision as sv
1314
@@ -76,7 +77,9 @@ def _cache_new_tracker_id(self, tracker_id: int) -> None:
7677class TrackerBlockBase (WorkflowBlock ):
7778 """Common run-loop shared by every tracker block.
7879
79- Sub-classes only need to override ``_create_tracker`` and ``get_manifest``.
80+ Sub-classes implement ``_create_tracker`` and ``get_manifest``. Override
81+ ``_tracker_update`` only when the tracker API requires additional context
82+ beyond ``sv.Detections`` (e.g. BoT-SORT with camera motion compensation).
8083 """
8184
8285 def __init__ (self ) -> None :
@@ -92,6 +95,20 @@ def _create_tracker(self, fps: int, **kwargs: Any) -> Any:
9295 """Instantiate the concrete tracker with algorithm-specific params."""
9396 ...
9497
98+ def _tracker_update (
99+ self ,
100+ tracker : Any ,
101+ detections : sv .Detections ,
102+ image : WorkflowImageData ,
103+ ) -> sv .Detections :
104+ """Invoke the tracker for one frame.
105+
106+ Must call ``tracker.update`` only with arguments that library trackers
107+ define for the per-frame step (typically detections, optionally a frame
108+ tensor). Do **not** pass workflow/block kwargs used in ``_create_tracker``.
109+ """
110+ return tracker .update (detections )
111+
95112 def _run_tracker (
96113 self ,
97114 image : WorkflowImageData ,
@@ -120,7 +137,7 @@ def _run_tracker(
120137 self ._trackers [video_id ] = self ._create_tracker (fps = fps , ** tracker_kwargs )
121138
122139 tracker = self ._trackers [video_id ]
123- tracked_detections = tracker . update ( detections )
140+ tracked_detections = self . _tracker_update ( tracker , detections , image )
124141
125142 # Filter out immature / unmatched tracks (tracker_id == -1)
126143 if tracked_detections .tracker_id is not None and len (tracked_detections ) > 0 :
0 commit comments