Skip to content

Commit 45715db

Browse files
authored
Merge pull request #156 from DeepLabCut/cy/expand-testing
Expanded test coverage & small tweaks/fixes
2 parents 5ef5f88 + 822f5d4 commit 45715db

File tree

16 files changed

+1328
-118
lines changed

16 files changed

+1328
-118
lines changed

.github/workflows/testing.yml

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
df -h
4444
4545
- name: Checkout code
46-
uses: actions/checkout@v4
46+
uses: actions/checkout@v4 # uses: actions/checkout@v6
4747

4848
- name: Install uv
4949
uses: astral-sh/setup-uv@v6
@@ -53,9 +53,9 @@ jobs:
5353
python-version: ${{ matrix.python-version }}
5454

5555
- name: Install the project
56-
run: uv sync --no-cache --all-extras --dev
56+
run: uv sync --all-extras --dev
5757
shell: bash
58-
58+
5959
- name: Install ffmpeg
6060
run: |
6161
if [ "$RUNNER_OS" == "Linux" ]; then
@@ -67,9 +67,27 @@ jobs:
6767
choco install ffmpeg
6868
fi
6969
shell: bash
70-
71-
- name: Run DLC Live Tests
70+
71+
- name: Run Model Benchmark Test
7272
run: uv run dlc-live-test --nodisplay
7373

74-
- name: Run Functional Benchmark Test
74+
- name: Run DLC Live Unit Tests
7575
run: uv run pytest
76+
# - name: Run DLC Live Unit Tests
77+
# run: uv run pytest --cov=dlclive --cov-report=xml --cov-report=term-missing
78+
79+
# - name: Coverage Report
80+
# uses: codecov/codecov-action@v5
81+
# with:
82+
# files: ./coverage.xml
83+
# flags: ${{ matrix.os }}-py${{ matrix.python-version }}
84+
# name: codecov-${{ matrix.os }}-py${{ matrix.python-version }}
85+
# - name: Add coverage to job summary
86+
# if: always()
87+
# shell: bash
88+
# run: |
89+
# uv run python -m coverage report -m > coverage.txt
90+
# echo "## Coverage (dlclive)" >> "$GITHUB_STEP_SUMMARY"
91+
# echo '```' >> "$GITHUB_STEP_SUMMARY"
92+
# cat coverage.txt >> "$GITHUB_STEP_SUMMARY"
93+
# echo '```' >> "$GITHUB_STEP_SUMMARY"

.pre-commit-config.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v6.0.0
4+
hooks:
5+
- id: check-docstring-first
6+
- id: end-of-file-fixer
7+
- id: trailing-whitespace
8+
- repo: https://github.com/asottile/setup-cfg-fmt
9+
rev: v3.2.0
10+
hooks:
11+
- id: setup-cfg-fmt
12+
- repo: https://github.com/astral-sh/ruff-pre-commit
13+
rev: v0.14.10
14+
hooks:
15+
# Run the formatter.
16+
- id: ruff-format
17+
# Run the linter.
18+
- id: ruff-check
19+
args: [--fix,--unsafe-fixes]

dlclive/core/inferenceutils.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
#
99
# Licensed under GNU Lesser General Public License v3.0
1010
#
11+
12+
13+
# NOTE - DUPLICATED @C-Achard 2026-01-26: Copied from the original DeepLabCut codebase
14+
# from deeplabcut/core/inferenceutils.py
1115
from __future__ import annotations
1216

1317
import heapq
@@ -17,9 +21,10 @@
1721
import pickle
1822
import warnings
1923
from collections import defaultdict
24+
from collections.abc import Iterable
2025
from dataclasses import dataclass
2126
from math import erf, sqrt
22-
from typing import Any, Iterable, Tuple
27+
from typing import Any
2328

2429
import networkx as nx
2530
import numpy as np
@@ -41,7 +46,7 @@ def _conv_square_to_condensed_indices(ind_row, ind_col, n):
4146
return n * ind_col - ind_col * (ind_col + 1) // 2 + ind_row - 1 - ind_col
4247

4348

44-
Position = Tuple[float, float]
49+
Position = tuple[float, float]
4550

4651

