From f9807548594828969f9cf21b45dab6ac8c06b1ca Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 22 Jan 2026 14:04:23 +0100 Subject: [PATCH 1/8] Refactor layer selection with utility functions Introduced find_layer, find_layers, and find_layers_of_types utilities in misc.py to simplify and standardize layer selection logic throughout _widgets.py. Replaced multiple for-loops and isinstance checks with these new functions, improving code readability and maintainability. --- src/napari_deeplabcut/_widgets.py | 271 ++++++++++++++++++------------ src/napari_deeplabcut/misc.py | 81 ++++++++- 2 files changed, 247 insertions(+), 105 deletions(-) diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index 0f97a50a..4d135a20 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -58,6 +58,9 @@ from napari_deeplabcut.misc import ( build_color_cycles, encode_categories, + find_layer, + find_layers, + find_layers_of_types, guarantee_multiindex_rows, to_os_dir_sep, ) @@ -370,12 +373,13 @@ def __init__(self, napari_viewer, parent=None): self.canvas = FigureCanvas() self.canvas.figure.set_layout_engine("constrained") self.ax = self.canvas.figure.subplots() + self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--") + self.ax.set_xlabel("Frame") + self.ax.set_ylabel("Y position") + self.toolbar = NapariNavigationToolbar(self.canvas, parent=self) self._replace_toolbar_icons() self.canvas.mpl_connect("button_press_event", self.on_doubleclick) - self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--") - self.ax.set_xlabel("Frame") - self.ax.set_ylabel("Y position") # Add a slot to specify the range of frames to plot self.slider = QSlider(Qt.Horizontal) self.slider.setMinimum(50) @@ -410,6 +414,7 @@ def __init__(self, napari_viewer, parent=None): # Run update plot range once to initialize the plot self._n = 0 self.update_plot_range(Event(type_name="", value=[self.viewer.dims.current_step[0]])) + self._apply_axis_theme() self.viewer.layers.events.inserted.connect(self._load_dataframe) self.viewer.dims.events.range.connect(self._update_slider_max) @@ -423,6 +428,17 @@ def on_doubleclick(self, event): l.set_visible(not show) self._refresh_canvas(value=self._n) + def _apply_axis_theme(self): + """Force axis/text colors to match napari theme.""" + is_light = self._napari_theme_has_light_bg() + fg = "black" if is_light else "white" + + self.ax.tick_params(axis="both", colors=fg, which="both") + self.ax.xaxis.label.set_color(fg) + self.ax.yaxis.label.set_color(fg) + self.ax.title.set_color(fg) + self.vline.set_color(fg) + def _napari_theme_has_light_bg(self) -> bool: """ Does this theme have a light background? @@ -480,33 +496,64 @@ def _replace_toolbar_icons(self) -> None: action.setIcon(QIcon(icon_path)) def _load_dataframe(self): - points_layer = None - for layer in self.viewer.layers: - if isinstance(layer, Points): - points_layer = layer - break + with mplstyle.context(self.mpl_style_sheet_path): + points_layer = find_layer( + self.viewer.layers, + Points, + ) - if points_layer is None or ~np.any(points_layer.data): - return + if points_layer is None or not np.any(points_layer.data): + return - self.show() # Silly hack so the window does not hang the first time it is shown - self.hide() + self.show() + self.hide() - self.df = _form_df( - points_layer.data, - { - "metadata": points_layer.metadata, - "properties": points_layer.properties, - }, - ) - for keypoint in self.df.columns.get_level_values("bodyparts").unique(): - y = self.df.xs((keypoint, "y"), axis=1, level=["bodyparts", "coords"]) - x = np.arange(len(y)) - color = points_layer.metadata["face_color_cycles"]["label"][keypoint] - lines = self.ax.plot(x, y, color=color, label=keypoint) - self._lines[keypoint] = lines + # Build dataframe from points + self.df = _form_df( + points_layer.data, + {"metadata": points_layer.metadata, "properties": points_layer.properties}, + ) - self._refresh_canvas(value=self._n) + # Find an Image layer (prefer a time stack, but allow 2D) + image_layer = find_layer( + self.viewer.layers, + Image, + predicate=lambda lyr: getattr(lyr.data, "ndim", 0) >= 3, + ) + # fallback to any Image if no stack found + if image_layer is None: + image_layer = find_layer(self.viewer.layers, Image) + + # Reset plot + self.ax.clear() + self.ax.set_xlabel("Frame") + self.ax.set_ylabel("Y position") + self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--") + self._apply_axis_theme() + + # Compute height robustly + H = None + if image_layer is not None: + data = image_layer.data + # If RGB, last axis is channels -> height is -3 + H = data.shape[-3] if getattr(image_layer, "rgb", False) else data.shape[-2] + + # Plot trajectories + self._lines = {} + for keypoint in self.df.columns.get_level_values("bodyparts").unique(): + y = self.df.xs((keypoint, "y"), axis=1, level=["bodyparts", "coords"]).to_numpy().squeeze() + x = np.arange(len(y)) + color = points_layer.metadata["face_color_cycles"]["label"][keypoint] + lines = self.ax.plot(x, y, color=color, label=keypoint) + self._lines[keypoint] = lines + + # Match napari (y increases downward) without modifying y data + if H is not None: + self.ax.set_ylim(H, 0) + else: + self.ax.invert_yaxis() + + self.canvas.draw_idle() def _toggle_line_visibility(self, keypoint): for artist in self._lines[keypoint]: @@ -514,12 +561,13 @@ def _toggle_line_visibility(self, keypoint): self._refresh_canvas(value=self._n) def _refresh_canvas(self, value): - start = max(0, value - self._window // 2) - end = min(value + self._window // 2, len(self.df)) + with mplstyle.context(self.mpl_style_sheet_path): + start = max(0, value - self._window // 2) + end = min(value + self._window // 2, len(self.df)) - self.ax.set_xlim(start, end) - self.vline.set_xdata([value]) - self.canvas.draw() + self.ax.set_xlim(start, end) + self.vline.set_xdata([value]) + self.canvas.draw() def set_window(self, value): self._window = value @@ -537,15 +585,18 @@ def update_plot_range(self, event): def _update_slider_max(self, event): """Update the slider's maximum value based on the number of frames in the data.""" - for layer in self.viewer.layers: - if isinstance(layer, Image) and len(layer.data.shape) >= 3: - n_frames = layer.data.shape[0] - # if less than 50 frames, set max to min to avoid slider issues - if n_frames < self.slider.minimum(): - self.slider.setMaximum(self.slider.minimum()) - else: - self.slider.setMaximum(n_frames - 1) - break + layer = find_layer( + self.viewer.layers, + Image, + predicate=lambda lyr: getattr(lyr.data, "ndim", 0) >= 3, + ) + if layer is not None: + n_frames = layer.data.shape[0] + # if less than 50 frames, set max to min to avoid slider issues + if n_frames < self.slider.minimum(): + self.slider.setMaximum(self.slider.minimum()) + else: + self.slider.setMaximum(n_frames - 1) class KeypointControls(QWidget): @@ -740,11 +791,11 @@ def settings(self): return QSettings() def load_superkeypoints_diagram(self): - points_layer = None - for layer in self.viewer.layers: - if isinstance(layer, Points): - points_layer = layer - break + points_layer = find_layer( + self.viewer.layers, + Points, + predicate=lambda lyr: lyr.metadata.get("tables") is not None, + ) if points_layer is None: return @@ -776,12 +827,11 @@ def _map_keypoints(self, super_animal: str): # - Assumes _load_superkeypoints and _load_config succeed # and return well-formed data; I/O errors are not handled. # - Silently ignores keypoints that have no nearest neighbor in the superkeypoint set (no user feedback). - points_layer = None - for layer in self.viewer.layers: - if isinstance(layer, Points) and layer.metadata.get("tables"): - points_layer = layer - break - + points_layer = find_layer( + self.viewer.layers, + Points, + predicate=lambda lyr: lyr.metadata.get("tables") is not None, + ) if points_layer is None or ~np.any(points_layer.data): return @@ -833,9 +883,8 @@ def _show_trails(self, state): inds = encode_categories(categories) temp = np.c_[inds, store.layer.data] cmap = "viridis" - for layer in self.viewer.layers: - if isinstance(layer, Points) and layer.metadata: - cmap = layer.metadata["colormap_name"] + for layer in find_layers(self.viewer.layers, Points): + cmap = layer.metadata["colormap_name"] self._trails = self.viewer.add_tracks( temp, tail_length=50, @@ -877,13 +926,14 @@ def _form_help_buttons(self): return layout def _extract_single_frame(self, *args): - image_layer = None - points_layer = None - for layer in self.viewer.layers: - if isinstance(layer, Image): - image_layer = layer - elif isinstance(layer, Points): - points_layer = layer + image_layer = find_layer( + self.viewer.layers, + Image, + ) + points_layer = find_layer( + self.viewer.layers, + Points, + ) if image_layer is not None: ind = self.viewer.dims.current_step[0] frame = image_layer.data[ind] @@ -914,24 +964,27 @@ def _extract_single_frame(self, *args): def _store_crop_coordinates(self, *args): if not (project_path := self._images_meta.get("project")): return - for layer in self.viewer.layers: - if isinstance(layer, Shapes): - try: - ind = layer.shape_type.index("rectangle") - except ValueError: - return - bbox = layer.data[ind][:, 1:] - h = self.viewer.dims.range[2][1] - bbox[:, 0] = h - bbox[:, 0] - bbox = np.clip(bbox, 0, a_max=None).astype(int) - y1, x1 = bbox.min(axis=0) - y2, x2 = bbox.max(axis=0) - temp = {"crop": ", ".join(map(str, [x1, x2, y1, y2]))} - config_path = os.path.join(project_path, "config.yaml") - cfg = _load_config(config_path) - cfg["video_sets"][os.path.join(project_path, "videos", self._images_meta["name"])] = temp - _write_config(config_path, cfg) - break + layer = find_layer( + self.viewer.layers, + Shapes, + predicate=lambda lyr: "rectangle" in lyr.shape_type, + ) + if layer is not None: + try: + ind = layer.shape_type.index("rectangle") + except ValueError: + return + bbox = layer.data[ind][:, 1:] + h = self.viewer.dims.range[2][1] + bbox[:, 0] = h - bbox[:, 0] + bbox = np.clip(bbox, 0, a_max=None).astype(int) + y1, x1 = bbox.min(axis=0) + y2, x2 = bbox.max(axis=0) + temp = {"crop": ", ".join(map(str, [x1, x2, y1, y2]))} + config_path = os.path.join(project_path, "config.yaml") + cfg = _load_config(config_path) + cfg["video_sets"][os.path.join(project_path, "videos", self._images_meta["name"])] = temp + _write_config(config_path, cfg) def _form_dropdown_menus(self, store): menu = KeypointsDropdownMenu(store) @@ -1002,8 +1055,12 @@ def rgb2hex(r, g, b, _): if self.color_mode == str(keypoints.ColorMode.INDIVIDUAL): mode = "id" - for layer in self.viewer.layers: - if isinstance(layer, Points) and layer.metadata: + for layer in find_layers( + self.viewer.layers, + Points, + predicate=lambda lyr: "face_color_cycles" in lyr.metadata, + ): + if layer: self._display.update_color_scheme( {name: to_hex(color) for name, color in layer.metadata["face_color_cycles"][mode].items()} ) @@ -1166,9 +1223,8 @@ def on_insert(self, event): colormap_selector.currentTextChanged.connect(self._update_colormap) point_controls.layout().addRow("colormap", colormap_selector) - for layer_ in self.viewer.layers: - if not isinstance(layer_, Image): - self._remap_frame_indices(layer_) + for layer_ in find_layers_of_types(self.viewer.layers, [Points, Tracks, Shapes]): + self._remap_frame_indices(layer_) def on_remove(self, event): layer = event.value @@ -1214,21 +1270,25 @@ def on_active_layer_change(self, event) -> None: menu.setHidden(True) def _update_colormap(self, colormap_name): - for layer in self.viewer.layers.selection: - if isinstance(layer, Points) and layer.metadata: - face_color_cycle_maps = build_color_cycles( - layer.metadata["header"], - colormap_name, - ) - layer.metadata["face_color_cycles"] = face_color_cycle_maps - face_color_prop = "label" - if self.color_mode == str(keypoints.ColorMode.INDIVIDUAL): - face_color_prop = "id" - - layer.face_color = face_color_prop - layer.face_color_cycle = face_color_cycle_maps[face_color_prop] - layer.events.face_color() - self._update_color_scheme() + layers = find_layers( + self.viewer.layers.selection, + Points, + predicate=lambda lyr: lyr.metadata, + ) + for layer in layers: + face_color_cycle_maps = build_color_cycles( + layer.metadata["header"], + colormap_name, + ) + layer.metadata["face_color_cycles"] = face_color_cycle_maps + face_color_prop = "label" + if self.color_mode == str(keypoints.ColorMode.INDIVIDUAL): + face_color_prop = "id" + + layer.face_color = face_color_prop + layer.face_color_cycle = face_color_cycle_maps[face_color_prop] + layer.events.face_color() + self._update_color_scheme() @register_points_action("Change labeling mode") def cycle_through_label_modes(self, *args): @@ -1271,11 +1331,14 @@ def color_mode(self, mode: str | keypoints.ColorMode): else: face_color_mode = "id" - for layer in self.viewer.layers: - if isinstance(layer, Points) and layer.metadata: - layer.face_color = face_color_mode - layer.face_color_cycle = layer.metadata["face_color_cycles"][face_color_mode] - layer.events.face_color() + for layer in find_layers( + self.viewer.layers, + Points, + predicate=lambda lyr: "face_color_cycles" in lyr.metadata, + ): + layer.face_color = face_color_mode + layer.face_color_cycle = layer.metadata["face_color_cycles"][face_color_mode] + layer.events.face_color() for btn in self._color_mode_selector.buttons(): if btn.text().lower() == str(mode).lower(): diff --git a/src/napari_deeplabcut/misc.py b/src/napari_deeplabcut/misc.py index a510d695..3809527a 100644 --- a/src/napari_deeplabcut/misc.py +++ b/src/napari_deeplabcut/misc.py @@ -1,15 +1,18 @@ from __future__ import annotations import os -from collections.abc import Sequence +from collections.abc import Callable, Iterable, Sequence from enum import Enum, EnumMeta from itertools import cycle from pathlib import Path +from typing import TypeVar, overload import numpy as np import pandas as pd from napari.utils import colormaps +T = TypeVar("T") + def find_project_config_path(labeled_data_path: str) -> str: return str(Path(labeled_data_path).parents[2] / "config.yaml") @@ -211,3 +214,79 @@ def _generate_next_value_(name, start, count, last_values): def __str__(self): return self.value + + +@overload +def find_layer( + layers: Iterable[object], + layer_type: type[T], + *, + predicate: Callable[[T], bool] | None = None, + default: T | None = None, +) -> T | None: ... + + +def find_layer( + layers: Iterable[object], + layer_type: type[T], + *, + predicate: Callable[[T], bool] | None = None, + default: T | None = None, +) -> T | None: + """ + Return the first layer in `layers` that is an instance of `layer_type`. + + Parameters + ---------- + layers: + Any iterable of layer-like objects (e.g. viewer.layers). + layer_type: + The class/type to match (e.g. napari.layers.Image). + predicate: + Optional filter called on matching layers. + default: + Value returned if no match is found. + + Returns + ------- + The first matching layer, else `default`. + """ + for layer in layers: + if isinstance(layer, layer_type): + if predicate is None or predicate(layer): + return layer + return default + + +def find_layers( + layers: Iterable[object], + layer_type: type[T], + *, + predicate: Callable[[T], bool] | None = None, +) -> list[T]: + """ + Return all layers in `layers` that are instances of `layer_type`. + """ + out: list[T] = [] + for layer in layers: + if isinstance(layer, layer_type): + if predicate is None or predicate(layer): + out.append(layer) + return out + + +def find_layers_of_types( + layers: Iterable[object], + layer_types: Iterable[type[T]], + *, + predicate: Callable[[T], bool] | None = None, +) -> list[T]: + """ + Return all layers in `layers` that are instances of any of `layer_types`. + """ + out: list[T] = [] + for layer in layers: + if any(isinstance(layer, layer_type) for layer_type in layer_types): + if predicate is None or predicate(layer): + out.append(layer) + return out From 9b90e3ee56347ea569c69365e616051521403518 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 22 Jan 2026 14:11:03 +0100 Subject: [PATCH 2/8] Optimize plot update for hidden matplotlib canvas Added a visibility check to KeypointMatplotlibCanvas.update_plot_range to prevent unnecessary updates when the widget is hidden, unless forced. Ensures plot updates only occur when the canvas is visible or explicitly requested, improving performance. --- src/napari_deeplabcut/_widgets.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index 4d135a20..dd85b0df 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -574,7 +574,9 @@ def set_window(self, value): self.slider_value.setText(str(value)) self.update_plot_range(Event(type_name="", value=[self._n])) - def update_plot_range(self, event): + def update_plot_range(self, event, force=False): + if not self.isVisible() and not force: + return value = event.value[0] self._n = value @@ -781,6 +783,9 @@ def _show_matplotlib_canvas(self, state): if Qt.CheckState(state) == Qt.CheckState.Checked: self._ensure_mpl_canvas_docked() if self._mpl_docked: + self._matplotlib_canvas.update_plot_range( + Event(type_name="", value=[self.viewer.dims.current_step[0]]), force=True + ) self._matplotlib_canvas.show() else: if self._mpl_docked: From 3eee41544a1ad1809612a6d4bae528d154a58578 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 22 Jan 2026 14:31:13 +0100 Subject: [PATCH 3/8] Add tests for plot refresh behavior in canvas --- src/napari_deeplabcut/_tests/test_widgets.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/napari_deeplabcut/_tests/test_widgets.py b/src/napari_deeplabcut/_tests/test_widgets.py index 05f60f56..ef1b2e9f 100644 --- a/src/napari_deeplabcut/_tests/test_widgets.py +++ b/src/napari_deeplabcut/_tests/test_widgets.py @@ -218,8 +218,11 @@ def test_matplotlib_canvas_initialization_and_slider(viewer, points, qtbot): assert canvas._window == initial_window + 100 assert canvas.slider_value.text() == str(initial_window + 100) - # Test plot refresh on frame change + # Test plot refresh does nothing when plot is hidden canvas.update_plot_range(event=type("Event", (), {"value": [5]})) + assert canvas._n == 0 + # Test plot refresh on frame change (forced as it is hidden) + canvas.update_plot_range(event=type("Event", (), {"value": [5]}), force=True) assert canvas._n == 5 # Check that x-limits reflect the new window start, end = canvas.ax.get_xlim() From 11fb46010ce33e904ffd81696d0c14824bf60014 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 22 Jan 2026 15:20:19 +0100 Subject: [PATCH 4/8] Fix points layer data checks and logic in widgets Improves robustness by handling cases where the points layer or its data attribute may be None or empty. Also corrects logical checks for found neighbors and simplifies color scheme updates. --- src/napari_deeplabcut/_widgets.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index dd85b0df..b88a7ed7 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -501,8 +501,8 @@ def _load_dataframe(self): self.viewer.layers, Points, ) - - if points_layer is None or not np.any(points_layer.data): + data = getattr(points_layer, "data", None) + if points_layer is None or data is None or len(data) == 0: return self.show() @@ -837,7 +837,8 @@ def _map_keypoints(self, super_animal: str): Points, predicate=lambda lyr: lyr.metadata.get("tables") is not None, ) - if points_layer is None or ~np.any(points_layer.data): + data = getattr(points_layer, "data", None) + if points_layer is None or data is None or not np.any(data): return xy = points_layer.data[:, 1:3] @@ -845,7 +846,7 @@ def _map_keypoints(self, super_animal: str): xy_ref = np.c_[[val for val in superkpts_dict.values()]] neighbors = keypoints._find_nearest_neighbors(xy, xy_ref) found = neighbors != -1 - if ~np.any(found): + if not found.any(): return project_path = points_layer.metadata["project"] @@ -1065,10 +1066,9 @@ def rgb2hex(r, g, b, _): Points, predicate=lambda lyr: "face_color_cycles" in lyr.metadata, ): - if layer: - self._display.update_color_scheme( - {name: to_hex(color) for name, color in layer.metadata["face_color_cycles"][mode].items()} - ) + self._display.update_color_scheme( + {name: to_hex(color) for name, color in layer.metadata["face_color_cycles"][mode].items()} + ) def _remap_frame_indices(self, layer): if not self._images_meta.get("paths"): From 2385277ceefbc4adacdaf01bc1097a9dee76bd3e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Thu, 22 Jan 2026 15:26:16 +0100 Subject: [PATCH 5/8] Use is_numeric_dtype for image_paths type check (pandas >=3.0) Replaces np.issubdtype with pandas' is_numeric_dtype to more robustly check if image_paths is numeric in read_hdf. This improves compatibility with different pandas index types. --- src/napari_deeplabcut/_reader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/napari_deeplabcut/_reader.py b/src/napari_deeplabcut/_reader.py index aca3b7bc..9fb3f439 100644 --- a/src/napari_deeplabcut/_reader.py +++ b/src/napari_deeplabcut/_reader.py @@ -11,6 +11,7 @@ from dask_image.imread import imread from napari.types import LayerData from natsort import natsorted +from pandas.api.types import is_numeric_dtype from napari_deeplabcut import misc @@ -228,7 +229,7 @@ def read_hdf(filename: str) -> list[LayerData]: nrows = df.shape[0] data = np.empty((nrows, 3)) image_paths = df["level_0"] - if np.issubdtype(image_paths.dtype, np.number): + if is_numeric_dtype(getattr(image_paths, "dtype", np.asarray(image_paths).dtype)): image_inds = image_paths.values paths2inds = [] else: From 708a455f845b85f7572b4ffa6d86f0b24be94a0e Mon Sep 17 00:00:00 2001 From: Cyril Achard Date: Thu, 22 Jan 2026 16:36:47 +0100 Subject: [PATCH 6/8] Update src/napari_deeplabcut/_widgets.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/napari_deeplabcut/_widgets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index b88a7ed7..d9e80f68 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -838,7 +838,7 @@ def _map_keypoints(self, super_animal: str): predicate=lambda lyr: lyr.metadata.get("tables") is not None, ) data = getattr(points_layer, "data", None) - if points_layer is None or data is None or not np.any(data): + if points_layer is None or data is None or not data.any(): return xy = points_layer.data[:, 1:3] From 06861b0dad1f2454f446620049b61f23ed34dadc Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 23 Jan 2026 16:36:16 +0100 Subject: [PATCH 7/8] Improve robustness in keypoint plotting and data checks Enhanced the handling of keypoint trajectory extraction in KeypointMatplotlibCanvas to ensure only 1D trajectories are plotted and to skip empty selections. Updated data presence checks in KeypointControls to use .size for better reliability. Improved dtype handling in read_hdf for image_paths, and made shape_type checks more robust for Shapes layers. --- src/napari_deeplabcut/_reader.py | 5 ++++- src/napari_deeplabcut/_widgets.py | 21 ++++++++++++++++----- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/napari_deeplabcut/_reader.py b/src/napari_deeplabcut/_reader.py index 9fb3f439..721d466d 100644 --- a/src/napari_deeplabcut/_reader.py +++ b/src/napari_deeplabcut/_reader.py @@ -229,7 +229,10 @@ def read_hdf(filename: str) -> list[LayerData]: nrows = df.shape[0] data = np.empty((nrows, 3)) image_paths = df["level_0"] - if is_numeric_dtype(getattr(image_paths, "dtype", np.asarray(image_paths).dtype)): + dtype = getattr(image_paths, "dtype", None) + if dtype is None: + dtype = np.asarray(image_paths).dtype + if is_numeric_dtype(dtype): image_inds = image_paths.values paths2inds = [] else: diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index d9e80f68..cfde4a65 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -541,8 +541,19 @@ def _load_dataframe(self): # Plot trajectories self._lines = {} for keypoint in self.df.columns.get_level_values("bodyparts").unique(): - y = self.df.xs((keypoint, "y"), axis=1, level=["bodyparts", "coords"]).to_numpy().squeeze() - x = np.arange(len(y)) + y_sel = self.df.xs( + (keypoint, "y"), + axis=1, + level=["bodyparts", "coords"], + ) + if y_sel.empty: + continue + # Convert to numpy and enforce "one trajectory per keypoint" + y = np.atleast_1d(y_sel.to_numpy().squeeze()) + if y.ndim != 1: + raise ValueError(f"Expected 1D y trajectory for keypoint={keypoint!r}, got shape={y.shape}") + x = np.arange(y.size) + color = points_layer.metadata["face_color_cycles"]["label"][keypoint] lines = self.ax.plot(x, y, color=color, label=keypoint) self._lines[keypoint] = lines @@ -838,7 +849,7 @@ def _map_keypoints(self, super_animal: str): predicate=lambda lyr: lyr.metadata.get("tables") is not None, ) data = getattr(points_layer, "data", None) - if points_layer is None or data is None or not data.any(): + if points_layer is None or data is None or not data.size > 0: return xy = points_layer.data[:, 1:3] @@ -846,7 +857,7 @@ def _map_keypoints(self, super_animal: str): xy_ref = np.c_[[val for val in superkpts_dict.values()]] neighbors = keypoints._find_nearest_neighbors(xy, xy_ref) found = neighbors != -1 - if not found.any(): + if not found.size > 0: return project_path = points_layer.metadata["project"] @@ -973,7 +984,7 @@ def _store_crop_coordinates(self, *args): layer = find_layer( self.viewer.layers, Shapes, - predicate=lambda lyr: "rectangle" in lyr.shape_type, + predicate=lambda lyr: getattr(lyr, "shape_type", None) and "rectangle" in lyr.shape_type, ) if layer is not None: try: From 03543fa1842737cd47cc61123cb002e27b7cf15a Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 23 Jan 2026 16:47:02 +0100 Subject: [PATCH 8/8] Simplify keypoint trajectory plotting logic Refactored the trajectory plotting code in KeypointMatplotlibCanvas to remove unnecessary checks and conversions, streamlining the extraction and plotting of keypoint data. --- src/napari_deeplabcut/_widgets.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index cfde4a65..e7e9b502 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -541,19 +541,8 @@ def _load_dataframe(self): # Plot trajectories self._lines = {} for keypoint in self.df.columns.get_level_values("bodyparts").unique(): - y_sel = self.df.xs( - (keypoint, "y"), - axis=1, - level=["bodyparts", "coords"], - ) - if y_sel.empty: - continue - # Convert to numpy and enforce "one trajectory per keypoint" - y = np.atleast_1d(y_sel.to_numpy().squeeze()) - if y.ndim != 1: - raise ValueError(f"Expected 1D y trajectory for keypoint={keypoint!r}, got shape={y.shape}") - x = np.arange(y.size) - + y = self.df.xs((keypoint, "y"), axis=1, level=["bodyparts", "coords"]).to_numpy().squeeze() + x = np.arange(len(y)) color = points_layer.metadata["face_color_cycles"]["label"][keypoint] lines = self.ax.plot(x, y, color=color, label=keypoint) self._lines[keypoint] = lines