diff --git a/src/napari_deeplabcut/_tests/compat/test_compat_integration.py b/src/napari_deeplabcut/_tests/compat/test_compat_integration.py index 076428d9..92990f67 100644 --- a/src/napari_deeplabcut/_tests/compat/test_compat_integration.py +++ b/src/napari_deeplabcut/_tests/compat/test_compat_integration.py @@ -5,7 +5,6 @@ from __future__ import annotations import numpy as np -import pytest from napari_deeplabcut.napari_compat import ( apply_points_layer_ui_tweaks, @@ -111,7 +110,6 @@ def paste_func(this): assert seen == [layer] -@pytest.mark.xfail(reason="This test is fixed in a subsequent PR, to be added") def test_apply_points_layer_ui_tweaks_real_dropdown(qtbot): from types import SimpleNamespace @@ -147,7 +145,12 @@ def __init__(self): def layout(self): return self._layout - layer = SimpleNamespace(metadata={"colormap_name": "magma"}) + class DummyLayer: + def __init__(self): + self.metadata = {"colormap_name": "magma"} + + layer = DummyLayer() + point_controls = PointControls() qtbot.addWidget(point_controls) diff --git a/src/napari_deeplabcut/_tests/e2e/test_points_layers.py b/src/napari_deeplabcut/_tests/e2e/test_points_layers.py index 772bb4a4..84687131 100644 --- a/src/napari_deeplabcut/_tests/e2e/test_points_layers.py +++ b/src/napari_deeplabcut/_tests/e2e/test_points_layers.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import numpy as np import pytest from napari.layers import Points -from napari_deeplabcut.core.layers import populate_keypoint_layer_properties +from napari_deeplabcut.core.layers import PointsInteractionObserver, populate_keypoint_layer_properties @pytest.mark.usefixtures("qtbot") @@ -100,7 +102,7 @@ def test_layer_insert_does_not_crash_when_current_property_is_nan(viewer, keypoi layer = viewer.add_points(data, **md) # Plot cannot be formed because of the NaN, # but the layer must still be added and cycle mode must not be enabled. - assert keypoint_controls._matplotlib_canvas.df is None + assert keypoint_controls._traj_mpl_canvas.df is None assert isinstance(layer, Points) assert layer.face_color_mode != "cycle" @@ -239,3 +241,71 @@ def test_copy_paste_same_frame_does_not_duplicate_existing_keypoints( # no duplicates expected on same frame assert len(layer.data) == before + + +def test_points_interaction_observer_emits_on_selected_data_change(viewer, qtbot): + layer = viewer.add_points( + np.array([[0, 0], [1, 1]]), + properties={"label": np.array(["nose", "tail"], dtype=object)}, + ) + viewer.layers.selection.active = layer + + seen = [] + + observer = PointsInteractionObserver(viewer, seen.append, debounce_ms=0) + observer.install() + + layer.selected_data.select_only(1) + + qtbot.waitUntil(lambda: len(seen) >= 1, timeout=1000) + + evt = seen[-1] + assert evt.viewer is viewer + assert isinstance(evt.layer, Points) + assert "selection" in evt.reasons + assert tuple(sorted(evt.layer.selected_data)) == (1,) + + observer.close() + + +def test_points_interaction_observer_rebinds_when_active_layer_changes(viewer, qtbot): + layer1 = viewer.add_points( + np.array([[0, 0], [1, 1]]), + properties={"label": np.array(["nose", "tail"], dtype=object)}, + name="points-1", + ) + layer2 = viewer.add_points( + np.array([[2, 2], [3, 3]]), + properties={"label": np.array(["paw", "ear"], dtype=object)}, + name="points-2", + ) + + seen = [] + observer = PointsInteractionObserver(viewer, seen.append, debounce_ms=0) + observer.install() + + viewer.layers.selection.active = layer1 + layer1.selected_data.select_only(0) + qtbot.waitUntil(lambda: any("selection" in ev.reasons for ev in seen), timeout=1000) + + count_after_layer1 = len(seen) + + # Switch active layer + viewer.layers.selection.active = layer2 + qtbot.waitUntil(lambda: len(seen) > count_after_layer1, timeout=1000) + + count_after_active_switch = len(seen) + + # Mutating old inactive layer selection should not produce a new callback + layer1.selected_data.select_only(1) + qtbot.wait(50) + assert len(seen) == count_after_active_switch + + # Mutating new active layer selection should produce a callback + layer2.selected_data.select_only(1) + qtbot.waitUntil(lambda: len(seen) > count_after_active_switch, timeout=1000) + assert seen[-1].layer is not None + assert seen[-1].layer.name == layer2.name + assert "selection" in seen[-1].reasons + + observer.close() diff --git a/src/napari_deeplabcut/_tests/test_widgets.py b/src/napari_deeplabcut/_tests/test_widgets.py index c66e5bfd..d4b4cc1d 100644 --- a/src/napari_deeplabcut/_tests/test_widgets.py +++ b/src/napari_deeplabcut/_tests/test_widgets.py @@ -1,3 +1,7 @@ +""" +NOTE: This file can be somewhat non-specific, please ensure functionalities from ui/ are tested +in the _tests/ui folder and consider moving tests below to more specific files as needed.""" + # src/napari_deeplabcut/_tests/test_widgets.py import os import types @@ -16,7 +20,7 @@ from napari_deeplabcut.ui.color_scheme_display import ColorSchemeDisplay from napari_deeplabcut.ui.dialogs import ShortcutRow from napari_deeplabcut.ui.labels_and_dropdown import KeypointsDropdownMenu, LabelPair -from napari_deeplabcut.ui.plots.trajectory import KeypointMatplotlibCanvas +from napari_deeplabcut.ui.plots.trajectory import TrajectoryMatplotlibCanvas from .conftest import force_show @@ -235,7 +239,7 @@ def test_color_scheme_display(qtbot): @pytest.mark.usefixtures("qtbot") def test_matplotlib_canvas_initialization_and_slider(viewer, points, qtbot): # Create the canvas widget - canvas = KeypointMatplotlibCanvas(viewer) + canvas = TrajectoryMatplotlibCanvas(viewer) qtbot.add_widget(canvas) # Simulate adding a Points layer (triggers _load_dataframe) @@ -254,8 +258,11 @@ def test_matplotlib_canvas_initialization_and_slider(viewer, points, qtbot): assert canvas._window == initial_window + 100 assert canvas.slider_value.text() == str(initial_window + 100) - # Test plot refresh on frame change + # Test plot refresh does nothing when plot is hidden canvas.update_plot_range(event=type("Event", (), {"value": [5]})) + assert canvas._n == 0 + # Test plot refresh on frame change (forced as it is hidden) + canvas.update_plot_range(event=type("Event", (), {"value": [5]}), force=True) assert canvas._n == 5 # Check that x-limits reflect the new window start, end = canvas.ax.get_xlim() @@ -286,7 +293,7 @@ def fake_add_dock_widget(*args, **kwargs): # Ensure it wouldn't try to dock again monkeypatch.setattr(controls.viewer.window, "add_dock_widget", fake_add_dock_widget) - controls._ensure_mpl_canvas_docked() + controls._ensure_traj_canvas_docked() assert called["count"] == 0, "add_dock_widget should not be called when already docked" assert controls._mpl_docked is True # stays docked @@ -301,7 +308,7 @@ def test_ensure_mpl_canvas_docked_missing_window(keypoint_controls, qtbot): controls.viewer = types.SimpleNamespace() # no 'window' controls._mpl_docked = False - controls._ensure_mpl_canvas_docked() + controls._ensure_traj_canvas_docked() # Nothing should change; crucially, no exceptions should be raised assert controls._mpl_docked is False @@ -322,11 +329,11 @@ def test_trajectory_loader_ignores_invalid_properties(viewer, keypoint_controls, layer = viewer.add_points(np.array([[0.0, 10.0, 20.0]]), **md) assert layer is not None - assert keypoint_controls._matplotlib_canvas.df is None # loader should have bailed out safely + assert keypoint_controls._traj_mpl_canvas.df is None # loader should have bailed out safely @pytest.mark.usefixtures("qtbot") -def test_ensure_mpl_canvas_docked_missing_qt_window(keypoint_controls, qtbot): +def test_ensure_traj_canvas_docked_missing_qt_window(keypoint_controls, qtbot): """If window._qt_window is None, method should safely no-op.""" controls = keypoint_controls qtbot.add_widget(controls) @@ -341,7 +348,7 @@ def add_dock_widget(self, *args, **kwargs): controls.viewer = types.SimpleNamespace(window=DummyWindow()) controls._mpl_docked = False - controls._ensure_mpl_canvas_docked() + controls._ensure_traj_canvas_docked() # Still undocked, no crash assert controls._mpl_docked is False @@ -365,7 +372,7 @@ def add_dock_widget(self, *args, **kwargs): controls._mpl_docked = False # Should not raise - controls._ensure_mpl_canvas_docked() + controls._ensure_traj_canvas_docked() # Docking failed → remains undocked assert controls._mpl_docked is False diff --git a/src/napari_deeplabcut/_tests/ui/test_cropping.py b/src/napari_deeplabcut/_tests/ui/test_cropping.py index c750d230..3c63b180 100644 --- a/src/napari_deeplabcut/_tests/ui/test_cropping.py +++ b/src/napari_deeplabcut/_tests/ui/test_cropping.py @@ -514,3 +514,333 @@ def test_execute_frame_extraction_keeps_new_labels_row_on_duplicate_index(monkey df_written = pd.read_hdf(labels_path, key="df_with_missing") assert float(df_written.iloc[0]["bp1"]) == 222.0 + + +def test_ensure_dlc_crop_layer_reuses_existing(monkeypatch): + monkeypatch.setattr(cropping_mod, "Shapes", FakeShapes) + + existing = FakeShapes( + data=[], + shape_type=[], + metadata={cropping_mod.DLC_CROP_LAYER_META_KEY: True}, + name=cropping_mod.DLC_CROP_LAYER_NAME, + ) + existing.visible = False + existing.mode = None + + viewer = SimpleNamespace( + layers=FakeLayerList([existing], active=None), + ) + + out = cropping_mod.ensure_dlc_crop_layer(viewer) + + assert out is existing + assert existing.visible is True + assert existing.mode == "add_rectangle" + assert viewer.layers.selection.active is existing + + +def test_ensure_dlc_crop_layer_creates_new(monkeypatch): + monkeypatch.setattr(cropping_mod, "Shapes", FakeShapes) + + created = FakeShapes(data=[], shape_type=[], metadata={}, name="created") + created.mode = None + + class Viewer: + def __init__(self): + self.layers = FakeLayerList([], active=None) + + def add_shapes(self, name, metadata): + created.name = name + created.metadata = metadata + self.layers.append(created) + return created + + viewer = Viewer() + + out = cropping_mod.ensure_dlc_crop_layer(viewer) + + assert out is created + assert out.name == cropping_mod.DLC_CROP_LAYER_NAME + assert out.metadata[cropping_mod.DLC_CROP_LAYER_META_KEY] is True + assert out.mode == "add_rectangle" + assert viewer.layers.selection.active is out + + +def test_dlc_config_y_extent_falls_back_to_active_image(monkeypatch): + monkeypatch.setattr(cropping_mod, "Image", FakeImage) + + image = FakeImage(data=np.zeros((5, 40, 70), dtype=np.uint8)) + viewer = SimpleNamespace( + dims=SimpleNamespace(range=None), + layers=FakeLayerList([image], active=image), + ) + + assert cropping_mod._dlc_config_y_extent(viewer) == 40 + + +def test_dlc_config_y_extent_falls_back_to_last_image(monkeypatch): + monkeypatch.setattr(cropping_mod, "Image", FakeImage) + + image = FakeImage(data=np.zeros((5, 55, 90), dtype=np.uint8)) + viewer = SimpleNamespace( + dims=SimpleNamespace(range=None), + layers=FakeLayerList([image], active=None), + ) + + assert cropping_mod._dlc_config_y_extent(viewer) == 55 + + +def test_dlc_config_y_extent_returns_none_when_unavailable(): + viewer = SimpleNamespace( + dims=SimpleNamespace(range=None), + layers=FakeLayerList([], active=None), + ) + + assert cropping_mod._dlc_config_y_extent(viewer) is None + + +def test_find_rectangle_in_layer_falls_back_from_selected_to_last(monkeypatch): + monkeypatch.setattr(cropping_mod, "Shapes", FakeShapes) + + bad_rect = np.array( + [ + [0.0, 10.0, 20.0], + [0.0, 10.0, 20.0], + [0.0, 10.0, 20.0], + [0.0, 10.0, 20.0], + ], + dtype=float, + ) + good_rect = np.array( + [ + [0.0, 5.0, 10.0], + [0.0, 5.0, 20.0], + [0.0, 15.0, 20.0], + [0.0, 15.0, 10.0], + ], + dtype=float, + ) + + layer = FakeShapes( + data=[bad_rect, good_rect], + shape_type=["rectangle", "rectangle"], + selected_data={0}, + ) + viewer = SimpleNamespace(dims=SimpleNamespace(range=[(0, 10, 1), (0, 100, 1), (0, 200, 1)])) + + spec = cropping_mod._find_rectangle_in_layer(viewer, layer, prefer_selected=True) + assert spec is not None + assert spec.viewer_crop.values == (10, 20, 5, 15) + + +def test_get_crop_source_summary_uses_active_shapes_when_no_dedicated(monkeypatch): + monkeypatch.setattr(cropping_mod, "Shapes", FakeShapes) + + rect = np.array( + [ + [0.0, 5.0, 10.0], + [0.0, 5.0, 20.0], + [0.0, 15.0, 20.0], + [0.0, 15.0, 10.0], + ], + dtype=float, + ) + + active = FakeShapes( + data=[rect], + shape_type=["rectangle"], + metadata={}, + name="manual", + selected_data={0}, + ) + + viewer = SimpleNamespace( + layers=FakeLayerList([active], active=active), + dims=SimpleNamespace(range=[(0, 10, 1), (0, 100, 1), (0, 200, 1)]), + ) + + source, spec = cropping_mod.get_crop_source_summary(viewer) + assert source == "active Shapes layer (manual)" + assert spec is not None + + +def test_get_crop_source_summary_returns_none_when_no_valid_rectangles(monkeypatch): + monkeypatch.setattr(cropping_mod, "Shapes", FakeShapes) + + poly = FakeShapes( + data=[np.array([[0, 1, 2], [0, 2, 3], [0, 3, 4]], dtype=float)], + shape_type=["polygon"], + metadata={}, + name="poly", + selected_data={0}, + ) + + viewer = SimpleNamespace( + layers=FakeLayerList([poly], active=poly), + dims=SimpleNamespace(range=[(0, 10, 1), (0, 100, 1), (0, 200, 1)]), + ) + + source, spec = cropping_mod.get_crop_source_summary(viewer) + assert source == "none" + assert spec is None + + +def test_plan_frame_extraction_rejects_missing_image_layer(): + viewer = SimpleNamespace(dims=SimpleNamespace(current_step=(0,))) + plan, error = cropping_mod.plan_frame_extraction(viewer, image_layer=None) + assert plan is None + assert "No image/video layer is active" in error + + +def test_plan_frame_extraction_rejects_missing_output_root(): + image = FakeImage(data=np.zeros((3, 10, 10), dtype=np.uint8), metadata={}) + viewer = SimpleNamespace(dims=SimpleNamespace(current_step=(0,))) + + plan, error = cropping_mod.plan_frame_extraction(viewer, image_layer=image) + assert plan is None + assert "Could not determine the output folder" in error + + +def test_plan_frame_extraction_requires_points_layer_for_label_export(tmp_path: Path): + image = FakeImage( + data=np.zeros((3, 10, 10), dtype=np.uint8), + metadata={"root": str(tmp_path)}, + ) + viewer = SimpleNamespace(dims=SimpleNamespace(current_step=(0,))) + + plan, error = cropping_mod.plan_frame_extraction( + viewer, + image_layer=image, + export_labels=True, + points_layer=None, + ) + assert plan is None + assert "no Points layer is available" in error + + +def test_plan_frame_extraction_requires_rectangle_when_crop_enabled(tmp_path: Path, monkeypatch): + monkeypatch.setattr(cropping_mod, "find_crop_rectangle", lambda viewer, prefer_selected=True: None) + + image = FakeImage( + data=np.zeros((3, 10, 10), dtype=np.uint8), + metadata={"root": str(tmp_path)}, + ) + viewer = SimpleNamespace(dims=SimpleNamespace(current_step=(0,))) + + plan, error = cropping_mod.plan_frame_extraction( + viewer, + image_layer=image, + apply_crop=True, + ) + assert plan is None + assert "no valid rectangle was found" in error + + +def test_plan_crop_save_rejects_missing_project_context(monkeypatch): + monkeypatch.setattr( + cropping_mod, + "infer_dlc_project_from_image_layer", + lambda image_layer, prefer_project_root=True: SimpleNamespace( + config_path=None, + project_root=None, + root_anchor=None, + ), + ) + + image = FakeImage(data=np.zeros((3, 10, 10), dtype=np.uint8)) + viewer = SimpleNamespace(layers=FakeLayerList([], active=None), dims=SimpleNamespace(range=[])) + + plan, error = cropping_mod.plan_crop_save(viewer, image_layer=image) + assert plan is None + assert "Could not determine a DLC config.yaml" in error + + +def test_plan_crop_save_rejects_missing_rectangle(monkeypatch, tmp_path: Path): + monkeypatch.setattr( + cropping_mod, + "infer_dlc_project_from_image_layer", + lambda image_layer, prefer_project_root=True: SimpleNamespace( + config_path=tmp_path / "config.yaml", + project_root=tmp_path, + root_anchor=tmp_path, + ), + ) + monkeypatch.setattr(cropping_mod, "find_crop_rectangle", lambda viewer, prefer_selected=True: None) + + image = FakeImage(data=np.zeros((3, 10, 10), dtype=np.uint8), name="demo.mp4") + viewer = SimpleNamespace(layers=FakeLayerList([], active=None), dims=SimpleNamespace(range=[])) + + plan, error = cropping_mod.plan_crop_save(viewer, image_layer=image) + assert plan is None + assert "No valid rectangle was found" in error + + +def test_execute_crop_save_replaces_non_dict_video_set_entry(monkeypatch, tmp_path: Path): + cfg = {"video_sets": {"video.mp4": "old-value"}} + + monkeypatch.setattr(cropping_mod.io, "load_config", lambda path: cfg) + written = {} + + def fake_write_config(path, out_cfg): + written["path"] = path + written["cfg"] = out_cfg + + monkeypatch.setattr(cropping_mod.io, "write_config", fake_write_config) + + plan = cropping_mod.CropSavePlan( + config_path=tmp_path / "config.yaml", + project_root=tmp_path, + video_key="video.mp4", + config_crop=cropping_mod.DLCConfigCropCoords(values=(1, 10, 20, 30)), + ) + + msg = cropping_mod.execute_crop_save(plan) + + assert "Saved crop" in msg + assert written["cfg"]["video_sets"]["video.mp4"]["crop"] == "1, 10, 20, 30" + + +class HiddenPanel(DummyPanel): + def isVisible(self): + return False + + def parentWidget(self): + return None + + +class VisiblePanel(DummyPanel): + def isVisible(self): + return True + + def parentWidget(self): + return None + + +def test_update_video_panel_context_hidden_panel_detaches_and_returns(monkeypatch): + panel = HiddenPanel() + called = {"detach": 0} + + monkeypatch.setattr( + cropping_mod, "_detach_crop_autorefresh", lambda p: called.__setitem__("detach", called["detach"] + 1) + ) + + viewer = SimpleNamespace(layers=FakeLayerList([], active=None), dims=SimpleNamespace(current_step=(0,))) + + cropping_mod.update_video_panel_context(viewer, panel) + + assert called["detach"] == 1 + assert panel.text is None + + +def test_update_video_panel_context_no_image_layer(monkeypatch): + panel = VisiblePanel() + monkeypatch.setattr(cropping_mod, "sync_crop_layer_autorefresh", lambda viewer, panel, refresh_callback: None) + + viewer = SimpleNamespace( + layers=FakeLayerList([], active=None), + dims=SimpleNamespace(current_step=(0,)), + ) + + cropping_mod.update_video_panel_context(viewer, panel) + assert panel.text == "No active video/image layer." diff --git a/src/napari_deeplabcut/_tests/ui/test_traj_select.py b/src/napari_deeplabcut/_tests/ui/test_traj_select.py new file mode 100644 index 00000000..04dc603c --- /dev/null +++ b/src/napari_deeplabcut/_tests/ui/test_traj_select.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from napari_deeplabcut.ui.plots.trajectory import TrajectoryMatplotlibCanvas + + +@pytest.mark.e2e +@pytest.mark.usefixtures("qtbot") +def test_sync_visible_lines_to_points_selection_shows_all_when_no_points_selected(viewer, qtbot): + layer = viewer.add_points( + np.array([[0, 0], [1, 1]]), + properties={"label": np.array(["nose", "tail"], dtype=object)}, + ) + + canvas = TrajectoryMatplotlibCanvas(viewer) + qtbot.add_widget(canvas) + + # Force the visibility sync to use this test layer directly. + canvas._get_plot_points_layer = lambda: layer + + # Avoid relying on df creation for this focused visibility test + canvas.df = object() + (line_nose,) = canvas.ax.plot([0, 1], [0, 1]) + (line_tail,) = canvas.ax.plot([0, 1], [1, 0]) + qtbot.wait(0) # ensure lines are fully initialized + canvas._lines = { + ("", "nose"): [line_nose], + ("", "tail"): [line_tail], + } + + layer.selected_data.clear() + canvas.sync_visible_lines_to_points_selection() + + assert line_nose.get_visible() is True + assert line_tail.get_visible() is True + + +@pytest.mark.e2e +@pytest.mark.usefixtures("qtbot") +def test_sync_visible_lines_to_points_selection_filters_by_selected_labels_in_bodypart_mode(viewer, qtbot): + layer = viewer.add_points( + np.array([[0, 0], [1, 1], [2, 2]]), + properties={"label": np.array(["nose", "tail", "nose"], dtype=object)}, + ) + + canvas = TrajectoryMatplotlibCanvas(viewer, get_color_mode=lambda: "bodypart") + qtbot.add_widget(canvas) + + # Force the visibility sync to use this test layer directly. + canvas._get_plot_points_layer = lambda: layer + + canvas.df = object() + (line_nose,) = canvas.ax.plot([0, 1], [0, 1]) + (line_tail,) = canvas.ax.plot([0, 1], [1, 0]) + qtbot.wait(0) # ensure lines are fully initialized + + canvas._lines = { + ("", "nose"): [line_nose], + ("", "tail"): [line_tail], + } + + # Select a point whose label is "tail" + layer.selected_data.select_only(1) + canvas.sync_visible_lines_to_points_selection() + + assert line_nose.get_visible() is False + assert line_tail.get_visible() is True + + +@pytest.mark.e2e +@pytest.mark.usefixtures("qtbot") +def test_sync_visible_lines_to_points_selection_shows_label_if_any_selected_point_has_that_label_in_bodypart_mode( + viewer, qtbot +): + layer = viewer.add_points( + np.array([[0, 0], [1, 1], [2, 2]]), + properties={"label": np.array(["nose", "tail", "nose"], dtype=object)}, + ) + + canvas = TrajectoryMatplotlibCanvas(viewer, get_color_mode=lambda: "bodypart") + qtbot.add_widget(canvas) + + # Force the visibility sync to use this test layer directly. + canvas._get_plot_points_layer = lambda: layer + + canvas.df = object() + (line_nose,) = canvas.ax.plot([0, 1], [0, 1]) + (line_tail,) = canvas.ax.plot([0, 1], [1, 0]) + qtbot.wait(0) # ensure lines are fully initialized + + canvas._lines = { + ("", "nose"): [line_nose], + ("", "tail"): [line_tail], + } + + # Select both nose points + layer.selected_data.update({0, 2}) + canvas.sync_visible_lines_to_points_selection() + + assert line_nose.get_visible() is True + assert line_tail.get_visible() is False diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index 293c9270..5cf0996c 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -3,10 +3,27 @@ NOTE: This file is generally already too long. For future development, please consider: - Moving existing responsibilities out into separate modules (existing or new) - Avoiding adding anything that is not strictly related to : - - Building the final UI (blocks can be moved to ui/ for better organization) - - Wiring to the core plugin functionality (e.g. via signals/slots, method calls, etc.) - - Anything that requires the full widget+viewer+signal/event context to function properly - - Similarly, test_widgets.py is a bit of a default drawer right now, please create new tests in _tests/ui + - Building the final UI (blocks can be moved to ui/ for better organization) + - Wiring to the core plugin functionality (e.g. via signals/slots, method calls, etc.) + - Anything that requires the full widget+viewer+signal/event context to function properly + - Similarly, test_widgets.py is a bit of a default drawer right now, please create new tests in _tests/ui +- Lifecycle of UI elements and Qt wiring should ideally: + - Use parent child widgets/controllers to KeypointControls + - Use child QTimers instead of fire-and-forget QTimer.singleShot for deferred UI work + - Use normal Qt signal connections for Qt-owned objects + - Keep explicit cleanup only for non-Qt subscriptions/resources + (e.g. napari event connections, observer install/uninstall, monkey-patch restoration) + +TODO: And general dev notes: +- The saving workflow is crammed into save_layers_dialog() right now, + and should move to a dedicated e.g. PointsLayerSaveFactory class in a dedicated file. +- Project/root/paths/image-meta/points-meta synchronization should be centralized. + It is too distributed right now, and it can be unclear "which truth" is authoritative. + Some sort of context-manager class would likely help. +- Maybe a dedicated layer lifecycle system would help, for layer adoption and setup. +- Something that owns UI sync state (what to refresh, why, when) could help with the heavy wiring. +- I'd suggest keeping in this file: + - color/label mode, menu/help actions, widget visibility and user interaction hooks. """ # src/napari_deeplabcut/_widgets.py @@ -61,6 +78,8 @@ from napari_deeplabcut.core.conflicts import compute_overwrite_report_for_points_save from napari_deeplabcut.core.layer_versioning import mark_layer_presentation_changed from napari_deeplabcut.core.layers import ( + PointsInteractionEvent, + PointsInteractionObserver, compute_label_progress, find_relevant_image_layer, get_first_points_layer, @@ -123,7 +142,7 @@ KeypointsDropdownMenu, ) from napari_deeplabcut.ui.layer_stats import LayerStatusPanel -from napari_deeplabcut.ui.plots.trajectory import KeypointMatplotlibCanvas +from napari_deeplabcut.ui.plots.trajectory import TrajectoryMatplotlibCanvas logger = logging.getLogger("napari-deeplabcut._widgets") # logger.setLevel(logging.DEBUG) # FIXME @C-Achard temp remove before merging @@ -196,6 +215,9 @@ def __init__(self, napari_viewer): # for future-proofing def _close_event(event): self.on_close(event) + points_inter = getattr(self, "_points_interactions", None) + if points_inter is not None: + points_inter.close() # if accepted, call original if event.isAccepted(): orig_close_event(event) @@ -255,10 +277,14 @@ def _close_event(event): ) self._mpl_docked = False - self._matplotlib_canvas = KeypointMatplotlibCanvas(self.viewer) + + self._traj_mpl_canvas = TrajectoryMatplotlibCanvas( + self.viewer, + get_color_mode=lambda: self.color_mode, + ) self._show_traj_plot_cb = QCheckBox("Show trajectories", parent=self) self._show_traj_plot_cb.setToolTip("Toggle to see trajectories in a t-y plot outside of the main video viewer") - self._show_traj_plot_cb.stateChanged.connect(self._show_matplotlib_canvas) + self._show_traj_plot_cb.stateChanged.connect(self._show_traj_canvas) self._show_traj_plot_cb.setChecked(False) self._show_traj_plot_cb.setEnabled(False) self._view_scheme_cb = QCheckBox("Show color scheme", parent=self) @@ -299,9 +325,17 @@ def _close_event(event): self._view_scheme_cb.setChecked(True) self._view_scheme_cb.toggled.connect(self._show_color_scheme) self._show_color_scheme() - self._color_scheme_panel.display.added.connect( - lambda w: w.part_label.clicked.connect(self._matplotlib_canvas._toggle_line_visibility), + # self._color_scheme_panel.display.added.connect( + # lambda w: w.part_label.clicked.connect(self._on_color_scheme_label_clicked), + # ) + + self._points_interactions = PointsInteractionObserver( + self.viewer, + self._on_points_interaction, + debounce_ms=0, + watch_content=False, ) + self._points_interactions.install() ### UI setup ends here # Modes init @@ -353,7 +387,7 @@ def _close_event(event): # NOTE while a timer may seem hacky, it is a simple, one-line solution that minimizes intrusion # There are to my knowledge no other way that is as concise and clean # (Of course this will be a problem if we start using it everywhere so do not reuse lightly) - QTimer.singleShot(10, self.silently_dock_matplotlib_canvas) + QTimer.singleShot(10, self.silently_dock_canvas) # If layers already exist (user loaded data before opening this widget), # adopt them so keypoint controls take ownership immediately. @@ -362,6 +396,10 @@ def _close_event(event): # Refresh layers stats widget QTimer.singleShot(0, self._refresh_layer_status_panel) + @cached_property + def settings(self): + return QSettings() + # ######################## # # Layer setup core methods # # ######################## # @@ -596,6 +634,11 @@ def _adopt_existing_layers(self) -> None: except Exception: pass + # Important: refresh the trajectory plot from the final adopted state. + # This fixes the case where layers were loaded before the plugin opened + # (e.g. drag-and-drop DLC data triggering the reader automatically). + self._refresh_trajectory_plot_from_layers() + def _adopt_layer(self, layer, index: int) -> None: """ Run the relevant portion of on_insert() for an already-existing layer. @@ -635,7 +678,7 @@ def _do(): QTimer.singleShot(0, _do) - def _ensure_mpl_canvas_docked(self) -> None: + def _ensure_traj_canvas_docked(self) -> None: """ Dock the Matplotlib canvas as a napari dock widget, exactly once, and only if the Qt window exists. Safe no-op in headless/proxy teardown. @@ -647,39 +690,77 @@ def _ensure_mpl_canvas_docked(self) -> None: if window is None: return - # If napari hasn't materialized its Qt window yet, skip (safe no-op). if getattr(window, "_qt_window", None) is None: - # In normal UI runs this won't happen when the user hits the checkbox. - # In tests/headless it may—so just do nothing. return try: - window.add_dock_widget(self._matplotlib_canvas, name="Trajectory plot", area="right", tabify=False) - self._matplotlib_canvas.canvas.draw_idle() - self._matplotlib_canvas.hide() + window.add_dock_widget(self._traj_mpl_canvas, name="Trajectory plot", area="right", tabify=False) + self._traj_mpl_canvas.canvas.draw_idle() + self._traj_mpl_canvas.hide() self._mpl_docked = True except Exception as e: - logging.debug("Skipping docking KeypointMatplotlibCanvas (not ready / teardown): %r", e) + logging.debug("Skipping docking canvas (not ready / teardown): %r", e) return - def silently_dock_matplotlib_canvas(self) -> None: + def _safe_get_traj_canvas(self): + canvas = getattr(self, "_traj_mpl_canvas", None) + if canvas is None: + return None + + try: + # Any Qt call is enough to verify the underlying C++ object still exists + canvas.isVisible() + except RuntimeError: + # Underlying Qt object was already deleted + self._traj_mpl_canvas = None + return None + + return canvas + + def silently_dock_canvas(self) -> None: """Dock the Matplotlib canvas without showing it.""" - self._ensure_mpl_canvas_docked() + self._ensure_traj_canvas_docked() if self._mpl_docked: - self._matplotlib_canvas.hide() + self._traj_mpl_canvas.hide() - def _show_matplotlib_canvas(self, state): + def _show_traj_canvas(self, state): if Qt.CheckState(state) == Qt.CheckState.Checked: - self._ensure_mpl_canvas_docked() + self._ensure_traj_canvas_docked() if self._mpl_docked: - self._matplotlib_canvas.show() + self._traj_mpl_canvas._apply_napari_theme() + self._traj_mpl_canvas.update_plot_range( + Event(type_name="", value=[self.viewer.dims.current_step[0]]), + force=True, + ) + self._traj_mpl_canvas.sync_visible_lines_to_points_selection() + self._traj_mpl_canvas.show() else: if self._mpl_docked: - self._matplotlib_canvas.hide() + self._traj_mpl_canvas.hide() - @cached_property - def settings(self): - return QSettings() + def _on_points_interaction(self, event: PointsInteractionEvent) -> None: + """ + Keep the trajectory plot in sync with the active points-layer selection. + + This is intentionally selection-driven: + - no selected points -> all trajectories + - selected points -> only selected labels' trajectories + """ + traj_canvas = self._safe_get_traj_canvas() + if traj_canvas is None or not traj_canvas.isVisible(): + return + if {"selection", "active_layer", "layers"} & set(event.reasons): + traj_canvas.sync_visible_lines_to_points_selection() + + def _refresh_trajectory_plot_from_layers(self) -> None: + """ + Refresh trajectory plot from the current viewer state. + + Deferred through QTimer so it runs after layer adoption/remap settles. + """ + traj_canvas = self._safe_get_traj_canvas() + if traj_canvas is not None: + QTimer.singleShot(0, traj_canvas.refresh_from_viewer_layers) def load_superkeypoints_diagram(self): points_layer = get_first_points_layer(self.viewer) @@ -1021,15 +1102,7 @@ def _refresh_layer_status_panel(self) -> None: self._layer_status_panel.set_point_size(get_uniform_point_size(active_dlc_points)) progress = compute_label_progress(active_dlc_points, fallback_paths=self._image_meta.paths) - self._layer_status_panel.set_progress_summary( - labeled_percent=progress.labeled_percent, - remaining_percent=progress.remaining_percent, - labeled_points=progress.labeled_points, - total_points=progress.total_points, - frame_count=progress.frame_count, - bodypart_count=progress.bodypart_count, - individual_count=progress.individual_count, - ) + self._layer_status_panel.set_progress_summary(progress=progress) def _on_active_points_size_changed(self, size: int) -> None: layer = self._current_dlc_points_layer() @@ -1199,7 +1272,7 @@ def _form_color_mode_selector(self): layout = QHBoxLayout() group = QButtonGroup(self) for i, mode in enumerate(keypoints.ColorMode.__members__, start=1): - btn = QRadioButton(mode.lower()) + btn = QRadioButton(mode.capitalize()) group.addButton(btn, i) layout.addWidget(btn) group.button(1).setChecked(True) @@ -1207,7 +1280,7 @@ def _form_color_mode_selector(self): self._layout.addWidget(group_box) def _func(): - self.color_mode = group.checkedButton().text() + self.color_mode = group.checkedButton().text().lower() group.buttonClicked.connect(_func) return group_box, group @@ -1776,6 +1849,13 @@ def color_mode(self, mode: str | keypoints.ColorMode): btn.setChecked(True) break + traj_canvas = self._safe_get_traj_canvas() + if traj_canvas is not None: + try: + traj_canvas.refresh_from_viewer_layers() + except Exception: + logger.debug("Failed to refresh trajectory plot after color mode change", exc_info=True) + self._update_color_scheme() self._trails_controller.on_points_visual_inputs_changed(checkbox_checked=self._trail_cb.isChecked()) diff --git a/src/napari_deeplabcut/config/_autostart.py b/src/napari_deeplabcut/config/_autostart.py index 2fb0b955..463561d9 100644 --- a/src/napari_deeplabcut/config/_autostart.py +++ b/src/napari_deeplabcut/config/_autostart.py @@ -50,6 +50,9 @@ def _ensure_keypoint_controls_open(viewer) -> None: ) except Exception: logger.debug("Failed to open Keypoint controls dock widget.", exc_info=True) + napari.utils.notifications.show_info( + "Failed to open Keypoint controls. Please open manually from the Plugins menu.", + ) def _maybe_open_for_inserted_layer(viewer, layer) -> None: diff --git a/src/napari_deeplabcut/config/models.py b/src/napari_deeplabcut/config/models.py index 39bf515d..36b6b24f 100644 --- a/src/napari_deeplabcut/config/models.py +++ b/src/napari_deeplabcut/config/models.py @@ -10,6 +10,8 @@ import numpy as np from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +# TODO @C-Achard: move to core/ + def unsorted_unique(array: Sequence) -> np.ndarray: """Return the unsorted unique elements of an array.""" diff --git a/src/napari_deeplabcut/core/layers.py b/src/napari_deeplabcut/core/layers.py index afbc9ea7..dd6f658d 100644 --- a/src/napari_deeplabcut/core/layers.py +++ b/src/napari_deeplabcut/core/layers.py @@ -1,3 +1,4 @@ +# src/napari_deeplabcut/core/layers.py from __future__ import annotations import logging @@ -8,6 +9,7 @@ import numpy as np from napari.layers import Image, Points, Shapes, Tracks +from qtpy.QtCore import QTimer from napari_deeplabcut.config.models import AnnotationKind, DLCHeaderModel from napari_deeplabcut.core.keypoints import build_color_cycles @@ -217,6 +219,11 @@ class LabelProgress: frame_count: int bodypart_count: int individual_count: int + completed_frames: int + completed_percent: float + incomplete_frames: tuple[int, ...] + incomplete_frames_by_individual: dict[str, int] + missing_points_by_individual: dict[str, int] def _get_header_model_from_metadata(md: dict) -> DLCHeaderModel | None: @@ -261,10 +268,10 @@ def set_uniform_point_size(layer: Points, size: int) -> None: layer.size = float(size) -def infer_frame_count(layer: Points, *, fallback_paths: list[str] | None = None) -> int: +def infer_frame_count(layer: Points, *, preferred_paths: list[str] | None = None) -> int: md = getattr(layer, "metadata", {}) or {} - paths = md.get("paths") or fallback_paths or [] + paths = preferred_paths or md.get("paths") or [] if paths: return len(paths) @@ -310,13 +317,124 @@ def infer_individual_count(layer: Points) -> int: return 1 +def _normalized_slot_id(value) -> str: + if value in ("", None): + return "" + try: + if np.isnan(value): + return "" + except Exception: + pass + text = str(value) + return "" if text.lower() == "nan" else text + + +def infer_observed_bodypart_names(layer: Points) -> list[str]: + """ + Ordered unique bodypart labels actually present in the active napari layer. + """ + props = getattr(layer, "properties", {}) or {} + labels = np.asarray(props.get("label", []), dtype=object).ravel() + + out: list[str] = [] + seen: set[str] = set() + for v in labels: + if v in ("", None): + continue + text = str(v) + if not text or text in seen: + continue + seen.add(text) + out.append(text) + return out + + +def infer_observed_individual_names(layer: Points) -> list[str]: + """ + Ordered unique individual ids actually present in the active napari layer. + + Single-animal convention: + - no ids / blank ids -> [''] + """ + props = getattr(layer, "properties", {}) or {} + ids_raw = props.get("id", None) + if ids_raw is None: + return [""] + + ids = np.asarray(ids_raw, dtype=object).ravel() + + out: list[str] = [] + seen: set[str] = set() + for v in ids: + text = _normalized_slot_id(v) + if text == "": + continue + if text in seen: + continue + seen.add(text) + out.append(text) + + return out if out else [""] + + +def _iter_labeled_slots(layer: Points): + """ + Yield unique annotatable slots currently represented in the active napari layer. + + Each slot is keyed by: + - frame index + - individual id ('' for single-animal / blank ids) + - bodypart label + """ + data = np.asarray(getattr(layer, "data", [])) + if data.ndim < 2 or data.shape[1] == 0: + return + + props = getattr(layer, "properties", {}) or {} + labels = np.asarray(props.get("label", []), dtype=object).ravel() + ids_raw = props.get("id", None) + + if ids_raw is None: + ids = np.array([""] * len(labels), dtype=object) + else: + ids = np.asarray(ids_raw, dtype=object).ravel() + + n = min(len(data), len(labels), len(ids)) + for i in range(n): + try: + frame = int(data[i, 0]) + except Exception: + continue + + label_val = labels[i] + if label_val in ("", None): + continue + label = str(label_val) + if not label: + continue + + id_text = _normalized_slot_id(ids[i]) + yield (frame, id_text, label) + + def compute_label_progress(layer: Points, *, fallback_paths: list[str] | None = None) -> LabelProgress: - frame_count = infer_frame_count(layer, fallback_paths=fallback_paths) + """ + Compute progress for the active napari layer. + + Semantics: + - Main percentage remains theoretical: + labeled_points / (frame_count * bodypart_count * individual_count) + - Richer frame-completion details are computed from the observed slot universe + currently represented in napari: + observed_ids × observed_labels + """ + frame_count = infer_frame_count(layer, preferred_paths=fallback_paths) bodypart_count = infer_bodypart_count(layer) individual_count = infer_individual_count(layer) total_points = frame_count * bodypart_count * individual_count + # Keep the top-line point percentage as before. data = np.asarray(getattr(layer, "data", [])) labeled_points = int(data.shape[0]) if data.ndim >= 2 else 0 @@ -328,6 +446,57 @@ def compute_label_progress(layer: Points, *, fallback_paths: list[str] | None = remaining_percent = max(0.0, 100.0 - labeled_percent) + # Richer completion details based on what is actually represented in napari. + slots = set(_iter_labeled_slots(layer) or []) + + observed_labels = infer_observed_bodypart_names(layer) + observed_ids = infer_observed_individual_names(layer) + expected_ids = observed_ids if observed_ids else [""] + + expected_pairs = {(id_text, label) for id_text in expected_ids for label in observed_labels} + expected_per_frame = len(expected_pairs) + + frame_to_pairs: dict[int, set[tuple[str, str]]] = {} + for frame, id_text, label in slots: + frame_to_pairs.setdefault(frame, set()).add((id_text, label)) + + # Count-based frame completion: match the diagnostic / user-facing intuition. + frame_slot_counts: dict[int, int] = {frame: len(pairs) for frame, pairs in frame_to_pairs.items()} + + completed_frames = 0 + incomplete_frames: list[int] = [] + incomplete_frames_by_individual: dict[str, int] = {id_text: 0 for id_text in expected_ids} + missing_points_by_individual: dict[str, int] = {id_text: 0 for id_text in expected_ids} + + if frame_count > 0 and expected_per_frame > 0: + for frame in range(frame_count): + present = frame_to_pairs.get(frame, set()) + present_count = frame_slot_counts.get(frame, 0) + + # A frame is considered complete if it has the expected number of unique slots. + if present_count >= expected_per_frame: + completed_frames += 1 + continue + + incomplete_frames.append(frame) + + # Still compute missing details by comparing against expected pairs. + # This is now only used for richer tooltip details on frames that are + # count-incomplete, which keeps the user-facing summary intuitive. + missing = expected_pairs - present + + missing_by_individual: dict[str, int] = {} + for id_text, _label in missing: + missing_by_individual[id_text] = missing_by_individual.get(id_text, 0) + 1 + + for id_text, missing_count in missing_by_individual.items(): + incomplete_frames_by_individual[id_text] = incomplete_frames_by_individual.get(id_text, 0) + 1 + missing_points_by_individual[id_text] = missing_points_by_individual.get(id_text, 0) + missing_count + + completed_percent = 100.0 * completed_frames / frame_count + else: + completed_percent = 0.0 + return LabelProgress( labeled_points=labeled_points, total_points=total_points, @@ -336,6 +505,11 @@ def compute_label_progress(layer: Points, *, fallback_paths: list[str] | None = frame_count=frame_count, bodypart_count=bodypart_count, individual_count=individual_count, + completed_frames=completed_frames, + completed_percent=completed_percent, + incomplete_frames=tuple(incomplete_frames), + incomplete_frames_by_individual=incomplete_frames_by_individual, + missing_points_by_individual=missing_points_by_individual, ) @@ -390,3 +564,235 @@ def find_relevant_image_layer(viewer) -> Image | None: return layer return None + + +# ----------------------------------------------- +# Points interaction observer +# ----------------------------------------------- + + +@dataclass(frozen=True) +class PointsInteractionEvent: + """ + Structured points-layer interaction event. + + Parameters + ---------- + viewer: + The napari viewer. + layer: + The active Points layer at the time the event flushes, or None. + reasons: + A normalized set of reasons that triggered the event. Typical values + include {"install"}, {"selection"}, {"active_layer"}, {"layers"}, + and {"content"}, depending on which observer hooks fired. + """ + + viewer: Any + layer: Points | None + reasons: frozenset[str] + + +def _iter_event_emitters(event_group: Any, names: tuple[str, ...]): + """ + Yield (name, emitter) pairs for names that exist on an event group. + """ + if event_group is None: + return + for name in names: + emitter = getattr(event_group, name, None) + if emitter is not None: + yield name, emitter + + +def capture_points_state( + layer: Points, + *, + include_data: bool = False, + include_properties: bool = False, +) -> dict[str, Any]: + """ + Best-effort snapshot helper for future history/undo systems. + + This is intentionally separate from the observer so callers can choose + whether they want lightweight interaction events or heavier snapshots. + """ + data = getattr(layer, "data", None) + try: + n_points = 0 if data is None else len(data) + except Exception: + n_points = 0 + + state: dict[str, Any] = { + "name": getattr(layer, "name", None), + "selected_data": tuple(sorted(int(i) for i in getattr(layer, "selected_data", set()) or set())), + "n_points": n_points, + } + + if include_data: + try: + state["data"] = getattr(layer, "data", None).copy() + except Exception: + state["data"] = None + + if include_properties: + try: + props = getattr(layer, "properties", {}) or {} + state["properties"] = {k: v.copy() if hasattr(v, "copy") else v for k, v in dict(props).items()} + except Exception: + state["properties"] = None + + return state + + +class PointsInteractionObserver: + """ + Observe the active napari Points layer and emit coalesced interaction events. + + Public / stable anchors used + ---------------------------- + - viewer.layers.selection.events.active + - layer.selected_data.events.changed / active + - layer.selected_data.events.items_changed (if present; useful in practice) + - viewer.layers.events.inserted / removed / reordered (if present) + + Notes + ----- + This is intentionally conservative: + - it avoids private napari APIs + - it tolerates event-name differences by connecting only to emitters that exist + - it coalesces bursts of events into one callback using a QTimer + """ + + def __init__( + self, + viewer: Any, + callback: Callable[[PointsInteractionEvent], None], + *, + debounce_ms: int = 0, + watch_content: bool = False, + ) -> None: + self.viewer = viewer + self.callback = callback + self.debounce_ms = max(0, int(debounce_ms)) + self.watch_content = watch_content + + self._active_layer: Points | None = None + self._viewer_connections: list[tuple[Any, Callable]] = [] + self._layer_connections: list[tuple[Any, Callable]] = [] + self._pending_reasons: set[str] = set() + + self._timer = QTimer() + self._timer.setSingleShot(True) + self._timer.timeout.connect(self._flush) + + # ------------------------------------------------------------------ + # Public lifecycle + # ------------------------------------------------------------------ + + def install(self) -> None: + """ + Install the observer onto the viewer. + """ + self._connect_viewer_events() + self._rebind_active_points_layer() + self._schedule("install") + + def close(self) -> None: + """ + Disconnect all emitters and stop the timer. + """ + self._timer.stop() + self._disconnect_all(self._layer_connections) + self._disconnect_all(self._viewer_connections) + self._active_layer = None + self._pending_reasons.clear() + + # ------------------------------------------------------------------ + # Internal connection helpers + # ------------------------------------------------------------------ + + def _connect(self, emitter: Any, callback: Callable, bucket: list[tuple[Any, Callable]]) -> None: + emitter.connect(callback) + bucket.append((emitter, callback)) + + def _disconnect_all(self, bucket: list[tuple[Any, Callable]]) -> None: + while bucket: + emitter, callback = bucket.pop() + try: + emitter.disconnect(callback) + except Exception: + pass + + def _connect_viewer_events(self) -> None: + # Active layer changes are the most important public hook. + active_emitter = getattr(self.viewer.layers.selection.events, "active", None) + if active_emitter is not None: + self._connect(active_emitter, self._on_active_layer_changed, self._viewer_connections) + + layer_events = getattr(self.viewer.layers, "events", None) + for _name, emitter in _iter_event_emitters(layer_events, ("inserted", "removed", "reordered")): + self._connect(emitter, self._on_layers_changed, self._viewer_connections) + + def _rebind_active_points_layer(self) -> None: + self._disconnect_all(self._layer_connections) + + active = getattr(self.viewer.layers.selection, "active", None) + if not isinstance(active, Points): + self._active_layer = None + return + + self._active_layer = active + + # Primary selection hooks: Selection model events + selection = getattr(active, "selected_data", None) + selection_events = getattr(selection, "events", None) + + for _name, emitter in _iter_event_emitters(selection_events, ("changed", "active", "items_changed")): + self._connect(emitter, self._on_selection_changed, self._layer_connections) + + # Optional content hooks, useful for future history/versioning use cases. + if self.watch_content: + layer_events = getattr(active, "events", None) + for event_name in ("data", "properties", "current_properties", "mode"): + for _name, emitter in _iter_event_emitters(layer_events, (event_name,)): + self._connect(emitter, self._on_content_changed, self._layer_connections) + + # ------------------------------------------------------------------ + # Event handlers + # ------------------------------------------------------------------ + + def _on_active_layer_changed(self, event=None) -> None: + self._rebind_active_points_layer() + self._schedule("active_layer") + + def _on_layers_changed(self, event=None) -> None: + # The active layer may or may not have changed, but rebinding is cheap and safe. + self._rebind_active_points_layer() + self._schedule("layers") + + def _on_selection_changed(self, event=None) -> None: + self._schedule("selection") + + def _on_content_changed(self, event=None) -> None: + self._schedule("content") + + # ------------------------------------------------------------------ + # Coalescing + # ------------------------------------------------------------------ + + def _schedule(self, reason: str) -> None: + self._pending_reasons.add(reason) + if not self._timer.isActive(): + self._timer.start(self.debounce_ms) + + def _flush(self) -> None: + reasons = frozenset(self._pending_reasons) + self._pending_reasons.clear() + + event = PointsInteractionEvent( + viewer=self.viewer, + layer=self._active_layer, + reasons=reasons, + ) + self.callback(event) diff --git a/src/napari_deeplabcut/napari.yaml b/src/napari_deeplabcut/napari.yaml index 481dfb4c..48883029 100644 --- a/src/napari_deeplabcut/napari.yaml +++ b/src/napari_deeplabcut/napari.yaml @@ -1,5 +1,5 @@ name: napari-deeplabcut -display_name: napari DeepLabCut +display_name: napari-deeplabcut contributions: commands: - id: napari-deeplabcut.get_hdf_reader diff --git a/src/napari_deeplabcut/napari_compat/color.py b/src/napari_deeplabcut/napari_compat/color.py index f277363b..c307dffc 100644 --- a/src/napari_deeplabcut/napari_compat/color.py +++ b/src/napari_deeplabcut/napari_compat/color.py @@ -19,18 +19,18 @@ def patch_color_manager_guess_continuous() -> None: try: import numpy as np from napari.layers.utils import color_manager - except Exception as e: + except Exception as e: # pragma: no cover logger.debug("Skipping color_manager patch (napari import failed): %r", e) return def guess_continuous(property_): try: return issubclass(property_.dtype.type, np.floating) - except Exception: + except Exception: # pragma: no cover return False try: color_manager.guess_continuous = guess_continuous - except Exception as e: + except Exception as e: # pragma: no cover logger.debug("Skipping color_manager patch (assignment failed): %r", e) return diff --git a/src/napari_deeplabcut/napari_compat/points_layer.py b/src/napari_deeplabcut/napari_compat/points_layer.py index 72fbb702..5ab69d21 100644 --- a/src/napari_deeplabcut/napari_compat/points_layer.py +++ b/src/napari_deeplabcut/napari_compat/points_layer.py @@ -228,8 +228,6 @@ def _offset_pasted_data( def apply_points_layer_ui_tweaks(viewer, layer, *, dropdown_cls, plt_module) -> object | None: """ - Best-effort private napari UI wiring. - Returns ------- object | None @@ -239,6 +237,7 @@ def apply_points_layer_ui_tweaks(viewer, layer, *, dropdown_cls, plt_module) -> controls = viewer.window.qt_viewer.dockLayerControls point_controls = controls.widget().widgets[layer] except Exception: + logger.debug("Failed to resolve point controls for layer UI tweaks", exc_info=True) return None widgets_to_hide = [ @@ -256,15 +255,24 @@ def apply_points_layer_ui_tweaks(viewer, layer, *, dropdown_cls, plt_module) -> widget = getattr(parent, widget_attr) widget.hide() except Exception: - pass + logger.debug( + "Failed to hide widget %s.%s in point controls", + parent_attr, + widget_attr, + exc_info=True, + ) try: cmap_source = plt_module.colormaps if callable(cmap_source): cmap_source = cmap_source() - colormap_selector = dropdown_cls(cmap_source, point_controls) - colormap_selector.update_to(layer.metadata.get("colormap_name", "viridis")) + # Normalize explicitly to strings for safety/readability. + cmap_names = [str(name) for name in cmap_source] + + colormap_selector = dropdown_cls(cmap_names, point_controls) + if hasattr(colormap_selector, "update_to"): + colormap_selector.update_to(layer.metadata.get("colormap_name", "viridis")) point_controls.layout().addRow("colormap", colormap_selector) return colormap_selector except Exception as e: diff --git a/src/napari_deeplabcut/ui/layer_stats.py b/src/napari_deeplabcut/ui/layer_stats.py index 1b3d9b56..da977334 100644 --- a/src/napari_deeplabcut/ui/layer_stats.py +++ b/src/napari_deeplabcut/ui/layer_stats.py @@ -1,17 +1,26 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from qtpy.QtCore import QSignalBlocker, Qt, Signal +from qtpy.QtGui import QIcon from qtpy.QtWidgets import ( + QApplication, QFormLayout, QGraphicsOpacityEffect, QGroupBox, QHBoxLayout, QLabel, + QMenu, QSlider, + QToolButton, QVBoxLayout, QWidget, ) +if TYPE_CHECKING: + from napari_deeplabcut.core.layers import LabelProgress + class LayerStatusPanel(QGroupBox): """ @@ -32,8 +41,31 @@ def __init__(self, parent: QWidget | None = None): self._progress_value = QLabel("No active keypoints layer") self._progress_value.setWordWrap(True) + self._progress_value.setCursor(Qt.WhatsThisCursor) self._progress_value.setTextInteractionFlags(Qt.TextSelectableByMouse) + # self._progress_info = QLabel("ℹ") + self._progress_info = QToolButton(self) + self._progress_info.setText("ℹ") + self._progress_info.setAutoRaise(True) + self._progress_info.setIcon(QIcon.fromTheme("help-about")) + self._progress_info.setCursor(Qt.WhatsThisCursor) + self._progress_info.setToolTip("Hover for more details") + + self._progress_details_text = "" + + self._progress_value.setContextMenuPolicy(Qt.CustomContextMenu) + self._progress_value.customContextMenuRequested.connect(self._show_progress_context_menu) + self._progress_info.setContextMenuPolicy(Qt.CustomContextMenu) + self._progress_info.customContextMenuRequested.connect(self._show_progress_context_menu) + + self._progress_container = QWidget(self) + progress_row = QHBoxLayout(self._progress_container) + progress_row.setContentsMargins(0, 0, 0, 0) + progress_row.setSpacing(4) + progress_row.addWidget(self._progress_value, stretch=1) + progress_row.addWidget(self._progress_info, stretch=0, alignment=Qt.AlignTop) + self._size_slider = QSlider(Qt.Horizontal, self) self._size_slider.setRange(1, 100) self._size_slider.setSingleStep(1) @@ -60,7 +92,7 @@ def __init__(self, parent: QWidget | None = None): form = QFormLayout() form.addRow("Folder", self._folder_value) - form.addRow("Progress", self._progress_value) + form.addRow("Progress", self._progress_container) form.addRow("Point size", self._size_controls) wrapper = QVBoxLayout(self) @@ -119,39 +151,100 @@ def set_point_size_enabled(self, enabled: bool, *, reason: str | None = None) -> def set_folder_name(self, folder_name: str) -> None: self._folder_value.setText(folder_name or "—") + def _show_progress_context_menu(self, pos) -> None: + if not self._progress_details_text: + return + + menu = QMenu(self) + copy_action = menu.addAction("Copy progress details") + chosen = menu.exec_(self._progress_value.mapToGlobal(pos)) + if chosen is copy_action: + self._copy_progress_details_to_clipboard() + + def _copy_progress_details_to_clipboard(self) -> None: + text = self._progress_details_text or self._progress_value.toolTip() or "" + if not text: + return + QApplication.clipboard().setText(text) + def set_progress_summary( self, *, - labeled_percent: float, - remaining_percent: float, - labeled_points: int, - total_points: int, - frame_count: int, - bodypart_count: int, - individual_count: int, + progress: LabelProgress, ) -> None: - if total_points <= 0: + + p = progress + + if p.total_points <= 0: self._progress_value.setText("Not enough metadata to estimate progress yet") self._progress_value.setToolTip("") + self._progress_details_text = "" return - if individual_count <= 1: - breakdown = f"{frame_count} frames × {bodypart_count} bodyparts" + if p.individual_count <= 1: + breakdown = f"{p.frame_count} frames × {p.bodypart_count} bodyparts" else: - breakdown = f"{frame_count} frames × {bodypart_count} bodyparts × {individual_count} individuals" - - self._progress_value.setText(f"{labeled_percent:.1f}% labeled") - self._progress_value.setToolTip( - f"{labeled_percent:.1f}% labeled, {remaining_percent:.1f}% remaining\n" - f"{labeled_points}/{total_points} of all possible points labeled • {breakdown}" + breakdown = f"{p.frame_count} frames × {p.bodypart_count} bodyparts × {p.individual_count} individuals" + + self._progress_value.setText(f"{p.labeled_percent:.1f}% labeled") + + incomplete_count = max(0, p.frame_count - p.completed_frames) + first_incomplete = list(p.incomplete_frames[:10]) + + details_lines = [ + f"{p.labeled_percent:.1f}% fully labeled, {p.remaining_percent:.1f}% remaining", + f"{p.labeled_points}/{p.total_points} possible keypoint slots currently labeled • {breakdown}", + f"{p.completed_frames}/{p.frame_count} frames fully labeled ({p.completed_percent:.1f}%)", + f"{incomplete_count} non fully labeled frames", + ] + + if first_incomplete: + details_lines.append("First non fully labeled frames: " + ", ".join(str(int(f)) for f in first_incomplete)) + + if p.individual_count > 1: + # Keep individual rows stable and compact + per_individual_lines = [] + for individual, n_frames in sorted(p.incomplete_frames_by_individual.items()): + if individual == "": + continue + int(p.missing_points_by_individual.get(individual, 0)) + per_individual_lines.append( + f"- {individual}: non fully labeled on {int(n_frames)} frame(s), " + f"{int(p.missing_points_by_individual.get(individual, 0))} missing keypoint(s)" + ) + + if per_individual_lines: + details_lines.append("") + details_lines.append("By individual:") + details_lines.extend(per_individual_lines) + + details_lines.append("") + details_lines.append( + "Tip: right-click this progress label to copy the full summary.\n" + "Please note that actual visibility of keypoints in the video cannot" + " be determined automatically.\nTherefore, not having all frames be fully labeled" + " does not necessarily imply the labeling is incomplete.\n" + "Please treat it as a relative progress estimate based on the maximum possible keypoints" ) + details_text = "\n".join(details_lines) + self._progress_details_text = details_text + self._progress_value.setToolTip(details_text) + self._progress_info.setToolTip(details_text) + self._progress_container.setToolTip(details_text) + def set_no_active_points_layer(self) -> None: self._progress_value.setText("No active keypoints layer") self._progress_value.setToolTip("") + self._progress_info.setToolTip("") + self._progress_container.setToolTip("") + self._progress_details_text = "" self.set_point_size_enabled(False, reason="Select a DLC keypoints layer to edit point size.") def set_invalid_points_layer(self) -> None: self._progress_value.setText("Active layer is not a DLC keypoints layer") self._progress_value.setToolTip("") + self._progress_info.setToolTip("") + self._progress_container.setToolTip("") + self._progress_details_text = "" self.set_point_size_enabled(False, reason="This control only works for DLC keypoints layers.") diff --git a/src/napari_deeplabcut/ui/plots/trajectory.py b/src/napari_deeplabcut/ui/plots/trajectory.py index e7aa6945..dd393227 100644 --- a/src/napari_deeplabcut/ui/plots/trajectory.py +++ b/src/napari_deeplabcut/ui/plots/trajectory.py @@ -16,23 +16,32 @@ import napari import numpy as np from matplotlib.backends.backend_qtagg import FigureCanvas, NavigationToolbar2QT +from napari.layers import Points from napari.utils.events import Event -from qtpy.QtCore import QSize, Qt +from qtpy.QtCore import QSize, Qt, QTimer from qtpy.QtGui import QIcon -from qtpy.QtWidgets import QHBoxLayout, QLabel, QSlider, QVBoxLayout, QWidget +from qtpy.QtWidgets import QHBoxLayout, QLabel, QSizePolicy, QSlider, QVBoxLayout, QWidget import napari_deeplabcut.core.io as io -from napari_deeplabcut.core.layers import get_first_points_layer, get_first_video_image_layer +from napari_deeplabcut.config.models import DLCHeaderModel +from napari_deeplabcut.config.settings import ( + DEFAULT_MULTI_ANIMAL_INDIVIDUAL_CMAP, + DEFAULT_SINGLE_ANIMAL_CMAP, +) +from napari_deeplabcut.core.keypoints import build_color_cycles +from napari_deeplabcut.core.layers import ( + get_first_image_layer, + get_first_video_image_layer, +) +from napari_deeplabcut.utils.deprecations import deprecated logger = logging.getLogger(__name__) - _PACKAGE = "napari_deeplabcut" @lru_cache(maxsize=1) def _pkg_root(): - # Traversable root of the package return files(_PACKAGE) @@ -52,6 +61,11 @@ def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] super().__init__(*args, **kwargs) self.setIconSize(QSize(28, 28)) + @staticmethod + def _qicon(pathlike) -> QIcon: + """Build a QIcon safely from any path-like object.""" + return QIcon(str(pathlike)) + def _update_buttons_checked(self) -> None: """Update toggle tool icons when selected/unselected.""" super()._update_buttons_checked() @@ -59,40 +73,41 @@ def _update_buttons_checked(self) -> None: if "pan" in self._actions: if self._actions["pan"].isChecked(): - self._actions["pan"].setIcon(QIcon(Path(icon_dir) / "Pan_checked.png")) + self._actions["pan"].setIcon(self._qicon(Path(icon_dir) / "Pan_checked.png")) else: - self._actions["pan"].setIcon(QIcon(Path(icon_dir) / "Pan.png")) + self._actions["pan"].setIcon(self._qicon(Path(icon_dir) / "Pan.png")) if "zoom" in self._actions: if self._actions["zoom"].isChecked(): - self._actions["zoom"].setIcon(QIcon(Path(icon_dir) / "Zoom_checked.png")) + self._actions["zoom"].setIcon(self._qicon(Path(icon_dir) / "Zoom_checked.png")) else: - self._actions["zoom"].setIcon(QIcon(Path(icon_dir) / "Zoom.png")) + self._actions["zoom"].setIcon(self._qicon(Path(icon_dir) / "Zoom.png")) -class KeypointMatplotlibCanvas(QWidget): +class TrajectoryMatplotlibCanvas(QWidget): """Trajectory plot using matplotlib for keypoints (t-y plot).""" - # FIXME: y axis should be reversed due to napari using top-left as origin - - def __init__(self, napari_viewer, parent=None): + def __init__(self, napari_viewer, parent=None, get_color_mode: callable = None): super().__init__(parent=parent) self.viewer = napari_viewer + self._get_color_mode = get_color_mode or (lambda: "bodypart") + with mplstyle.context(self.mpl_style_sheet_path): self.canvas = FigureCanvas() - self.canvas.figure.set_size_inches(4, 2, forward=True) + # self.canvas.figure.set_size_inches(4, 2, forward=True) self.canvas.figure.set_layout_engine("constrained") self.ax = self.canvas.figure.subplots() - - self.toolbar = NapariNavigationToolbar(self.canvas, parent=self) - self._replace_toolbar_icons() + self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--") + self.ax.set_xlabel("Frame") + self.ax.set_ylabel("Y position") + + # self.toolbar = NapariNavigationToolbar(self.canvas, parent=self) + self.toolbar = NavigationToolbar2QT(self.canvas, parent=self) + self.toolbar.setIconSize(QSize(28, 28)) + self._set_toolbar_tooltips() self.canvas.mpl_connect("button_press_event", self.on_doubleclick) - self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--") - self.ax.set_xlabel("Frame") - self.ax.set_ylabel("Y position") - self.slider = QSlider(Qt.Horizontal) self.slider.setMinimum(50) self.slider.setMaximum(10000) @@ -104,29 +119,147 @@ def __init__(self, napari_viewer, parent=None): self._window = self.slider.value() self.slider.valueChanged.connect(self.set_window) - layout = QVBoxLayout() - layout.addWidget(self.canvas) - layout.addWidget(self.toolbar) + self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + self.canvas.updateGeometry() + + self.toolbar.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) + self.slider.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) + self.slider_value.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + + layout = QVBoxLayout(self) + layout.setContentsMargins(4, 4, 4, 4) + layout.setSpacing(4) - layout2 = QHBoxLayout() - layout2.addWidget(self.slider) - layout2.addWidget(self.slider_value) - layout.addLayout(layout2) + # Give almost all vertical stretch to the plot canvas + layout.addWidget(self.canvas, stretch=1) + layout.addWidget(self.toolbar, stretch=0) - self.setLayout(layout) + slider_row = QHBoxLayout() + slider_row.addWidget(self.slider, stretch=1) + slider_row.addWidget(self.slider_value, stretch=0) + layout.addLayout(slider_row, stretch=0) + + # self.setLayout(layout) self.frames = [] self.keypoints = [] self.df = None - self.setMinimumHeight(300) + self.setMinimumSize(280, 350) self.viewer.dims.events.current_step.connect(self.update_plot_range) self._n = 0 - self.update_plot_range(Event(type_name="", value=[self.viewer.dims.current_step[0]])) + self.update_plot_range( + Event(type_name="", value=[self.viewer.dims.current_step[0]]), + force=True, + ) + self._apply_axis_theme() self.viewer.layers.events.inserted.connect(self._load_dataframe) self.viewer.dims.events.range.connect(self._update_slider_max) - self._lines: dict[str, list] = {} + self._lines: dict[tuple[str, str], list] = {} + + self._apply_napari_theme() + self._connect_theme_events() + # If layers already existed before this widget was created + # (e.g. drag-and-drop load before opening the plugin), populate + # the plot from the current viewer state on the next event-loop turn. + QTimer.singleShot(0, self.refresh_from_viewer_layers) + + def resizeEvent(self, event) -> None: + super().resizeEvent(event) + self.canvas.draw_idle() + + def sizeHint(self) -> QSize: + """ + Preferred initial size for the trajectory plot dock widget. + + Wide enough for the toolbar + slider, and tall enough that the plot + is useful by default without preventing later resizing. + """ + return QSize(480, 340) + + def minimumSizeHint(self) -> QSize: + """ + Smallest comfortable size before the widget becomes cramped. + """ + return QSize(280, 340) + + def _get_header_model_from_metadata(self, md: dict) -> DLCHeaderModel | None: + """Return DLCHeaderModel from metadata['header'] when possible. + TODO: Check codebase for duplicate logic and centralize if needed. + """ + if not isinstance(md, dict): + return None + + hdr = md.get("header", None) + if hdr is None: + return None + + if isinstance(hdr, DLCHeaderModel): + return hdr + + if isinstance(hdr, dict): + try: + return DLCHeaderModel.model_validate(hdr) + except Exception: + return None + + try: + return DLCHeaderModel(columns=hdr) + except Exception: + return None + + def _is_multianimal_layer(self, layer: Points) -> bool: + """TODO: Check codebase for duplicate logic and centralize if needed.""" + md = getattr(layer, "metadata", None) or {} + header = self._get_header_model_from_metadata(md) + if header is None: + return False + + try: + inds = getattr(header, "individuals", None) + return bool(inds and len(inds) > 0 and str(inds[0]) != "") + except Exception: + return False + + def _get_config_colormap(self, layer: Points) -> str: + """Return the colormap for the given layer based on its metadata. + TODO: Check codebase for duplicate logic and centralize if needed. + """ + md = getattr(layer, "metadata", None) or {} + cmap = md.get("config_colormap") + if isinstance(cmap, str) and cmap: + return cmap + return DEFAULT_SINGLE_ANIMAL_CMAP + + def _plot_mode(self) -> str: + try: + mode = str(self._get_color_mode()).lower() + except Exception: + mode = "bodypart" + + if "individual" in mode: + return "individual" + return "bodypart" + + @staticmethod + def _normalized_cycle(mapping) -> dict[str, object]: + """Return a cycle mapping with string keys for robust lookups.""" + try: + return {str(k): v for k, v in (mapping or {}).items()} + except Exception: + return {} + + @staticmethod + def _normalized_individual_name(value) -> str: + """Best-effort normalization for individual/id values used as cycle keys.""" + if value is None: + return "" + text = str(value) + if not text or text.lower() == "nan": + return "" + return text def on_doubleclick(self, event): if getattr(event, "dblclick", False): @@ -134,10 +267,98 @@ def on_doubleclick(self, event): return show = list(self._lines.values())[0][0].get_visible() for lines in self._lines.values(): - for l in lines: - l.set_visible(not show) + for line in lines: + line.set_visible(not show) self._refresh_canvas(value=self._n) + def refresh_from_viewer_layers(self) -> None: + """ + Refresh the trajectory plot from the current viewer state. + """ + try: + self._load_dataframe() + except Exception: + logger.debug("Trajectory plot: failed to load dataframe from viewer layers", exc_info=True) + + try: + self._update_slider_max() + except Exception: + logger.debug("Trajectory plot: failed to update slider max", exc_info=True) + + try: + self.sync_visible_lines_to_points_selection() + except Exception: + logger.debug("Trajectory plot: failed to sync visible lines to point selection", exc_info=True) + + def _get_plot_points_layer(self): + """ + Return the active plottable Points layer for DLC trajectories when possible. + + A generic napari Points layer may not have the DLC header required by io.form_df. + If the active layer is not a suitable DLC Points layer, fall back to the first + plottable Points layer in the viewer. + """ + + def _is_plottable_points_layer(layer) -> bool: + if not isinstance(layer, Points): + return False + + md = getattr(layer, "metadata", None) or {} + data = getattr(layer, "data", None) + + if md.get("header") is None: + return False + if data is None or len(data) == 0: + return False + + return True + + active_layer = getattr(getattr(self.viewer.layers, "selection", None), "active", None) + if _is_plottable_points_layer(active_layer): + return active_layer + + for layer in self.viewer.layers: + if _is_plottable_points_layer(layer): + return layer + + return None + + def _clear_plot(self) -> None: + """Clear plotted trajectories and reset axes.""" + self.df = None + self._lines = {} + + self.ax.clear() + self.ax.set_xlabel("Frame") + self.ax.set_ylabel("Y position") + self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--") + self._apply_napari_theme() + + def _has_refreshable_df(self) -> bool: + """ + Return True if self.df looks like something we can safely use with len(...). + + This keeps selection-only tests and partial UI states from crashing. + """ + if self.df is None: + return False + try: + len(self.df) + except Exception: + return False + return True + + def _apply_axis_theme(self) -> None: + """Force axis/text colors to match napari theme.""" + is_light = self._napari_theme_has_light_bg() + fg = "black" if is_light else "white" + + self.ax.tick_params(axis="both", colors=fg, which="both") + self.ax.xaxis.label.set_color(fg) + self.ax.yaxis.label.set_color(fg) + self.ax.title.set_color(fg) + self.vline.set_color(fg) + def _napari_theme_has_light_bg(self) -> bool: theme = napari.utils.theme.get_theme(self.viewer.theme) _, _, bg_lightness = theme.background.as_hsl_tuple() @@ -147,15 +368,13 @@ def _napari_theme_has_light_bg(self) -> bool: def mpl_style_sheet_path(self) -> Path: if self._napari_theme_has_light_bg(): return _styles_traversable() / "light.mplstyle" - else: - return _styles_traversable() / "dark.mplstyle" + return _styles_traversable() / "dark.mplstyle" def _get_path_to_icon(self) -> Path: icon_root = _assets_traversable() / "icons" if self._napari_theme_has_light_bg(): return icon_root / "black" - else: - return icon_root / "white" + return icon_root / "white" def _replace_toolbar_icons(self) -> None: icon_dir = self._get_path_to_icon() @@ -167,53 +386,86 @@ def _replace_toolbar_icons(self) -> None: ) if text == "Zoom": action.setToolTip("Zoom to rectangle; Click once to activate; Click again to deactivate") - if len(text) > 0: - icon_path = icon_dir / (text + ".png") - action.setIcon(QIcon(str(icon_path))) + if text: + icon_path = Path(icon_dir) / f"{text}.png" + if icon_path.is_file(): + action.setIcon(QIcon(str(icon_path))) + else: + logger.debug(f"Failed to set toolbar icon from {icon_path}: file does not exist") def _load_dataframe(self, event=None) -> None: - points_layer = get_first_points_layer(self.viewer) - if points_layer is None: - return - # Preserve existing semantics (numpy bool inversion) from original code - try: - if ~np.any(points_layer.data): + with mplstyle.context(self.mpl_style_sheet_path): + points_layer = self._get_plot_points_layer() + if points_layer is None: + # No plottable DLC points layer present -> clear the plot quietly. + self._clear_plot() return - except Exception: - return - - # Silly hack so the window does not hang the first time it is shown - self.show() - self.hide() - try: - self.df = io.form_df( - points_layer.data, - layer_metadata=points_layer.metadata, - layer_properties=points_layer.properties, - ) - except Exception as e: - logger.error("Failed to form DataFrame from points layer: %r", e, exc_info=True) - return - - self._lines.clear() - self.ax.clear() - self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--") - self.ax.set_xlabel("Frame") - self.ax.set_ylabel("Y position") + # Silly hack so the window does not hang the first time it is shown + was_visible = self.isVisible() + self.show() + if not was_visible: + self.hide() - for keypoint in self.df.columns.get_level_values("bodyparts").unique(): - y = self.df.xs((keypoint, "y"), axis=1, level=["bodyparts", "coords"]) - x = np.arange(len(y)) try: - color = points_layer.metadata["face_color_cycles"]["label"][keypoint] - except Exception: - color = "C0" - lines = self.ax.plot(x, y, color=color, label=str(keypoint)) - self._lines[str(keypoint)] = lines + self.df = io.form_df( + points_layer.data, + layer_metadata=points_layer.metadata, + layer_properties=points_layer.properties, + ) + except KeyError as e: + # Generic / incomplete points layer: not an error for the UI, just skip plotting. + logger.debug("Trajectory plot skipped for non-DLC/incomplete points layer: %r", e) + self._clear_plot() + return + except Exception as e: + logger.error("Failed to form DataFrame from points layer: %r", e, exc_info=True) + self._clear_plot() + return - self._refresh_canvas(value=self._n) + image_layer = get_first_video_image_layer(self.viewer) + if image_layer is None: + image_layer = get_first_image_layer(self.viewer) + + self._lines = {} + self.ax.clear() + self.ax.set_xlabel("Frame") + self.ax.set_ylabel("Y position") + self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--") + self._apply_napari_theme() + + height = None + if image_layer is not None: + try: + img_data = image_layer.data + if getattr(image_layer, "rgb", False): + height = img_data.shape[-3] + else: + height = img_data.shape[-2] + except Exception: + height = None + + self._lines = {} + + for individual, bodypart, y in self._iter_series(): + x = np.arange(y.shape[0]) + color = self._line_color_for(points_layer, individual, bodypart) + label = self._legend_text_for(individual, bodypart) + artists = self.ax.plot(x, y, color=color, label=label) + self._lines[(individual, bodypart)] = artists + + # Match napari image coordinates: y increases downward + if height is not None: + self.ax.set_ylim(height, 0) + else: + self.ax.invert_yaxis() + + self._refresh_canvas(value=self._n) + @deprecated( + details="No longer used, instead visibility is based on napari Points selection.", + replacement="sync_visible_lines_to_points_selection", + ) def _toggle_line_visibility(self, keypoint: str) -> None: if keypoint not in self._lines: return @@ -221,31 +473,48 @@ def _toggle_line_visibility(self, keypoint: str) -> None: artist.set_visible(not artist.get_visible()) self._refresh_canvas(value=self._n) + def show_only_keypoint(self, keypoint: str) -> None: + """Show all trajectories matching one bodypart; if unknown, show all.""" + matches = {k for k in self._lines if k[1] == keypoint} + if not matches: + self._show_all_keypoints() + return + self._set_visible_keypoints(matches) + def _refresh_canvas(self, value: int) -> None: - if self.df is None: + if not self._has_refreshable_df(): + self.canvas.draw_idle() return - start = max(0, value - self._window // 2) - end = min(value + self._window // 2, len(self.df)) - self.ax.set_xlim(start, end) - self.vline.set_xdata([value]) - self.canvas.draw() + + with mplstyle.context(self.mpl_style_sheet_path): + start = max(0, value - self._window // 2) + end = min(value + self._window // 2, len(self.df)) + self.ax.set_xlim(start, end) + self.vline.set_xdata([value]) + self.canvas.draw() def set_window(self, value: int) -> None: self._window = value self.slider_value.setText(str(value)) self.update_plot_range(Event(type_name="", value=[self._n])) - def update_plot_range(self, event) -> None: + def update_plot_range(self, event, force: bool = False) -> None: + if not self.isVisible() and not force: + return + value = event.value[0] self._n = value + if self.df is None: return + self._refresh_canvas(value) def _update_slider_max(self, event=None) -> None: img = get_first_video_image_layer(self.viewer) if img is None: return + try: n_frames = img.data.shape[0] except Exception: @@ -255,3 +524,319 @@ def _update_slider_max(self, event=None) -> None: self.slider.setMaximum(self.slider.minimum()) else: self.slider.setMaximum(n_frames - 1) + + def _set_visible_keypoints(self, visible_keys: set) -> None: + if not self._lines: + return + + mode = self._plot_mode() + + for (individual, bodypart), artists in self._lines.items(): + if mode == "individual": + show = (individual, bodypart) in visible_keys + else: + show = bodypart in visible_keys + + for artist in artists: + artist.set_visible(show) + + if self.isVisible(): + self._refresh_canvas(value=self._n) + + def _show_all_keypoints(self) -> None: + """Show all trajectories.""" + if not self._lines: + return + + for artists in self._lines.values(): + for artist in artists: + artist.set_visible(True) + + if self.isVisible(): + self._refresh_canvas(value=self._n) + + def _selected_line_keys_from_points_layer(self) -> set: + points_layer = self._get_plot_points_layer() + if points_layer is None: + return set() + + selected = getattr(points_layer, "selected_data", None) + if not selected: + return set() + + props = getattr(points_layer, "properties", {}) or {} + labels = props.get("label", None) + ids = props.get("id", None) + + if labels is None: + return set() + + try: + labels_arr = np.asarray(labels, dtype=object).ravel() + except Exception: + return set() + + try: + ids_arr = np.asarray(ids, dtype=object).ravel() if ids is not None else None + except Exception: + ids_arr = None + + mode = self._plot_mode() + visible = set() + + for idx in selected: + try: + i = int(idx) + except Exception: + continue + if not (0 <= i < len(labels_arr)): + continue + + label = str(labels_arr[i]) + if not label: + continue + + if mode == "individual": + individual = "" + if ids_arr is not None and i < len(ids_arr): + val = ids_arr[i] + if val is not None: + text = str(val) + if text and text.lower() != "nan": + individual = text + visible.add((individual, label)) + else: + # bodypart mode -> show all series with this bodypart + visible.add(label) + + return visible + + def sync_visible_lines_to_points_selection(self) -> None: + """ + Sync trajectory visibility to the current napari Points selection. + + Behavior: + - no selected points -> show all trajectories + - selected points -> show only trajectories for the selected (id, label) pairs + """ + if not self._lines: + return + + visible = self._selected_line_keys_from_points_layer() + if not visible: + self._show_all_keypoints() + return + + self._set_visible_keypoints(visible) + + def _set_toolbar_tooltips(self) -> None: + """Set clearer tooltips for the stock matplotlib toolbar.""" + for action in self.toolbar.actions(): + text = action.text() + if text == "Pan": + action.setToolTip( + "Pan/Zoom: Left button pans; Right button zooms; Click once to activate; Click again to deactivate" + ) + elif text == "Zoom": + action.setToolTip("Zoom to rectangle; Click once to activate; Click again to deactivate") + + def _df_has_individuals(self) -> bool: + if self.df is None: + return False + try: + cols = self.df.columns + if "individuals" not in cols.names: + return False + vals = [str(v) for v in cols.get_level_values("individuals").unique()] + return any(v != "" for v in vals) + except Exception: + return False + + def _resolved_face_color_cycles(self, layer: Points) -> dict[str, dict]: + """ + Resolve the same label/id color cycles as ColorSchemeResolver. + + - label cycle uses the current config colormap + - id cycle uses DEFAULT_MULTI_ANIMAL_INDIVIDUAL_CMAP for multi-animal + - single-animal falls back to the bodypart cycle for both + """ + md = getattr(layer, "metadata", None) or {} + header = self._get_header_model_from_metadata(md) + if header is None: + return {} + + config_cmap = self._get_config_colormap(layer) + + try: + bodypart_cycles = build_color_cycles(header, config_cmap) or {} + except Exception: + logger.debug("Trajectory plot: failed to build bodypart color cycles", exc_info=True) + bodypart_cycles = {} + + if self._is_multianimal_layer(layer): + try: + individual_cycles = build_color_cycles(header, DEFAULT_MULTI_ANIMAL_INDIVIDUAL_CMAP) or {} + except Exception: + logger.debug("Trajectory plot: failed to build individual color cycles", exc_info=True) + individual_cycles = {} + else: + individual_cycles = bodypart_cycles + + return { + "label": self._normalized_cycle(bodypart_cycles.get("label", {})), + "id": self._normalized_cycle(individual_cycles.get("id", {})), + } + + def _line_color_for(self, points_layer: Points, individual: str, bodypart: str): + """ + Resolve line color from the same logic as ColorSchemeResolver. + + - individual mode -> prefer id cycle + - bodypart mode -> prefer label cycle + - fallback -> whichever exists + - final fallback -> Matplotlib default + """ + cycles = self._resolved_face_color_cycles(points_layer) + mode = self._plot_mode() + + id_cycle = cycles.get("id", {}) or {} + label_cycle = cycles.get("label", {}) or {} + + individual_key = self._normalized_individual_name(individual) + bodypart_key = str(bodypart) + + if mode == "individual" and self._is_multianimal_layer(points_layer): + if individual_key and individual_key in id_cycle: + return id_cycle[individual_key] + if bodypart_key in label_cycle: + return label_cycle[bodypart_key] + return "C0" + + if bodypart_key in label_cycle: + return label_cycle[bodypart_key] + if individual_key and individual_key in id_cycle: + return id_cycle[individual_key] + return "C0" + + def _legend_text_for(self, individual: str, bodypart: str) -> str: + if self._plot_mode() == "individual" and individual: + return f"{individual} • {bodypart}" + return bodypart + + def _iter_series(self): + if self.df is None: + return + + cols = self.df.columns + names = list(cols.names) + + has_inds = "individuals" in names + if has_inds: + individuals = [self._normalized_individual_name(v) for v in cols.get_level_values("individuals").unique()] + # preserve order while removing duplicates + seen = set() + individuals = [v for v in individuals if not (v in seen or seen.add(v))] + else: + individuals = [""] + + for individual in individuals: + for bodypart in cols.get_level_values("bodyparts").unique(): + mask = cols.get_level_values("bodyparts") == bodypart + mask &= cols.get_level_values("coords") == "y" + + if has_inds: + mask &= cols.get_level_values("individuals") == individual + + y_df = self.df.loc[:, mask] + if y_df.shape[1] == 0: + continue + + y = np.asarray(y_df.iloc[:, 0].to_numpy(), dtype=float).ravel() + if y.size == 0: + continue + + yield self._normalized_individual_name(individual), str(bodypart), y + + def _apply_toolbar_stylesheet(self) -> None: + """ + Apply a minimal Qt stylesheet so the toolbar background has enough contrast. + + In light mode: + - toolbar background becomes light gray + - buttons remain mostly transparent until hover/checked + + In dark mode: + - keep a subtle dark/transparent look + """ + is_light = self._napari_theme_has_light_bg() + + if is_light: + fg = "#202020" + toolbar_bg = "#ececec" # light gray background for the whole toolbar + hover = "rgba(0, 0, 0, 0.06)" + pressed = "rgba(0, 0, 0, 0.12)" + border = "rgba(0, 0, 0, 0.10)" + else: + fg = "#f2f2f2" + toolbar_bg = "transparent" + hover = "rgba(255, 255, 255, 0.10)" + pressed = "rgba(255, 255, 255, 0.18)" + border = "rgba(255, 255, 255, 0.12)" + + self.toolbar.setStyleSheet( + f""" + QToolBar {{ + background: {toolbar_bg}; + border: none; + spacing: 2px; + padding: 2px; + }} + QToolButton {{ + color: {fg}; + background: transparent; + border: 1px solid transparent; + border-radius: 4px; + padding: 2px; + margin: 1px; + }} + QToolButton:hover {{ + background: {hover}; + border: 1px solid {border}; + }} + QToolButton:pressed {{ + background: {pressed}; + }} + QToolButton:checked {{ + background: {pressed}; + border: 1px solid {border}; + }} + """ + ) + + def _apply_label_styles(self) -> None: + """Apply light/dark text color to simple Qt labels in this widget.""" + fg = "black" if self._napari_theme_has_light_bg() else "white" + self.slider_value.setStyleSheet(f"color: {fg};") + + def _apply_napari_theme(self) -> None: + """ + Re-apply all napari-dependent styling. + + Safe to call repeatedly. + """ + self._apply_axis_theme() + # self._apply_toolbar_icons() + self._apply_toolbar_stylesheet() + self._apply_label_styles() + self.canvas.draw_idle() + + def _connect_theme_events(self) -> None: + """ + Re-apply theme when the viewer theme changes, if the event exists. + + Safe across versions: if the event is absent, this becomes a no-op. + """ + viewer_events = getattr(self.viewer, "events", None) + theme_emitter = getattr(viewer_events, "theme", None) + if theme_emitter is not None: + theme_emitter.connect(lambda event=None: self._apply_napari_theme())