4752
@dataclass(frozen=True)
@@ -155,7 +160,7 @@ def soft_identity(self):
155160
unq, idx, cnt = np.unique(data[:, 3], return_inverse=True, return_counts=True)
156161
avg = np.bincount(idx, weights=data[:, 2]) / cnt
157162
soft = softmax(avg)
158-
return dict(zip(unq.astype(int), soft))
163+
return dict(zip(unq.astype(int), soft, strict=False))
159164

160165
@property
161166
def affinity(self):
@@ -262,7 +267,8 @@ def __init__(
262267
self._has_identity = "identity" in self[0]
263268
if identity_only and not self._has_identity:
264269
warnings.warn(
265-
"The network was not trained with identity; setting `identity_only` to False."
270+
"The network was not trained with identity; setting `identity_only` to False.",
271+
stacklevel=2,
266272
)
267273
self.identity_only = identity_only & self._has_identity
268274
self.nan_policy = nan_policy
@@ -344,15 +350,19 @@ def calibrate(self, train_data_file):
344350
pass
345351
n_bpts = len(df.columns.get_level_values("bodyparts").unique())
346352
if n_bpts == 1:
347-
warnings.warn("There is only one keypoint; skipping calibration...")
353+
warnings.warn(
354+
"There is only one keypoint; skipping calibration...", stacklevel=2
355+
)
348356
return
349357

350358
xy = df.to_numpy().reshape((-1, n_bpts, 2))
351359
frac_valid = np.mean(~np.isnan(xy), axis=(1, 2))
352360
# Only keeps skeletons that are more than 90% complete
353361
xy = xy[frac_valid >= 0.9]
354362
if not xy.size:
355-
warnings.warn("No complete poses were found. Skipping calibration...")
363+
warnings.warn(
364+
"No complete poses were found. Skipping calibration...", stacklevel=2
365+
)
356366
return
357367

358368
# TODO Normalize dists by longest length?
@@ -369,7 +379,8 @@ def calibrate(self, train_data_file):
369379
except np.linalg.LinAlgError:
370380
# Covariance matrix estimation fails due to numerical singularities
371381
warnings.warn(
372-
"The assembler could not be robustly calibrated. Continuing without it..."
382+
"The assembler could not be robustly calibrated. Continuing without it...",
383+
stacklevel=2,
373384
)
374385

375386
def calc_assembly_mahalanobis_dist(
@@ -428,10 +439,12 @@ def _flatten_detections(data_dict):
428439
ids = [np.ones(len(arr), dtype=int) * -1 for arr in confidence]
429440
else:
430441
ids = [arr.argmax(axis=1) for arr in ids]
431-
for i, (coords, conf, id_) in enumerate(zip(coordinates, confidence, ids)):
442+
for i, (coords, conf, id_) in enumerate(
443+
zip(coordinates, confidence, ids, strict=False)
444+
):
432445
if not np.any(coords):
433446
continue
434-
for xy, p, g in zip(coords, conf, id_):
447+
for xy, p, g in zip(coords, conf, id_, strict=False):
435448
joint = Joint(tuple(xy), p.item(), i, ind, g)
436449
ind += 1
437450
yield joint
@@ -474,13 +487,13 @@ def extract_best_links(self, joints_dict, costs, trees=None):
474487
(conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity)
475488
)
476489
candidates = sorted(
477-
zip(rows, cols, aff[rows, cols], lengths[rows, cols]),
490+
zip(rows, cols, aff[rows, cols], lengths[rows, cols], strict=False),
478491
key=lambda x: x[2],
479492
reverse=True,
480493
)
481494
i_seen = set()
482495
j_seen = set()
483-
for i, j, w, l in candidates:
496+
for i, j, w, _l in candidates:
484497
if i not in i_seen and j not in j_seen:
485498
i_seen.add(i)
486499
j_seen.add(j)
@@ -502,7 +515,7 @@ def extract_best_links(self, joints_dict, costs, trees=None):
502515
]
503516
aff = aff[np.ix_(keep_s, keep_t)]
504517
rows, cols = linear_sum_assignment(aff, maximize=True)
505-
for row, col in zip(rows, cols):
518+
for row, col in zip(rows, cols, strict=False):
506519
w = aff[row, col]
507520
if w >= self.min_affinity:
508521
links.append(Link(dets_s[keep_s[row]], dets_t[keep_t[col]], w))
@@ -548,9 +561,9 @@ def push_to_stack(i):
548561
d = self.calc_assembly_mahalanobis_dist(assembly, nan_policy=nan_policy)
549562
if d < d_old:
550563
push_to_stack(new_ind)
551-
if tabu:
552-
_, _, link = heapq.heappop(tabu)
553-
heapq.heappush(stack, (-link.affinity, next(counter), link))
564+
if tabu:
565+
_, _, link = heapq.heappop(tabu)
566+
heapq.heappush(stack, (-link.affinity, next(counter), link))
554567
else:
555568
heapq.heappush(tabu, (d - d_old, next(counter), best))
556569
assembly.__dict__.update(assembly._dict)
@@ -665,7 +678,7 @@ def build_assemblies(self, links):
665678
for idx in store[j]._idx:
666679
store[idx] = store[i]
667680
except KeyError:
668-
# Some links may reference indices that were never added to `store`;
681+
# Some links may reference indices that were never added to `store`;
669682
# in that case we intentionally skip merging for this link
670683
pass
671684

