diff --git a/src/napari_deeplabcut/_reader.py b/src/napari_deeplabcut/_reader.py index aca3b7bc..721d466d 100644 --- a/src/napari_deeplabcut/_reader.py +++ b/src/napari_deeplabcut/_reader.py @@ -11,6 +11,7 @@ from dask_image.imread import imread from napari.types import LayerData from natsort import natsorted +from pandas.api.types import is_numeric_dtype from napari_deeplabcut import misc @@ -228,7 +229,10 @@ def read_hdf(filename: str) -> list[LayerData]: nrows = df.shape[0] data = np.empty((nrows, 3)) image_paths = df["level_0"] - if np.issubdtype(image_paths.dtype, np.number): + dtype = getattr(image_paths, "dtype", None) + if dtype is None: + dtype = np.asarray(image_paths).dtype + if is_numeric_dtype(dtype): image_inds = image_paths.values paths2inds = [] else: diff --git a/src/napari_deeplabcut/_tests/test_widgets.py b/src/napari_deeplabcut/_tests/test_widgets.py index 05f60f56..ef1b2e9f 100644 --- a/src/napari_deeplabcut/_tests/test_widgets.py +++ b/src/napari_deeplabcut/_tests/test_widgets.py @@ -218,8 +218,11 @@ def test_matplotlib_canvas_initialization_and_slider(viewer, points, qtbot): assert canvas._window == initial_window + 100 assert canvas.slider_value.text() == str(initial_window + 100) - # Test plot refresh on frame change + # Test plot refresh does nothing when plot is hidden canvas.update_plot_range(event=type("Event", (), {"value": [5]})) + assert canvas._n == 0 + # Test plot refresh on frame change (forced as it is hidden) + canvas.update_plot_range(event=type("Event", (), {"value": [5]}), force=True) assert canvas._n == 5 # Check that x-limits reflect the new window start, end = canvas.ax.get_xlim() diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index 0f97a50a..e7e9b502 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -58,6 +58,9 @@ from napari_deeplabcut.misc import ( build_color_cycles, encode_categories, + find_layer, + find_layers, + find_layers_of_types, guarantee_multiindex_rows, to_os_dir_sep, ) @@ -370,12 +373,13 @@ def __init__(self, napari_viewer, parent=None): self.canvas = FigureCanvas() self.canvas.figure.set_layout_engine("constrained") self.ax = self.canvas.figure.subplots() + self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--") + self.ax.set_xlabel("Frame") + self.ax.set_ylabel("Y position") + self.toolbar = NapariNavigationToolbar(self.canvas, parent=self) self._replace_toolbar_icons() self.canvas.mpl_connect("button_press_event", self.on_doubleclick) - self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--") - self.ax.set_xlabel("Frame") - self.ax.set_ylabel("Y position") # Add a slot to specify the range of frames to plot self.slider = QSlider(Qt.Horizontal) self.slider.setMinimum(50) @@ -410,6 +414,7 @@ def __init__(self, napari_viewer, parent=None): # Run update plot range once to initialize the plot self._n = 0 self.update_plot_range(Event(type_name="", value=[self.viewer.dims.current_step[0]])) + self._apply_axis_theme() self.viewer.layers.events.inserted.connect(self._load_dataframe) self.viewer.dims.events.range.connect(self._update_slider_max) @@ -423,6 +428,17 @@ def on_doubleclick(self, event): l.set_visible(not show) self._refresh_canvas(value=self._n) + def _apply_axis_theme(self): + """Force axis/text colors to match napari theme.""" + is_light = self._napari_theme_has_light_bg() + fg = "black" if is_light else "white" + + self.ax.tick_params(axis="both", colors=fg, which="both") + self.ax.xaxis.label.set_color(fg) + self.ax.yaxis.label.set_color(fg) + self.ax.title.set_color(fg) + self.vline.set_color(fg) + def _napari_theme_has_light_bg(self) -> bool: """ Does this theme have a light background? @@ -480,33 +496,64 @@ def _replace_toolbar_icons(self) -> None: action.setIcon(QIcon(icon_path)) def _load_dataframe(self): - points_layer = None - for layer in self.viewer.layers: - if isinstance(layer, Points): - points_layer = layer - break + with mplstyle.context(self.mpl_style_sheet_path): + points_layer = find_layer( + self.viewer.layers, + Points, + ) + data = getattr(points_layer, "data", None) + if points_layer is None or data is None or len(data) == 0: + return - if points_layer is None or ~np.any(points_layer.data): - return + self.show() + self.hide() - self.show() # Silly hack so the window does not hang the first time it is shown - self.hide() + # Build dataframe from points + self.df = _form_df( + points_layer.data, + {"metadata": points_layer.metadata, "properties": points_layer.properties}, + ) - self.df = _form_df( - points_layer.data, - { - "metadata": points_layer.metadata, - "properties": points_layer.properties, - }, - ) - for keypoint in self.df.columns.get_level_values("bodyparts").unique(): - y = self.df.xs((keypoint, "y"), axis=1, level=["bodyparts", "coords"]) - x = np.arange(len(y)) - color = points_layer.metadata["face_color_cycles"]["label"][keypoint] - lines = self.ax.plot(x, y, color=color, label=keypoint) - self._lines[keypoint] = lines + # Find an Image layer (prefer a time stack, but allow 2D) + image_layer = find_layer( + self.viewer.layers, + Image, + predicate=lambda lyr: getattr(lyr.data, "ndim", 0) >= 3, + ) + # fallback to any Image if no stack found + if image_layer is None: + image_layer = find_layer(self.viewer.layers, Image) + + # Reset plot + self.ax.clear() + self.ax.set_xlabel("Frame") + self.ax.set_ylabel("Y position") + self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--") + self._apply_axis_theme() + + # Compute height robustly + H = None + if image_layer is not None: + data = image_layer.data + # If RGB, last axis is channels -> height is -3 + H = data.shape[-3] if getattr(image_layer, "rgb", False) else data.shape[-2] + + # Plot trajectories + self._lines = {} + for keypoint in self.df.columns.get_level_values("bodyparts").unique(): + y = self.df.xs((keypoint, "y"), axis=1, level=["bodyparts", "coords"]).to_numpy().squeeze() + x = np.arange(len(y)) + color = points_layer.metadata["face_color_cycles"]["label"][keypoint] + lines = self.ax.plot(x, y, color=color, label=keypoint) + self._lines[keypoint] = lines + + # Match napari (y increases downward) without modifying y data + if H is not None: + self.ax.set_ylim(H, 0) + else: + self.ax.invert_yaxis() - self._refresh_canvas(value=self._n) + self.canvas.draw_idle() def _toggle_line_visibility(self, keypoint): for artist in self._lines[keypoint]: @@ -514,19 +561,22 @@ def _toggle_line_visibility(self, keypoint): self._refresh_canvas(value=self._n) def _refresh_canvas(self, value): - start = max(0, value - self._window // 2) - end = min(value + self._window // 2, len(self.df)) + with mplstyle.context(self.mpl_style_sheet_path): + start = max(0, value - self._window // 2) + end = min(value + self._window // 2, len(self.df)) - self.ax.set_xlim(start, end) - self.vline.set_xdata([value]) - self.canvas.draw() + self.ax.set_xlim(start, end) + self.vline.set_xdata([value]) + self.canvas.draw() def set_window(self, value): self._window = value self.slider_value.setText(str(value)) self.update_plot_range(Event(type_name="", value=[self._n])) - def update_plot_range(self, event): + def update_plot_range(self, event, force=False): + if not self.isVisible() and not force: + return value = event.value[0] self._n = value @@ -537,15 +587,18 @@ def update_plot_range(self, event): def _update_slider_max(self, event): """Update the slider's maximum value based on the number of frames in the data.""" - for layer in self.viewer.layers: - if isinstance(layer, Image) and len(layer.data.shape) >= 3: - n_frames = layer.data.shape[0] - # if less than 50 frames, set max to min to avoid slider issues - if n_frames < self.slider.minimum(): - self.slider.setMaximum(self.slider.minimum()) - else: - self.slider.setMaximum(n_frames - 1) - break + layer = find_layer( + self.viewer.layers, + Image, + predicate=lambda lyr: getattr(lyr.data, "ndim", 0) >= 3, + ) + if layer is not None: + n_frames = layer.data.shape[0] + # if less than 50 frames, set max to min to avoid slider issues + if n_frames < self.slider.minimum(): + self.slider.setMaximum(self.slider.minimum()) + else: + self.slider.setMaximum(n_frames - 1) class KeypointControls(QWidget): @@ -730,6 +783,9 @@ def _show_matplotlib_canvas(self, state): if Qt.CheckState(state) == Qt.CheckState.Checked: self._ensure_mpl_canvas_docked() if self._mpl_docked: + self._matplotlib_canvas.update_plot_range( + Event(type_name="", value=[self.viewer.dims.current_step[0]]), force=True + ) self._matplotlib_canvas.show() else: if self._mpl_docked: @@ -740,11 +796,11 @@ def settings(self): return QSettings() def load_superkeypoints_diagram(self): - points_layer = None - for layer in self.viewer.layers: - if isinstance(layer, Points): - points_layer = layer - break + points_layer = find_layer( + self.viewer.layers, + Points, + predicate=lambda lyr: lyr.metadata.get("tables") is not None, + ) if points_layer is None: return @@ -776,13 +832,13 @@ def _map_keypoints(self, super_animal: str): # - Assumes _load_superkeypoints and _load_config succeed # and return well-formed data; I/O errors are not handled. # - Silently ignores keypoints that have no nearest neighbor in the superkeypoint set (no user feedback). - points_layer = None - for layer in self.viewer.layers: - if isinstance(layer, Points) and layer.metadata.get("tables"): - points_layer = layer - break - - if points_layer is None or ~np.any(points_layer.data): + points_layer = find_layer( + self.viewer.layers, + Points, + predicate=lambda lyr: lyr.metadata.get("tables") is not None, + ) + data = getattr(points_layer, "data", None) + if points_layer is None or data is None or not data.size > 0: return xy = points_layer.data[:, 1:3] @@ -790,7 +846,7 @@ def _map_keypoints(self, super_animal: str): xy_ref = np.c_[[val for val in superkpts_dict.values()]] neighbors = keypoints._find_nearest_neighbors(xy, xy_ref) found = neighbors != -1 - if ~np.any(found): + if not found.size > 0: return project_path = points_layer.metadata["project"] @@ -833,9 +889,8 @@ def _show_trails(self, state): inds = encode_categories(categories) temp = np.c_[inds, store.layer.data] cmap = "viridis" - for layer in self.viewer.layers: - if isinstance(layer, Points) and layer.metadata: - cmap = layer.metadata["colormap_name"] + for layer in find_layers(self.viewer.layers, Points): + cmap = layer.metadata["colormap_name"] self._trails = self.viewer.add_tracks( temp, tail_length=50, @@ -877,13 +932,14 @@ def _form_help_buttons(self): return layout def _extract_single_frame(self, *args): - image_layer = None - points_layer = None - for layer in self.viewer.layers: - if isinstance(layer, Image): - image_layer = layer - elif isinstance(layer, Points): - points_layer = layer + image_layer = find_layer( + self.viewer.layers, + Image, + ) + points_layer = find_layer( + self.viewer.layers, + Points, + ) if image_layer is not None: ind = self.viewer.dims.current_step[0] frame = image_layer.data[ind] @@ -914,24 +970,27 @@ def _extract_single_frame(self, *args): def _store_crop_coordinates(self, *args): if not (project_path := self._images_meta.get("project")): return - for layer in self.viewer.layers: - if isinstance(layer, Shapes): - try: - ind = layer.shape_type.index("rectangle") - except ValueError: - return - bbox = layer.data[ind][:, 1:] - h = self.viewer.dims.range[2][1] - bbox[:, 0] = h - bbox[:, 0] - bbox = np.clip(bbox, 0, a_max=None).astype(int) - y1, x1 = bbox.min(axis=0) - y2, x2 = bbox.max(axis=0) - temp = {"crop": ", ".join(map(str, [x1, x2, y1, y2]))} - config_path = os.path.join(project_path, "config.yaml") - cfg = _load_config(config_path) - cfg["video_sets"][os.path.join(project_path, "videos", self._images_meta["name"])] = temp - _write_config(config_path, cfg) - break + layer = find_layer( + self.viewer.layers, + Shapes, + predicate=lambda lyr: getattr(lyr, "shape_type", None) and "rectangle" in lyr.shape_type, + ) + if layer is not None: + try: + ind = layer.shape_type.index("rectangle") + except ValueError: + return + bbox = layer.data[ind][:, 1:] + h = self.viewer.dims.range[2][1] + bbox[:, 0] = h - bbox[:, 0] + bbox = np.clip(bbox, 0, a_max=None).astype(int) + y1, x1 = bbox.min(axis=0) + y2, x2 = bbox.max(axis=0) + temp = {"crop": ", ".join(map(str, [x1, x2, y1, y2]))} + config_path = os.path.join(project_path, "config.yaml") + cfg = _load_config(config_path) + cfg["video_sets"][os.path.join(project_path, "videos", self._images_meta["name"])] = temp + _write_config(config_path, cfg) def _form_dropdown_menus(self, store): menu = KeypointsDropdownMenu(store) @@ -1002,11 +1061,14 @@ def rgb2hex(r, g, b, _): if self.color_mode == str(keypoints.ColorMode.INDIVIDUAL): mode = "id" - for layer in self.viewer.layers: - if isinstance(layer, Points) and layer.metadata: - self._display.update_color_scheme( - {name: to_hex(color) for name, color in layer.metadata["face_color_cycles"][mode].items()} - ) + for layer in find_layers( + self.viewer.layers, + Points, + predicate=lambda lyr: "face_color_cycles" in lyr.metadata, + ): + self._display.update_color_scheme( + {name: to_hex(color) for name, color in layer.metadata["face_color_cycles"][mode].items()} + ) def _remap_frame_indices(self, layer): if not self._images_meta.get("paths"): @@ -1166,9 +1228,8 @@ def on_insert(self, event): colormap_selector.currentTextChanged.connect(self._update_colormap) point_controls.layout().addRow("colormap", colormap_selector) - for layer_ in self.viewer.layers: - if not isinstance(layer_, Image): - self._remap_frame_indices(layer_) + for layer_ in find_layers_of_types(self.viewer.layers, [Points, Tracks, Shapes]): + self._remap_frame_indices(layer_) def on_remove(self, event): layer = event.value @@ -1214,21 +1275,25 @@ def on_active_layer_change(self, event) -> None: menu.setHidden(True) def _update_colormap(self, colormap_name): - for layer in self.viewer.layers.selection: - if isinstance(layer, Points) and layer.metadata: - face_color_cycle_maps = build_color_cycles( - layer.metadata["header"], - colormap_name, - ) - layer.metadata["face_color_cycles"] = face_color_cycle_maps - face_color_prop = "label" - if self.color_mode == str(keypoints.ColorMode.INDIVIDUAL): - face_color_prop = "id" - - layer.face_color = face_color_prop - layer.face_color_cycle = face_color_cycle_maps[face_color_prop] - layer.events.face_color() - self._update_color_scheme() + layers = find_layers( + self.viewer.layers.selection, + Points, + predicate=lambda lyr: lyr.metadata, + ) + for layer in layers: + face_color_cycle_maps = build_color_cycles( + layer.metadata["header"], + colormap_name, + ) + layer.metadata["face_color_cycles"] = face_color_cycle_maps + face_color_prop = "label" + if self.color_mode == str(keypoints.ColorMode.INDIVIDUAL): + face_color_prop = "id" + + layer.face_color = face_color_prop + layer.face_color_cycle = face_color_cycle_maps[face_color_prop] + layer.events.face_color() + self._update_color_scheme() @register_points_action("Change labeling mode") def cycle_through_label_modes(self, *args): @@ -1271,11 +1336,14 @@ def color_mode(self, mode: str | keypoints.ColorMode): else: face_color_mode = "id" - for layer in self.viewer.layers: - if isinstance(layer, Points) and layer.metadata: - layer.face_color = face_color_mode - layer.face_color_cycle = layer.metadata["face_color_cycles"][face_color_mode] - layer.events.face_color() + for layer in find_layers( + self.viewer.layers, + Points, + predicate=lambda lyr: "face_color_cycles" in lyr.metadata, + ): + layer.face_color = face_color_mode + layer.face_color_cycle = layer.metadata["face_color_cycles"][face_color_mode] + layer.events.face_color() for btn in self._color_mode_selector.buttons(): if btn.text().lower() == str(mode).lower(): diff --git a/src/napari_deeplabcut/misc.py b/src/napari_deeplabcut/misc.py index a510d695..3809527a 100644 --- a/src/napari_deeplabcut/misc.py +++ b/src/napari_deeplabcut/misc.py @@ -1,15 +1,18 @@ from __future__ import annotations import os -from collections.abc import Sequence +from collections.abc import Callable, Iterable, Sequence from enum import Enum, EnumMeta from itertools import cycle from pathlib import Path +from typing import TypeVar, overload import numpy as np import pandas as pd from napari.utils import colormaps +T = TypeVar("T") + def find_project_config_path(labeled_data_path: str) -> str: return str(Path(labeled_data_path).parents[2] / "config.yaml") @@ -211,3 +214,79 @@ def _generate_next_value_(name, start, count, last_values): def __str__(self): return self.value + + +@overload +def find_layer( + layers: Iterable[object], + layer_type: type[T], + *, + predicate: Callable[[T], bool] | None = None, + default: T | None = None, +) -> T | None: ... + + +def find_layer( + layers: Iterable[object], + layer_type: type[T], + *, + predicate: Callable[[T], bool] | None = None, + default: T | None = None, +) -> T | None: + """ + Return the first layer in `layers` that is an instance of `layer_type`. + + Parameters + ---------- + layers: + Any iterable of layer-like objects (e.g. viewer.layers). + layer_type: + The class/type to match (e.g. napari.layers.Image). + predicate: + Optional filter called on matching layers. + default: + Value returned if no match is found. + + Returns + ------- + The first matching layer, else `default`. + """ + for layer in layers: + if isinstance(layer, layer_type): + if predicate is None or predicate(layer): + return layer + return default + + +def find_layers( + layers: Iterable[object], + layer_type: type[T], + *, + predicate: Callable[[T], bool] | None = None, +) -> list[T]: + """ + Return all layers in `layers` that are instances of `layer_type`. + """ + out: list[T] = [] + for layer in layers: + if isinstance(layer, layer_type): + if predicate is None or predicate(layer): + out.append(layer) + return out + + +def find_layers_of_types( + layers: Iterable[object], + layer_types: Iterable[type[T]], + *, + predicate: Callable[[T], bool] | None = None, +) -> list[T]: + """ + Return all layers in `layers` that are instances of any of `layer_types`. + """ + out: list[T] = [] + for layer in layers: + if any(isinstance(layer, layer_type) for layer_type in layer_types): + if predicate is None or predicate(layer): + out.append(layer) + return out