diff --git a/pyproject.toml b/pyproject.toml index 1256114c..ea088ddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,9 @@ dev = [ "pytest-qt", "tox", ] +tracking = [ + "torch", +] [project.urls] "Bug Tracker" = "https://github.com/DeepLabCut/napari-deeplabcut/issues" Documentation = "https://github.com/DeepLabCut/napari-deeplabcut#README.md" diff --git a/src/napari_deeplabcut/_reader.py b/src/napari_deeplabcut/_reader.py index 6dcff093..52c29822 100644 --- a/src/napari_deeplabcut/_reader.py +++ b/src/napari_deeplabcut/_reader.py @@ -4,9 +4,11 @@ from __future__ import annotations import logging +from functools import partial from pathlib import Path from napari_deeplabcut.config._autostart import maybe_install_keypoint_controls_autostart +from napari_deeplabcut.config.models import DLCProjectContext from napari_deeplabcut.core.discovery import discover_annotations from napari_deeplabcut.core.io import ( SUPPORTED_IMAGES, @@ -17,11 +19,41 @@ read_images, read_video, ) -from napari_deeplabcut.core.project_paths import looks_like_dlc_labeled_folder +from napari_deeplabcut.core.project_paths import ( + infer_dlc_project_from_labeled_folder, + infer_dlc_project_from_video_path, + looks_like_dlc_labeled_folder, + session_key_from_project_context, +) logger = logging.getLogger(__name__) +def _build_dlc_layer_meta( + *, + session_role: str | None, + project_context: DLCProjectContext | None, +) -> dict: + """ + Build explicit DLC lifecycle metadata for image/video layers. + + If session_role is None or project_context is None, the layer should be + treated as a non-session image/video by lifecycle code. + """ + if session_role is None or project_context is None: + return { + "session_role": None, + "project_context": None, + "session_key": None, + } + + return { + "session_role": session_role, + "project_context": project_context.model_dump(mode="python", exclude_none=True), + "session_key": session_key_from_project_context(project_context), + } + + def get_hdf_reader(path): if isinstance(path, list): path = path[0] @@ -40,9 +72,28 @@ def get_image_reader(path): def get_video_reader(path): - if isinstance(path, str) and any(path.lower().endswith(ext) for ext in SUPPORTED_VIDEOS): - return read_video - return None + if not isinstance(path, str) or not any(path.lower().endswith(ext) for ext in SUPPORTED_VIDEOS): + return None + + ctx = infer_dlc_project_from_video_path(path) + if ctx is None: + # Generic non-DLC video layer: allowed, but ignored by lifecycle session context. + return partial( + read_video, + dlc_meta=_build_dlc_layer_meta( + session_role=None, + project_context=None, + ), + ) + + maybe_install_keypoint_controls_autostart() + return partial( + read_video, + dlc_meta=_build_dlc_layer_meta( + session_role="video", + project_context=ctx, + ), + ) def get_config_reader(path): @@ -87,7 +138,17 @@ def get_folder_parser(path): ) return None - layers.extend(read_images(images)) + ctx = infer_dlc_project_from_labeled_folder(path) + + layers.extend( + read_images( + images, + dlc_meta=_build_dlc_layer_meta( + session_role="image", + project_context=ctx, + ), + ) + ) # Deterministic discovery: load ALL H5 artifacts artifacts = discover_annotations(path) diff --git a/src/napari_deeplabcut/_tests/config/test_keybinds.py b/src/napari_deeplabcut/_tests/config/test_keybinds.py index 4e6c0b88..31350339 100644 --- a/src/napari_deeplabcut/_tests/config/test_keybinds.py +++ b/src/napari_deeplabcut/_tests/config/test_keybinds.py @@ -23,8 +23,12 @@ def bind_key(self, key, callback, overwrite=False): def test_iter_shortcuts_returns_registry(): shortcuts = tuple(keybinds.iter_shortcuts()) - assert shortcuts == keybinds.SHORTCUTS - assert shortcuts, "SHORTCUTS should not be empty" + expected = keybinds.SHORTCUTS + if keybinds.TRACKING_SHORTCUTS_ENABLED: + expected = expected + keybinds.TRACKING_SHORTCUTS + + assert shortcuts == expected + assert keybinds.SHORTCUTS, "SHORTCUTS should not be empty" def test_shortcuts_registry_points_layer_entries_have_callbacks(): diff --git a/src/napari_deeplabcut/_tests/conftest.py b/src/napari_deeplabcut/_tests/conftest.py index 0b149102..32a3b7e7 100644 --- a/src/napari_deeplabcut/_tests/conftest.py +++ b/src/napari_deeplabcut/_tests/conftest.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd import pytest +from napari.utils.events import Event from qtpy.QtWidgets import QApplication, QDockWidget from skimage.io import imsave @@ -199,9 +200,43 @@ def images(tmp_path_factory, viewer, fake_image): return viewer.open(output_path, plugin="napari-deeplabcut")[0] +class DummyDimsForStore: + def __init__(self, nsteps=5, current_step=0): + self.nsteps = (nsteps,) + self.current_step = (current_step,) + self.set_calls = [] + + def set_current_step(self, axis, value): + self.set_calls.append((axis, value)) + steps = list(self.current_step) + while len(steps) <= axis: + steps.append(0) + steps[axis] = value + self.current_step = tuple(steps) + + +class DummyViewerForStore: + def __init__(self, nsteps=5, current_step=0): + self.dims = DummyDimsForStore(nsteps=nsteps, current_step=current_step) + + @pytest.fixture -def store(viewer, points): - return keypoints.KeypointStore(viewer, points) +def store(points): + try: + data = np.asarray(points.data) + nsteps = int(np.nanmax(data[:, 0])) + 1 if data.size else 1 + except Exception: + nsteps = 1 + + viewer = DummyViewerForStore(nsteps=nsteps) + store = keypoints.KeypointStore(viewer, points) + + # Mimic the minimal runtime wiring used by LOOP mode + if not hasattr(points.events, "query_next_frame"): + points.events.add(query_next_frame=Event) + points.events.query_next_frame.connect(store._advance_step) + + return store @pytest.fixture diff --git a/src/napari_deeplabcut/_tests/core/io/test_hdf_reader.py b/src/napari_deeplabcut/_tests/core/io/test_hdf_reader.py index 8aacb79b..8e3ba14e 100644 --- a/src/napari_deeplabcut/_tests/core/io/test_hdf_reader.py +++ b/src/napari_deeplabcut/_tests/core/io/test_hdf_reader.py @@ -29,7 +29,7 @@ def _write_h5_single_animal( ) df = pd.DataFrame([values], index=list(index), columns=cols) path.parent.mkdir(parents=True, exist_ok=True) - df.to_hdf(path, key="keypoints", mode="w") + df.to_hdf(path, key="df_with_missing", mode="w") return df diff --git a/src/napari_deeplabcut/_tests/core/io/test_write_routing.py b/src/napari_deeplabcut/_tests/core/io/test_write_routing.py index 60602d12..4edbbb5e 100644 --- a/src/napari_deeplabcut/_tests/core/io/test_write_routing.py +++ b/src/napari_deeplabcut/_tests/core/io/test_write_routing.py @@ -18,7 +18,7 @@ def test_resolve_output_path_returns_none_for_machine_without_save_target(): "project_root": str(Path.cwd()), "source_relpath_posix": "machinelabels-iter0.h5", "kind": AnnotationKind.MACHINE, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", } } } @@ -39,7 +39,7 @@ def test_write_hdf_refuses_machine_without_promotion(tmp_path: Path): "project_root": str(tmp_path), "source_relpath_posix": "machinelabels-iter0.h5", "kind": AnnotationKind.MACHINE, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", }, # header is required by writer "header": { @@ -84,7 +84,7 @@ def test_write_hdf_aborts_machine_without_promotion_target(tmp_path: Path): "project_root": str(tmp_path), "source_relpath_posix": "machinelabels-iter0.h5", "kind": AnnotationKind.MACHINE, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", }, "header": {"columns": [("S", "", "bp1", "x"), ("S", "", "bp1", "y")]}, }, diff --git a/src/napari_deeplabcut/_tests/core/layer_manager/__init__.py b/src/napari_deeplabcut/_tests/core/layer_manager/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/napari_deeplabcut/_tests/core/layer_manager/test_manager.py b/src/napari_deeplabcut/_tests/core/layer_manager/test_manager.py new file mode 100644 index 00000000..95e756c7 --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/layer_manager/test_manager.py @@ -0,0 +1,433 @@ +from __future__ import annotations + +import gc +from types import SimpleNamespace + +import numpy as np +import pytest +from napari.layers import Image, Points + +from napari_deeplabcut.core.layer_lifecycle import LayerLifecycleManager + + +def mark_as_dlc_session_image(layer, *, role="image"): + layer.metadata = dict(layer.metadata or {}) + layer.metadata["dlc"] = { + "session_role": role, + "project_context": { + "root_anchor": "C:/project/labeled-data/test", + "project_root": "C:/project", + "config_path": "C:/project/config.yaml", + "dataset_folder": "C:/project/labeled-data/test", + }, + "session_key": "C:/project", + } + return layer + + +# --------------------------------------------------------------------------- +# Minimal fake viewer/layer event infrastructure +# --------------------------------------------------------------------------- +class DummySignal: + def __init__(self): + self._callbacks = [] + + def connect(self, callback): + if callback not in self._callbacks: + self._callbacks.append(callback) + + def disconnect(self, callback): + if callback in self._callbacks: + self._callbacks.remove(callback) + + @property + def callbacks(self): + return list(self._callbacks) + + +class DummyLayerEvents: + def __init__(self): + self.inserted = DummySignal() + self.removed = DummySignal() + + +class DummyLayerList(list): + def __init__(self, layers=()): + super().__init__(layers) + self.events = DummyLayerEvents() + + +class DummyViewer: + def __init__(self, layers=()): + self.layers = DummyLayerList(layers) + + +class DummyImageMeta: + def __init__(self): + self.root = None + self.paths = None + + def model_dump(self, **kwargs): + return {} + + +class SignalRecorder: + def __init__(self): + self.calls = [] + + def __call__(self, *args): + self.calls.append(args) + + @property + def count(self): + return len(self.calls) + + +def connect_signal_recorders(manager): + rec = SimpleNamespace( + refresh_video=SignalRecorder(), + refresh_status=SignalRecorder(), + setup_points=SignalRecorder(), + merged_points=SignalRecorder(), + removed_points=SignalRecorder(), + removed_tracks=SignalRecorder(), + move_image_bottom=SignalRecorder(), + video_visibility=SignalRecorder(), + adopted=SignalRecorder(), + inserted=SignalRecorder(), + removed=SignalRecorder(), + conflicts=SignalRecorder(), + ) + + manager.refresh_video_panel_requested.connect(rec.refresh_video) + manager.refresh_layer_status_requested.connect(rec.refresh_status) + manager.points_layer_setup_requested.connect(rec.setup_points) + manager.points_layers_merged_requested.connect(rec.merged_points) + manager.points_layer_removed_requested.connect(rec.removed_points) + manager.tracks_layer_removed_requested.connect(rec.removed_tracks) + manager.move_image_layer_to_bottom_requested.connect(rec.move_image_bottom) + manager.video_widget_visibility_requested.connect(rec.video_visibility) + manager.adopted_existing_layers.connect(rec.adopted) + manager.layer_insert_processed.connect(rec.inserted) + manager.layer_remove_processed.connect(rec.removed) + manager.session_conflict_rejected.connect(rec.conflicts) + return rec + + +# --------------------------------------------------------------------------- +# Fake store used so manager tests do not depend on real KeypointStore/viewer +# --------------------------------------------------------------------------- + + +class FakeStore: + def __init__(self, viewer, layer): + self.viewer = viewer + self._layer = layer + self._layer_id = id(layer) + self._resolver = None + self._get_label_mode = None + + @property + def layer(self): + return self._layer + + @layer.setter + def layer(self, layer): + self._layer = layer + self._layer_id = id(layer) + + @property + def layer_id(self): + return self._layer_id + + def attach_layer_resolver(self, resolver): + self._resolver = resolver + + def set_label_mode_getter(self, getter): + self._get_label_mode = getter + + def _advance_step(self, event=None): + return None + + def add(self, coord): + return None + + +# --------------------------------------------------------------------------- +# Shared factories +# --------------------------------------------------------------------------- + + +def make_image(name="img"): + layer = Image(np.zeros((5, 5))) + layer.name = name + return layer + + +def make_points(name="pts"): + layer = Points(np.zeros((0, 3))) + layer.name = name + return layer + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def immediate_qtimer(monkeypatch): + from napari_deeplabcut.core.layer_lifecycle import manager as manager_module + + monkeypatch.setattr( + manager_module.QTimer, + "singleShot", + staticmethod(lambda _ms, fn: fn()), + ) + + +@pytest.fixture +def fake_store(monkeypatch): + from napari_deeplabcut.core.layer_lifecycle import manager as manager_module + + monkeypatch.setattr(manager_module.keypoints, "KeypointStore", FakeStore) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_manager_attach_and_detach_are_idempotent(qtbot): + viewer = DummyViewer() + manager = LayerLifecycleManager(viewer=viewer) + + manager.attach() + manager.attach() + + assert viewer.layers.events.inserted.callbacks == [manager.on_insert] + assert viewer.layers.events.removed.callbacks == [manager.on_remove] + + manager.detach() + manager.detach() + + assert viewer.layers.events.inserted.callbacks == [] + assert viewer.layers.events.removed.callbacks == [] + + +def test_manager_register_and_query_managed_points(qtbot): + viewer = DummyViewer() + manager = LayerLifecycleManager(viewer=viewer) + + pts = make_points() + store = object() + + assert manager.has_managed_points() is False + assert manager.managed_points_count() == 0 + assert manager.managed_points_layers() == () + assert manager.resolve_live_layer(pts) is None + assert manager.get_live_runtime(pts) is None + assert manager.get_store(pts) is None + + manager.register_managed_points_layer(pts, store) + + assert manager.is_managed(pts) is True + assert manager.has_managed_points() is True + assert manager.managed_points_count() == 1 + assert manager.managed_points_layers() == (pts,) + assert list(manager.iter_managed_points()) == [(pts, store)] + + assert manager.resolve_live_layer(pts) is pts + runtime = manager.get_live_runtime(pts) + assert runtime is not None + assert runtime.layer_id == id(pts) + assert runtime.store is store + assert manager.get_store(pts) is store + assert manager.require_store(pts) is store + + removed = manager.unregister_managed_layer(pts) + assert removed is store + + assert manager.is_managed(pts) is False + assert manager.has_managed_points() is False + assert manager.managed_points_count() == 0 + assert manager.managed_points_layers() == () + assert manager.resolve_live_layer(pts) is None + assert manager.get_live_runtime(pts) is None + assert manager.get_store(pts) is None + + +def test_manager_on_insert_points_sets_up_points_and_refreshes_ui( + qtbot, + fake_store, + monkeypatch, +): + img = make_image() + pts_existing = make_points("existing") + pts_inserted = make_points("inserted") + + viewer = DummyViewer([img, pts_existing, pts_inserted]) + manager = LayerLifecycleManager(viewer=viewer) + rec = connect_signal_recorders(manager) + + monkeypatch.setattr(manager, "validate_header", lambda layer: True) + + remap_calls = [] + manager._remap_frame_indices = lambda layer: remap_calls.append(layer) + + event = SimpleNamespace(value=pts_inserted, index=2, source=viewer.layers) + + manager.on_insert(event) + + assert manager.is_managed(pts_inserted) is True + + assert rec.setup_points.count == 1 + req = rec.setup_points.calls[0][0] + assert req.layer is pts_inserted + assert req.store is manager.get_store(pts_inserted) + + assert rec.refresh_video.count >= 1 + assert rec.refresh_status.count >= 1 + assert rec.inserted.count == 1 + assert rec.inserted.calls[0][0] is pts_inserted + + assert pts_existing in remap_calls + assert pts_inserted in remap_calls + + +def test_manager_on_insert_image_updates_context_and_refreshes_ui(qtbot): + img = mark_as_dlc_session_image(make_image("inserted-image")) + pts = make_points() + + viewer = DummyViewer([img, pts]) + manager = LayerLifecycleManager(viewer=viewer) + rec = connect_signal_recorders(manager) + + remap_calls = [] + manager._remap_frame_indices = lambda layer: remap_calls.append(layer) + + event = SimpleNamespace(value=img, index=0, source=viewer.layers) + + manager.on_insert(event) + + assert manager.active_dlc_image_layer() is img + assert manager.image_meta.name == "inserted-image" + + assert rec.refresh_video.count >= 1 + assert rec.refresh_status.count >= 1 + assert rec.move_image_bottom.count == 1 + assert rec.move_image_bottom.calls[0][0] is img + + assert pts in remap_calls + + +def test_manager_adopt_existing_layers_skips_already_managed_points( + qtbot, + fake_store, + monkeypatch, +): + img = mark_as_dlc_session_image(make_image()) + pts_managed = make_points("managed") + pts_unmanaged = make_points("unmanaged") + + viewer = DummyViewer([img, pts_managed, pts_unmanaged]) + manager = LayerLifecycleManager(viewer=viewer) + rec = connect_signal_recorders(manager) + + monkeypatch.setattr(manager, "validate_header", lambda layer: True) + + remap_calls = [] + manager._remap_frame_indices = lambda layer: remap_calls.append(layer) + + manager.register_managed_points_layer(pts_managed, object()) + + manager.adopt_existing_layers() + + assert manager.active_dlc_image_layer() is img + assert manager.image_meta.name == img.name + + assert rec.move_image_bottom.count == 1 + assert rec.move_image_bottom.calls[0][0] is img + + assert rec.setup_points.count == 1 + req = rec.setup_points.calls[0][0] + assert req.layer is pts_unmanaged + + assert rec.adopted.count == 1 + + assert pts_managed in remap_calls + assert pts_unmanaged in remap_calls + + +def test_manager_on_remove_triggers_ui_cleanup_and_refresh(qtbot): + pts = make_points() + viewer = DummyViewer([pts]) + manager = LayerLifecycleManager(viewer=viewer) + rec = connect_signal_recorders(manager) + + manager.register_managed_points_layer(pts, object()) + + event = SimpleNamespace(value=pts) + + manager.on_remove(event) + + assert rec.removed_points.count == 1 + removed_layer, remaining = rec.removed_points.calls[0] + assert removed_layer is pts + assert remaining == 1 + + manager._flush_post_remove_refresh() + + assert rec.refresh_video.count >= 1 + assert rec.refresh_status.count >= 1 + assert rec.removed.count == 1 + assert rec.removed.calls[0][0] is pts + + +def test_manager_reap_dead_entries_removes_stale_entry(qtbot): + viewer = DummyViewer() + manager = LayerLifecycleManager(viewer=viewer) + + pts = make_points() + store = object() + + manager.register_managed_points_layer(pts, store) + + layer_id = id(pts) + del pts + gc.collect() + + report_before = manager.audit_registry() + assert report_before.dead_count == 1 + assert any(issue.code == "dead-entry" and issue.layer_id == layer_id for issue in report_before.issues) + + reaped = manager.clear_dead_entries(log=False) + + assert len(reaped) == 1 + assert reaped[0].layer_id == layer_id + assert reaped[0].runtime.store is store + + report_after = manager.audit_registry() + assert report_after.dead_count == 0 + assert report_after.issues == () + + +@pytest.mark.parametrize( + ("event_factory", "expected_name"), + [ + (lambda viewer, img, pts: SimpleNamespace(value=pts, index=1, source=viewer.layers), "pts"), + (lambda viewer, img, pts: SimpleNamespace(index=1, source=viewer.layers), "pts"), + (lambda viewer, img, pts: SimpleNamespace(source=[img, pts]), "pts"), + ], +) +def test_manager_resolve_inserted_layer_prefers_value_then_index_then_source(qtbot, event_factory, expected_name): + img = make_image("img") + pts = make_points("pts") + + viewer = DummyViewer([img, pts]) + manager = LayerLifecycleManager(viewer=viewer) + + event = event_factory(viewer, img, pts) + + layer = manager._resolve_inserted_layer(event) + + assert layer is pts + assert layer.name == expected_name diff --git a/src/napari_deeplabcut/_tests/core/layer_manager/test_registry.py b/src/napari_deeplabcut/_tests/core/layer_manager/test_registry.py new file mode 100644 index 00000000..2df3c5b3 --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/layer_manager/test_registry.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import gc + +import pytest + +from napari_deeplabcut.core.layer_lifecycle.registry import ( + ManagedPointsRuntime, + RuntimeRegistry, +) + + +class DummyLayer: + pass + + +def test_registry_register_resolve_require_unregister_roundtrip(): + registry = RuntimeRegistry() + layer = DummyLayer() + runtime = ManagedPointsRuntime(layer_id=id(layer), store="store-1") + + registry.register(layer, runtime) + + assert registry.is_managed(layer) is True + assert registry.resolve_live_layer(layer) is layer + assert registry.get_live_runtime(layer) is runtime + assert registry.require_live_runtime(layer) is runtime + assert registry.get_store(layer) == "store-1" + assert registry.require_store(layer) == "store-1" + assert len(registry) == 1 + + removed = registry.unregister(layer) + + assert removed is runtime + assert registry.is_managed(layer) is False + assert registry.resolve_live_layer(layer) is None + assert registry.get_live_runtime(layer) is None + assert registry.get_store(layer) is None + assert len(registry) == 0 + + +def test_registry_duplicate_registration_raises(): + registry = RuntimeRegistry() + layer = DummyLayer() + + registry.register( + layer, + ManagedPointsRuntime(layer_id=id(layer), store="a"), + ) + + with pytest.raises(ValueError): + registry.register( + layer, + ManagedPointsRuntime(layer_id=id(layer), store="b"), + ) + + +def test_registry_iter_live_items_layers_runtimes_are_consistent(): + registry = RuntimeRegistry() + layer1 = DummyLayer() + layer2 = DummyLayer() + + runtime1 = ManagedPointsRuntime(layer_id=id(layer1), store="s1") + runtime2 = ManagedPointsRuntime(layer_id=id(layer2), store="s2") + + registry.register(layer1, runtime1) + registry.register(layer2, runtime2) + + items = list(registry.iter_live_items()) + layers = list(registry.iter_live_layers()) + runtimes = list(registry.iter_live_runtimes()) + + assert items == [(layer1, runtime1), (layer2, runtime2)] + assert layers == [layer1, layer2] + assert runtimes == [runtime1, runtime2] + assert set(registry.layer_ids()) == {id(layer1), id(layer2)} + + assert registry.get_store(layer1) == "s1" + assert registry.get_store(layer2) == "s2" + + registry.assert_consistent() + + +def test_registry_clear_dead_entries_removes_stale_entry_and_reports_it(): + registry = RuntimeRegistry() + + layer = DummyLayer() + runtime = ManagedPointsRuntime(layer_id=id(layer), store="store") + registry.register(layer, runtime) + + assert len(registry) == 1 + layer_id = id(layer) + + # Remove the only strong reference held by the test. + del layer + gc.collect() + + # Before reaping, the stale id may still be present in the registry index, + # but it should no longer count as live. + assert len(registry) == 0 + assert layer_id in registry.layer_ids() + + reaped = registry.clear_dead_entries(log=False) + + assert len(reaped) == 1 + assert reaped[0].layer_id == layer_id + assert reaped[0].runtime is runtime + + assert layer_id not in registry.layer_ids() + assert len(registry) == 0 + assert list(registry.iter_live_items()) == [] + + +def test_registry_audit_reports_dead_entry_before_reap(): + registry = RuntimeRegistry() + + layer = DummyLayer() + runtime = ManagedPointsRuntime(layer_id=id(layer), store="store") + registry.register(layer, runtime) + + layer_id = id(layer) + del layer + gc.collect() + + report = registry.audit() + + assert report.live_count == 0 + assert report.dead_count == 1 + assert any(issue.code == "dead-entry" and issue.layer_id == layer_id for issue in report.issues) + + reaped = registry.clear_dead_entries(log=False) + assert len(reaped) == 1 + + report_after = registry.audit() + assert report_after.live_count == 0 + assert report_after.dead_count == 0 + assert report_after.issues == () diff --git a/src/napari_deeplabcut/_tests/core/test_conflicts.py b/src/napari_deeplabcut/_tests/core/test_conflicts.py index f2da0b15..d8189283 100644 --- a/src/napari_deeplabcut/_tests/core/test_conflicts.py +++ b/src/napari_deeplabcut/_tests/core/test_conflicts.py @@ -254,7 +254,7 @@ def fake_build_report(conflicts, *, layer_name, destination_path): assert result is report assert seen["set_df_scorer"] == (raw_new_df, "target_scorer") - assert seen["read_hdf_calls"] == [(out, "keypoints")] + assert seen["read_hdf_calls"] == [(out, "df_with_missing")] assert seen["keypoint_conflicts"] == (old_df, promoted_df) assert seen["build_report"] == ( key_conflict, @@ -316,7 +316,7 @@ def test_compute_overwrite_report_falls_back_when_keyed_hdf_read_fails(monkeypat def fake_read_hdf(path, key=None): calls.append((Path(path), key)) - if key == "keypoints": + if key == "df_with_missing": raise KeyError("missing key") return old_df @@ -335,7 +335,7 @@ def fake_read_hdf(path, key=None): assert result is report assert calls == [ - (out, "keypoints"), + (out, "df_with_missing"), (out, None), ] diff --git a/src/napari_deeplabcut/_tests/core/test_dataframes.py b/src/napari_deeplabcut/_tests/core/test_dataframes.py index ca37d768..e7aac883 100644 --- a/src/napari_deeplabcut/_tests/core/test_dataframes.py +++ b/src/napari_deeplabcut/_tests/core/test_dataframes.py @@ -85,6 +85,13 @@ def test_guarantee_multiindex_rows_leaves_numeric_index_unchanged(): assert not isinstance(df.index, pd.MultiIndex) +def test_guarantee_multiindex_rows_empty_df_is_noop(): + df = pd.DataFrame(columns=["x", "y"]) + guarantee_multiindex_rows(df) + assert len(df.index) == 0 + assert not isinstance(df.index, pd.MultiIndex) + + # ----------------------------------------------------------------------------- # 2) harmonize_keypoint_column_index # ----------------------------------------------------------------------------- diff --git a/src/napari_deeplabcut/_tests/core/test_keypoints.py b/src/napari_deeplabcut/_tests/core/test_keypoints.py index 215f02c9..0f9d0cbe 100644 --- a/src/napari_deeplabcut/_tests/core/test_keypoints.py +++ b/src/napari_deeplabcut/_tests/core/test_keypoints.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from napari_deeplabcut.core import keypoints @@ -11,8 +12,8 @@ def test_store_advance_step(store): def test_store_labels(store, fake_keypoints): + assert store.layer_id == id(store.layer) assert store.n_steps == fake_keypoints.shape[0] - # Labels are derived from the header (bodyparts level); this asserts the store mirrors header order assert store.labels == list(fake_keypoints.columns.get_level_values("bodyparts").unique()) @@ -49,6 +50,7 @@ def test_store_keypoints(store, fake_keypoints): store.next_keypoint() +@pytest.mark.usefixtures("qtbot") def test_point_resize(qtbot, viewer, points): viewer.layers.selection.add(points) layer = viewer.layers[0] @@ -76,7 +78,7 @@ def test_add_unannotated(store): # IMPORTANT: pass coord with the CURRENT frame index so we truly add to the frame we're on # Data layout is (frame, y, x) - keypoints._add(store, coord=(ind_to_remove, 1, 1)) + store.add((ind_to_remove, 1, 1)) # Exactly one new point should be appended assert store.layer.data.shape[0] == n_points + 1 @@ -99,7 +101,7 @@ def test_add_quick(store): # Add (or move) at the CURRENT frame; coord uses (frame, y, x) coord = store.current_step, -1, -1 - keypoints._add(store, coord=coord) + store.add(coord) # After QUICK add/move, the point for the current frame should match the requested coord # (If it existed, it was moved; if not, it was added.) @@ -107,3 +109,47 @@ def test_add_quick(store): store.layer.data[store.current_step], coord, ) + + +def test_store_can_attach_layer_resolver(store): + original_layer = store.layer + layer_id = id(original_layer) + + # Resolver returns the original live layer by id. + store.attach_layer_resolver(lambda requested_id: original_layer if requested_id == layer_id else None) + + assert store.layer_id == layer_id + assert store.maybe_layer() is original_layer + assert store.layer is original_layer + + +def test_store_layer_raises_when_resolver_returns_none(store): + store.attach_layer_resolver(lambda requested_id: None) + + assert store.maybe_layer() is None + + with pytest.raises(keypoints.LayerUnavailableError): + _ = store.layer + + +def test_store_resolver_is_authoritative_over_local_fallback(store): + + # Even though the store still has fallback refs, resolver should dominate. + store.attach_layer_resolver(lambda requested_id: None) + + assert store.maybe_layer() is None + + with pytest.raises(keypoints.LayerUnavailableError): + _ = store.layer + + +def test_store_layer_setter_updates_layer_id_and_keypoints(store, viewer): + old_layer = store.layer + old_layer_id = store.layer_id + + new_layer = viewer.layers[0].copy() if hasattr(viewer.layers[0], "copy") else old_layer + store.layer = new_layer + + assert store.layer is new_layer + assert store.layer_id == id(new_layer) + assert store.layer_id != old_layer_id or new_layer is old_layer diff --git a/src/napari_deeplabcut/_tests/core/test_metadata.py b/src/napari_deeplabcut/_tests/core/test_metadata.py index cf3347b8..d6cbad65 100644 --- a/src/napari_deeplabcut/_tests/core/test_metadata.py +++ b/src/napari_deeplabcut/_tests/core/test_metadata.py @@ -285,7 +285,7 @@ def fake_model_validate(payload): "kind": "gt", "project_root": "/tmp", "source_relpath_posix": "CollectedData_A.h5", - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", }, "save_target": {"kind": "MACHINE"}, } @@ -361,7 +361,7 @@ def test_prepare_points_payload_migrates_legacy_source_h5(monkeypatch): monkeypatch.setattr( metadata_mod, "_build_io_from_source_h5", - lambda src, dataset_key="keypoints": {"kind": AnnotationKind.GT, "dataset_key": dataset_key}, + lambda src, dataset_key="df_with_missing": {"kind": AnnotationKind.GT, "dataset_key": dataset_key}, ) payload = metadata_mod._prepare_points_payload( @@ -370,7 +370,7 @@ def test_prepare_points_payload_migrates_legacy_source_h5(monkeypatch): ) assert payload["io"]["kind"] == AnnotationKind.GT - assert payload["io"]["dataset_key"] == "keypoints" + assert payload["io"]["dataset_key"] == "df_with_missing" # ----------------------------------------------------------------------------- @@ -394,7 +394,7 @@ def test_attach_source_and_io_to_layer_kwargs_sets_legacy_fields_and_io(monkeypa assert inner["source_h5"].endswith("CollectedData_Jane.h5") assert inner["io"]["kind"] == AnnotationKind.GT assert inner["io"]["source_relpath_posix"] == "CollectedData_Jane.h5" - assert inner["io"]["dataset_key"] == "keypoints" + assert inner["io"]["dataset_key"] == "df_with_missing" # ----------------------------------------------------------------------------- @@ -532,14 +532,14 @@ def test_build_io_provenance_dict_keeps_enum_kind_object(tmp_path: Path): project_root=tmp_path, source_relpath_posix="CollectedData_Jane.h5", kind=AnnotationKind.GT, - dataset_key="keypoints", + dataset_key="df_with_missing", ) # mode="python" => should keep enum object at runtime assert isinstance(d["kind"], AnnotationKind) assert d["kind"] == AnnotationKind.GT assert d["project_root"] == str(tmp_path) assert d["source_relpath_posix"] == "CollectedData_Jane.h5" - assert d["dataset_key"] == "keypoints" + assert d["dataset_key"] == "df_with_missing" assert d["schema_version"] == 1 @@ -548,6 +548,6 @@ def test_build_io_provenance_dict_excludes_none_fields(tmp_path: Path): project_root=tmp_path, source_relpath_posix="CollectedData_Jane.h5", kind=None, # exclude_none=True => kind should be absent - dataset_key="keypoints", + dataset_key="df_with_missing", ) assert "kind" not in d diff --git a/src/napari_deeplabcut/_tests/core/test_provenance.py b/src/napari_deeplabcut/_tests/core/test_provenance.py index 36856c54..6a202579 100644 --- a/src/napari_deeplabcut/_tests/core/test_provenance.py +++ b/src/napari_deeplabcut/_tests/core/test_provenance.py @@ -26,7 +26,7 @@ def test_ensure_io_provenance_accepts_model_instance(tmp_path: Path): project_root=str(tmp_path), source_relpath_posix="CollectedData_Jane.h5", kind=AnnotationKind.GT, - dataset_key="keypoints", + dataset_key="df_with_missing", ) out = ensure_io_provenance(io) assert out is io @@ -38,7 +38,7 @@ def test_ensure_io_provenance_accepts_dict_with_enum_kind(tmp_path: Path): "project_root": str(tmp_path), "source_relpath_posix": "CollectedData_Jane.h5", "kind": AnnotationKind.GT, # IMPORTANT: enum instance, not string - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", } out = ensure_io_provenance(payload) assert isinstance(out, IOProvenance) @@ -57,7 +57,7 @@ def test_ensure_io_provenance_rejects_dict_with_string_kind(tmp_path: Path): "project_root": str(tmp_path), "source_relpath_posix": "CollectedData_Jane.h5", "kind": "gt", # invalid at runtime by policy - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", } with pytest.raises(MissingProvenanceError): ensure_io_provenance(payload) @@ -69,7 +69,7 @@ def test_ensure_io_provenance_rejects_invalid_kind_value(tmp_path: Path): "project_root": str(tmp_path), "source_relpath_posix": "CollectedData_Jane.h5", "kind": "not-a-kind", - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", } with pytest.raises(MissingProvenanceError): ensure_io_provenance(payload) @@ -86,7 +86,7 @@ def test_ensure_io_provenance_rejects_missing_required_relpath(tmp_path: Path): "project_root": str(tmp_path), # missing source_relpath_posix "kind": AnnotationKind.GT, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", } # ensure_io_provenance validates; resolve_provenance_path is stricter about missing relpath out = ensure_io_provenance(payload) @@ -108,7 +108,7 @@ def test_normalize_provenance_converts_backslashes(tmp_path: Path): project_root=str(tmp_path), source_relpath_posix=r"labeled-data\test\CollectedData_Jane.h5", kind=AnnotationKind.GT, - dataset_key="keypoints", + dataset_key="df_with_missing", ) out = normalize_provenance(io) assert out is not None @@ -135,7 +135,7 @@ def test_resolve_provenance_path_uses_root_anchor_when_provided(tmp_path: Path): "project_root": str(other_root), # valid dir, but not where file exists "source_relpath_posix": "CollectedData_Jane.h5", "kind": AnnotationKind.GT, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", } resolved = resolve_provenance_path(io, root_anchor=anchor) @@ -152,7 +152,7 @@ def test_resolve_provenance_path_uses_project_root_when_root_anchor_missing(tmp_ "project_root": str(root), "source_relpath_posix": "CollectedData_Jane.h5", "kind": AnnotationKind.GT, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", } resolved = resolve_provenance_path(io, root_anchor=None) @@ -165,7 +165,7 @@ def test_resolve_provenance_path_requires_source_relpath_posix(tmp_path: Path): "project_root": str(tmp_path), "source_relpath_posix": None, "kind": AnnotationKind.GT, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", } with pytest.raises(MissingProvenanceError): resolve_provenance_path(payload) @@ -177,7 +177,7 @@ def test_resolve_provenance_path_requires_anchor_or_project_root(tmp_path: Path) "project_root": None, "source_relpath_posix": "CollectedData_Jane.h5", "kind": AnnotationKind.GT, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", } with pytest.raises(UnresolvablePathError): resolve_provenance_path(payload, root_anchor=None) @@ -189,7 +189,7 @@ def test_resolve_provenance_path_raises_if_missing_by_default(tmp_path: Path): "project_root": str(tmp_path), "source_relpath_posix": "CollectedData_Jane.h5", "kind": AnnotationKind.GT, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", } with pytest.raises(UnresolvablePathError): resolve_provenance_path(payload, allow_missing=False) @@ -201,7 +201,7 @@ def test_resolve_provenance_path_allows_missing_when_flag_true(tmp_path: Path): "project_root": str(tmp_path), "source_relpath_posix": "CollectedData_Jane.h5", "kind": AnnotationKind.GT, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", } resolved = resolve_provenance_path(payload, allow_missing=True) assert resolved == tmp_path / "CollectedData_Jane.h5" @@ -217,7 +217,7 @@ def test_resolve_provenance_path_normalizes_backslashes(tmp_path: Path): "project_root": str(tmp_path), "source_relpath_posix": r"labeled-data\CollectedData_Jane.h5", "kind": AnnotationKind.GT, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", } resolved = resolve_provenance_path(payload) assert resolved == tmp_path / "labeled-data" / "CollectedData_Jane.h5" diff --git a/src/napari_deeplabcut/_tests/core/test_reader_layerdata_contract.py b/src/napari_deeplabcut/_tests/core/test_reader_layerdata_contract.py index 4cd165d8..b42825c1 100644 --- a/src/napari_deeplabcut/_tests/core/test_reader_layerdata_contract.py +++ b/src/napari_deeplabcut/_tests/core/test_reader_layerdata_contract.py @@ -19,7 +19,7 @@ def _write_minimal_h5(path: Path, scorer: str, with_likelihood: bool = False, al row = [10.0, 20.0] + ([0.9] if with_likelihood else []) df = pd.DataFrame([row], index=["img000.png"], columns=cols) path.parent.mkdir(parents=True, exist_ok=True) - df.to_hdf(path, key="keypoints", mode="w") + df.to_hdf(path, key="df_with_missing", mode="w") return df @@ -68,7 +68,7 @@ def test_read_hdf_single_filters_data_and_properties_consistently(tmp_path: Path h5 = tmp_path / "CollectedData_John.h5" cols = pd.MultiIndex.from_product([["John"], ["bp1", "bp2"], ["x", "y"]], names=["scorer", "bodyparts", "coords"]) df = pd.DataFrame([[10.0, 20.0, np.nan, np.nan]], index=["img000.png"], columns=cols) - df.to_hdf(h5, key="keypoints", mode="w") + df.to_hdf(h5, key="df_with_missing", mode="w") layers = read_hdf_single(h5, kind=AnnotationKind.GT) data, meta, _ = layers[0] diff --git a/src/napari_deeplabcut/_tests/core/test_writer_promotion.py b/src/napari_deeplabcut/_tests/core/test_writer_promotion.py index 1a8edcee..7c85dc21 100644 --- a/src/napari_deeplabcut/_tests/core/test_writer_promotion.py +++ b/src/napari_deeplabcut/_tests/core/test_writer_promotion.py @@ -32,7 +32,7 @@ def _make_minimal_points_metadata( "project_root": str(root), "source_relpath_posix": f"{name}.h5", "kind": kind, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", }, }, } @@ -42,7 +42,7 @@ def _make_minimal_points_metadata( def _read_keypoints_h5(p: Path) -> pd.DataFrame: - return pd.read_hdf(p, key="keypoints") + return pd.read_hdf(p, key="df_with_missing") def test_writer_aborts_if_machine_source_without_save_target(tmp_path: Path): @@ -86,7 +86,7 @@ def test_writer_promotion_writes_collecteddata_and_rewrites_scorer(tmp_path: Pat "project_root": str(tmp_path), "source_relpath_posix": "CollectedData_Alice.h5", "kind": AnnotationKind.GT, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", "scorer": "Alice", } @@ -101,7 +101,7 @@ def test_writer_promotion_writes_collecteddata_and_rewrites_scorer(tmp_path: Pat # Create a dummy machine file and snapshot it (writer must not touch it) machine_path = tmp_path / "machinelabels-iter0.h5" df_machine = pd.DataFrame(np.nan, columns=cols, index=["img000.png"]) - df_machine.to_hdf(machine_path, key="keypoints", mode="w") + df_machine.to_hdf(machine_path, key="df_with_missing", mode="w") machine_before = _read_keypoints_h5(machine_path) points = np.array( diff --git a/src/napari_deeplabcut/_tests/e2e/conftest.py b/src/napari_deeplabcut/_tests/e2e/conftest.py index 31060437..8fee2958 100644 --- a/src/napari_deeplabcut/_tests/e2e/conftest.py +++ b/src/napari_deeplabcut/_tests/e2e/conftest.py @@ -98,9 +98,9 @@ def _patched_maybe_confirm_overwrite(parent, report): return state["result"] - import napari_deeplabcut.ui.dialogs as dlg + import napari_deeplabcut.ui.ui_dialogs.save as save_dlg - monkeypatch.setattr(dlg, "maybe_confirm_overwrite", _patched_maybe_confirm_overwrite) + monkeypatch.setattr(save_dlg, "maybe_confirm_overwrite", _patched_maybe_confirm_overwrite) class Controller: @property diff --git a/src/napari_deeplabcut/_tests/e2e/test_layer_coloring.py b/src/napari_deeplabcut/_tests/e2e/test_layer_coloring.py index 685f5a8c..e71d5d27 100644 --- a/src/napari_deeplabcut/_tests/e2e/test_layer_coloring.py +++ b/src/napari_deeplabcut/_tests/e2e/test_layer_coloring.py @@ -55,7 +55,7 @@ def test_config_placeholder_points_layer_colors_after_first_keypoint_added(viewe assert "header" in md, "Expected header in metadata for config.yaml placeholder layer" # 3) Begin editing: add bodypart1 then bodypart2 - store = controls._stores.get(placeholder) + store = controls.get_layer_store(placeholder) assert store is not None, "Expected KeypointStore to be registered for placeholder Points layer" # Add first point @@ -134,7 +134,7 @@ def test_config_placeholder_multianimal_colors_by_id_after_first_keypoint_added( assert "animal2" in id_cycles, f"Expected 'animal2' in derived id cycles; got keys={list(id_cycles)[:10]}" # 3) Begin editing: add a point for animal1, then animal2 - store = controls._stores.get(placeholder) + store = controls.get_layer_store(placeholder) assert store is not None, "Expected KeypointStore for placeholder Points layer" # Add first point: (frame, y, x) @@ -212,7 +212,7 @@ def test_color_scheme_panel_toggle_shows_active_then_full_config_bodyparts( # Make sure the placeholder is the active target layer viewer.layers.selection.active = placeholder - store = controls._stores.get(placeholder) + store = controls.get_layer_store(placeholder) assert store is not None # Deterministically add bodypart1 @@ -273,8 +273,8 @@ def test_color_scheme_panel_multianimal_toggle_shows_active_then_full_config_ind assert placeholder.data is None or len(placeholder.data) == 0 # Wait until the existing controls instance has wired the layer - qtbot.waitUntil(lambda: placeholder in controls._stores, timeout=5_000) - store = controls._stores.get(placeholder) + qtbot.waitUntil(lambda: controls.get_layer_store(placeholder) is not None, timeout=5_000) + store = controls.get_layer_store(placeholder) assert store is not None # This assertion is now valid because we're using the controls instance diff --git a/src/napari_deeplabcut/_tests/e2e/test_overwrite_and_merge.py b/src/napari_deeplabcut/_tests/e2e/test_overwrite_and_merge.py index c4a40ca5..d4e4a4ce 100644 --- a/src/napari_deeplabcut/_tests/e2e/test_overwrite_and_merge.py +++ b/src/napari_deeplabcut/_tests/e2e/test_overwrite_and_merge.py @@ -13,7 +13,7 @@ @pytest.mark.usefixtures("qtbot") -def test_config_first_hazard_regression_no_silent_deletion(viewer, keypoint_controls, qtbot, tmp_path, caplog): +def test_config_regression_no_silent_deletion(viewer, keypoint_controls, qtbot, tmp_path, caplog): """ Regression for the original report: Save the WRONG (placeholder) layer and still preserve previous labels due to merge-on-save. @@ -22,7 +22,7 @@ def test_config_first_hazard_regression_no_silent_deletion(viewer, keypoint_cont project, config_path, labeled_folder, h5_path = _make_minimal_dlc_project(tmp_path) - pre = pd.read_hdf(h5_path, key="keypoints") + pre = pd.read_hdf(h5_path, key="df_with_missing") assert np.isfinite(_get_coord_from_df(pre, "bodypart1", "x")) assert np.isnan(_get_coord_from_df(pre, "bodypart2", "x")) @@ -46,7 +46,7 @@ def test_config_first_hazard_regression_no_silent_deletion(viewer, keypoint_cont # Placeholder should still be present for this regression to apply assert placeholder in viewer.layers - store = keypoint_controls._stores.get(placeholder) + store = keypoint_controls.get_layer_store(placeholder) assert store is not None # Add a new bodypart2 point to placeholder using (frame, y, x) @@ -57,7 +57,7 @@ def test_config_first_hazard_regression_no_silent_deletion(viewer, keypoint_cont viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") qtbot.wait(200) - post = pd.read_hdf(h5_path, key="keypoints") + post = pd.read_hdf(h5_path, key="df_with_missing") b1x_post = _get_coord_from_df(post, "bodypart1", "x") b2x_post = _get_coord_from_df(post, "bodypart2", "x") @@ -119,7 +119,7 @@ def test_no_overwrite_warning_when_only_filling_nans(viewer, keypoint_controls, logger.info("any NaNs in points.data = %s", np.isnan(points.data).any()) logger.info("labels[:10] = %s", points.properties.get("label")[:10]) logger.info("ids[:10] = %s", points.properties.get("id")[:10] if "id" in points.properties else None) - store = keypoint_controls._stores.get(points) + store = keypoint_controls.get_layer_store(points) assert store is not None # Fill NaNs for bodypart2 @@ -129,7 +129,7 @@ def test_no_overwrite_warning_when_only_filling_nans(viewer, keypoint_controls, viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") qtbot.wait(200) - post = pd.read_hdf(h5_path, key="keypoints") + post = pd.read_hdf(h5_path, key="df_with_missing") assert np.isfinite(_get_coord_from_df(post, "bodypart1", "x")) assert np.isfinite(_get_coord_from_df(post, "bodypart2", "x")) @@ -149,7 +149,7 @@ def test_overwrite_warning_triggers_on_conflict(viewer, keypoint_controls, qtbot qtbot.wait(200) points = _get_points_layer_with_data(viewer) - store = keypoint_controls._stores.get(points) + store = keypoint_controls.get_layer_store(points) assert store is not None # Create conflict: overwrite bodypart1 from (10,20) -> (99,88) @@ -163,7 +163,7 @@ def test_overwrite_warning_triggers_on_conflict(viewer, keypoint_controls, qtbot assert overwrite_confirm.calls[0]["n_pairs"] is not None assert overwrite_confirm.calls[0]["n_pairs"] >= 1 - post = pd.read_hdf(h5_path, key="keypoints") + post = pd.read_hdf(h5_path, key="df_with_missing") assert _get_coord_from_df(post, "bodypart1", "x") == 99.0 @@ -176,7 +176,7 @@ def test_overwrite_warning_cancel_aborts_write(viewer, keypoint_controls, qtbot, project, config_path, labeled_folder, h5_path = _make_minimal_dlc_project(tmp_path) - pre = pd.read_hdf(h5_path, key="keypoints") + pre = pd.read_hdf(h5_path, key="df_with_missing") b1x_pre = _get_coord_from_df(pre, "bodypart1", "x") b1y_pre = _get_coord_from_df(pre, "bodypart1", "y") @@ -187,7 +187,7 @@ def test_overwrite_warning_cancel_aborts_write(viewer, keypoint_controls, qtbot, qtbot.wait(200) points = _get_points_layer_with_data(viewer) - store = keypoint_controls._stores.get(points) + store = keypoint_controls.get_layer_store(points) assert store is not None _set_or_add_bodypart_xy(points, store, "bodypart1", x=456.0, y=123.0) @@ -203,6 +203,6 @@ def test_overwrite_warning_cancel_aborts_write(viewer, keypoint_controls, qtbot, assert len(overwrite_confirm.calls) == 1, "Expected overwrite confirmation to be requested once." - post = pd.read_hdf(h5_path, key="keypoints") + post = pd.read_hdf(h5_path, key="df_with_missing") assert _get_coord_from_df(post, "bodypart1", "x") == b1x_pre assert _get_coord_from_df(post, "bodypart1", "y") == b1y_pre diff --git a/src/napari_deeplabcut/_tests/e2e/test_points_layers.py b/src/napari_deeplabcut/_tests/e2e/test_points_layers.py index 84687131..f41f4057 100644 --- a/src/napari_deeplabcut/_tests/e2e/test_points_layers.py +++ b/src/napari_deeplabcut/_tests/e2e/test_points_layers.py @@ -158,8 +158,8 @@ def test_copy_paste_points_to_new_frame_does_not_crash_and_offsets_frame( layer = viewer.add_points(data, **md) assert isinstance(layer, Points) - qtbot.waitUntil(lambda: layer in controls._stores, timeout=5_000) - assert layer in controls._stores + qtbot.waitUntil(lambda: controls.get_layer_store(layer) is not None, timeout=5_000) + assert controls.get_layer_store(layer) is not None # frame 0: select and copy viewer.dims.set_point(0, 0) diff --git a/src/napari_deeplabcut/_tests/e2e/test_routing_and_provenance.py b/src/napari_deeplabcut/_tests/e2e/test_routing_and_provenance.py index 149d02a2..243e1056 100644 --- a/src/napari_deeplabcut/_tests/e2e/test_routing_and_provenance.py +++ b/src/napari_deeplabcut/_tests/e2e/test_routing_and_provenance.py @@ -1,3 +1,4 @@ +import importlib import logging import numpy as np @@ -5,7 +6,9 @@ import pytest from napari.layers import Points -from napari_deeplabcut.core.io import AnnotationKind, MissingProvenanceError +import napari_deeplabcut._widgets as widgets_mod +from napari_deeplabcut.config.models import AnnotationKind +from napari_deeplabcut.core.errors import MissingProvenanceError from .utils import ( _assert_only_these_files_changed, @@ -21,23 +24,35 @@ @pytest.fixture -def forbid_project_config_dialog(monkeypatch): +def save_workflow_mod(): + """ + Module object where PointsLayerSaveWorkflow is defined and whose imported + names must be monkeypatched for save-flow tests. + """ + return importlib.import_module(widgets_mod.PointsLayerSaveWorkflow.__module__) + + +@pytest.fixture +def forbid_project_config_dialog(monkeypatch, save_workflow_mod): monkeypatch.setattr( - "napari_deeplabcut._widgets.ui_dialogs.prompt_for_project_config_for_save", + save_workflow_mod, + "prompt_for_project_config_for_save", lambda *args, **kwargs: pytest.fail("Unexpected project-config dialog."), ) monkeypatch.setattr( - "napari_deeplabcut._widgets.ui_dialogs.maybe_confirm_dataset_path_rewrite", + save_workflow_mod, + "maybe_confirm_dataset_path_rewrite", lambda *args, **kwargs: pytest.fail("Unexpected dataset path rewrite confirmation."), ) monkeypatch.setattr( - "napari_deeplabcut._widgets.ui_dialogs.warn_existing_dataset_folder_conflict", + save_workflow_mod, + "warn_existing_dataset_folder_conflict", lambda *args, **kwargs: pytest.fail("Unexpected dataset-folder conflict warning."), ) @pytest.fixture -def skip_project_config_dialog(monkeypatch): +def skip_project_config_dialog(monkeypatch, save_workflow_mod): """ Simulate the new promotion policy when no config.yaml exists. @@ -56,10 +71,7 @@ def _skip(*args, **kwargs): action=ui_dialogs.ProjectConfigPromptAction.SKIP, ) - monkeypatch.setattr( - "napari_deeplabcut._widgets.ui_dialogs.prompt_for_project_config_for_save", - _skip, - ) + monkeypatch.setattr(save_workflow_mod, "prompt_for_project_config_for_save", _skip) return calls @@ -89,7 +101,7 @@ def test_save_routes_to_correct_gt_when_multiple_gt_exist( points_b = next((ly for ly in viewer.layers if isinstance(ly, Points) and ly.name == gt_b.stem), None) assert points_b is not None, f"Expected a Points layer named {gt_b.stem}" - store_b = keypoint_controls._stores.get(points_b) + store_b = keypoint_controls.get_layer_store(points_b) assert store_b is not None # Fill NaNs for bodypart2 in B only (no overwrite dialog) @@ -136,7 +148,7 @@ def test_machine_layer_does_not_modify_gt_on_save(viewer, keypoint_controls, qtb machine_layer = next((ly for ly in viewer.layers if isinstance(ly, Points) and ly.name == machine_path.stem), None) assert machine_layer is not None - store = keypoint_controls._stores.get(machine_layer) + store = keypoint_controls.get_layer_store(machine_layer) assert store is not None # Fill NaNs in machine file (no overwrite prompt) @@ -177,7 +189,7 @@ def test_layer_rename_does_not_change_save_target(viewer, keypoint_controls, qtb layer = next((ly for ly in viewer.layers if isinstance(ly, Points) and ly.name == gt_a.stem), None) assert layer is not None - store = keypoint_controls._stores.get(layer) + store = keypoint_controls.get_layer_store(layer) assert store is not None # Rename in UI @@ -230,7 +242,7 @@ def test_ambiguous_placeholder_save_aborts_when_multiple_gt_exist( viewer.open(str(labeled_folder), plugin="napari-deeplabcut") qtbot.wait(200) - store = keypoint_controls._stores.get(placeholder) + store = keypoint_controls.get_layer_store(placeholder) assert store is not None # Add a point to placeholder @@ -311,7 +323,7 @@ def test_config_first_save_writes_gt_into_dataset_folder(viewer, keypoint_contro assert pts_layers, "Expected a Points layer from config.yaml" points = pts_layers[0] - store = keypoint_controls._stores.get(points) + store = keypoint_controls.get_layer_store(points) assert store is not None # Add a point and save @@ -344,7 +356,7 @@ def test_promotion_first_save_skip_config_then_prompt_scorer_and_create_sidecar( labeled_folder = _make_labeled_folder_with_machine_only(tmp_path) machine_path = labeled_folder / "machinelabels-iter0.h5" - machine_pre = pd.read_hdf(machine_path, key="keypoints") + machine_pre = pd.read_hdf(machine_path, key="df_with_missing") # Open folder viewer.open(str(labeled_folder), plugin="napari-deeplabcut") @@ -357,7 +369,7 @@ def test_promotion_first_save_skip_config_then_prompt_scorer_and_create_sidecar( machine_layer = next(p for p in pts_layers if p.name == "machinelabels-iter0") # Edit: add bodypart2 (use helper that works across versions) - store = keypoint_controls._stores.get(machine_layer) + store = keypoint_controls.get_layer_store(machine_layer) assert store is not None _set_or_add_bodypart_xy(machine_layer, store, "bodypart2", x=44.0, y=33.0) @@ -394,7 +406,7 @@ def test_promotion_first_save_skip_config_then_prompt_scorer_and_create_sidecar( assert gt_path.exists() # Machine file unchanged - machine_post = pd.read_hdf(machine_path, key="keypoints") + machine_post = pd.read_hdf(machine_path, key="df_with_missing") pd.testing.assert_frame_equal(machine_pre, machine_post) @@ -417,7 +429,7 @@ def test_promotion_second_save_skip_config_then_use_sidecar_without_scorer_promp sidecar.write_text('{"schema_version": 1, "default_scorer": "Alice"}', encoding="utf-8") machine_path = labeled_folder / "machinelabels-iter0.h5" - machine_pre = pd.read_hdf(machine_path, key="keypoints") + machine_pre = pd.read_hdf(machine_path, key="df_with_missing") controls = keypoint_controls viewer.window.add_dock_widget(controls, name="Keypoint controls", area="right") @@ -429,7 +441,7 @@ def test_promotion_second_save_skip_config_then_use_sidecar_without_scorer_promp pts_layers = [ly for ly in viewer.layers if isinstance(ly, Points)] machine_layer = next(p for p in pts_layers if p.name == "machinelabels-iter0") - store = controls._stores.get(machine_layer) + store = controls.get_layer_store(machine_layer) assert store is not None _set_or_add_bodypart_xy(machine_layer, store, "bodypart1", x=99.0, y=88.0) @@ -445,7 +457,7 @@ def test_promotion_second_save_skip_config_then_use_sidecar_without_scorer_promp gt_path = labeled_folder / "CollectedData_Alice.h5" assert gt_path.exists() - machine_post = pd.read_hdf(machine_path, key="keypoints") + machine_post = pd.read_hdf(machine_path, key="df_with_missing") pd.testing.assert_frame_equal(machine_pre, machine_post) controls._save_layers_dialog(selected=True) @@ -464,6 +476,7 @@ def test_projectless_folder_save_can_associate_with_config_and_coerce_paths_to_d tmp_path, monkeypatch, overwrite_confirm, + save_workflow_mod, ): """ Contract: a project-less labeled folder can be associated with a chosen DLC @@ -506,7 +519,7 @@ def test_projectless_folder_save_can_associate_with_config_and_coerce_paths_to_d qtbot.waitUntil(lambda: any(isinstance(ly, Points) for ly in viewer.layers), timeout=5_000) points = next(ly for ly in viewer.layers if isinstance(ly, Points)) - store = keypoint_controls._stores.get(points) + store = keypoint_controls.get_layer_store(points) assert store is not None # Simulate project-less folder metadata: @@ -523,31 +536,32 @@ def test_projectless_folder_save_can_associate_with_config_and_coerce_paths_to_d store.current_keypoint = keypoints.Keypoint("bodypart1", "") points.add(np.array([0.0, 11.0, 22.0], dtype=float)) + import napari_deeplabcut.core.conflicts as conflicts from napari_deeplabcut.ui import dialogs as ui_dialogs + real_compute = conflicts.compute_overwrite_report_for_points_save + captured = {} + + def _wrapped_compute(data, attributes): + captured["attributes"] = attributes + return real_compute(data, attributes) + monkeypatch.setattr( - "napari_deeplabcut._widgets.ui_dialogs.prompt_for_project_config_for_save", + save_workflow_mod, + "prompt_for_project_config_for_save", lambda *args, **kwargs: ui_dialogs.ProjectConfigPromptResult( action=ui_dialogs.ProjectConfigPromptAction.ASSOCIATE, config_path=str(config_path), ), ) monkeypatch.setattr( - "napari_deeplabcut._widgets.ui_dialogs.maybe_confirm_dataset_path_rewrite", + save_workflow_mod, + "maybe_confirm_dataset_path_rewrite", lambda *args, **kwargs: True, ) - - import napari_deeplabcut.core.conflicts as conflicts - - real_compute = conflicts.compute_overwrite_report_for_points_save - captured = {} - - def _wrapped_compute(data, attributes): - captured["attributes"] = attributes - return real_compute(data, attributes) - monkeypatch.setattr( - "napari_deeplabcut._widgets.compute_overwrite_report_for_points_save", + save_workflow_mod, + "compute_overwrite_report_for_points_save", _wrapped_compute, ) @@ -587,7 +601,7 @@ def _wrapped_compute(data, attributes): assert points.metadata["paths"] == expected_paths # H5 row index contains canonical DLC row keys for the safe cases - df = pd.read_hdf(expected_h5, key="keypoints") + df = pd.read_hdf(expected_h5, key="df_with_missing") if isinstance(df.index, pd.MultiIndex): observed_rows = ["/".join(map(str, idx)) for idx in df.index] else: @@ -607,6 +621,7 @@ def test_projectless_folder_save_refuses_when_target_dataset_folder_already_cont tmp_path, monkeypatch, overwrite_confirm, + save_workflow_mod, ): """ Contract: project-association save must refuse if the target dataset folder @@ -636,7 +651,7 @@ def test_projectless_folder_save_refuses_when_target_dataset_folder_already_cont qtbot.waitUntil(lambda: any(isinstance(ly, Points) for ly in viewer.layers), timeout=5_000) points = next(ly for ly in viewer.layers if isinstance(ly, Points)) - store = keypoint_controls._stores.get(points) + store = keypoint_controls.get_layer_store(points) assert store is not None points.metadata = dict(points.metadata or {}) @@ -652,18 +667,21 @@ def test_projectless_folder_save_refuses_when_target_dataset_folder_already_cont from napari_deeplabcut.ui import dialogs as ui_dialogs monkeypatch.setattr( - "napari_deeplabcut._widgets.ui_dialogs.prompt_for_project_config_for_save", + save_workflow_mod, + "prompt_for_project_config_for_save", lambda *args, **kwargs: ui_dialogs.ProjectConfigPromptResult( action=ui_dialogs.ProjectConfigPromptAction.ASSOCIATE, config_path=str(config_path), ), ) monkeypatch.setattr( - "napari_deeplabcut._widgets.ui_dialogs.warn_existing_dataset_folder_conflict", + save_workflow_mod, + "warn_existing_dataset_folder_conflict", lambda *args, **kwargs: warned.setdefault("called", True), ) monkeypatch.setattr( - "napari_deeplabcut._widgets.ui_dialogs.maybe_confirm_dataset_path_rewrite", + save_workflow_mod, + "maybe_confirm_dataset_path_rewrite", lambda *args, **kwargs: True, ) @@ -687,6 +705,7 @@ def test_promotion_nearby_config_wins_no_dialog_no_prompt( tmp_path, monkeypatch, inputdialog, + save_workflow_mod, ): """ If a valid DLC config.yaml is discoverable near a machine-labeled layer, @@ -705,7 +724,7 @@ def test_promotion_nearby_config_wins_no_dialog_no_prompt( sidecar = labeled_folder / ".napari-deeplabcut.json" sidecar.write_text('{"schema_version": 1, "default_scorer": "Alice"}', encoding="utf-8") - machine_pre = pd.read_hdf(machine_path, key="keypoints") + machine_pre = pd.read_hdf(machine_path, key="df_with_missing") dialog_calls = {"count": 0} @@ -714,7 +733,8 @@ def _unexpected_config_dialog(*args, **kwargs): pytest.fail("Config-selection dialog must not appear when nearby config.yaml is auto-discovered.") monkeypatch.setattr( - "napari_deeplabcut._widgets.ui_dialogs.prompt_for_project_config_for_save", + save_workflow_mod, + "prompt_for_project_config_for_save", _unexpected_config_dialog, ) @@ -728,7 +748,7 @@ def _unexpected_config_dialog(*args, **kwargs): pts_layers = [ly for ly in viewer.layers if isinstance(ly, Points)] machine_layer = next(p for p in pts_layers if p.name == machine_path.stem) - store = keypoint_controls._stores.get(machine_layer) + store = keypoint_controls.get_layer_store(machine_layer) assert store is not None _set_or_add_bodypart_xy(machine_layer, store, "bodypart2", x=54.0, y=43.0) @@ -749,7 +769,7 @@ def _unexpected_config_dialog(*args, **kwargs): assert expected_gt.exists(), f"Expected GT with config scorer to be created: {expected_gt}" assert not unexpected_gt.exists(), f"Sidecar scorer must be ignored when config.yaml is nearby: {unexpected_gt}" - machine_post = pd.read_hdf(machine_path, key="keypoints") + machine_post = pd.read_hdf(machine_path, key="df_with_missing") pd.testing.assert_frame_equal(machine_pre, machine_post) @@ -761,6 +781,7 @@ def test_promotion_selected_external_config_wins_no_scorer_prompt( tmp_path, monkeypatch, inputdialog, + save_workflow_mod, ): """ If no nearby config.yaml is found, but the user points the save flow to a @@ -771,7 +792,7 @@ def test_promotion_selected_external_config_wins_no_scorer_prompt( """ labeled_folder = _make_labeled_folder_with_machine_only(tmp_path) machine_path = labeled_folder / "machinelabels-iter0.h5" - machine_pre = pd.read_hdf(machine_path, key="keypoints") + machine_pre = pd.read_hdf(machine_path, key="df_with_missing") # External DLC project whose config scorer should be used. external_project, external_config_path, _external_dataset = _make_project_config_and_frames_no_gt( @@ -797,7 +818,8 @@ def _choose_external_config(*args, **kwargs): ) monkeypatch.setattr( - "napari_deeplabcut._widgets.ui_dialogs.prompt_for_project_config_for_save", + save_workflow_mod, + "prompt_for_project_config_for_save", _choose_external_config, ) @@ -811,7 +833,7 @@ def _choose_external_config(*args, **kwargs): pts_layers = [ly for ly in viewer.layers if isinstance(ly, Points)] machine_layer = next(p for p in pts_layers if p.name == "machinelabels-iter0") - store = keypoint_controls._stores.get(machine_layer) + store = keypoint_controls.get_layer_store(machine_layer) assert store is not None _set_or_add_bodypart_xy(machine_layer, store, "bodypart1", x=91.0, y=82.0) @@ -834,5 +856,102 @@ def _choose_external_config(*args, **kwargs): f"Sidecar scorer must be ignored when a valid external config is selected: {unexpected_gt}" ) - machine_post = pd.read_hdf(machine_path, key="keypoints") + machine_post = pd.read_hdf(machine_path, key="df_with_missing") pd.testing.assert_frame_equal(machine_pre, machine_post) + + +@pytest.mark.usefixtures("qtbot") +def test_direct_video_labeling_save_is_blocked_without_paths( + viewer, + keypoint_controls, + qtbot, + tmp_path, + monkeypatch, + overwrite_confirm, +): + """ + Unsupported workflow guard: + - user has a video layer open + - user adds config.yaml / placeholder points layer + - points layer has no extracted-frame paths + - save must abort with a warning before overwrite preflight / writer save + """ + overwrite_confirm.forbid() + + project, config_path, labeled_folder = _make_project_config_and_frames_no_gt(tmp_path) + + # 1) Open config first -> placeholder Points layer with valid DLC header + viewer.open(str(config_path), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: any(isinstance(ly, Points) for ly in viewer.layers), timeout=5_000) + qtbot.wait(200) + + points = next(ly for ly in viewer.layers if isinstance(ly, Points)) + store = keypoint_controls.get_layer_store(points) + assert store is not None + + # Ensure the placeholder truly has no extracted-frame paths + points.metadata = dict(points.metadata or {}) + points.metadata["paths"] = [] + # Keep a root/project hint if your normal workflow would have one + points.metadata.setdefault("project", str(project)) + points.metadata.setdefault("root", str(labeled_folder)) + + # Add one point so the layer looks "dirty" / save-worthy + _set_or_add_bodypart_xy(points, store, "bodypart1", x=11.0, y=22.0) + + # 2) Add a synthetic video image layer to create the unsupported context + viewer.add_image( + np.zeros((3, 8, 8, 3), dtype=np.uint8), + name="clip.mp4", + metadata={"dlc": {"session_role": "video"}}, + ) + qtbot.wait(100) + + # Select the points layer for save + viewer.layers.selection.active = points + keypoint_controls.viewer.layers.selection.active = points + keypoint_controls.viewer.layers.selection.select_only(points) + + warned = {"called": False} + + # Patch the save workflow module where the imported symbols are actually used + save_mod = importlib.import_module(keypoint_controls._save_workflow.__class__.__module__) + + # If you added a dedicated warning helper, patch that directly (cleanest) + if hasattr(keypoint_controls._save_workflow, "_warn_unsupported_direct_video_label_save"): + monkeypatch.setattr( + keypoint_controls._save_workflow, + "_warn_unsupported_direct_video_label_save", + lambda layer, metadata: warned.__setitem__("called", True), + ) + else: + # Fallback if you still use QMessageBox.warning directly in the workflow module + monkeypatch.setattr( + save_mod.QMessageBox, + "warning", + lambda *args, **kwargs: warned.__setitem__("called", True), + ) + + # Guard must abort BEFORE overwrite preflight + monkeypatch.setattr( + save_mod, + "compute_overwrite_report_for_points_save", + lambda *args, **kwargs: pytest.fail("Overwrite preflight must not run for unsupported direct-video save."), + ) + + # Optional extra safety: writer save must not be reached either + + def _unexpected_save(*args, **kwargs): + pytest.fail("viewer.layers.save must not be called for unsupported direct-video save.") + + monkeypatch.setattr(viewer.layers, "save", _unexpected_save) + + # Call the workflow directly so we can assert on the outcome + outcome = keypoint_controls._save_workflow.save_layers(selected=True) + qtbot.wait(100) + + assert outcome.saved is False + assert warned["called"] is True + + # No GT file should have been created in the dataset folder + assert not (labeled_folder / "CollectedData_John.h5").exists() diff --git a/src/napari_deeplabcut/_tests/e2e/test_save_e2e.py b/src/napari_deeplabcut/_tests/e2e/test_save_e2e.py new file mode 100644 index 00000000..3023192b --- /dev/null +++ b/src/napari_deeplabcut/_tests/e2e/test_save_e2e.py @@ -0,0 +1,256 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +import yaml +from napari.layers import Points + +from napari_deeplabcut.config.models import DLCHeaderModel +from napari_deeplabcut.core.io import _read_hdf_any_key + +from .utils import ( + _make_project_config_and_frames_no_gt, + _set_or_add_bodypart_xy, +) + + +def _assert_single_animal_on_disk(path: Path, *, expected_bodyparts: tuple[str, ...] | None = None) -> pd.DataFrame: + """ + Assert canonical SA on-disk format: + columns = 3-level MultiIndex [scorer, bodyparts, coords] + """ + df = _read_hdf_any_key(path) + assert isinstance(df.columns, pd.MultiIndex), f"Expected MultiIndex columns in {path}" + assert df.columns.nlevels == 3, ( + f"Expected SA 3-level columns in {path}, got {df.columns.nlevels}: {df.columns.names}" + ) + assert list(df.columns.names) == ["scorer", "bodyparts", "coords"], ( + f"Expected SA names ['scorer','bodyparts','coords'], got {df.columns.names}" + ) + + if expected_bodyparts is not None: + observed = tuple(dict.fromkeys(df.columns.get_level_values("bodyparts"))) + assert observed == expected_bodyparts, f"Expected bodyparts {expected_bodyparts}, got {observed}" + + return df + + +def _seed_single_animal_gt( + labeled_folder: Path, + *, + scorer: str = "John", + bodyparts: tuple[str, ...] = ("bodypart1",), +) -> Path: + """ + Create a canonical SA GT file in labeled_folder: + - 3-level columns + - 1 labeled image row + - first bodypart gets finite coords, the rest NaN + """ + image_files = sorted(labeled_folder.glob("*.png")) + assert image_files, f"No extracted frames found in {labeled_folder}" + + img_name = image_files[0].name + dataset_name = labeled_folder.name + + cols = pd.MultiIndex.from_product( + [[scorer], list(bodyparts), ["x", "y"]], + names=["scorer", "bodyparts", "coords"], + ) + idx = pd.MultiIndex.from_tuples( + [("labeled-data", dataset_name, img_name)], + ) + + arr = np.full((1, len(cols)), np.nan, dtype=float) + + # Give the first bodypart one finite label so the file is non-empty + arr[0, 0] = 10.0 # bodypart1 x + arr[0, 1] = 20.0 # bodypart1 y + + df = pd.DataFrame(arr, index=idx, columns=cols) + + gt_path = labeled_folder / f"CollectedData_{scorer}.h5" + df.to_hdf(gt_path, key="df_with_missing", mode="w") + df.to_csv(gt_path.with_suffix(".csv")) + + return gt_path + + +def _append_bodypart_to_config(config_path: Path, bodypart: str) -> None: + cfg = yaml.safe_load(config_path.read_text(encoding="utf-8")) + bodyparts = list(cfg.get("bodyparts", [])) + if bodypart not in bodyparts: + bodyparts.append(bodypart) + cfg["bodyparts"] = bodyparts + config_path.write_text(yaml.safe_dump(cfg), encoding="utf-8") + + +def _header_model_for_layer(layer: Points) -> DLCHeaderModel: + hdr = (layer.metadata or {}).get("header") + if isinstance(hdr, DLCHeaderModel): + return hdr + return DLCHeaderModel.model_validate(hdr) + + +@pytest.mark.usefixtures("qtbot") +def test_single_animal_direct_h5_roundtrip_preserves_sa_format( + viewer, keypoint_controls, qtbot, tmp_path, overwrite_confirm +): + """ + Open a canonical SA GT .h5 directly, edit, save. + + This isolates the plain reader/writer path: + - NO config layer + - NO config merge + - NO placeholder workflow + + If this test fails, then config merge is NOT required to reproduce the bug. + """ + overwrite_confirm.capture() + + project, config_path, labeled_folder = _make_project_config_and_frames_no_gt(tmp_path) + gt_path = _seed_single_animal_gt(labeled_folder, bodyparts=("bodypart1",)) + + # Sanity: seed file really is canonical SA on disk + _assert_single_animal_on_disk(gt_path, expected_bodyparts=("bodypart1",)) + + viewer.open(str(gt_path), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len([ly for ly in viewer.layers if isinstance(ly, Points)]) == 1, timeout=10_000) + qtbot.wait(100) + + layer = next(ly for ly in viewer.layers if isinstance(ly, Points)) + store = keypoint_controls.get_layer_store(layer) + assert store is not None + + # Internal diagnostic: direct H5 read already normalizes SA -> canonical_4 with individuals=[""] + hdr = _header_model_for_layer(layer) + assert hdr.as_multiindex().nlevels == 4 + assert hdr.individuals == [""] + + # Edit existing bodypart and save + _set_or_add_bodypart_xy(layer, store, "bodypart1", x=101.0, y=202.0) + + viewer.layers.selection.active = layer + keypoint_controls.viewer.layers.selection.select_only(layer) + keypoint_controls._save_layers_dialog(selected=True) + qtbot.wait(100) + + # This is the real regression assertion. + _assert_single_animal_on_disk(gt_path, expected_bodyparts=("bodypart1",)) + + +@pytest.mark.usefixtures("qtbot") +def test_single_animal_gt_then_config_merge_preserves_sa_format( + viewer, keypoint_controls, qtbot, tmp_path, overwrite_confirm +): + """ + 1) existing SA GT on disk + 2) config.yaml edited to add bodypart2 + 3) open GT first + 4) open config.yaml + 5) save + + This isolates the 'config merge into existing GT layer' path. + """ + overwrite_confirm.capture() + + project, config_path, labeled_folder = _make_project_config_and_frames_no_gt(tmp_path) + gt_path = _seed_single_animal_gt(labeled_folder, bodyparts=("bodypart1",)) + _append_bodypart_to_config(config_path, "bodypart2") + + # Open GT first + viewer.open(str(gt_path), plugin="napari-deeplabcut") + qtbot.waitUntil( + lambda: len([ly for ly in viewer.layers if isinstance(ly, Points)]) == 1, + timeout=10_000, + ) + qtbot.wait(100) + + gt_layer = next(ly for ly in viewer.layers if isinstance(ly, Points)) + gt_store = keypoint_controls.get_layer_store(gt_layer) + assert gt_store is not None + + # Then open config -> should merge and settle back to one Points layer + viewer.open(str(config_path), plugin="napari-deeplabcut") + qtbot.waitUntil( + lambda: len([ly for ly in viewer.layers if isinstance(ly, Points)]) == 1, + timeout=10_000, + ) + qtbot.wait(100) + + pts_layers = [ly for ly in viewer.layers if isinstance(ly, Points)] + assert len(pts_layers) == 1, f"Expected merged single Points layer, got {[p.name for p in pts_layers]}" + + layer = pts_layers[0] + store = keypoint_controls.get_layer_store(layer) + assert store is not None + + hdr = _header_model_for_layer(layer) + assert "bodypart2" in hdr.bodyparts, f"Expected merged header to contain bodypart2, got {hdr.bodyparts}" + assert "bodypart2" in [kp.label for kp in store._keypoints], ( + f"Store keypoints are stale after config merge: {store._keypoints}" + ) + + _set_or_add_bodypart_xy(layer, store, "bodypart2", x=77.0, y=88.0) + + viewer.layers.selection.active = layer + keypoint_controls.viewer.layers.selection.select_only(layer) + keypoint_controls._save_layers_dialog(selected=True) + qtbot.wait(100) + + _assert_single_animal_on_disk(gt_path, expected_bodyparts=("bodypart1", "bodypart2")) + + +@pytest.mark.usefixtures("qtbot") +def test_single_animal_config_first_then_folder_new_bodypart_preserves_sa_format( + viewer, keypoint_controls, qtbot, tmp_path, overwrite_confirm +): + """ + 1) existing SA GT on disk + 2) config.yaml edited to add bodypart2 + 3) open config.yaml first + 4) open labeled-data folder + 5) add labels + 6) save + + This exercises the config-first / placeholder path. + """ + overwrite_confirm.capture() + + _project, config_path, labeled_folder = _make_project_config_and_frames_no_gt(tmp_path) + gt_path = _seed_single_animal_gt(labeled_folder, bodyparts=("bodypart1",)) + + # Simulate user editing config.yaml outside the plugin + _append_bodypart_to_config(config_path, "bodypart2") + + # Open config first -> placeholder points layer + viewer.open(str(config_path), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: any(isinstance(ly, Points) for ly in viewer.layers), timeout=5_000) + qtbot.wait(100) + + # Then open the labeled-data folder + viewer.open(str(labeled_folder), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len(viewer.layers) >= 2, timeout=10_000) + # qtbot.wait(500) + + pts_layers = [ly for ly in viewer.layers if isinstance(ly, Points)] + assert pts_layers, "Expected at least one Points layer after config-first + folder open" + + # In this workflow the surviving layer is typically the placeholder / merged layer + layer = pts_layers[0] + store = keypoint_controls.get_layer_store(layer) + assert store is not None + + hdr = _header_model_for_layer(layer) + assert "bodypart2" in hdr.bodyparts, f"Expected config-first layer header to contain bodypart2, got {hdr.bodyparts}" + + _set_or_add_bodypart_xy(layer, store, "bodypart2", x=55.0, y=66.0) + + viewer.layers.selection.active = layer + keypoint_controls.viewer.layers.selection.select_only(layer) + keypoint_controls._save_layers_dialog(selected=True) + qtbot.wait(100) + + # Expected good behavior: still canonical SA on disk + _assert_single_animal_on_disk(gt_path, expected_bodyparts=("bodypart1", "bodypart2")) diff --git a/src/napari_deeplabcut/_tests/e2e/utils.py b/src/napari_deeplabcut/_tests/e2e/utils.py index 467a397e..49fd319e 100644 --- a/src/napari_deeplabcut/_tests/e2e/utils.py +++ b/src/napari_deeplabcut/_tests/e2e/utils.py @@ -74,7 +74,7 @@ def _make_minimal_dlc_project(tmp_path: Path): df0 = pd.DataFrame([[10.0, 20.0, np.nan, np.nan]], index=idx, columns=cols) h5_path = labeled / "CollectedData_John.h5" - df0.to_hdf(h5_path, key="keypoints", mode="w") + df0.to_hdf(h5_path, key="df_with_missing", mode="w") df0.to_csv(str(h5_path).replace(".h5", ".csv")) return project, config_path, labeled, h5_path @@ -97,7 +97,7 @@ def _make_labeled_folder_with_machine_only(tmp_path: Path) -> Path: ) df0 = pd.DataFrame([[np.nan, np.nan, np.nan, np.nan]], index=["img000.png"], columns=cols) (folder / "machinelabels-iter0.h5").unlink(missing_ok=True) - df0.to_hdf(folder / "machinelabels-iter0.h5", key="keypoints", mode="w") + df0.to_hdf(folder / "machinelabels-iter0.h5", key="df_with_missing", mode="w") df0.to_csv(str(folder / "machinelabels-iter0.csv")) return folder @@ -126,7 +126,7 @@ def _write_keypoints_h5( df = pd.DataFrame([values], index=idx, columns=cols) path.parent.mkdir(parents=True, exist_ok=True) - df.to_hdf(path, key="keypoints", mode="w") + df.to_hdf(path, key="df_with_missing", mode="w") df.to_csv(str(path).replace(".h5", ".csv")) return path @@ -211,7 +211,7 @@ def _make_project_config_and_frames_no_gt(tmp_path: Path): def _read_h5_keypoints(path: Path) -> pd.DataFrame: - return pd.read_hdf(path, key="keypoints") + return pd.read_hdf(path, key="df_with_missing") def _index_mask_for_img(df: pd.DataFrame, basename: str) -> np.ndarray: diff --git a/src/napari_deeplabcut/_tests/test_reader.py b/src/napari_deeplabcut/_tests/test_reader.py index f43994da..5a3082c5 100644 --- a/src/napari_deeplabcut/_tests/test_reader.py +++ b/src/napari_deeplabcut/_tests/test_reader.py @@ -85,7 +85,7 @@ def test_read_hdf_old_index(tmp_path_factory, fake_keypoints): path = str(tmp_path_factory.mktemp("folder") / "data.h5") old_index = [f"labeled-data/video/img{i}.png" for i in range(fake_keypoints.shape[0])] fake_keypoints.index = old_index - fake_keypoints.to_hdf(path, key="keypoints") + fake_keypoints.to_hdf(path, key="df_with_missing") layers = read_hdf(path) assert len(layers) == 1 image_paths = layers[0][1]["metadata"]["paths"] @@ -104,7 +104,7 @@ def test_read_hdf_new_index(tmp_path_factory, fake_keypoints): ] ) fake_keypoints.index = new_index - fake_keypoints.to_hdf(path, key="keypoints") + fake_keypoints.to_hdf(path, key="df_with_missing") layers = read_hdf(path) assert len(layers) == 1 image_paths = layers[0][1]["metadata"]["paths"] diff --git a/src/napari_deeplabcut/_tests/test_widgets.py b/src/napari_deeplabcut/_tests/test_widgets.py index d4b4cc1d..e59e143e 100644 --- a/src/napari_deeplabcut/_tests/test_widgets.py +++ b/src/napari_deeplabcut/_tests/test_widgets.py @@ -53,7 +53,8 @@ def test_save_layers(viewer, keypoint_controls, points): @pytest.mark.usefixtures("qtbot") def test_show_trails(viewer, keypoint_controls, store): - keypoint_controls._stores[store.layer] = store + # keypoint_controls._stores[store.layer] = store + keypoint_controls.layer_manager.register_managed_layer(store.layer, store) viewer.layers.selection.active = store.layer keypoint_controls._is_saved = True @@ -484,7 +485,7 @@ def test_widget_map_keypoints_writes_to_config(keypoint_controls, qtbot, points, # neighbors indices correspond to ordering of list(dummy_superkpts) # Here: ["nose", "upper_jaw"] -> indices [0, 1] - monkeypatch.setattr(keypoints, "_find_nearest_neighbors", lambda xy, xy_ref: np.array([0, 1])) + monkeypatch.setattr(keypoints, "find_nearest_neighbors", lambda xy, xy_ref: np.array([0, 1])) # If your io.load_config / io.write_config do more than YAML I/O, # you can keep them. Otherwise stubbing them makes the test isolated. @@ -552,7 +553,6 @@ def test_read_config_injects_tables_metadata(tmp_path): } -@pytest.mark.usefixtures("qtbot") def test_points_layer_with_tables_shows_superkeypoints_button(keypoint_controls, qtbot, points): controls = keypoint_controls qtbot.add_widget(controls) @@ -561,26 +561,26 @@ def test_points_layer_with_tables_shows_superkeypoints_button(keypoint_controls, points.metadata["tables"] = {"superanimal_quadruped": {"bp1": "nose", "bp2": "upper_jaw"}} - # Simulate the same setup path that real inserted/adopted layers use - controls._setup_points_layer(points, allow_merge=False) + controls.layer_manager._setup_points_layer(points, allow_merge=False) assert not controls._keypoint_mapping_button.isHidden() - assert controls._keypoint_mapping_button.text() == "Load superkeypoints diagram" @pytest.mark.usefixtures("qtbot") -def test_points_layer_with_tables_button_not_lost_on_merge_path(keypoint_controls, qtbot, points, monkeypatch): +def test_points_layer_with_tables_button_not_lost_on_merge_path(keypoint_controls, qtbot, points): controls = keypoint_controls qtbot.add_widget(controls) points.metadata["tables"] = {"superanimal_quadruped": {"bp1": "nose"}} - # Force the merge branch to happen - monkeypatch.setattr(controls, "_maybe_merge_config_points_layer", lambda layer: True) + # First do a normal setup so the button becomes visible. + controls.layer_manager._setup_points_layer(points, allow_merge=False) + assert not controls._keypoint_mapping_button.isHidden() - controls._setup_points_layer(points, allow_merge=True) + # Simulate manager-driven merge refresh on the already-managed layer. + controls._on_points_layers_merged_requested((points,)) - assert controls._keypoint_mapping_button.isHidden() + assert not controls._keypoint_mapping_button.isHidden() @pytest.mark.usefixtures("qtbot") diff --git a/src/napari_deeplabcut/_tests/test_writer.py b/src/napari_deeplabcut/_tests/test_writer.py index 355fb018..b9228a66 100644 --- a/src/napari_deeplabcut/_tests/test_writer.py +++ b/src/napari_deeplabcut/_tests/test_writer.py @@ -89,7 +89,7 @@ def _add_source_io(metadata: dict, *, root: Path, kind: AnnotationKind, source_n "project_root": str(root), "source_relpath_posix": source_name.replace("\\", "/"), "kind": kind, # AnnotationKind.GT or AnnotationKind.MACHINE - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", } # legacy migration compatibility (optional but good) md["source_h5"] = str((root / source_name).resolve()) @@ -105,7 +105,7 @@ def _add_save_target(metadata: dict, *, root: Path, scorer: str) -> None: "project_root": str(root), "source_relpath_posix": f"CollectedData_{scorer}.h5", "kind": AnnotationKind.GT, - "dataset_key": "keypoints", + "dataset_key": "df_with_missing", "scorer": scorer, } @@ -231,9 +231,6 @@ def test_write_hdf_promotion_merges_into_existing_gt(tmp_path, fake_keypoints, m root = tmp_path / "proj" root.mkdir() - # Always allow overwrite confirmation in unit test - # monkeypatch.setattr(dialogs, "maybe_confirm_overwrite", lambda *args, **kwargs: True) - header = DLCHeaderModel(columns=fake_keypoints.columns) n_rows = len(fake_keypoints) @@ -274,13 +271,13 @@ def test_write_hdf_promotion_merges_into_existing_gt(tmp_path, fake_keypoints, m # Convert to MultiIndex of path components (matches refactored indexing model) guarantee_multiindex_rows(gt) - gt.to_hdf(gt_path, key="keypoints", mode="w") + gt.to_hdf(gt_path, key="df_with_missing", mode="w") # Create a machine file too; it must remain untouched machine_path = root / "machinelabels-iter0.h5" df_machine = pd.DataFrame(np.nan, index=[0], columns=fake_keypoints.columns) - df_machine.to_hdf(machine_path, key="keypoints", mode="w") - machine_before = pd.read_hdf(machine_path, key="keypoints") + df_machine.to_hdf(machine_path, key="df_with_missing", mode="w") + machine_before = pd.read_hdf(machine_path, key="df_with_missing") points = np.column_stack([np.arange(n_rows), rng.random(n_rows), rng.random(n_rows)]) @@ -288,14 +285,14 @@ def test_write_hdf_promotion_merges_into_existing_gt(tmp_path, fake_keypoints, m assert Path(fnames[0]).name == "CollectedData_me.h5" # GT should exist and be readable - df = pd.read_hdf(fnames[0], key="keypoints") + df = pd.read_hdf(fnames[0], key="df_with_missing") assert isinstance(df, pd.DataFrame) # Must still be scored as "me" after promotion assert df.columns.get_level_values("scorer")[0] == "me" # Machine file must be unchanged - machine_after = pd.read_hdf(machine_path, key="keypoints") + machine_after = pd.read_hdf(machine_path, key="df_with_missing") pd.testing.assert_frame_equal(machine_before, machine_after) @@ -350,8 +347,6 @@ def test_write_hdf_promotion_creates_gt_when_missing(tmp_path, fake_keypoints, m root = tmp_path / "proj" root.mkdir() - # monkeypatch.setattr(dialogs, "maybe_confirm_overwrite", lambda *args, **kwargs: True) - header = DLCHeaderModel(columns=fake_keypoints.columns) n_rows = len(fake_keypoints) @@ -387,7 +382,7 @@ def test_write_hdf_promotion_creates_gt_when_missing(tmp_path, fake_keypoints, m out_h5 = Path(fnames[0]) assert out_h5.exists() - df = pd.read_hdf(out_h5, key="keypoints") + df = pd.read_hdf(out_h5, key="df_with_missing") assert df.columns.get_level_values("scorer")[0] == "alice" # Ensure we still did NOT write back to a machine source file diff --git a/src/napari_deeplabcut/_tests/tracking/conftest.py b/src/napari_deeplabcut/_tests/tracking/conftest.py new file mode 100644 index 00000000..bf55398a --- /dev/null +++ b/src/napari_deeplabcut/_tests/tracking/conftest.py @@ -0,0 +1,218 @@ +import numpy as np +import pandas as pd +import pytest + +from napari_deeplabcut.tracking.core.data import TrackingModelInputs, TrackingWorkerData, TrackingWorkerOutput +from napari_deeplabcut.tracking.core.models import AVAILABLE_TRACKERS, RawModelOutputs, TrackingModel + +# --- Tracking fixtures --- +DUMMY_TRACKER_NAME = "TestTracker" + + +class DummyTracker(TrackingModel): + """ + Minimal tracker that: + - echoes inputs to outputs with a tiny deterministic transform, + - emits progress via the callback, + - honors stop_callback. + """ + + name = DUMMY_TRACKER_NAME + info_text = "Dummy tracker for unit testing." + + def load_model(self, device: str): + # No-op model; keep a simple config to emulate 'step' like CoTracker. + class _NoOpModel: + step = 3 + + return _NoOpModel() + + def prepare_inputs(self, cfg: "TrackingWorkerData", **kwargs) -> TrackingModelInputs: + # Ensure video is (T, H, W, C) and keypoints is (K, 3) where columns: [frame_idx, x, y] or [id, x, y] + video = np.asarray(cfg.video) + queries = np.asarray(cfg.keypoints).copy() + metadata = { + "keypoint_range": cfg.keypoint_range, + "backward_tracking": getattr(cfg, "backward_tracking", False), + } + return TrackingModelInputs(video=video, keypoints=queries, metadata=metadata) + + def run(self, inputs: TrackingModelInputs, progress_callback, stop_callback, **kwargs) -> RawModelOutputs: + # Fake progression per frame; stop if requested. + T = inputs.video.shape[0] + K = inputs.keypoints.shape[0] + + # Produce tracks of shape (T, K, 2) with a deterministic offset (e.g., +1 pixel) + tracks = np.zeros((T, K, 2), dtype=float) + for t in range(T): + progress_callback(t, T) + if stop_callback(): + # Return partial result up to t + tracks = tracks[: t + 1] + vis = np.ones_like(tracks[..., 0], dtype=bool) # visibility dummy + return RawModelOutputs(keypoints=tracks, keypoint_features={"visibility": vis}) + # Use the input (x, y) for all K points and add a tiny drift proportional to t + tracks[t, :, 0] = inputs.keypoints[:, 1] + 0.1 * t # x + tracks[t, :, 1] = inputs.keypoints[:, 2] + 0.1 * t # y + + vis = np.ones_like(tracks[..., 0], dtype=bool) + return RawModelOutputs(keypoints=tracks, keypoint_features={"visibility": vis}) + + def prepare_outputs( + self, model_outputs: RawModelOutputs, worker_inputs: "TrackingWorkerData" = None, **kwargs + ) -> "TrackingWorkerOutput": + # Flatten (T, K, 2) -> (N, 3) with [frame_idx, x, y] + tracks = model_outputs.keypoints + T = tracks.shape[0] + K = tracks.shape[1] + + T1, T2 = worker_inputs.keypoint_range + frame_ids = np.repeat(np.arange(T1, T1 + T), K) + flat = tracks.reshape(-1, 2) + keypoints = np.column_stack((frame_ids, flat)) # (N, 3) + + # Minimal features: concat original per-keypoint features replicated per frame + keypoints_features = pd.concat( + [worker_inputs.keypoint_features] * T, + ignore_index=True, + ) + + return TrackingWorkerOutput( + keypoints=keypoints, + keypoint_features=keypoints_features, + ) + + def validate_outputs(self, inputs: TrackingModelInputs, outputs: "TrackingWorkerOutput") -> tuple[bool, str]: + """ + Validate DummyTracker outputs. + + Expectations for DummyTracker: + - outputs.keypoints is an (N, 3) float array of [frame_idx, x, y] + - N == (T2 - T1) * K where: + T1, T2 = inputs.metadata["keypoint_range"] + T = T2 - T1 (number of frames produced) + K = inputs.keypoints.shape[0] (number of query points) + - frame_idx are integers in [T1, T2-1] + - x, y are finite. If video shape known, also check bounds: x∈[0,W), y∈[0,H) + - outputs.keypoint_features is a DataFrame with length N + and contains at least the columns present in worker_inputs.keypoint_features + (as repeated by the DummyTracker) + """ + + # -------- Basic structure checks + kp = outputs.keypoints + if not isinstance(kp, np.ndarray): + return False, "outputs.keypoints must be a numpy array" + + if kp.ndim != 2 or kp.shape[1] != 3: + return False, f"outputs.keypoints must have shape (N, 3); got {kp.shape}" + + # -------- Expected length: N = (T2 - T1) * K + meta = inputs.metadata or {} + if ( + "keypoint_range" not in meta + or not isinstance(meta["keypoint_range"], (tuple, list)) + or len(meta["keypoint_range"]) != 2 + ): + return False, "inputs.metadata.keypoint_range must be a (T1, T2) tuple" + + T1, T2 = meta["keypoint_range"] + if not (isinstance(T1, (int, np.integer)) and isinstance(T2, (int, np.integer)) and T2 >= T1): + return False, "Invalid keypoint_range; expected integers with T2 >= T1" + + K = inputs.keypoints.shape[0] + expected_len = (T2 - T1) * K + if kp.shape[0] != expected_len: + return False, f"Expected (T*K)={expected_len} rows; got {kp.shape[0]}" + + # -------- Frame index checks + frames = kp[:, 0] + # Allow float dtype but must be whole numbers + if not np.all(np.isfinite(frames)): + return False, "Frame indices contain non-finite values" + + if not np.allclose(frames, np.round(frames)): + return False, "Frame indices must be integers" + + frames_int = frames.astype(int) + if frames_int.min() < T1 or frames_int.max() > (T2 - 1): + return False, f"Frame indices out of range [{T1}, {T2 - 1}]" + + # -------- Coordinate checks + xy = kp[:, 1:3] + if not np.all(np.isfinite(xy)): + return False, "Coordinates contain NaN/Inf" + + # -------- Features checks + feats = outputs.keypoint_features + if not isinstance(feats, pd.DataFrame): + return False, "outputs.keypoint_features must be a pandas DataFrame" + + if len(feats) != expected_len: + return False, f"keypoint_features length mismatch: expected {expected_len}, got {len(feats)}" + + # When produced by DummyTracker, features are a concat of the input per frame + # Ensure at least the same columns are present and non-null + required_cols = [] + try: + # worker_inputs.keypoint_features is replicated in DummyTracker.prepare_outputs + required_cols = list(self.cfg.keypoint_features.columns) # may exist on the tracker + except Exception: + # fallback to inputs.shape if not accessible; skip strict column match + pass + + missing = [c for c in required_cols if c not in feats.columns] + if missing: + return False, f"Missing required feature columns: {missing}" + + if required_cols: + if feats[required_cols].isna().any().any(): + return False, "keypoint_features contain NaN in required columns" + + return True, "" + + +@pytest.fixture(autouse=True) +def register_dummy_tracker(): + """ + Auto-register DummyTracker for all tests and restore registry afterwards. + """ + prev = dict(AVAILABLE_TRACKERS) + AVAILABLE_TRACKERS[DUMMY_TRACKER_NAME] = {"class": DummyTracker} + try: + yield + finally: + AVAILABLE_TRACKERS.clear() + AVAILABLE_TRACKERS.update(prev) + + +@pytest.fixture +def track_worker_inputs(): + """ + Provide minimal valid TrackingWorkerData with: + - 5-frame RGB video of 4x4 pixels, + - 2 keypoints, + - keypoint_range covering all frames, + - simple features DataFrame. + """ + video = np.zeros((5, 4, 4, 3), dtype=np.uint8) + + keypoints = np.array( + [ + [0, 10.0, 20.0], + [0, 30.0, 40.0], + ], + dtype=float, + ) + + keypoint_features = pd.DataFrame({"id": [0, 1], "name": ["kp0", "kp1"]}) + + # Build TrackingWorkerData + return TrackingWorkerData( + tracker_name=DUMMY_TRACKER_NAME, + video=video, + keypoints=keypoints, + keypoint_range=(0, 5), # frames 0..4 + keypoint_features=keypoint_features, + backward_tracking=False, + ) diff --git a/src/napari_deeplabcut/_tests/tracking/test_widgets.py b/src/napari_deeplabcut/_tests/tracking/test_widgets.py new file mode 100644 index 00000000..a0f3dc9b --- /dev/null +++ b/src/napari_deeplabcut/_tests/tracking/test_widgets.py @@ -0,0 +1,260 @@ +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +import pytest +from qtpy.QtCore import Qt + +from napari_deeplabcut.tracking._widgets import TrackingControls +from napari_deeplabcut.tracking.core.data import TrackingWorkerData +from napari_deeplabcut.tracking.core.models import AVAILABLE_TRACKERS + +if TYPE_CHECKING: + import napari + +_DUMMY_VIDEO_N_FRAMES = 10 + + +def _get_tracking_controls(viewer: "napari.Viewer") -> TrackingControls: + viewer.window.add_dock_widget( + TrackingControls(viewer), + name="Tracking controls", + area="right", + ) + for _title, wdg in viewer.window.dock_widgets.items(): + if isinstance(wdg, TrackingControls) and wdg.property("ndlc_tracking_controls"): + return wdg + raise RuntimeError("Tracking controls dock widget not found") + + +@pytest.fixture +def setup_tracking_widget(qtbot, viewer, monkeypatch): + """ + Factory fixture that returns a function to set up the TrackingControls test environment. + + Usage in tests: + tc, video_layer, points_layer = setup_tracking_env(add_data=True) + # or for just the widget: + tc = setup_tracking_env() + """ + + def _setup_tracking_data(*, add_data: bool = False, disable_insert_hooks: bool = True): + tc = _get_tracking_controls(viewer) + qtbot.addWidget(tc) + + # Optionally disable plugin insert hooks that assume DLC metadata/header. + if add_data and disable_insert_hooks: + monkeypatch.setattr( + "napari_deeplabcut._widgets.KeypointControls.on_insert", + lambda *args, **kwargs: None, + raising=False, + ) + monkeypatch.setattr( + "napari_deeplabcut.ui.plots.trajectory.TrajectoryMatplotlibCanvas._load_dataframe", + lambda *args, **kwargs: None, + raising=False, + ) + + if add_data: + # Minimal, deterministic video and points layers + video_layer = viewer.add_image(np.zeros((_DUMMY_VIDEO_N_FRAMES, 4, 4), dtype=np.uint8), name="video_stack") + points_layer = viewer.add_points( + np.array([[0, 1.0, 2.0], [0, 3.0, 4.0]]), + features=pd.DataFrame({"id": [0, 1], "name": ["kp0", "kp1"]}), + name="keypoints", + ) + tc._video_layer_combo.value = video_layer + tc._keypoint_layer_combo.value = points_layer + return tc, video_layer, points_layer + + return tc + + return _setup_tracking_data + + +def _put_all_points_on_frame(points_layer, frame: int) -> None: + """ + Move all test points onto a specific frame so tracking has a valid seed frame. + Keeps the same number/order of points and therefore preserves features alignment. + """ + data = np.asarray(points_layer.data, dtype=float).copy() + assert data.size > 0, "Test fixture produced no points." + data[:, 0] = float(frame) + points_layer.data = data + + +def _set_current_tracking_frame(tc, viewer, frame: int) -> None: + """ + Drive the tracking widget the same way the real UI does: + the reference frame follows the viewer's current step. + """ + viewer.dims.current_step = (frame,) + (0,) * (viewer.dims.ndim - 1) + tc._video_layer_changed() + assert tc._reference_spinbox.value() == frame + + +@pytest.mark.usefixtures("qtbot") +def test_tracking_controls_initial_state(setup_tracking_widget): + tc = setup_tracking_widget(add_data=False) + + items = [tc._tracking_method_combo.itemText(i) for i in range(tc._tracking_method_combo.count())] + assert set(items) >= set(AVAILABLE_TRACKERS.keys()) + current = tc._tracking_method_combo.currentText() + info = AVAILABLE_TRACKERS[current]["class"].info_text + assert tc._model_info_button.toolTip() == info + + +def test_tracking_frame_controls_layer_selection_and_ranges(setup_tracking_widget, viewer): + tc, video_layer, points_layer = setup_tracking_widget(add_data=True) + viewer.dims.current_step = (2,) + (0,) * (viewer.dims.ndim - 1) + tc._video_layer_changed() + # Forward range + assert tc._forward_slider.minimum() == 0 + assert tc._forward_slider.maximum() == _DUMMY_VIDEO_N_FRAMES - 1 - 2 # 2 steps forward possible + assert tc._forward_spinbox_absolute.minimum() == 2 + assert tc._forward_spinbox_absolute.maximum() == _DUMMY_VIDEO_N_FRAMES - 1 + # Backward range + assert tc._backward_slider.minimum() == -2 + assert tc._backward_slider.maximum() == 0 + assert tc._backward_spinbox_absolute.minimum() == 0 + assert tc._backward_spinbox_absolute.maximum() == 2 + # Reference spinbox + assert tc._reference_spinbox.value() == 2 + assert tc._reference_spinbox.minimum() == 0 + assert tc._reference_spinbox.maximum() == _DUMMY_VIDEO_N_FRAMES - 1 + + big_video_n_frames = 200 + new_video = np.zeros((big_video_n_frames, 4, 4), dtype=np.uint8) + tc._video_layer_combo.value = viewer.add_image(new_video, name="big_video") + tc._video_layer_changed() + frame = 150 + viewer.dims.current_step = (frame,) + (0,) * (viewer.dims.ndim - 1) + # Forward range + assert tc._forward_slider.minimum() == 0 + assert tc._forward_slider.maximum() == big_video_n_frames - 1 - frame + assert tc._forward_spinbox_absolute.minimum() == frame + assert tc._forward_spinbox_absolute.maximum() == big_video_n_frames - 1 + # Backward range + assert tc._backward_slider.minimum() == -frame + assert tc._backward_slider.maximum() == 0 + assert tc._backward_spinbox_absolute.minimum() == 0 + assert tc._backward_spinbox_absolute.maximum() == frame + # Reference spinbox + assert tc._reference_spinbox.value() == frame + assert tc._reference_spinbox.minimum() == 0 + assert tc._reference_spinbox.maximum() == big_video_n_frames - 1 + + +@pytest.mark.usefixtures("qtbot") +def test_forward_track(setup_tracking_widget, qtbot, viewer): + tc, video_layer, points_layer = setup_tracking_widget(add_data=True) + + # Only change internal state + def fake_start_worker(self): + self.worker_started = True + + from types import MethodType + + tc._start_worker = MethodType(fake_start_worker, tc) + + # Set current frame to 0, set forward absolute to 3 + viewer.dims.current_step = (0,) + (0,) * (viewer.dims.ndim - 1) + tc._video_layer_changed() + tc._reference_spinbox.setValue(0) + tc._forward_spinbox_absolute.setValue(3) + + with qtbot.waitSignal(tc.trackingRequested, timeout=1500) as req: + qtbot.mouseClick(tc._tracking_forward_button, Qt.LeftButton) + + twd: TrackingWorkerData = req.args[0] + assert isinstance(twd, TrackingWorkerData) + assert twd.tracker_name == tc._tracking_method_combo.currentText() + assert twd.backward_tracking is False + # video slice should have length 3 (frames 0..3 inclusive when +1 applied) + assert twd.video.shape[0] == 4 # because track_forward uses forward_frame_idx + 1 + # keypoints should be those from ref frame with frame index reset to 0 + assert (twd.keypoints[:, 0] == 0).all() + # features replicated per ref frame selection (only ref frame rows) + assert len(twd.keypoint_features) == len(points_layer.features) + + +@pytest.mark.usefixtures("qtbot") +def test_backward_track(setup_tracking_widget, qtbot, viewer): + tc, video_layer, points_layer = setup_tracking_widget(add_data=True) + from types import MethodType + + tc._start_worker = MethodType(lambda self: setattr(self, "worker_started", True), tc) + + _put_all_points_on_frame(points_layer, 2) + _set_current_tracking_frame(tc, viewer, 2) + + # Set ref frame to 2; backward absolute to 0 so it’s < ref + tc._backward_spinbox_absolute.setValue(0) + + with qtbot.waitSignal(tc.trackingRequested, timeout=1500) as req: + qtbot.mouseClick(tc._tracking_backward_button, Qt.LeftButton) + twd = req.args[0] + assert twd.backward_tracking is True + assert twd.reference_frame_index == 2 + # For backward, track() reverses the video slice + assert twd.video.shape[0] == (2 - 0 + 1) # inclusive range when +1 is applied in TrackControls + + # Seed keypoints are re-based to local frame 0 inside the sliced tracking video. + assert np.all(twd.keypoints[:, 0] == 0) + # Original point properties should still be present on the worker input. + assert "id" in twd.keypoint_features.columns + assert "name" in twd.keypoint_features.columns + + # New tracking identity columns should also be present. + assert "tracking_query_index" in twd.keypoint_features.columns + assert "tracking_query_frame" in twd.keypoint_features.columns + assert set(twd.keypoint_features["tracking_query_frame"]) == {2} + + +@pytest.mark.usefixtures("qtbot") +def test_bothway_track(setup_tracking_widget, qtbot, viewer): + tc, video_layer, points_layer = setup_tracking_widget(add_data=True) + from types import MethodType + + tc._start_worker = MethodType(lambda self: setattr(self, "worker_started", True), tc) + + # New invariant: current/reference frame must actually contain seed keypoints. + _put_all_points_on_frame(points_layer, 3) + _set_current_tracking_frame(tc, viewer, 3) + + # Forward target > ref, backward target < ref + tc._forward_spinbox_absolute.setValue(6) + tc._backward_spinbox_absolute.setValue(0) + + captured = [] + tc.trackingRequested.connect(lambda d: captured.append(d)) + + # Ensure backward path doesn't fail due to missing keypoint_widget + tc.keypoint_widget = object() + + with qtbot.waitSignals([tc.trackingRequested, tc.trackingRequested], timeout=2000): + qtbot.mouseClick(tc._tracking_bothway_button, Qt.LeftButton) + tc.trackedKeypointsAdded.emit() + + assert len(captured) == 2 + assert captured[0].backward_tracking is False + assert captured[1].backward_tracking is True + + # Both requests should use the same seed frame, because the ref frame is still 3. + assert captured[0].reference_frame_index == 3 + assert captured[1].reference_frame_index == 3 + + # Do the same when forward == reference -> only backward tracking should run. + captured.clear() + + # Since the widget keeps reference == current frame, make that explicit. + _set_current_tracking_frame(tc, viewer, 3) + tc._forward_spinbox_absolute.setValue(3) # == ref, so forward is invalid + tc._backward_spinbox_absolute.setValue(0) + + with qtbot.waitSignal(tc.trackingRequested, timeout=1500): + qtbot.mouseClick(tc._tracking_bothway_button, Qt.LeftButton) + + assert len(captured) == 1 + assert captured[0].backward_tracking is True + assert captured[0].reference_frame_index == 3 diff --git a/src/napari_deeplabcut/_tests/tracking/test_worker.py b/src/napari_deeplabcut/_tests/tracking/test_worker.py new file mode 100644 index 00000000..067de853 --- /dev/null +++ b/src/napari_deeplabcut/_tests/tracking/test_worker.py @@ -0,0 +1,53 @@ +from napari_deeplabcut.tracking.core.data import TrackingWorkerOutput +from napari_deeplabcut.tracking.ui.worker import TrackingWorker + + +def test_tracking_worker(qtbot, track_worker_inputs): + worker = TrackingWorker() + + with qtbot.waitSignal(worker.trackingFinished, timeout=1000) as await_finished: + worker.track(track_worker_inputs) + + output: TrackingWorkerOutput = await_finished.args[0] + assert isinstance(output, TrackingWorkerOutput) + # Expect (T * K, 3) with columns [frame_idx, x, y] + T = track_worker_inputs.video.shape[0] + K = track_worker_inputs.keypoints.shape[0] + assert output.keypoints.shape == (T * K, 3) + # Check frame indices + frames = output.keypoints[:, 0] + assert frames.min() == 0 + assert frames.max() == T - 1 + + +def test_progress_emitted(track_worker_inputs): + worker = TrackingWorker() + + progress_events = [] + worker.progress.connect(lambda current, total: progress_events.append((current, total))) + worker.track(track_worker_inputs) + + T = track_worker_inputs.video.shape[0] + assert len(progress_events) >= T # at least one per frame + assert progress_events[-1] == (T - 1, T) # final progress emit is done elsewhere + + +def test_stop_tracking_emits_stopped(qtbot, track_worker_inputs): + worker = TrackingWorker() + await_stopped = qtbot.waitSignal(worker.trackingStopped, timeout=1000) + + def stop_on_first_progress(current, total): + worker.stop_tracking() + + worker.progress.connect(stop_on_first_progress) + worker.track(track_worker_inputs) + assert await_stopped.signal_triggered + + +def test_unknown_tracker_emits_stopped(qtbot, track_worker_inputs): + worker = TrackingWorker() + inval_cfg = track_worker_inputs + inval_cfg.tracker_name = "DoesNotExist" + await_stopped = qtbot.waitSignal(worker.trackingStopped, timeout=1000) + worker.track(inval_cfg) + assert await_stopped.signal_triggered diff --git a/src/napari_deeplabcut/_tests/ui/test_save.py b/src/napari_deeplabcut/_tests/ui/test_save.py new file mode 100644 index 00000000..ed8696e2 --- /dev/null +++ b/src/napari_deeplabcut/_tests/ui/test_save.py @@ -0,0 +1,438 @@ +from __future__ import annotations + +import logging +from types import SimpleNamespace + +import numpy as np +import pytest +from napari.layers import Image, Points + +import napari_deeplabcut.ui.ui_dialogs.save as save_mod + +# --------------------------------------------------------------------- +# Lightweight test doubles +# --------------------------------------------------------------------- + + +class DummySelection(list): + def __init__(self, items=(), *, active=None): + super().__init__(items) + self.active = active + + def select_only(self, layer): + self[:] = [layer] + self.active = layer + + +class DummyLayers(list): + def __init__(self, items=()): + super().__init__(items) + self.selection = DummySelection() + self.save_calls = [] + + def save(self, *args, **kwargs): + self.save_calls.append((args, kwargs)) + + +class DummyViewer: + def __init__(self, layers=()): + self.layers = DummyLayers(layers) + + +class DummyLayerManager: + def __init__(self): + self.image_root = None + self.image_paths = None + self._active_image = None + self._managed_points = () + + def active_dlc_image_layer(self): + return self._active_image + + def managed_points_layers(self): + return tuple(self._managed_points) + + +class DummyTrailsController: + def __init__(self): + self.persist_calls = [] + + def persist_folder_ui_state_for_points_layer(self, layer, *, checkbox_checked: bool): + self.persist_calls.append((layer, checkbox_checked)) + + +# --------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------- + + +@pytest.fixture +def points_layer(): + return Points( + np.empty((0, 3), dtype=float), + name="points", + metadata={}, + properties={}, + ) + + +@pytest.fixture +def points_layer_2(): + return Points( + np.empty((0, 3), dtype=float), + name="points-2", + metadata={}, + properties={}, + ) + + +@pytest.fixture +def image_layer(): + return Image( + np.zeros((5, 8, 8), dtype=np.uint8), + name="img-stack", + metadata={}, + ) + + +@pytest.fixture +def workflow_factory(): + def _make(*, viewer=None, layer_manager=None, trails_controller=None, trail_checked=False, resolve_config=None): + viewer = viewer or DummyViewer() + layer_manager = layer_manager or DummyLayerManager() + trails_controller = trails_controller or DummyTrailsController() + + wf = save_mod.PointsLayerSaveWorkflow( + parent=None, + viewer=viewer, + layer_manager=layer_manager, + trails_controller=trails_controller, + trail_checkbox_getter=lambda: trail_checked, + resolve_config_path_for_layer=resolve_config or (lambda _layer: None), + current_project_path_getter=lambda: None, + current_image_meta_getter=lambda: None, + logger=logging.getLogger("test.save_workflow"), + ) + return wf, viewer, layer_manager, trails_controller + + return _make + + +# --------------------------------------------------------------------- +# _save_multiple_layers +# --------------------------------------------------------------------- + + +def test_save_multiple_layers_returns_false_when_dialog_cancelled( + monkeypatch, + workflow_factory, + points_layer, +): + class FakeDialog: + def __init__(self): + self.history = None + + def setHistory(self, hist): + self.history = hist + + def getSaveFileName(self, **kwargs): + return "", "" + + monkeypatch.setattr(save_mod, "QFileDialog", FakeDialog) + monkeypatch.setattr(save_mod, "get_save_history", lambda: [r"C:\tmp"]) + + viewer = DummyViewer([points_layer]) + wf, viewer, _lm, trails = workflow_factory(viewer=viewer) + + outcome = wf._save_multiple_layers(selected=True, selected_layers=[points_layer]) + + assert outcome.saved is False + assert viewer.layers.save_calls == [] + assert trails.persist_calls == [] + + +def test_save_multiple_layers_selected_persists_only_selected_points_layers( + monkeypatch, + workflow_factory, + points_layer, + image_layer, +): + captured = {} + + class FakeDialog: + def __init__(self): + self.history = None + + def setHistory(self, hist): + self.history = hist + captured["history"] = hist + + def getSaveFileName(self, **kwargs): + captured["dialog_kwargs"] = kwargs + return r"C:\tmp\out.tif", "ignored" + + monkeypatch.setattr(save_mod, "QFileDialog", FakeDialog) + monkeypatch.setattr(save_mod, "get_save_history", lambda: [r"C:\start"]) + + viewer = DummyViewer([points_layer, image_layer]) + wf, viewer, _lm, trails = workflow_factory(viewer=viewer, trail_checked=True) + + selected_layers = [points_layer, image_layer] + outcome = wf._save_multiple_layers(selected=True, selected_layers=selected_layers) + + assert outcome.saved is True + assert outcome.status_message == "Data successfully saved" + + assert viewer.layers.save_calls == [((r"C:\tmp\out.tif",), {"selected": True})] + + # Only Points layers from the selected set should be persisted + assert trails.persist_calls == [(points_layer, True)] + + assert captured["history"] == [r"C:\start"] + assert captured["dialog_kwargs"]["caption"] == "Save selected layers" + assert captured["dialog_kwargs"]["dir"] == r"C:\start" + + +def test_save_multiple_layers_all_uses_managed_points_layers( + monkeypatch, + workflow_factory, + points_layer, + points_layer_2, +): + class FakeDialog: + def setHistory(self, hist): + pass + + def getSaveFileName(self, **kwargs): + return r"C:\tmp\all_layers.npy", "ignored" + + monkeypatch.setattr(save_mod, "QFileDialog", FakeDialog) + monkeypatch.setattr(save_mod, "get_save_history", lambda: [r"C:\start"]) + + viewer = DummyViewer([points_layer, points_layer_2]) + lm = DummyLayerManager() + lm._managed_points = (points_layer, points_layer_2) + + wf, viewer, _lm, trails = workflow_factory( + viewer=viewer, + layer_manager=lm, + trail_checked=False, + ) + + outcome = wf._save_multiple_layers(selected=False, selected_layers=[]) + + assert outcome.saved is True + assert viewer.layers.save_calls == [((r"C:\tmp\all_layers.npy",), {"selected": False})] + assert trails.persist_calls == [ + (points_layer, False), + (points_layer_2, False), + ] + + +def test_save_layers_dispatches_to_save_multiple_for_non_single_points_selection( + monkeypatch, + workflow_factory, + points_layer, + image_layer, +): + viewer = DummyViewer([points_layer, image_layer]) + viewer.layers.selection[:] = [points_layer, image_layer] + viewer.layers.selection.active = points_layer + + wf, _viewer, _lm, _trails = workflow_factory(viewer=viewer) + + called = {} + + def _fake_save_multiple(*, selected, selected_layers): + called["selected"] = selected + called["selected_layers"] = list(selected_layers) + return save_mod.SaveOutcome(saved=True, status_message="ok") + + monkeypatch.setattr(wf, "_save_multiple_layers", _fake_save_multiple) + + outcome = wf.save_layers(selected=True) + + assert outcome.saved is True + assert called["selected"] is True + assert called["selected_layers"] == [points_layer, image_layer] + + +# --------------------------------------------------------------------- +# _best_image_context_layer +# --------------------------------------------------------------------- + + +def test_best_image_context_layer_prefers_lifecycle_owned_active_image( + workflow_factory, + image_layer, +): + selected_image = Image(np.zeros((2, 2), dtype=np.uint8), name="selected") + first_image = Image(np.zeros((2, 2), dtype=np.uint8), name="first") + + viewer = DummyViewer([first_image]) + viewer.layers.selection.active = selected_image + + lm = DummyLayerManager() + lm._active_image = image_layer + + wf, *_ = workflow_factory(viewer=viewer, layer_manager=lm) + + assert wf._best_image_context_layer() is image_layer + + +def test_best_image_context_layer_falls_back_to_selected_image(workflow_factory): + selected_image = Image(np.zeros((2, 2), dtype=np.uint8), name="selected") + viewer = DummyViewer([selected_image]) + viewer.layers.selection.active = selected_image + + wf, *_ = workflow_factory(viewer=viewer) + + assert wf._best_image_context_layer() is selected_image + + +def test_best_image_context_layer_falls_back_to_first_image_in_viewer(workflow_factory): + img1 = Image(np.zeros((2, 2), dtype=np.uint8), name="img1") + img2 = Image(np.zeros((2, 2), dtype=np.uint8), name="img2") + + viewer = DummyViewer([img1, img2]) + viewer.layers.selection.active = None + + wf, *_ = workflow_factory(viewer=viewer) + + assert wf._best_image_context_layer() is img1 + + +# --------------------------------------------------------------------- +# _enrich_points_metadata_for_save +# --------------------------------------------------------------------- + + +def test_enrich_metadata_returns_unchanged_if_root_already_present( + workflow_factory, + points_layer, +): + wf, _viewer, lm, _trails = workflow_factory() + lm.image_root = r"C:\project\labeled-data\ctx" + lm.image_paths = ["img001.png", "img002.png"] + + md = {"root": r"C:\already\set", "paths": []} + + out = wf._enrich_points_metadata_for_save(points_layer, md) + + # Current behavior: early return if root exists + assert out == md + + +def test_enrich_metadata_fills_from_layer_manager_image_context( + workflow_factory, + points_layer, +): + wf, _viewer, lm, _trails = workflow_factory() + lm.image_root = r"C:\project\labeled-data\session1" + lm.image_paths = ["img001.png", "img002.png"] + + out = wf._enrich_points_metadata_for_save(points_layer, {}) + + assert out["root"] == r"C:\project\labeled-data\session1" + assert out["paths"] == ["img001.png", "img002.png"] + + +def test_enrich_metadata_returns_unchanged_when_no_context_and_no_config( + workflow_factory, + points_layer, +): + wf, *_ = workflow_factory(resolve_config=lambda _layer: None) + + md = {} + out = wf._enrich_points_metadata_for_save(points_layer, md) + + assert out == {} + + +def test_enrich_metadata_adds_project_when_config_resolves_but_no_image_layer( + monkeypatch, + workflow_factory, + points_layer, + tmp_path, +): + project_root = tmp_path / "project" + config_path = project_root / "config.yaml" + project_root.mkdir() + config_path.write_text("dummy", encoding="utf-8") + + monkeypatch.setattr(save_mod, "resolve_project_root_from_config", lambda p: project_root) + + wf, *_ = workflow_factory(resolve_config=lambda _layer: config_path) + + out = wf._enrich_points_metadata_for_save(points_layer, {}) + + assert out["project"] == str(project_root) + assert "root" not in out + + +def test_enrich_metadata_uses_source_anchor_when_it_looks_like_labeled_folder( + monkeypatch, + workflow_factory, + points_layer, + tmp_path, +): + project_root = tmp_path / "project" + project_root.mkdir() + config_path = project_root / "config.yaml" + config_path.write_text("dummy", encoding="utf-8") + + inferred_root = project_root / "labeled-data" / "sessionA" + + monkeypatch.setattr(save_mod, "resolve_project_root_from_config", lambda p: project_root) + monkeypatch.setattr(save_mod, "normalize_anchor_candidate", lambda src: inferred_root) + monkeypatch.setattr(save_mod, "looks_like_dlc_labeled_folder", lambda p: True) + + lm = DummyLayerManager() + lm._active_image = SimpleNamespace( + source=SimpleNamespace(path=r"C:\whatever\img001.png"), + name="ignored", + metadata={}, + ) + + wf, _viewer, _lm, _trails = workflow_factory( + layer_manager=lm, + resolve_config=lambda _layer: config_path, + ) + + out = wf._enrich_points_metadata_for_save(points_layer, {}) + + assert out["project"] == str(project_root) + assert out["root"] == str(inferred_root) + + +def test_enrich_metadata_falls_back_to_project_labeled_data_image_name_folder( + monkeypatch, + workflow_factory, + points_layer, + tmp_path, +): + project_root = tmp_path / "project" + dataset_dir = project_root / "labeled-data" / "session42" + dataset_dir.mkdir(parents=True) + config_path = project_root / "config.yaml" + config_path.write_text("dummy", encoding="utf-8") + + monkeypatch.setattr(save_mod, "resolve_project_root_from_config", lambda p: project_root) + monkeypatch.setattr(save_mod, "normalize_anchor_candidate", lambda src: None) + + lm = DummyLayerManager() + lm._active_image = SimpleNamespace( + source=SimpleNamespace(path=None), + name="session42", + metadata={}, + ) + + wf, _viewer, _lm, _trails = workflow_factory( + layer_manager=lm, + resolve_config=lambda _layer: config_path, + ) + + out = wf._enrich_points_metadata_for_save(points_layer, {}) + + assert out["project"] == str(project_root) + assert out["root"] == str(dataset_dir) diff --git a/src/napari_deeplabcut/_tests/ui/test_singleton_widget.py b/src/napari_deeplabcut/_tests/ui/test_singleton_widget.py new file mode 100644 index 00000000..ccf5bf7c --- /dev/null +++ b/src/napari_deeplabcut/_tests/ui/test_singleton_widget.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import pytest +from qtpy.QtWidgets import QWidget + +from napari_deeplabcut.ui.base_widget import ViewerSingletonWidget + + +class DummyWindow: + def __init__(self): + self.dock_widgets = {} + + +class DummyViewer: + def __init__(self): + self.window = DummyWindow() + + +class Wrapper: + def __init__(self, wrapped=None, *, use_obj=False): + if use_obj: + self._obj = wrapped + else: + self.__wrapped__ = wrapped + + +class DockWrapper: + def __init__(self, widget): + self._widget = widget + + def widget(self): + return self._widget + + +class TestSingleton(ViewerSingletonWidget): + def __init__(self, napari_viewer): + if not self._singleton_prepare_init(napari_viewer=napari_viewer): + return + + super().__init__() + self._singleton_finalize_init() + + self.viewer = self.canonical_viewer(napari_viewer) + self.init_count = getattr(self, "init_count", 0) + 1 + + +class OtherSingleton(ViewerSingletonWidget): + def __init__(self, napari_viewer): + if not self._singleton_prepare_init(napari_viewer=napari_viewer): + return + + super().__init__() + self._singleton_finalize_init() + + self.viewer = self.canonical_viewer(napari_viewer) + self.init_count = getattr(self, "init_count", 0) + 1 + + +@pytest.fixture(autouse=True) +def clear_singleton_registry(): + """ + Keep tests isolated because the registry is class-level global state. + """ + ViewerSingletonWidget._instances_by_cls.clear() + yield + ViewerSingletonWidget._instances_by_cls.clear() + + +def test_extract_viewer_from_call_positional(): + viewer = DummyViewer() + assert ViewerSingletonWidget._extract_viewer_from_call((viewer,), {}) is viewer + + +def test_extract_viewer_from_call_napari_viewer_kwarg(): + viewer = DummyViewer() + assert ViewerSingletonWidget._extract_viewer_from_call((), {"napari_viewer": viewer}) is viewer + + +def test_extract_viewer_from_call_viewer_kwarg(): + viewer = DummyViewer() + assert ViewerSingletonWidget._extract_viewer_from_call((), {"viewer": viewer}) is viewer + + +def test_extract_viewer_from_call_none_when_missing(): + assert ViewerSingletonWidget._extract_viewer_from_call((), {}) is None + + +def test_canonical_viewer_unwraps_wrapped_chain(): + viewer = DummyViewer() + wrapped = Wrapper(Wrapper(viewer)) + assert ViewerSingletonWidget.canonical_viewer(wrapped) is viewer + + +def test_canonical_viewer_unwraps_obj_chain(): + viewer = DummyViewer() + wrapped = Wrapper(Wrapper(viewer, use_obj=True), use_obj=True) + assert ViewerSingletonWidget.canonical_viewer(wrapped) is viewer + + +def test_canonical_viewer_stops_on_self_wrapped(): + wrapped = Wrapper(None) + wrapped.__wrapped__ = wrapped + assert ViewerSingletonWidget.canonical_viewer(wrapped) is wrapped + + +@pytest.mark.usefixtures("qtbot") +def test_same_viewer_same_subclass_returns_same_instance(qtbot): + viewer = DummyViewer() + + w1 = TestSingleton(viewer) + qtbot.addWidget(w1) + + w2 = TestSingleton(viewer) + + assert w1 is w2 + assert w1.init_count == 1 + assert w2.init_count == 1 + + +@pytest.mark.usefixtures("qtbot") +def test_get_or_create_returns_existing_instance(qtbot): + viewer = DummyViewer() + + w1 = TestSingleton(viewer) + qtbot.addWidget(w1) + + w2 = TestSingleton.get_or_create(viewer) + + assert w1 is w2 + assert w1.init_count == 1 + + +@pytest.mark.usefixtures("qtbot") +def test_get_existing_accepts_wrapped_viewer(qtbot): + viewer = DummyViewer() + proxy = Wrapper(viewer) + + w = TestSingleton(viewer) + qtbot.addWidget(w) + + assert TestSingleton.get_existing(proxy) is w + + +@pytest.mark.usefixtures("qtbot") +def test_different_subclasses_have_independent_singletons(qtbot): + viewer = DummyViewer() + + w1 = TestSingleton(viewer) + qtbot.addWidget(w1) + + w2 = OtherSingleton(viewer) + qtbot.addWidget(w2) + + assert w1 is not w2 + assert isinstance(w1, TestSingleton) + assert isinstance(w2, OtherSingleton) + + +@pytest.mark.usefixtures("qtbot") +def test_is_docked_true_when_widget_directly_registered(qtbot): + viewer = DummyViewer() + widget = TestSingleton(viewer) + qtbot.addWidget(widget) + + viewer.window.dock_widgets["a"] = widget + + assert TestSingleton.is_docked(viewer, widget) is True + + +@pytest.mark.usefixtures("qtbot") +def test_is_docked_true_when_wrapper_returns_widget(qtbot): + viewer = DummyViewer() + widget = TestSingleton(viewer) + qtbot.addWidget(widget) + + viewer.window.dock_widgets["a"] = DockWrapper(widget) + + assert TestSingleton.is_docked(viewer, widget) is True + + +@pytest.mark.usefixtures("qtbot") +def test_is_docked_false_when_widget_not_present(qtbot): + viewer = DummyViewer() + widget = TestSingleton(viewer) + qtbot.addWidget(widget) + + other = QWidget() + qtbot.addWidget(other) + + assert TestSingleton.is_docked(viewer, widget) is False + + +@pytest.mark.usefixtures("qtbot") +def test_get_existing_returns_none_after_widget_deleted(qtbot): + viewer = DummyViewer() + + widget = TestSingleton(viewer) + qtbot.addWidget(widget) + + assert TestSingleton.get_existing(viewer) is widget + + widget.deleteLater() + qtbot.waitUntil(lambda: TestSingleton.get_existing(viewer) is None, timeout=1000) + + assert TestSingleton.get_existing(viewer) is None diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index 7ab286e2..9c9c4a4d 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -17,10 +17,6 @@ TODO: And general dev notes: - The saving workflow is crammed into save_layers_dialog() right now, and should move to a dedicated e.g. PointsLayerSaveFactory class in a dedicated file. -- Project/root/paths/image-meta/points-meta synchronization should be centralized. - It is too distributed right now, and it can be unclear "which truth" is authoritative. - Some sort of context-manager class would likely help. -- Maybe a dedicated layer lifecycle system would help, for layer adoption and setup. - Something that owns UI sync state (what to refresh, why, when) could help with the heavy wiring. - I'd suggest keeping in this file: - color/label mode, menu/help actions, widget visibility and user interaction hooks. @@ -30,185 +26,130 @@ from __future__ import annotations import logging -from contextlib import contextmanager from copy import deepcopy from datetime import datetime -from functools import cached_property, partial +from functools import cached_property from pathlib import Path -from types import MethodType import matplotlib.pyplot as plt import numpy as np -from napari.layers import Image, Points, Tracks +from napari.layers import Image, Points from napari.utils.events import Event -from napari.utils.history import get_save_history, update_save_history -from pydantic import ValidationError from qtpy.QtCore import QSettings, QSignalBlocker, Qt, QTimer from qtpy.QtGui import QAction from qtpy.QtWidgets import ( QButtonGroup, QCheckBox, - QFileDialog, QGridLayout, QGroupBox, QHBoxLayout, - QInputDialog, QLabel, QMessageBox, QPushButton, QRadioButton, QVBoxLayout, - QWidget, ) -import napari_deeplabcut.core.io as io -from napari_deeplabcut import misc -from napari_deeplabcut.config import settings -from napari_deeplabcut.config.keybinds import ( +from . import misc +from .config import settings +from .config.keybinds import ( install_global_points_keybindings, - install_points_layer_keybindings, ) -from napari_deeplabcut.config.models import DLCHeaderModel, ImageMetadata, PointsMetadata -from napari_deeplabcut.core import keypoints -from napari_deeplabcut.core.config_sync import ( +from .config.models import ImageMetadata +from .core import io, keypoints +from .core.config_sync import ( load_point_size_from_config, resolve_config_path_from_layer, save_point_size_to_config, ) -from napari_deeplabcut.core.conflicts import compute_overwrite_report_for_points_save -from napari_deeplabcut.core.layer_versioning import mark_layer_presentation_changed -from napari_deeplabcut.core.layers import ( +from .core.layer_lifecycle import ( + MergeDecisionRequest, + MergeDecisionResult, + MergeDisposition, + PointsLayerSetupRequest, + get_or_create_layer_manager, +) +from .core.layer_versioning import mark_layer_presentation_changed +from .core.layers import ( PointsInteractionEvent, PointsInteractionObserver, compute_label_progress, - find_relevant_image_layer, get_first_points_layer, get_points_layer_with_tables, get_uniform_point_size, infer_folder_display_name, - is_machine_layer, set_uniform_point_size, ) -from napari_deeplabcut.core.metadata import ( - MergePolicy, - apply_project_paths_override_to_points_meta, - infer_image_root, - migrate_points_layer_metadata, +from .core.metadata import ( read_points_meta, - sync_points_from_image, - write_points_meta, -) -from napari_deeplabcut.core.project_paths import ( - PathMatchPolicy, - coerce_paths_to_dlc_row_keys, - dataset_folder_has_files, - find_nearest_config, - resolve_project_root_from_config, - target_dataset_folder_for_config, -) -from napari_deeplabcut.core.provenance import ( - apply_gt_save_target, - is_projectless_folder_association_candidate, - requires_gt_promotion, - suggest_human_placeholder, -) -from napari_deeplabcut.core.remap import remap_layer_data_by_paths -from napari_deeplabcut.core.sidecar import ( - get_default_scorer, - set_default_scorer, ) -from napari_deeplabcut.core.trails import TrailsController, safe_folder_anchor_from_points_layer -from napari_deeplabcut.napari_compat import ( +from .core.trails import TrailsController +from .napari_compat import ( apply_points_layer_ui_tweaks, - install_add_wrapper, - install_paste_patch, patch_color_manager_guess_continuous, register_points_action, ) -from napari_deeplabcut.napari_compat.points_layer import make_paste_data -from napari_deeplabcut.ui import dialogs as ui_dialogs -from napari_deeplabcut.ui.color_scheme_display import ColorSchemePanel -from napari_deeplabcut.ui.cropping import ( +from .ui.base_widget import ViewerSingletonWidget +from .ui.color_scheme_display import ColorSchemePanel +from .ui.cropping import ( build_video_action_menu, handle_apply_crop_toggled, - resolve_project_path_from_image_layer, run_extract_current_frame, run_store_crop_coordinates, update_video_panel_context, ) -from napari_deeplabcut.ui.debug_window import DebugTextWindow, make_issue_report_provider -from napari_deeplabcut.ui.dialogs import Shortcuts, Tutorial -from napari_deeplabcut.ui.labels_and_dropdown import ( +from .ui.debug_window import DebugTextWindow, make_issue_report_provider +from .ui.dialogs import Shortcuts, Tutorial +from .ui.labels_and_dropdown import ( DropdownMenu, KeypointsDropdownMenu, ) -from napari_deeplabcut.ui.layer_stats import LayerStatusPanel -from napari_deeplabcut.ui.plots.trajectory import TrajectoryMatplotlibCanvas -from napari_deeplabcut.utils.debug import get_debug_recorder, install_debug_recorder +from .ui.layer_stats import LayerStatusPanel +from .ui.plots.trajectory import TrajectoryMatplotlibCanvas +from .ui.ui_dialogs.save import PointsLayerSaveWorkflow +from .utils.debug import get_debug_recorder, install_debug_recorder, log_timing +from .utils.deprecations import deprecated logger = logging.getLogger("napari-deeplabcut._widgets") # logger.setLevel(logging.DEBUG) # FIXME @C-Achard temp remove before merging -def _prompt_for_scorer(parent_widget, *, anchor: str, suggested: str) -> str | None: - """Prompt user for a scorer name. Returns non-empty string or None if cancelled.""" - text, ok = QInputDialog.getText( - parent_widget, - "Choose scorer", - "No DLC config.yaml scorer found.\n" - "Please enter a scorer name for the CollectedData file.\n\n" - "Tip: Use your name or a stable lab identifier.\n" - "(We strongly discourage keeping the generic 'human_xxxxxx'.)", - text=suggested, - ) - if not ok: - return None - scorer = (text or "").strip() - if not scorer: - return None - return scorer - - -@contextmanager -def _temporary_layer_metadata(layer: Points, metadata: dict): - old_metadata = dict(layer.metadata or {}) - layer.metadata = metadata - try: - yield - finally: - layer.metadata = old_metadata - - -class KeypointControls(QWidget): +class KeypointControls(ViewerSingletonWidget): def __init__(self, napari_viewer): + if not self._singleton_prepare_init(napari_viewer=napari_viewer): + return + super().__init__() + self._singleton_finalize_init() + self.viewer = self.canonical_viewer(napari_viewer) + # Monkey-patch napari continuous variable type guess patch_color_manager_guess_continuous() - self._is_saved = False - self.viewer = napari_viewer - - self.viewer.layers.events.inserted.connect(self.on_insert) - self.viewer.layers.events.removed.connect(self.on_remove) + # Layer lifecycle manager + self.layer_manager = get_or_create_layer_manager(self.viewer) + self.layer_manager.set_merge_decision_provider(self) + ## Hook up signals for layer lifecycle events as needed, e.g.: + self.layer_manager.session_conflict_rejected.connect(self._on_session_conflict_detected) + self.layer_manager.refresh_video_panel_requested.connect(self._refresh_video_panel_context) + self.layer_manager.refresh_layer_status_requested.connect(self._refresh_layer_status_panel) + self.layer_manager.video_widget_visibility_requested.connect(self._on_video_widget_visibility_requested) + self.layer_manager.move_image_layer_to_bottom_requested.connect(self._move_image_layer_to_bottom) + + self.layer_manager.points_layer_setup_requested.connect(self._on_points_layer_setup_requested) + self.layer_manager.points_layers_merged_requested.connect(self._on_points_layers_merged_requested) + self.layer_manager.points_layer_removed_requested.connect(self._on_points_layer_removed_requested) + self.layer_manager.tracks_layer_removed_requested.connect(self._on_tracks_layer_removed_requested) + ## Timers + self._timers = {} + self._temp_timers = set() + self.destroyed.connect(self._cleanup_timers) ## Debug ## self._debug_recorder = install_debug_recorder() self._debug_window = None - - show_debug_action = QAction("&Generate napari-dlc log", self) - show_debug_action.setToolTip("Show a debug report with recent plugin logs") - show_debug_action.triggered.connect(self._show_debug_window) - self.viewer.window.help_menu.addAction(show_debug_action) ########### - - # self.viewer.window.qt_viewer._get_and_try_preferred_reader = MethodType( - # _get_and_try_preferred_reader, - # self.viewer.window.qt_viewer, - # ) - # Project data - self._project_path: str | None = None - status_bar = self.viewer.window._qt_window.statusBar() self.last_saved_label = QLabel("") self.last_saved_label.hide() @@ -217,8 +158,6 @@ def __init__(self, napari_viewer): self._color_mode = keypoints.ColorMode.default() self._label_mode = keypoints.LabelMode.default() - # Hold references to the KeypointStores - self._stores = {} # Intercept close event if data were not saved qt_win = self.viewer.window._qt_window orig_close_event = qt_win.closeEvent @@ -239,12 +178,13 @@ def _close_event(event): # Storage for extra image metadata that are relevant to other layers. # These are updated anytime images are added to the Viewer # and passed on to the other layers upon creation. - self._image_meta = ImageMetadata() # Storage for layers requiring recoloring self._recolor_pending = set() # Add some more controls self._layout = QVBoxLayout(self) + # TODO just use a single synced DropdownMenu instance instead of one per layer + # it would reduce tracking and centralize the menu logic. Updating should be cheap anyways. self._menus = [] self._layer_to_menu = {} self.viewer.layers.selection.events.active.connect(self.on_active_layer_change) @@ -259,7 +199,8 @@ def _close_event(event): self._video_group.apply_crop_cb.toggled.connect(self._on_apply_crop_toggled) self.viewer.dims.events.current_step.connect(lambda event: self._refresh_video_panel_context()) self.viewer.layers.selection.events.active.connect(lambda event: self._refresh_video_panel_context()) - QTimer.singleShot(0, self._refresh_video_panel_context) + # QTimer.singleShot(0, self._refresh_video_panel_context) + self._schedule_once("startup_video_panel_refresh", 0, self._refresh_video_panel_context) # form helper display self._keypoint_mapping_button = None @@ -283,7 +224,7 @@ def _close_event(event): self._trail_cb.stateChanged.connect(self._on_show_trails_toggled) self._trails_controller = TrailsController( self.viewer, - managed_points_layers_getter=lambda: tuple(self._stores.keys()), + managed_points_layers_getter=self.layer_manager.managed_points_layers, color_mode_getter=lambda: self.color_mode, resolved_cycle_getter=self._resolved_cycle_for_layer, ) @@ -326,7 +267,7 @@ def _close_event(event): self._color_scheme_panel = ColorSchemePanel( viewer=self.viewer, get_color_mode=lambda: self.color_mode, - get_header_model=self._get_header_model_from_metadata, + get_header_model=self.layer_manager.get_header_model_from_metadata, parent=self, ) self._color_scheme_display = self.viewer.window.add_dock_widget( @@ -348,6 +289,18 @@ def _close_event(event): watch_content=False, ) self._points_interactions.install() + + self._save_workflow = PointsLayerSaveWorkflow( + parent=self, + viewer=self.viewer, + layer_manager=self.layer_manager, + trails_controller=self._trails_controller, + trail_checkbox_getter=lambda: self._trail_cb.isChecked(), + resolve_config_path_for_layer=self._resolve_config_path_for_layer, + current_project_path_getter=lambda: self.layer_manager.project_path, + current_image_meta_getter=lambda: self.layer_manager.image_meta, + logger=logger, + ) ### UI setup ends here # Modes init @@ -367,17 +320,7 @@ def _close_event(event): elif "save all layers" in action_name: self.viewer.window.file_menu.removeAction(action) - # Add action to show the walkthrough again - launch_tutorial = QAction("&Launch napari-dlc tutorial", self) - launch_tutorial.triggered.connect(self.start_tutorial) - self.viewer.window.view_menu.addAction(launch_tutorial) - - # Add action to view keyboard shortcuts - display_shortcuts_action = QAction("&Show napari-dlc shortcuts", self) - display_shortcuts_action.triggered.connect(self.display_shortcuts) - self.viewer.window.help_menu.addAction(display_shortcuts_action) - # Install global keybinds - install_global_points_keybindings() + self._add_plugin_actions() # Hide some unused viewer buttons # NOTE (future) do we truly want to disable these ? Tracking util may need to create new points layers @@ -388,30 +331,108 @@ def _close_event(event): # self.viewer.window._qt_viewer.layerButtons.newPointsButton.setDisabled(True) self.viewer.window._qt_viewer.layerButtons.newLabelsButton.setDisabled(True) - # Disable tutorial on first launch for now. Can be accessed any time from the button. - # if self.settings.value("first_launch", True) and not os.environ.get( - # "NAPARI_DLC_HIDE_TUTORIAL", False - # ): - # QTimer.singleShot(10, self.start_tutorial) - # self.settings.setValue("first_launch", False) - # Slightly delay docking so it is shown underneath the KeypointsControls widget - # NOTE while a timer may seem hacky, it is a simple, one-line solution that minimizes intrusion - # There are to my knowledge no other way that is as concise and clean - # (Of course this will be a problem if we start using it everywhere so do not reuse lightly) - QTimer.singleShot(10, self.silently_dock_canvas) + # NOTE we may want to switch to timers owned by the widget instead of fire-and-forget + # QTimer.singleShot(10, self.silently_dock_canvas) + self._schedule_once("startup_dock_canvas", 10, self.silently_dock_canvas) # If layers already exist (user loaded data before opening this widget), # adopt them so keypoint controls take ownership immediately. - QTimer.singleShot(0, self._adopt_existing_layers) + # QTimer.singleShot(0, self._adopt_existing_layers) + self.layer_manager.schedule_initial_adoption() # Refresh layers stats widget - QTimer.singleShot(0, self._refresh_layer_status_panel) + # QTimer.singleShot(0, self._refresh_layer_status_panel) + self._schedule_once("initial_layer_status_refresh", 0, self._refresh_layer_status_panel) @cached_property def settings(self): return QSettings() + def _cleanup_timers(self, *_args) -> None: + # FIXME @C-Achard move to singleton base class for other widgets to reuse, and maybe add logs + """Stop/delete all timers owned by this widget. + + This is intentionally defensive: by teardown time, some underlying + C++ objects may already be in the process of destruction. + """ + for timer in list(getattr(self, "_timers", {}).values()): + try: + timer.stop() + timer.deleteLater() + except RuntimeError: + pass + self._timers = {} + + for timer in list(getattr(self, "_temp_timers", set())): + try: + timer.stop() + timer.deleteLater() + except RuntimeError: + pass + self._temp_timers.clear() + + def _schedule_once(self, name: str, msec: int, callback) -> None: + """Schedule/coalesce a named single-shot callback owned by this widget.""" + timer = self._timers.get(name) + if timer is None: + timer = QTimer(self) + timer.setSingleShot(True) + timer.timeout.connect(callback) + self._timers[name] = timer + try: + timer.start(msec) + except RuntimeError: + # Widget/timer already in teardown + pass + + def _single_shot_owned(self, msec: int, callback) -> None: + """Schedule a one-off callback using a child QTimer tracked by this widget.""" + timer = QTimer(self) + timer.setSingleShot(True) + self._temp_timers.add(timer) + + def _fire(): + try: + callback() + finally: + self._temp_timers.discard(timer) + try: + timer.deleteLater() + except RuntimeError: + pass + + timer.timeout.connect(_fire) + try: + timer.start(msec) + except RuntimeError: + self._temp_timers.discard(timer) + + def _flush_recolor_pending(self) -> None: + pending = tuple(self._recolor_pending) + self._recolor_pending.clear() + for layer in pending: + try: + if layer not in self.viewer.layers: + continue + self._apply_points_coloring_from_metadata(layer) + except RuntimeError: + logger.debug("Skipping recolor for deleted/tearing-down layer", exc_info=True) + except Exception: + logger.debug("Failed deferred recolor", exc_info=True) + + @deprecated(details="Use the layer manager's project context.", replacement="layer_manager.image_meta") + @property + def _image_meta(self) -> ImageMetadata: + """Compatibility shim: lifecycle-owned image context now lives in manager.""" + return self.layer_manager.image_meta + + @deprecated(details="Use the layer manager's project context.", replacement="layer_manager.project_path") + @property + def _project_path(self) -> str | None: + """Compatibility shim: lifecycle-owned project path now lives in manager.""" + return self.layer_manager.project_path + @register_points_action("Change labeling mode") def cycle_through_label_modes(self, *args): self.label_mode = next(keypoints.LabelMode) @@ -449,7 +470,7 @@ def color_mode(self): def color_mode(self, mode: str | keypoints.ColorMode): self._color_mode = keypoints.ColorMode(mode) - for layer in list(self._stores.keys()): + for layer in list(self.layer_manager.managed_points_layers()): if isinstance(layer, Points) and layer.metadata: self._apply_points_coloring_from_metadata(layer) @@ -468,29 +489,22 @@ def color_mode(self, mode: str | keypoints.ColorMode): self._update_color_scheme() self._trails_controller.on_points_visual_inputs_changed(checkbox_checked=self._trail_cb.isChecked()) - def _is_multianimal(self, layer) -> bool: - if layer is None or not isinstance(layer, Points): - return False - - md = layer.metadata or {} - hdr = self._get_header_model_from_metadata(md) - if hdr is None: - return False - - try: - inds = hdr.individuals - return bool(inds and len(inds) > 0 and str(inds[0]) != "") - except Exception: - return False - def _active_layer_is_multianimal(self) -> bool: """Returns: whether the active layer is a multi-animal points layer""" for layer in self.viewer.layers.selection: - if self._is_multianimal(layer): + if self.layer_manager.is_multianimal(layer): return True return False + def _on_session_conflict_detected(self, reason: str) -> None: + QMessageBox.warning( + self, + "A labeled data folder is already loaded!", + f"{reason}\n\n", + QMessageBox.Ok, + ) + def _show_debug_window(self) -> None: try: if self._debug_window is None: @@ -518,124 +532,52 @@ def _show_debug_window(self) -> None: # ######################## # # Layer setup core methods # # ######################## # + def resolve_merge(self, request: MergeDecisionRequest) -> MergeDecisionResult: + if not request.added_keypoints: + return MergeDecisionResult(disposition=MergeDisposition.KEEP_BOTH) - def _setup_image_layer(self, layer: Image, index: int | None = None, *, reorder: bool = True) -> None: - md = layer.metadata or {} - paths = md.get("paths") - if paths is None and io.is_video(layer.name): - self.video_widget.setVisible(True) + shared = "Do you want to hide the existing keypoints and add the new ones as a separate layer?" + text = f"{request.message}\n\n{shared}" if request.message else shared + answer = QMessageBox.question( + self, + "", + text, + QMessageBox.Yes | QMessageBox.No, + ) - self._update_image_meta_from_layer(layer) + if answer == QMessageBox.Yes: + return MergeDecisionResult(disposition=MergeDisposition.HIDE_EXISTING) - if not self._project_path: - self._cache_project_path_from_image_layer(layer) - if self._project_path is not None: + return MergeDecisionResult(disposition=MergeDisposition.KEEP_BOTH) + + def _on_points_layers_merged_requested(self, layers: tuple[Points, ...]) -> None: + """Refresh widget-owned UI after manager merged placeholder config into managed layers.""" + try: + # Refresh dropdown menus after header/bodypart changes. + for menu in self._menus: try: - layer.metadata = dict(layer.metadata or {}) - layer.metadata.setdefault("project", self._project_path) + menu._map_individuals_to_bodyparts() + menu._update_items() + except Exception: + logger.debug("Failed to refresh dropdown menu after config merge", exc_info=True) + + # Re-apply presentation metadata to affected layers. + for layer in layers: + try: + self._apply_points_coloring_from_metadata(layer) except Exception: logger.debug( - "Failed to set project path metadata on image layer %r", + "Failed to re-apply points coloring after merge for layer=%r", getattr(layer, "name", layer), exc_info=True, ) - self._sync_points_layers_from_image_meta() - self._refresh_video_panel_context() - logger.debug( - "Setup image layer=%r index=%s reorder=%s paths_count=%s root=%r", - getattr(layer, "name", layer), - index, - reorder, - len(md.get("paths") or []), - md.get("root"), - ) - - if reorder and index is not None: - QTimer.singleShot(10, partial(self._move_image_layer_to_bottom, index)) - - def _maybe_merge_config_points_layer(self, layer: Points) -> bool: - md = layer.metadata or {} - logger.debug( - "Maybe merge config points layer=%r project=%r stores=%d", - getattr(layer, "name", layer), - md.get("project"), - len(self._stores), - ) - if not md.get("project", "") or not self._stores: - return False - - new_metadata = md.copy() - - keypoints_menu = self._menus[0].menus["label"] - current_keypoint_set = {keypoints_menu.itemText(i) for i in range(keypoints_menu.count())} - hdr = self._get_header_model_from_metadata(new_metadata) - if hdr is None: - return False - new_keypoint_set = set(hdr.bodyparts) - diff = new_keypoint_set.difference(current_keypoint_set) - - if diff: - answer = QMessageBox.question(self, "", "Do you want to display the new keypoints only?") - if answer == QMessageBox.Yes: - self.viewer.layers[-2].shown = False - logger.debug( - "Merging config points layer=%r new_keypoints=%s", - getattr(layer, "name", layer), - sorted(diff), - ) - - self.viewer.status = f"New keypoint{'s' if len(diff) > 1 else ''} {', '.join(diff)} found." - for _layer, store in self._stores.items(): - pts = read_points_meta(_layer, migrate_legacy=True, drop_controls=True, drop_header=False) - if not hasattr(pts, "errors"): - updated = pts.model_copy(update={"header": hdr}) - write_points_meta(_layer, updated, merge_policy=MergePolicy.MERGE, fields={"header"}) - store.layer = _layer - - for menu in self._menus: - menu._map_individuals_to_bodyparts() - menu._update_items() - - QTimer.singleShot(10, self.viewer.layers.pop) - - # apply the new color cycles + recolor safely - for _layer, store in self._stores.items(): - _layer.metadata["config_colormap"] = new_metadata.get( - "config_colormap", _layer.metadata.get("config_colormap") - ) - _layer.metadata["face_color_cycles"] = new_metadata["face_color_cycles"] - _layer.metadata["colormap_name"] = new_metadata.get("colormap_name", _layer.metadata.get("colormap_name")) - mark_layer_presentation_changed(_layer) - self._apply_points_coloring_from_metadata(_layer) - store.layer = _layer - - self._update_color_scheme() - return True - - def _get_header_model_from_metadata(self, md: dict) -> DLCHeaderModel | None: - """Return DLCHeaderModel regardless of whether md['header'] is a model, dict payload, or MultiIndex.""" - if not isinstance(md, dict): - return None - hdr = md.get("header", None) - if hdr is None: - return None - - if isinstance(hdr, DLCHeaderModel): - logger.debug("Header is already a DLCHeaderModel instance.") - return hdr - - if isinstance(hdr, dict): - try: - return DLCHeaderModel.model_validate(hdr) - except Exception: - return None + self._update_color_scheme() + self._trails_controller.on_points_visual_inputs_changed(checkbox_checked=self._trail_cb.isChecked()) + self._refresh_layer_status_panel() - # fallback: allow MultiIndex / list-of-tuples / Index inputs - try: - return DLCHeaderModel(columns=hdr) except Exception: - return None + logger.debug("Failed to refresh widget state after merged points update", exc_info=True) @staticmethod def get_layer_controls(layer: Points) -> KeypointControls | None: @@ -645,179 +587,140 @@ def get_layer_controls(layer: Points) -> KeypointControls | None: def get_layer_store(layer: Points) -> keypoints.KeypointStore | None: return getattr(layer, "_dlc_store", None) - def _wire_points_layer(self, layer: Points) -> keypoints.KeypointStore | None: - if not self._validate_header(layer): - return None - existing = getattr(layer, "_dlc_store", None) - if existing is not None: - self._stores[layer] = existing - layer._dlc_controls = self - return existing - - # ensure presence of IO metadata for saving & routing - mig = migrate_points_layer_metadata(layer) - if hasattr(mig, "errors"): - logger.warning( - "Points metadata validation failed during wiring for layer=%r: %s", - getattr(layer, "name", layer), - mig, - ) - - store = keypoints.KeypointStore(self.viewer, layer) - self._stores[layer] = store - layer._dlc_store = store - layer._dlc_controls = self - - # default root/paths from current image meta if missing - if not layer.metadata.get("root") and self._image_meta.root: - layer.metadata["root"] = self._image_meta.root - if not layer.metadata.get("paths") and self._image_meta.paths: - layer.metadata["paths"] = self._image_meta.paths - - # save history - if root := layer.metadata.get("root"): - update_save_history(root) - - store._get_label_mode = lambda: self._label_mode - layer.text.visible = False + # ------------------------------------------------------------------ # + # UI-only hooks used by LayerLifecycleManager # + # ------------------------------------------------------------------ # - paste_func = make_paste_data(self, store=store) - install_paste_patch(layer, paste_func=paste_func) + def _set_points_controls_enabled(self, enabled: bool) -> None: + self._radio_box.setEnabled(enabled) + self._color_grp.setEnabled(enabled) + self._trail_cb.setEnabled(enabled) + self._show_traj_plot_cb.setEnabled(enabled) - add_impl = MethodType(keypoints._add, store) # bind store to add implementation - install_add_wrapper(layer, add_impl=add_impl, schedule_recolor=self._schedule_recolor) - - # store events / navigation - layer.events.add(query_next_frame=Event) - layer.events.query_next_frame.connect(store._advance_step) - - # Install keybinds - install_points_layer_keybindings(layer, self, store) + def _complete_points_layer_ui_setup(self, layer: Points, store: keypoints.KeypointStore) -> None: + """UI-only completion after lifecycle manager finished points setup.""" + if layer.metadata.get("tables", ""): + self._keypoint_mapping_button.show() - if len(self._stores) == 1 and self._is_multianimal(layer): - # set internal mode without triggering recolor storms - self._color_mode = keypoints.ColorMode.INDIVIDUAL - # update button state so UI matches (optional but good) - for btn in self._color_mode_selector.buttons(): - if btn.text().lower() == str(self._color_mode).lower(): - btn.setChecked(True) - break + selector = apply_points_layer_ui_tweaks(self.viewer, layer, dropdown_cls=DropdownMenu, plt_module=plt) + if selector is not None: + try: + selector.currentTextChanged.connect(self._update_colormap) + except Exception: + pass - # apply cycles (works even if empty; see method) self._apply_points_coloring_from_metadata(layer) - self._maybe_initialize_layer_point_size_from_config(layer) - self._connect_layer_status_events(layer) - # refresh trails if enabled (e.g. when merging a config points layer with trails metadata) self._trails_controller.on_points_layer_added_or_rewired(checkbox_checked=self._trail_cb.isChecked()) - # menus - self._form_dropdown_menus(store) + if layer not in self._layer_to_menu: + self._form_dropdown_menus(store) - # project path - proj = layer.metadata.get("project") - if proj: - self._project_path = proj + # proj = layer.metadata.get("project") # MOVED to LayerLifecycleManager + # if proj: + # self._project_path = proj - # enable GUI groups - self._radio_box.setEnabled(True) - self._color_grp.setEnabled(True) - self._trail_cb.setEnabled(True) - self._show_traj_plot_cb.setEnabled(True) - - md = layer.metadata or {} + self._set_points_controls_enabled(True) + self._update_color_scheme() logger.debug( - "Wire points layer=%r existing_store=%s project=%s root=%s len_paths=%s", + "Setup points layer=%r metadata_keys=%s", getattr(layer, "name", layer), - getattr(layer, "_dlc_store", None) is not None, - md.get("project"), - md.get("root"), - len(md.get("paths", [])), + sorted((layer.metadata or {}).keys()), ) - return store - - def _setup_points_layer(self, layer: Points, *, allow_merge: bool = True) -> None: - if not self._validate_header(layer): - return + def _on_points_layer_setup_requested(self, req: PointsLayerSetupRequest) -> None: + layer = req.layer + store = req.store - if allow_merge and self._maybe_merge_config_points_layer(layer): - return + try: + resources = self.layer_manager.attach_points_layer_runtime( + layer=layer, + store=store, + controls=self, + resolve_layer_by_id=self.layer_manager.resolve_live_layer, + get_label_mode=lambda: self._label_mode, + schedule_recolor=self._schedule_recolor, + existing_resources=req.existing_resources, + ) + req.runtime_resources = resources - if layer.metadata.get("tables", ""): - self._keypoint_mapping_button.show() + layer._dlc_controls = self - store = self._wire_points_layer(layer) - if store is None: - return + if self.layer_manager.managed_points_count() == 1 and self.layer_manager.is_multianimal(layer): + self._color_mode = keypoints.ColorMode.INDIVIDUAL + for btn in self._color_mode_selector.buttons(): + if btn.text().lower() == str(self._color_mode).lower(): + btn.setChecked(True) + break - selector = apply_points_layer_ui_tweaks(self.viewer, layer, dropdown_cls=DropdownMenu, plt_module=plt) - if selector is not None: - try: - selector.currentTextChanged.connect(self._update_colormap) - except Exception: - pass + self._maybe_initialize_layer_point_size_from_config(layer) + self._connect_layer_status_events(layer) + self._complete_points_layer_ui_setup(layer, store) - self._update_color_scheme() - logger.debug( - "Setup points layer=%r allow_merge=%s metadata_keys=%s", - getattr(layer, "name", layer), - allow_merge, - sorted((layer.metadata or {}).keys()), - ) + except Exception: + logger.debug( + "Failed to complete points layer setup for layer=%r", + getattr(layer, "name", layer), + exc_info=True, + ) - def _adopt_existing_layers(self) -> None: - """ - When the widget is opened after layers already exist, we need to - run the same initialization as if they had been inserted. - """ - # Iterate over a snapshot, because on_insert may modify layer order - logger.debug("Adopting existing layers count=%d", len(self.viewer.layers)) - layers_snapshot = list(self.viewer.layers) + def _on_video_widget_visibility_requested(self, visible: bool) -> None: + try: + self.video_widget.setVisible(bool(visible)) + except Exception: + logger.debug("Failed to update video widget visibility", exc_info=True) - for idx, layer in enumerate(layers_snapshot): - self._adopt_layer(layer, idx) + def _on_points_layer_removed_requested(self, layer: Points, remaining_points_layers: int) -> None: + self._on_points_layer_removed_ui(layer, remaining_points_layers=remaining_points_layers) - # After adoption, refresh UI state + def _on_tracks_layer_removed_requested(self, layer) -> None: try: - active = self.viewer.layers.selection.active - if active is not None: - # Force the GUI to update visibility of menus, etc. - self.on_active_layer_change(Event(type_name="active", value=active)) + was_trails = self._trails_controller.on_tracks_layer_removed(layer) except Exception: - pass + logger.debug("Failed to process tracks layer removal", exc_info=True) + return - # Important: refresh the trajectory plot from the final adopted state. - # This fixes the case where layers were loaded before the plugin opened - # (e.g. drag-and-drop DLC data triggering the reader automatically). - self._refresh_trajectory_plot_from_layers() + if was_trails: + with QSignalBlocker(self._trail_cb): + self._trail_cb.setChecked(False) - def _adopt_layer(self, layer, index: int) -> None: - """ - Run the relevant portion of on_insert() for an already-existing layer. - This avoids duplicating your logic and prevents reliance on napari's Event object. - """ - logger.debug( - "Adopt layer=%r type=%s index=%s", - getattr(layer, "name", layer), - type(layer).__name__, - index, - ) - if isinstance(layer, Image): - self._setup_image_layer(layer, index, reorder=True) - elif isinstance(layer, Points): - if layer not in self._stores: - self._setup_points_layer(layer, allow_merge=False) # typically don’t merge during adopt - if not isinstance(layer, Image): - self._remap_frame_indices(layer) - - def _validate_header(self, layer) -> bool: - res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=False) - if isinstance(res, ValidationError) or res.header is None: - self.viewer.status = ( - "This Points layer does not look like a DLC keypoints layer. Missing a valid DLC header." - ) - return False - return True + def _on_points_layer_removed_ui(self, layer: Points, *, remaining_points_layers: int) -> None: + """UI-only cleanup after lifecycle manager removed a managed Points layer.""" + with log_timing( + logger, + f"_on_points_layer_removed_ui total layer={getattr(layer, 'name', layer)!r}", + threshold_ms=0.01, + ): + with log_timing( + logger, + "color scheme update after points removal", + threshold_ms=0.01, + ): + self._update_color_scheme() + + with log_timing( + logger, + f"trails_controller.on_points_layer_removed layer={getattr(layer, 'name', layer)!r}", + threshold_ms=0.01, + ): + self._trails_controller.on_points_layer_removed(layer) + + if remaining_points_layers == 0: + with log_timing( + logger, + "points menu teardown", + threshold_ms=0.01, + ): + while self._menus: + menu = self._menus.pop() + try: + self._layout.removeWidget(menu) + except Exception: + pass + menu.deleteLater() + + self._layer_to_menu = {} + self._set_points_controls_enabled(False) + self.last_saved_label.hide() def _schedule_recolor(self, layer: Points) -> None: if not hasattr(self, "_recolor_pending"): @@ -828,13 +731,7 @@ def _schedule_recolor(self, layer: Points) -> None: self._recolor_pending.add(layer) - def _do(): - try: - self._apply_points_coloring_from_metadata(layer) - finally: - self._recolor_pending.discard(layer) - - QTimer.singleShot(0, _do) + self._schedule_once(f"recolor_{layer.name}", 0, self._flush_recolor_pending) def _ensure_traj_canvas_docked(self) -> None: """ @@ -910,16 +807,6 @@ def _on_points_interaction(self, event: PointsInteractionEvent) -> None: if {"selection", "active_layer", "layers"} & set(event.reasons): traj_canvas.sync_visible_lines_to_points_selection() - def _refresh_trajectory_plot_from_layers(self) -> None: - """ - Refresh trajectory plot from the current viewer state. - - Deferred through QTimer so it runs after layer adoption/remap settles. - """ - traj_canvas = self._safe_get_traj_canvas() - if traj_canvas is not None: - QTimer.singleShot(0, traj_canvas.refresh_from_viewer_layers) - def load_superkeypoints_diagram(self): points_layer = get_first_points_layer(self.viewer) if points_layer is None: @@ -963,7 +850,7 @@ def _map_keypoints(self, super_animal: str): xy = points_layer.data[:, 1:3] superkpts_dict = io.load_superkeypoints(super_animal) xy_ref = np.asarray(list(superkpts_dict.values()), dtype=float) - neighbors = keypoints._find_nearest_neighbors(xy, xy_ref) + neighbors = keypoints.find_nearest_neighbors(xy, xy_ref) found = neighbors != -1 if not np.any(found): return @@ -972,7 +859,7 @@ def _map_keypoints(self, super_animal: str): config_path = str(Path(project_path) / "config.yaml") cfg = io.load_config(config_path) conversion_tables = cfg.get("SuperAnimalConversionTables", {}) - hdr = self._get_header_model_from_metadata(points_layer.metadata or {}) + hdr = self.layer_manager.get_header_model_from_metadata(points_layer.metadata or {}) if hdr is None: return bdprts_map = map(str, hdr.bodyparts) @@ -993,226 +880,37 @@ def start_tutorial(self): def display_shortcuts(self): Shortcuts(self.viewer.window._qt_window.current(), viewer=self.viewer).show() - def _move_image_layer_to_bottom(self, index): - if (ind := index) != 0: - self.viewer.layers.move_selected(ind, 0) - self.viewer.layers.select_next() # Auto-select the Points layer + def _move_image_layer_to_bottom(self, layer: Image): + try: + if layer not in self.viewer.layers: + return + ind = list(self.viewer.layers).index(layer) + if ind != 0: + self.viewer.layers.selection.clear() + self.viewer.layers.selection.add(layer) + self.viewer.layers.move_selected(ind, 0) + self.viewer.layers.select_next() + except Exception: + logger.debug("Failed to move image layer to bottom", exc_info=True) # ------------------------------------------------------------------ # Metadata helpers (authoritative models + napari-friendly dict sync) # ------------------------------------------------------------------ - @staticmethod - def _layer_source_path(layer) -> str | None: - """Best-effort access to napari layer source path (may not exist).""" - try: - src = getattr(layer, "source", None) - p = getattr(src, "path", None) if src is not None else None - return str(p) if p else None - except Exception: - return None - - def _update_image_meta_from_layer(self, layer: Image) -> None: - """ - Update authoritative self._images_meta using an Image layer. - Also keep a dict-like subset synced for other layers (non-breaking). - """ - md = layer.metadata or {} - - paths = md.get("paths") - shape = None - try: - shape = layer.level_shapes[0] - except Exception: - shape = None - - root = infer_image_root( - explicit_root=md.get("root"), - paths=paths, - source_path=self._layer_source_path(layer), - ) - - incoming = ImageMetadata( - paths=list(paths) if paths else None, - root=str(root) if root else None, - shape=tuple(shape) if shape is not None else None, - name=getattr(layer, "name", None), - ) - - # Merge without clobbering already-known values - # (same behavior as old "only set if non-empty") - base = self._image_meta - merged = base.model_copy(deep=True) - for field, value in incoming.model_dump().items(): - if getattr(merged, field) in (None, "", []) and value not in (None, "", []): - setattr(merged, field, value) - - self._image_meta = merged - - def _sync_points_layers_from_image_meta(self) -> None: - """ - Ensure all Points layers have core fields required for saving. - - Adapter-based flow: - - read validated points meta (visible failures) - - apply sync logic against authoritative self._image_meta - - write back validated dict via gateway - """ - if self._image_meta is None: - return - - for ly in list(self.viewer.layers): - if not isinstance(ly, Points): - continue - - if ly.metadata is None: - ly.metadata = {} - - # 1) Read + migrate legacy (io from source_h5, header coercion, etc.) - res = read_points_meta(ly, migrate_legacy=True, drop_controls=False, drop_header=False) - if hasattr(res, "errors"): # ValidationError duck-typing - logger.warning( - "Points metadata validation failed during sync for layer=%r: %s", - getattr(ly, "name", ly), - res, - ) - continue - - pts_model: PointsMetadata = res - - # 2) Sync missing fields from image meta (pure model transform) - synced = sync_points_from_image(self._image_meta, pts_model) - - # 3) Write back through gateway (fill missing only; never clobber) - out = write_points_meta( - ly, - synced, - merge_policy=MergePolicy.MERGE_MISSING, - migrate_legacy=True, - validate=True, - ) - if hasattr(out, "errors"): - logger.warning( - "Failed to write synced points metadata for layer=%r: %s", - getattr(ly, "name", ly), - out, - ) - def _resolve_config_path_for_layer(self, layer: Points | None) -> Path | None: if layer is None: return None - image_layer = find_relevant_image_layer(self.viewer) + image_layer = self.layer_manager.active_dlc_image_layer() return resolve_config_path_from_layer( layer, - fallback_project=self._project_path, - fallback_root=self._image_meta.root, + fallback_project=self.layer_manager.project_path, + fallback_root=self.layer_manager.image_root, image_layer=image_layer, prefer_project_root=True, max_levels=5, ) - def _maybe_prepare_project_path_override_metadata(self, layer: Points) -> tuple[dict | None, bool]: - """ - Optionally prepare save-time metadata by associating a project-less labeled - folder with an explicit DLC project chosen via config.yaml. - - Returns - ------- - tuple[dict | None, bool] - (overridden_metadata, abort_save) - - - (None, False): feature not applicable; continue normal save - - (metadata, False): apply metadata override and continue - - (None, True): user cancelled or operation was refused; abort save - """ - res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=False) - if isinstance(res, ValidationError): - return None, False - - pts_meta: PointsMetadata = res - paths = pts_meta.paths or [] - if not paths: - return None, False - - if not is_projectless_folder_association_candidate(pts_meta): - return None, False - - source_root = pts_meta.root - if not source_root: - return None, False - - try: - source_root_path = Path(source_root).expanduser().resolve(strict=False) - except Exception: - source_root_path = Path(source_root) - - # NOTE: @C-Achard 2026-03-27 Currently does not let user choose - # a different dataset name than the source folder, - # to keep a lightweight workflow. - # This could be allowed in the future if there is demand. - dataset_name = source_root_path.name - if not dataset_name: - return None, False - - initial_dir = self._project_path or pts_meta.project or str(source_root_path) - dialog_result = ui_dialogs.prompt_for_project_config_for_save(self, initial_dir=initial_dir) - - if dialog_result.action is ui_dialogs.ProjectConfigPromptAction.CANCEL: - logger.debug("User cancelled project association prompt.") - return None, True # abort save - - if dialog_result.action is ui_dialogs.ProjectConfigPromptAction.SKIP: - logger.debug("User chose to continue without project association.") - return None, False # continue normal save path - - if dialog_result.action is not ui_dialogs.ProjectConfigPromptAction.ASSOCIATE: - logger.warning("Unexpected project association dialog result: %r", dialog_result) - return None, True # fail safe: abort save - - config_path = dialog_result.config_path - if not config_path: - logger.warning("Project association result was ASSOCIATE but config_path was empty.") - return None, True # fail safe: abort save - - project_root = resolve_project_root_from_config(config_path) - if project_root is None: - QMessageBox.warning( - self, - "Invalid project configuration", - "The selected file is not a valid DeepLabCut config.yaml or project root. " - "The save operation has been cancelled.", - ) - return None, True - - target_folder = target_dataset_folder_for_config(config_path, dataset_name=dataset_name) - if dataset_folder_has_files(target_folder): - ui_dialogs.warn_existing_dataset_folder_conflict(self, target_folder=target_folder) - return None, True # refuse the save - - rewritten_paths, unresolved = coerce_paths_to_dlc_row_keys( - paths, - source_root=source_root_path, - dataset_name=dataset_name, - ) - - if not ui_dialogs.maybe_confirm_dataset_path_rewrite( - self, - project_root=project_root, - dataset_name=dataset_name, - n_paths=len(paths), - n_unresolved=len(unresolved), - ): - return None, True # user declined - - overridden = apply_project_paths_override_to_points_meta( - pts_meta, - project_root=project_root, - rewritten_paths=rewritten_paths, - ) - - return overridden.model_dump(mode="python", exclude_none=True), False - def _show_color_scheme(self): show = self._view_scheme_cb.isChecked() self._color_scheme_display.setVisible(show) @@ -1227,7 +925,8 @@ def _current_dlc_points_layer(self) -> Points | None: except Exception: return None - if isinstance(res, ValidationError): + # if isinstance(res, ValidationError): + if hasattr(res, "errors"): return None if getattr(res, "header", None) is None: @@ -1238,13 +937,20 @@ def _current_dlc_points_layer(self) -> Points | None: def _refresh_layer_status_panel(self) -> None: active_layer = self.viewer.layers.selection.active active_dlc_points = self._current_dlc_points_layer() - active_image = find_relevant_image_layer(self.viewer) + active_image = self.layer_manager.active_dlc_image_layer() + fallback_n_frames = None + try: + if active_image is not None and active_image.data.ndim == 4: # (T, H, W, C) + fallback_n_frames = int(active_image.data.shape[0]) + except Exception: + logger.debug("Refresh layer stats - Failed to determine frame count from active image layer", exc_info=True) folder_name = infer_folder_display_name( active_image if active_image is not None else active_layer, - fallback_root=self._image_meta.root, + fallback_root=self.layer_manager.image_root, ) - self._layer_status_panel.set_folder_name(folder_name) + project_path = self.layer_manager.project_path + self._layer_status_panel.set_folder_name(folder_name, full_path=project_path) # No active layer or not a Points layer at all if active_layer is None or not isinstance(active_layer, Points): @@ -1259,7 +965,11 @@ def _refresh_layer_status_panel(self) -> None: self._layer_status_panel.set_point_size_enabled(True) self._layer_status_panel.set_point_size(get_uniform_point_size(active_dlc_points)) - progress = compute_label_progress(active_dlc_points, fallback_paths=self._image_meta.paths) + progress = compute_label_progress( + active_dlc_points, + fallback_paths=self.layer_manager.image_paths, + fallback_n_frames=fallback_n_frames, + ) self._layer_status_panel.set_progress_summary(progress=progress) def _on_active_points_size_changed(self, size: int) -> None: @@ -1328,70 +1038,38 @@ def _connect_layer_status_events(self, layer: Points) -> None: except Exception: pass - def _form_help_buttons(self): - layout = QVBoxLayout() - help_buttons_layout = QHBoxLayout() - self.show_shortcuts_btn = QPushButton("View shortcuts") - self.show_shortcuts_btn.clicked.connect(self.display_shortcuts) - help_buttons_layout.addWidget(self.show_shortcuts_btn) - self.tutorial_btn = QPushButton("Start tutorial") - self.tutorial_btn.clicked.connect(self.start_tutorial) - help_buttons_layout.addWidget(self.tutorial_btn) - layout.addLayout(help_buttons_layout) - self._keypoint_mapping_button = QPushButton("Load superkeypoints diagram") - self._load_superkeypoints_action = self._keypoint_mapping_button.clicked.connect( - self.load_superkeypoints_diagram - ) - self._keypoint_mapping_button.hide() - layout.addWidget(self._keypoint_mapping_button) - return layout - - def _refresh_video_panel_context(self) -> None: - update_video_panel_context(self.viewer, self._video_group) - - def _cache_project_path_from_image_layer(self, layer: Image) -> None: - """Best-effort cache of project path from an image/video layer.""" - project_path = resolve_project_path_from_image_layer(layer) - if project_path is None: - return - - self._project_path = project_path - try: - layer.metadata = dict(layer.metadata or {}) - layer.metadata.setdefault("project", self._project_path) - except Exception: - logger.debug( - "Failed to set project path metadata on image layer %r", - getattr(layer, "name", layer), - exc_info=True, - ) + # ------------------------------------------------------------------ + # UI setup + # ------------------------------------------------------------------ + def _add_plugin_actions(self): + # Help menu + self.viewer.window.help_menu.addSeparator() + # Add action to show the walkthrough again + launch_tutorial = QAction("&Launch napari-dlc tutorial", self) + launch_tutorial.triggered.connect(self.start_tutorial) + self.viewer.window.help_menu.addAction(launch_tutorial) - self._refresh_video_panel_context() + # Add action to view keyboard shortcuts + display_shortcuts_action = QAction("&Show napari-dlc shortcuts", self) + display_shortcuts_action.triggered.connect(self.display_shortcuts) + self.viewer.window.help_menu.addAction(display_shortcuts_action) + # Install global keybinds hook + install_global_points_keybindings() - def _extract_single_frame(self, *args): - ok, msg = run_extract_current_frame( - self.viewer, - self._video_group, - validate_points_layer=self._validate_header, - ) - self.viewer.status = msg - self._refresh_video_panel_context() + # Add debug action to generate a log report for troubleshooting + show_debug_action = QAction("&Generate napari-dlc log", self) + show_debug_action.setToolTip("Show a debug report with recent plugin logs") + show_debug_action.triggered.connect(self._show_debug_window) + self.viewer.window.help_menu.addAction(show_debug_action) - def _on_apply_crop_toggled(self, checked) -> None: - handle_apply_crop_toggled(self.viewer, self._video_group, bool(checked)) - self._refresh_video_panel_context() + def _sync_dropdown_visibility(self) -> None: + active = self.viewer.layers.selection.active + menu_idx = -1 + if active is not None and isinstance(active, Points): + menu_idx = self._layer_to_menu.get(active, -1) - def _store_crop_coordinates(self, *args): - ok, msg, project_path = run_store_crop_coordinates( - self.viewer, - self._video_group, - explicit_project_path=self._project_path, - fallback_video_name=self._image_meta.name, - ) - if project_path is not None: - self._project_path = project_path - self.viewer.status = msg - self._refresh_video_panel_context() + for idx, menu in enumerate(self._menus): + menu.setHidden(idx != menu_idx) def _form_dropdown_menus(self, store): menu = KeypointsDropdownMenu(store) @@ -1406,6 +1084,8 @@ def _form_dropdown_menus(self, store): layout.addWidget(menu) self._layout.addLayout(layout) + self._sync_dropdown_visibility() + def _form_mode_radio_buttons(self): group_box = QGroupBox("Labeling mode") layout = QHBoxLayout() @@ -1443,6 +1123,52 @@ def _func(): group.buttonClicked.connect(_func) return group_box, group + def _form_help_buttons(self): + layout = QVBoxLayout() + help_buttons_layout = QHBoxLayout() + self.show_shortcuts_btn = QPushButton("View shortcuts") + self.show_shortcuts_btn.clicked.connect(self.display_shortcuts) + help_buttons_layout.addWidget(self.show_shortcuts_btn) + self.tutorial_btn = QPushButton("Start tutorial") + self.tutorial_btn.clicked.connect(self.start_tutorial) + help_buttons_layout.addWidget(self.tutorial_btn) + layout.addLayout(help_buttons_layout) + self._keypoint_mapping_button = QPushButton("Load superkeypoints diagram") + self._load_superkeypoints_action = self._keypoint_mapping_button.clicked.connect( + self.load_superkeypoints_diagram + ) + self._keypoint_mapping_button.hide() + layout.addWidget(self._keypoint_mapping_button) + return layout + + def _refresh_video_panel_context(self) -> None: + update_video_panel_context(self.viewer, self._video_group) + + def _extract_single_frame(self, *args): + ok, msg = run_extract_current_frame( + self.viewer, + self._video_group, + validate_points_layer=self.layer_manager.validate_header, + ) + self.viewer.status = msg + self._refresh_video_panel_context() + + def _on_apply_crop_toggled(self, checked) -> None: + handle_apply_crop_toggled(self.viewer, self._video_group, bool(checked)) + self._refresh_video_panel_context() + + def _store_crop_coordinates(self, *args): + ok, msg, project_path = run_store_crop_coordinates( + self.viewer, + self._video_group, + explicit_project_path=self.layer_manager.project_path, + fallback_video_name=self.layer_manager.image_meta.name, + ) + if project_path is not None: + self.layer_manager.project_path = project_path + self.viewer.status = msg + self._refresh_video_panel_context() + def _update_color_scheme(self): if hasattr(self, "_color_scheme_panel"): self._color_scheme_panel.schedule_update() @@ -1494,312 +1220,9 @@ def _apply_points_coloring_from_metadata(self, layer: Points) -> None: except Exception: pass - def _remap_frame_indices(self, layer): - """ - Best-effort remap of time/frame indices in non-Image layers to match current Image order. - - Safety principles - ----------------- - - Never delete or silently corrupt user data. - - Only write back to layer.data after a remap has been accepted as safe. - - Always sync non-path image metadata when possible. - - Do NOT replace metadata["paths"] unless remap is accepted as safe. - - Specifically reject ambiguous basename-only remaps (depth=1 with duplicate / - non-bijective warnings), which commonly happen when data are moved out of the - standard DLC labeled-data layout. - """ - try: - new_paths = self._image_meta.paths - if not new_paths: - return - - if layer.metadata is None: - layer.metadata = {} - - md = layer.metadata - old_paths = md.get("paths") or [] - - # Always sync safe non-path metadata from image meta. - # Do NOT sync "paths" yet; that is only safe after we decide remap is acceptable. - try: - safe_image_meta = self._image_meta.model_dump(exclude_none=True) - safe_image_meta.pop("paths", None) - layer.metadata.update(safe_image_meta) - except Exception: - logger.debug( - "Failed to sync non-path image metadata for layer=%r", - getattr(layer, "name", str(layer)), - exc_info=True, - ) - - if not old_paths: - logger.debug( - "Skipping remap for layer=%r: no existing layer metadata paths.", - getattr(layer, "name", str(layer)), - ) - return - - # Determine time column (napari-specific choice) - time_col = 1 if isinstance(layer, Tracks) else 0 - - if logger.isEnabledFor(logging.DEBUG): - arr_before = np.asarray(layer.data) - logger.debug( - "Remap start layer=%r old_paths_len=%s new_paths_len=%s data_shape=%s frame_min=%s frame_max=%s", - getattr(layer, "name", str(layer)), - len(old_paths), - len(new_paths or []), - getattr(arr_before, "shape", None), - int(np.nanmin(arr_before[:, time_col])) if arr_before.size else None, - int(np.nanmax(arr_before[:, time_col])) if arr_before.size else None, - ) - - res = remap_layer_data_by_paths( - data=layer.data, - old_paths=old_paths, - new_paths=new_paths, - time_col=time_col, - policy=PathMatchPolicy.ORDERED_DEPTHS, - ) - - logger.debug( - "Remap result layer=%r changed=%s mapped_count=%s depth=%s message=%s warnings=%s", - getattr(layer, "name", str(layer)), - res.changed, - res.mapped_count, - res.depth_used, - res.message, - res.warnings, - ) - - if res.applied and res.data is not None: - layer.data = res.data - - if res.accept_paths_update: - layer.metadata["paths"] = list(new_paths) - if isinstance(layer, Points): - mark_layer_presentation_changed(layer) - - # Final debug logging - if res.depth_used is None: - logger.debug("Remap skipped for %s: %s", getattr(layer, "name", str(layer)), res.message) - else: - logger.debug( - "Remap %s for %s (depth=%s, mapped=%s): %s", - "applied" if res.changed else "accepted-noop", - getattr(layer, "name", str(layer)), - res.depth_used, - res.mapped_count, - res.message, - ) - - except Exception: - logger.exception("Failed to remap frame indices for layer %s", getattr(layer, "name", str(layer))) - return - - def on_insert(self, event): - layer = event.source[-1] - logger.debug( - "on_insert layer=%r type=%s index=%s", - getattr(layer, "name", layer), - type(layer).__name__, - getattr(event, "index", None), - ) - if isinstance(layer, Image): - self._setup_image_layer(layer, event.index, reorder=True) - elif isinstance(layer, Points): - self._setup_points_layer(layer, allow_merge=True) - - for layer_ in self.viewer.layers: - if not isinstance(layer_, Image): - self._remap_frame_indices(layer_) - self._refresh_video_panel_context() - self._refresh_layer_status_panel() - - def on_remove(self, event): - layer = event.value - n_points_layer = sum(isinstance(l, Points) for l in self.viewer.layers) - - if isinstance(layer, Points): - self._stores.pop(layer, None) - - # Refresh color scheme panel regardless; it will clear itself if no valid target remains. - self._update_color_scheme() - self._trails_controller.on_points_layer_removed(layer) - - if n_points_layer == 0: - while self._menus: - menu = self._menus.pop() - self._layout.removeWidget(menu) - menu.deleteLater() - menu.destroy() - - self._layer_to_menu = {} - self._trail_cb.setEnabled(False) - self._show_traj_plot_cb.setEnabled(False) - self.last_saved_label.hide() - - elif isinstance(layer, Image): - self._image_meta = ImageMetadata() - paths = layer.metadata.get("paths") - if paths is None: - self.video_widget.setVisible(False) - - elif isinstance(layer, Tracks): - was_trails = self._trails_controller.on_tracks_layer_removed(layer) - if was_trails: - with QSignalBlocker(self._trail_cb): - self._trail_cb.setChecked(False) - - self._refresh_video_panel_context() - self._refresh_layer_status_panel() - def _on_show_trails_toggled(self, state): self._trails_controller.toggle(Qt.CheckState(state) == Qt.CheckState.Checked) - def _ensure_promotion_save_target(self, layer: Points) -> bool: - """Ensure a prediction/machine source layer has a GT save_target set. - - Returns True if save_target is set (or already existed), False if user cancels. - """ - if not is_machine_layer(layer): - return True - - mig = migrate_points_layer_metadata(layer) - if hasattr(mig, "errors"): - logger.warning( - "Failed to migrate points layer metadata for layer=%r: %s", - getattr(layer, "name", layer), - mig, - ) - - res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=False) - if isinstance(res, ValidationError): - logger.warning( - "Points metadata validation failed for layer=%r during save target check: %s", - getattr(layer, "name", layer), - res, - ) - QMessageBox.warning(self, "Cannot save", "Layer metadata is invalid; see logs for details.") - return False - - pts: PointsMetadata = res - - if not requires_gt_promotion(pts): - return True - - anchor = safe_folder_anchor_from_points_layer(layer) - if not anchor: - QMessageBox.warning(self, "Cannot save", "Could not determine a folder anchor for saving.") - return False - - scorer = None - - # 1) Auto-discovered config.yaml always wins - cfg_path = None - try: - cfg_path = find_nearest_config(anchor) - except Exception: - logger.debug("Automatic config discovery failed for anchor=%r", anchor, exc_info=True) - - if cfg_path: - try: - scorer = ui_dialogs.load_scorer_from_config(cfg_path) - except Exception: - logger.exception("Failed to load auto-discovered config.yaml: %s", cfg_path) - ui_dialogs.warn_invalid_config_for_scorer( - self, - config_path=cfg_path, - reason="unreadable", - auto_found=True, - ) - return False - - if not scorer: - ui_dialogs.warn_invalid_config_for_scorer( - self, - config_path=cfg_path, - reason="missing_scorer", - auto_found=True, - ) - return False - - else: - # 2) No config found automatically -> let the user choose one - dialog_result = ui_dialogs.prompt_for_project_config_for_save( - self, - initial_dir=self._project_path or anchor, - window_title="Locate DLC config for scorer resolution", - message=( - "No DeepLabCut config.yaml could be found automatically for this machine-labeled layer.\n\n" - "If this layer belongs to a DLC project, choose its config.yaml so the save uses the " - "project scorer and standard naming.\n\n" - "If no config.yaml exists, you can continue without one." - ), - choose_button_text="Choose config.yaml", - skip_button_text="Continue without config", - resolve_scorer=True, - ) - - if dialog_result.action is ui_dialogs.ProjectConfigPromptAction.CANCEL: - return False - - if dialog_result.action is ui_dialogs.ProjectConfigPromptAction.ASSOCIATE: - scorer = dialog_result.scorer - - else: - # 3) Only if no config is available at all may sidecar be consulted - scorer = get_default_scorer(anchor) - - # 4) Final fallback: prompt manually - if not scorer: - suggested = suggest_human_placeholder(anchor) - while True: - s = _prompt_for_scorer(self, anchor=anchor, suggested=suggested) - if s is None: - return False - if s.startswith("human_"): - choice = QMessageBox.question( - self, - "Generic scorer name", - "You entered a generic scorer name starting with 'human_'.\n\n" - "We strongly recommend using a real name or stable identifier.\n" - "Do you want to keep this generic scorer anyway?", - QMessageBox.Yes | QMessageBox.No, - ) - if choice == QMessageBox.No: - suggested = s - continue - scorer = s - break - try: - set_default_scorer(anchor, scorer) - except Exception: - logger.debug("Failed to persist default scorer to sidecar", exc_info=True) - - updated = apply_gt_save_target( - pts, - anchor=anchor, - scorer=scorer, - dataset_key="keypoints", - ) - - out = write_points_meta( - layer, - updated, - merge_policy=MergePolicy.MERGE, - fields={"save_target"}, - migrate_legacy=True, - validate=True, - ) - - if hasattr(out, "errors"): - logger.warning("Failed to write save_target for layer=%r: %s", getattr(layer, "name", layer), out) - QMessageBox.warning(self, "Cannot save", "Failed to write save target metadata; see logs for details.") - return False - - return True - def _toggle_overwrite_confirmation(self, state) -> None: enabled = Qt.CheckState(state) == Qt.CheckState.Checked settings.set_overwrite_confirmation_enabled(enabled) @@ -1807,120 +1230,19 @@ def _toggle_overwrite_confirmation(self, state) -> None: # Hack to save a KeyPoints layer without showing the Save dialog def _save_layers_dialog(self, selected=False): - """Save layers (all or selected) to disk, using ``LayerList.save()``. - Parameters - ---------- - selected : bool - If True, only layers that are selected in the viewer will be saved. - By default, all layers are saved. - """ - - selected_layers = list(self.viewer.layers.selection) - msg = "" - if not len(self.viewer.layers): - msg = "There are no layers in the viewer to save." - elif selected and not len(selected_layers): - msg = "Please select a Points layer to save." - if msg: - QMessageBox.warning(self, "Nothing to save", msg, QMessageBox.Ok) + """Save layers (all or selected) using the dedicated save workflow.""" + outcome = self._save_workflow.save_layers(selected=selected) + if not outcome.saved: return - if len(selected_layers) == 1 and isinstance(selected_layers[0], Points): - layer = selected_layers[0] - # Promotion-to-GT policy: never write back to machine/prediction sources. - ok = self._ensure_promotion_save_target(layer) - if not ok: - return - - logger.debug( - "About to save. io.kind=%r save_target=%r", - layer.metadata.get("io", {}).get("kind"), - layer.metadata.get("save_target"), - ) - try: - overridden_metadata, abort_save = self._maybe_prepare_project_path_override_metadata(layer) - if abort_save: - logger.debug("Save aborted during project-association path handling.") - return - - attributes = { - "name": layer.name, - "metadata": overridden_metadata if overridden_metadata is not None else dict(layer.metadata or {}), - "properties": dict(layer.properties or {}), - } - report = compute_overwrite_report_for_points_save(layer.data, attributes) - except Exception as e: - logger.exception("Failed to compute overwrite preflight for layer %r", getattr(layer, "name", layer)) - QMessageBox.warning( - self, - "Cannot save", - f"Failed to prepare save preflight:\n{e}", - QMessageBox.Ok, - ) - return - - if report is not None: - if not ui_dialogs.maybe_confirm_overwrite( - parent=self, - report=report, - ): - logger.debug("Save cancelled.") - return - - if overridden_metadata is not None: - with _temporary_layer_metadata(layer, overridden_metadata): - self.viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") - # Persist the successful override into live metadata after save - layer.metadata = dict(overridden_metadata) - else: - self.viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") - # hook to persist UI state on successful save - try: - self._trails_controller.persist_folder_ui_state_for_points_layer( - layer, - checkbox_checked=self._trail_cb.isChecked(), - ) - except Exception: - logger.debug( - "Failed to persist folder UI state after save for layer=%r", - getattr(layer, "name", layer), - exc_info=True, - ) - self.viewer.status = "Data successfully saved" - else: - dlg = QFileDialog() - hist = get_save_history() - dlg.setHistory(hist) - filename, _ = dlg.getSaveFileName( - caption=f"Save {'selected' if selected else 'all'} layers", - dir=hist[0], # home dir by default - ) - if filename: - self.viewer.layers.save(filename, selected=selected) - # hook to persist UI state on successful save - try: - if selected: - candidate_layers = [ly for ly in selected_layers if isinstance(ly, Points)] - else: - candidate_layers = list(self._stores.keys()) - - for ly in candidate_layers: - if ly in self.viewer.layers: - self._trails_controller.persist_folder_ui_state_for_points_layer( - ly, - checkbox_checked=self._trail_cb.isChecked(), - ) - except Exception: - logger.debug("Failed to persist sidecar UI state after multi-layer save", exc_info=True) - - else: - return self._is_saved = True + if outcome.status_message: + self.viewer.status = outcome.status_message self.last_saved_label.setText(f"Last saved at {str(datetime.now().time()).split('.')[0]}") self.last_saved_label.show() def on_close(self, event): - if self._stores and not self._is_saved: + if self.layer_manager.has_managed_points() and not self._is_saved: choice = QMessageBox.warning( self, "Warning", @@ -1933,6 +1255,15 @@ def on_close(self, event): event.ignore() else: event.accept() + cleared = self.layer_manager.clear_dead_entries(log=True) + if cleared: + logger.debug("Cleared %d dead entries from layer manager on close", len(cleared)) + report = self.layer_manager.audit_registry() + if report.issues: + logger.warning("Layer manager audit on close reported issues:\n%s", report.issues) + + if self.layer_manager is not None: + self.layer_manager.set_merge_decision_provider(None) def on_active_layer_change(self, event) -> None: """Updates the GUI when the active layer changes @@ -1940,20 +1271,15 @@ def on_active_layer_change(self, event) -> None: * Sets the visibility of the "Color mode" box to True if the selected layer is a multi-animal one, or False otherwise """ - self._color_grp.setVisible(self._is_multianimal(event.value)) - # self._update_color_scheme() # if needed - menu_idx = -1 - if event.value is not None and isinstance(event.value, Points): - menu_idx = self._layer_to_menu.get(event.value, -1) - - for idx, menu in enumerate(self._menus): - if idx == menu_idx: - menu.setHidden(False) - else: - menu.setHidden(True) + with log_timing( + logger, f"on_active_layer_change value={getattr(event.value, 'name', None)!r}", threshold_ms=0.0 + ): + self._color_grp.setVisible(self.layer_manager.is_multianimal(event.value)) + # self._update_color_scheme() # if needed + self._sync_dropdown_visibility() - self._refresh_video_panel_context() - self._refresh_layer_status_panel() + self._refresh_video_panel_context() + self._refresh_layer_status_panel() def _update_colormap(self, colormap_name: str): for layer in self.viewer.layers.selection: diff --git a/src/napari_deeplabcut/config/_autostart.py b/src/napari_deeplabcut/config/_autostart.py index 463561d9..017e5332 100644 --- a/src/napari_deeplabcut/config/_autostart.py +++ b/src/napari_deeplabcut/config/_autostart.py @@ -8,9 +8,9 @@ from napari.utils.events import Event from qtpy.QtCore import QTimer -from napari_deeplabcut._widgets import KeypointControls from napari_deeplabcut.config.settings import get_auto_open_keypoint_controls from napari_deeplabcut.core.metadata import read_points_meta +from napari_deeplabcut.widget_factory import get_existing_keypoint_controls logger = logging.getLogger(__name__) @@ -29,13 +29,6 @@ def _is_dlc_points_layer(layer) -> bool: return res.header is not None -def get_existing_keypoint_controls(viewer): - for widget in viewer.window.dock_widgets.values(): - if isinstance(widget, KeypointControls): - return widget - return None - - def _ensure_keypoint_controls_open(viewer) -> None: """Open Keypoint controls dock widget if enabled in settings.""" if viewer is None or not get_auto_open_keypoint_controls(): diff --git a/src/napari_deeplabcut/config/keybinds.py b/src/napari_deeplabcut/config/keybinds.py index 765a6974..d9433823 100644 --- a/src/napari_deeplabcut/config/keybinds.py +++ b/src/napari_deeplabcut/config/keybinds.py @@ -10,6 +10,8 @@ import numpy as np from napari.layers import Points +from .settings import TRACKING_SHORTCUTS_ENABLED + _global_points_bindings_installed = False @@ -118,9 +120,83 @@ def _jump_unlabeled_frame(ctx: BindingContext): ), ) +# -------------------------------- +# Tracking shortcuts +# -------------------------------- + + +@dataclass(frozen=True) +class TrackingKeybindConfig: + key: str + tooltip: str + + def get_display(self) -> str: + txt = self.tooltip + if TRACKING_SHORTCUTS_ENABLED: + txt += f" ({self.key})" + return txt + + +TRACK_FORWARD = TrackingKeybindConfig(key="l", tooltip="Track forward") +TRACK_FORWARD_END = TrackingKeybindConfig(key="k", tooltip="Track forward to end") +TRACK_BACKWARD = TrackingKeybindConfig(key="h", tooltip="Track backward") +TRACK_BACKWARD_END = TrackingKeybindConfig(key="j", tooltip="Track backward to start") +MOVE_FORWARD_FRAME = TrackingKeybindConfig(key="i", tooltip="Move forward one frame") +MOVE_BACKWARD_FRAME = TrackingKeybindConfig(key="u", tooltip="Move backward one frame") + + +TRACKING_SHORTCUTS: tuple[ShortcutSpec, ...] = ( + ShortcutSpec( + keys=(TRACK_FORWARD.key,), + description=TRACK_FORWARD.tooltip, + group="Tracking", + scope="tracking-points-layer", + when="Tracking widget is open", + ), + ShortcutSpec( + keys=(TRACK_FORWARD_END.key,), + description=TRACK_FORWARD_END.tooltip, + group="Tracking", + scope="tracking-points-layer", + when="Tracking widget is open", + ), + ShortcutSpec( + keys=(TRACK_BACKWARD.key,), + description=TRACK_BACKWARD.tooltip, + group="Tracking", + scope="tracking-points-layer", + when="Tracking widget is open", + ), + ShortcutSpec( + keys=(TRACK_BACKWARD_END.key,), + description=TRACK_BACKWARD_END.tooltip, + group="Tracking", + scope="tracking-points-layer", + when="Tracking widget is open", + ), + ShortcutSpec( + keys=(MOVE_FORWARD_FRAME.key,), + description=MOVE_FORWARD_FRAME.tooltip, + group="Tracking", + scope="tracking-points-layer", + when="Tracking widget is open", + ), + ShortcutSpec( + keys=(MOVE_BACKWARD_FRAME.key,), + description=MOVE_BACKWARD_FRAME.tooltip, + group="Tracking", + scope="tracking-points-layer", + when="Tracking widget is open", + ), +) + +# ----- Keybind registry functions ------ + def iter_shortcuts() -> Iterable[ShortcutSpec]: - return SHORTCUTS + yield from SHORTCUTS + if TRACKING_SHORTCUTS_ENABLED: + yield from TRACKING_SHORTCUTS def _bind_each_key(layer: Points, keys: tuple[str, ...], callback, *, overwrite: bool = False) -> None: diff --git a/src/napari_deeplabcut/config/models.py b/src/napari_deeplabcut/config/models.py index 36b6b24f..0e6269af 100644 --- a/src/napari_deeplabcut/config/models.py +++ b/src/napari_deeplabcut/config/models.py @@ -313,9 +313,14 @@ def with_scorer(self, scorer: str) -> DLCHeaderModel: Replaces legacy `header.scorer = ...`. """ - canon = self._canonical_4() - new_cols = [(str(scorer), ind, bp, coord) for _, ind, bp, coord in canon] - return self.model_copy(update={"columns": new_cols, "names": ["scorer", "individuals", "bodyparts", "coords"]}) + new_cols = [] + for col in self.columns: + t = tuple(map(str, col)) + if len(t) >= 1: + new_cols.append((str(scorer), *t[1:])) + else: + new_cols.append((str(scorer),)) + return self.model_copy(update={"columns": new_cols, "names": self.names}) def form_individual_bodypart_pairs(self) -> list[tuple[str, str]]: """ @@ -430,7 +435,7 @@ class IOProvenance(BaseModel): description="Project-relative POSIX path to the source .h5 (forward slashes).", ) kind: AnnotationKind | None = Field(default=None, description="Annotation kind for routing", strict=True) - dataset_key: str = Field(default="keypoints", description="HDF5 key for keypoints table") + dataset_key: str = Field(default="df_with_missing", description="HDF5 key for keypoints table") @field_validator("source_relpath_posix") @classmethod diff --git a/src/napari_deeplabcut/config/settings.py b/src/napari_deeplabcut/config/settings.py index 8edbc833..5ade51a9 100644 --- a/src/napari_deeplabcut/config/settings.py +++ b/src/napari_deeplabcut/config/settings.py @@ -1,11 +1,18 @@ +import os + from qtpy.QtCore import QSettings +# Colormap settings DEFAULT_SINGLE_ANIMAL_CMAP = "rainbow" DEFAULT_MULTI_ANIMAL_INDIVIDUAL_CMAP = "Set3" +# UI settings _OVERWRITE_CONFIRM_ENABLED_KEY = "napari_deeplabcut/overwrite/confirm_enabled" AUTO_OPEN_KEYPOINT_CONTROLS_KEY = "napari_deeplabcut/ui/auto_open_keypoint_controls" +# Tracking settings +TRACKING_SHORTCUTS_ENABLED = os.environ.get("NAPARI_DLC_TRACKING_SHORTCUTS_ENABLED", "1") == "1" + def get_overwrite_confirmation_enabled() -> bool: """Return whether overwrite confirmation dialogs are enabled.""" diff --git a/src/napari_deeplabcut/core/conflicts.py b/src/napari_deeplabcut/core/conflicts.py index 41360df8..597f9b0f 100644 --- a/src/napari_deeplabcut/core/conflicts.py +++ b/src/napari_deeplabcut/core/conflicts.py @@ -9,6 +9,7 @@ from napari_deeplabcut.core import schemas as dlc_schemas from napari_deeplabcut.core.dataframes import set_df_scorer from napari_deeplabcut.core.errors import AmbiguousSaveError, MissingProvenanceError +from napari_deeplabcut.core.io import DLC_CANONICAL_H5_KEY from napari_deeplabcut.core.metadata import parse_points_metadata from napari_deeplabcut.core.project_paths import infer_dlc_project_from_points_meta from napari_deeplabcut.core.provenance import ( @@ -143,7 +144,7 @@ def compute_overwrite_report_for_points_save( return None try: - df_old = pd.read_hdf(out, key="keypoints") + df_old = pd.read_hdf(out, key=DLC_CANONICAL_H5_KEY) except (KeyError, ValueError): df_old = pd.read_hdf(out) @@ -193,7 +194,7 @@ def compute_overwrite_report_for_extracted_labels_row( return None try: - df_old = pd.read_hdf(out, key="df_with_missing") + df_old = pd.read_hdf(out, key=DLC_CANONICAL_H5_KEY) except (KeyError, ValueError): df_old = pd.read_hdf(out) diff --git a/src/napari_deeplabcut/core/dataframes.py b/src/napari_deeplabcut/core/dataframes.py index 8a0fc142..fb65a482 100644 --- a/src/napari_deeplabcut/core/dataframes.py +++ b/src/napari_deeplabcut/core/dataframes.py @@ -8,11 +8,100 @@ import pandas as pd from napari_deeplabcut.config.models import ConflictEntry, OverwriteConflictReport -from napari_deeplabcut.core.schemas import PointsWriteInputModel +from napari_deeplabcut.core.schemas import DLCHeaderModel, PointsWriteInputModel logger = logging.getLogger(__name__) +def _is_logical_single_animal_header( + header: DLCHeaderModel | None, + *, + is_ma_project: bool | None = None, +) -> bool: + """ + Determine whether the target should be written in canonical single-animal format. + + Priority + -------- + 1) Explicit project-mode signal from DLC config (is_ma_project) + 2) Original header shape + 3) Blank-individuals fallback + """ + if is_ma_project is not None: + return not bool(is_ma_project) + + if header is None: + return False + + # Canonical/original SA header + if header.nlevels == 3: + return True + + # Fallback for normalized 4-level headers that still represent SA + if header.nlevels == 4: + inds = [str(i) for i in header.individuals if str(i) != ""] + return len(inds) == 0 + + return False + + +def restore_dlc_on_disk_header_shape( + df: pd.DataFrame, header: DLCHeaderModel, *, is_ma_project: bool | None = None +) -> pd.DataFrame: + """ + Args: + df: DataFrame with arbitrary column structure produced from napari Points + metadata. + header: Authoritative DLCHeaderModel that defines the expected column structure on disk. + is_ma_from_config: Optional boolean indicating if the project is multi-animal based on DLC config. + + Restore the DataFrame column structure to match the authoritative DLC header + that should be used on disk. + + - If the logical target is single-animal, collapse an empty 'individuals' level. + - Reindex to the authoritative header order. + """ + if not isinstance(df.columns, pd.MultiIndex): + return df + + df_out = df.copy() + + if _is_logical_single_animal_header(header, is_ma_project=is_ma_project): + # If the normalized dataframe has an empty individuals level, collapse it. + if df_out.columns.nlevels == 4 and "individuals" in (df_out.columns.names or []): + inds = pd.Index(df_out.columns.get_level_values("individuals")).astype(str) + non_empty_inds = {x for x in inds if x != ""} + if non_empty_inds: + raise ValueError( + "Refusing to write single-animal format because dataframe contains " + f"non-empty individuals: {sorted(non_empty_inds)}" + ) + + df_out = df_out.droplevel("individuals", axis=1) + df_out.columns = df_out.columns.set_names(["scorer", "bodyparts", "coords"]) + + # Reindex to the original authoritative 3-level header order + try: + # For an SA header, header.columns is the original 3-level tuples + target_cols = pd.MultiIndex.from_tuples( + header.columns, + names=header.names or ["scorer", "bodyparts", "coords"], + ) + df_out = df_out.reindex(target_cols, axis=1) + except Exception: + logger.debug("Could not reindex collapsed SA dataframe to authoritative header", exc_info=True) + + return df_out + + # Multi-animal: use canonical 4-level ordering + try: + target_cols = header.as_multiindex() + df_out = df_out.reindex(target_cols, axis=1) + except Exception: + logger.debug("Could not reindex MA dataframe to authoritative header", exc_info=True) + + return df_out + + def set_df_scorer(df: pd.DataFrame, scorer: str) -> pd.DataFrame: """Return df with scorer level set to the given scorer (if present).""" scorer = (scorer or "").strip() @@ -124,6 +213,9 @@ def guarantee_multiindex_rows(df: pd.DataFrame) -> None: Legacy DLC data may use an index with pathto/video/file.png strings as Index. The new format uses a MultiIndex with each path component as a level. """ + if len(df.index) == 0: + return + # Make paths platform-agnostic if they are not already if not isinstance(df.index, pd.MultiIndex): # Backwards compatibility path = df.index[0] diff --git a/src/napari_deeplabcut/core/io.py b/src/napari_deeplabcut/core/io.py index d7fa59d8..3f98793a 100644 --- a/src/napari_deeplabcut/core/io.py +++ b/src/napari_deeplabcut/core/io.py @@ -43,6 +43,7 @@ harmonize_keypoint_column_index, harmonize_keypoint_row_index, merge_multiple_scorers, + restore_dlc_on_disk_header_shape, set_df_scorer, ) from napari_deeplabcut.core.errors import AmbiguousSaveError, MissingProvenanceError @@ -62,6 +63,7 @@ # ----------------------------------------------------------------------------- _GLOB_MAGIC = set("*?[") _SUPPORTED_SUFFIXES = {ext.lower() for ext in SUPPORTED_IMAGES} +DLC_CANONICAL_H5_KEY = "df_with_missing" # TODO use this key instead of str literal in all places def _has_glob_magic(name: str) -> bool: @@ -116,6 +118,26 @@ def write_config(config_path: str | Path, params: dict[str, Any]) -> None: # and attaches provenance via attach_source_and_io_to_layer_kwargs. +def _read_hdf_any_key(file: Path) -> pd.DataFrame: + """Read an HDF file without knowing the key in advance. Try common DLC keys.""" + file = str(file) + try: + return pd.read_hdf(file, key=DLC_CANONICAL_H5_KEY) + except (KeyError, ValueError): + logger.error(f"Key '{DLC_CANONICAL_H5_KEY}' not found in {file}. Trying to read without specifying a key.") + fallback_keys = ["keypoints"] + for k in fallback_keys: + try: + return pd.read_hdf(file, key=k) + except (KeyError, ValueError): + logger.error(f"Key '{k}' not found in {file}.") + logger.error( + f"None of the expected keys {fallback_keys + [DLC_CANONICAL_H5_KEY]} were found in {file}. " + "Falling back to default read_hdf which may raise its own error if no valid key is found." + ) + return pd.read_hdf(file) # Let pandas guess instead + + def read_hdf(filename: str) -> list[LayerData]: layers = [] for file in Path(filename).parent.glob(Path(filename).name): @@ -125,13 +147,16 @@ def read_hdf(filename: str) -> list[LayerData]: def read_hdf_single(file: Path, *, kind: AnnotationKind | None = None) -> list[LayerData]: """Read a single H5 file and attach provenance with optional explicit kind. + Dataset may be under keypoints/ or df_with_missing/ for compatibility with various DLC versions. + We use read_hdf without specifying a key to allow pandas to auto-detect the correct one. - Produces one Points layer per H5 file - Points.data contains only finite coordinates - Unlabeled keypoints are omitted from Points.data - Empty Points layers are valid """ - temp = pd.read_hdf(str(file)) + # temp = pd.read_hdf(str(file)) + temp = _read_hdf_any_key(file) temp = merge_multiple_scorers(temp) header = DLCHeaderModel(columns=temp.columns) temp = temp.droplevel("scorer", axis=1) @@ -288,7 +313,48 @@ def form_df( return df -def _atomic_to_hdf(df: pd.DataFrame, out_path: Path, key: str = "keypoints") -> None: +def _resolve_multianimalproject_for_write( + *, + out_path: Path, + pts_meta: PointsMetadata, +) -> bool | None: + """ + Best-effort resolve of DLC config multianimalproject flag for save. + + Resolution order + ---------------- + 1) pts_meta.project / explicit project context if available + 2) config near output path + 3) config near pts_meta.root + 4) None if not resolvable + """ + candidates: list[Path] = [] + + project = getattr(pts_meta, "project", None) + if project: + candidates.append(Path(project) / "config.yaml") + + candidates.append(out_path.parent) + + root = getattr(pts_meta, "root", None) + if root: + candidates.append(Path(root)) + + for candidate in candidates: + try: + cfg_path = find_nearest_config(candidate, max_levels=3) + cfg = load_config(str(cfg_path)) + if isinstance(cfg, dict) and "multianimalproject" in cfg: + logger.debug("Resolved multianimalproject=True from config at %s", cfg_path) + return bool(cfg.get("multianimalproject", False)) + except Exception: + continue + + logger.debug("Could not resolve multianimalproject flag from any candidate configs.") + return None + + +def _atomic_to_hdf(df: pd.DataFrame, out_path: Path, key: str = DLC_CANONICAL_H5_KEY) -> None: """Best-effort atomic write: write to temp and replace.""" out_path.parent.mkdir(parents=True, exist_ok=True) tmp = out_path.with_suffix(out_path.suffix + ".tmp") @@ -351,6 +417,7 @@ def writer(path: str, data: Any, attributes: dict) -> List[str] # If promoting to GT and scorer is known, rewrite scorer level if target_scorer: df_new = set_df_scorer(df_new, target_scorer) + header_for_write = pts_meta.header.with_scorer(target_scorer) if target_scorer else pts_meta.header # Never write back to machine sources without an explicit promotion target if not out_path and source_kind == AnnotationKind.MACHINE: @@ -404,10 +471,7 @@ def writer(path: str, data: Any, attributes: dict) -> List[str] # Merge-on-save for GT if destination_kind == AnnotationKind.GT and out.exists(): - try: - df_old = pd.read_hdf(out, key="keypoints") - except (KeyError, ValueError): - df_old = pd.read_hdf(out) + df_old = _read_hdf_any_key(out) # Harmonize indices and merge try: @@ -420,25 +484,33 @@ def writer(path: str, data: Any, attributes: dict) -> List[str] df_new = harmonize_keypoint_column_index(df_new) df_old = harmonize_keypoint_column_index(df_old) df_out = df_new.combine_first(df_old) - - # Normalize columns to DLC header if possible - try: - header = DLCHeaderModel(columns=df_out.columns) - df_out = df_out.reindex(header.columns, axis=1) - except Exception: - pass else: df_out = df_new - # Final cleanup + # Final cleanup of rows try: guarantee_multiindex_rows(df_out) except Exception: pass df_out.sort_index(inplace=True) + # Restore canonical on-disk DLC header shape from the header. + is_ma_project = _resolve_multianimalproject_for_write( + out_path=out, + pts_meta=pts_meta, + ) + df_out = restore_dlc_on_disk_header_shape(df_out, header_for_write, is_ma_project=is_ma_project) + + logger.debug("FINAL WRITE columns nlevels: %s", getattr(df_out.columns, "nlevels", None)) + logger.debug("FINAL WRITE columns names: %s", getattr(df_out.columns, "names", None)) + if isinstance(df_out.columns, pd.MultiIndex) and "individuals" in (df_out.columns.names or []): + logger.debug( + "FINAL WRITE individuals values: %s", + list(dict.fromkeys(df_out.columns.get_level_values("individuals").astype(str))), + ) + # Write .h5 and .csv - _atomic_to_hdf(df_out, out, key="keypoints") + _atomic_to_hdf(df_out, out, key=DLC_CANONICAL_H5_KEY) csv_path = out.with_suffix(".csv") df_out.to_csv(csv_path) @@ -661,65 +733,44 @@ def make_delayed_array(fp: Path, first_shape: tuple[int, ...], first_dtype: np.d ) from e -# Read images from a list of files or a glob/string path -def read_images(path: str | Path | list[str | Path]): - """Reads one or multiple images and returns a Napari Image layer. - - Uses `_expand_image_paths` to resolve the input into a list of valid - image files. Supports single paths, glob expressions, directories, - and lists or tuples of such paths. - - Behavior: - * If one file is found: - - Loaded using `dask_image.imread.imread`. - * If multiple files are found: - - Loaded lazily using `lazy_imread` into a stacked image - layer. +def _build_image_layer_kwargs( + *, + filepaths: list[Path], + dlc_meta: dict | None = None, + name: str = "images", +) -> dict[str, Any]: + metadata = { + "paths": [canonicalize_path(fp, 3) for fp in filepaths], + "root": str(filepaths[0].parent), + } + if dlc_meta is not None: + metadata["dlc"] = dlc_meta - Args: - path (str | Path | list[str | Path]): - Input path(s), directory, or glob pattern(s) to expand into - supported image files. + return { + "name": name, + "metadata": metadata, + } - Returns: - list[LayerData]: - A list containing one Napari layer tuple of the form - `(data, metadata, "image")`. - Raises: - OSError: If no supported images are found after expansion. - """ +# Read images from a list of files or a glob/string path +def read_images( + path: str | Path | list[str | Path], + *, + dlc_meta: dict | None = None, +) -> list[LayerData]: filepaths = _expand_image_paths(path) - if not filepaths: raise OSError(f"No supported images were found in {path}") filepaths = natsorted(filepaths, key=str) + kwargs = _build_image_layer_kwargs(filepaths=filepaths, dlc_meta=dlc_meta, name="images") - # Multiple images → lazy-imread stack if len(filepaths) > 1: - # NOTE: canonicalize_path(fp, 3) stores a stable relative-ish path for the UI/metadata. - relative_paths = [canonicalize_path(fp, 3) for fp in filepaths] - params = { - "name": "images", - "metadata": { - "paths": relative_paths, - "root": str(filepaths[0].parent), - }, - } data = _lazy_imread(filepaths, use_dask=True, stack=True) - return [(data, params, "image")] + else: + data = imread(str(filepaths[0])) - # Single image → old behavior - image_path = filepaths[0] - params = { - "name": "images", - "metadata": { - "paths": [canonicalize_path(image_path, 3)], - "root": str(image_path.parent), - }, - } - return [(imread(str(image_path)), params, "image")] + return [(data, kwargs, "image")] # ============================================================================= @@ -772,7 +823,7 @@ def close(self): self.stream.release() -def read_video(filename: str, opencv: bool = True): +def read_video(filename: str, opencv: bool = True, *, dlc_meta: dict | None = None): if opencv: stream = Video(filename) # NOTE construct output shape tuple in (H, W, C) order to match read_frame() data @@ -805,4 +856,7 @@ def _read_frame(ind): "root": root, }, } + if dlc_meta is not None: + params["metadata"]["dlc"] = dlc_meta + return [(movie, params, "image")] diff --git a/src/napari_deeplabcut/core/keypoints.py b/src/napari_deeplabcut/core/keypoints.py index ff91b27f..3a39a9eb 100644 --- a/src/napari_deeplabcut/core/keypoints.py +++ b/src/napari_deeplabcut/core/keypoints.py @@ -1,7 +1,8 @@ # src/napari_deeplabcut/keypoints.py import logging +import weakref from collections import namedtuple -from collections.abc import Sequence +from collections.abc import Callable, Sequence from enum import auto import numpy as np @@ -17,10 +18,15 @@ from napari_deeplabcut.config.models import DLCHeaderModel from napari_deeplabcut.core.metadata import read_points_meta from napari_deeplabcut.misc import CycleEnum, HeaderLike +from napari_deeplabcut.utils.deprecations import deprecated logger = logging.getLogger(__name__) +class LayerUnavailableError(RuntimeError): + """Raised when a KeypointStore can no longer resolve its backing layer.""" + + # Monkeypatch the point size slider def _change_size(self, value): """Resize all points at once regardless of the current selection.""" @@ -44,9 +50,9 @@ def _change_symbol(self, text): QtPointsControls.changeCurrentSymbol = _change_symbol +@deprecated(details="Unused currently, should be removed.") def _validate_points_meta_best_effort(layer) -> bool: """ - Phase-2 friendly: validate points metadata without mutating it. We drop header + controls during validation to avoid runtime-object issues. """ res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=True) @@ -106,21 +112,97 @@ def default(cls): class KeypointStore: - def __init__(self, viewer, layer: Points): + def __init__( + self, + viewer, + layer: Points, + *, + resolve_layer_by_id: Callable[[int], Points] | None = None, + ): self.viewer = viewer self._keypoints = [] self._header: DLCHeaderModel | None = None - self.layer = layer + + self._layer_id: int | None = None + self._resolve_layer_by_id = resolve_layer_by_id + + # Fallback if no resolver is provided + self._layer_ref: weakref.ReferenceType[Points] | None = None + self._strong_layer_ref: Points | None = None # Used to keep the layer alive if no resolver is provided + + self.layer = layer # Use the setter to initialize keypoints and header + self.viewer.dims.set_current_step(0, 0) + def set_label_mode_getter(self, getter: Callable[[], LabelMode]): + self._get_label_mode = getter + @property - def layer(self): - return self._layer + def layer_id(self) -> int | None: + return self._layer_id + + def attach_layer_resolver(self, resolve_layer_by_id: Callable[[int], Points | None]) -> None: + """Attach a narrow lifecycle-owned resolver. + + The resolver should accept a layer_id and return the currently live layer, + or None if the layer is no longer available. + """ + self._resolve_layer_by_id = resolve_layer_by_id + + def maybe_layer(self) -> Points | None: + """Resolve the current layer if still available, else return None.""" + if self._layer_id is None: + return None + + # Lifecycle resolver is authoritative when present. + if self._resolve_layer_by_id is not None: + try: + return self._resolve_layer_by_id(self._layer_id) + except Exception: + logger.debug("Layer resolver failed for layer_id=%r", self._layer_id, exc_info=True) + return None + + # Fallback for tests / legacy contexts. + if self._layer_ref is not None: + return self._layer_ref() + return self._strong_layer_ref + + def require_layer(self) -> Points: + layer = self.maybe_layer() + if layer is None: + raise LayerUnavailableError(f"Layer is no longer available for KeypointStore layer_id={self._layer_id}") + return layer + + @property + def layer(self) -> Points: + return self.require_layer() @layer.setter - def layer(self, layer): - self._layer = layer + def layer(self, layer: Points): + self._layer_id = id(layer) + try: + self._layer_ref = weakref.ref(layer) + self._strong_layer_ref = None + except TypeError: + # Fallback if a given object cannot be weak-referenced. + self._layer_ref = None + self._strong_layer_ref = layer + + # Avoid repeated validated metadata reads when rebinding the same live layer. + # NOTE: metadata/header may have changed even when layer is the same, + # therefore this is unsafe. Properly skipping would require checking header signatures + # if same_layer: + # logger.debug( + # "Skipping KeypointStore header refresh for same layer_id=%r name=%r", + # self._layer_id, + # getattr(layer, "name", layer), + # ) + # return + + self._refresh_header_from_layer(layer) + + def _refresh_header_from_layer(self, layer: Points) -> None: res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=False) if isinstance(res, ValidationError) or res.header is None: self._header = None @@ -149,14 +231,16 @@ def n_steps(self): @property def annotated_keypoints(self) -> list[Keypoint]: + layer = self.layer mask = self.current_mask - labels = self.layer.properties["label"][mask] - ids = self.layer.properties["id"][mask] + labels = layer.properties["label"][mask] + ids = layer.properties["id"][mask] return [Keypoint(label, id_) for label, id_ in zip(labels, ids, strict=False)] @property def current_mask(self) -> Sequence[bool]: - return np.asarray(self.layer.data[:, 0] == self.current_step) + layer = self.layer + return np.asarray(layer.data[:, 0] == self.current_step) @property def current_keypoint(self) -> Keypoint: @@ -173,12 +257,13 @@ def current_keypoint(self) -> Keypoint: @current_keypoint.setter def current_keypoint(self, keypoint: Keypoint): + layer = self.layer # Avoid changing the properties of a selected point - if not len(self.layer.selected_data): - current_properties = self.layer.current_properties + if not len(layer.selected_data): + current_properties = layer.current_properties current_properties["label"] = np.asarray([keypoint.label]) current_properties["id"] = np.asarray([keypoint.id]) - self.layer.current_properties = current_properties + layer.current_properties = current_properties def next_keypoint(self, *args): ind = self._keypoints.index(self.current_keypoint) + 1 @@ -196,10 +281,11 @@ def current_label(self) -> str: @current_label.setter def current_label(self, label: str): - if not len(self.layer.selected_data): - current_properties = self.layer.current_properties + layer = self.layer + if not len(layer.selected_data): + current_properties = layer.current_properties current_properties["label"] = np.asarray([label]) - self.layer.current_properties = current_properties + layer.current_properties = current_properties @property def current_id(self) -> str: @@ -207,84 +293,123 @@ def current_id(self) -> str: @current_id.setter def current_id(self, id_: str): - if not len(self.layer.selected_data): - current_properties = self.layer.current_properties + layer = self.layer + if not len(layer.selected_data): + current_properties = layer.current_properties current_properties["id"] = np.asarray([id_]) - self.layer.current_properties = current_properties + layer.current_properties = current_properties def _advance_step(self, event): ind = (self.current_step + 1) % self.n_steps self.viewer.dims.set_current_step(0, ind) def _find_first_unlabeled_frame(self, event): + layer = self.layer inds = set(range(self.n_steps)) - unlabeled_inds = inds.difference(self.layer.data[:, 0].astype(int)) + unlabeled_inds = inds.difference(layer.data[:, 0].astype(int)) if not unlabeled_inds: self.viewer.dims.set_current_step(0, self.n_steps - 1) else: self.viewer.dims.set_current_step(0, min(unlabeled_inds)) + def _clear_stale_selection_if_off_frame(self) -> None: + layer = self.layer + if not len(layer.selected_data): + return + + try: + sel = np.fromiter(layer.selected_data, dtype=int) + except Exception: + layer.selected_data = set() + return + + if sel.size == 0: + layer.selected_data = set() + return + + try: + selected_steps = np.asarray(layer.data[sel, 0]) + except Exception: + layer.selected_data = set() + return + + # If none of the selected points belong to the current frame, clear selection. + if not np.any(selected_steps == self.current_step): + layer.selected_data = set() + + def add(self, coord): + coord = np.atleast_2d(coord) + + # Clear stale cross-frame selection before any logic + self._clear_stale_selection_if_off_frame() + + get_mode = getattr(self, "_get_label_mode", None) + label_mode = get_mode() if callable(get_mode) else None -def _add(store, coord): - coord = np.atleast_2d(coord) + changed = False - # Controls are runtime-only; prefer layer attribute, fall back to metadata. - get_mode = getattr(store, "_get_label_mode", None) - label_mode = get_mode() if callable(get_mode) else None + if self.current_keypoint not in self.annotated_keypoints: + layer = self.layer - if store.current_keypoint not in store.annotated_keypoints: - # 1) append data - store.layer.data = np.append(store.layer.data, coord, axis=0) + # 1) append data + layer.data = np.append(layer.data, coord, axis=0) - # 2) append/align properties to match number of points - kp = store.current_keypoint - n_new = coord.shape[0] - n_total = len(store.layer.data) - n_old = n_total - n_new + # 2) append/align properties to match number of points + kp = self.current_keypoint + n_new = coord.shape[0] + n_total = len(layer.data) + n_old = n_total - n_new - props = store.layer.properties.copy() + props = layer.properties.copy() - def _as_array(key, dtype): - arr = props.get(key, None) - if arr is None: - return np.array([], dtype=dtype) - return np.asarray(arr, dtype=dtype) + def _as_array(key, dtype): + arr = props.get(key, None) + if arr is None: + return np.array([], dtype=dtype) + return np.asarray(arr, dtype=dtype) - # Existing values truncated/padded to n_old, then append new rows - label_arr = _as_array("label", object)[:n_old] - id_arr = _as_array("id", object)[:n_old] - lik_arr = _as_array("likelihood", float)[:n_old] + label_arr = _as_array("label", object)[:n_old] + id_arr = _as_array("id", object)[:n_old] + lik_arr = _as_array("likelihood", float)[:n_old] - # If any are shorter than n_old, pad (rare but safe) - if label_arr.size < n_old: - label_arr = np.concatenate([label_arr, np.array([kp.label] * (n_old - label_arr.size), dtype=object)]) - if id_arr.size < n_old: - id_arr = np.concatenate([id_arr, np.array([kp.id] * (n_old - id_arr.size), dtype=object)]) - if lik_arr.size < n_old: - lik_arr = np.concatenate([lik_arr, np.ones(n_old - lik_arr.size, dtype=float)]) + if label_arr.size < n_old: + label_arr = np.concatenate([label_arr, np.array([kp.label] * (n_old - label_arr.size), dtype=object)]) + if id_arr.size < n_old: + id_arr = np.concatenate([id_arr, np.array([kp.id] * (n_old - id_arr.size), dtype=object)]) + if lik_arr.size < n_old: + lik_arr = np.concatenate([lik_arr, np.ones(n_old - lik_arr.size, dtype=float)]) - props["label"] = np.concatenate([label_arr, np.array([kp.label] * n_new, dtype=object)]) - props["id"] = np.concatenate([id_arr, np.array([kp.id] * n_new, dtype=object)]) - props["likelihood"] = np.concatenate([lik_arr, np.ones(n_new, dtype=float)]) + props["label"] = np.concatenate([label_arr, np.array([kp.label] * n_new, dtype=object)]) + props["id"] = np.concatenate([id_arr, np.array([kp.id] * n_new, dtype=object)]) + props["likelihood"] = np.concatenate([lik_arr, np.ones(n_new, dtype=float)]) - store.layer.properties = props + layer.properties = props + changed = True - elif label_mode is LabelMode.QUICK: - ind = store.annotated_keypoints.index(store.current_keypoint) - data = store.layer.data - data[np.flatnonzero(store.current_mask)[ind]] = coord.squeeze() - store.layer.data = data + elif label_mode is LabelMode.QUICK: + layer = self.layer + ind = self.annotated_keypoints.index(self.current_keypoint) + data = layer.data + data[np.flatnonzero(self.current_mask)[ind]] = coord.squeeze() + layer.data = data + changed = True + + self.layer.selected_data = set() + + if label_mode is LabelMode.LOOP: + if changed: + self.layer.events.query_next_frame() + else: + if changed: + self.next_keypoint() - store.layer.selected_data = set() - # If controls are missing, behave like the default mode (advance keypoint) - if label_mode is LabelMode.LOOP: - store.layer.events.query_next_frame() - else: - store.next_keypoint() +@deprecated(details="Temporary compat shim, remove once KeypointStore.add is properly integrated.") +def add(store: KeypointStore, coord): + return store.add(coord) -def _find_nearest_neighbors(xy_true, xy_pred, k=5): +def find_nearest_neighbors(xy_true, xy_pred, k=5): n_preds = xy_pred.shape[0] tree = cKDTree(xy_pred) dist, inds = tree.query(xy_true, k=k) diff --git a/src/napari_deeplabcut/core/layer_lifecycle/__init__.py b/src/napari_deeplabcut/core/layer_lifecycle/__init__.py new file mode 100644 index 00000000..ba7198ef --- /dev/null +++ b/src/napari_deeplabcut/core/layer_lifecycle/__init__.py @@ -0,0 +1,22 @@ +from .manager import LayerLifecycleManager +from .merge import ( + MergeDecisionProvider, + MergeDecisionRequest, + MergeDecisionResult, + MergeDisposition, +) +from .registry import ManagedPointsRuntime, PointsLayerSetupRequest, RuntimeRegistry +from .spawn import get_layer_manager, get_or_create_layer_manager + +__all__ = [ + "LayerLifecycleManager", + "ManagedPointsRuntime", + "RuntimeRegistry", + "MergeDecisionProvider", + "PointsLayerSetupRequest", + "MergeDecisionRequest", + "MergeDecisionResult", + "MergeDisposition", + "get_layer_manager", + "get_or_create_layer_manager", +] diff --git a/src/napari_deeplabcut/core/layer_lifecycle/manager.py b/src/napari_deeplabcut/core/layer_lifecycle/manager.py new file mode 100644 index 00000000..1c9e27fb --- /dev/null +++ b/src/napari_deeplabcut/core/layer_lifecycle/manager.py @@ -0,0 +1,1177 @@ +from __future__ import annotations + +import logging +from collections.abc import Callable, Iterator +from types import MethodType +from typing import TYPE_CHECKING, Any + +import numpy as np +from napari.layers import Image, Layer, Points, Tracks +from napari.utils.events import Event +from napari.utils.history import update_save_history +from qtpy.QtCore import QObject, QTimer, Signal + +from ...config.keybinds import install_points_layer_keybindings +from ...config.models import DLCHeaderModel, ImageMetadata, PointsMetadata +from ...core import keypoints +from ...core.io import is_video +from ...core.layer_versioning import mark_layer_presentation_changed +from ...core.metadata import ( + MergePolicy, + infer_image_root, + migrate_points_layer_metadata, + read_points_meta, + sync_points_from_image, + write_points_meta, +) +from ...core.project_paths import PathMatchPolicy +from ...core.remap import remap_layer_data_by_paths +from ...napari_compat import install_add_wrapper, install_paste_patch +from ...napari_compat.points_layer import make_paste_data +from ...ui.cropping import resolve_project_path_from_image_layer +from ...utils.debug import log_timing +from .merge import MergeDecisionProvider, MergeDecisionRequest, MergeDecisionResult, MergeDisposition +from .registry import ( + ClearedRegistryEntry, + ManagedPointsRuntime, + PointsLayerSetupRequest, + PointsRuntimeResources, + RegistryAuditReport, + RuntimeRegistry, +) + +if TYPE_CHECKING: + import napari + + from ..keypoints import KeypointStore + +logger = logging.getLogger("napari-deeplabcut.lifecycle") + + +class LayerLifecycleManager(QObject): + """Lifecycle wrapper around existing widget behavior. + + Goals + ----- + - centralize viewer layer event entry points + - own the runtime registry for managed Points layers + - centralize layer liveness / store resolution + - keep current behavior by delegating back into widget-owned hooks + + Non-goals (for now) + ------------------- + - full reconciliation engine + - moving all logic out of the widget + - changing merge/remap policies + """ + + # UI signals for widget hooks + refresh_video_panel_requested = Signal() + refresh_layer_status_requested = Signal() + video_widget_visibility_requested = Signal(bool) + move_image_layer_to_bottom_requested = Signal(object) + + # Layer setup/teardown + points_layer_setup_requested = Signal(object) # PointsLayerSetupRequest + points_layers_merged_requested = Signal(object) # tuple[Points, ...] + points_layer_removed_requested = Signal(object, int) # layer, remaining_points_layers + tracks_layer_removed_requested = Signal(object) + + # Layer insertion/adoption + adopted_existing_layers = Signal() + layer_insert_processed = Signal(object) + layer_remove_processed = Signal(object) + + # Session management + session_conflict_rejected = Signal(str) # if a new DLC folder is loaded on top of the current one + + def __init__(self, viewer: napari.Viewer, *, parent: QObject | None = None) -> None: + super().__init__(parent=parent) + + self.viewer = viewer + self.registry: RuntimeRegistry[Any] = RuntimeRegistry() + self._merge_decision_provider: MergeDecisionProvider | None = None + + # Lifecycle-owned viewer/image context + self._active_dlc_image_layer_id: int | None = None + self._image_meta = ImageMetadata() + self._project_path: str | None = None + + # Layers management + ## Layer insertion + self._initial_adopt_timer = QTimer(self) + self._initial_adopt_timer.setSingleShot(True) + self._initial_adopt_timer.timeout.connect(self.adopt_existing_layers) + ## Layer removal + self._post_remove_refresh_timer = QTimer(self) + self._post_remove_refresh_timer.setSingleShot(True) + self._post_remove_refresh_timer.timeout.connect(self._flush_post_remove_refresh) + + self._attached = False + + # ------------------------------------------------------------------ # + # Centralized access API # + # ------------------------------------------------------------------ # + + def resolve_live_layer(self, layer_or_id: Any) -> Layer | None: + layer = self.registry.resolve_live_layer(layer_or_id) + return layer if isinstance(layer, Layer) else None + + def get_live_runtime(self, layer_or_id: Any) -> ManagedPointsRuntime[Any] | None: + return self.registry.get_live_runtime(layer_or_id) + + def get_store(self, layer_or_id: Any) -> KeypointStore | None: + store = self.registry.get_store(layer_or_id) + return store # typed via TYPE_CHECKING + + def require_store(self, layer_or_id: Any) -> KeypointStore: + store = self.registry.require_store(layer_or_id) + return store # typed via TYPE_CHECKING + + def iter_managed_points(self) -> Iterator[tuple[Points, KeypointStore]]: + """Iterate only live managed Points layers and their stores.""" + for layer, runtime in self.registry.iter_live_items(): + if isinstance(layer, Points): + yield layer, runtime.store + + def managed_points_layers(self) -> tuple[Points, ...]: + return tuple(layer for layer, _ in self.iter_managed_points()) + + def managed_points_count(self) -> int: + return sum(1 for _ in self.iter_managed_points()) + + def has_managed_points(self) -> bool: + return any(True for _ in self.iter_managed_points()) + + def clear_dead_entries(self, *, log: bool = True) -> tuple[ClearedRegistryEntry[Any], ...]: + return self.registry.clear_dead_entries(log=log) + + def audit_registry(self) -> RegistryAuditReport: + return self.registry.audit() + + def set_merge_decision_provider(self, provider: MergeDecisionProvider | None) -> None: + self._merge_decision_provider = provider + + # ------------------------------------------------------------------ # + # lifecycle-owned image/project context # + # ------------------------------------------------------------------ # + + @property + def image_meta(self) -> ImageMetadata: + return self._image_meta + + @property + def project_path(self) -> str | None: + return self._project_path + + @project_path.setter + def project_path(self, value: str | None) -> None: + self._project_path = value + + @property + def image_root(self) -> str | None: + return self._image_meta.root + + @property + def image_paths(self) -> list[str] | None: + return self._image_meta.paths + + @property + def image_name(self) -> str | None: + return self._image_meta.name + + # ------------------------------------------------------------------ # + # Lifecycle wiring # + # ------------------------------------------------------------------ # + + def attach(self) -> None: + """Attach to viewer layer events.""" + if self._attached: + return + + self.viewer.layers.events.inserted.connect(self.on_insert) + self.viewer.layers.events.removed.connect(self.on_remove) + self._attached = True + + def detach(self) -> None: + """Detach from viewer layer events.""" + if not self._attached: + return + + try: + self.viewer.layers.events.inserted.disconnect(self.on_insert) + except Exception: + pass + + try: + self.viewer.layers.events.removed.disconnect(self.on_remove) + except Exception: + pass + + self._attached = False + + def schedule_initial_adoption(self) -> None: + """Schedule adoption of existing layers after the event loop starts.""" + self._initial_adopt_timer.start(0) + + def _dlc_meta_for_layer(self, layer: Layer) -> dict | None: + md = layer.metadata or {} + if not isinstance(md, dict): + return None + payload = md.get("dlc", None) + return payload if isinstance(payload, dict) else None + + def is_dlc_session_image_layer(self, layer: Image) -> bool: + payload = self._dlc_meta_for_layer(layer) + if not payload: + return False + + role = payload.get("session_role", None) + ctx = payload.get("project_context", None) + + return role in {"image", "video"} and isinstance(ctx, dict) and bool(ctx) + + def active_dlc_image_layer(self) -> Image | None: + if self._active_dlc_image_layer_id is None: + return None + + for layer in self.viewer.layers: + if id(layer) == self._active_dlc_image_layer_id and isinstance(layer, Image): + return layer + + return None + + def can_accept_dlc_session_image(self, layer: Image) -> tuple[bool, str | None]: + active = self.active_dlc_image_layer() + if active is None: + return True, None + if active is layer: + return True, None + return ( + False, + "A DLC project/video is already open.\n" + "The plugin will attempt to load annotations from the new project, " + "but will not load the video.\n\n" + "If you meant to load extra annotations for the current video, " + "please only load the corresponding h5 files.\n" + "If you meant to switch to a different project/video, " + "please save and clear the current layers before loading the new labeled data folder.", + ) + + def _reject_conflicting_dlc_image_layer(self, layer: Image, reason: str) -> None: + """Reject a conflicting DLC session image safely. + + Do not remove synchronously inside the insert callback: + napari may still be finalizing list insertion / selection. + """ + self.viewer.status = reason + self.session_conflict_rejected.emit(reason) + + def _remove_later(ly=layer): + try: + if ly in self.viewer.layers: + self.viewer.layers.remove(ly) + except Exception: + logger.debug( + "Failed to remove conflicting DLC image layer %r", + getattr(ly, "name", ly), + exc_info=True, + ) + + QTimer.singleShot(0, _remove_later) + + # ------------------------------------------------------------------ # + # Layer setup managers # + # ------------------------------------------------------------------ # + @staticmethod + def _layer_source_path(layer) -> str | None: + """Best-effort access to napari layer source path (may not exist).""" + try: + src = getattr(layer, "source", None) + p = getattr(src, "path", None) if src is not None else None + return str(p) if p else None + except Exception: + return None + + @staticmethod + def get_header_model_from_metadata(md: dict) -> DLCHeaderModel | None: + """Return DLCHeaderModel from metadata payload, if possible.""" + if not isinstance(md, dict): + return None + + hdr = md.get("header", None) + if hdr is None: + return None + + if isinstance(hdr, DLCHeaderModel): + return hdr + + if isinstance(hdr, dict): + try: + return DLCHeaderModel.model_validate(hdr) + except Exception: + return None + + try: + return DLCHeaderModel(columns=hdr) + except Exception: + return None + + @staticmethod + def is_multianimal(layer) -> bool: + """Return True if this layer looks like a multi-animal Points layer.""" + if layer is None or not isinstance(layer, Points): + return False + + md = layer.metadata or {} + hdr = LayerLifecycleManager.get_header_model_from_metadata(md) + if hdr is None: + return False + + try: + inds = hdr.individuals + # return bool(inds and len(inds) > 0 and str(inds[0]) != "") + return any(str(ind) != "" for ind in inds) + except Exception: + return False + + @staticmethod + def is_config_placeholder_points_layer(layer: Points) -> bool: + """Return True if this looks like the temporary config placeholder layer. + + - must be a Points layer + - must carry a project hint + - must not already be tied to image/root/paths context + - must not contain actual point data + """ + if layer is None or not isinstance(layer, Points): + return False + + md = layer.metadata or {} + if not md.get("project"): + return False + + # Real labeled/prediction layers usually carry image/root/paths context. + if md.get("root") or md.get("paths"): + return False + + try: + data = np.asarray(layer.data) if layer.data is not None else np.empty((0, 3)) + except Exception: + data = np.empty((0, 3)) + + return data.size == 0 + + @staticmethod + def validate_header(layer: Points) -> bool: + res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=False) + if hasattr(res, "errors") or getattr(res, "header", None) is None: + # self.viewer.status = ( + # "This Points layer does not look like a DLC keypoints layer. Missing a valid DLC header." + # ) + logger.debug("Points layer %r failed DLC header validation: %s", getattr(layer, "name", layer), res) + return False + return True + + def _update_image_meta_from_layer(self, layer: Image) -> bool: + """Update lifecycle-owned image metadata from an Image layer. + + Returns + ------- + bool + True if the authoritative image context changed, else False. + """ + md = layer.metadata or {} + + paths = md.get("paths") + try: + shape = layer.level_shapes[0] + except Exception: + shape = None + + root = infer_image_root( + explicit_root=md.get("root"), + paths=paths, + source_path=self._layer_source_path(layer), + ) + + incoming = ImageMetadata( + paths=list(paths) if paths else None, + root=str(root) if root else None, + shape=tuple(shape) if shape is not None else None, + name=getattr(layer, "name", None), + ) + + base = self._image_meta + merged = base.model_copy(deep=True) + for field, value in incoming.model_dump().items(): + if getattr(merged, field) in (None, "", []) and value not in (None, "", []): + setattr(merged, field, value) + + changed = merged != self._image_meta + self._image_meta = merged + return changed + + def _sync_points_layers_from_image_meta(self) -> None: + """Sync managed/all points metadata from lifecycle-owned image context.""" + if self._image_meta is None: + return + + for ly in list(self.viewer.layers): + if not isinstance(ly, Points): + continue + + if ly.metadata is None: + ly.metadata = {} + + res = read_points_meta(ly, migrate_legacy=True, drop_controls=False, drop_header=False) + if hasattr(res, "errors"): + logger.warning( + "Points metadata validation failed during sync for layer=%r: %s", + getattr(ly, "name", ly), + res, + ) + continue + + pts_model: PointsMetadata = res + synced = sync_points_from_image(self._image_meta, pts_model) + + out = write_points_meta( + ly, + synced, + merge_policy=MergePolicy.MERGE_MISSING, + migrate_legacy=True, + validate=True, + ) + if hasattr(out, "errors"): + logger.warning( + "Failed to write synced points metadata for layer=%r: %s", + getattr(ly, "name", ly), + out, + ) + + def _cache_project_path_from_image_layer(self, layer: Image) -> None: + """Best-effort lifecycle-owned cache of project path from an image/video layer.""" + project_path = resolve_project_path_from_image_layer(layer) + if project_path is None: + return + + self._project_path = project_path + try: + layer.metadata = dict(layer.metadata or {}) + layer.metadata.setdefault("project", self._project_path) + except Exception: + logger.debug( + "Failed to set project path metadata on image layer %r", + getattr(layer, "name", layer), + exc_info=True, + ) + + self.refresh_video_panel_requested.emit() + + def _setup_image_layer(self, layer: Image, index: int | None = None, *, reorder: bool = True) -> None: + """Lifecycle-owned setup for an inserted/adopted Image layer.""" + md = layer.metadata or {} + paths = md.get("paths") + if paths is None: + try: + if is_video(layer.name): + self.video_widget_visibility_requested.emit(True) + except Exception: + pass + + self._active_dlc_image_layer_id = id(layer) + context_changed = self._update_image_meta_from_layer(layer) + + if not self._project_path: + self._cache_project_path_from_image_layer(layer) + if self._project_path is not None: + try: + layer.metadata = dict(layer.metadata or {}) + layer.metadata.setdefault("project", self._project_path) + except Exception: + logger.debug( + "Failed to set project path metadata on image layer %r", + getattr(layer, "name", layer), + exc_info=True, + ) + + if context_changed: + self._sync_points_layers_from_image_meta() + + self.refresh_video_panel_requested.emit() + + logger.debug( + "Setup image layer=%r index=%s reorder=%s paths_count=%s root=%r context_changed=%s", + getattr(layer, "name", layer), + index, + reorder, + len(md.get("paths") or []), + md.get("root"), + context_changed, + ) + + if reorder and index is not None: + QTimer.singleShot(10, lambda ly=layer: self.move_image_layer_to_bottom_requested.emit(ly)) + + def _wire_points_layer(self, layer: Points) -> KeypointStore | None: + """Lifecycle-owned wiring of a managed Points layer. + + The manager owns registration and lifecycle sequencing. + Runtime/UI completion is delegated via points_layer_setup_requested. + """ + if not self.validate_header(layer): + return None + + existing = getattr(layer, "_dlc_store", None) + if existing is not None: + self.register_managed_points_layer(layer, existing) + + runtime = self.get_live_runtime(layer) + existing_resources = None + if runtime is not None: + existing_resources = runtime.resources.get("points_runtime", None) + + req = PointsLayerSetupRequest( + layer=layer, + store=existing, + existing_resources=existing_resources, + ) + self.points_layer_setup_requested.emit(req) + + runtime = self.get_live_runtime(layer) + if runtime is not None and req.runtime_resources is not None: + runtime.resources["points_runtime"] = req.runtime_resources + + return existing + + mig = migrate_points_layer_metadata(layer) + if hasattr(mig, "errors"): + logger.warning( + "Points metadata validation failed during wiring for layer=%r: %s", + getattr(layer, "name", layer), + mig, + ) + + store = keypoints.KeypointStore(self.viewer, layer) + self.register_managed_points_layer(layer, store) + + layer._dlc_store = store + + proj = layer.metadata.get("project") + if proj: + self._project_path = proj + + if not layer.metadata.get("root") and self._image_meta.root: + layer.metadata["root"] = self._image_meta.root + if not layer.metadata.get("paths") and self._image_meta.paths: + layer.metadata["paths"] = self._image_meta.paths + + if root := layer.metadata.get("root"): + update_save_history(root) + + layer.text.visible = False + + req = PointsLayerSetupRequest(layer=layer, store=store) + self.points_layer_setup_requested.emit(req) + + runtime = self.get_live_runtime(layer) + if runtime is not None and req.runtime_resources is not None: + runtime.resources["points_runtime"] = req.runtime_resources + + md = layer.metadata or {} + logger.debug( + "Wire points layer=%r existing_store=%s project=%s root=%s len_paths=%s", + getattr(layer, "name", layer), + getattr(layer, "_dlc_store", None) is not None, + md.get("project"), + md.get("root"), + len(md.get("paths", [])), + ) + + return store + + def _remove_layer_if_present(self, layer: Layer) -> None: + try: + if layer in self.viewer.layers: + with log_timing( + logger, + f"viewer.layers.remove layer={getattr(layer, 'name', layer)!r}", + threshold_ms=0.01, + ): + self.viewer.layers.remove(layer) + except Exception: + logger.debug("Failed to remove layer=%r", getattr(layer, "name", layer), exc_info=True) + + @staticmethod + def _set_layer_visible(layer: Layer, visible: bool) -> None: + try: + layer.visible = visible + except Exception: + try: + layer.shown = visible + except Exception: + logger.debug("Failed to set visibility for layer=%r", getattr(layer, "name", layer), exc_info=True) + + def _setup_points_layer(self, layer: Points, *, allow_merge: bool = True) -> None: + """Lifecycle-owned setup for an inserted/adopted Points layer.""" + if not self.validate_header(layer): + return + + if allow_merge: + consumed = self._maybe_merge_config_points_layer(layer) + if consumed: + logger.debug( + "Consumed temporary config placeholder layer=%r during merge path", + getattr(layer, "name", layer), + ) + return + + store = self._wire_points_layer(layer) + if store is None: + return + + logger.debug( + "Setup points layer=%r allow_merge=%s metadata_keys=%s", + getattr(layer, "name", layer), + allow_merge, + sorted((layer.metadata or {}).keys()), + ) + + def _schedule_post_remove_refresh(self) -> None: + """Coalesce repeated UI refreshes during layer removal bursts.""" + self._post_remove_refresh_timer.start(0) + + def _flush_post_remove_refresh(self) -> None: + with log_timing(logger, "_flush_post_remove_refresh total", threshold_ms=0.01): + self.refresh_video_panel_requested.emit() + self.refresh_layer_status_requested.emit() + + def _handle_removed_layer(self, layer: Any) -> None: + """Lifecycle-owned remove handling. + + Transitional note: + ------------------ + UI/menu cleanup still delegates to a widget-owned UI hook. + """ + with log_timing( + logger, + f"_handle_removed_layer total layer={getattr(layer, 'name', layer)!r}", + threshold_ms=0.01, + ): + n_points_layer = sum(isinstance(l, Points) for l in self.viewer.layers) + + if isinstance(layer, Points): + with log_timing( + logger, + f"unregister_managed_layer layer={getattr(layer, 'name', layer)!r}", + threshold_ms=0.01, + ): + store = self.unregister_managed_layer(layer) + + if store is not None: + with log_timing( + logger, + f"points_layer_removed_requested layer={getattr(layer, 'name', layer)!r}", + threshold_ms=0.01, + ): + self.points_layer_removed_requested.emit(layer, n_points_layer) + + elif isinstance(layer, Image): + if self._active_dlc_image_layer_id == id(layer): + self._active_dlc_image_layer_id = None + self._image_meta = ImageMetadata() + self._project_path = None + + paths = layer.metadata.get("paths") + if paths is None: + self.video_widget_visibility_requested.emit(False) + else: + logger.debug( + "Removed non-session or inactive image layer=%r; keeping current DLC session context.", + getattr(layer, "name", layer), + ) + + elif isinstance(layer, Tracks): + self.tracks_layer_removed_requested.emit(layer) + + self._schedule_post_remove_refresh() + + def _remap_frame_indices(self, layer: Any) -> None: + """Lifecycle-owned remap of non-Image layer time/frame indices.""" + try: + new_paths = self._image_meta.paths + if not new_paths: + return + + if layer.metadata is None: + layer.metadata = {} + + md = layer.metadata + old_paths = md.get("paths") or [] + + try: + safe_image_meta = self._image_meta.model_dump(exclude_none=True) + safe_image_meta.pop("paths", None) + layer.metadata.update(safe_image_meta) + except Exception: + logger.debug( + "Failed to sync non-path image metadata for layer=%r", + getattr(layer, "name", str(layer)), + exc_info=True, + ) + + if not old_paths: + logger.debug( + "Skipping remap for layer=%r: no existing layer metadata paths.", + getattr(layer, "name", str(layer)), + ) + return + + time_col = 1 if isinstance(layer, Tracks) else 0 + + if logger.isEnabledFor(logging.DEBUG): + arr_before = np.asarray(layer.data) + logger.debug( + "Remap start layer=%r old_paths_len=%s new_paths_len=%s data_shape=%s frame_min=%s frame_max=%s", + getattr(layer, "name", str(layer)), + len(old_paths), + len(new_paths or []), + getattr(arr_before, "shape", None), + int(np.nanmin(arr_before[:, time_col])) if arr_before.size else None, + int(np.nanmax(arr_before[:, time_col])) if arr_before.size else None, + ) + + res = remap_layer_data_by_paths( + data=layer.data, + old_paths=old_paths, + new_paths=new_paths, + time_col=time_col, + policy=PathMatchPolicy.ORDERED_DEPTHS, + ) + + logger.debug( + "Remap result layer=%r changed=%s mapped_count=%s depth=%s message=%s warnings=%s", + getattr(layer, "name", str(layer)), + res.changed, + res.mapped_count, + res.depth_used, + res.message, + res.warnings, + ) + + if res.applied and res.data is not None: + layer.data = res.data + + if res.accept_paths_update: + layer.metadata["paths"] = list(new_paths) + if isinstance(layer, Points): + mark_layer_presentation_changed(layer) + + if res.depth_used is None: + logger.debug("Remap skipped for %s: %s", getattr(layer, "name", str(layer)), res.message) + else: + logger.debug( + "Remap %s for %s (depth=%s, mapped=%s): %s", + "applied" if res.changed else "accepted-noop", + getattr(layer, "name", str(layer)), + res.depth_used, + res.mapped_count, + res.message, + ) + + except Exception: + logger.exception("Failed to remap frame indices for layer %s", getattr(layer, "name", str(layer))) + + def _maybe_merge_config_points_layer(self, layer: Points) -> bool: + """Merge a temporary config placeholder layer into existing managed layers. + + Returns + ------- + bool + True if the placeholder layer was consumed by the merge flow + (including explicit HIDE_NEW / CANCEL handling), else False. + """ + if not self.is_config_placeholder_points_layer(layer): + return False + + managed = list(self.iter_managed_points()) + if not managed: + return False + + md = layer.metadata or {} + logger.debug( + "Maybe merge config placeholder layer=%r project=%r managed_layers=%d", + getattr(layer, "name", layer), + md.get("project"), + len(managed), + ) + + new_metadata = md.copy() + new_header = self.get_header_model_from_metadata(new_metadata) + if new_header is None: + logger.debug( + "Skipping config placeholder merge for layer=%r: missing/invalid header", + getattr(layer, "name", layer), + ) + return False + + reference_layer, _reference_store = managed[0] + reference_header = self.get_header_model_from_metadata(reference_layer.metadata or {}) + if reference_header is None: + logger.debug( + "Skipping config placeholder merge for layer=%r: reference managed layer has no valid header", + getattr(layer, "name", layer), + ) + return False + + current_keypoint_set = set(reference_header.bodyparts) + new_keypoint_set = set(new_header.bodyparts) + diff = tuple(sorted(new_keypoint_set.difference(current_keypoint_set))) + + visible_existing_layer = None + for managed_layer, _store in managed: + if managed_layer is layer: + continue + try: + if getattr(managed_layer, "visible", True): + visible_existing_layer = managed_layer + break + except Exception: + visible_existing_layer = managed_layer + break + + message = f"New keypoint{'s' if len(diff) > 1 else ''} {', '.join(diff)} found." if diff else "" + + decision = self._resolve_merge_decision( + new_layer=layer, + existing_layers=tuple(ly for ly, _ in managed if ly is not layer), + added_keypoints=diff, + message=message, + ) + + disposition = decision.disposition + + # Optional visibility policy before merge + if disposition is MergeDisposition.HIDE_EXISTING and visible_existing_layer is not None: + self._set_layer_visible(visible_existing_layer, False) + logger.debug( + "Config placeholder merge layer=%r hid_existing_layer=%r", + getattr(layer, "name", layer), + getattr(visible_existing_layer, "name", visible_existing_layer), + ) + + if disposition is MergeDisposition.HIDE_NEW: + QTimer.singleShot(0, lambda ly=layer: self._remove_layer_if_present(ly)) + return True + + if disposition is MergeDisposition.CANCEL: + # Conservative behavior: stop lifecycle setup and leave the placeholder + # layer untouched/unmanaged for now. + logger.debug( + "Config placeholder merge cancelled for layer=%r", + getattr(layer, "name", layer), + ) + return True + + if diff: + self.viewer.status = message + + # Merge header into all managed layers. + affected_layers: list[Points] = [] + for managed_layer, store in managed: + pts = read_points_meta( + managed_layer, + migrate_legacy=True, + drop_controls=True, + drop_header=False, + ) + if not hasattr(pts, "errors"): + updated = pts.model_copy(update={"header": new_header}) + write_points_meta( + managed_layer, + updated, + merge_policy=MergePolicy.MERGE, + fields={"header"}, + ) + store.layer = managed_layer + affected_layers.append(managed_layer) + + # Apply updated presentation metadata to existing managed layers. + for managed_layer, store in managed: + managed_layer.metadata["config_colormap"] = new_metadata.get( + "config_colormap", + managed_layer.metadata.get("config_colormap"), + ) + if "face_color_cycles" in new_metadata: + managed_layer.metadata["face_color_cycles"] = new_metadata["face_color_cycles"] + managed_layer.metadata["colormap_name"] = new_metadata.get( + "colormap_name", + managed_layer.metadata.get("colormap_name"), + ) + + mark_layer_presentation_changed(managed_layer) + store.layer = managed_layer + + # Ask UI consumers to refresh menus/colors based on updated managed layers. + self.points_layers_merged_requested.emit(tuple(affected_layers)) + + # Remove the temporary placeholder layer explicitly by identity. + QTimer.singleShot(0, lambda ly=layer: self._remove_layer_if_present(ly)) + + # General panel refreshes + self.refresh_layer_status_requested.emit() + + return True + + def _resolve_merge_decision( + self, + *, + new_layer: Any, + existing_layers: tuple[Any, ...], + added_keypoints: tuple[str, ...], + message: str, + ) -> MergeDecisionResult: + provider = self._merge_decision_provider + if provider is None: + return MergeDecisionResult(disposition=MergeDisposition.KEEP_BOTH) + + req = MergeDecisionRequest( + new_layer=new_layer, + existing_layers=existing_layers, + added_keypoints=added_keypoints, + message=message, + ) + + try: + result = provider.resolve_merge(req) + except Exception: + logger.debug("Merge decision provider failed; defaulting to KEEP_BOTH", exc_info=True) + return MergeDecisionResult(disposition=MergeDisposition.KEEP_BOTH) + + if result is None: + return MergeDecisionResult(disposition=MergeDisposition.KEEP_BOTH) + + return result + + def attach_points_layer_runtime( + self, + *, + layer: Points, + store: keypoints.KeypointStore, + controls: Any, + resolve_layer_by_id: Callable[[int], Points | None], + get_label_mode: Callable[[], Any], + schedule_recolor: Callable[[Points], None], + existing_resources: PointsRuntimeResources | None = None, + ) -> PointsRuntimeResources: + """Attach managed runtime behavior to a Points layer. + + Responsibilities + ---------------- + - bind lifecycle-backed layer resolution to the store + - bind label-mode getter to the store + - install paste patch + - install add wrapper + - add/connect query_next_frame event + - install points-layer keybindings + + This helper does NOT: + - register the runtime in the lifecycle registry + - own UI/menu setup + - decide merge/remap/save policy + """ + resources = existing_resources or PointsRuntimeResources() + + # Narrow lifecycle dependencies injected explicitly. + store.attach_layer_resolver(resolve_layer_by_id) + store.set_label_mode_getter(get_label_mode) + + # Copy/paste patch + if not resources.paste_patch_installed: + paste_func = make_paste_data(controls, store=store) + install_paste_patch(layer, paste_func=paste_func) + resources.paste_patch_installed = True + + # Add layer to store + if not resources.add_wrapper_installed: + add_impl = MethodType(keypoints.KeypointStore.add, store) + install_add_wrapper(layer, add_impl=add_impl, schedule_recolor=schedule_recolor) + resources.add_wrapper_installed = True + + # layer-specific navigation event + if not hasattr(layer.events, "query_next_frame"): + layer.events.add(query_next_frame=Event) + resources.query_next_frame_event_added = True + + if not resources.query_next_frame_connected: + try: + layer.events.query_next_frame.connect(store._advance_step) + resources.query_next_frame_connected = True + except Exception: + pass + + if not resources.keybindings_installed: + install_points_layer_keybindings(layer, controls, store) + resources.keybindings_installed = True + + return resources + + # ------------------------------------------------------------------ # + # Registry facade # + # ------------------------------------------------------------------ # + + def is_managed(self, layer: Any) -> bool: + """Whether this exact currently live layer is registered.""" + return self.registry.is_managed(layer) + + def register_managed_layer(self, layer: Layer, store: KeypointStore, **resources: Any) -> None: + if isinstance(layer, Points): + self.register_managed_points_layer(layer, store, **resources) + else: + raise ValueError(f"Unsupported layer type for management: {type(layer).__name__}") + + def register_managed_points_layer(self, layer: Points, store: KeypointStore, **resources: Any) -> None: + """Register a managed Points layer if not already registered.""" + if self.registry.is_managed(layer): + return + + self.registry.register( + layer, + ManagedPointsRuntime( + layer_id=id(layer), + store=store, + resources=dict(resources), + ), + ) + + def unregister_managed_layer(self, layer_or_id: Any) -> Any | None: + """Unregister a managed layer by layer object or layer id.""" + runtime = self.registry.unregister(layer_or_id) + return None if runtime is None else runtime.store + + # ------------------------------------------------------------------ # + # Event entry points # + # ------------------------------------------------------------------ # + + def adopt_existing_layers(self) -> None: + logger.debug("Lifecycle manager adopting existing layers count=%d", len(self.viewer.layers)) + + layers_snapshot = list(self.viewer.layers) + for idx, layer in enumerate(layers_snapshot): + self._adopt_layer(layer, idx) + + self.adopted_existing_layers.emit() + + def on_insert(self, event: Any) -> None: + layer = self._resolve_inserted_layer(event) + if layer is None: + logger.debug("Lifecycle manager could not resolve inserted layer for event=%r", event) + return + + logger.debug( + "Lifecycle manager processing insert layer=%r type=%s index=%s", + getattr(layer, "name", layer), + type(layer).__name__, + getattr(event, "index", None), + ) + + should_remap = False + + if isinstance(layer, Image): + should_remap = self._maybe_accept_and_setup_image_layer( + layer, + getattr(event, "index", None), + ) + elif isinstance(layer, Points): + self._setup_points_layer(layer, allow_merge=True) + should_remap = True + + if should_remap: + for layer_ in self.viewer.layers: + if not isinstance(layer_, Image): + self._remap_frame_indices(layer_) + + self.refresh_video_panel_requested.emit() + self.refresh_layer_status_requested.emit() + self.layer_insert_processed.emit(layer) + + def on_remove(self, event: Any) -> None: + layer = getattr(event, "value", None) + if layer is None: + logger.debug("Lifecycle manager received remove event without value: %r", event) + return + + logger.debug( + "Lifecycle manager processing remove layer=%r type=%s", + getattr(layer, "name", layer), + type(layer).__name__, + ) + + with log_timing( + logger, + f"on_remove total layer={getattr(layer, 'name', layer)!r}", + threshold_ms=0.01, + ): + self._handle_removed_layer(layer) + self.layer_remove_processed.emit(layer) + + def _maybe_accept_and_setup_image_layer(self, layer: Image, index: int | None) -> bool: + if not self.is_dlc_session_image_layer(layer): + logger.debug( + "Ignoring non-DLC image layer during lifecycle setup: %r", + getattr(layer, "name", layer), + ) + return False + + ok, reason = self.can_accept_dlc_session_image(layer) + if not ok: + self._reject_conflicting_dlc_image_layer( + layer, + reason or "Conflicting DLC project/video layer", + ) + return False + + self._setup_image_layer(layer, index, reorder=True) + return True + + def _adopt_layer(self, layer: Any, index: int) -> None: + logger.debug( + "Lifecycle manager adopt layer=%r type=%s index=%s", + getattr(layer, "name", layer), + type(layer).__name__, + index, + ) + + if isinstance(layer, Image): + self._maybe_accept_and_setup_image_layer(layer, index) + elif isinstance(layer, Points): + if not self.registry.is_managed(layer): + self._setup_points_layer(layer, allow_merge=False) + + if not isinstance(layer, Image): + self._remap_frame_indices(layer) + + def _resolve_inserted_layer(self, event: Any) -> Any | None: + # Best case: event carries the inserted value directly + layer = getattr(event, "value", None) + if layer is not None: + return layer + + # Prefer explicit event index over “last item in source” + index = getattr(event, "index", None) + if isinstance(index, int): + try: + return self.viewer.layers[index] + except Exception: + pass + + # Conservative fallback only + source = getattr(event, "source", None) + try: + if source: + return source[-1] + except Exception: + pass + + return None diff --git a/src/napari_deeplabcut/core/layer_lifecycle/merge.py b/src/napari_deeplabcut/core/layer_lifecycle/merge.py new file mode 100644 index 00000000..dc6963c8 --- /dev/null +++ b/src/napari_deeplabcut/core/layer_lifecycle/merge.py @@ -0,0 +1,30 @@ +# src/napari_deeplabcut/core/layer_lifecycle/merge.py +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Protocol + + +class MergeDisposition(str, Enum): + KEEP_BOTH = "keep_both" + HIDE_EXISTING = "hide_existing" + HIDE_NEW = "hide_new" + CANCEL = "cancel" + + +@dataclass(frozen=True, slots=True) +class MergeDecisionRequest: + new_layer: Any + existing_layers: tuple[Any, ...] + added_keypoints: tuple[str, ...] + message: str + + +@dataclass(frozen=True, slots=True) +class MergeDecisionResult: + disposition: MergeDisposition + + +class MergeDecisionProvider(Protocol): + def resolve_merge(self, request: MergeDecisionRequest) -> MergeDecisionResult: ... diff --git a/src/napari_deeplabcut/core/layer_lifecycle/registry.py b/src/napari_deeplabcut/core/layer_lifecycle/registry.py new file mode 100644 index 00000000..fd8af29d --- /dev/null +++ b/src/napari_deeplabcut/core/layer_lifecycle/registry.py @@ -0,0 +1,325 @@ +# src/napari_deeplabcut/core/layer_lifecycle/registry.py +from __future__ import annotations + +import logging +import weakref +from collections.abc import Iterator +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +if TYPE_CHECKING: + from napari.layers import Points + + from ..keypoints import KeypointStore + +logger = logging.getLogger("napari-deeplabcut.lifecycle.registry") +StoreT = TypeVar("StoreT") + + +@dataclass(slots=True) +class ManagedPointsRuntime(Generic[StoreT]): + """Runtime attachment for a managed Points layer. + + Important: + ---------- + This runtime should not strongly store the layer object itself. + + It stores: + - the stable layer identity used by the registry (`layer_id`) + - the store (KeypointStore) associated with this layer + - generic resources for future cleanup hooks / attachments + + Layer liveness / resolution is owned by the registry. + Everything downstream should use it to resolve whether the layer is still live. + This is meant to let napari own true layer lifecycles without interference, + while still enabling robust cleanup of plugin-managed + runtime attachments when layers are removed. + """ + + layer_id: int + store: StoreT + resources: dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class PointsRuntimeResources: + """Non-Qt runtime attachments installed on a managed Points layer. + + It gives the lifecycle manager one place to record what it attached and later clean up or audit. + Intended to fit in ManagedPointsRuntime.resources. + """ + + query_next_frame_event_added: bool = False + query_next_frame_connected: bool = False + add_wrapper_installed: bool = False + paste_patch_installed: bool = False + keybindings_installed: bool = False + extra: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class PointsLayerSetupRequest: + layer: Points + store: KeypointStore + existing_resources: PointsRuntimeResources | None = None + runtime_resources: PointsRuntimeResources | None = None + + +@dataclass(frozen=True, slots=True) +class ClearedRegistryEntry(Generic[StoreT]): + """A registry entry that was removed because its layer was no longer live.""" + + layer_id: int + runtime: ManagedPointsRuntime[StoreT] + + +@dataclass(frozen=True, slots=True) +class RegistryAuditIssue: + code: str + message: str + layer_id: int | None = None + + +@dataclass(frozen=True, slots=True) +class RegistryAuditReport: + live_count: int + dead_count: int + issues: tuple[RegistryAuditIssue, ...] + + +@dataclass(slots=True) +class _RegistryEntry(Generic[StoreT]): + layer_id: int + layer_ref: weakref.ReferenceType[Any] | None + strong_layer: Any | None + runtime: ManagedPointsRuntime[StoreT] + + def resolve_layer(self) -> Any | None: + if self.layer_ref is not None: + return self.layer_ref() + return self.strong_layer + + +class RuntimeRegistry(Generic[StoreT]): + """Single owner of managed runtime attachments. + + Invariants + ---------- + - At most one runtime bundle per registered layer identity. + - Registration is explicit. + - Layer liveness is resolved here. + - Runtime attachments do not need to strongly own the layer object. + """ + + def __init__(self) -> None: + self._entries_by_id: dict[int, _RegistryEntry[StoreT]] = {} + + # ------------------------------------------------------------------ # + # core identity / registration # + # ------------------------------------------------------------------ # + + def __len__(self) -> int: + """Number of currently live entries.""" + return sum(1 for _layer, _runtime in self.iter_live_items()) + + def layer_ids(self) -> tuple[int, ...]: + """All currently registered entry ids, including stale/dead ones.""" + return tuple(self._entries_by_id.keys()) + + def is_managed(self, layer: Any) -> bool: + """Whether this exact live layer object is currently registered and live.""" + entry = self._entries_by_id.get(id(layer)) + if entry is None: + return False + resolved = entry.resolve_layer() + return resolved is layer + + def contains_layer_id(self, layer_id: int) -> bool: + """Whether a registry entry exists for this id (live or stale).""" + return layer_id in self._entries_by_id + + def register(self, layer: Any, runtime: ManagedPointsRuntime[StoreT]) -> None: + layer_id = id(layer) + if layer_id in self._entries_by_id: + raise ValueError(f"Layer already registered: id={layer_id}") + + if runtime.layer_id != layer_id: + raise ValueError(f"Runtime layer_id mismatch: runtime.layer_id={runtime.layer_id}, actual={layer_id}") + + try: + layer_ref: weakref.ReferenceType[Any] | None = weakref.ref(layer) + strong_layer = None + except TypeError: + logger.error("Could not cleanly register layer as a weakref; storing strong reference instead: %r", layer) + # Fallback for objects that do not support weakref. + # This means the registry *will* strongly hold such layers. + layer_ref = None + strong_layer = layer + + self._entries_by_id[layer_id] = _RegistryEntry( + layer_id=layer_id, + layer_ref=layer_ref, + strong_layer=strong_layer, + runtime=runtime, + ) + + def unregister(self, layer_or_id: Any) -> ManagedPointsRuntime[StoreT] | None: + entry = self._entries_by_id.pop(self._coerce_layer_id(layer_or_id), None) + return None if entry is None else entry.runtime + + # ------------------------------------------------------------------ # + # centralized live resolution # + # ------------------------------------------------------------------ # + + def resolve_live_layer(self, layer_or_id: Any) -> Any | None: + """Resolve a currently live layer object from a layer or layer id.""" + entry = self._entries_by_id.get(self._coerce_layer_id(layer_or_id)) + if entry is None: + return None + return entry.resolve_layer() + + def get_live_runtime(self, layer_or_id: Any) -> ManagedPointsRuntime[StoreT] | None: + """Return runtime only if the corresponding layer is currently live.""" + entry = self._entries_by_id.get(self._coerce_layer_id(layer_or_id)) + if entry is None: + return None + if entry.resolve_layer() is None: + return None + return entry.runtime + + def require_live_runtime(self, layer_or_id: Any) -> ManagedPointsRuntime[StoreT]: + runtime = self.get_live_runtime(layer_or_id) + if runtime is None: + raise KeyError(f"Managed live runtime not found: {layer_or_id!r}") + return runtime + + def get_store(self, layer_or_id: Any) -> StoreT | None: + runtime = self.get_live_runtime(layer_or_id) + return None if runtime is None else runtime.store + + def require_store(self, layer_or_id: Any) -> StoreT: + runtime = self.require_live_runtime(layer_or_id) + return runtime.store + + # ------------------------------------------------------------------ # + # live iteration # + # ------------------------------------------------------------------ # + + def iter_live_items(self) -> Iterator[tuple[Any, ManagedPointsRuntime[StoreT]]]: + """Yield only currently live (layer, runtime) pairs.""" + for entry in list(self._entries_by_id.values()): + layer = entry.resolve_layer() + if layer is not None: + yield layer, entry.runtime + + def iter_live_layers(self) -> Iterator[Any]: + for layer, _runtime in self.iter_live_items(): + yield layer + + def iter_live_runtimes(self) -> Iterator[ManagedPointsRuntime[StoreT]]: + for _layer, runtime in self.iter_live_items(): + yield runtime + + # ------------------------------------------------------------------ # + # dead-entry handling / reporting # + # ------------------------------------------------------------------ # + + def dead_layer_ids(self) -> tuple[int, ...]: + """Return ids whose registered layer object is no longer live.""" + dead: list[int] = [] + for layer_id, entry in self._entries_by_id.items(): + if entry.resolve_layer() is None: + dead.append(layer_id) + return tuple(dead) + + def clear_dead_entries(self, *, log: bool = True) -> tuple[ClearedRegistryEntry[StoreT], ...]: + """Remove dead entries and return what was reaped. + + This is intentionally observable so lifecycle cleanup bugs are not silently hidden. + """ + reaped: list[ClearedRegistryEntry[StoreT]] = [] + + for layer_id in list(self.dead_layer_ids()): + entry = self._entries_by_id.pop(layer_id, None) + if entry is None: + continue + + item = ClearedRegistryEntry(layer_id=layer_id, runtime=entry.runtime) + reaped.append(item) + + if log: + logger.warning( + "Cleared dead managed layer entry without explicit unregister: layer_id=%s", + layer_id, + ) + + return tuple(reaped) + + # ------------------------------------------------------------------ # + # diagnostics / auditing # + # ------------------------------------------------------------------ # + + def audit(self) -> RegistryAuditReport: + issues: list[RegistryAuditIssue] = [] + live_count = 0 + dead_count = 0 + seen_ids: set[int] = set() + + for layer_id, entry in self._entries_by_id.items(): + if layer_id in seen_ids: + issues.append( + RegistryAuditIssue( + code="duplicate-layer-id", + message="Duplicate registry layer id detected", + layer_id=layer_id, + ) + ) + seen_ids.add(layer_id) + + resolved = entry.resolve_layer() + if resolved is None: + dead_count += 1 + issues.append( + RegistryAuditIssue( + code="dead-entry", + message="Managed entry has no live layer", + layer_id=layer_id, + ) + ) + continue + + live_count += 1 + + if entry.runtime.layer_id != layer_id: + issues.append( + RegistryAuditIssue( + code="runtime-layer-id-mismatch", + message=( + f"Runtime layer_id ({entry.runtime.layer_id}) does not match registry entry id ({layer_id})" + ), + layer_id=layer_id, + ) + ) + + return RegistryAuditReport( + live_count=live_count, + dead_count=dead_count, + issues=tuple(issues), + ) + + def assert_consistent(self) -> None: + report = self.audit() + assert not report.issues, f"Registry consistency issues: {report.issues!r}" + + # ------------------------------------------------------------------ # + # misc # + # ------------------------------------------------------------------ # + + def clear(self) -> None: + self._entries_by_id.clear() + + @staticmethod + def _coerce_layer_id(layer_or_id: Any) -> int: + if isinstance(layer_or_id, int): + return layer_or_id + return id(layer_or_id) diff --git a/src/napari_deeplabcut/core/layer_lifecycle/spawn.py b/src/napari_deeplabcut/core/layer_lifecycle/spawn.py new file mode 100644 index 00000000..4ef79208 --- /dev/null +++ b/src/napari_deeplabcut/core/layer_lifecycle/spawn.py @@ -0,0 +1,61 @@ +# src/napari_deeplabcut/core/layer_lifecycle/spawn.py +from __future__ import annotations + +import threading +import weakref + +from qtpy.QtCore import QObject + +from .manager import LayerLifecycleManager + +_MANAGER_REGISTRY: weakref.WeakKeyDictionary[object, LayerLifecycleManager] = weakref.WeakKeyDictionary() +_MANAGER_LOCK = threading.RLock() +_VIEWER_ATTR = "_ndlc_layer_manager" + + +def _viewer_qparent(viewer) -> QObject | None: + try: + window = getattr(viewer, "window", None) + qt_window = getattr(window, "_qt_window", None) + return qt_window if isinstance(qt_window, QObject) else None + except Exception: + return None + + +def get_layer_manager(viewer) -> LayerLifecycleManager | None: + with _MANAGER_LOCK: + mgr = _MANAGER_REGISTRY.get(viewer) + if mgr is not None: + return mgr + + try: + mgr = getattr(viewer, _VIEWER_ATTR, None) + except Exception: + mgr = None + + if mgr is not None and getattr(mgr, "viewer", None) is viewer: + _MANAGER_REGISTRY[viewer] = mgr + return mgr + + return None + + +def get_or_create_layer_manager(viewer) -> LayerLifecycleManager: + with _MANAGER_LOCK: + mgr = get_layer_manager(viewer) + if mgr is not None: + return mgr + + mgr = LayerLifecycleManager( + viewer=viewer, + parent=_viewer_qparent(viewer), + ) + mgr.attach() + + _MANAGER_REGISTRY[viewer] = mgr + try: + setattr(viewer, _VIEWER_ATTR, mgr) + except Exception: + pass + + return mgr diff --git a/src/napari_deeplabcut/core/layers.py b/src/napari_deeplabcut/core/layers.py index dd6f658d..cb2fa37d 100644 --- a/src/napari_deeplabcut/core/layers.py +++ b/src/napari_deeplabcut/core/layers.py @@ -13,6 +13,7 @@ from napari_deeplabcut.config.models import AnnotationKind, DLCHeaderModel from napari_deeplabcut.core.keypoints import build_color_cycles +from napari_deeplabcut.utils.deprecations import deprecated T = TypeVar("T") @@ -268,12 +269,16 @@ def set_uniform_point_size(layer: Points, size: int) -> None: layer.size = float(size) -def infer_frame_count(layer: Points, *, preferred_paths: list[str] | None = None) -> int: +def infer_frame_count( + layer: Points, *, preferred_paths: list[str] | None = None, fallback_n_frames: int | None = None +) -> int: md = getattr(layer, "metadata", {}) or {} paths = preferred_paths or md.get("paths") or [] if paths: return len(paths) + if fallback_n_frames is not None: + return fallback_n_frames data = np.asarray(getattr(layer, "data", [])) if data.size == 0: @@ -417,7 +422,9 @@ def _iter_labeled_slots(layer: Points): yield (frame, id_text, label) -def compute_label_progress(layer: Points, *, fallback_paths: list[str] | None = None) -> LabelProgress: +def compute_label_progress( + layer: Points, *, fallback_paths: list[str] | None = None, fallback_n_frames: int | None = None +) -> LabelProgress: """ Compute progress for the active napari layer. @@ -428,7 +435,7 @@ def compute_label_progress(layer: Points, *, fallback_paths: list[str] | None = currently represented in napari: observed_ids × observed_labels """ - frame_count = infer_frame_count(layer, preferred_paths=fallback_paths) + frame_count = infer_frame_count(layer, preferred_paths=fallback_paths, fallback_n_frames=fallback_n_frames) bodypart_count = infer_bodypart_count(layer) individual_count = infer_individual_count(layer) @@ -554,6 +561,7 @@ def infer_folder_display_name( return "—" +@deprecated(details="Use LayerLifecycleManager.active_dlc_image_layer instead") def find_relevant_image_layer(viewer) -> Image | None: active = viewer.layers.selection.active if isinstance(active, Image): diff --git a/src/napari_deeplabcut/core/metadata.py b/src/napari_deeplabcut/core/metadata.py index 36710eb3..916e3802 100644 --- a/src/napari_deeplabcut/core/metadata.py +++ b/src/napari_deeplabcut/core/metadata.py @@ -421,7 +421,7 @@ def attach_source_and_io_to_layer_kwargs( file_path: Path, *, kind: AnnotationKind | None = None, - dataset_key: str = "keypoints", + dataset_key: str = "df_with_missing", ) -> None: """ Attach authoritative source info + IO provenance to napari layer metadata dict. @@ -542,7 +542,7 @@ def _infer_kind_from_source_name(p: Path) -> AnnotationKind | None: def _build_io_from_source_h5( src: str, *, - dataset_key: str = "keypoints", + dataset_key: str = "df_with_missing", ) -> dict[str, Any] | None: """Legacy migration: build io provenance dict from source_h5 string.""" if not isinstance(src, str) or not src: @@ -596,7 +596,7 @@ def _prepare_points_payload( # legacy migration: io from source_h5 if migrate_legacy and not raw.get("io"): src = raw.get("source_h5") - io_dict = _build_io_from_source_h5(src, dataset_key="keypoints") + io_dict = _build_io_from_source_h5(src, dataset_key="df_with_missing") if io_dict: raw["io"] = io_dict @@ -717,7 +717,7 @@ def write_points_meta( # legacy migration (optional): if caller writes anything, keep io stable if migrate_legacy and not merged.get("io") and merged.get("source_h5"): - io_dict = _build_io_from_source_h5(str(merged.get("source_h5")), dataset_key="keypoints") + io_dict = _build_io_from_source_h5(str(merged.get("source_h5")), dataset_key="df_with_missing") if io_dict: merged["io"] = io_dict diff --git a/src/napari_deeplabcut/core/project_paths.py b/src/napari_deeplabcut/core/project_paths.py index 3c7286e8..0cd37aa1 100644 --- a/src/napari_deeplabcut/core/project_paths.py +++ b/src/napari_deeplabcut/core/project_paths.py @@ -665,3 +665,94 @@ def infer_dlc_project_from_image_layer( prefer_project_root=prefer_project_root, max_levels=max_levels, ) + + +# FIXME @C-Achard 2026-04-22 Add unit tests for below funcs +def infer_dlc_project_from_labeled_folder( + folder: str | Path, + *, + prefer_project_root: bool = True, + max_levels: int = 5, +) -> DLCProjectContext: + """ + Infer DLC project context from a labeled-data/ folder. + + This is the common case for folder-parser workflows: + - the dataset folder itself is a valid root anchor + - config.yaml may or may not exist above it + - if config exists, we may elevate project_root/config_path + """ + dataset_folder = normalize_anchor_candidate(folder) + if dataset_folder is None: + raise ValueError(f"Could not normalize labeled folder: {folder!r}") + + return infer_dlc_project( + anchor_candidates=[dataset_folder], + dataset_candidates=[dataset_folder], + explicit_root=dataset_folder, + prefer_project_root=prefer_project_root, + max_levels=max_levels, + ) + + +def infer_dlc_project_from_video_path( + video_path: str | Path, + *, + max_levels: int = 5, +) -> DLCProjectContext | None: + """ + Infer DLC project context for a directly opened video. + + Important: + ---------- + A directly opened video should count as a DLC session video only if it can + be confidently associated with a DLC project context (typically via a nearby + config.yaml). + TODO @C-Achard 2026-04-22: Ensure that adding a config.yaml later properly + updates the inferred context and associated session image layers. + + Returns + ------- + DLCProjectContext | None + Returns None if no recognizable DLC project context can be inferred. + """ + ctx = infer_dlc_project_from_opened( + video_path, + explicit_root=None, + prefer_project_root=True, + max_levels=max_levels, + ) + + # For directly opened videos, require a recognizable project context. + if ctx.project_root is None and ctx.config_path is None: + return None + + return ctx + + +# ----------------------------------------------------------------------------- +# Lifecycle/session helpers +# ----------------------------------------------------------------------------- +def session_key_from_project_context(ctx: DLCProjectContext | None) -> str | None: + """ + Build a stable session key from the strongest available project context hint. + + Priority: + - project_root + - config_path parent (project_root) + - dataset_folder + - root_anchor + + This is intended for lifecycle/session identity, not for IO routing. + """ + if ctx is None: + return None + + key = ctx.project_root or ctx.dataset_folder or ctx.root_anchor + if key is None: + return None + + try: + return str(key.expanduser().resolve()) + except Exception: + return str(key) diff --git a/src/napari_deeplabcut/core/provenance.py b/src/napari_deeplabcut/core/provenance.py index 78b9d34b..18c410b1 100644 --- a/src/napari_deeplabcut/core/provenance.py +++ b/src/napari_deeplabcut/core/provenance.py @@ -51,7 +51,7 @@ def build_gt_save_target( anchor: str, scorer: str, *, - dataset_key: str = "keypoints", + dataset_key: str = "df_with_missing", ) -> IOProvenance: """ Build a GT save_target pointing to CollectedData_.h5 under a folder anchor. @@ -72,7 +72,7 @@ def apply_gt_save_target( *, anchor: str, scorer: str, - dataset_key: str = "keypoints", + dataset_key: str = "df_with_missing", ) -> PointsMetadata: """ Return an updated PointsMetadata with a GT promotion save_target attached. diff --git a/src/napari_deeplabcut/napari.yaml b/src/napari_deeplabcut/napari.yaml index 48883029..5da89702 100644 --- a/src/napari_deeplabcut/napari.yaml +++ b/src/napari_deeplabcut/napari.yaml @@ -23,6 +23,9 @@ contributions: - id: napari-deeplabcut.make_keypoint_controls python_name: napari_deeplabcut._widgets:KeypointControls title: Make keypoint controls + - id: napari-deeplabcut.make_tracking_controls + python_name: napari_deeplabcut.tracking._widgets:TrackingControls + title: Make the tracking controls widget readers: - command: napari-deeplabcut.get_hdf_reader accepts_directories: false @@ -46,3 +49,5 @@ contributions: widgets: - command: napari-deeplabcut.make_keypoint_controls display_name: Keypoint controls + - command: napari-deeplabcut.make_tracking_controls + display_name: Tracking controls diff --git a/src/napari_deeplabcut/tracking/README.md b/src/napari_deeplabcut/tracking/README.md new file mode 100644 index 00000000..093c168e --- /dev/null +++ b/src/napari_deeplabcut/tracking/README.md @@ -0,0 +1,39 @@ +# Point-tracker assisted labeling + +**EXPERIMENTAL FEATURE** + +This feature allows to speed up the labeling process by using a simple point tracker. +Please note this is still a very basic implementation, and **it WILL overwrite existing annotations** +in your napari-DLC Points layer! + +Always **backup your data** before using this feature, and **try it on a copy of your data first.** +We cannot be held responsible for any lost annotations! + +The current intended workflow would be to annotate a single frame, use the tracker to propagate annotations, +and manually correct any mistakes before saving. + +Based on interest, we may polish the user experience and add more advanced tracking algorithms in the future. + +**Basic usage:** + +- The tracking widget is opened via the "Plugin > napari-deeplabcut: Tracking controls" menu. +- In the layer selection lists, select both the video layer and the Points layer to be used for tracking. +- Select the starting frame for tracking by moving the viewer slider to the desired frame. +- Select how many frames you want to track forward and backward (relative to current frame or in absolute terms, termed respectively Rel and Abs.). +- Use the track forward/backward/both buttons to run the tracker. + +**Key bindings:** + +- Tracking Controls + - **`l`** → Track **forward** + - **`k`** → Track **forward (to end)** + - **`h`** → Track **backward** + - **`j`** → Track **backward (to end)** + +- Frame Navigation + - **`i`** → Move **forward one frame** + - **`u`** → Move **backward one frame** + +**Known issues:** +- After several runs, keypoint attributions may get shuffled. Do not run the tracker several times without checking the results in between. +- Can only be run on plugin-controlled Points layers. Creating a new Points layer manually will not allow tracking on it. diff --git a/src/napari_deeplabcut/tracking/_widgets.py b/src/napari_deeplabcut/tracking/_widgets.py new file mode 100644 index 00000000..12069c9b --- /dev/null +++ b/src/napari_deeplabcut/tracking/_widgets.py @@ -0,0 +1,848 @@ +import logging +from copy import deepcopy +from functools import partial + +import napari +import numpy as np +import pandas as pd +from magicgui.widgets import ComboBox, create_widget +from napari.layers import Image, Points +from napari.viewer import Viewer +from qtpy.QtCore import Qt, Signal, Slot +from qtpy.QtGui import QIcon +from qtpy.QtWidgets import ( + QApplication, + QComboBox, + QGridLayout, + QHBoxLayout, + QLabel, + QProgressBar, + QPushButton, + QSlider, + QSpinBox, + QStyle, + QToolButton, + QVBoxLayout, + QWidget, +) + +from napari_deeplabcut._widgets import KeypointControls +from napari_deeplabcut.config.keybinds import ( + MOVE_BACKWARD_FRAME, + MOVE_FORWARD_FRAME, + TRACK_BACKWARD, + TRACK_BACKWARD_END, + TRACK_FORWARD, + TRACK_FORWARD_END, +) + +# Keybinds +from napari_deeplabcut.config.settings import TRACKING_SHORTCUTS_ENABLED +from napari_deeplabcut.core.keypoints import KeypointStore +from napari_deeplabcut.core.layer_lifecycle import get_or_create_layer_manager +from napari_deeplabcut.tracking.core.data import ( + TrackingWorkerData, + TrackingWorkerOutput, + add_query_identity_columns, + build_tracking_result_metadata, +) +from napari_deeplabcut.tracking.core.models import AVAILABLE_TRACKERS +from napari_deeplabcut.tracking.ui.worker import TrackingWorker + +logger = logging.getLogger(__name__) +# TODO @C-Achard: fix the sliders sync not firing (on existing layers ?) + + +class TrackingControls(QWidget): + trackingRequested = Signal(object) + trackedKeypointsAdded = Signal() + + def __init__(self, viewer: "napari.viewer.Viewer"): + super().__init__() + self._viewer: Viewer = viewer + self.lifecycle_manager = get_or_create_layer_manager(viewer) + # self.setObjectName("napari-deeplabcut-tracking-controls") + self.setProperty("ndlc_tracking_controls", True) + + # Layout + ## Data and model selection + self._tracking_method_combo = QComboBox() + self._model_info_button = QToolButton() + self._model_info_button.setIcon(QIcon.fromTheme("help-about")) + self._model_info_button.setIconSize(self._model_info_button.iconSize() * 1.2) + self._keypoint_layer_combo: ComboBox = create_widget(annotation=Points) + self._video_layer_combo: ComboBox = create_widget(annotation=Image) + self._video_layer_combo.changed.connect(self._video_layer_changed) + ## Frame selection controls + self._set_ref_button = QPushButton() + self._reference_spinbox = QSpinBox() + self._reference_spinbox.setReadOnly(True) + self._reference_spinbox.setButtonSymbols(QSpinBox.NoButtons) + self._updating_controls = False + ### Backward + self._backward_slider = QSlider(Qt.Horizontal) + self._backward_spinbox_absolute = QSpinBox() + self._backward_spinbox_relative = QSpinBox() + ### Forward + self._forward_slider = QSlider(Qt.Horizontal) + self._forward_spinbox_absolute = QSpinBox() + self._forward_spinbox_relative = QSpinBox() + ## Tracking controls + self._tracking_stop_button = QPushButton() + self._tracking_forward_button = QPushButton() + self._tracking_forward_button.clicked.connect(self.track_forward) + self._tracking_forward_end_button = QPushButton() + self._tracking_forward_end_button.clicked.connect(self.track_forward_end) + self._tracking_backward_button = QPushButton() + self._tracking_backward_end_button = QPushButton() + self._tracking_backward_end_button.clicked.connect(self.track_backward_end) + self._tracking_backward_button.clicked.connect(self.track_backward) + self._tracking_bothway_button = QPushButton() + self._tracking_bothway_button.clicked.connect(self.track_bothway) + self._tracking_progress_bar = QProgressBar() + + # Controls + ## Forward controls + self._forward_slider.valueChanged.connect(partial(self._forward_update, from_absolute=False, from_slider=True)) + self._forward_spinbox_relative.valueChanged.connect( + partial(self._forward_update, from_absolute=False, from_slider=False) + ) + self._forward_spinbox_absolute.valueChanged.connect( + partial(self._forward_update, from_absolute=True, from_slider=False) + ) + ## Backward controls + self._backward_slider.valueChanged.connect( + partial(self._backward_update, from_absolute=False, from_slider=True) + ) + self._backward_spinbox_relative.valueChanged.connect( + partial(self._backward_update, from_absolute=False, from_slider=False) + ) + self._backward_spinbox_absolute.valueChanged.connect( + partial(self._backward_update, from_absolute=True, from_slider=False) + ) + + # when the range of viewer dims changes (e.g. on opening a new video), update the reference spinbox max + self._viewer.dims.events.current_step.connect(lambda e: self._reference_spinbox.setValue(int(e.value[0]))) + self._viewer.dims.events.current_step.connect(self._set_frame_controls_range) + self._reference_spinbox.valueChanged.connect(self._set_frame_controls_range) + + # Worker + self.is_tracking = False + self.worker_started = False + self.worker: TrackingWorker | None = None + + # Reference to the keypoint control widget. + # this gets assigned after the user requests tracking for the first time. + self.keypoint_widget: KeypointControls | None = None + + self._setup_keybindings(viewer=viewer) + + self._build_layout() + + def _set_model_info_tooltip(self, current_model_name: str = None): + """Retrieves the display info for the selected model and sets it as tooltip for the model info button.""" + tracker_info = AVAILABLE_TRACKERS.get(current_model_name, None) + tracker_info = tracker_info["class"].info_text if tracker_info is not None else None + if tracker_info is not None: + self._model_info_button.setToolTip(tracker_info) + else: + self._model_info_button.setToolTip("") + + def _set_tooltips(self): + self._tracking_forward_button.setToolTip(TRACK_FORWARD.get_display()) + self._tracking_forward_end_button.setToolTip(TRACK_FORWARD_END.get_display()) + self._tracking_backward_button.setToolTip(TRACK_BACKWARD.get_display()) + self._tracking_backward_end_button.setToolTip(TRACK_BACKWARD_END.get_display()) + self._tracking_bothway_button.setToolTip("Track both ways") + self._tracking_stop_button.setToolTip("Stop tracking") + self._set_ref_button.setToolTip("Set reference frame") + + def _dock_widget(self): + try: + for dock in self._viewer.window._dock_widgets.values(): + if dock.widget() is self: + return dock + except Exception: + return None + return None + + def _tracking_shortcuts_active(self) -> bool: + if not TRACKING_SHORTCUTS_ENABLED: + return False + dock = self._dock_widget() + return dock is not None and dock.isVisible() + + def _setup_keybindings(self, viewer: "napari.viewer.Viewer"): + if not TRACKING_SHORTCUTS_ENABLED: + self._set_tooltips() + return + + @Points.bind_key(TRACK_FORWARD.key, overwrite=True) + def track_forward(event): + if not self._tracking_shortcuts_active(): + return + self.track_forward() + + @Points.bind_key(TRACK_FORWARD_END.key, overwrite=True) + def track_forward_end(event): + if not self._tracking_shortcuts_active(): + return + self.track_forward_end() + + @Points.bind_key(TRACK_BACKWARD.key, overwrite=True) + def track_backward(event): + if not self._tracking_shortcuts_active(): + return + self.track_backward() + + @Points.bind_key(TRACK_BACKWARD_END.key, overwrite=True) + def track_backward_end(event): + if not self._tracking_shortcuts_active(): + return + self.track_backward_end() + + @Points.bind_key(MOVE_FORWARD_FRAME.key, overwrite=True) + def move_forward_frame(event): + if not self._tracking_shortcuts_active(): + return + viewer.dims.current_step = ( + viewer.dims.current_step[0] + 1, + *viewer.dims.current_step[1:], + ) + + @Points.bind_key(MOVE_BACKWARD_FRAME.key, overwrite=True) + def move_backward_frame(event): + if not self._tracking_shortcuts_active(): + return + viewer.dims.current_step = ( + viewer.dims.current_step[0] - 1, + *viewer.dims.current_step[1:], + ) + + self._set_tooltips() + + def _update_frame_controls( + self, + slider, + relative_spinbox, + absolute_spinbox, + reference_spinbox, + value, + direction, + from_absolute=False, + from_slider=False, + ): + """ + Generic function to update slider, relative spinbox, and absolute spinbox. + + Parameters: + - slider: The slider widget. + - relative_spinbox: The relative spinbox widget. + - absolute_spinbox: The absolute spinbox widget. + - reference_spinbox: The reference spinbox widget. + - value: The new value to set. + - direction: "forward" or "backward". + - from_absolute: Whether the update is triggered by the absolute spinbox. + - from_slider: Whether the update is triggered by the slider. + """ + if self._updating_controls: + return + self._updating_controls = True + try: + if from_absolute: + # Update relative and slider from absolute spinbox + relative_value = value - reference_spinbox.value() + if direction == "forward": + relative_value = max(0, relative_value) + else: # backward + relative_value = min(0, relative_value) + relative_spinbox.setValue(relative_value) + slider.setValue(relative_value) + elif from_slider: + # Update relative and absolute spinboxes from slider + relative_spinbox.setValue(value) + absolute_spinbox.setValue(reference_spinbox.value() + value) + else: + # Update slider and absolute spinbox from relative spinbox + slider.setValue(value) + absolute_spinbox.setValue(reference_spinbox.value() + value) + finally: + self._updating_controls = False + + def _seed_query_points_and_features( + self, + ref_frame_idx: int, + ) -> tuple[np.ndarray, pd.DataFrame]: + """ + Extract the points/features from the chosen reference frame and attach + stable query identity columns before sending them to the model. + """ + layer = self.keypoint_layer + if layer is None: + raise ValueError("No keypoint layer selected.") + + mask = np.asarray(layer.data[:, 0]).astype(int) == int(ref_frame_idx) + + keypoints = np.asarray(layer.data[mask], dtype=float).copy() + if len(keypoints) == 0: + raise ValueError(f"No keypoints found on reference frame {ref_frame_idx}.") + + layer_features = layer.features + if isinstance(layer_features, pd.DataFrame): + seed_features = layer_features.loc[mask].reset_index(drop=True).copy() + else: + seed_features = pd.DataFrame(layer_features).loc[mask].reset_index(drop=True).copy() + + if len(seed_features) != len(keypoints): + raise ValueError( + f"Seed feature row count mismatch: got {len(seed_features)} feature rows " + f"for {len(keypoints)} keypoints." + ) + + seed_features = add_query_identity_columns( + seed_features, + query_frame=ref_frame_idx, + source_layer_name=layer.name, + ) + + # In the sliced tracking video, the query frame is always time 0 + keypoints[:, 0] = 0.0 + + return keypoints, seed_features + + def _create_tracking_result_layer( + self, + keypoints: np.ndarray, + features: pd.DataFrame, + *, + tracker_name: str, + ref_frame_idx: int, + ) -> Points: + """ + Create a new Points layer holding the tracking result. + This must NOT modify the original DLC annotation layer. + """ + source = self.keypoint_layer + if source is None: + raise ValueError("No source keypoint layer selected.") + + metadata = build_tracking_result_metadata( + source.metadata, + tracker_name=tracker_name, + source_layer_name=source.name, + query_frame=ref_frame_idx, + ) + + layer = self._viewer.add_points( + data=keypoints, + features=features, + name=f"[Tracked] {source.name}", + metadata=metadata, + ) + + # Distinguish tracking results visually + try: + layer.symbol = "cross" + except Exception: + pass + + try: + layer.opacity = 0.85 + except Exception: + pass + + try: + layer.size = deepcopy(source.size) + except Exception: + pass + + try: + # Optional: keep source colors if useful, but still visually distinct + layer.face_color = deepcopy(source.face_color) + layer.face_color_mode = source.face_color_mode + except Exception: + pass + + try: + layer.border_width = 0.15 + layer.border_color = "green" + except Exception: + pass + + return layer + + def _forward_update(self, value: int, from_absolute: bool, from_slider: bool): + """Helper to update forward controls. + + Parameters: + - value: The new value to set. + - from_absolute: Whether the update is triggered by the absolute spinbox. + - from_slider: Whether the update is triggered by the slider. + """ + self._update_frame_controls( + slider=self._forward_slider, + relative_spinbox=self._forward_spinbox_relative, + absolute_spinbox=self._forward_spinbox_absolute, + reference_spinbox=self._reference_spinbox, + value=value, + direction="forward", + from_absolute=from_absolute, + from_slider=from_slider, + ) + + def _backward_update(self, value: int, from_absolute: bool, from_slider: bool): + """Helper to update backward controls. + + Parameters: + - value: The new value to set. + - from_absolute: Whether the update is triggered by the absolute spinbox. + - from_slider: Whether the update is triggered by the slider. + """ + self._update_frame_controls( + slider=self._backward_slider, + relative_spinbox=self._backward_spinbox_relative, + absolute_spinbox=self._backward_spinbox_absolute, + reference_spinbox=self._reference_spinbox, + value=value, + direction="backward", + from_absolute=from_absolute, + from_slider=from_slider, + ) + + @Slot() + def _set_frame_controls_range(self): + if self._updating_controls: + return + self._updating_controls = True + try: + if self.video_layer is None: + return + + max_frames = self.video_layer.data.shape[0] - 1 + logger.debug(f"Updating tracking controls for video with {max_frames + 1} frames.") + current_frame = max(0, min(self._viewer.dims.current_step[0], max_frames)) + logger.debug(f"Current frame: {current_frame}") + self._reference_spinbox.setRange(0, max_frames) + self._reference_spinbox.setValue(current_frame) + + forward_delta = self._forward_spinbox_relative.value() + self._forward_slider.setRange(0, max_frames - current_frame) + self._forward_slider.setValue(forward_delta) + self._forward_spinbox_relative.setRange(0, max_frames - current_frame) + self._forward_spinbox_relative.setValue(forward_delta) + self._forward_spinbox_absolute.setRange(current_frame, max_frames) + self._forward_spinbox_absolute.setValue(current_frame + forward_delta) # see _forward_update + + backward_delta = self._backward_spinbox_relative.value() + self._backward_slider.setRange(-current_frame, 0) + self._backward_slider.setValue(backward_delta) + self._backward_spinbox_relative.setRange(-current_frame, 0) + self._backward_spinbox_relative.setValue(backward_delta) + self._backward_spinbox_absolute.setRange(0, current_frame) + self._backward_spinbox_absolute.setValue( + current_frame + self._backward_spinbox_relative.value() + ) # see _backward_update + + finally: + self._updating_controls = False + + def _start_worker(self): + self.is_tracking = False + self.worker_started = False + self.worker = TrackingWorker() + + # Explicit queued connections for cross-thread delivery back to this QWidget + self.worker.trackingStarted.connect(self.tracking_started, Qt.QueuedConnection) + self.worker.started.connect(self._on_worker_started, Qt.QueuedConnection) + self.worker.finished.connect(self._on_worker_finished, Qt.QueuedConnection) + self.worker.progress.connect(self._on_worker_progress, Qt.QueuedConnection) + self.worker.trackingFinished.connect(self.tracking_finished, Qt.QueuedConnection) + self.worker.trackingStopped.connect(self.tracking_stopped, Qt.QueuedConnection) + + # Main thread -> worker thread + self.trackingRequested.connect(self.worker.track, Qt.QueuedConnection) + + # Main-thread button click, no UI work in worker + self._tracking_stop_button.clicked.connect(self._request_worker_stop) + + self.worker.start() + + def _request_worker_stop(self): + if self.worker is not None: + self.worker.request_stop() + + def _debug_thread(self, where: str) -> None: + import threading + + from qtpy.QtCore import QThread + + logger.debug( + "%s | python_thread=%s | qt_current_thread=%r | widget_thread=%r", + where, + threading.current_thread().name, + QThread.currentThread(), + self.thread, + ) + + @Slot() + def _on_worker_started(self): + self._debug_thread("_on_worker_started") + self.worker_started = True + + @Slot() + def _on_worker_finished(self): + self._debug_thread("_on_worker_finished") + self.worker_started = False + + @Slot(int, int) + def _on_worker_progress(self, current: int, total: int): + self._debug_thread("_on_worker_progress") + if self._tracking_progress_bar.maximum() != total: + self._tracking_progress_bar.setMaximum(total) + self._tracking_progress_bar.setValue(current) + + @property + def keypoint_layer(self) -> Points | None: + return self._keypoint_layer_combo.value + + @property + def keypoint_store(self) -> KeypointStore | None: + return self.keypoint_widget._stores[self.keypoint_layer] if self.keypoint_widget else None + + @property + def video_layer(self) -> Image | None: + return self._video_layer_combo.value + + @Slot() + def _video_layer_changed(self): + if self._viewer.dims.ndim != 3: + return + self._set_frame_controls_range() + + @Slot() + def tracking_started(self): + self.is_tracking = True + self._tracking_progress_bar.setValue(0) + + @Slot(object) + def tracking_finished(self, out: TrackingWorkerOutput): + self.is_tracking = False + try: + new_features_df = ( + out.keypoint_features + if isinstance(out.keypoint_features, pd.DataFrame) + else pd.DataFrame(out.keypoint_features) + ) + + ref_frame_idx = ( + int(new_features_df["tracking_query_frame"].iloc[0]) + if "tracking_query_frame" in new_features_df.columns and len(new_features_df) + else int(self._reference_spinbox.value()) + ) + + layer = self._create_tracking_result_layer( + out.keypoints, + new_features_df, + tracker_name=self._tracking_method_combo.currentText(), + ref_frame_idx=ref_frame_idx, + ) + + self._viewer.layers.selection.active = layer + self._viewer.status = f'Created tracking result layer "{layer.name}"' + except Exception as e: + logger.exception("Error creating tracking result layer", exc_info=e) + + self._tracking_progress_bar.setValue(self._tracking_progress_bar.maximum()) + self._tracking_progress_bar.setFormat("%p% Done") + self.trackedKeypointsAdded.emit() + + @Slot() + def tracking_stopped(self): + self.is_tracking = False + self._tracking_progress_bar.setValue(self._tracking_progress_bar.maximum()) + self._tracking_progress_bar.setFormat("%p% Stopped") + + # def add_keypoints_to_layer(self, new_keypoints: np.ndarray, new_features: pd.DataFrame): + # current_keypoints = self.keypoint_layer.data + # current_features: pd.DataFrame = self.keypoint_layer.features + + # # Extract unique frame indices + # unique_frames = np.sort(np.unique(np.concatenate((current_keypoints[:, 0], new_keypoints[:, 0])))) + + # merged_keypoints = [] + # merged_features = [] + + # for frame in unique_frames: + # # Select keypoints and features for the current frame + # frame_old_keypoints = current_keypoints[current_keypoints[:, 0] == frame] + # frame_old_features = current_features[current_keypoints[:, 0] == frame] + + # frame_new_keypoints = new_keypoints[new_keypoints[:, 0] == frame] + # frame_new_features = new_features[new_keypoints[:, 0] == frame] + + # # Here we can add custom logic when merging. Right now we overwrite any previous keypoints. + # if len(frame_new_keypoints) > 0: + # # If there are keypoints in new, take those + # merged_keypoints.append(frame_new_keypoints) + # merged_features.append(frame_new_features) + # else: + # merged_keypoints.append(frame_old_keypoints) + # merged_features.append(frame_old_features) + + # merged_keypoints = ( + # np.vstack(merged_keypoints) if merged_keypoints else np.empty((0, current_keypoints.shape[1])) + # ) + # merged_feature_df = pd.concat(merged_features, ignore_index=True) + + # self.keypoint_layer.data = merged_keypoints + # self.keypoint_layer.features = merged_feature_df + + @Slot() + def track_forward(self): + ref_frame_idx: int = self._reference_spinbox.value() + forward_frame_idx: int = self._forward_spinbox_absolute.value() + if forward_frame_idx <= ref_frame_idx: + return + self.track( + (ref_frame_idx, forward_frame_idx + 1), + ref_frame_idx, + backward_tracking=False, + ) + + @Slot() + def track_forward_end(self): + ref_frame_idx: int = self._reference_spinbox.value() + forward_frame_idx: int = self.video_layer.data.shape[0] - 1 + if forward_frame_idx <= ref_frame_idx: + return + self.track( + (ref_frame_idx, forward_frame_idx + 1), + ref_frame_idx, + backward_tracking=False, + ) + + @Slot() + def track_backward(self): + ref_frame_idx: int = self._reference_spinbox.value() + backward_frame_idx: int = self._backward_spinbox_absolute.value() + if backward_frame_idx >= ref_frame_idx: + return + logger.debug(f"Tracking backward from {backward_frame_idx} to {ref_frame_idx}") + self.track( + (backward_frame_idx, ref_frame_idx + 1), + ref_frame_idx, + backward_tracking=True, + ) + + @Slot() + def track_backward_end(self): + ref_frame_idx: int = self._reference_spinbox.value() + backward_frame_idx: int = 0 + if backward_frame_idx >= ref_frame_idx: + return + self.track( + (backward_frame_idx, ref_frame_idx + 1), + ref_frame_idx, + backward_tracking=True, + ) + + @Slot() + def track_bothway(self): + # if forward target is invalid, go directly backward + ref = self._reference_spinbox.value() + fwd = self._forward_spinbox_absolute.value() + if fwd <= ref: + self.track_backward() + return + + self.track_forward() + self.trackedKeypointsAdded.connect(self.track_backward, type=Qt.ConnectionType.SingleShotConnection) + + def track(self, keypoint_range: tuple[int, int], ref_frame_idx, backward_tracking=False): + if not self.worker_started: + self._start_worker() + + if not self.keypoint_widget: + for k, v in self._viewer.window._dock_widgets.items(): + if "Keypoint controls" in k and "napari-deeplabcut" in k: + self.keypoint_widget = v.widget() + break + + if self.is_tracking: + return + + if self.video_layer is None: + logger.warning("No video layer selected.") + return + + if self.keypoint_layer is None: + logger.warning("No keypoint layer selected.") + return + + self._tracking_progress_bar.setFormat("%p%") + + if backward_tracking: + video_slice = self.video_layer.data[keypoint_range[0] : keypoint_range[1]][::-1] + else: + video_slice = self.video_layer.data[keypoint_range[0] : keypoint_range[1]] + + try: + seed_keypoints, seed_features = self._seed_query_points_and_features(ref_frame_idx) + except ValueError as e: + logger.warning(str(e)) + napari.utils.notifications.show_warning(str(e)) + return + + tracking_data = TrackingWorkerData( + tracker_name=self._tracking_method_combo.currentText(), + video=video_slice, + keypoints=seed_keypoints, + keypoint_features=seed_features, + keypoint_range=keypoint_range, + backward_tracking=backward_tracking, + reference_frame_index=int(ref_frame_idx), + ) + self.trackingRequested.emit(tracking_data) + + def _build_layout(self): + # Layout + self.setLayout(QVBoxLayout()) + ## Model selection + self._tracking_method_combo.addItems(AVAILABLE_TRACKERS.keys()) + self._tracking_method_combo.setCurrentIndex(0) + + _model_info_layout = QHBoxLayout() + _model_info_layout.addWidget(QLabel("Tracker")) + _model_info_layout.addWidget(self._model_info_button) + + _tracking_method_layout = QHBoxLayout() + _tracking_method_layout.addLayout(_model_info_layout) + _tracking_method_layout.addWidget(self._tracking_method_combo) + self._tracking_method_combo.currentTextChanged.connect(self._set_model_info_tooltip) + self._set_model_info_tooltip(self._tracking_method_combo.currentText()) + ## Layer selection + ### Keypoint layer + self._viewer.layers.events.inserted.connect(self._keypoint_layer_combo.reset_choices) + self._viewer.layers.events.removed.connect(self._keypoint_layer_combo.reset_choices) + self._viewer.layers.events.reordered.connect(self._keypoint_layer_combo.reset_choices) + _keypoint_layer_method_layout = QHBoxLayout() + _keypoint_layer_method_layout.addWidget(QLabel("Keypoints")) + _keypoint_layer_method_layout.addWidget(self._keypoint_layer_combo.native) + ### Video layer + self._viewer.layers.events.inserted.connect(self._video_layer_combo.reset_choices) + self._viewer.layers.events.removed.connect(self._video_layer_combo.reset_choices) + self._viewer.layers.events.reordered.connect(self._video_layer_combo.reset_choices) + _video_layer_method_layout = QHBoxLayout() + _video_layer_method_layout.addWidget(QLabel("Video")) + _video_layer_method_layout.addWidget(self._video_layer_combo.native) + + # Stack previous layouts + self.layout().addLayout(_tracking_method_layout) + self.layout().addLayout(_keypoint_layer_method_layout) + self.layout().addLayout(_video_layer_method_layout) + + ## Frame range controls + range_controls_layout = QGridLayout() # 3 by 5 + self._backward_slider.setRange(-100, 0) + # NOTE : check why this is not inverting the slider appearance as expected + # self._backward_slider.setInvertedAppearance(True) + range_controls_layout.addWidget(self._backward_slider, 0, 0, 1, 2) + self._backward_spinbox_absolute.setRange(0, 100) + self._backward_spinbox_absolute.setAlignment(Qt.AlignCenter) + self._backward_spinbox_absolute.setStyleSheet( + """ + QSpinBox { + padding: 0; + } + """ + ) + range_controls_layout.addWidget(self._backward_spinbox_absolute, 1, 1) + range_controls_layout.addWidget(QLabel("<< Abs"), 1, 0) + self._backward_spinbox_relative.setRange(-100, 0) + self._backward_spinbox_relative.setAlignment(Qt.AlignCenter) + self._backward_spinbox_relative.setStyleSheet( + """ + QSpinBox { + padding: 0; + } + """ + ) + range_controls_layout.addWidget(QLabel("<< Rel"), 2, 0) + range_controls_layout.addWidget(self._backward_spinbox_relative, 2, 1) + _ref_label = QLabel("Current") + self._reference_spinbox.setRange(0, 100) + self._reference_spinbox.setAlignment(Qt.AlignCenter) + self._reference_spinbox.setStyleSheet( + """ + QSpinBox { + padding: 0; + } + """ + ) + range_controls_layout.addWidget(self._reference_spinbox, 1, 2) + _ref_label.setAlignment(Qt.AlignCenter) + range_controls_layout.addWidget(_ref_label, 0, 2) + self._forward_slider.setRange(0, 100) + range_controls_layout.addWidget(self._forward_slider, 0, 3, 1, 2) + self._forward_spinbox_absolute.setRange(0, 100) + self._forward_spinbox_absolute.setAlignment(Qt.AlignCenter) + self._forward_spinbox_absolute.setStyleSheet( + """ + QSpinBox { + padding: 0; + } + """ + ) + range_controls_layout.addWidget(QLabel("Abs >>"), 1, 4) + range_controls_layout.addWidget(self._forward_spinbox_absolute, 1, 3) + self._forward_spinbox_relative.setRange(0, 100) + self._forward_spinbox_relative.setAlignment(Qt.AlignCenter) + self._forward_spinbox_relative.setStyleSheet( + """ + QSpinBox { + padding: 0; + } + """ + ) + range_controls_layout.addWidget(QLabel("Rel >>"), 2, 4) + range_controls_layout.addWidget(self._forward_spinbox_relative, 2, 3) + + self.layout().addLayout(range_controls_layout) + + ## Start/stop tracking controls + + def themed_icon(name: str, fallback: QStyle.StandardPixmap) -> QIcon: + """Use napari's QApplication instance to get style and fallback icons.""" + # More consistent than using Unicode characters that may vary in appearance across platforms + # Esp. on high-res displays they are very small (4K laptop screens etc) + style = QApplication.instance().style() # reuse existing app + return QIcon.fromTheme(name, style.standardIcon(fallback)) + + tracking_controls_layout = QGridLayout() # 2 by 3 + # NOTE : leaving previous unicode characters as comments for reference + # self._tracking_backward_button.setText("⇤") + # self._tracking_backward_end_button.setText("⇤⇤") + # self._tracking_stop_button.setText("□") + # self._tracking_forward_button.setText("⇥") + # self._tracking_forward_end_button.setText("⇥⇥") + # self._tracking_bothway_button.setText("↹") + self._tracking_backward_button.setIcon(themed_icon("go-previous", QStyle.SP_ArrowLeft)) + tracking_controls_layout.addWidget(self._tracking_backward_button, 0, 0) + self._tracking_backward_end_button.setIcon(themed_icon("media-seek-backward", QStyle.SP_MediaSeekBackward)) + tracking_controls_layout.addWidget(self._tracking_backward_end_button, 1, 0) + + self._tracking_stop_button.setIcon(themed_icon("media-playback-stop", QStyle.SP_MediaStop)) + tracking_controls_layout.addWidget(self._tracking_stop_button, 0, 1) + + self._tracking_forward_button.setIcon(themed_icon("go-next", QStyle.SP_ArrowRight)) + tracking_controls_layout.addWidget(self._tracking_forward_button, 0, 2) + self._tracking_forward_end_button.setIcon(themed_icon("media-seek-forward", QStyle.SP_MediaSeekForward)) + tracking_controls_layout.addWidget(self._tracking_forward_end_button, 1, 2) + + # NOTE: Find a better icon ? Not really any standard icon for "both way" + self._tracking_bothway_button.setIcon(themed_icon("view-refresh", QStyle.SP_BrowserReload)) + tracking_controls_layout.addWidget(self._tracking_bothway_button, 1, 1) + + self._tracking_progress_bar.setRange(0, 100) + self.layout().addLayout(tracking_controls_layout) + self.layout().addWidget(self._tracking_progress_bar) diff --git a/src/napari_deeplabcut/tracking/core/__init__.py b/src/napari_deeplabcut/tracking/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/napari_deeplabcut/tracking/core/data.py b/src/napari_deeplabcut/tracking/core/data.py new file mode 100644 index 00000000..491eb8de --- /dev/null +++ b/src/napari_deeplabcut/tracking/core/data.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass +from typing import Any + +import numpy as np +import pandas as pd + +TRACKING_LAYER_METADATA_KEY = "ndlc_tracking" +TRACKING_SCHEMA_VERSION = 1 + + +# ----- Data schemas ----- +@dataclass +class TrackingWorkerData: + tracker_name: str # model name + video: np.ndarray + keypoints: np.ndarray # (num_keypoint, 3) + # [0]: frame number in `video` [1]: x, [2]: y + + keypoint_features: pd.DataFrame + # one row per query keypoint. Order must be preserved. + + keypoint_range: tuple[int, int] + backward_tracking: bool + reference_frame_index: int | None = None + + +@dataclass(frozen=True) +class TrackingModelInputs: + """Inputs required for tracking model processing.""" + + video: np.ndarray # (num_frames, height, width, channels) + keypoints: np.ndarray # base on model requirements + metadata: dict[str, Any] # Additional metadata if needed + + +@dataclass +class RawModelOutputs: + """Outputs from tracking model processing.""" + + keypoints: np.ndarray # (num_frames, num_keypoints, 2) + keypoint_features: dict[str, Any] # Additional features if needed + + +@dataclass(frozen=True) +class TrackingWorkerOutput: + """ + Returned by models and passed on to the plugin by the worker. + + keypoints: (N, 3) + [:, 0] = frame index (int) + [:, 1] = x coordinate (float) + [:, 2] = y coordinate (float) + + keypoint_features: + shape (N, M), one row per tracked keypoint row, aligned with `keypoints` + """ + + keypoints: np.ndarray + keypoint_features: pd.DataFrame + + +# ------ Data features ------ + + +def coerce_features_df(features) -> pd.DataFrame: + """Return a defensive DataFrame copy with a clean RangeIndex.""" + if isinstance(features, pd.DataFrame): + return features.reset_index(drop=True).copy() + return pd.DataFrame(features).reset_index(drop=True).copy() + + +def add_query_identity_columns( + seed_features: pd.DataFrame, + *, + query_frame: int, + source_layer_name: str, +) -> pd.DataFrame: + """ + Add stable identity columns for each seed query before tracking. + Aims to recover semantic point identity + after tracker inference. + """ + df = coerce_features_df(seed_features) + + df["tracking_query_index"] = np.arange(len(df), dtype=int) + df["tracking_query_frame"] = int(query_frame) + df["tracking_source_layer_name"] = str(source_layer_name) + + return df + + +def expand_query_features_over_time( + seed_features: pd.DataFrame, + *, + frame_ids: np.ndarray, + visibility: np.ndarray | None, + tracker_name: str, +) -> pd.DataFrame: + """ + Repeat seed features across all tracked frames, preserving original + semantic columns (e.g. label, id) and adding tracking-specific fields. + + Parameters + ---------- + seed_features + One row per seed/query point, in the same order as the model query order. + frame_ids + Actual frame indices corresponding to the model output time axis. + visibility + Optional visibility array of shape (T, K) or (T, K, 1). + tracker_name + Human-readable tracker name, e.g. "Cotracker 3". + """ + seed = coerce_features_df(seed_features) + + K = len(seed) + T = len(frame_ids) + + repeated = pd.concat([seed] * T, ignore_index=True) + + repeated["tracking_tracker_name"] = str(tracker_name) + repeated["tracking_frame"] = np.repeat(np.asarray(frame_ids, dtype=int), K) + repeated["tracking_is_prediction"] = True + + if visibility is not None: + vis = np.asarray(visibility) + if vis.ndim == 3 and vis.shape[-1] == 1: + vis = vis[..., 0] + elif vis.ndim == 3 and vis.shape[0] == 1: + vis = vis.squeeze(0) + + expected = (T, K) + if vis.shape != expected: + raise ValueError(f"Visibility shape mismatch. Expected {expected}, got {vis.shape}.") + + repeated["tracking_visible"] = vis.reshape(T * K).astype(bool) + else: + repeated["tracking_visible"] = True + + return repeated + + +def build_tracking_result_metadata( + source_metadata: dict | None, + *, + tracker_name: str, + source_layer_name: str, + query_frame: int, +) -> dict: + """ + Build metadata for a tracking-result Points layer while keeping the source + metadata around as much as possible. + """ + md = deepcopy(source_metadata or {}) + md[TRACKING_LAYER_METADATA_KEY] = { + "schema_version": TRACKING_SCHEMA_VERSION, + "kind": "cotracker-result", + "tracker_name": str(tracker_name), + "source_layer_name": str(source_layer_name), + "query_frame": int(query_frame), + } + return md + + +def is_tracking_result_layer(layer) -> bool: + md = getattr(layer, "metadata", {}) or {} + info = md.get(TRACKING_LAYER_METADATA_KEY) + return isinstance(info, dict) and info.get("kind") == "cotracker-result" diff --git a/src/napari_deeplabcut/tracking/core/models.py b/src/napari_deeplabcut/tracking/core/models.py new file mode 100644 index 00000000..ce4d13fa --- /dev/null +++ b/src/napari_deeplabcut/tracking/core/models.py @@ -0,0 +1,312 @@ +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import numpy as np + +from napari_deeplabcut.tracking.core.data import ( + RawModelOutputs, + TrackingModelInputs, + TrackingWorkerData, + TrackingWorkerOutput, + coerce_features_df, + expand_query_features_over_time, +) + +if TYPE_CHECKING: + from napari_deeplabcut.tracking.core.models import TrackingModel + +# List of available tracking models. +# Automatically populated via the @register_backbone decorator. +AVAILABLE_TRACKERS: dict[str, dict[str, Any]] = {} + +# TODO @C-Achard: consider splitting into base.py (TrackingModel) and putting models in core/models/ + +logger = logging.getLogger(__name__) + + +def register_backbone(model_name: str): + def decorator(cls): + AVAILABLE_TRACKERS[model_name] = { + "class": cls, + } + return cls + + return decorator + + +class TrackingModel(ABC): + """Abstract base class for tracking models. + Use this to add new tracking models. + """ + + # These fields must be set per model + name: str + info_text: str + + def __init__(self, cfg: "TrackingWorkerData"): + super().__init__() + self.cfg: TrackingWorkerData = cfg + if not isinstance(cfg, TrackingWorkerData): + raise ValueError("cfg must be an instance of TrackingWorkerData") + + self.device = self.auto_set_device() + self.model = self.load_model(self.device) + + def auto_set_device(self): + """Automatically set the device for the model. + Override this method if you have specific device requirements. + """ + import torch + + # check for MPS + if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + self.device = "mps" + elif torch.cuda.is_available(): + self.device = "cuda" + else: + self.device = "cpu" + return self.device + + @abstractmethod + def load_model(self, device: str) -> Any: + """Load the model on the specified device. + Make sure to use proper device and set to eval mode. + """ + raise NotImplementedError + + @abstractmethod + def prepare_inputs(self, cfg: "TrackingWorkerData", **kwargs) -> TrackingModelInputs: + """Prepare inputs for processing.""" + raise NotImplementedError + + @abstractmethod + def run(self, inputs: TrackingModelInputs, progress_callback, stop_callback, **kwargs) -> RawModelOutputs: + """Process the inputs and return outputs.""" + raise NotImplementedError + + @abstractmethod + def prepare_outputs( + self, model_outputs: RawModelOutputs, worker_inputs: "TrackingWorkerData" = None, **kwargs + ) -> "TrackingWorkerOutput": + """Prepare outputs after processing.""" + raise NotImplementedError + + @abstractmethod + def validate_outputs(self, inputs: TrackingModelInputs, outputs: "TrackingWorkerOutput") -> tuple[bool, str]: + """Validate the outputs.""" + raise NotImplementedError + + +# pragma: no cover +ct3 = "Cotracker 3" + + +@register_backbone(ct3) +class Cotracker3(TrackingModel): + name = ct3 + info_text = ( + "Cotracker 3 model from Facebook Research.\n" + "See https://cotracker3.github.io/ and CoTracker3: " + "Simpler and Better Point Tracking by " + "Pseudo-Labelling Real Videos by Karaev et al., 2024." + ) + + def __init__(self, cfg): + super().__init__(cfg) + + def load_model(self, device: str): + import torch + + model = torch.hub.load( + "facebookresearch/co-tracker", + "cotracker3_offline", + ).to(device) + model.eval() + return model + + def prepare_inputs(self, cfg: "TrackingWorkerData", **kwargs): + self.cfg = cfg + + # video is originally (num_frames, H, W, C) + video = np.asarray(self.cfg.video) + + # IMPORTANT: do NOT mutate worker_inputs in place + queries = np.asarray(self.cfg.keypoints, dtype=float).copy() + + # CoTracker expects [t, y, x] + queries[:, [1, 2]] = queries[:, [2, 1]] + + metadata = { + "keypoint_range": self.cfg.keypoint_range, + "backward_tracking": self.cfg.backward_tracking, + "reference_frame_index": self.cfg.reference_frame_index, + } + + return TrackingModelInputs(video=video, keypoints=queries, metadata=metadata) + + def run(self, inputs: TrackingModelInputs, progress_callback, stop_callback) -> RawModelOutputs: + import torch + + if stop_callback(): + return None + + # inputs.video is (T, H, W, C) + video = ( + torch.from_numpy(inputs.video).to(self.device).float().permute(0, 3, 1, 2)[None] # -> (1, T, C, H, W) + ) + + # inputs.keypoints is (K, 3), already converted in prepare_inputs() + queries = torch.from_numpy(inputs.keypoints).to(self.device).float()[None] # -> (1, K, 3) + + total_frames = int(inputs.video.shape[0]) + progress_callback(0, total_frames) + + with torch.inference_mode(): + pred_tracks, pred_visibility = self.model( + video, + queries=queries, + ) + + if pred_tracks is None or pred_visibility is None: + raise RuntimeError("CoTracker offline returned no predictions for the provided clip and queries.") + + progress_callback(total_frames, total_frames) + + if stop_callback(): + return None + + tracks = pred_tracks.detach().cpu().numpy() + visibility = pred_visibility.detach().cpu().numpy() + + # Normalize shapes: + # expected from model: (1, T, K, 2) and (1, T, K, 1) or similar + if tracks.ndim == 4 and tracks.shape[0] == 1: + tracks = tracks[0] # -> (T, K, 2) + + if visibility.ndim == 4 and visibility.shape[0] == 1: + visibility = visibility[0] # -> (T, K, 1) + + if visibility.ndim == 3 and visibility.shape[-1] == 1: + visibility = visibility[..., 0] # -> (T, K) + + return RawModelOutputs( + keypoints=tracks, + keypoint_features={"visibility": visibility}, + ) + + def prepare_outputs( + self, + model_outputs: RawModelOutputs, + worker_inputs: TrackingWorkerData = None, + **kwargs, + ) -> "TrackingWorkerOutput": + """ + Convert CoTracker outputs into canonical plugin format while preserving + original per-query semantic identity. + + Result: + - keypoints: (N, 3) = [frame_idx, x, y] + - keypoint_features: one row per tracked point, aligned with keypoints + """ + tracks = np.asarray(model_outputs.keypoints, dtype=float) # expected (T, K, 2) + + if tracks.ndim != 3 or tracks.shape[-1] != 2: + raise ValueError(f"Expected tracks with shape (T, K, 2), got {tracks.shape}.") + + visibility = model_outputs.keypoint_features.get("visibility") + if visibility is not None: + visibility = np.asarray(visibility) + if visibility.ndim == 3 and visibility.shape[-1] == 1: + visibility = visibility[..., 0] + + T1, T2 = map(int, worker_inputs.keypoint_range) + frame_ids = np.arange(T1, T2, dtype=int) + + # IMPORTANT: + # When backward_tracking=True, the model saw the time-reversed video slice. + # So we must reverse the TIME axis before flattening, not the flattened rows. + if worker_inputs.backward_tracking: + tracks = tracks[::-1, :, :] + if visibility is not None: + visibility = visibility[::-1, :] + + T, K, _ = tracks.shape + if T != len(frame_ids): + raise ValueError(f"Time dimension mismatch. tracks has T={T}, frame_ids has len={len(frame_ids)}.") + + seed_features = coerce_features_df(worker_inputs.keypoint_features) + if len(seed_features) != K: + raise ValueError(f"Seed feature row count mismatch. Expected K={K}, got {len(seed_features)}.") + + # Flatten frame-major, preserving query order inside each frame + xy = tracks.reshape(T * K, 2) + keypoints = np.column_stack((np.repeat(frame_ids, K), xy)) + + # Restore plugin convention from [frame, y, x] -> [frame, x, y] + keypoints[:, [1, 2]] = keypoints[:, [2, 1]] + + keypoint_features = expand_query_features_over_time( + seed_features, + frame_ids=frame_ids, + visibility=visibility, + tracker_name=self.name, + ) + + return TrackingWorkerOutput( + keypoints=keypoints, + keypoint_features=keypoint_features, + ) + + def _process_step(self, window_frames, is_first_step, queries): + """ + Internal helper for chunked processing. + """ + import torch + + video_chunk = ( + torch.tensor(np.stack(window_frames[-self.model.step * 2 :]), device=self.device) + .float() + .permute(0, 3, 1, 2)[None] + ) # (1, T, 3, H, W) + logger.debug(f"Video chunk shape: {video_chunk.shape}, Queries shape: {queries.shape}") + return self.model( + video_chunk, + is_first_step=is_first_step, + queries=queries[None], + add_support_grid=True, + ) + + def validate_outputs(self, inputs: TrackingModelInputs, outputs: "TrackingWorkerOutput") -> tuple[bool, str]: + """Validate the outputs.""" + if not isinstance(outputs.keypoints, np.ndarray): + return False, "Outputs keypoints is not a numpy array." + if not outputs.keypoints.ndim == 2: + return False, "Outputs keypoints is not a 2D array." + if not outputs.keypoints.shape[1] == 3: + return False, "Outputs keypoints does not have 3 columns." + # For CoTracker3, outputs contain tracked keypoints for every frame in the + # keypoint_range. If keypoint_range = (T1, T2) and there are K input + # keypoints, we expect (T2 - T1) * K output rows. + metadata = getattr(inputs, "metadata", None) + if metadata is None or "keypoint_range" not in metadata: + return False, "Missing keypoint_range metadata required for validation." + keypoint_range = metadata["keypoint_range"] + if not isinstance(keypoint_range, (tuple, list)) or len(keypoint_range) != 2: + return False, "Invalid keypoint_range metadata; expected (T1, T2)." + T1, T2 = keypoint_range + try: + T1_int = int(T1) + T2_int = int(T2) + except (TypeError, ValueError): + return False, "keypoint_range values must be integers." + if T2_int <= T1_int: + return False, "keypoint_range must satisfy T2 > T1." + K = inputs.keypoints.shape[0] + expected_n_keypoints = (T2_int - T1_int) * K + if outputs.keypoints.shape[0] != expected_n_keypoints: + return ( + False, + f"Number of output keypoints does not match expected ((T2 - T1) * K) = {expected_n_keypoints}.", + ) + return True, "" diff --git a/src/napari_deeplabcut/tracking/ui/__init__,py b/src/napari_deeplabcut/tracking/ui/__init__,py new file mode 100644 index 00000000..e69de29b diff --git a/src/napari_deeplabcut/tracking/ui/worker.py b/src/napari_deeplabcut/tracking/ui/worker.py new file mode 100644 index 00000000..9c41f72a --- /dev/null +++ b/src/napari_deeplabcut/tracking/ui/worker.py @@ -0,0 +1,173 @@ +import logging +import threading +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from qtpy.QtCore import QObject, QThread, Signal, Slot + +from napari_deeplabcut.tracking.core.data import ( + RawModelOutputs, + TrackingModelInputs, + TrackingWorkerData, + TrackingWorkerOutput, +) +from napari_deeplabcut.tracking.core.models import AVAILABLE_TRACKERS + +try: + import importlib + + # import torch + importlib.import_module("torch") +except ImportError as e: + raise ImportError( + "TrackingWorker requires PyTorch to be installed.Please install with `pip install napari-deeplabcut[tracking]`." + ) from e + +"""Worker is not allowed to perform main thread operations, such as : +- viewer.add_* +- viewer.layers.* +- layer.data = ... +- QWidget updates +- QCoreApplication.processEvents() +- anything vispy / OpenGL / rendering-related + +Please be careful when editing this file. +""" + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +DEBUG = False +if DEBUG: + logger.setLevel(logging.DEBUG) + import debugpy + + +@dataclass +class TorchHubModel: + org: str + model: str + + +class TrackingWorker(QObject): + started = Signal() + finished = Signal() + progress = Signal(int, int) # current, total + trackingStarted = Signal() + trackingFinished = Signal(object) # emits TrackingWorkerOutput + trackingStopped = Signal() + + def __init__(self): + super().__init__() + self._stop_requested = threading.Event() + self.thread = None + + @Slot(object) + def track(self, cfg: TrackingWorkerData): + """ + Tracking core logic: + 1. Instantiate model from registry. + 2. prepare_inputs(cfg) + 3. run(inputs, progress_cb, stop_cb) + 4. prepare_outputs(raw, inputs) + 5. Emit results to the plugin. + """ + logger.debug( + "TrackingWorker.track | python_thread=%s | qt_current_thread=%r | worker_thread=%r", + threading.current_thread().name, + QThread.currentThread(), + self.thread, + ) + model = None + self._stop_requested.clear() + + try: + if DEBUG: + debugpy.debug_this_thread() + + logger.debug( + "TrackingWorker.track | python_thread=%s | qt_current_thread=%r | worker_thread=%r", + threading.current_thread().name, + QThread.currentThread(), + self.thread, + ) + + self.trackingStarted.emit() + + model_name = cfg.tracker_name + try: + model_cls = AVAILABLE_TRACKERS[model_name]["class"] + except KeyError: + logger.error(f"Unknown tracker: {model_name}") + self.trackingStopped.emit() + self.finished.emit() + return + + model = model_cls(cfg) + + def progress_callback(current: int, total: int): + self.progress.emit(int(current), int(total)) + + def stop_callback() -> bool: + return self._should_stop() + + try: + inputs: TrackingModelInputs = model.prepare_inputs(cfg) + raw: RawModelOutputs = model.run(inputs, progress_callback, stop_callback) + + if self._should_stop(): + self.trackingStopped.emit() + self.finished.emit() + return + + output: TrackingWorkerOutput = model.prepare_outputs(raw, cfg) + + if hasattr(model, "validate_outputs"): + valid, msg = model.validate_outputs(inputs, output) + if not valid: + raise ValueError(f"Invalid model outputs: {msg}") + + except Exception as exc: + logger.exception("Tracking failed", exc_info=exc) + self.trackingStopped.emit() + self.finished.emit() + return + + self.trackingFinished.emit(output) + self.finished.emit() + + finally: + try: + import torch + + torch.cuda.empty_cache() + except Exception: + logger.debug("Could not clear CUDA cache", exc_info=True) + + if model is not None: + del model + + def run(self): + self.started.emit() + + def start(self): + self.thread = QThread() + self.moveToThread(self.thread) + + self.finished.connect(self.thread.quit) + self.thread.started.connect(self.run) + self.thread.finished.connect(self.thread.deleteLater) + + self.thread.start() + + def request_stop(self): + self._stop_requested.set() + + @Slot() + def stop_tracking(self): + self.request_stop() + + def _should_stop(self) -> bool: + return self._stop_requested.is_set() diff --git a/src/napari_deeplabcut/ui/base_widget/__init__.py b/src/napari_deeplabcut/ui/base_widget/__init__.py new file mode 100644 index 00000000..784d48ac --- /dev/null +++ b/src/napari_deeplabcut/ui/base_widget/__init__.py @@ -0,0 +1,3 @@ +from .singleton_widget import ViewerSingletonWidget + +__all__ = ["ViewerSingletonWidget"] diff --git a/src/napari_deeplabcut/ui/base_widget/singleton_widget.py b/src/napari_deeplabcut/ui/base_widget/singleton_widget.py new file mode 100644 index 00000000..f3a7a44d --- /dev/null +++ b/src/napari_deeplabcut/ui/base_widget/singleton_widget.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import weakref +from typing import ClassVar + +from qtpy.QtWidgets import QWidget + + +class ViewerSingletonWidget(QWidget): + """Base QWidget enforcing at most one live instance per viewer per subclass.""" + + _instances_by_cls: ClassVar[ + dict[type, weakref.WeakKeyDictionary[object, weakref.ReferenceType[ViewerSingletonWidget]]] + ] = {} + + # ------------------------------------------------------------------ # + # Viewer extraction / normalization # + # ------------------------------------------------------------------ # + + @staticmethod + def _extract_viewer_from_call(args, kwargs): + if args: + return args[0] + if "napari_viewer" in kwargs: + return kwargs["napari_viewer"] + if "viewer" in kwargs: + return kwargs["viewer"] + return None + + @staticmethod + def canonical_viewer(viewer): + current = viewer + seen = set() + + while True: + wrapped = getattr(current, "__wrapped__", None) + if wrapped is None: + wrapped = getattr(current, "_obj", None) + + if wrapped is None or wrapped is current or id(wrapped) in seen: + return current + + seen.add(id(current)) + current = wrapped + + # ------------------------------------------------------------------ # + # Registry helpers # + # ------------------------------------------------------------------ # + + @classmethod + def _instance_registry(cls): + reg = ViewerSingletonWidget._instances_by_cls.get(cls) + if reg is None: + reg = weakref.WeakKeyDictionary() + ViewerSingletonWidget._instances_by_cls[cls] = reg + return reg + + @staticmethod + def _is_qt_alive(widget) -> bool: + try: + widget.objectName() # any QObject call is enough + except RuntimeError: + return False + except Exception: + return True + return True + + @classmethod + def get_existing(cls, viewer): + canonical = cls.canonical_viewer(viewer) + ref = cls._instance_registry().get(canonical) + widget = ref() if ref is not None else None + if widget is None: + return None + if not cls._is_qt_alive(widget): + cls._instance_registry().pop(canonical, None) + return None + return widget + + @classmethod + def get_or_create(cls, viewer, *args, **kwargs): + existing = cls.get_existing(viewer) + if existing is not None: + return existing + return cls(viewer, *args, **kwargs) + + @classmethod + def is_docked(cls, viewer, widget) -> bool: + try: + for obj in viewer.window.dock_widgets.values(): + if obj is widget: + return True + try: + if obj.widget() is widget: + return True + except Exception: + pass + except Exception: + pass + return False + + # ------------------------------------------------------------------ # + # Singleton construction # + # ------------------------------------------------------------------ # + + def __new__(cls, *args, **kwargs): + viewer = cls._extract_viewer_from_call(args, kwargs) + if viewer is None: + raise TypeError( + f"{cls.__name__} requires a viewer argument (positional, napari_viewer=..., or viewer=...)." + ) + + canonical = cls.canonical_viewer(viewer) + existing = cls.get_existing(canonical) + if existing is not None: + return existing + + obj = super().__new__(cls) + cls._instance_registry()[canonical] = weakref.ref(obj) + return obj + + def _singleton_prepare_init(self, *args, **kwargs) -> bool: + """Pre-Qt-init guard. Safe to call before QWidget.__init__().""" + if getattr(self, "_viewer_singleton_initialized", False): + return False + + viewer = self._extract_viewer_from_call(args, kwargs) + if viewer is None: + raise TypeError(f"{self.__class__.__name__} requires a viewer argument during initialization.") + + self._viewer_singleton_initialized = True + self._viewer_singleton_key = self.canonical_viewer(viewer) + return True + + def _singleton_finalize_init(self) -> None: + """Post-Qt-init finalization. Call after QWidget.__init__().""" + # only connect once, and only after QObject exists + if getattr(self, "_viewer_singleton_finalize_done", False): + return + self._viewer_singleton_finalize_done = True + self.destroyed.connect(self._on_singleton_destroyed) + + def _on_singleton_destroyed(self, *args) -> None: + key = getattr(self, "_viewer_singleton_key", None) + if key is None: + return + + reg = self.__class__._instance_registry() + ref = reg.get(key) + if ref is not None and ref() is self: + reg.pop(key, None) diff --git a/src/napari_deeplabcut/ui/dialogs.py b/src/napari_deeplabcut/ui/dialogs.py index ba7333e7..37094f78 100644 --- a/src/napari_deeplabcut/ui/dialogs.py +++ b/src/napari_deeplabcut/ui/dialogs.py @@ -1,3 +1,8 @@ +"""Common dialogs used in napari-deeplabcut. + +NOTE: could be split into dialogs/tutorial.py, dialogs/shortcuts.py, dialogs/config_prompt.py, etc. +""" + # src/napari_deeplabcut/ui/dialogs.py from __future__ import annotations @@ -67,6 +72,8 @@ def _scope_label(scope: str) -> str: return "Points layer" if scope == "global-points": return "All Points layers" + if scope == "tracking-points-layer": + return "Points layer + tracking widget" return scope @@ -231,6 +238,33 @@ def closeEvent(self, event): super().closeEvent(event) + def _find_tracking_dock(self): + if self.viewer is None: + return None + + window = getattr(self.viewer, "window", None) + if window is None: + return None + + try: + for dock in window._dock_widgets.values(): + widget = dock.widget() + if widget is None: + continue + + if widget.objectName() == "napari-deeplabcut-tracking-controls" or bool( + widget.property("ndlc_tracking_controls") + ): + return dock + except Exception: + return None + + return None + + def _tracking_widget_is_open(self) -> bool: + dock = self._find_tracking_dock() + return dock is not None and dock.isVisible() + def _build_rows(self) -> None: grouped = defaultdict(list) for spec in iter_shortcuts(): @@ -265,12 +299,20 @@ def _availability_for_spec(self, spec) -> tuple[bool, str | None]: """ if self.viewer is None: return False, "No viewer available." + active = self._active_layer() active_is_points = isinstance(active, Points) if spec.scope == "points-layer" and not active_is_points: return False, "No active Points layer." + if spec.scope == "tracking-points-layer": + if not self._tracking_widget_is_open(): + return False, "Tracking widget is not open." + if not active_is_points: + return False, "No active Points layer." + return True, None + # Optional: support extra conditions later, e.g. multi-animal-only # if getattr(spec, "requires_multianimal", False): # if not active_is_points or not self._is_multianimal(active): diff --git a/src/napari_deeplabcut/ui/layer_stats.py b/src/napari_deeplabcut/ui/layer_stats.py index da977334..0a301e6b 100644 --- a/src/napari_deeplabcut/ui/layer_stats.py +++ b/src/napari_deeplabcut/ui/layer_stats.py @@ -148,8 +148,12 @@ def set_point_size_enabled(self, enabled: bool, *, reason: str | None = None) -> self._size_slider.setToolTip(tooltip) self._size_value.setToolTip(tooltip) - def set_folder_name(self, folder_name: str) -> None: - self._folder_value.setText(folder_name or "—") + def set_folder_name(self, folder_name: str, *, full_path: str | None = None) -> None: + text = folder_name or "—" + self._folder_value.setText(text) + + tooltip = (f"Project: {full_path}" if full_path else "No open project folder").strip() + self._folder_value.setToolTip(tooltip) def _show_progress_context_menu(self, pos) -> None: if not self._progress_details_text: diff --git a/src/napari_deeplabcut/ui/plots/trajectory.py b/src/napari_deeplabcut/ui/plots/trajectory.py index dd393227..e9342623 100644 --- a/src/napari_deeplabcut/ui/plots/trajectory.py +++ b/src/napari_deeplabcut/ui/plots/trajectory.py @@ -23,12 +23,12 @@ from qtpy.QtWidgets import QHBoxLayout, QLabel, QSizePolicy, QSlider, QVBoxLayout, QWidget import napari_deeplabcut.core.io as io -from napari_deeplabcut.config.models import DLCHeaderModel from napari_deeplabcut.config.settings import ( DEFAULT_MULTI_ANIMAL_INDIVIDUAL_CMAP, DEFAULT_SINGLE_ANIMAL_CMAP, ) from napari_deeplabcut.core.keypoints import build_color_cycles +from napari_deeplabcut.core.layer_lifecycle import LayerLifecycleManager from napari_deeplabcut.core.layers import ( get_first_image_layer, get_first_video_image_layer, @@ -185,44 +185,6 @@ def minimumSizeHint(self) -> QSize: """ return QSize(280, 340) - def _get_header_model_from_metadata(self, md: dict) -> DLCHeaderModel | None: - """Return DLCHeaderModel from metadata['header'] when possible. - TODO: Check codebase for duplicate logic and centralize if needed. - """ - if not isinstance(md, dict): - return None - - hdr = md.get("header", None) - if hdr is None: - return None - - if isinstance(hdr, DLCHeaderModel): - return hdr - - if isinstance(hdr, dict): - try: - return DLCHeaderModel.model_validate(hdr) - except Exception: - return None - - try: - return DLCHeaderModel(columns=hdr) - except Exception: - return None - - def _is_multianimal_layer(self, layer: Points) -> bool: - """TODO: Check codebase for duplicate logic and centralize if needed.""" - md = getattr(layer, "metadata", None) or {} - header = self._get_header_model_from_metadata(md) - if header is None: - return False - - try: - inds = getattr(header, "individuals", None) - return bool(inds and len(inds) > 0 and str(inds[0]) != "") - except Exception: - return False - def _get_config_colormap(self, layer: Points) -> str: """Return the colormap for the given layer based on its metadata. TODO: Check codebase for duplicate logic and centralize if needed. @@ -661,7 +623,7 @@ def _resolved_face_color_cycles(self, layer: Points) -> dict[str, dict]: - single-animal falls back to the bodypart cycle for both """ md = getattr(layer, "metadata", None) or {} - header = self._get_header_model_from_metadata(md) + header = LayerLifecycleManager.get_header_model_from_metadata(md) if header is None: return {} @@ -673,7 +635,7 @@ def _resolved_face_color_cycles(self, layer: Points) -> dict[str, dict]: logger.debug("Trajectory plot: failed to build bodypart color cycles", exc_info=True) bodypart_cycles = {} - if self._is_multianimal_layer(layer): + if LayerLifecycleManager.is_multianimal(layer): try: individual_cycles = build_color_cycles(header, DEFAULT_MULTI_ANIMAL_INDIVIDUAL_CMAP) or {} except Exception: @@ -705,7 +667,7 @@ def _line_color_for(self, points_layer: Points, individual: str, bodypart: str): individual_key = self._normalized_individual_name(individual) bodypart_key = str(bodypart) - if mode == "individual" and self._is_multianimal_layer(points_layer): + if mode == "individual" and LayerLifecycleManager.is_multianimal(points_layer): if individual_key and individual_key in id_cycle: return id_cycle[individual_key] if bodypart_key in label_cycle: diff --git a/src/napari_deeplabcut/ui/ui_dialogs/__init__.py b/src/napari_deeplabcut/ui/ui_dialogs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/napari_deeplabcut/ui/ui_dialogs/save.py b/src/napari_deeplabcut/ui/ui_dialogs/save.py new file mode 100644 index 00000000..1cca5b20 --- /dev/null +++ b/src/napari_deeplabcut/ui/ui_dialogs/save.py @@ -0,0 +1,710 @@ +# src/napari_deeplabcut/ui/dialogs/save.py +from __future__ import annotations + +import logging +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import dataclass +from html import escape +from pathlib import Path +from typing import TYPE_CHECKING + +from napari.layers import Image, Points +from napari.utils.history import get_save_history +from qtpy.QtCore import Qt +from qtpy.QtWidgets import QFileDialog, QInputDialog, QMessageBox + +from ...core.conflicts import compute_overwrite_report_for_points_save +from ...core.errors import MissingProvenanceError +from ...core.io import is_video +from ...core.layers import is_machine_layer +from ...core.metadata import ( + MergePolicy, + apply_project_paths_override_to_points_meta, + migrate_points_layer_metadata, + read_points_meta, + write_points_meta, +) +from ...core.project_paths import ( + coerce_paths_to_dlc_row_keys, + dataset_folder_has_files, + find_nearest_config, + looks_like_dlc_labeled_folder, + normalize_anchor_candidate, + resolve_project_root_from_config, + target_dataset_folder_for_config, +) +from ...core.provenance import ( + apply_gt_save_target, + is_projectless_folder_association_candidate, + requires_gt_promotion, + suggest_human_placeholder, +) +from ...core.sidecar import get_default_scorer, set_default_scorer +from ...core.trails import safe_folder_anchor_from_points_layer +from ..dialogs import ( + ProjectConfigPromptAction, + load_scorer_from_config, + maybe_confirm_dataset_path_rewrite, + maybe_confirm_overwrite, + prompt_for_project_config_for_save, + warn_existing_dataset_folder_conflict, + warn_invalid_config_for_scorer, +) + +if TYPE_CHECKING: + from collections.abc import Iterable + + from ...config.models import ImageMetadata + from ...core.layer_lifecycle.manager import LayerLifecycleManager + from ...core.trails import TrailsController + + +@dataclass(slots=True) +class SaveOutcome: + saved: bool + status_message: str | None = None + + +def _prompt_for_scorer(parent_widget, *, anchor: str, suggested: str) -> str | None: + """Prompt user for a scorer name. Returns non-empty string or None if cancelled.""" + text, ok = QInputDialog.getText( + parent_widget, + "Choose scorer", + "No DLC config.yaml scorer found.\n" + "Please enter a scorer name for the CollectedData file.\n\n" + "Tip: Use your name or a stable lab identifier.\n" + "(We strongly discourage keeping the generic 'human_xxxxxx'.)", + text=suggested, + ) + if not ok: + return None + scorer = (text or "").strip() + if not scorer: + return None + return scorer + + +@contextmanager +def _temporary_layer_metadata(layer: Points, metadata: dict): + old_metadata = dict(layer.metadata or {}) + layer.metadata = metadata + try: + yield + finally: + layer.metadata = old_metadata + + +class PointsLayerSaveWorkflow: + """Orchestrate save flows for napari-deeplabcut points layers. + + This class owns: + - single-layer save routing + - promotion-to-GT checks + - project association / metadata override flow + - overwrite preflight + confirmation + - save-time provenance enrichment + - post-save sidecar UI-state persistence + + It intentionally does NOT own widget-specific UI state updates such as: + - KeypointControls._is_saved + - last-saved timestamp label + """ + + def __init__( + self, + *, + parent, + viewer, + layer_manager: LayerLifecycleManager, + trails_controller: TrailsController, + trail_checkbox_getter: Callable[[], bool], + resolve_config_path_for_layer: Callable[[Points | None], Path | None], + current_project_path_getter: Callable[[], str | None], + current_image_meta_getter: Callable[[], ImageMetadata], + logger: logging.Logger, + ) -> None: + self.parent = parent + self.viewer = viewer + self.layer_manager = layer_manager + self.trails_controller = trails_controller + self.trail_checkbox_getter = trail_checkbox_getter + self.resolve_config_path_for_layer = resolve_config_path_for_layer + self.current_project_path_getter = current_project_path_getter + self.current_image_meta_getter = current_image_meta_getter + self.logger = logger + + # ------------------------------------------------------------------ # + # Public entry point # + # ------------------------------------------------------------------ # + + def save_layers(self, *, selected: bool = False) -> SaveOutcome: + selected_layers = list(self.viewer.layers.selection) + + msg = "" + if not len(self.viewer.layers): + msg = "There are no layers in the viewer to save." + elif selected and not len(selected_layers): + msg = "Please select a Points layer to save." + + if msg: + QMessageBox.warning(self.parent, "Nothing to save", msg, QMessageBox.Ok) + return SaveOutcome(saved=False) + + if len(selected_layers) == 1 and isinstance(selected_layers[0], Points): + return self._save_single_points_layer(selected_layers[0]) + + return self._save_multiple_layers(selected=selected, selected_layers=selected_layers) + + # ------------------------------------------------------------------ # + # Single-layer points save # + # ------------------------------------------------------------------ # + + def _save_single_points_layer(self, layer: Points) -> SaveOutcome: + ok = self._ensure_promotion_save_target(layer) + if not ok: + return SaveOutcome(saved=False) + + self.logger.debug( + "About to save. io.kind=%r save_target=%r", + layer.metadata.get("io", {}).get("kind"), + layer.metadata.get("save_target"), + ) + + save_metadata: dict = dict(layer.metadata or {}) + + try: + overridden_metadata, abort_save = self._maybe_prepare_project_path_override_metadata(layer) + if abort_save: + self.logger.debug("Save aborted during project-association path handling.") + return SaveOutcome(saved=False) + + base_metadata = overridden_metadata if overridden_metadata is not None else dict(layer.metadata or {}) + save_metadata = self._enrich_points_metadata_for_save(layer, base_metadata) + + if self._is_unsupported_direct_video_label_save(layer, save_metadata): + self.logger.debug( + "Save aborted due to unsupported direct video + config.yaml label save case. Layer=%r", + getattr(layer, "name", layer), + ) + self._warn_unsupported_direct_video_label_save(layer, save_metadata) + return SaveOutcome(saved=False) + + attributes = { + "name": layer.name, + "metadata": save_metadata, + "properties": dict(layer.properties or {}), + } + + report = compute_overwrite_report_for_points_save(layer.data, attributes) + + except MissingProvenanceError: + self.logger.exception( + "Missing save provenance for layer %r", + getattr(layer, "name", layer), + ) + QMessageBox.warning( + self.parent, + "Cannot save keypoints", + self._format_missing_provenance_save_message(layer, save_metadata), + QMessageBox.Ok, + ) + return SaveOutcome(saved=False) + + except Exception as e: + self.logger.exception( + "Failed to prepare save checks for layer %r", + getattr(layer, "name", layer), + ) + QMessageBox.warning( + self.parent, + "Cannot save keypoints", + f"Something went wrong while preparing this layer for saving:\n{e}", + QMessageBox.Ok, + ) + return SaveOutcome(saved=False) + + if report is not None: + if not maybe_confirm_overwrite( + parent=self.parent, + report=report, + ): + self.logger.debug("Save cancelled.") + return SaveOutcome(saved=False) + + metadata_changed = save_metadata != dict(layer.metadata or {}) + + with _temporary_layer_metadata(layer, save_metadata): + self.viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") + + # Persist successful save-time metadata improvements into the live layer. + if metadata_changed: + layer.metadata = dict(save_metadata) + + self._persist_folder_ui_state_for_layers([layer]) + + return SaveOutcome(saved=True, status_message="Data successfully saved") + + # ------------------------------------------------------------------ # + # Multi-layer / generic save # + # ------------------------------------------------------------------ # + + def _save_multiple_layers(self, *, selected: bool, selected_layers: list) -> SaveOutcome: + dlg = QFileDialog() + hist = get_save_history() + dlg.setHistory(hist) + + filename, _ = dlg.getSaveFileName( + caption=f"Save {'selected' if selected else 'all'} layers", + dir=hist[0], # home dir by default + ) + + if not filename: + return SaveOutcome(saved=False) + + self.viewer.layers.save(filename, selected=selected) + + if selected: + candidate_layers = [ly for ly in selected_layers if isinstance(ly, Points)] + else: + candidate_layers = list(self.layer_manager.managed_points_layers()) + + self._persist_folder_ui_state_for_layers(candidate_layers) + + return SaveOutcome(saved=True, status_message="Data successfully saved") + + # ------------------------------------------------------------------ # + # Save-time metadata enrichment # + # ------------------------------------------------------------------ # + + def _best_image_context_layer(self) -> Image | None: + """Return the best available image layer for save-time provenance inference.""" + active = self.layer_manager.active_dlc_image_layer() + if active is not None: + return active + + selected = self.viewer.layers.selection.active + if isinstance(selected, Image): + return selected + + for layer in self.viewer.layers: + if isinstance(layer, Image): + return layer + + return None + + def _enrich_points_metadata_for_save(self, layer: Points, metadata: dict) -> dict: + """Best-effort save-time metadata enrichment for DLC routing. + + Conservative policy: + - never overwrite explicit metadata already present + - first reuse lifecycle-owned image context + - then try nearby config.yaml + image/source hints + """ + md = dict(metadata or {}) + + if md.get("root"): + return md + + if self.layer_manager.image_root: + md.setdefault("root", self.layer_manager.image_root) + if self.layer_manager.image_paths: + md.setdefault("paths", self.layer_manager.image_paths) + + if md.get("root"): + return md + + config_path = self.resolve_config_path_for_layer(layer) + if config_path is None: + return md + + project_root = resolve_project_root_from_config(config_path) + if project_root is None: + return md + + md.setdefault("project", str(project_root)) + + image_layer = self._best_image_context_layer() + if image_layer is None: + return md + + try: + src = getattr(getattr(image_layer, "source", None), "path", None) + except Exception: + src = None + + src_anchor = normalize_anchor_candidate(src) if src else None + if src_anchor is not None and looks_like_dlc_labeled_folder(src_anchor): + md.setdefault("root", str(src_anchor)) + return md + + image_name = getattr(image_layer, "name", None) + if image_name: + candidate = project_root / "labeled-data" / str(image_name) + if candidate.is_dir(): + md.setdefault("root", str(candidate)) + + return md + + def _is_video_context_layer(self, layer: Image | None) -> bool: + if layer is None: + return False + + md = getattr(layer, "metadata", {}) or {} + dlc_md = md.get("dlc") or {} + if dlc_md.get("session_role") == "video": + return True + + try: + src = getattr(getattr(layer, "source", None), "path", None) + except Exception: + src = None + + for candidate in (src, getattr(layer, "name", None)): + if candidate and is_video(str(candidate)): + return True + + return False + + def _is_unsupported_direct_video_label_save(self, layer: Points, metadata: dict) -> bool: + """ + Unsupported case: + - points layer has no extracted-frame row keys (`paths`) + - current save context is a video session + + This corresponds to labeling directly on a loaded video after adding + a config.yaml / placeholder config, which bypasses DLC frame extraction. + """ + paths = metadata.get("paths") or [] + if paths: + return False + + image_layer = self._best_image_context_layer() + return self._is_video_context_layer(image_layer) + + def _warn_unsupported_direct_video_label_save(self, layer: Points, metadata: dict) -> None: + image_layer = self._best_image_context_layer() + image_name = getattr(image_layer, "name", None) if image_layer is not None else "current video" + config_path = self.resolve_config_path_for_layer(layer) + + image_html = escape(image_name or "not found") + config_html = escape(str(config_path) if config_path is not None else "not found") + + msg = QMessageBox(self.parent) + msg.setIcon(QMessageBox.Warning) + msg.setWindowTitle("Cannot save labels from video directly") + msg.setTextFormat(Qt.RichText) + msg.setText( + "

" + "The currently loaded image layer looks like it is a video, and the keypoints layer " + "has no paths to individual image files." + "

" + "

" + "Saving labels created directly on a loaded video by adding config.yaml " + "is not supported." + "

" + "

" + "This bypasses DeepLabCut's frame extraction workflow, and the plugin cannot write a valid " + "CollectedData_<scorer>.h5 file from video frame indices alone." + "

" + f"

Video: {image_html}
" + f"Config: {config_html}

" + "

What to do instead:

" + "" + ) + msg.setStandardButtons(QMessageBox.Ok) + msg.exec() + + def _format_missing_provenance_save_message(self, layer: Points, metadata: dict) -> str: + config_path = self.resolve_config_path_for_layer(layer) + image_layer = self._best_image_context_layer() + + layer_name = getattr(layer, "name", "Unnamed layer") + project_hint = metadata.get("project") or self.current_project_path_getter() or "not available" + root_hint = metadata.get("root") or "not available" + n_paths = len(metadata.get("paths") or []) + + image_name = getattr(image_layer, "name", None) if image_layer is not None else None + + return ( + "Couldn't determine where to save this keypoints layer in a DeepLabCut project.\n\n" + f"Layer: {layer_name}\n" + f"Project hint: {project_hint}\n" + f"Dataset folder (root): {root_hint}\n" + f"Paths metadata: {n_paths} item(s)\n" + f"Nearby config.yaml: {str(config_path) if config_path is not None else 'not found'}\n" + f"Image/video context: {image_name or 'not available'}\n\n" + "To save this layer, the plugin needs either:\n" + "• a dataset folder (metadata['root']), or\n" + "• enough project/image context to infer one automatically.\n\n" + "Try one of the following:\n" + "• open the image/video from the DLC project first,\n" + "• load the project's config.yaml or labeled-data folder, or\n" + "• make sure this layer is associated with the correct dataset folder." + ) + + # ------------------------------------------------------------------ # + # Promotion / provenance helpers # + # ------------------------------------------------------------------ # + + def _ensure_promotion_save_target(self, layer: Points) -> bool: + """Ensure a prediction/machine source layer has a GT save_target set. + + Returns True if save_target is set (or already existed), False if user cancels. + """ + if not getattr(layer, "metadata", None): + layer.metadata = {} + + from ...config.models import PointsMetadata # local import to avoid unnecessary module load in import path + + if not safe_folder_anchor_from_points_layer(layer) and not is_machine_layer(layer := layer): + # Preserve old fast path for non-machine layers + return True + + if not is_machine_layer(layer): + return True + + mig = migrate_points_layer_metadata(layer) + if hasattr(mig, "errors"): + self.logger.warning( + "Failed to migrate points layer metadata for layer=%r: %s", + getattr(layer, "name", layer), + mig, + ) + + res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=False) + if hasattr(res, "errors"): + self.logger.warning( + "Points metadata validation failed for layer=%r during save target check: %s", + getattr(layer, "name", layer), + res, + ) + QMessageBox.warning(self.parent, "Cannot save", "Layer metadata is invalid; see logs for details.") + return False + + pts: PointsMetadata = res + + if not requires_gt_promotion(pts): + return True + + anchor = safe_folder_anchor_from_points_layer(layer) + if not anchor: + QMessageBox.warning(self.parent, "Cannot save", "Could not determine a folder anchor for saving.") + return False + + scorer = None + + cfg_path = None + try: + cfg_path = find_nearest_config(anchor) + except Exception: + self.logger.debug("Automatic config discovery failed for anchor=%r", anchor, exc_info=True) + + if cfg_path: + try: + scorer = load_scorer_from_config(cfg_path) + except Exception: + self.logger.exception("Failed to load auto-discovered config.yaml: %s", cfg_path) + warn_invalid_config_for_scorer( + self.parent, + config_path=cfg_path, + reason="unreadable", + auto_found=True, + ) + return False + + if not scorer: + warn_invalid_config_for_scorer( + self.parent, + config_path=cfg_path, + reason="missing_scorer", + auto_found=True, + ) + return False + + else: + dialog_result = prompt_for_project_config_for_save( + self.parent, + initial_dir=self.current_project_path_getter() or anchor, + window_title="Locate DLC config for scorer resolution", + message=( + "No DeepLabCut config.yaml could be found automatically for this machine-labeled layer.\n\n" + "If this layer belongs to a DLC project, choose its config.yaml so the save uses the " + "project scorer and standard naming.\n\n" + "If no config.yaml exists, you can continue without one." + ), + choose_button_text="Choose config.yaml", + skip_button_text="Continue without config", + resolve_scorer=True, + ) + + if dialog_result.action is ProjectConfigPromptAction.CANCEL: + return False + + if dialog_result.action is ProjectConfigPromptAction.ASSOCIATE: + scorer = dialog_result.scorer + + else: + scorer = get_default_scorer(anchor) + + if not scorer: + suggested = suggest_human_placeholder(anchor) + while True: + s = _prompt_for_scorer(self.parent, anchor=anchor, suggested=suggested) + if s is None: + return False + if s.startswith("human_"): + choice = QMessageBox.question( + self.parent, + "Generic scorer name", + "You entered a generic scorer name starting with 'human_'.\n\n" + "We strongly recommend using a real name or stable identifier.\n" + "Do you want to keep this generic scorer anyway?", + QMessageBox.Yes | QMessageBox.No, + ) + if choice == QMessageBox.No: + suggested = s + continue + scorer = s + break + try: + set_default_scorer(anchor, scorer) + except Exception: + self.logger.debug("Failed to persist default scorer to sidecar", exc_info=True) + + updated = apply_gt_save_target( + pts, + anchor=anchor, + scorer=scorer, + dataset_key="df_with_missing", + ) + + out = write_points_meta( + layer, + updated, + merge_policy=MergePolicy.MERGE, + fields={"save_target"}, + migrate_legacy=True, + validate=True, + ) + + if hasattr(out, "errors"): + self.logger.warning("Failed to write save_target for layer=%r: %s", getattr(layer, "name", layer), out) + QMessageBox.warning( + self.parent, + "Cannot save", + "Failed to write save target metadata; see logs for details.", + ) + return False + + return True + + def _maybe_prepare_project_path_override_metadata(self, layer: Points) -> tuple[dict | None, bool]: + """Optionally prepare save-time metadata by associating a project-less labeled + folder with an explicit DLC project chosen via config.yaml. + """ + from ...config.models import PointsMetadata # local import to avoid unnecessary module load in import path + + res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=False) + if hasattr(res, "errors"): + return None, False + + pts_meta: PointsMetadata = res + paths = pts_meta.paths or [] + if not paths: + return None, False + + if not is_projectless_folder_association_candidate(pts_meta): + return None, False + + source_root = pts_meta.root + if not source_root: + return None, False + + try: + source_root_path = Path(source_root).expanduser().resolve(strict=False) + except Exception: + source_root_path = Path(source_root) + + dataset_name = source_root_path.name + if not dataset_name: + return None, False + + initial_dir = self.current_project_path_getter() or pts_meta.project or str(source_root_path) + dialog_result = prompt_for_project_config_for_save(self.parent, initial_dir=initial_dir) + + if dialog_result.action is ProjectConfigPromptAction.CANCEL: + self.logger.debug("User cancelled project association prompt.") + return None, True + + if dialog_result.action is ProjectConfigPromptAction.SKIP: + self.logger.debug("User chose to continue without project association.") + return None, False + + if dialog_result.action is not ProjectConfigPromptAction.ASSOCIATE: + self.logger.warning("Unexpected project association dialog result: %r", dialog_result) + return None, True + + config_path = dialog_result.config_path + if not config_path: + self.logger.warning("Project association result was ASSOCIATE but config_path was empty.") + return None, True + + project_root = resolve_project_root_from_config(config_path) + if project_root is None: + QMessageBox.warning( + self.parent, + "Invalid project configuration", + "The selected file is not a valid DeepLabCut config.yaml or project root. " + "The save operation has been cancelled.", + ) + return None, True + + target_folder = target_dataset_folder_for_config(config_path, dataset_name=dataset_name) + if dataset_folder_has_files(target_folder): + warn_existing_dataset_folder_conflict(self.parent, target_folder=target_folder) + return None, True + + rewritten_paths, unresolved = coerce_paths_to_dlc_row_keys( + paths, + source_root=source_root_path, + dataset_name=dataset_name, + ) + + if not maybe_confirm_dataset_path_rewrite( + self.parent, + project_root=project_root, + dataset_name=dataset_name, + n_paths=len(paths), + n_unresolved=len(unresolved), + ): + return None, True + + overridden = apply_project_paths_override_to_points_meta( + pts_meta, + project_root=project_root, + rewritten_paths=rewritten_paths, + ) + + return overridden.model_dump(mode="python", exclude_none=True), False + + # ------------------------------------------------------------------ # + # Post-save state persistence # + # ------------------------------------------------------------------ # + + def _persist_folder_ui_state_for_layers(self, layers: Iterable[Points]) -> None: + try: + checked = bool(self.trail_checkbox_getter()) + for layer in layers: + if layer in self.viewer.layers: + self.trails_controller.persist_folder_ui_state_for_points_layer( + layer, + checkbox_checked=checked, + ) + except Exception: + self.logger.debug("Failed to persist folder UI state after save", exc_info=True) diff --git a/src/napari_deeplabcut/utils/debug.py b/src/napari_deeplabcut/utils/debug.py index d33e60b7..29da3d2a 100644 --- a/src/napari_deeplabcut/utils/debug.py +++ b/src/napari_deeplabcut/utils/debug.py @@ -7,14 +7,46 @@ import threading import traceback from collections import deque +from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from importlib import metadata from pathlib import Path +from time import perf_counter_ns _DEBUG_HANDLER_ATTR = "_napari_dlc_debug_recorder" LOG_QUEUE_MAXLEN = 1000 +# FIXME disable for release to avoid any overhead, or make configurable via env var/settings +NAPARI_DLC_LOG_TIMING = True + + +@contextmanager +def log_timing( + logger: logging.Logger, + label: str, + *, + level: int = logging.DEBUG, + threshold_ms: float | None = None, +): + """Lightweight scoped timer for debug instrumentation. + + Uses perf_counter_ns() for monotonic timing. + Logs only if logger is enabled for the requested level. + Optionally suppresses tiny timings below threshold_ms. + """ + if not logger.isEnabledFor(level) or not NAPARI_DLC_LOG_TIMING: + yield + return + + t0 = perf_counter_ns() + try: + yield + finally: + dt_ms = (perf_counter_ns() - t0) / 1e6 + if threshold_ms is None or dt_ms >= threshold_ms: + logger.log(level, "%s took %.3f ms", label, dt_ms) + def install_debug_recorder( *, @@ -136,11 +168,20 @@ def render_text(self, *, limit: int = 200) -> str: lines: list[str] = [] try: records = self.snapshot()[-max(1, int(limit)) :] + if not records: + return "" + + base = records[0].created for rec in records: - ts = datetime.fromtimestamp(rec.created).strftime("%H:%M:%S") - lines.append(f"{ts} | {rec.level:<8} | {rec.logger_name} | {rec.message}") + ts = datetime.fromtimestamp(rec.created).strftime("%H:%M:%S.%f")[:-3] + if NAPARI_DLC_LOG_TIMING: + rel_ms = (rec.created - base) * 1000.0 + lines.append(f"{ts} (+{rel_ms:8.1f} ms) | {rec.level:<8} | {rec.logger_name} | {rec.message}") + else: + lines.append(f"{ts} | {rec.level:<8} | {rec.logger_name} | {rec.message}") if rec.exc_text: lines.append(rec.exc_text.rstrip()) + if self._dropped: lines.append(f"[debug-recorder] dropped internal failures: {self._dropped}") except Exception: diff --git a/src/napari_deeplabcut/widget_factory.py b/src/napari_deeplabcut/widget_factory.py new file mode 100644 index 00000000..2ccda730 --- /dev/null +++ b/src/napari_deeplabcut/widget_factory.py @@ -0,0 +1,12 @@ +# src/napari_deeplabcut/_widget_factory.py +from __future__ import annotations + +from ._widgets import KeypointControls + + +def get_existing_keypoint_controls(viewer) -> KeypointControls | None: + return KeypointControls.get_existing(viewer) + + +def get_or_create_keypoint_controls(viewer) -> KeypointControls: + return KeypointControls(viewer) diff --git a/tox.ini b/tox.ini index adcf5741..2a997888 100644 --- a/tox.ini +++ b/tox.ini @@ -28,5 +28,6 @@ passenv = PYTEST_QT_API PYVISTA_OFF_SCREEN extras = - testing + dev + tracking commands = pytest -v --color=yes --cov=napari_deeplabcut --cov-report=xml