@@ -791,7 +804,7 @@ def _assemble(self, data_dict, ind_frame):
791804
]
792805
else:
793806
scores = [ass._affinity for ass in assemblies]
794-
lst = list(zip(scores, assemblies))
807+
lst = list(zip(scores, assemblies, strict=False))
795808
assemblies = []
796809
while lst:
797810
temp = max(lst, key=lambda x: x[0])
@@ -1074,7 +1087,7 @@ def match_assemblies(
10741087
if ~np.isnan(oks):
10751088
mat[i, j] = oks
10761089
rows, cols = linear_sum_assignment(mat, maximize=True)
1077-
for row, col in zip(rows, cols):
1090+
for row, col in zip(rows, cols, strict=False):
10781091
matched[row].ground_truth = ground_truth[col]
10791092
matched[row].oks = mat[row, col]
10801093
_ = inds_true.remove(col)
@@ -1087,7 +1100,7 @@ def parse_ground_truth_data_file(h5_file):
10871100
try:
10881101
df.drop("single", axis=1, level="individuals", inplace=True)
10891102
except KeyError:
1090-
# Ignore if the "single" individual column is absent
1103+
# Ignore if the "single" individual column is absent
10911104
pass
10921105
# Cast columns of dtype 'object' to float to avoid TypeError
10931106
# further down in _parse_ground_truth_data.
@@ -1128,7 +1141,7 @@ def find_outlier_assemblies(dict_of_assemblies, criterion="area", qs=(5, 95)):
11281141
for frame_ind, assemblies in dict_of_assemblies.items():
11291142
for assembly in assemblies:
11301143
tuples.append((frame_ind, getattr(assembly, criterion)))
1131-
frame_inds, vals = zip(*tuples)
1144+
frame_inds, vals = zip(*tuples, strict=False)
11321145
vals = np.asarray(vals)
11331146
lo, up = np.percentile(vals, qs, interpolation="nearest")
11341147
inds = np.flatnonzero((vals < lo) | (vals > up)).tolist()
@@ -1246,12 +1259,14 @@ def evaluate_assembly(
12461259
ass_pred_dict,
12471260
ass_true_dict,
12481261
oks_sigma=0.072,
1249-
oks_thresholds=np.linspace(0.5, 0.95, 10),
1262+
oks_thresholds=None,
12501263
margin=0,
12511264
symmetric_kpts=None,
12521265
greedy_matching=False,
12531266
with_tqdm: bool = True,
12541267
):
1268+
if oks_thresholds is None:
1269+
oks_thresholds = np.linspace(0.5, 0.95, 10)
12551270
if greedy_matching:
12561271
return evaluate_assembly_greedy(
12571272
ass_true_dict,

dlclive/display.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
try:
99
from tkinter import Label, Tk
10+
1011
from PIL import ImageTk
12+
1113
_TKINTER_AVAILABLE = True
1214
except ImportError:
1315
_TKINTER_AVAILABLE = False
@@ -59,7 +61,9 @@ def set_display(self, im_size, bodyparts):
5961
self.lab.pack()
6062

6163
all_colors = getattr(cc, self.cmap)
62-
self.colors = all_colors[:: int(len(all_colors) / bodyparts)]
64+
# Avoid 0 step
65+
step = max(1, int(len(all_colors) / bodyparts))
66+
self.colors = all_colors[::step]
6367

6468
def display_frame(self, frame, pose=None):
6569
"""
@@ -75,10 +79,10 @@ def display_frame(self, frame, pose=None):
7579
"""
7680
if not _TKINTER_AVAILABLE:
7781
raise ImportError("tkinter is not available. Cannot display frames.")
78-
82+
7983
im_size = (frame.shape[1], frame.shape[0])
84+
img = Image.fromarray(frame) # avoid undefined image if pose is None
8085
if pose is not None:
81-
img = Image.fromarray(frame)
8286
draw = ImageDraw.Draw(img)
8387

8488
if len(pose.shape) == 2:
@@ -91,33 +95,16 @@ def display_frame(self, frame, pose=None):
9195
for j in range(pose.shape[1]):
9296
if pose[i, j, 2] > self.pcutoff:
9397
try:
94-
x0 = (
95-
pose[i, j, 0] - self.radius
96-
if pose[i, j, 0] - self.radius > 0
97-
else 0
98-
)
99-
x1 = (
100-
pose[i, j, 0] + self.radius
101-
if pose[i, j, 0] + self.radius < im_size[0]
102-
else im_size[1]
103-
)
104-
y0 = (
105-
pose[i, j, 1] - self.radius
106-
if pose[i, j, 1] - self.radius > 0
107-
else 0
108-
)
109-
y1 = (
110-
pose[i, j, 1] + self.radius
111-
if pose[i, j, 1] + self.radius < im_size[1]
112-
else im_size[0]
113-
)
98+
x0 = max(0, pose[i, j, 0] - self.radius)
99+
x1 = min(im_size[0], pose[i, j, 0] + self.radius)
100+
y0 = max(0, pose[i, j, 1] - self.radius)
101+
y1 = min(im_size[1], pose[i, j, 1] + self.radius)
114102
coords = [x0, y0, x1, y1]
115103
draw.ellipse(
116104
coords, fill=self.colors[j], outline=self.colors[j]
117105
)
118106
except Exception as e:
119107
print(e)
120-
121108
img_tk = ImageTk.PhotoImage(image=img, master=self.window)
122109
self.lab.configure(image=img_tk)
123110
self.window.update()

dlclive/dlclive.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
55
Licensed under GNU Lesser General Public License v3.0
66
"""
7+
78
from __future__ import annotations
89

910
from pathlib import Path
@@ -197,12 +198,12 @@ def __init__(
197198
self.processor = processor
198199
self.convert2rgb = convert2rgb
199200

201+
self.pose: np.ndarray | None = None
202+
200203
if isinstance(display, Display):
201204
self.display = display
202205
elif display:
203-
self.display = Display(
204-
pcutoff=pcutoff, radius=display_radius, cmap=display_cmap
205-
)
206+
self.display = Display(pcutoff=pcutoff, radius=display_radius, cmap=display_cmap)
206207
else:
207208
self.display = None
208209

@@ -250,9 +251,7 @@ def process_frame(self, frame: np.ndarray) -> np.ndarray:
250251
processed frame: convert type, crop, convert color
251252
"""
252253
if self.cropping:
253-
frame = frame[
254-
self.cropping[2] : self.cropping[3], self.cropping[0] : self.cropping[1]
255-
]
254+
frame = frame[self.cropping[2] : self.cropping[3], self.cropping[0] : self.cropping[1]]
256255

257256
if self.dynamic[0]:
258257
if self.pose is not None:
@@ -263,9 +262,7 @@ def process_frame(self, frame: np.ndarray) -> np.ndarray:
263262
elif len(self.pose) == 1:
264263
pose = self.pose[0]
265264
else:
266-
raise ValueError(
267-
"Cannot use Dynamic Cropping - more than 1 individual found"
268-
)
265+
raise ValueError("Cannot use Dynamic Cropping - more than 1 individual found")
269266

270267
else:
271268
pose = self.pose

0 commit comments

Comments
 (0)