Skip to content

Commit 331a097

Browse files
feat : idtracker tests added
feat : optional : probablities logic fixed feat : updated expectations feat : strings seprated
1 parent 365c909 commit 331a097

4 files changed

Lines changed: 122 additions & 5 deletions

File tree

movement/io/load_poses.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def from_file(
106106
"LightningPose",
107107
"Anipose",
108108
"NWB",
109+
"idtracker.ai",
109110
],
110111
fps: float | None = None,
111112
**kwargs,
@@ -175,6 +176,8 @@ def from_file(
175176
return from_lp_file(file, fps)
176177
elif source_software == "Anipose":
177178
return from_anipose_file(file, fps, **kwargs)
179+
elif source_software == "idtracker.ai":
180+
return from_idtracker_file(file, fps)
178181
elif source_software == "NWB":
179182
if fps is not None:
180183
logger.warning(
@@ -313,6 +316,9 @@ def from_idtracker_style_dict(
313316
trajectories = np.asarray(idtracker_data["trajectories"])
314317
n_frames, n_individuals, _ = trajectories.shape
315318

319+
# Reshape from idtracker (frames, individuals, space)
320+
# to movement (frames, space, keypoints, individuals)
321+
# Note: idtracker does not track multiple keypoints, so keypoints=1
316322
pos_reshaped = np.moveaxis(trajectories, source=-1, destination=1)
317323
position_array = np.expand_dims(pos_reshaped, axis=2)
318324

@@ -327,6 +333,8 @@ def from_idtracker_style_dict(
327333
else:
328334
probs = np.asarray(probs)
329335

336+
# Handle idtracker.ai edge case: some versions output probabilities
337+
# with an undocumented trailing singleton dimension (e.g., (N, M, 1))
330338
if probs.ndim == 3 and probs.shape[2] == 1:
331339
probs = probs[:, :, 0]
332340
elif probs.ndim != 2:
@@ -657,11 +665,10 @@ def _ds_from_idtracker_file(
657665
"""
658666
file_path = valid_file.file
659667

660-
# --- THE ROUTER ---
661668
if isinstance(valid_file, ValidIdtrackerH5):
662669
idtracker_data = _dict_from_idtracker_h5(file_path)
663670
else:
664-
logger.error(
671+
raise logger.error(
665672
TypeError(f"Unsupported idtracker file type: {type(valid_file)}")
666673
)
667674

@@ -867,7 +874,7 @@ def _dict_from_idtracker_h5(path: Path) -> dict[str, Any]:
867874
"""Create a dictionary of idtracker.ai pose data from an .h5 file."""
868875
with h5py.File(path, "r") as f:
869876
trajectories = f["trajectories"][:]
870-
probs = f["id_probabilities"][:]
877+
probs = f["id_probabilities"][:] if "id_probabilities" in f else None
871878

872879
fps = f.attrs.get("frames_per_second")
873880
if isinstance(fps, (list, tuple, np.ndarray)):

movement/validators/files.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,9 +418,10 @@ class ValidIdtrackerH5:
418418
converter=Path,
419419
validator=validators.and_(
420420
_file_validator(permission="r", suffixes=suffixes),
421-
_hdf5_validator(datasets={"trajectories", "id_probabilities"}),
421+
_hdf5_validator(datasets={"trajectories"}),
422422
),
423423
)
424+
"""Path to the idtracker.ai .h5 file to validate."""
424425

425426

426427
@define

tests/fixtures/files.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from unittest.mock import mock_open, patch
99

1010
import h5py
11+
import numpy as np
1112
import pytest
1213
import xarray as xr
1314
from sleap_io.io.slp import read_labels, write_labels
@@ -515,6 +516,57 @@ def anipose_csv_file():
515516
)
516517

517518

519+
# ---------------- idtracker.ai file fixtures ----------------------------
520+
@pytest.fixture
521+
def idtracker_valid_h5_file(tmp_path):
522+
"""Return the path to a valid idtracker.ai .h5 file."""
523+
file_path = tmp_path / "valid_idtracker.h5"
524+
525+
with h5py.File(file_path, "w") as f:
526+
# 10 frames, 3 individuals, 2 spatial dimensions (x, y)
527+
f.create_dataset("trajectories", data=np.ones((10, 3, 2)))
528+
f.create_dataset("id_probabilities", data=np.ones((10, 3)))
529+
f.attrs["frames_per_second"] = np.array([30.0])
530+
return file_path
531+
532+
533+
@pytest.fixture
534+
def idtracker_buggy_shape_h5_file(tmp_path):
535+
"""Return the path to an idtracker.ai .h5
536+
file with trailing singleton dimension.
537+
"""
538+
file_path = tmp_path / "buggy_idtracker.h5"
539+
with h5py.File(file_path, "w") as f:
540+
f.create_dataset("trajectories", data=np.ones((10, 3, 2)))
541+
# Buggy shape: (10, 3, 1) instead of (10, 3)
542+
f.create_dataset("id_probabilities", data=np.ones((10, 3, 1)))
543+
f.attrs["frames_per_second"] = np.array([30.0])
544+
return file_path
545+
546+
547+
@pytest.fixture
548+
def idtracker_trackless_h5_file(tmp_path):
549+
"""Return the path to an idtracker.ai .h5 file missing id_probabilities."""
550+
file_path = tmp_path / "trackless_idtracker.h5"
551+
with h5py.File(file_path, "w") as f:
552+
f.create_dataset("trajectories", data=np.ones((10, 3, 2)))
553+
# Intentionally omitting the id_probabilities dataset
554+
f.attrs["frames_per_second"] = np.array([30.0])
555+
return file_path
556+
557+
558+
@pytest.fixture(
559+
params=[
560+
"idtracker_valid_h5_file",
561+
"idtracker_buggy_shape_h5_file",
562+
"idtracker_trackless_h5_file",
563+
]
564+
)
565+
def idtracker_h5_file(request):
566+
"""Fixture to parametrize various idtracker.ai files."""
567+
return request.getfixturevalue(request.param)
568+
569+
518570
# ---------------- netCDF file fixtures ----------------------------
519571
@pytest.fixture(scope="session")
520572
def invalid_netcdf_file_missing_confidence(tmp_path_factory):

tests/test_unit/test_io/test_load_poses.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,54 @@ def test_load_from_anipose_file():
245245
]
246246

247247

248+
def test_load_from_idtracker_file(idtracker_h5_file, helpers):
249+
ds = load_poses.from_idtracker_file(idtracker_h5_file)
250+
expected_values = {
251+
**expected_values_poses,
252+
"source_software": "idtracker.ai",
253+
"file_path": idtracker_h5_file,
254+
"fps": 30.0,
255+
}
256+
helpers.assert_valid_dataset(ds, expected_values)
257+
258+
259+
def test_load_from_idtracker_style_dict(helpers):
260+
"""Test loading pose tracks from an idtracker.ai style dictionary."""
261+
idtracker_dict = {
262+
"trajectories": np.ones(
263+
(10, 3, 2)
264+
), # 10 frames, 3 individuals, 2 spatial (x, y)
265+
"id_probabilities": np.ones((10, 3)),
266+
"frames_per_second": 30.0,
267+
}
268+
269+
# Pass it directly to our formatter
270+
ds = load_poses.from_idtracker_style_dict(idtracker_dict)
271+
272+
# Verify it creates a perfect movement dataset
273+
expected_values = {
274+
**expected_values_poses,
275+
"source_software": "idtracker.ai",
276+
"fps": 30.0,
277+
}
278+
helpers.assert_valid_dataset(ds, expected_values)
279+
280+
281+
def test_load_idtracker_without_probs(idtracker_trackless_h5_file):
282+
"""Test that loading an idtracker.ai file without identity probabilities
283+
returns a dataset with NaN confidence scores and default individual names.
284+
"""
285+
ds = load_poses.from_idtracker_file(idtracker_trackless_h5_file)
286+
287+
# 1. Check if default individual names were assigned
288+
# (our fixture has 3 individuals)
289+
assert ds.individuals.values.tolist() == ["id_0", "id_1", "id_2"]
290+
291+
# 2. Check if confidence scores are NaN
292+
# (since no probabilities were provided)
293+
assert np.isnan(ds.confidence.values).all()
294+
295+
248296
@pytest.mark.parametrize("kwargs", [{}, {"rate": 10.0, "starting_time": 0.0}])
249297
@pytest.mark.parametrize("input_type", ["nwb_file", "nwbfile_object"])
250298
def test_load_from_nwb_file(input_type, kwargs, request):
@@ -275,7 +323,15 @@ def test_load_from_nwb_file(input_type, kwargs, request):
275323
@pytest.mark.filterwarnings("ignore:.*is deprecated:DeprecationWarning")
276324
@pytest.mark.parametrize(
277325
"source_software",
278-
["DeepLabCut", "SLEAP", "LightningPose", "Anipose", "NWB", "Unknown"],
326+
[
327+
"DeepLabCut",
328+
"SLEAP",
329+
"LightningPose",
330+
"Anipose",
331+
"NWB",
332+
"idtracker.ai",
333+
"Unknown",
334+
],
279335
)
280336
@pytest.mark.parametrize("fps", [None, 30, 60.0])
281337
def test_from_file_delegates_correctly(source_software, fps, caplog):
@@ -288,6 +344,7 @@ def test_from_file_delegates_correctly(source_software, fps, caplog):
288344
"LightningPose": "movement.io.load_poses.from_lp_file",
289345
"Anipose": "movement.io.load_poses.from_anipose_file",
290346
"NWB": "movement.io.load_poses.from_nwb_file",
347+
"idtracker.ai": "movement.io.load_poses.from_idtracker_file",
291348
}
292349
if source_software == "Unknown":
293350
with pytest.raises(ValueError, match="Unsupported source"):

0 commit comments

Comments
 (0)