Skip to content

Commit 2cb88b4

Browse files
omkar-334SkalskiPBordaclaudeCopilot
authored
add missing and modern typehints (#302)
* add missing and modern typehints * Remove unused import from __init__.py * add missing `: object` to __exit__ *_ params for consistency * add TYPE_CHECKING and annotations to resolve potential circular imports --------- Co-authored-by: Piotr Skalski <piotr.skalski92@gmail.com> Co-authored-by: jirka <6035284+Borda@users.noreply.github.com> Co-authored-by: Claude Code <noreply@anthropic.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 5c9d9a5 commit 2cb88b4

6 files changed

Lines changed: 29 additions & 13 deletions

File tree

test/core/test_tracker_integration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def _run_tracker_on_flat_dataset(
7070
tracked = tracker.update(detections)
7171
if tracked.tracker_id is not None:
7272
mature = tracked[tracked.tracker_id != -1]
73+
assert isinstance(mature, sv.Detections)
7374
mot.write(frame_idx, mature)
7475
else:
7576
mot.write(frame_idx, tracked)
@@ -110,6 +111,9 @@ def test_tracker_regression(
110111
)
111112

112113
aggregate = result.aggregate
114+
assert aggregate.HOTA is not None
115+
assert aggregate.CLEAR is not None
116+
assert aggregate.Identity is not None
113117
assert aggregate.HOTA.HOTA * 100 == pytest.approx(expected["HOTA"], abs=0.001)
114118
assert aggregate.CLEAR.MOTA * 100 == pytest.approx(expected["MOTA"], abs=0.001)
115119
assert aggregate.Identity.IDF1 * 100 == pytest.approx(expected["IDF1"], abs=0.001)

trackers/core/bytetrack/kalman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_next_tracker_id(cls) -> int:
5151
cls.count_id += 1
5252
return next_id
5353

54-
def __init__(self, bbox: np.ndarray):
54+
def __init__(self, bbox: np.ndarray) -> None:
5555
# Initialize with a temporary ID of -1
5656
# Will be assigned a real ID when the track is considered mature
5757
self.tracker_id = -1

trackers/eval/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66

77
"""Evaluation metrics and utilities for tracking benchmarks."""
88

9+
from __future__ import annotations
10+
11+
from typing import TYPE_CHECKING
12+
913
from trackers.eval.box import box_ioa, box_iou
1014
from trackers.eval.clear import aggregate_clear_metrics, compute_clear_metrics
1115
from trackers.eval.hota import aggregate_hota_metrics, compute_hota_metrics
@@ -18,8 +22,11 @@
1822
SequenceResult,
1923
)
2024

25+
if TYPE_CHECKING:
26+
from trackers.eval.evaluate import evaluate_mot_sequence, evaluate_mot_sequences
27+
2128

22-
def __getattr__(name: str):
29+
def __getattr__(name: str) -> object:
2330
"""Lazy imports for evaluate functions to avoid circular imports."""
2431
if name in ("evaluate_mot_sequence", "evaluate_mot_sequences"):
2532
from trackers.eval import evaluate as _evaluate

trackers/io/mot.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import csv
1010
from dataclasses import dataclass
1111
from pathlib import Path
12+
from typing import TextIO
1213

1314
import numpy as np
1415
import supervision as sv
@@ -418,9 +419,9 @@ def _prepare_mot_sequence(
418419
class _MOTOutput:
419420
"""Context manager for MOT format file writing."""
420421

421-
def __init__(self, path: Path | None):
422+
def __init__(self, path: Path | None) -> None:
422423
self.path = path
423-
self._file = None
424+
self._file: TextIO | None = None
424425

425426
def write(self, frame_idx: int, detections: sv.Detections) -> None:
426427
"""Write detections for a frame in MOT format."""
@@ -447,12 +448,12 @@ def write(self, frame_idx: int, detections: sv.Detections) -> None:
447448
f"{conf:.4f},-1,-1,-1\n"
448449
)
449450

450-
def __enter__(self):
451+
def __enter__(self) -> _MOTOutput:
451452
if self.path is not None:
452453
self.path.parent.mkdir(parents=True, exist_ok=True)
453454
self._file = open(self.path, "w")
454455
return self
455456

456-
def __exit__(self, *_):
457+
def __exit__(self, *_: object) -> None:
457458
if self._file is not None:
458459
self._file.close()

trackers/io/video.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _iter_image_folder_frames(
9090
class _VideoOutput:
9191
"""Context manager for lazy video file writing."""
9292

93-
def __init__(self, path: Path | None, *, fps: float = _DEFAULT_OUTPUT_FPS):
93+
def __init__(self, path: Path | None, *, fps: float = _DEFAULT_OUTPUT_FPS) -> None:
9494
self.path = path
9595
self.fps = fps
9696
self._writer: cv2.VideoWriter | None = None
@@ -137,7 +137,7 @@ def __exit__(self, *_: object) -> None:
137137
class _DisplayWindow:
138138
"""Context manager for OpenCV display window with resizable output."""
139139

140-
def __init__(self, window_name: str = "Tracking"):
140+
def __init__(self, window_name: str = "Tracking") -> None:
141141
self.window_name = window_name
142142
self._quit_requested = False
143143
cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
@@ -159,8 +159,8 @@ def quit_requested(self) -> bool:
159159
"""Return True if user pressed quit key."""
160160
return self._quit_requested
161161

162-
def __enter__(self):
162+
def __enter__(self) -> _DisplayWindow:
163163
return self
164164

165-
def __exit__(self, *_):
165+
def __exit__(self, *_: object) -> None:
166166
cv2.destroyWindow(self.window_name)

trackers/scripts/track.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import sys
1212
from contextlib import nullcontext
1313
from pathlib import Path
14+
from typing import TYPE_CHECKING
1415

1516
import numpy as np
1617
import supervision as sv
@@ -23,6 +24,9 @@
2324
from trackers.scripts.progress import _classify_source, _SourceInfo, _TrackingProgress
2425
from trackers.utils.device import _best_device
2526

27+
if TYPE_CHECKING:
28+
from inference_models import AnyModel
29+
2630
# Defaults
2731
DEFAULT_MODEL = "rfdetr-nano"
2832
DEFAULT_TRACKER = "bytetrack"
@@ -555,7 +559,7 @@ def _init_model(
555559
*,
556560
device: str = DEFAULT_DEVICE,
557561
api_key: str | None = None,
558-
):
562+
) -> AnyModel:
559563
"""Load detection model via inference-models.
560564
561565
Args:
@@ -585,7 +589,7 @@ def _init_model(
585589
)
586590

587591

588-
def _run_model(model, frame: np.ndarray, confidence: float) -> sv.Detections:
592+
def _run_model(model: AnyModel, frame: np.ndarray, confidence: float) -> sv.Detections:
589593
"""Run model inference and return sv.Detections."""
590594
predictions = model(frame)
591595
if not predictions:
@@ -627,7 +631,7 @@ def _extract_tracker_params(
627631
return params
628632

629633

630-
def _init_tracker(tracker_id: str, **kwargs) -> BaseTracker:
634+
def _init_tracker(tracker_id: str, **kwargs: object) -> BaseTracker:
631635
"""Create tracker instance from registry.
632636
633637
Args:

0 commit comments

Comments
 (0)