diff --git a/.gitignore b/.gitignore index 73d56d31..4b9390b3 100644 --- a/.gitignore +++ b/.gitignore @@ -82,3 +82,6 @@ venv/ # written by setuptools_scm **/_version.py + +# Package managers +uv.lock diff --git a/README.md b/README.md index b3d755e8..acfe2276 100644 --- a/README.md +++ b/README.md @@ -1,143 +1,230 @@ -# napari-deeplabcut: keypoint annotation for pose estimation +# napari-deeplabcut - Keypoint annotation tool for pose estimation +napari+deeplabcut - -napari+deeplabcut - -[📚Documentation](https://deeplabcut.github.io/DeepLabCut/README.html) | +[📚 Plugin Documentation](https://deeplabcut.github.io/DeepLabCut/docs/gui/napari_GUI.html) | [🛠️ DeepLabCut Installation](https://deeplabcut.github.io/DeepLabCut/docs/installation.html) | -[🌎 Home Page](https://www.deeplabcut.org) | +[🌎 DeepLabCut Home Page](https://www.deeplabcut.org) | [![License: LGPL-3.0](https://img.shields.io/badge/License-LGPL%203.0-blue.svg)](https://www.gnu.org/licenses/lgpl-3.0) [![PyPI](https://img.shields.io/pypi/v/napari-deeplabcut.svg?color=green)](https://pypi.org/project/napari-deeplabcut) [![Python Version](https://img.shields.io/pypi/pyversions/napari-deeplabcut.svg?color=green)](https://python.org) -[![tests](https://github.com/DeepLabCut/napari-deeplabcut/workflows/tests/badge.svg)](https://github.com/DeepLabCut/napari-deeplabcut/actions) +[![tests](https://github.com/DeepLabCut/napari-deeplabcut/actions/workflows/test_and_deploy.yml/badge.svg?branch=main)](https://github.com/DeepLabCut/napari-deeplabcut/actions/workflows/test_and_deploy.yml) [![codecov](https://codecov.io/gh/DeepLabCut/napari-deeplabcut/branch/main/graph/badge.svg)](https://codecov.io/gh/DeepLabCut/napari-deeplabcut) [![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/napari-deeplabcut)](https://napari-hub.org/plugins/napari-deeplabcut) -A napari plugin for keypoint annotation, also used within DeepLabCut! +A napari plugin for keypoint annotation and label refinement, also used within DeepLabCut. +--- ## Installation -If you installed DeepLabCut[gui], this plugin is already installed. However, you can also use this as a stand-alone keypoint annotator without using DeepLabCut. Instructions below! +If you installed `DeepLabCut[gui]`, this plugin is already included. + +You can also install `napari-deeplabcut` as a standalone keypoint annotation plugin without using the full DeepLabCut GUI. + +### Standard install -Start by installing PySide6 with `pip install "pyside6==6.4.2"`; this is the library we now use to build GUIs. +Using `pip` (e.g. in a `conda` environment): -You can then install `napari-deeplabcut` via [pip]: +```bash +pip install napari-deeplabcut +``` + +Using `uv`: - pip install napari-deeplabcut +```bash +uv venv -p 3.12 # create a new virtual environment with Python 3.12 +source .venv/bin/activate # activate the virtual environment. Use the relevant command for your OS/shell if different. +uv pip install napari-deeplabcut +``` +> [!NOTE] +> A conda environment or uv venv is not strictly required. Please use your preferred package manager! +### Latest development version -Alternatively, to install the latest development version, run: +Using `pip`: - pip install git+https://github.com/DeepLabCut/napari-deeplabcut.git +```bash +pip install git+https://github.com/DeepLabCut/napari-deeplabcut.git +``` +--- ## Usage -To use the plugin, please run: +Start napari: - napari +```bash +napari +``` + +Then activate the plugin in: -Then, activate the plugin in Plugins > napari-deeplabcut: Keypoint controls. +> **Plugins → napari-deeplabcut: Keypoint controls** -All accepted files (config.yaml, images, h5 data files) can be loaded -either by dropping them directly onto the canvas or via the File menu. +Accepted files such as `config.yaml`, image folders, videos, and `.h5` annotation files can be loaded either by dragging them onto the canvas or through the **File** menu. -The easiest way to get started is to drop a folder (typically a folder from within a DeepLabCut's `labeled-data` directory), and, if labeling from scratch, drop the corresponding `config.yaml` to automatically add a `Points layer` and populate the dropdown menus. +> [!TIP] +> The widget opens automatically when drag-and-dropping a compatible labeled-data folder. -[🎥 DEMO -](https://youtu.be/hsA9IB5r73E) +### Recommended way to get started -**Tools & shortcuts are:** +The easiest way to start labeling from scratch is: -- `2` and `3`, to easily switch between labeling and selection mode -- `4`, to enable pan & zoom (which is achieved using the mouse wheel or finger scrolling on the Trackpad) -- `M`, to cycle through regular (sequential), quick, and cycle annotation mode (see the description [here](https://github.com/DeepLabCut/napari-deeplabcut/blob/5a5709dd38868341568d66eab548ae8abf37cd63/src/napari_deeplabcut/keypoints.py#L25-L34)) -- `E`, to enable edge coloring (by default, if using this in refinement GUI mode, points with a confidence lower than 0.6 are marked -in red) -- `F`, to toggle between animal and body part color scheme. -- `V`, to toggle visibility of the selected layer. -- `backspace` to delete a point. -- Check the box "display text" to show the label names on the canvas. -- To move to another folder, be sure to save (`Ctrl+S`), then delete the layers, and re-drag/drop the next folder. -- One can jump to a specific image by double-clicking and editing the current frame number (located to the right of the slider). -- Selected points can be copied with `Ctrl+C`, and pasted onto other images with `Ctrl+V`. +1. Open (or drag-and-drop) an image-only folder from your computer, or within a DeepLabCut project's `labeled-data` directory + - This means that only the images are loaded, without any existing annotations +2. Open (or drag-and-drop) the `config.yaml` from your project +This creates: +- an **Image** layer with the images (or video frames), and +- an empty **Points** layer populated with the keypoint metadata from the config. -### Save Layers +You may then start annotating in the points layer that was created. -Annotations and segmentations are saved with `File > Save Selected Layer(s)...` (or its shortcut `Ctrl+S`). -Only when saving segmentation masks does a save file dialog pop up to name the destination folder; -keypoint annotations are otherwise automatically saved in the corresponding folder as `CollectedData_.h5`. -- As a reminder, DLC will only use the H5 file; so be sure if you open already labeled images you save/overwrite the H5. -- Note, before saving a layer, make sure the points layer is selected. If the user clicked on the image(s) layer first, does `Save As`, then closes the window, any labeling work during that session will be lost! -- Modifying and then saving points in a `machinelabels...` layer will add to or overwrite the existing `CollectedData` layer and will **not** save to the `machinelabels` file. +> [!NOTE] +> If you load a folder from outside a DeepLabCut project and try to save a Points layer, you will be prompted to provide the config.yaml file +> used by the project. You may then move the labeled data folder into your project directory for downstream use. +[🎥 DEMO](https://youtu.be/hsA9IB5r73E) -### Video frame extraction and prediction refinement +--- -Since v0.0.4, videos can be viewed in the GUI. +## Tools and shortcuts -Since v0.0.5, trailing points can be visualized; e.g., helping in the identification -of swaps or outlier, jittery predictions. +- `2` / `3`: switch between labeling and selection mode when a Points layer is active +- `4`: enable pan & zoom +- `M`: cycle through sequential, quick, and cycle annotation modes +- `E`: toggle edge coloring +- `F`: toggle between individual and body-part coloring modes +- `V`: toggle visibility of the selected layer +- `Backspace`: delete selected point(s) +- `Ctrl+C` / `Ctrl+V`: copy and paste selected points +- Double-click the current frame number to jump to a specific frame -Loading a video (and its corresponding output h5 file) will enable the video actions -at the top of the dock widget: they offer the option to manually extract video -frames from the GUI, or to define cropping coordinates. -Note that keypoints can be displaced and saved, as when annotating individual frames. +> [!TIP] +> Press the "View shortcuts" button in the dock widget for a reference. +Additional dock controls include: -## Workflow +- **Warn on overwrite**: enable or disable confirmation prompts when saving would overwrite existing annotations +- **Show trails**: display keypoint trails over time in the main viewer +- **Show trajectories**: open a trajectory plot in a separate dock widget +- **Show color scheme**: display the active/configured color mapping reference +- **Video tools**: extract the current frame and store crop coordinates for videos -Suggested workflows, depending on the image folder contents: +--- -1. **Labeling from scratch** – the image folder does not contain `CollectedData_.h5` file. +## Saving layers - Open *napari* as described in [Usage](#usage) and open an image folder together with the DeepLabCut project's `config.yaml`. - The image folder creates an *image layer* with the images to label. - Supported image formats are: `jpg`, `jpeg`, `png`. - The `config.yaml` file creates a *Points layer*, which holds metadata (such as keypoints read from the config file) necessary for labeling. - Select the *Points layer* in the layer list (lower left pane on the GUI) and click on the *+*-symbol in the layer controls menu (upper left pane) to start labeling. - The current keypoint can be viewed/selected in the keypoints dropdown menu (right pane). - The slider below the displayed image (or the left/right arrow keys) allows selecting the image to label. +Use: - To save the labeling progress refer to [Save Layers](#save-layers). - `Data successfully saved` should be shown in the status bar, and the image folder should now contain a `CollectedData_.h5` file. - (Note: For convenience, a CSV file with the same name is also saved.) +> **File → Save Selected Layer(s)...** -2. **Resuming labeling** – the image folder contains a `CollectedData_.h5` file. +or the shortcut: + +```text +Ctrl+S +``` - Open *napari* and open an image folder (which needs to contain a `CollectedData_.h5` file). - In this case, it is not necessary to open the DLC project's `config.yaml` file, as all necessary metadata is read from the `h5` data file. +### Keypoint save behavior + +Keypoint annotations are automatically saved into the corresponding dataset folder as: + +```text +CollectedData_.h5 +``` - Saving works as described in *1*. +For convenience, the companion `.csv` file is written in the same folder. - ***Note that if a new body part has been added to the `config.yaml` file after having started to label, loading the config in the GUI is necessary to update the dropdown menus and other metadata.*** +### Important notes - ***As `viridis` is `napari-deeplabcut` default colormap, loading the config in the GUI is also needed to update the color scheme.*** +- DeepLabCut uses the **H5** file as the authoritative annotation file. +- Before saving, make sure the **Points** layer you want to save is selected. + - The plugin will not save if several Points layers are selected at the same time, to avoid ambiguity. +- Saving a `machinelabels...` layer does **not** write back to the machine labels file. + Instead, refined annotations are written into the appropriate `CollectedData...` file. +- If saving would overwrite existing annotations, the plugin will prompt for confirmation. + - While labeling, confirmation can be disabled by unchecking the "Warn on overwrite" option in the dock widget. +- Several plugin functions implicitly expect `config.yaml` to be present two folders up from the saved `CollectedData...` file, so make sure to keep the config in the project directory structure for best results. Fallback behaviors are present but may not cover all edge cases. + - If you save a Points layer without a config file present in the expected location, you will be prompted to provide the path to the config file that matches the dataset you are working on. The plugin will then save the points and metadata into the correct folder based on the config path provided. Afterwards, it is recommended to move the dataset folder into the correct location within the project directory structure for best compatibility with other DeepLabCut functions. Please edit the `config.yaml` file if needed to update the paths to the videos and image folders. -3. **Refining labels** – the image folder contains a `machinelabels-iter<#>.h5` file. +--- - The process is analog to *2*. - Open *napari* and open an image folder. - If the video was originally labeled, *and* had outliers extracted it will contain a `CollectedData_.h5` file and a `machinelabels-iter<#>.h5` file. In this case, select the `machinelabels` layer in the GUI, and type `e` to show edges. Red indicates likelihood < 0.6. As you navigate through frames, images with labels with edges will need to be refined (moved, deleted, etc). Images with labels without edges will be on the `CollectedData` (previous manual annotations) layer and shouldn't need refining. However, you can switch to that layer and fix errors. You can also right-click on the `CollectedData` layer and select `toggle visibility` to hide that layer. Select the `machinelabels` layer before saving which will append your refined annotations to `CollectedData`. +## Video support + +Videos can be opened directly in the GUI. + +When a video is loaded, the plugin enables a small video action panel that can be used to: + +- Extract the current frame into the dataset +- Optionally export existing machine labels for that frame +- Define and save crop coordinates to the DLC `config.yaml` + +Keypoints in video-based workflows can be edited and saved in the same way as ordinary image-folder workflows. + +--- + +## Workflow (outside of DLC GUI) + +Suggested workflows depend on what is already present in the dataset folder. + +Please note this describes the workflow when napari is launched as a standalone application, outside of the DeepLabCut GUI. + +### 1) Labeling from scratch + +Use this when the image folder does **not** yet contain a `CollectedData_.h5` file. + +1. Open a folder of extracted images +2. Open the corresponding DeepLabCut `config.yaml` +3. Select the created **Points** layer +4. Start labeling +5. Save the points layer with `Ctrl+S` + +After saving, the folder should now contain: + +```text +CollectedData_.h5 +CollectedData_.csv +``` - If the folder only had outliers extracted and wasn't originally labeled, it will not have a `CollectedData` layer. Work with the `machinelabels` layer selected to refine annotation positions, then save. +--- - In this case, it is not necessary to open the DLC project's `config.yaml` file, as all necessary metadata is read from the `h5` data file. +### 2) Resuming labeling - Saving works as described in *1*. +Use this when the folder already contains a `CollectedData_.h5` file. -4. **Drawing segmentation masks** +Open (or drag-and-drop) the folder in napari. The existing keypoint metadata and annotations will be loaded from the H5 file, so loading `config.yaml` is not needed nor recommended. - Drop an image folder as in *1*, manually add a *shapes layer*. Then select the *rectangle* in the layer controls (top left pane), - and start drawing rectangles over the images. Masks and rectangle vertices are saved as described in [Save Layers](#save-layers). - Note that masks can be reloaded and edited at a later stage by dropping the `vertices.csv` file onto the canvas. +However, loading the config is still useful if: -### Workflow flowchart +- The project’s bodyparts changed +- You would like to refresh the configured color scheme from the project config + +--- + +### 3) Refining machine labels + +Use this when the folder contains a machine predictions file such as: + +```text +machinelabels-iter<...>.h5 +``` + +Open the folder in napari. + +If both a `CollectedData...` file and a `machinelabels...` file are present: + +- Edit the `machinelabels` layer to refine predictions +- Optionally use edge coloring (`E`) to highlight low-confidence labels +- Save the selected `machinelabels` layer to merge refinements into `CollectedData` + +If the folder contains only `machinelabels...` and no `CollectedData...`, refined annotations will still be saved into a new `CollectedData...` target. + +--- + +## Workflow flowchart ```mermaid %%{init: {"flowchart": {"htmlLabels": false}} }%% @@ -145,81 +232,100 @@ graph TD id1[What stage of labeling?] id2[deeplabcut.label_frames] id3[deeplabcut.refine_labels] - id4[Add labels to, or modify in, \n `CollectedData...` layer and save that layer] - id5[Modify labels in `machinelabels` layer and save \n which will create a `CollectedData...` file] + id4[Add labels to, or modify in, + `CollectedData...` layer and save that layer] + id5[Modify labels in `machinelabels` layer and save + which will create or update `CollectedData...`] id6[Have you refined some labels from the most recent iteration and saved already?] id7["All extracted frames are already saved in `CollectedData...`. -1. Hide or trash all `machinelabels` layers. -2. Then modify in and save `CollectedData`"] - id8[" -1. hide or trash all `machinelabels` layers except for the most recent. -2. Select most recent `machinelabels` and hit `e` to show edges. -3. Modify only in `machinelabels` and skip frames with labels without edges shown. -4. Save `machinelabels` layer, which will add data to `CollectedData`. - - If you need to revisit this video later, ignore `machinelabels` and work only in `CollectedData`"] - - id1 -->|I need to manually label new frames \n or fix my labels|id2 - id1 ---->|I need to refine outlier frames \nfrom analyzed videos|id3 - id2 -->id4 +1. Hide or remove all `machinelabels` layers. +2. Continue working in `CollectedData`."] + id8["1. Keep only the most recent `machinelabels` layer. +2. Select it and press `E` to show edges. +3. Refine labels in `machinelabels`. +4. Save to merge into `CollectedData`. +- If you revisit the dataset later, you can continue working in `CollectedData`."] + + id1 -->|I need to manually label new frames + or fix existing labels|id2 + id1 -->|I need to refine outlier frames + from analyzed videos|id3 + id2 --> id4 id3 -->|I only have a `machinelabels...` file|id5 - id3 ---->|I have both `machinelabels` and `CollectedData` files|id6 + id3 -->|I have both `machinelabels` and `CollectedData` files|id6 id6 -->|yes|id7 - id6 ---->|no, I just extracted outliers|id8 + id6 -->|no, I just extracted outliers|id8 ``` -### Labeling multiple image folders +--- + +## Labeling multiple image folders + +Only one dataset folder should be worked on at a time. + +After finishing a folder: -Labeling multiple image folders has to be done in sequence; i.e., only one image folder can be opened at a time. -After labeling the images of a particular folder is done and the associated *Points layer* has been saved, *all* layers should be removed from the layers list (lower left pane on the GUI) by selecting them and clicking on the trashcan icon. -Now, another image folder can be labeled, following the process described in *1*, *2*, or *3*, depending on the particular image folder. +1. Save the relevant **Points** layer +2. Remove the current layers from the viewer +3. Open the next folder +This keeps plugin operation and saving unambiguous. -### Defining cropping coordinates +--- -Prior to defining cropping coordinates, two elements should be loaded in the GUI: -a video and the DLC project's `config.yaml` file (into which the crop dimensions will be stored). -Then it suffices to add a `Shapes layer`, draw a `rectangle` in it with the desired area, -and hit the button `Store crop coordinates`; coordinates are automatically written to the configuration file. +## Defining crop coordinates +To store crop coordinates in a DLC project: + +1. Open the video from the project’s `videos` folder +2. Enable cropping in the video tools +3. Draw a rectangle in the newly created crop layer (the tool is selected by default) +4. Click **Store crop coordinates** after checking the coordinates in the widget. + +The crop coordinates are then written back to the project configuration. + +--- ## Contributing -Contributions are very welcome. Tests can be run with [tox], please ensure -the coverage at least stays the same before you submit a pull request. +Contributions are welcome. -To locally install the code, please git clone the repo and then run `pip install -e .` +Tests can be run locally with [tox]. +Please note we use pre-commit hooks to run linters and formatters on changed files, so make sure to install the pre-commit dependencies: -## License +```bash +pip install pre-commit +pre-commit install +``` -Distributed under the terms of the [BSD-3] license, -"napari-deeplabcut" is free and open source software. +### Development install -## Issues +Clone the repository and install it in editable mode. + +Using `pip`: -If you encounter any problems, please [file an issue] along with a detailed description. +```bash +pip install -e . +``` -[file an issue]: https://github.com/DeepLabCut/napari-deeplabcut/issues +If you need development dependencies as well, use the project’s `dev` extra: +```bash +pip install -e .[dev] +``` -## Acknowledgements +## License +Distributed under the terms of the [LGPL-3.0](https://www.gnu.org/licenses/lgpl-3.0). -This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template. We thank the Chan Zuckerberg Initiative (CZI) for funding the initial development of this work! +## Issues + +If you encounter any problems, please [file an issue](https://github.com/DeepLabCut/napari-deeplabcut/issues) with a detailed description and, if possible, a minimal reproducible example. - +This [napari](https://github.com/napari/napari) plugin was originally generated with [Cookiecutter](https://github.com/audreyr/cookiecutter) using [@napari](https://github.com/napari)'s [cookiecutter-napari-plugin](https://github.com/napari/cookiecutter-napari-plugin) template. +We thank the Chan Zuckerberg Initiative (CZI) for funding the initial development of this work! -[napari]: https://github.com/napari/napari -[Cookiecutter]: https://github.com/audreyr/cookiecutter -[@napari]: https://github.com/napari -[cookiecutter-napari-plugin]: https://github.com/napari/cookiecutter-napari-plugin -[BSD-3]: http://opensource.org/licenses/BSD-3-Clause [tox]: https://tox.readthedocs.io/en/latest/ -[pip]: https://pypi.org/project/pip/ -[PyPI]: https://pypi.org/ diff --git a/pyproject.toml b/pyproject.toml index 9f3a7f13..1256114c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ email = "admin@deeplabcut.org" [project.entry-points."napari.manifest"] napari-deeplabcut = "napari_deeplabcut:napari.yaml" [project.optional-dependencies] -testing = [ +dev = [ "pillow", "pytest", "pytest-cov", @@ -97,3 +97,7 @@ table_format = "long" [tool.pytest.ini_options] qt_api = "pyside6" +testpaths = [ "src/napari_deeplabcut/_tests" ] +markers = [ + "e2e: end-to-end tests. Invoke napari viewer fixtures, slow.", +] diff --git a/src/napari_deeplabcut/__init__.py b/src/napari_deeplabcut/__init__.py index 0c5dbf49..aac9a753 100644 --- a/src/napari_deeplabcut/__init__.py +++ b/src/napari_deeplabcut/__init__.py @@ -12,8 +12,7 @@ "get_hdf_reader", "get_image_reader", "get_video_reader", - "write_hdf", - "write_masks", + "write_hdf_napari_dlc", "__version__", ) @@ -26,8 +25,7 @@ get_video_reader, ) from ._writer import ( # noqa: F401 (explicit re-export via __all__) - write_hdf, - write_masks, + write_hdf_napari_dlc, ) try: diff --git a/src/napari_deeplabcut/_reader.py b/src/napari_deeplabcut/_reader.py index fdc04585..6dcff093 100644 --- a/src/napari_deeplabcut/_reader.py +++ b/src/napari_deeplabcut/_reader.py @@ -1,44 +1,41 @@ -import json -from collections.abc import Callable, Sequence -from pathlib import Path - -import cv2 -import dask.array as da -import numpy as np -import pandas as pd -import yaml -from dask import delayed -from dask_image.imread import imread -from napari.types import LayerData -from natsort import natsorted +# src/napari_deeplabcut/_reader.py +"""Readers for DeepLabCut data formats.""" -from napari_deeplabcut import misc +from __future__ import annotations -SUPPORTED_IMAGES = ".jpg", ".jpeg", ".png" -SUPPORTED_VIDEOS = ".mp4", ".mov", ".avi" +import logging +from pathlib import Path +from napari_deeplabcut.config._autostart import maybe_install_keypoint_controls_autostart +from napari_deeplabcut.core.discovery import discover_annotations +from napari_deeplabcut.core.io import ( + SUPPORTED_IMAGES, + SUPPORTED_VIDEOS, + read_config, + read_hdf, + read_hdf_single, + read_images, + read_video, +) +from napari_deeplabcut.core.project_paths import looks_like_dlc_labeled_folder -def is_video(filename: str): - return any(filename.lower().endswith(ext) for ext in SUPPORTED_VIDEOS) +logger = logging.getLogger(__name__) def get_hdf_reader(path): if isinstance(path, list): path = path[0] - - if not path.endswith(".h5"): + if not str(path).endswith(".h5"): return None - + maybe_install_keypoint_controls_autostart() return read_hdf def get_image_reader(path): if isinstance(path, list): path = path[0] - - if not any(path.lower().endswith(ext) for ext in SUPPORTED_IMAGES): + if not any(str(path).lower().endswith(ext) for ext in SUPPORTED_IMAGES): return None - return read_images @@ -51,454 +48,68 @@ def get_video_reader(path): def get_config_reader(path): if isinstance(path, list): path = path[0] - - if not path.endswith(".yaml"): + if not str(path).endswith(".yaml"): return None - + maybe_install_keypoint_controls_autostart() return read_config -def _filter_extensions( - image_paths: list[str | Path], - valid_extensions: tuple[str] = SUPPORTED_IMAGES, -) -> list[Path]: - """ - Filter image paths by valid extensions. - """ +def _filter_extensions(image_paths, valid_extensions=SUPPORTED_IMAGES) -> list[Path]: return [Path(p) for p in image_paths if Path(p).suffix.lower() in valid_extensions] def get_folder_parser(path): if not path or not Path(path).is_dir(): return None - layers = [] + if not looks_like_dlc_labeled_folder(path): + # TODO: @C-Achard raise explaining why it is not considered a DLC-labeled folder + # e.g. no h5 files, no labeled-data in path, etc. + return None + layers = [] images = _filter_extensions(Path(path).iterdir(), valid_extensions=SUPPORTED_IMAGES) if not images: - raise OSError(f"No supported images were found in {path} with extensions {SUPPORTED_IMAGES}.") - - image_layer = read_images(images) - layers.extend(image_layer) - for file in Path(path).iterdir(): - if file.name.endswith(".h5"): - try: - layers.extend(read_hdf(str(file))) - break # one h5 per annotated video - except Exception as e: - raise RuntimeError(f"Could not read annotation data from {file}") from e - return lambda _: layers - - -# Helper functions for lazy image reading and normalization -# NOTE : forced keyword-only arguments for clarity -def _read_and_normalize(*, filepath: Path, normalize_func: Callable[[np.ndarray], np.ndarray]) -> np.ndarray: - arr = cv2.imread(str(filepath), cv2.IMREAD_UNCHANGED) - if arr is None: - raise OSError(f"Could not read image: {filepath}") - return normalize_func(arr) - - -def _normalize_to_rgb(arr: np.ndarray) -> np.ndarray: - if arr.ndim == 2: - return cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB) - if arr.ndim == 3 and arr.shape[2] == 4: - return cv2.cvtColor(arr, cv2.COLOR_BGRA2RGB) - return cv2.cvtColor(arr, cv2.COLOR_BGR2RGB) - - -def _expand_image_paths(path: str | Path | list[str | Path] | tuple[str | Path, ...]) -> list[Path]: - # Normalize input to list[Path] - raw_paths = [Path(p) for p in path] if isinstance(path, (list, tuple)) else [Path(path)] - - expanded: list[Path] = [] - for p in raw_paths: - if p.is_dir() and p.suffix.lower() != ".zarr": - file_matches: list[Path] = [] - for ext in SUPPORTED_IMAGES: - file_matches.extend(p.glob(f"*{ext}")) - expanded.extend(x for x in natsorted(file_matches, key=str) if x.is_file()) - else: - matches = list(p.parent.glob(p.name)) - expanded.extend(matches or [p]) - - return [p for p in expanded if p.is_file() and p.suffix.lower() in SUPPORTED_IMAGES] - - -# Lazy image reader that supports directories and lists of files -def _lazy_imread( - filenames: str | Path | list[str | Path], - use_dask: bool | None = None, - stack: bool = True, -) -> np.ndarray | da.Array | list[np.ndarray | da.Array]: - """Lazily reads one or more images with optional Dask support. - - Resolves file paths using `_expand_image_paths`, ensuring consistent - handling of directories, glob patterns, and lists/tuples of paths. - Images are normalized to RGB and may be wrapped in Dask delayed - objects for lazy loading. - - Behavior: - * If a single image is resolved: - - The image is read eagerly and returned as a NumPy array. - * If multiple images are resolved: - - The first image is read eagerly to determine shape and dtype. - - Subsequent images are loaded lazily via Dask unless - `use_dask=False`. - - Stacking behavior is controlled by `stack`. - - Args: - filenames (str | Path | list[str | Path]): - File path(s), directory, or glob pattern(s) to load. - use_dask (bool | None, optional): - Whether to load images lazily using Dask. - Defaults to `True` when multiple files are found, otherwise - `False`. - stack (bool, optional): - If True, stack images along axis 0 into a single array. - If False, return a list of arrays or delayed arrays. - Defaults to True. - - Returns: - np.ndarray | da.Array | list[np.ndarray | da.Array]: - Loaded image data. The return type depends on the number of - images found, the `use_dask` flag, and the `stack` option. - - Raises: - ValueError: If no supported images are found. - """ - expanded = _expand_image_paths(filenames) - - if not expanded: - raise ValueError(f"No supported images were found for input: {filenames}") - - if use_dask is None: - use_dask = len(expanded) > 1 - - images = [] - first_shape = None - first_dtype = None - - def make_delayed_array(fp: Path, first_shape: tuple[int, ...], first_dtype: np.dtype) -> da.Array: - """Create a dask array for a single file.""" - return da.from_delayed( - delayed(_read_and_normalize)(filepath=fp, normalize_func=_normalize_to_rgb), - shape=first_shape, - dtype=first_dtype, - ) - - for fp in expanded: - if first_shape is None: - arr0 = _read_and_normalize(filepath=fp, normalize_func=_normalize_to_rgb) - first_shape = arr0.shape - first_dtype = arr0.dtype - - if use_dask: - images.append(make_delayed_array(fp, first_shape, first_dtype)) - else: - images.append(arr0) - continue - - if use_dask: - images.append(make_delayed_array(fp, first_shape, first_dtype)) - else: - images.append(_read_and_normalize(filepath=fp, normalize_func=_normalize_to_rgb)) - - if len(images) == 1: - return images[0] - - try: - return da.stack(images) if use_dask and stack else (np.stack(images) if stack else images) - except ValueError as e: - raise ValueError( - "Cannot stack images with different shapes using NumPy. " - "Ensure all images have the same shape or set stack=False." - ) from e - - -# Read images from a list of files or a glob/string path -def read_images(path: str | Path | list[str | Path]): - """Reads one or multiple images and returns a Napari Image layer. - - Uses `_expand_image_paths` to resolve the input into a list of valid - image files. Supports single paths, glob expressions, directories, - and lists or tuples of such paths. - - Behavior: - * If one file is found: - - Loaded using `dask_image.imread.imread`. - * If multiple files are found: - - Loaded lazily using `lazy_imread` into a stacked image - layer. - - Args: - path (str | Path | list[str | Path]): - Input path(s), directory, or glob pattern(s) to expand into - supported image files. - - Returns: - list[LayerData]: - A list containing one Napari layer tuple of the form - `(data, metadata, "image")`. - - Raises: - OSError: If no supported images are found after expansion. - """ - filepaths = _expand_image_paths(path) - - if not filepaths: - raise OSError(f"No supported images were found in {path}") - - filepaths = natsorted(filepaths, key=str) - - # Multiple images → lazy-imread stack - if len(filepaths) > 1: - relative_paths = [misc.canonicalize_path(fp, 3) for fp in filepaths] - params = { - "name": "images", - "metadata": { - "paths": relative_paths, - "root": str(filepaths[0].parent), - }, - } - data = _lazy_imread(filepaths, use_dask=True, stack=True) - return [(data, params, "image")] - - # Single image → old behavior - image_path = filepaths[0] - params = { - "name": "images", - "metadata": { - "paths": [misc.canonicalize_path(image_path, 3)], - "root": str(image_path.parent), - }, - } - return [(imread(str(image_path)), params, "image")] - - -# Helper to populate keypoint layer metadata -def _populate_metadata( - header: misc.DLCHeader, - *, - labels: Sequence[str] | None = None, - ids: Sequence[str] | None = None, - likelihood: Sequence[float] | None = None, - paths: list[str] | None = None, - size: int | None = 8, - pcutoff: float | None = 0.6, - colormap: str | None = "viridis", -) -> dict: - if labels is None: - labels = header.bodyparts - if ids is None: - ids = header.individuals - if likelihood is None: - likelihood = np.ones(len(labels)) - face_color_cycle_maps = misc.build_color_cycles(header, colormap) - face_color_prop = "id" if ids[0] else "label" - return { - "name": "keypoints", - "text": "{id}–{label}" if ids[0] else "label", - "properties": { - "label": list(labels), - "id": list(ids), - "likelihood": likelihood, - "valid": likelihood > pcutoff, - }, - "face_color_cycle": face_color_cycle_maps[face_color_prop], - "face_color": face_color_prop, - "face_colormap": colormap, - "border_color": "valid", - "border_color_cycle": ["black", "red"], - "border_width": 0, - "border_width_is_relative": False, - "size": size, - "metadata": { - "header": header, - "face_color_cycles": face_color_cycle_maps, - "colormap_name": colormap, - "paths": paths or [], - }, - } - - -def _load_superkeypoints_diagram(super_animal: str): - path = str(Path(__file__).parent / "assets" / f"{super_animal}.jpg") - try: - return imread(path), {"root": ""}, "images" - except Exception as e: - raise FileNotFoundError(f"Superkeypoints diagram not found for {super_animal}.") from e - - -def _load_superkeypoints(super_animal: str): - path = str(Path(__file__).parent / "assets" / f"{super_animal}.json") - if not Path(path).is_file(): - raise FileNotFoundError(f"Superkeypoints JSON file not found for {super_animal}.") - with open(path) as f: - return json.load(f) - - -def _load_config(config_path: str): - with open(config_path) as file: - return yaml.safe_load(file) - - -# Read config file and create keypoint layer metadata -def read_config(configname: str) -> list[LayerData]: - config = _load_config(configname) - header = misc.DLCHeader.from_config(config) - metadata = _populate_metadata( - header, - size=config["dotsize"], - pcutoff=config["pcutoff"], - colormap=config["colormap"], - likelihood=np.array([1]), - ) - metadata["name"] = f"CollectedData_{config['scorer']}" - metadata["ndim"] = 3 - metadata["property_choices"] = metadata.pop("properties") - metadata["metadata"]["project"] = str(Path(configname).parent) - conversion_tables = config.get("SuperAnimalConversionTables") - if conversion_tables is not None: - super_animal, table = conversion_tables.popitem() - metadata["metadata"]["tables"] = {super_animal: table} - return [(None, metadata, "points")] - - -# Read HDF file and create keypoint layers -def read_hdf(filename: str) -> list[LayerData]: - config_path = misc.find_project_config_path(filename) - layers = [] - for file in Path(filename).parent.glob(Path(filename).name): - temp = pd.read_hdf(str(file)) - temp = misc.merge_multiple_scorers(temp) - header = misc.DLCHeader(temp.columns) - temp = temp.droplevel("scorer", axis=1) - if "individuals" not in temp.columns.names: - # Append a fake level to the MultiIndex - # to make it look like a multi-animal DataFrame - old_idx = temp.columns.to_frame() - old_idx.insert(0, "individuals", "") - temp.columns = pd.MultiIndex.from_frame(old_idx) - try: - cfg = _load_config(config_path) - colormap = cfg["colormap"] - except FileNotFoundError: - colormap = "rainbow" - else: - colormap = "Set3" - if isinstance(temp.index, pd.MultiIndex): - temp.index = [str(Path(*row)) for row in temp.index] - df = ( - temp.stack(["individuals", "bodyparts"]) - .reindex(header.individuals, level="individuals") - .reindex(header.bodyparts, level="bodyparts") - .reset_index() - ) - nrows = df.shape[0] - data = np.empty((nrows, 3)) - image_paths = df["level_0"] - if pd.api.types.is_numeric_dtype(getattr(image_paths, "dtype", np.asarray(image_paths).dtype)): - image_inds = image_paths.values - paths2inds = [] + has_video = any(Path(path).glob(f"*{ext}") for ext in SUPPORTED_VIDEOS) + if has_video: + logger.info( + "No supported images found in '%s' (extensions: %s). " + "A supported video appears to be present; open the video directly to view frames.", + path, + SUPPORTED_IMAGES, + ) else: - image_inds, paths2inds = misc.encode_categories( - image_paths, - is_path=True, - return_unique=True, - do_sort=True, + logger.warning( + "No supported images found in '%s' (extensions: %s), and no supported videos found (extensions: %s).", + path, + SUPPORTED_IMAGES, + SUPPORTED_VIDEOS, ) + return None - data[:, 0] = image_inds - data[:, 1:] = df[["y", "x"]].to_numpy() - metadata = _populate_metadata( - header, - labels=df["bodyparts"], - ids=df["individuals"], - likelihood=df.get("likelihood"), - paths=list(paths2inds), - colormap=colormap, - ) - metadata["name"] = Path(filename).stem - metadata["metadata"]["root"] = str(Path(filename).parent) - # Store file name in case the layer's name is edited by the user - metadata["metadata"]["name"] = metadata["name"] - layers.append((data, metadata, "points")) - return layers - - -# Video reader using OpenCV -class Video: - def __init__(self, video_path): - if not Path(video_path).is_file(): - raise ValueError(f'Video path "{video_path}" does not point to a file.') - - self.path = video_path - self.stream = cv2.VideoCapture(video_path) - if not self.stream.isOpened(): - raise OSError("Video could not be opened.") - - self._n_frames = int(self.stream.get(cv2.CAP_PROP_FRAME_COUNT)) - self._width = int(self.stream.get(cv2.CAP_PROP_FRAME_WIDTH)) - self._height = int(self.stream.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self._frame = cv2.UMat(self._height, self._width, cv2.CV_8UC3) - - def __len__(self): - return self._n_frames - - @property - def width(self): - return self._width - - @property - def height(self): - return self._height - - def set_to_frame(self, ind): - ind = min(ind, len(self) - 1) - ind += 1 # Unclear why this is needed at all - self.stream.set(cv2.CAP_PROP_POS_FRAMES, ind) - - def read_frame(self): - self.stream.retrieve(self._frame) - cv2.cvtColor(self._frame, cv2.COLOR_BGR2RGB, self._frame, 3) - return self._frame.get() - - def close(self): - self.stream.release() - - -def read_video(filename: str, opencv: bool = True): - if opencv: - stream = Video(filename) - # NOTE construct output shape tuple in (H, W, C) order to match read_frame() data - shape = stream.height, stream.width, 3 + layers.extend(read_images(images)) - def _read_frame(ind): - stream.set_to_frame(ind) - return stream.read_frame() + # Deterministic discovery: load ALL H5 artifacts + artifacts = discover_annotations(path) + h5_artifacts = [(Path(a.h5_path), a.kind) for a in artifacts if a.h5_path is not None] - lazy_reader = delayed(_read_frame) - else: # pragma: no cover - from pims import PyAVReaderIndexed + if not h5_artifacts: + return lambda _: layers + errors = [] + for h5_path, kind in h5_artifacts: try: - stream = PyAVReaderIndexed(filename) - except ImportError: - raise ImportError("`pip install av` to use the PyAV video reader.") from None + layers.extend(read_hdf_single(h5_path, kind=kind)) + except Exception as e: + logger.debug("Could not read annotation data from %s", h5_path, exc_info=True) + errors.append((RuntimeError, f"Could not read annotation data from {h5_path}", e)) - shape = stream.frame_shape - lazy_reader = delayed(stream.get_frame) + n_points_layers = sum(1 for _, _, layer_type in layers if layer_type == "points") + if n_points_layers == 0 and errors: + exc_type, msg, cause = errors[0] + raise exc_type(msg) from cause - movie = da.stack([da.from_delayed(lazy_reader(i), shape=shape, dtype=np.uint8) for i in range(len(stream))]) - elems = list(Path(filename).parts) - elems[-2] = "labeled-data" - elems[-1] = Path(elems[-1]).stem # + Path(filename).suffix - root = str(Path(*elems)) - params = { - "name": filename, - "metadata": { - "root": root, - }, - } - return [(movie, params, "image")] + if n_points_layers > 0: + maybe_install_keypoint_controls_autostart() + + return lambda _: layers diff --git a/src/napari_deeplabcut/_tests/compat/conftest.py b/src/napari_deeplabcut/_tests/compat/conftest.py new file mode 100644 index 00000000..a274130e --- /dev/null +++ b/src/napari_deeplabcut/_tests/compat/conftest.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +import sys +import types +from dataclasses import dataclass +from types import SimpleNamespace + +import numpy as np +import pandas as pd +import pytest +from qtpy.QtWidgets import ( + QCheckBox, + QFormLayout, + QLabel, + QLineEdit, + QWidget, +) + + +class DummyDropdown(QWidget): + """Simple QWidget-based dropdown stand-in for QFormLayout.addRow().""" + + def __init__(self, colormaps, parent=None): + super().__init__(parent) + self.colormaps = colormaps + self.updated_to = None + + def update_to(self, value): + self.updated_to = value + + +@pytest.fixture +def dropdown_cls(): + return DummyDropdown + + +@pytest.fixture +def plt_module(): + # minimal matplotlib-like object for compat function + return SimpleNamespace(colormaps=["viridis", "magma", "plasma"]) + + +class _FaceColorControl: + def __init__(self): + self.face_color_edit = QLineEdit() + self.face_color_label = QLabel("face") + + +class _BorderColorControl: + def __init__(self): + self.border_color_edit = QLineEdit() + self.border_color_edit_label = QLabel("border") + + +class _OutSliceCheckboxControl: + def __init__(self): + self.out_of_slice_checkbox = QCheckBox("out") + self.out_of_slice_checkbox_label = QLabel("out label") + + +class DummyPointControls(QWidget): + def __init__(self): + super().__init__() + self._layout = QFormLayout(self) + self.setLayout(self._layout) + + self._face_color_control = _FaceColorControl() + self._border_color_control = _BorderColorControl() + self._out_slice_checkbox_control = _OutSliceCheckboxControl() + + def layout(self): + return self._layout + + +@pytest.fixture +def ui_env(qtbot): + """ + Minimal viewer/layer environment for apply_points_layer_ui_tweaks(). + Uses real Qt widgets and a tiny viewer-shaped object. + """ + layer = object() + + point_controls = DummyPointControls() + qtbot.addWidget(point_controls) + + dock_layer_controls = SimpleNamespace(widget=lambda: SimpleNamespace(widgets={layer: point_controls})) + viewer = SimpleNamespace(window=SimpleNamespace(qt_viewer=SimpleNamespace(dockLayerControls=dock_layer_controls))) + + layer_obj = SimpleNamespace(metadata={"colormap_name": "magma"}) + + plt_module = SimpleNamespace(colormaps=["viridis", "magma", "plasma"]) + + return SimpleNamespace( + viewer=viewer, + layer_key=layer, + layer_obj=layer_obj, + point_controls=point_controls, + dropdown_cls=DummyDropdown, + plt_module=plt_module, + ) + + +@dataclass(frozen=True) +class Keypoint: + label: object + id: object + + +class Recorder: + def __init__(self): + self.calls = [] + + def __call__(self, *args, **kwargs): + self.calls.append((args, kwargs)) + + +class DummyFeatureTable: + def __init__(self): + self.appended = [] + + def append(self, features): + self.appended.append(features.copy()) + + +class DummyText: + def __init__(self): + self.calls = [] + + def _paste(self, **kwargs): + self.calls.append(kwargs) + + +class DummyColorManager: + def __init__(self): + self.calls = [] + + def _paste(self, **kwargs): + self.calls.append(kwargs) + + +class DummyLayerForPaste: + def __init__(self): + # existing layer state + self.data = np.array([[1, 2, 3]], dtype=float) + self.shown = np.array([True], dtype=bool) + self.size = np.array([5.0], dtype=float) + self.symbol = np.array(["o"], dtype=object) + self.edge_width = np.array([0.5], dtype=float) + + # these are the attrs mutated by make_paste_data() + self._data = self.data.copy() + self._shown = self.shown.copy() + self._size = self.size.copy() + self._symbol = self.symbol.copy() + self._edge_width = self.edge_width.copy() + + self._view_data = np.array([[1, 2, 3]], dtype=float) + + self._slice_input = SimpleNamespace(not_displayed=[1]) + self._slice_indices = np.array([0, 7, 0]) + + self._feature_table = DummyFeatureTable() + self.text = DummyText() + self._edge = DummyColorManager() + self._face = DummyColorManager() + + self._selected_view = [] + self._selected_data = set() + + self.refresh_count = 0 + + self._clipboard = { + "features": pd.DataFrame( + { + "label": ["nose", "tail"], + "id": [1, 2], + } + ), + "indices": np.array([0, 5, 0]), + "text": { + "string": np.array(["nose-1", "tail-2"], dtype=object), + "color": "white", + }, + "data": np.array( + [ + [10.0, 20.0, 30.0], + [40.0, 50.0, 60.0], + ] + ), + "shown": np.array([True, False], dtype=bool), + "size": np.array([3.0, 4.0], dtype=float), + "symbol": np.array(["x", "+"], dtype=object), + "edge_width": np.array([1.0, 2.0], dtype=float), + "edge_color": np.array( + [ + [1.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 1.0], + ] + ), + "face_color": np.array( + [ + [0.5, 0.0, 0.0, 1.0], + [0.0, 0.5, 0.0, 1.0], + ] + ), + } + + def refresh(self): + self.refresh_count += 1 + + +@pytest.fixture +def paste_env(monkeypatch): + """ + Environment for make_paste_data(). + + We monkeypatch the private napari helper import because compat code imports it + inside the returned closure. + """ + fake_layer_utils = types.ModuleType("napari.layers.utils.layer_utils") + fake_layer_utils._features_to_properties = lambda features: { + col: features[col].to_numpy() for col in features.columns + } + monkeypatch.setitem(sys.modules, "napari.layers.utils.layer_utils", fake_layer_utils) + + recolor_calls = [] + + def schedule_recolor(layer): + recolor_calls.append(layer) + + controls = SimpleNamespace( + np=np, + keypoints=SimpleNamespace(Keypoint=Keypoint), + _schedule_recolor=schedule_recolor, + ) + + store_layer = object() + store = SimpleNamespace( + annotated_keypoints={Keypoint("nose", 1)}, + layer=store_layer, + ) + + layer = DummyLayerForPaste() + + return SimpleNamespace( + controls=controls, + store=store, + layer=layer, + recolor_calls=recolor_calls, + Keypoint=Keypoint, + ) diff --git a/src/napari_deeplabcut/_tests/compat/test_compat_integration.py b/src/napari_deeplabcut/_tests/compat/test_compat_integration.py new file mode 100644 index 00000000..076428d9 --- /dev/null +++ b/src/napari_deeplabcut/_tests/compat/test_compat_integration.py @@ -0,0 +1,165 @@ +"""Integration/smoke tests for napari API compatibility code. +Validate the overrides and monkeypatches don't crash and have the intended effect on napari layers""" +# src/napari_deeplabcut/_tests/compat/test_compat_integration.py + +from __future__ import annotations + +import numpy as np +import pytest + +from napari_deeplabcut.napari_compat import ( + apply_points_layer_ui_tweaks, + install_add_wrapper, + install_paste_patch, +) + + +def _get_point_controls(viewer, layer): + return viewer.window.qt_viewer.dockLayerControls.widget().widgets[layer] + + +def test_apply_points_layer_ui_tweaks_smoke_real_viewer(viewer, qtbot, dropdown_cls, plt_module): + """Smoke test against a real napari viewer + real Points controls. + + This should fail if napari's private control wiring changed in a way that + breaks our compat layer for supported versions. + """ + layer = viewer.add_points( + np.array([[1.0, 2.0], [3.0, 4.0]]), + features={"label": ["a", "b"], "id": [1, 2]}, + ) + layer.metadata["colormap_name"] = "magma" + + # Ensure this layer is the active one so napari builds/shows its controls. + viewer.layers.selection.active = layer + + qtbot.waitUntil( + lambda: layer in viewer.window.qt_viewer.dockLayerControls.widget().widgets, + timeout=3000, + ) + + point_controls = _get_point_controls(viewer, layer) + before_rows = point_controls.layout().rowCount() + + selector = apply_points_layer_ui_tweaks( + viewer, + layer, + dropdown_cls=dropdown_cls, + plt_module=plt_module, + ) + + # True smoke-test expectation: + # on supported napari versions, we expect the compat hook to really wire in. + assert selector is not None + assert selector.updated_to == "magma" + assert point_controls.layout().rowCount() == before_rows + 1 + + # These are the real private widgets our compat layer is supposed to hide. + assert point_controls._face_color_control.face_color_edit.isHidden() + assert point_controls._face_color_control.face_color_label.isHidden() + assert point_controls._border_color_control.border_color_edit.isHidden() + assert point_controls._border_color_control.border_color_edit_label.isHidden() + assert point_controls._out_slice_checkbox_control.out_of_slice_checkbox.isHidden() + assert point_controls._out_slice_checkbox_control.out_of_slice_checkbox_label.isHidden() + + +def test_install_add_wrapper_smoke_real_points_layer(viewer): + """Smoke test that method rebinding works on a real napari Points layer.""" + layer = viewer.add_points(np.array([[0.0, 0.0]])) + + calls = [] + + def add_impl(*args, **kwargs): + calls.append((args, kwargs)) + return "added" + + def schedule_recolor(layer_obj): + calls.append(layer_obj) + + install_add_wrapper(layer, add_impl=add_impl, schedule_recolor=schedule_recolor) + + # Bound method really installed on a real layer + assert layer.add.__self__ is layer + + payload = np.array([[1.0, 2.0]]) + result = layer.add(payload, source="smoke-test") + + assert result == "added" + + add_args, add_kwargs = calls[0] + np.testing.assert_array_equal(add_args[0], payload) + assert add_kwargs == {"source": "smoke-test"} + + assert calls[1] is layer + + +def test_install_paste_patch_smoke_real_points_layer(viewer): + """Smoke test that _paste_data can be rebound on a real napari Points layer.""" + layer = viewer.add_points(np.array([[0.0, 0.0]])) + + seen = [] + + def paste_func(this): + seen.append(this) + + install_paste_patch(layer, paste_func=paste_func) + + assert layer._paste_data.__self__ is layer + + layer._paste_data() + + assert seen == [layer] + + +@pytest.mark.xfail(reason="This test is fixed in a subsequent PR, to be added") +def test_apply_points_layer_ui_tweaks_real_dropdown(qtbot): + from types import SimpleNamespace + + import matplotlib.pyplot as plt + from qtpy.QtWidgets import QCheckBox, QFormLayout, QLabel, QLineEdit, QWidget + + from napari_deeplabcut.ui.labels_and_dropdown import DropdownMenu + + class FaceColorControl: + def __init__(self): + self.face_color_edit = QLineEdit() + self.face_color_label = QLabel("face") + + class BorderColorControl: + def __init__(self): + self.border_color_edit = QLineEdit() + self.border_color_edit_label = QLabel("border") + + class OutSliceControl: + def __init__(self): + self.out_of_slice_checkbox = QCheckBox("out") + self.out_of_slice_checkbox_label = QLabel("out label") + + class PointControls(QWidget): + def __init__(self): + super().__init__() + self._layout = QFormLayout(self) + self.setLayout(self._layout) + self._face_color_control = FaceColorControl() + self._border_color_control = BorderColorControl() + self._out_slice_checkbox_control = OutSliceControl() + + def layout(self): + return self._layout + + layer = SimpleNamespace(metadata={"colormap_name": "magma"}) + point_controls = PointControls() + qtbot.addWidget(point_controls) + + dock_layer_controls = SimpleNamespace(widget=lambda: SimpleNamespace(widgets={layer: point_controls})) + viewer = SimpleNamespace(window=SimpleNamespace(qt_viewer=SimpleNamespace(dockLayerControls=dock_layer_controls))) + + selector = apply_points_layer_ui_tweaks(viewer, layer, dropdown_cls=DropdownMenu, plt_module=plt) + assert selector is not None + + assert point_controls._face_color_control.face_color_edit.isHidden() + assert point_controls._face_color_control.face_color_label.isHidden() + assert point_controls._border_color_control.border_color_edit.isHidden() + assert point_controls._border_color_control.border_color_edit_label.isHidden() + assert point_controls._out_slice_checkbox_control.out_of_slice_checkbox.isHidden() + assert point_controls._out_slice_checkbox_control.out_of_slice_checkbox_label.isHidden() diff --git a/src/napari_deeplabcut/_tests/compat/test_compat_internal.py b/src/napari_deeplabcut/_tests/compat/test_compat_internal.py new file mode 100644 index 00000000..47ec907b --- /dev/null +++ b/src/napari_deeplabcut/_tests/compat/test_compat_internal.py @@ -0,0 +1,181 @@ +"""Internal unit tests for compat module logic. Does not check actual napari integration.""" + +# src/napari_deeplabcut/_tests/compat/test_compat_internal.py +from __future__ import annotations + +from types import SimpleNamespace + +import numpy as np +import pandas as pd +import pytest + +from napari_deeplabcut.napari_compat import ( + apply_points_layer_ui_tweaks, + install_add_wrapper, + install_paste_patch, +) +from napari_deeplabcut.napari_compat.points_layer import make_paste_data + + +def test_apply_points_layer_ui_tweaks_returns_none_when_viewer_shape_is_missing(): + viewer = SimpleNamespace(window=SimpleNamespace(qt_viewer=SimpleNamespace())) + layer = SimpleNamespace(metadata={}) + + result = apply_points_layer_ui_tweaks( + viewer, + layer, + dropdown_cls=object, + plt_module=SimpleNamespace(colormaps=[]), + ) + + assert result is None + + +def test_install_add_wrapper_calls_add_impl_and_schedule_recolor(): + calls = [] + + class Layer: + pass + + layer = Layer() + + def add_impl(*args, **kwargs): + calls.append(("add_impl", args, kwargs)) + return "added" + + def schedule_recolor(obj): + calls.append(("schedule_recolor", obj)) + + install_add_wrapper(layer, add_impl=add_impl, schedule_recolor=schedule_recolor) + + result = layer.add(1, 2, kind="point") + + assert result == "added" + assert calls[0] == ("add_impl", (1, 2), {"kind": "point"}) + assert calls[1] == ("schedule_recolor", layer) + + +def test_install_add_wrapper_swallows_schedule_recolor_errors(): + class Layer: + pass + + layer = Layer() + add_called = [] + + def add_impl(*args, **kwargs): + add_called.append((args, kwargs)) + return 123 + + def schedule_recolor(_layer): + raise RuntimeError("boom") + + install_add_wrapper(layer, add_impl=add_impl, schedule_recolor=schedule_recolor) + + assert layer.add("x") == 123 + assert add_called == [(("x",), {})] + + +def test_install_paste_patch_binds_method_to_layer(): + class Layer: + pass + + layer = Layer() + seen = [] + + def paste_func(this): + seen.append(this) + + install_paste_patch(layer, paste_func=paste_func) + + layer._paste_data() + + assert seen == [layer] + + +def test_make_paste_data_returns_early_when_all_points_are_annotated(monkeypatch): + import sys + import types + from dataclasses import dataclass + + fake_layer_utils = types.ModuleType("napari.layers.utils.layer_utils") + fake_layer_utils._features_to_properties = lambda features: { + col: features[col].to_numpy() for col in features.columns + } + monkeypatch.setitem(sys.modules, "napari.layers.utils.layer_utils", fake_layer_utils) + + @dataclass(frozen=True) + class Keypoint: + label: object + id: object + + recolor_calls = [] + + controls = SimpleNamespace( + np=np, + keypoints=SimpleNamespace(Keypoint=Keypoint), + _schedule_recolor=lambda layer: recolor_calls.append(layer), + ) + + store = SimpleNamespace( + annotated_keypoints={Keypoint("nose", 1)}, + layer=object(), + ) + + layer = SimpleNamespace( + _clipboard={ + "features": pd.DataFrame({"label": ["nose"], "id": [1]}), + "indices": np.array([0, 0]), + "text": None, + }, + data=np.array([[1.0, 2.0]]), + shown=np.array([True]), + size=np.array([1.0]), + symbol=np.array(["o"], dtype=object), + edge_width=np.array([1.0]), + _view_data=np.array([[1.0, 2.0]]), + refresh=lambda: pytest.fail("refresh should not be called"), + ) + + paste = make_paste_data(controls, store=store) + paste(layer) + + assert recolor_calls == [] + # features are popped before the early return; smoke test just checks no crash/no recolor + + +def test_make_paste_data_pastes_only_unannotated_points_and_recolors(paste_env): + paste = make_paste_data(paste_env.controls, store=paste_env.store) + paste(paste_env.layer) + + layer = paste_env.layer + + # original 1 point + only 1 pasted point survives (tail, id=2) + assert layer._data.shape == (2, 3) + np.testing.assert_allclose(layer._data[0], [1.0, 2.0, 3.0]) + + # clipboard indices were [0, 5, 0], current slice idx is [0, 7, 0], so +2 on axis 1 + np.testing.assert_allclose(layer._data[1], [40.0, 52.0, 60.0]) + + np.testing.assert_array_equal(layer._shown, np.array([True, False])) + np.testing.assert_allclose(layer._size, np.array([5.0, 4.0])) + np.testing.assert_array_equal(layer._symbol, np.array(["o", "+"], dtype=object)) + np.testing.assert_allclose(layer._edge_width, np.array([0.5, 2.0])) + + assert len(layer._feature_table.appended) == 1 + appended_features = layer._feature_table.appended[0] + assert list(appended_features["label"]) == ["tail"] + assert list(appended_features["id"]) == [2] + + assert len(layer.text.calls) == 1 + assert layer.text.calls[0]["color"] == "white" + np.testing.assert_array_equal(layer.text.calls[0]["string"], np.array(["tail-2"], dtype=object)) + + assert len(layer._edge.calls) == 1 + assert len(layer._face.calls) == 1 + np.testing.assert_array_equal(layer._edge.calls[0]["colors"], np.array([[0.0, 1.0, 0.0, 1.0]])) + np.testing.assert_array_equal(layer._face.calls[0]["colors"], np.array([[0.0, 0.5, 0.0, 1.0]])) + + assert layer._selected_view == [1] + assert layer._selected_data == {1} + assert layer.refresh_count == 1 + assert paste_env.recolor_calls == [paste_env.store.layer] diff --git a/src/napari_deeplabcut/_tests/config/test_keybinds.py b/src/napari_deeplabcut/_tests/config/test_keybinds.py new file mode 100644 index 00000000..4e6c0b88 --- /dev/null +++ b/src/napari_deeplabcut/_tests/config/test_keybinds.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import numpy as np + +import napari_deeplabcut.config.keybinds as keybinds + + +class DummyLayer: + def __init__(self): + self.bound = [] + + def bind_key(self, key, callback, overwrite=False): + self.bound.append( + { + "key": key, + "callback": callback, + "overwrite": overwrite, + } + ) + + +def test_iter_shortcuts_returns_registry(): + shortcuts = tuple(keybinds.iter_shortcuts()) + assert shortcuts == keybinds.SHORTCUTS + assert shortcuts, "SHORTCUTS should not be empty" + + +def test_shortcuts_registry_points_layer_entries_have_callbacks(): + for spec in keybinds.SHORTCUTS: + assert spec.keys + assert spec.description + assert spec.group + assert spec.scope in {"points-layer", "global-points"} + + if spec.scope == "points-layer": + assert spec.get_callback is not None + assert spec.action is not None + + +def test_shortcuts_registry_has_no_duplicate_keys_within_scope(): + seen = set() + + for spec in keybinds.SHORTCUTS: + for key in spec.keys: + item = (spec.scope, key) + assert item not in seen, f"Duplicate shortcut declared for scope/key: {item}" + seen.add(item) + + +def test_bind_each_key_binds_all_keys(): + layer = DummyLayer() + + def callback(): + return None + + keybinds._bind_each_key(layer, ("A", "B"), callback, overwrite=True) + + assert layer.bound == [ + {"key": "A", "callback": callback, "overwrite": True}, + {"key": "B", "callback": callback, "overwrite": True}, + ] + + +def test_install_points_layer_keybindings_binds_registry_declared_shortcuts(): + layer = DummyLayer() + + controls = SimpleNamespace( + cycle_through_label_modes=object(), + cycle_through_color_modes=object(), + ) + store = SimpleNamespace( + next_keypoint=object(), + prev_keypoint=object(), + _find_first_unlabeled_frame=object(), + ) + + keybinds.install_points_layer_keybindings(layer, controls, store) + + expected = [] + ctx = keybinds.BindingContext(controls=controls, store=store) + + for spec in keybinds.SHORTCUTS: + if spec.scope != "points-layer": + continue + callback = spec.get_callback(ctx) + for key in spec.keys: + expected.append( + { + "key": key, + "callback": callback, + "overwrite": spec.overwrite, + } + ) + + assert layer.bound == expected + + +def test_callback_resolvers_return_expected_methods(): + controls = SimpleNamespace( + cycle_through_label_modes=object(), + cycle_through_color_modes=object(), + ) + store = SimpleNamespace( + next_keypoint=object(), + prev_keypoint=object(), + _find_first_unlabeled_frame=object(), + ) + ctx = keybinds.BindingContext(controls=controls, store=store) + + assert keybinds._cycle_label_mode(ctx) is controls.cycle_through_label_modes + assert keybinds._cycle_color_mode(ctx) is controls.cycle_through_color_modes + assert keybinds._next_keypoint(ctx) is store.next_keypoint + assert keybinds._prev_keypoint(ctx) is store.prev_keypoint + assert keybinds._jump_unlabeled_frame(ctx) is store._find_first_unlabeled_frame + + +def test_toggle_edge_color_toggles_between_0_and_2(): + layer = SimpleNamespace(border_width=np.array([0, 2, 0, 2])) + + keybinds.toggle_edge_color(layer) + np.testing.assert_array_equal(layer.border_width, np.array([2, 0, 2, 0])) + + keybinds.toggle_edge_color(layer) + np.testing.assert_array_equal(layer.border_width, np.array([0, 2, 0, 2])) + + +def test_install_global_points_keybindings_installs_once(monkeypatch): + calls = [] + + class DummyPoints: + @staticmethod + def bind_key(key): + calls.append(("bind_key", key)) + + def decorator(callback): + calls.append(("decorated", key, callback)) + return callback + + return decorator + + monkeypatch.setattr(keybinds, "Points", DummyPoints) + monkeypatch.setattr(keybinds, "_global_points_bindings_installed", False) + + keybinds.install_global_points_keybindings() + first_calls = list(calls) + + keybinds.install_global_points_keybindings() + second_calls = list(calls) + + assert first_calls, "Expected at least one global binding registration" + assert second_calls == first_calls, "Second install should be a no-op" + + +def test_global_shortcuts_are_not_missing_install_support(): + """ + Help ensure all global shortcuts have corresponding installation code + by asserting that all global actions are accounted for. + """ + global_actions = {spec.action for spec in keybinds.SHORTCUTS if spec.scope == "global-points"} + assert global_actions == {keybinds.ShortcutAction.TOGGLE_EDGE_COLOR} diff --git a/src/napari_deeplabcut/_tests/conftest.py b/src/napari_deeplabcut/_tests/conftest.py index a5485251..0b149102 100644 --- a/src/napari_deeplabcut/_tests/conftest.py +++ b/src/napari_deeplabcut/_tests/conftest.py @@ -1,5 +1,6 @@ # src/napari_deeplabcut/_tests/conftest.py import json +import logging import os from pathlib import Path @@ -7,11 +8,13 @@ import numpy as np import pandas as pd import pytest -from PIL import Image -from qtpy.QtWidgets import QDockWidget +from qtpy.QtWidgets import QApplication, QDockWidget from skimage.io import imsave -from napari_deeplabcut import _writer, keypoints +from napari_deeplabcut.config.models import DLCHeaderModel +from napari_deeplabcut.config.settings import set_auto_open_keypoint_controls +from napari_deeplabcut.core import io as io +from napari_deeplabcut.core import keypoints # os.environ["NAPARI_DLC_HIDE_TUTORIAL"] = "True" # no longer on by default @@ -20,6 +23,89 @@ # os.environ["QT_QPA_PLATFORM"] = "offscreen" # headless QT for CI # os.environ["QT_OPENGL"] = "software" # avoid some CI issues with OpenGL # os.environ["PYTEST_QT_API"] = "pyqt6" # only for local testing with pyqt6, we use pyside6 otherwise +logging.getLogger("napari_deeplabcut").propagate = True +logging.getLogger("napari-deeplabcut").propagate = True + + +def force_show(widget, qtbot, *, process_ms: int = 50): + """ + Best-effort show of a widget and all its Qt parents, even under + QT_QPA_PLATFORM=offscreen. + + This does NOT require real screen exposure; it only ensures the widget + is no longer hidden and that Qt processes the resulting show events. + """ + chain = [] + w = widget + while w is not None: + chain.append(w) + w = w.parentWidget() + + # Show from top-most parent down so child visibility can resolve. + for w in reversed(chain): + try: + w.show() + except RuntimeError: + pass + + QApplication.processEvents() + qtbot.wait(process_ms) + QApplication.processEvents() + + return widget + + +@pytest.fixture(autouse=True) +def disable_auto_open_keypoint_controls(): + """Disable auto-opening of keypoint controls in tests by default.""" + original_value = set_auto_open_keypoint_controls(False) + yield + set_auto_open_keypoint_controls(original_value) + + +@pytest.fixture(autouse=True) +def only_deeplabcut_debug_logs(): + """ + Show DEBUG logs only for napari-deeplabcut. + Suppress DEBUG from all other libraries. + """ + logging.getLogger() + + # Store original levels + original_levels = {} + + try: + for name, logger in logging.root.manager.loggerDict.items(): + if not isinstance(logger, logging.Logger): + continue + + original_levels[name] = logger.level + + if not (name.startswith("napari_deeplabcut") or name.startswith("napari-deeplabcut")): + logger.setLevel(logging.INFO) + + # Ensure our plugin is verbose + logging.getLogger("napari_deeplabcut").setLevel(logging.DEBUG) + + yield + finally: + # Restore original logger levels + for name, level in original_levels.items(): + logger = logging.getLogger(name) + logger.setLevel(level) + + +def make_real_header(bodyparts=("bodypart1", "bodypart2"), individuals=("",), scorer="S"): + cols = pd.MultiIndex.from_product( + [[scorer], list(individuals), list(bodyparts), ["x", "y"]], + names=["scorer", "individuals", "bodyparts", "coords"], + ) + return DLCHeaderModel(columns=cols) + + +@pytest.fixture +def make_real_header_factory(): + return make_real_header @pytest.fixture @@ -52,6 +138,27 @@ def viewer(make_napari_viewer_proxy): pass +@pytest.fixture +def keypoint_controls_and_dock(viewer): + dock, controls = viewer.window.add_plugin_dock_widget( + "napari-deeplabcut", + "Keypoint controls", + ) + return controls, dock + + +@pytest.fixture +def keypoint_controls(keypoint_controls_and_dock): + controls, _dock = keypoint_controls_and_dock + return controls + + +@pytest.fixture +def keypoint_controls_dock(keypoint_controls_and_dock): + _controls, dock = keypoint_controls_and_dock + return dock + + @pytest.fixture def fake_keypoints(): n_rows = 10 @@ -133,7 +240,7 @@ def config_path(tmp_path_factory): }, } path = str(tmp_path_factory.mktemp("configs") / "config.yaml") - _writer._write_config( + io.write_config( path, params=cfg, ) @@ -158,75 +265,60 @@ def video_path(tmp_path_factory): @pytest.fixture -def superkeypoints_assets(tmp_path, monkeypatch): - """ - Create a fake module dir with the expected assets layout: +def superkeypoints_assets(): + super_animal = "superanimal_quadruped" + json_path = Path(__file__).resolve().parents[1] / "assets" / f"{super_animal}.json" + data = json.loads(json_path.read_text(encoding="utf-8")) + return {"data": data, "super_animal": super_animal} - module_dir/_reader_fake.py -> patched as __file__ - module_dir/assets/fake.json - module_dir/assets/fake.jpg - This mirrors the code under test: - Path(__file__).parent / "assets" / f"{super_animal}.json|.jpg" - """ - module_dir = tmp_path / "module" - assets_dir = module_dir / "assets" - assets_dir.mkdir(parents=True) - - super_animal = "fake" - data = { - "SK1": [10.0, 20.0], - "SK2": [40.0, 60.0], - } - - # JSON with superkeypoints coordinates - (assets_dir / f"{super_animal}.json").write_text(json.dumps(data)) - - # Small 10x10 RGB diagram - Image.new("RGB", (10, 10), "white").save(assets_dir / f"{super_animal}.jpg") - - # Patch the module's __file__ so that Path(__file__).parent == module_dir - fake_module_file = module_dir / "_reader_fake.py" - fake_module_file.write_text("# fake") - monkeypatch.setattr("napari_deeplabcut._reader.__file__", str(fake_module_file)) +@pytest.fixture +def single_animal_project(tmp_path: Path): + project = tmp_path / "project_single" + project.mkdir(parents=True, exist_ok=True) - return { - "module_dir": module_dir, - "assets_dir": assets_dir, - "super_animal": super_animal, - "data": data, + cfg = { + "scorer": "John", + "dotsize": 8, + "pcutoff": 0.6, + "colormap": "viridis", + "bodyparts": ["cfg1", "cfg2"], + "video_sets": {}, } + config_path = project / "config.yaml" + io.write_config(config_path, cfg) + return project, config_path + @pytest.fixture -def mapped_points(points, superkeypoints_assets, config_path): - """ - Return a DLC Points layer that is ready for _map_keypoints(): - - metadata['project'] is set (so the widget can write config.yaml) - - metadata['tables'] contains a mapping for two real bodyparts -> SK1/SK2 - - at least two rows have coordinates exactly on the SK1/SK2 positions - and their labels are set to those bodyparts, guaranteeing a neighbor match. +def multianimal_config_project(tmp_path: Path): """ - layer = points # DLC layer created via viewer.open(..., plugin="napari-deeplabcut") - super_animal = superkeypoints_assets["super_animal"] - superkpts = superkeypoints_assets["data"] + Minimal DLC-style multi-animal project config used for E2E tests. - # Required by _map_keypoints to locate and write config.yaml - # NOTE: This relies on config_path pointing to a file directly under the - # project directory, so that Path(config_path).parent is the project root. - layer.metadata["project"] = str(Path(config_path).parent) - header = layer.metadata["header"] - bp1, bp2 = header.bodyparts[:2] - - # Inject a conversion table into metadata - layer.metadata["tables"] = {super_animal: {bp1: "SK1", bp2: "SK2"}} + Returns: + project_dir, config_path + """ + import yaml - # Ensure _map_keypoints finds matches: - # Put the first two rows exactly on SK1/SK2 and set their labels accordingly. - layer.data[0, 1:] = np.array(superkpts["SK1"], dtype=float) - layer.properties["label"][0] = bp1 + project = tmp_path / "project_multi" + project.mkdir(parents=True, exist_ok=True) - layer.data[1, 1:] = np.array(superkpts["SK2"], dtype=float) - layer.properties["label"][1] = bp2 + cfg = { + "scorer": "John", + "dotsize": 8, + "pcutoff": 0.6, + "colormap": "viridis", + # DLC multi-animal flags/fields expected by napari_deeplabcut.config.models.DLCHeaderModel.from_config + "multianimalproject": True, + "individuals": ["animal1", "animal2"], + "multianimalbodyparts": ["bodypart1", "bodypart2"], + # Often present in DLC configs; safe to include even if empty + "uniquebodyparts": [], + # Keep bodyparts too (some parts of the plugin still read it) + "bodyparts": ["bodypart1", "bodypart2"], + } - return layer, super_animal, bp1, bp2 + config_path = project / "config.yaml" + config_path.write_text(yaml.safe_dump(cfg), encoding="utf-8") + return project, config_path diff --git a/src/napari_deeplabcut/_tests/core/__init__.py b/src/napari_deeplabcut/_tests/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/napari_deeplabcut/_tests/core/io/__init__.py b/src/napari_deeplabcut/_tests/core/io/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/napari_deeplabcut/_tests/core/io/test_form_df.py b/src/napari_deeplabcut/_tests/core/io/test_form_df.py new file mode 100644 index 00000000..f9552f48 --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/io/test_form_df.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from napari_deeplabcut.core.io import form_df + + +def test_form_df_autofills_likelihood_when_missing(): + points = np.array([[0.0, 44.0, 55.0]]) # [frame,y,x] + layer_metadata = { + "header": {"columns": [("S", "", "bp1", "x"), ("S", "", "bp1", "y")]}, + "paths": ["labeled-data/test/img000.png"], + } + layer_properties = {"label": ["bp1"], "id": [""]} # no likelihood + + df = form_df(points, layer_metadata, layer_properties) + # Should contain finite coords + assert np.isfinite(df.to_numpy()).any() + + +def test_form_df_rejects_properties_length_mismatch(): + points = np.array([[0.0, 44.0, 55.0]]) + layer_metadata = {"header": {"columns": [("S", "", "bp1", "x"), ("S", "", "bp1", "y")]}} + layer_properties = {"label": ["bp1", "bp2"], "id": [""]} # label length mismatch + + with pytest.raises(ValueError): + form_df(points, layer_metadata, layer_properties) diff --git a/src/napari_deeplabcut/_tests/core/io/test_hdf_reader.py b/src/napari_deeplabcut/_tests/core/io/test_hdf_reader.py new file mode 100644 index 00000000..8aacb79b --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/io/test_hdf_reader.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pandas as pd + +from napari_deeplabcut.config.models import AnnotationKind +from napari_deeplabcut.core.io import read_hdf_single + +# ----------------------------- +# Helpers to write minimal DLC-like H5 files +# ----------------------------- + + +def _write_h5_single_animal( + path: Path, + *, + scorer: str = "John", + bodyparts=("bp1", "bp2"), + values=None, # [bp1x,bp1y,bp2x,bp2y] for one image + index=("img000.png",), +): + if values is None: + values = [np.nan, np.nan, np.nan, np.nan] + cols = pd.MultiIndex.from_product( + [[scorer], list(bodyparts), ["x", "y"]], + names=["scorer", "bodyparts", "coords"], + ) + df = pd.DataFrame([values], index=list(index), columns=cols) + path.parent.mkdir(parents=True, exist_ok=True) + df.to_hdf(path, key="keypoints", mode="w") + return df + + +def _assert_layerdata_invariants(layerdata): + data, meta, layer_type = layerdata + assert layer_type == "points" + arr = np.asarray(data) + assert arr.ndim == 2 and arr.shape[1] == 3, "Points data must be (N,3): [frame,y,x] for napari" + n = arr.shape[0] + + # Napari expects per-point properties lengths to match n + props = meta.get("properties") or {} + assert isinstance(props, dict) + for k, v in props.items(): + assert len(v) == n, f"property '{k}' length {len(v)} != N points {n}" + + +# ----------------------------- +# Tests +# ----------------------------- + + +def test_read_hdf_single_all_nan_machine_is_empty_points_and_properties(tmp_path: Path): + """ + Contract: empty Points layers are valid, but must have empty properties too. + This test directly prevents the napari ValueError you saw in E2Es. + """ + h5 = tmp_path / "machinelabels-iter0.h5" + _write_h5_single_animal(h5, scorer="machine", values=[np.nan, np.nan, np.nan, np.nan]) + + layers = read_hdf_single(h5, kind=AnnotationKind.MACHINE) + assert len(layers) == 1 + data, meta, _ = layers[0] + + assert np.asarray(data).shape[0] == 0 + _assert_layerdata_invariants(layers[0]) + + +def test_read_hdf_single_one_finite_point_produces_one_point_and_properties(tmp_path: Path): + h5 = tmp_path / "CollectedData_John.h5" + _write_h5_single_animal(h5, scorer="John", values=[10.0, 20.0, np.nan, np.nan]) + + layers = read_hdf_single(h5, kind=AnnotationKind.GT) + data, meta, _ = layers[0] + + assert np.asarray(data).shape[0] == 1 + _assert_layerdata_invariants(layers[0]) + + # Ensure the point corresponds to bp1 (since bp2 is NaN) + assert meta["properties"]["label"] == ["bp1"] + + +def test_read_hdf_single_filters_data_and_properties_consistently(tmp_path: Path): + """ + Regression guard: if finite mask filters df, it must also filter data. + This is the exact failure mode that triggers napari's 'length mismatch' error. + """ + h5 = tmp_path / "CollectedData_John.h5" + _write_h5_single_animal(h5, scorer="John", values=[10.0, 20.0, np.nan, np.nan]) + + layers = read_hdf_single(h5) + data, meta, _ = layers[0] + + n = np.asarray(data).shape[0] + assert len(meta["properties"]["label"]) == n + assert len(meta["properties"]["id"]) == n + assert len(meta["properties"]["likelihood"]) == n + + +def test_read_hdf_single_accepts_3level_header_and_inserts_individuals(tmp_path: Path): + """ + Reader must accept classic 3-level single-animal DLC tables and normalize internally. + """ + h5 = tmp_path / "CollectedData_John.h5" + _write_h5_single_animal(h5, scorer="John", values=[10.0, 20.0, np.nan, np.nan]) + + layers = read_hdf_single(h5) + _, meta, _ = layers[0] + + # 'id' must exist and match data length; for single animal, values are expected to be empty string + assert "id" in meta["properties"] + assert meta["properties"]["id"] == [""] + + +def test_read_hdf_single_metadata_contains_root_and_name(tmp_path: Path): + """ + Reader contract used by multiple E2E tests: metadata.root/name should be present. + """ + h5 = tmp_path / "CollectedData_John.h5" + _write_h5_single_animal(h5, scorer="John", values=[10.0, 20.0, np.nan, np.nan]) + + layers = read_hdf_single(h5) + _, meta, _ = layers[0] + assert meta["name"] == "CollectedData_John" + assert meta["metadata"]["root"] == str(h5.parent) + assert meta["metadata"]["name"] == "CollectedData_John" diff --git a/src/napari_deeplabcut/_tests/core/io/test_write_routing.py b/src/napari_deeplabcut/_tests/core/io/test_write_routing.py new file mode 100644 index 00000000..60602d12 --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/io/test_write_routing.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest + +from napari_deeplabcut.config.models import AnnotationKind +from napari_deeplabcut.core.errors import AmbiguousSaveError, MissingProvenanceError +from napari_deeplabcut.core.io import resolve_output_path_from_metadata, write_hdf + + +def test_resolve_output_path_returns_none_for_machine_without_save_target(): + md = { + "metadata": { + "io": { + "schema_version": 1, + "project_root": str(Path.cwd()), + "source_relpath_posix": "machinelabels-iter0.h5", + "kind": AnnotationKind.MACHINE, + "dataset_key": "keypoints", + } + } + } + out_path, scorer, kind = resolve_output_path_from_metadata(md) + assert out_path is None + assert scorer is None + assert kind == AnnotationKind.MACHINE + + +def test_write_hdf_refuses_machine_without_promotion(tmp_path: Path): + # minimal points + metadata for a machine source + data = np.zeros((1, 3), dtype=float) + attrs = { + "metadata": { + "root": str(tmp_path), + "io": { + "schema_version": 1, + "project_root": str(tmp_path), + "source_relpath_posix": "machinelabels-iter0.h5", + "kind": AnnotationKind.MACHINE, + "dataset_key": "keypoints", + }, + # header is required by writer + "header": { + "columns": [("S", "", "bp1", "x"), ("S", "", "bp1", "y")], + }, + }, + "properties": {"label": ["bp1"], "id": [""], "likelihood": [1.0]}, + } + + with pytest.raises(MissingProvenanceError): + write_hdf("__dlc__.h5", data, attrs) + + +def test_write_hdf_raises_ambiguous_when_multiple_gt_candidates_and_no_provenance(tmp_path: Path): + # Create two GT candidates in root folder + (tmp_path / "CollectedData_John.h5").write_bytes(b"dummy") + (tmp_path / "CollectedData_Jane.h5").write_bytes(b"dummy") + + data = np.zeros((1, 3), dtype=float) + attrs = { + "metadata": { + "root": str(tmp_path), + "header": { + "columns": [("S", "", "bp1", "x"), ("S", "", "bp1", "y")], + }, + # No io/source_h5 => triggers GT fallback scan => ambiguous + }, + "properties": {"label": ["bp1"], "id": [""], "likelihood": [1.0]}, + } + + with pytest.raises(AmbiguousSaveError): + write_hdf("__dlc__.h5", data, attrs) + + +def test_write_hdf_aborts_machine_without_promotion_target(tmp_path: Path): + data = np.array([[0.0, 44.0, 55.0]], dtype=float) + attrs = { + "metadata": { + "root": str(tmp_path), + "io": { + "schema_version": 1, + "project_root": str(tmp_path), + "source_relpath_posix": "machinelabels-iter0.h5", + "kind": AnnotationKind.MACHINE, + "dataset_key": "keypoints", + }, + "header": {"columns": [("S", "", "bp1", "x"), ("S", "", "bp1", "y")]}, + }, + "properties": {"label": ["bp1"], "id": [""], "likelihood": [1.0]}, + } + + with pytest.raises(MissingProvenanceError): + write_hdf("__dlc__.h5", data, attrs) diff --git a/src/napari_deeplabcut/_tests/core/test_config_sync.py b/src/napari_deeplabcut/_tests/core/test_config_sync.py new file mode 100644 index 00000000..28125293 --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_config_sync.py @@ -0,0 +1,437 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import napari_deeplabcut.core.config_sync as cs + + +class DummyLayer: + def __init__(self, *, metadata=None, source_path=None): + self.metadata = metadata or {} + self.source = SimpleNamespace(path=source_path) if source_path is not None else None + + +# ----------------------------------------------------------------------------- +# Small helpers +# ----------------------------------------------------------------------------- + + +def test_coerce_point_size_rounds_and_clamps(): + assert cs._coerce_point_size(12) == 12 + assert cs._coerce_point_size(12.6) == 13 + assert cs._coerce_point_size("7") == 7 + assert cs._coerce_point_size(-5) == 1 + assert cs._coerce_point_size(999) == 100 + assert cs._coerce_point_size("not-a-number") == 6 + + +def test_layer_source_path_returns_string_when_available(): + layer = DummyLayer(source_path="/tmp/some/file.png") + assert cs._layer_source_path(layer) == "/tmp/some/file.png" + + +def test_layer_source_path_returns_none_when_source_missing(): + layer = DummyLayer() + assert cs._layer_source_path(layer) is None + + +def test_layer_source_path_returns_none_when_source_path_access_fails(): + class BadSource: + @property + def path(self): + raise RuntimeError("boom") + + layer = DummyLayer() + layer.source = BadSource() + assert cs._layer_source_path(layer) is None + + +# ----------------------------------------------------------------------------- +# resolve_config_path_from_layer +# ----------------------------------------------------------------------------- + + +def test_resolve_config_prefers_points_meta_inference(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text("dotsize: 6\n", encoding="utf-8") + + layer = DummyLayer(metadata={}) + + monkeypatch.setattr( + cs, + "read_points_meta", + lambda *args, **kwargs: SimpleNamespace(project=None, root=None, paths=[]), + ) + monkeypatch.setattr( + cs, + "infer_dlc_project_from_points_meta", + lambda *args, **kwargs: SimpleNamespace(config_path=config_path), + ) + + resolved = cs.resolve_config_path_from_layer(layer) + + assert resolved == config_path + + +def test_resolve_config_uses_image_layer_inference_when_points_meta_has_no_config(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text("dotsize: 6\n", encoding="utf-8") + + layer = DummyLayer(metadata={}) + image_layer = DummyLayer(metadata={"root": str(tmp_path)}) + + monkeypatch.setattr( + cs, + "read_points_meta", + lambda *args, **kwargs: SimpleNamespace(project=None, root=None, paths=[]), + ) + monkeypatch.setattr( + cs, + "infer_dlc_project_from_points_meta", + lambda *args, **kwargs: SimpleNamespace(config_path=None), + ) + monkeypatch.setattr( + cs, + "infer_dlc_project_from_image_layer", + lambda *args, **kwargs: SimpleNamespace(config_path=config_path), + ) + + resolved = cs.resolve_config_path_from_layer(layer, image_layer=image_layer) + + assert resolved == config_path + + +def test_resolve_config_uses_generic_fallback_hints(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text("dotsize: 6\n", encoding="utf-8") + + layer = DummyLayer(metadata={}) + + monkeypatch.setattr( + cs, + "read_points_meta", + lambda *args, **kwargs: SimpleNamespace(project=None, root=None, paths=[]), + ) + monkeypatch.setattr( + cs, + "infer_dlc_project_from_points_meta", + lambda *args, **kwargs: SimpleNamespace(config_path=None), + ) + monkeypatch.setattr( + cs, + "infer_dlc_project", + lambda *args, **kwargs: SimpleNamespace(config_path=config_path), + ) + + resolved = cs.resolve_config_path_from_layer(layer, fallback_project=str(tmp_path)) + + assert resolved == config_path + + +def test_resolve_config_uses_find_nearest_config_as_last_resort(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text("dotsize: 6\n", encoding="utf-8") + + layer = DummyLayer(metadata={"root": str(tmp_path)}) + + monkeypatch.setattr(cs, "read_points_meta", lambda *args, **kwargs: None) + monkeypatch.setattr(cs, "infer_dlc_project", lambda *args, **kwargs: SimpleNamespace(config_path=None)) + monkeypatch.setattr(cs, "find_nearest_config", lambda *args, **kwargs: config_path) + + resolved = cs.resolve_config_path_from_layer(layer) + + assert resolved == config_path + + +def test_resolve_config_returns_none_when_everything_fails(monkeypatch): + layer = DummyLayer(metadata={}) + + monkeypatch.setattr(cs, "read_points_meta", lambda *args, **kwargs: None) + monkeypatch.setattr(cs, "infer_dlc_project", lambda *args, **kwargs: SimpleNamespace(config_path=None)) + monkeypatch.setattr(cs, "find_nearest_config", lambda *args, **kwargs: None) + + resolved = cs.resolve_config_path_from_layer(layer) + + assert resolved is None + + +def test_resolve_config_ignores_points_meta_when_read_points_meta_raises(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text("dotsize: 6\n", encoding="utf-8") + + layer = DummyLayer(metadata={"root": str(tmp_path)}) + + monkeypatch.setattr(cs, "read_points_meta", lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("boom"))) + monkeypatch.setattr(cs, "infer_dlc_project", lambda *args, **kwargs: SimpleNamespace(config_path=config_path)) + + resolved = cs.resolve_config_path_from_layer(layer) + + assert resolved == config_path + + +def test_resolve_config_skips_points_meta_when_errors_attribute_present(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text("dotsize: 6\n", encoding="utf-8") + + layer = DummyLayer(metadata={"root": str(tmp_path)}) + + monkeypatch.setattr(cs, "read_points_meta", lambda *args, **kwargs: SimpleNamespace(errors=["bad meta"])) + monkeypatch.setattr(cs, "infer_dlc_project", lambda *args, **kwargs: SimpleNamespace(config_path=config_path)) + + resolved = cs.resolve_config_path_from_layer(layer) + + assert resolved == config_path + + +def test_resolve_config_skips_non_file_points_meta_config_and_falls_through(monkeypatch, tmp_path): + missing_config = tmp_path / "missing_config.yaml" + real_config = tmp_path / "config.yaml" + real_config.write_text("dotsize: 6\n", encoding="utf-8") + + layer = DummyLayer(metadata={"root": str(tmp_path)}) + + monkeypatch.setattr( + cs, + "read_points_meta", + lambda *args, **kwargs: SimpleNamespace(project=None, root=None, paths=[]), + ) + monkeypatch.setattr( + cs, + "infer_dlc_project_from_points_meta", + lambda *args, **kwargs: SimpleNamespace(config_path=missing_config), + ) + monkeypatch.setattr( + cs, + "infer_dlc_project", + lambda *args, **kwargs: SimpleNamespace(config_path=real_config), + ) + + resolved = cs.resolve_config_path_from_layer(layer) + + assert resolved == real_config + + +def test_resolve_config_passes_paths_into_generic_inference(monkeypatch, tmp_path): + captured = {} + + layer = DummyLayer( + metadata={ + "project": str(tmp_path / "proj"), + "root": str(tmp_path / "root"), + "paths": [ + "labeled-data/session_001/img001.png", + "labeled-data/session_001/img002.png", + "labeled-data/session_001/img003.png", + "labeled-data/session_001/img004.png", + ], + }, + source_path=str(tmp_path / "video.mp4"), + ) + + monkeypatch.setattr(cs, "read_points_meta", lambda *args, **kwargs: None) + + def fake_infer_dlc_project(**kwargs): + captured.update(kwargs) + return SimpleNamespace(config_path=None) + + monkeypatch.setattr(cs, "infer_dlc_project", fake_infer_dlc_project) + monkeypatch.setattr(cs, "find_nearest_config", lambda *args, **kwargs: None) + + resolved = cs.resolve_config_path_from_layer( + layer, + fallback_project=str(tmp_path / "fallback_project"), + fallback_root=str(tmp_path / "fallback_root"), + prefer_project_root=False, + max_levels=7, + ) + + assert resolved is None + assert captured["dataset_candidates"] == ["labeled-data/session_001/img001.png"] + assert captured["anchor_candidates"] == [ + str(tmp_path / "proj"), + str(tmp_path / "root"), + str(tmp_path / "video.mp4"), + str(tmp_path / "fallback_project"), + str(tmp_path / "fallback_root"), + "labeled-data/session_001/img001.png", + "labeled-data/session_001/img002.png", + "labeled-data/session_001/img003.png", + ] + assert captured["prefer_project_root"] is False + assert captured["max_levels"] == 7 + + +def test_resolve_config_uses_image_inference_after_points_inference_exception(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text("dotsize: 6\n", encoding="utf-8") + + layer = DummyLayer(metadata={}) + image_layer = DummyLayer(metadata={"root": str(tmp_path)}) + + monkeypatch.setattr( + cs, + "read_points_meta", + lambda *args, **kwargs: SimpleNamespace(project=None, root=None, paths=[]), + ) + monkeypatch.setattr( + cs, + "infer_dlc_project_from_points_meta", + lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("boom")), + ) + monkeypatch.setattr( + cs, + "infer_dlc_project_from_image_layer", + lambda *args, **kwargs: SimpleNamespace(config_path=config_path), + ) + + resolved = cs.resolve_config_path_from_layer(layer, image_layer=image_layer) + + assert resolved == config_path + + +def test_resolve_config_continues_when_find_nearest_config_raises_for_one_candidate(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text("dotsize: 6\n", encoding="utf-8") + + layer = DummyLayer( + metadata={ + "project": "bad-candidate", + "root": str(tmp_path), + } + ) + + monkeypatch.setattr(cs, "read_points_meta", lambda *args, **kwargs: None) + monkeypatch.setattr(cs, "infer_dlc_project", lambda *args, **kwargs: SimpleNamespace(config_path=None)) + + def fake_find(candidate, **kwargs): + if candidate == "bad-candidate": + raise RuntimeError("boom") + return config_path + + monkeypatch.setattr(cs, "find_nearest_config", fake_find) + + resolved = cs.resolve_config_path_from_layer(layer) + + assert resolved == config_path + + +# ----------------------------------------------------------------------------- +# load_point_size_from_config +# ----------------------------------------------------------------------------- + + +def test_load_point_size_from_config_returns_none_for_missing_path(): + assert cs.load_point_size_from_config(None) is None + + +def test_load_point_size_from_config_returns_none_when_load_fails(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + + monkeypatch.setattr(cs.io, "load_config", lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("boom"))) + + assert cs.load_point_size_from_config(config_path) is None + + +def test_load_point_size_from_config_returns_none_when_key_missing(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + + monkeypatch.setattr(cs.io, "load_config", lambda *args, **kwargs: {"colormap": "rainbow"}) + + assert cs.load_point_size_from_config(config_path) is None + + +def test_load_point_size_from_config_coerces_and_clamps(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + + monkeypatch.setattr(cs.io, "load_config", lambda *args, **kwargs: {"dotsize": "250"}) + + assert cs.load_point_size_from_config(config_path) == 100 + + +# ----------------------------------------------------------------------------- +# save_point_size_to_config +# ----------------------------------------------------------------------------- + + +def test_save_point_size_to_config_returns_false_when_path_missing(): + assert cs.save_point_size_to_config(None, 12) is False + + +def test_save_point_size_to_config_returns_false_when_load_fails(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + + monkeypatch.setattr(cs.io, "load_config", lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("boom"))) + + assert cs.save_point_size_to_config(config_path, 12) is False + + +def test_save_point_size_to_config_returns_false_when_value_unchanged(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + + written = [] + + monkeypatch.setattr(cs.io, "load_config", lambda *args, **kwargs: {"dotsize": 12}) + monkeypatch.setattr(cs.io, "write_config", lambda *args, **kwargs: written.append(True)) + + assert cs.save_point_size_to_config(config_path, 12) is False + assert written == [] + + +def test_save_point_size_to_config_writes_updated_value(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + + written = {} + + monkeypatch.setattr(cs.io, "load_config", lambda *args, **kwargs: {"dotsize": 6, "colormap": "rainbow"}) + + def fake_write(path, cfg): + written["path"] = path + written["cfg"] = cfg + + monkeypatch.setattr(cs.io, "write_config", fake_write) + + assert cs.save_point_size_to_config(config_path, 12) is True + assert written["path"] == str(config_path) + assert written["cfg"]["dotsize"] == 12 + assert written["cfg"]["colormap"] == "rainbow" + + +def test_save_point_size_to_config_clamps_before_writing(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + + written = {} + + monkeypatch.setattr(cs.io, "load_config", lambda *args, **kwargs: {}) + + def fake_write(path, cfg): + written["cfg"] = cfg + + monkeypatch.setattr(cs.io, "write_config", fake_write) + + assert cs.save_point_size_to_config(config_path, 999) is True + assert written["cfg"]["dotsize"] == 100 + + +def test_save_point_size_to_config_still_writes_when_old_value_is_not_coercible(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + + written = {} + + monkeypatch.setattr(cs.io, "load_config", lambda *args, **kwargs: {"dotsize": object()}) + + def fake_write(path, cfg): + written["cfg"] = cfg + + monkeypatch.setattr(cs.io, "write_config", fake_write) + + assert cs.save_point_size_to_config(config_path, 15) is True + assert written["cfg"]["dotsize"] == 15 + + +def test_save_point_size_to_config_returns_false_when_write_fails(monkeypatch, tmp_path): + config_path = tmp_path / "config.yaml" + + monkeypatch.setattr(cs.io, "load_config", lambda *args, **kwargs: {"dotsize": 6}) + monkeypatch.setattr(cs.io, "write_config", lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("boom"))) + + assert cs.save_point_size_to_config(config_path, 15) is False diff --git a/src/napari_deeplabcut/_tests/core/test_conflicts.py b/src/napari_deeplabcut/_tests/core/test_conflicts.py new file mode 100644 index 00000000..f2da0b15 --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_conflicts.py @@ -0,0 +1,362 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import pandas as pd +import pytest + +import napari_deeplabcut.core.conflicts as conflicts_mod +import napari_deeplabcut.core.dataframes as dataframes_mod +from napari_deeplabcut.config.models import AnnotationKind, DLCProjectContext +from napari_deeplabcut.core.errors import AmbiguousSaveError, MissingProvenanceError + + +def _make_points_meta( + *, + header_scorer: str | None = "scorerA", + save_target=None, + io_kind=AnnotationKind.GT, + root: str | None = None, +): + header = None if header_scorer is None else SimpleNamespace(scorer=header_scorer) + io = None if io_kind is None else SimpleNamespace(kind=io_kind) + return SimpleNamespace( + header=header, + save_target=save_target, + io=io, + root=root, + ) + + +def _stub_validation_pipeline( + monkeypatch, + *, + pts_meta, + attrs_obj=None, + points_obj=None, + props_obj=None, + ctx_obj=None, + df_new=None, +): + """ + Stub schema validation + parse_points_metadata + form_df_from_validated. + + This keeps tests focused on the routing/conflict logic rather than on + pydantic/schema correctness (which is tested elsewhere). + """ + if attrs_obj is None: + attrs_obj = SimpleNamespace( + metadata={"some": "meta"}, + properties={"label": ["nose"], "id": [1]}, + ) + if points_obj is None: + points_obj = SimpleNamespace(data="POINTS") + if props_obj is None: + props_obj = SimpleNamespace(properties="PROPS") + if ctx_obj is None: + ctx_obj = SimpleNamespace(ctx="CTX") + if df_new is None: + df_new = pd.DataFrame({"dummy": [1]}) + + monkeypatch.setattr( + conflicts_mod.dlc_schemas.PointsLayerAttributesModel, + "model_validate", + lambda payload: attrs_obj, + ) + monkeypatch.setattr( + conflicts_mod, + "parse_points_metadata", + lambda metadata, drop_header=False: pts_meta, + ) + monkeypatch.setattr( + conflicts_mod.dlc_schemas.PointsDataModel, + "model_validate", + lambda payload: points_obj, + ) + monkeypatch.setattr( + conflicts_mod.dlc_schemas.KeypointPropertiesModel, + "model_validate", + lambda payload: props_obj, + ) + monkeypatch.setattr( + conflicts_mod.dlc_schemas.PointsWriteInputModel, + "model_validate", + lambda payload: ctx_obj, + ) + monkeypatch.setattr( + dataframes_mod, + "form_df_from_validated", + lambda ctx: df_new, + ) + + return SimpleNamespace( + attrs_obj=attrs_obj, + points_obj=points_obj, + props_obj=props_obj, + ctx_obj=ctx_obj, + df_new=df_new, + ) + + +def test_compute_overwrite_report_raises_when_header_missing(monkeypatch): + pts_meta = _make_points_meta(header_scorer=None) + _stub_validation_pipeline(monkeypatch, pts_meta=pts_meta) + + with pytest.raises(ValueError, match="valid DLC header"): + conflicts_mod.compute_overwrite_report_for_points_save( + data=[[0, 1, 2]], + attributes={"name": "points"}, + ) + + +def test_compute_overwrite_report_raises_for_machine_source_without_resolved_target(monkeypatch): + pts_meta = _make_points_meta(io_kind=AnnotationKind.MACHINE) + _stub_validation_pipeline(monkeypatch, pts_meta=pts_meta) + + monkeypatch.setattr( + conflicts_mod, + "resolve_output_path_from_metadata", + lambda attributes: (None, None, AnnotationKind.MACHINE), + ) + + with pytest.raises(MissingProvenanceError, match="MACHINE source"): + conflicts_mod.compute_overwrite_report_for_points_save( + data=[[0, 1, 2]], + attributes={"name": "points"}, + ) + + +def test_compute_overwrite_report_raises_when_gt_fallback_is_ambiguous(monkeypatch, tmp_path): + pts_meta = _make_points_meta(io_kind=AnnotationKind.GT) + _stub_validation_pipeline(monkeypatch, pts_meta=pts_meta) + + monkeypatch.setattr( + conflicts_mod, + "resolve_output_path_from_metadata", + lambda attributes: (None, None, AnnotationKind.GT), + ) + monkeypatch.setattr( + conflicts_mod, + "infer_dlc_project_from_points_meta", + lambda *args, **kwargs: DLCProjectContext( + root_anchor=tmp_path, + dataset_folder=tmp_path, + ), + ) + + (tmp_path / "CollectedData_a.h5").touch() + (tmp_path / "CollectedData_b.h5").touch() + + with pytest.raises(AmbiguousSaveError) as excinfo: + conflicts_mod.compute_overwrite_report_for_points_save( + data=[[0, 1, 2]], + attributes={"name": "points"}, + ) + + err = excinfo.value + assert "Multiple CollectedData_*.h5 files found" in str(err) + assert sorted(Path(p).name for p in err.candidates) == [ + "CollectedData_a.h5", + "CollectedData_b.h5", + ] + + +def test_compute_overwrite_report_returns_none_for_non_gt_destination(monkeypatch, tmp_path): + out = tmp_path / "machine_output.h5" + + pts_meta = _make_points_meta( + save_target=None, + io_kind=AnnotationKind.MACHINE, + ) + _stub_validation_pipeline(monkeypatch, pts_meta=pts_meta) + + monkeypatch.setattr( + conflicts_mod, + "resolve_output_path_from_metadata", + lambda attributes: (str(out), None, AnnotationKind.GT), + ) + + result = conflicts_mod.compute_overwrite_report_for_points_save( + data=[[0, 1, 2]], + attributes={"name": "points"}, + ) + + assert result is None + + +def test_compute_overwrite_report_returns_none_when_output_does_not_exist(monkeypatch, tmp_path): + out = tmp_path / "CollectedData_scorerA.h5" + + pts_meta = _make_points_meta(io_kind=AnnotationKind.GT) + _stub_validation_pipeline(monkeypatch, pts_meta=pts_meta) + + monkeypatch.setattr( + conflicts_mod, + "resolve_output_path_from_metadata", + lambda attributes: (str(out), None, AnnotationKind.GT), + ) + + result = conflicts_mod.compute_overwrite_report_for_points_save( + data=[[0, 1, 2]], + attributes={"name": "points"}, + ) + + assert result is None + + +def test_compute_overwrite_report_returns_report_for_existing_gt_file_with_conflicts(monkeypatch, tmp_path): + out = tmp_path / "CollectedData_target.h5" + out.touch() + + old_df = pd.DataFrame({"old": [1]}) + raw_new_df = pd.DataFrame({"new": [1]}) + promoted_df = pd.DataFrame({"promoted": [1]}) + key_conflict = object() + report = SimpleNamespace(has_conflicts=True, marker="REPORT") + + pts_meta = _make_points_meta(io_kind=AnnotationKind.GT) + _stub_validation_pipeline(monkeypatch, pts_meta=pts_meta, df_new=raw_new_df) + + monkeypatch.setattr( + conflicts_mod, + "resolve_output_path_from_metadata", + lambda attributes: (str(out), "target_scorer", AnnotationKind.GT), + ) + + seen = {} + + def fake_set_df_scorer(df, scorer): + seen["set_df_scorer"] = (df, scorer) + return promoted_df + + def fake_read_hdf(path, key=None): + seen.setdefault("read_hdf_calls", []).append((Path(path), key)) + return old_df + + def fake_keypoint_conflicts(df_old, df_new): + seen["keypoint_conflicts"] = (df_old, df_new) + return key_conflict + + def fake_build_report(conflicts, *, layer_name, destination_path): + seen["build_report"] = (conflicts, layer_name, destination_path) + return report + + monkeypatch.setattr(conflicts_mod, "set_df_scorer", fake_set_df_scorer) + monkeypatch.setattr(pd, "read_hdf", fake_read_hdf) + monkeypatch.setattr(dataframes_mod, "keypoint_conflicts", fake_keypoint_conflicts) + monkeypatch.setattr(dataframes_mod, "build_overwrite_conflict_report", fake_build_report) + + result = conflicts_mod.compute_overwrite_report_for_points_save( + data=[[0, 1, 2]], + attributes={"name": "my-points-layer"}, + ) + + assert result is report + assert seen["set_df_scorer"] == (raw_new_df, "target_scorer") + assert seen["read_hdf_calls"] == [(out, "keypoints")] + assert seen["keypoint_conflicts"] == (old_df, promoted_df) + assert seen["build_report"] == ( + key_conflict, + "my-points-layer", + str(out), + ) + + +def test_compute_overwrite_report_returns_none_when_report_has_no_conflicts(monkeypatch, tmp_path): + out = tmp_path / "CollectedData_target.h5" + out.touch() + + old_df = pd.DataFrame({"old": [1]}) + new_df = pd.DataFrame({"new": [1]}) + report = SimpleNamespace(has_conflicts=False) + + pts_meta = _make_points_meta(io_kind=AnnotationKind.GT) + _stub_validation_pipeline(monkeypatch, pts_meta=pts_meta, df_new=new_df) + + monkeypatch.setattr( + conflicts_mod, + "resolve_output_path_from_metadata", + lambda attributes: (str(out), None, AnnotationKind.GT), + ) + monkeypatch.setattr(pd, "read_hdf", lambda path, key=None: old_df) + monkeypatch.setattr(dataframes_mod, "keypoint_conflicts", lambda df_old, df_new: "conflicts") + monkeypatch.setattr( + dataframes_mod, + "build_overwrite_conflict_report", + lambda conflicts, *, layer_name, destination_path: report, + ) + + result = conflicts_mod.compute_overwrite_report_for_points_save( + data=[[0, 1, 2]], + attributes={"name": "my-points-layer"}, + ) + + assert result is None + + +def test_compute_overwrite_report_falls_back_when_keyed_hdf_read_fails(monkeypatch, tmp_path): + out = tmp_path / "CollectedData_target.h5" + out.touch() + + old_df = pd.DataFrame({"old": [1]}) + new_df = pd.DataFrame({"new": [1]}) + report = SimpleNamespace(has_conflicts=True) + + pts_meta = _make_points_meta(io_kind=AnnotationKind.GT) + _stub_validation_pipeline(monkeypatch, pts_meta=pts_meta, df_new=new_df) + + monkeypatch.setattr( + conflicts_mod, + "resolve_output_path_from_metadata", + lambda attributes: (str(out), None, AnnotationKind.GT), + ) + + calls = [] + + def fake_read_hdf(path, key=None): + calls.append((Path(path), key)) + if key == "keypoints": + raise KeyError("missing key") + return old_df + + monkeypatch.setattr(pd, "read_hdf", fake_read_hdf) + monkeypatch.setattr(dataframes_mod, "keypoint_conflicts", lambda df_old, df_new: "conflicts") + monkeypatch.setattr( + dataframes_mod, + "build_overwrite_conflict_report", + lambda conflicts, *, layer_name, destination_path: report, + ) + + result = conflicts_mod.compute_overwrite_report_for_points_save( + data=[[0, 1, 2]], + attributes={"name": "my-points-layer"}, + ) + + assert result is report + assert calls == [ + (out, "keypoints"), + (out, None), + ] + + +def test_compute_overwrite_report_raises_when_gt_fallback_has_no_root_and_no_dataset_dir(monkeypatch): + pts_meta = _make_points_meta(io_kind=AnnotationKind.GT, root=None) + _stub_validation_pipeline(monkeypatch, pts_meta=pts_meta) + + monkeypatch.setattr( + conflicts_mod, + "resolve_output_path_from_metadata", + lambda attributes: (None, None, AnnotationKind.GT), + ) + monkeypatch.setattr( + conflicts_mod, + "infer_dlc_project_from_points_meta", + lambda *args, **kwargs: DLCProjectContext(), + ) + + with pytest.raises(MissingProvenanceError, match="GT fallback requires root"): + conflicts_mod.compute_overwrite_report_for_points_save( + data=[[0, 1, 2]], + attributes={"name": "points"}, + ) diff --git a/src/napari_deeplabcut/_tests/core/test_dataframes.py b/src/napari_deeplabcut/_tests/core/test_dataframes.py new file mode 100644 index 00000000..ca37d768 --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_dataframes.py @@ -0,0 +1,430 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest + +from napari_deeplabcut.config.models import DLCHeaderModel, PointsMetadata +from napari_deeplabcut.core.dataframes import ( + align_old_new, + form_df_from_validated, + guarantee_multiindex_rows, + harmonize_keypoint_column_index, + harmonize_keypoint_row_index, + keypoint_conflicts, + merge_multiple_scorers, +) +from napari_deeplabcut.core.schemas import PointsWriteInputModel + +# ----------------------------------------------------------------------------- +# Helpers: create DLC-like column MultiIndex +# ----------------------------------------------------------------------------- + + +def cols_3level(scorer="S", bodyparts=("bp1",), coords=("x", "y")) -> pd.MultiIndex: + """Classic single-animal DLC layout often appears as 3-level columns. [1](https://github.com/DeepLabCut/DeepLabCut/issues/3072)""" + return pd.MultiIndex.from_product( + [[scorer], list(bodyparts), list(coords)], + names=["scorer", "bodyparts", "coords"], + ) + + +def cols_4level(scorer="S", individuals=("",), bodyparts=("bp1",), coords=("x", "y")) -> pd.MultiIndex: + """Multi-animal DLC layout is 4-level columns (and some pipelines may store empty individuals for single-animal). [1](https://github.com/DeepLabCut/DeepLabCut/issues/3072)""" + return pd.MultiIndex.from_product( + [[scorer], list(individuals), list(bodyparts), list(coords)], + names=["scorer", "individuals", "bodyparts", "coords"], + ) + + +def make_points_ctx( + *, + header_cols: pd.MultiIndex, + scorer: str = "S", + # napari points data: [frame, y, x] + points_data: np.ndarray, + labels: list[str], + ids: list[str], + likelihood: list[float] | None = None, + paths: list[str] | None = None, +) -> PointsWriteInputModel: + """Build a real PointsWriteInputModel from schemas.py.""" + meta = PointsMetadata( + header=DLCHeaderModel(columns=header_cols), + paths=paths, + ) + ctx = PointsWriteInputModel( + points={"data": points_data}, + meta=meta, + props={"label": labels, "id": ids, "likelihood": likelihood}, + ) + return ctx + + +# ----------------------------------------------------------------------------- +# 1) guarantee_multiindex_rows +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize("sep_path", ["labeled-data/test/img000.png", r"labeled-data\test\img000.png"]) +def test_guarantee_multiindex_rows_splits_string_index(sep_path): + df = pd.DataFrame([[1.0]], columns=[("S", "bp1", "x")], index=[sep_path]) + df.columns = pd.MultiIndex.from_tuples(df.columns, names=["scorer", "bodyparts", "coords"]) + + guarantee_multiindex_rows(df) + assert isinstance(df.index, pd.MultiIndex) + assert df.index.to_list()[0][-1] == "img000.png" + + +def test_guarantee_multiindex_rows_leaves_numeric_index_unchanged(): + df = pd.DataFrame([[1.0]], columns=[("S", "bp1", "x")], index=[0]) + df.columns = pd.MultiIndex.from_tuples(df.columns, names=["scorer", "bodyparts", "coords"]) + + guarantee_multiindex_rows(df) + # numeric index should remain (TypeError branch) + assert not isinstance(df.index, pd.MultiIndex) + + +# ----------------------------------------------------------------------------- +# 2) harmonize_keypoint_column_index +# ----------------------------------------------------------------------------- + + +def test_harmonize_keypoint_column_index_upgrades_3_to_4_with_empty_individuals(): + df = pd.DataFrame([[1.0, 2.0]], columns=cols_3level(bodyparts=("bp1",), coords=("x", "y")), index=["img000.png"]) + out = harmonize_keypoint_column_index(df) + + assert out.columns.nlevels == 4 + assert out.columns.names == ["scorer", "individuals", "bodyparts", "coords"] + assert set(out.columns.get_level_values("individuals")) == {""} + + +def test_harmonize_keypoint_column_index_keeps_4level_and_sets_names(): + df = pd.DataFrame( + [[1.0, 2.0]], + columns=cols_4level(individuals=("animal1",), bodyparts=("bp1",), coords=("x", "y")), + index=["img000.png"], + ) + out = harmonize_keypoint_column_index(df) + assert out.columns.nlevels == 4 + assert out.columns.names == ["scorer", "individuals", "bodyparts", "coords"] + + +# ----------------------------------------------------------------------------- +# 3) align_old_new should be stable for 3-level vs 4-level mixes +# ----------------------------------------------------------------------------- + + +def test_align_old_new_works_when_old_is_3level_and_new_is_4level(): + df_old = pd.DataFrame( + [[10.0, 20.0]], + columns=cols_3level(bodyparts=("bp1",), coords=("x", "y")), + index=["img000.png"], + ) + df_new = pd.DataFrame( + [[11.0, 22.0]], + columns=cols_4level(individuals=("",), bodyparts=("bp1",), coords=("x", "y")), + index=["img000.png"], + ) + + old2, new2 = align_old_new(df_old, df_new) + assert old2.columns.nlevels == 4 + assert new2.columns.nlevels == 4 + assert old2.columns.equals(new2.columns) + assert old2.index.equals(new2.index) + + +# Overwrite conflict detection uses harmonized index to avoid (Multi)Index mismatches +def test_align_old_new_handles_basename_index_vs_deep_path_multiindex_regression(): + """ + Regression test for overwrite-preflight crash: + + Old on-disk GT files may use basename-like row indices such as: + Index(["img00001.png"]) + + Newly formed dataframes from runtime points metadata may use deeper + path-like row indices that become a 3-level MultiIndex after + guarantee_multiindex_rows(), e.g.: + MultiIndex([("labeled-data", "test", "img00001.png")]) + + align_old_new() must harmonize row indices before reindex/union, + otherwise pandas may raise: + AssertionError: Length of new_levels (3) must be <= self.nlevels (1) + """ + cols = cols_4level( + scorer="S", + individuals=("",), + bodyparts=("nose",), + coords=("x", "y"), + ) + + # Simulate existing on-disk GT dataframe with basename-only row key + df_old = pd.DataFrame( + [[10.0, 20.0]], + columns=cols, + index=["img00001.png"], + ) + guarantee_multiindex_rows(df_old) # -> MultiIndex([("img00001.png",)]) + + # Simulate newly formed dataframe from runtime metadata with deep relpath + df_new = pd.DataFrame( + [[11.0, 22.0]], + columns=cols, + index=["labeled-data/test/img00001.png"], + ) + guarantee_multiindex_rows(df_new) # -> MultiIndex([("labeled-data","test","img00001.png")]) + + old2, new2 = align_old_new(df_old, df_new) + + # The key regression assertion: no crash, and indices are now aligned + assert isinstance(old2.index, pd.MultiIndex) + assert isinstance(new2.index, pd.MultiIndex) + assert old2.index.equals(new2.index) + + # After harmonization, both should collapse to basename representation + assert old2.index.nlevels == 1 + assert new2.index.nlevels == 1 + assert old2.index.to_list() == [("img00001.png",)] + assert new2.index.to_list() == [("img00001.png",)] + + # Columns should also be aligned + assert old2.columns.equals(new2.columns) + + # And values should still be present on the aligned row + row = ("img00001.png",) + assert old2.loc[row, ("S", "", "nose", "x")] == 10.0 + assert old2.loc[row, ("S", "", "nose", "y")] == 20.0 + assert new2.loc[row, ("S", "", "nose", "x")] == 11.0 + assert new2.loc[row, ("S", "", "nose", "y")] == 22.0 + + +def test_keypoint_conflicts_handles_basename_index_vs_deep_path_multiindex_regression(): + """ + Regression test for overwrite-conflict preflight: + keypoint_conflicts() should handle shallow-vs-deep row indices without crashing. + """ + cols = cols_4level( + scorer="S", + individuals=("",), + bodyparts=("nose",), + coords=("x", "y"), + ) + + df_old = pd.DataFrame( + [[10.0, 20.0]], + columns=cols, + index=["img00001.png"], + ) + guarantee_multiindex_rows(df_old) + + df_new = pd.DataFrame( + [[11.0, 20.0]], # x differs -> conflict + columns=cols, + index=["labeled-data/test/img00001.png"], + ) + guarantee_multiindex_rows(df_new) + + kc = keypoint_conflicts(df_old, df_new) + + assert isinstance(kc.index, pd.MultiIndex) + assert kc.index.nlevels == 1 + assert kc.index.to_list() == [("img00001.png",)] + assert kc.loc[("img00001.png",)].any() + + +# ----------------------------------------------------------------------------- +# 4) keypoint_conflicts DLC semantics +# ----------------------------------------------------------------------------- + + +def test_keypoint_conflicts_detects_conflict_single_animal_3level(): + df_old = pd.DataFrame( + [[10.0, 20.0]], + columns=cols_3level(bodyparts=("bp1",), coords=("x", "y")), + index=["img000.png"], + ) + df_new = pd.DataFrame( + [[99.0, 20.0]], # x differs + columns=cols_3level(bodyparts=("bp1",), coords=("x", "y")), + index=["img000.png"], + ) + kc = keypoint_conflicts(df_old, df_new) + + assert kc.loc[("img000.png",)].any() + assert any("bp1" in str(c) for c in kc.columns) + + +def test_keypoint_conflicts_detects_conflict_multi_animal_4level(): + df_old = pd.DataFrame( + [[10.0, 20.0]], + columns=cols_4level(individuals=("animal1",), bodyparts=("bp1",), coords=("x", "y")), + index=["img000.png"], + ) + df_new = pd.DataFrame( + [[99.0, 20.0]], + columns=cols_4level(individuals=("animal1",), bodyparts=("bp1",), coords=("x", "y")), + index=["img000.png"], + ) + kc = keypoint_conflicts(df_old, df_new) + + assert kc.loc[("img000.png",)].any() + assert any(("animal1" in str(c) and "bp1" in str(c)) for c in kc.columns) + + +# ----------------------------------------------------------------------------- +# 5) harmonize_keypoint_row_index: collapse deep index to basenames when overlap is high +# ----------------------------------------------------------------------------- + + +def test_harmonize_keypoint_row_index_collapses_to_basename_when_overlap_high(): + cols = cols_3level(bodyparts=("bp1",), coords=("x", "y")) + df_deep = pd.DataFrame( + [[1.0, 2.0]], + columns=cols, + index=pd.MultiIndex.from_tuples([("labeled-data", "test", "img000.png")]), + ) + df_shallow = pd.DataFrame( + [[1.0, 2.0]], + columns=cols, + index=pd.MultiIndex.from_tuples([("img000.png",)]), + ) + a, b = harmonize_keypoint_row_index(df_deep, df_shallow) + + assert isinstance(a.index, pd.MultiIndex) and isinstance(b.index, pd.MultiIndex) + assert a.index.nlevels == b.index.nlevels == 1 + assert a.index.to_list()[0][0] == "img000.png" + + +# ----------------------------------------------------------------------------- +# 6) merge_multiple_scorers +# ----------------------------------------------------------------------------- + + +def test_merge_multiple_scorers_picks_highest_likelihood_when_present(): + """ + DLC output often includes likelihood per keypoint, per frame. [1](https://github.com/DeepLabCut/DeepLabCut/issues/3072) + If multiple scorers exist, merge should keep the max-likelihood annotation. + """ + cols = pd.MultiIndex.from_product( + [["A", "B"], ["bp1"], ["x", "y", "likelihood"]], + names=["scorer", "bodyparts", "coords"], + ) + df = pd.DataFrame( + [ + [10, 20, 0.1, 100, 200, 0.9], # frame0 -> pick B + [11, 21, 0.8, 101, 201, 0.2], # frame1 -> pick A + ], + columns=cols, + index=[0, 1], + ) + + out = merge_multiple_scorers(df) + assert out.shape == (2, 3) + assert np.allclose(out.to_numpy()[0], [100, 200, 0.9], equal_nan=True) + assert np.allclose(out.to_numpy()[1], [11, 21, 0.8], equal_nan=True) + + +def test_merge_multiple_scorers_all_nan_likelihood_does_not_crash(): + cols = pd.MultiIndex.from_product( + [["A", "B"], ["bp1"], ["x", "y", "likelihood"]], + names=["scorer", "bodyparts", "coords"], + ) + df = pd.DataFrame([[np.nan] * 6], columns=cols, index=[0]) + + out = merge_multiple_scorers(df) + assert out.shape == (1, 3) + assert np.isnan(out.to_numpy()).all() + + +def test_merge_multiple_scorers_without_likelihood_picks_first_scorer(): + cols = pd.MultiIndex.from_product( + [["A", "B"], ["bp1"], ["x", "y"]], + names=["scorer", "bodyparts", "coords"], + ) + df = pd.DataFrame([[10, 20, 100, 200]], columns=cols, index=[0]) + out = merge_multiple_scorers(df) + + assert set(out.columns.get_level_values("scorer")) == {"A"} + assert out.to_numpy().tolist() == [[10, 20]] + + +# ----------------------------------------------------------------------------- +# 7) form_df_from_validated (writer-facing) using real schemas +# ----------------------------------------------------------------------------- + + +def test_form_df_from_validated_writes_xy_in_dlc_order_and_maps_paths_index(): + """ + PointsDataModel converts napari [frame,y,x] -> DLC [x,y]. + When meta.paths is present, df index should be replaced with those path keys. + """ + header = cols_4level(scorer="S", individuals=("",), bodyparts=("bp1",), coords=("x", "y")) + ctx = make_points_ctx( + header_cols=header, + points_data=np.array([[0.0, 44.0, 55.0]]), # frame=0, y=44, x=55 + labels=["bp1"], + ids=[""], + likelihood=None, + paths=["labeled-data/test/img000.png"], + ) + + df = form_df_from_validated(ctx) + + # Index should become path string (then guarantee_multiindex_rows may split it) + assert isinstance(df.index, pd.MultiIndex) + assert df.index.to_list()[0][-1] == "img000.png" + + # Ensure x,y written correctly + # Column selection: (scorer, individuals, bodyparts, coords) + assert df.loc[df.index[0], ("S", "", "bp1", "x")] == 55.0 + assert df.loc[df.index[0], ("S", "", "bp1", "y")] == 44.0 + + +def test_form_df_from_validated_accepts_3level_header_and_preserves_values(): + """ + DLC still uses both 3-level (single-animal) and 4-level (multi-animal) headers in the wild. [1](https://github.com/DeepLabCut/DeepLabCut/issues/3072) + Our writer must accept a 3-level header and must not lose semantic meaning. + The current implementation upgrades to 4-level with individuals="". + """ + header3 = cols_3level(scorer="S", bodyparts=("bp1",), coords=("x", "y")) + ctx = make_points_ctx( + header_cols=header3, + points_data=np.array([[0.0, 44.0, 55.0]]), + labels=["bp1"], + ids=[""], # single-animal sentinel + likelihood=None, + paths=["labeled-data/test/img000.png"], + ) + + df = form_df_from_validated(ctx) + assert isinstance(df.columns, pd.MultiIndex) + + if df.columns.nlevels == 3: + assert df.columns.names == ["scorer", "bodyparts", "coords"] + assert df.loc[df.index[0], ("S", "bp1", "x")] == 55.0 + assert df.loc[df.index[0], ("S", "bp1", "y")] == 44.0 + else: + assert df.columns.nlevels == 4 + assert df.columns.names == ["scorer", "individuals", "bodyparts", "coords"] + assert set(df.columns.get_level_values("individuals")) == {""} + assert df.loc[df.index[0], ("S", "", "bp1", "x")] == 55.0 + assert df.loc[df.index[0], ("S", "", "bp1", "y")] == 44.0 + + +def test_form_df_from_validated_raises_when_header_reindex_drops_all_finite_points(): + """ + The function has an invariant: if layer has finite points, df must retain finite coords. + We trigger mismatch by giving a header that does NOT include the labeled bodypart. + """ + header = cols_4level(scorer="S", individuals=("",), bodyparts=("DIFFERENT_BP",), coords=("x", "y")) + ctx = make_points_ctx( + header_cols=header, + points_data=np.array([[0.0, 44.0, 55.0]]), + labels=["bp1"], # not present in header -> reindex wipes coords + ids=[""], + likelihood=None, + paths=["labeled-data/test/img000.png"], + ) + + with pytest.raises(RuntimeError, match="Writer produced no finite coordinates"): + _ = form_df_from_validated(ctx) diff --git a/src/napari_deeplabcut/_tests/core/test_discovery.py b/src/napari_deeplabcut/_tests/core/test_discovery.py new file mode 100644 index 00000000..7b53072c --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_discovery.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from pathlib import Path + +from napari_deeplabcut.config.models import AnnotationKind +from napari_deeplabcut.core.discovery import ( + discover_annotation_paths, + discover_annotations, + iter_annotation_candidates, +) + + +def _touch(p: Path) -> Path: + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text("x", encoding="utf-8") + return p + + +def test_discover_annotation_artifacts_groups_h5_and_csv(tmp_path: Path): + # Two GT + one machine, with mixed presence of CSV companions. + _touch(tmp_path / "CollectedData_John.h5") + _touch(tmp_path / "CollectedData_John.csv") + _touch(tmp_path / "CollectedData_Jane.h5") # no csv + _touch(tmp_path / "machinelabels-iter0.h5") + _touch(tmp_path / "machinelabels-iter0.csv") + + arts = discover_annotations(tmp_path) + assert len(arts) == 3 + + # Deterministic ordering by filename + assert [a.stem for a in arts] == ["CollectedData_Jane", "CollectedData_John", "machinelabels-iter0"] + + # Kind inference + assert arts[0].kind == AnnotationKind.GT + assert arts[1].kind == AnnotationKind.GT + assert arts[2].kind == AnnotationKind.MACHINE + + # Pairing behavior + by_stem = {a.stem: a for a in arts} + assert by_stem["CollectedData_John"].h5_path.name.endswith(".h5") + assert by_stem["CollectedData_John"].csv_path.name.endswith(".csv") + assert by_stem["CollectedData_Jane"].csv_path is None + assert by_stem["machinelabels-iter0"].csv_path is not None + + +def test_discover_annotation_paths_prefers_h5(tmp_path: Path): + _touch(tmp_path / "CollectedData_John.csv") + _touch(tmp_path / "CollectedData_John.h5") + paths = discover_annotation_paths(tmp_path) + assert len(paths) == 1 + assert paths[0].suffix.lower() == ".h5" + + +def test_discover_annotation_paths_supports_csv_only(tmp_path: Path): + _touch(tmp_path / "CollectedData_John.csv") + paths = discover_annotation_paths(tmp_path) + assert len(paths) == 1 + assert paths[0].suffix.lower() == ".csv" + + +def test_iter_annotation_candidates_expands_folders_and_files(tmp_path: Path): + folder = tmp_path / "shared" + folder.mkdir() + + _touch(folder / "CollectedData_John.h5") + _touch(folder / "machinelabels-iter0.h5") + + # Provide a mixture of folder and direct file input + extra_file = _touch(tmp_path / "CollectedData_Jane.h5") + + out = iter_annotation_candidates([folder, extra_file]) + + # Deterministic order by filename + assert [p.name for p in out] == ["CollectedData_Jane.h5", "CollectedData_John.h5", "machinelabels-iter0.h5"] diff --git a/src/napari_deeplabcut/_tests/test_keypoints.py b/src/napari_deeplabcut/_tests/core/test_keypoints.py similarity index 96% rename from src/napari_deeplabcut/_tests/test_keypoints.py rename to src/napari_deeplabcut/_tests/core/test_keypoints.py index 76e4c811..215f02c9 100644 --- a/src/napari_deeplabcut/_tests/test_keypoints.py +++ b/src/napari_deeplabcut/_tests/core/test_keypoints.py @@ -1,6 +1,6 @@ import numpy as np -from napari_deeplabcut import keypoints +from napari_deeplabcut.core import keypoints def test_store_advance_step(store): @@ -61,7 +61,7 @@ def test_point_resize(qtbot, viewer, points): def test_add_unannotated(store): # LOOP mode: after a successful add/move, the viewer advances to the next frame - store.layer.metadata["controls"].label_mode = "loop" + store._get_label_mode = lambda: keypoints.LabelMode.LOOP # Make frame 1 unannotated by removing all its rows from the layer data ind_to_remove = 1 @@ -92,7 +92,7 @@ def test_add_unannotated(store): def test_add_quick(store): # QUICK mode: if the keypoint for the current frame already exists, it is MOVED; otherwise, it is ADDED. # QUICK does NOT auto-advance the viewer. - store.layer.metadata["controls"].label_mode = "quick" + store._get_label_mode = lambda: keypoints.LabelMode.QUICK # Choose a specific keypoint to act on; this determines which (label, id) is added/moved store.current_keypoint = store._keypoints[0] diff --git a/src/napari_deeplabcut/_tests/core/test_layer_versioning.py b/src/napari_deeplabcut/_tests/core/test_layer_versioning.py new file mode 100644 index 00000000..1084d494 --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_layer_versioning.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import numpy as np +import pytest +from napari.layers import Points + +from napari_deeplabcut.core import keypoints +from napari_deeplabcut.core.layer_versioning import ( + detach_layer_change_hooks, + layer_change_generations, + mark_layer_content_changed, + mark_layer_presentation_changed, +) +from napari_deeplabcut.core.trails import trails_geometry_signature, trails_signature + + +@pytest.fixture +def points_layer() -> Points: + return Points( + data=np.array( + [ + [0, 1, 2], + [1, 3, 4], + [2, 5, 6], + [3, 7, 8], + ], + dtype=float, + ), + properties={ + "label": np.array(["nose", "tail", "nose", "tail"], dtype=object), + "id": np.array(["", "", "", ""], dtype=object), + }, + metadata={ + "colormap_name": "magma", + }, + name="points", + ) + + +def test_layer_change_generations_start_at_zero(points_layer: Points): + generations = layer_change_generations(points_layer) + + assert generations.content == 0 + assert generations.presentation == 0 + + +def test_layer_change_generations_bump_on_data_assignment(points_layer: Points): + before = layer_change_generations(points_layer).content + + data = np.asarray(points_layer.data).copy() + data[0, 0] += 1 + points_layer.data = data + + after = layer_change_generations(points_layer).content + assert after > before + + +def test_layer_change_generations_bump_on_properties_assignment(points_layer: Points): + before = layer_change_generations(points_layer).content + + props = dict(points_layer.properties) + props["label"] = np.array(["nose", "tail", "ear", "tail"], dtype=object) + points_layer.properties = props + + after = layer_change_generations(points_layer).content + assert after > before + + +def test_layer_change_generations_bump_on_metadata_assignment(points_layer: Points): + before = layer_change_generations(points_layer).presentation + + md = dict(points_layer.metadata or {}) + md["colormap_name"] = "viridis" + points_layer.metadata = md + + after = layer_change_generations(points_layer).presentation + assert after == before + 1 + + +def test_manual_content_mark_bumps_generation(points_layer: Points): + before = layer_change_generations(points_layer).content + + mark_layer_content_changed(points_layer) + + after = layer_change_generations(points_layer).content + assert after == before + 1 + + +def test_manual_presentation_mark_bumps_generation(points_layer: Points): + before = layer_change_generations(points_layer).presentation + + mark_layer_presentation_changed(points_layer) + + after = layer_change_generations(points_layer).presentation + assert after == before + 1 + + +def test_trails_geometry_signature_tracks_content_generation(points_layer: Points): + sig_before = trails_geometry_signature(points_layer) + + data = np.asarray(points_layer.data).copy() + data[0, 0] += 1 + points_layer.data = data + + sig_after = trails_geometry_signature(points_layer) + + assert sig_before[0] == id(points_layer) + assert sig_after[0] == id(points_layer) + assert sig_after != sig_before + + +def test_trails_signature_tracks_presentation_generation(points_layer: Points): + sig_before = trails_signature(points_layer, keypoints.ColorMode.BODYPART) + + md = dict(points_layer.metadata or {}) + md["colormap_name"] = "viridis" + points_layer.metadata = md + + sig_after = trails_signature(points_layer, keypoints.ColorMode.BODYPART) + + assert sig_before[0] == id(points_layer) + assert sig_after[0] == id(points_layer) + assert sig_before[1] == str(keypoints.ColorMode.BODYPART) + assert sig_after[1] == str(keypoints.ColorMode.BODYPART) + assert sig_after != sig_before + + +def test_trails_signature_tracks_content_generation(points_layer: Points): + sig_before = trails_signature(points_layer, keypoints.ColorMode.BODYPART) + + props = dict(points_layer.properties) + props["label"] = np.array(["nose", "tail", "ear", "tail"], dtype=object) + points_layer.properties = props + + sig_after = trails_signature(points_layer, keypoints.ColorMode.BODYPART) + + assert sig_after != sig_before + + +def test_detach_layer_change_hooks_reinstalls_cleanly(points_layer: Points): + sig_before = trails_geometry_signature(points_layer) + + detach_layer_change_hooks(points_layer) + + _ = layer_change_generations(points_layer) + + data = np.asarray(points_layer.data).copy() + data[0, 0] += 1 + points_layer.data = data + + sig_after = trails_geometry_signature(points_layer) + + assert sig_after != sig_before diff --git a/src/napari_deeplabcut/_tests/core/test_layers_metadata.py b/src/napari_deeplabcut/_tests/core/test_layers_metadata.py new file mode 100644 index 00000000..90c28c2c --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_layers_metadata.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from napari_deeplabcut.config.models import AnnotationKind +from napari_deeplabcut.core import keypoints +from napari_deeplabcut.core.layers import is_machine_layer, populate_keypoint_layer_properties + + +class HeaderStub: + """Minimal stand-in for DLCHeaderModel. + + In DLC single-animal projects, there is no individuals level; your code + treats that as individuals == [''] so ids[0] is falsy. In multi-animal, + individuals contains names like ['animal1', ...] so ids[0] is truthy. [3](https://forum.image.sc/t/how-to-generate-h5-files-from-pre-collected-data/61936) + """ + + def __init__(self, bodyparts=("bp1", "bp2"), individuals=("",)): + self.bodyparts = list(bodyparts) + self.individuals = list(individuals) + + +@pytest.fixture +def patch_color_cycles(monkeypatch): + """Make build_color_cycles deterministic and independent of colormap internals.""" + + def fake_build_color_cycles(header, colormap): + # Must return cycles for both candidate face_color properties. + return { + "label": {"bp1": "red", "bp2": "blue"}, + "id": {"": "gray", "animal1": "green"}, + } + + monkeypatch.setattr(keypoints, "build_color_cycles", fake_build_color_cycles) + + +# ----------------------------------------------------------------------------- +# Face color/text selection: single animal vs multi animal +# ----------------------------------------------------------------------------- + + +def test_single_animal_uses_label_for_face_color_and_text(patch_color_cycles): + """Single-animal: no individuals dimension => ids[0] is falsy, use label. [3](https://forum.image.sc/t/how-to-generate-h5-files-from-pre-collected-data/61936)""" + header = HeaderStub(bodyparts=("bp1", "bp2"), individuals=("",)) # single-animal sentinel + md = populate_keypoint_layer_properties(header) + + assert md["face_color"] == "label" + assert md["text"] == "label" + assert md["face_color_cycle"] == md["metadata"]["face_color_cycles"]["label"] + + +def test_multi_animal_uses_id_for_face_color_and_text(patch_color_cycles): + """Multi-animal: individuals dimension present => ids[0] truthy, use id. [3](https://forum.image.sc/t/how-to-generate-h5-files-from-pre-collected-data/61936)""" + header = HeaderStub(bodyparts=("bp1", "bp2"), individuals=("animal1", "animal2")) + md = populate_keypoint_layer_properties(header) + + assert md["face_color"] == "id" + assert md["text"] == "{id}–{label}" + assert md["face_color_cycle"] == md["metadata"]["face_color_cycles"]["id"] + + +# ----------------------------------------------------------------------------- +# Robustness: must accept empty labels/ids/likelihood +# ----------------------------------------------------------------------------- + + +def test_empty_ids_must_not_crash_defaults_to_label(patch_color_cycles): + """Regression guard for the E2E failure: ids[0] must not be assumed.""" + header = HeaderStub(bodyparts=("bp1", "bp2"), individuals=("",)) + md = populate_keypoint_layer_properties(header, labels=["bp1"], ids=[]) + + # When ids is empty, treat like single-animal (label-based) rather than crashing. + assert md["face_color"] == "label" + assert md["text"] == "label" + assert md["properties"]["id"] == [] + + +def test_empty_labels_and_ids_produce_empty_properties(patch_color_cycles): + header = HeaderStub(bodyparts=(), individuals=()) + md = populate_keypoint_layer_properties(header, labels=[], ids=[], likelihood=None) + + assert md["properties"]["label"] == [] + assert md["properties"]["id"] == [] + assert md["properties"]["likelihood"].shape == (0,) + assert md["properties"]["valid"].shape == (0,) + assert md["face_color"] == "label" + assert md["text"] == "label" + + +# ----------------------------------------------------------------------------- +# Likelihood + pcutoff behavior +# ----------------------------------------------------------------------------- + + +def test_valid_is_thresholded_by_pcutoff(patch_color_cycles): + """Likelihood is per-frame confidence in DLC outputs; valid derived via cutoff. [1](https://deeplabcut.github.io/DeepLabCut/docs/HelperFunctions.html)[4](https://forum.image.sc/t/what-to-do-with-likelihoods-0-95/45897)""" + header = HeaderStub(bodyparts=("bp1", "bp2"), individuals=("",)) + likelihood = np.array([0.2, 0.9], dtype=float) + + md = populate_keypoint_layer_properties( + header, + labels=["bp1", "bp2"], + ids=[""], # single-animal sentinel + likelihood=likelihood, + pcutoff=0.6, + ) + assert md["properties"]["valid"].tolist() == [False, True] + + +def test_default_likelihood_is_ones_of_len_labels(patch_color_cycles): + """Default fallback behavior should be stable even when likelihood not provided.""" + header = HeaderStub(bodyparts=("bp1", "bp2"), individuals=("",)) + md = populate_keypoint_layer_properties(header, labels=["bp1", "bp2"], ids=[""]) + + assert np.all(md["properties"]["likelihood"] == np.ones(2)) + assert np.all(md["properties"]["valid"] == (np.ones(2) > 0.6)) + + +# ----------------------------------------------------------------------------- +# is_machine_layer +# ----------------------------------------------------------------------------- + + +class LayerStub: + def __init__(self, metadata): + self.metadata = metadata + + +def test_is_machine_layer_true_for_enum_kind(caplog): + layer = LayerStub(metadata={"io": {"kind": AnnotationKind.MACHINE}}) + assert is_machine_layer(layer) is True + assert "literal 'machine' str" not in caplog.text + + +@pytest.mark.parametrize("k", ["machine", "MACHINE", "Machine"]) +def test_is_machine_layer_true_for_string_kind_logs_info(caplog, k): + layer = LayerStub(metadata={"io": {"kind": k}}) + assert is_machine_layer(layer) is True + assert "literal 'machine' str was used for io.kind" in caplog.text + + +@pytest.mark.parametrize("metadata", [{}, {"io": {}}, {"io": {"kind": None}}, {"io": {"kind": AnnotationKind.GT}}]) +def test_is_machine_layer_false_for_missing_or_non_machine(metadata): + layer = LayerStub(metadata=metadata) + assert is_machine_layer(layer) is False + + +def test_ids_as_pandas_series_single_animal_does_not_crash(patch_color_cycles): + import pandas as pd + + header = HeaderStub(bodyparts=("bp1",), individuals=("",)) + md = populate_keypoint_layer_properties(header, labels=["bp1"], ids=pd.Series([""], name="individuals")) + assert md["face_color"] == "label" + assert md["text"] == "label" + + +def test_ids_as_empty_pandas_series_does_not_crash_defaults_to_label(patch_color_cycles): + import pandas as pd + + header = HeaderStub(bodyparts=("bp1",), individuals=("",)) + md = populate_keypoint_layer_properties(header, labels=["bp1"], ids=pd.Series([], dtype=str, name="individuals")) + assert md["face_color"] == "label" + assert md["text"] == "label" + assert md["properties"]["id"] == [] + + +def test_ids_as_pandas_series_multi_animal_uses_id(patch_color_cycles): + import pandas as pd + + header = HeaderStub(bodyparts=("bp1",), individuals=("animal1",)) + md = populate_keypoint_layer_properties(header, labels=["bp1"], ids=pd.Series(["animal1"], name="individuals")) + assert md["face_color"] == "id" + assert md["text"] == "{id}–{label}" diff --git a/src/napari_deeplabcut/_tests/core/test_metadata.py b/src/napari_deeplabcut/_tests/core/test_metadata.py new file mode 100644 index 00000000..cf3347b8 --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_metadata.py @@ -0,0 +1,553 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import pandas as pd +import pytest +from pydantic import BaseModel, ValidationError + +import napari_deeplabcut.core.metadata as metadata_mod +from napari_deeplabcut.config.models import AnnotationKind, DLCHeaderModel, ImageMetadata, PointsMetadata +from napari_deeplabcut.core.errors import AmbiguousSaveError, MissingProvenanceError +from napari_deeplabcut.core.metadata import build_io_provenance_dict + +# ----------------------------------------------------------------------------- +# small helpers +# ----------------------------------------------------------------------------- + + +def _make_validation_error() -> ValidationError: + class TmpModel(BaseModel): + x: int + + try: + TmpModel.model_validate({"x": "not-an-int"}) + except ValidationError as e: + return e + raise AssertionError("expected ValidationError") + + +class DummyLayer: + def __init__(self, metadata=None, name="dummy-layer"): + self.metadata = metadata + self.name = name + + +# ----------------------------------------------------------------------------- +# pure helper coverage +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("path_str", "expected"), + [ + ("project/labeled-data/mouse1", True), + ("project/LABELED-DATA/mouse1", True), + ("project/labeled-data", False), + ("project/images/mouse1", False), + ], +) +def test_is_dlc_dataset_root(path_str: str, expected: bool): + assert metadata_mod._is_dlc_dataset_root(Path(path_str)) is expected + + +@pytest.mark.parametrize( + ("paths", "expected"), + [ + (None, False), + ([], False), + (["images/img001.png"], False), + (["labeled-data/test/img001.png"], True), + ([r"labeled-data\test\img001.png"], True), + ], +) +def test_paths_look_like_labeled_data(paths, expected): + assert metadata_mod._paths_look_like_labeled_data(paths) is expected + + +def test_looks_like_project_root_true_when_same_path(tmp_path: Path): + assert metadata_mod._looks_like_project_root(str(tmp_path), str(tmp_path)) is True + + +def test_looks_like_project_root_false_when_different(tmp_path: Path): + other = tmp_path / "other" + assert metadata_mod._looks_like_project_root(str(tmp_path), str(other)) is False + + +def test_infer_image_root_prefers_explicit_root(tmp_path: Path): + p = tmp_path / "images" / "img001.png" + p.parent.mkdir(parents=True) + p.touch() + + result = metadata_mod.infer_image_root( + explicit_root="/explicit/root", + paths=[str(p)], + source_path=str(p), + ) + + assert result == "/explicit/root" + + +def test_infer_image_root_uses_first_path_parent(tmp_path: Path): + p = tmp_path / "images" / "img001.png" + p.parent.mkdir(parents=True) + p.touch() + + result = metadata_mod.infer_image_root(paths=[str(p)]) + assert result == str(p.parent.resolve()) + + +def test_infer_image_root_falls_back_to_source_path_parent(tmp_path: Path): + p = tmp_path / "images" / "img001.png" + p.parent.mkdir(parents=True) + p.touch() + + result = metadata_mod.infer_image_root(source_path=str(p)) + assert result == str(p.parent.resolve()) + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (None, True), + ("", True), + ([], True), + ({}, True), + ((), True), + (False, False), + (0, False), + ("x", False), + ([1], False), + ], +) +def test_is_empty_value(value, expected): + assert metadata_mod._is_empty_value(value) is expected + + +def test_require_unique_target_returns_single_candidate(tmp_path: Path): + candidate = tmp_path / "CollectedData_A.h5" + assert metadata_mod.require_unique_target([candidate]) == candidate + + +def test_require_unique_target_raises_when_missing(): + with pytest.raises(MissingProvenanceError, match="No candidates found"): + metadata_mod.require_unique_target([], context="save target") + + +def test_require_unique_target_raises_when_ambiguous(tmp_path: Path): + c1 = tmp_path / "CollectedData_A.h5" + c2 = tmp_path / "CollectedData_B.h5" + + with pytest.raises(AmbiguousSaveError, match="Ambiguous save target"): + metadata_mod.require_unique_target([c1, c2], context="save target") + + +# ----------------------------------------------------------------------------- +# merge / sync +# ----------------------------------------------------------------------------- + + +def test_merge_image_metadata_only_fills_missing_fields(): + base = ImageMetadata(root="rootA", name="", shape=None) + incoming = ImageMetadata(root="rootB", name="images", shape=[10, 20]) + + merged = metadata_mod.merge_image_metadata(base, incoming) + + assert merged.root == "rootA" # preserved + assert merged.name == "images" # filled + assert tuple(merged.shape) == (10, 20) # filled + + +def test_merge_points_metadata_does_not_clobber_and_skips_controls(): + base = PointsMetadata(root="rootA", name="", controls={"runtime": 1}) + incoming = PointsMetadata(root="rootB", name="points", controls={"runtime": 2}) + + merged = metadata_mod.merge_points_metadata(base, incoming) + + assert merged.root == "rootA" + assert merged.name == "points" + # controls should not be copied from incoming + assert getattr(merged, "controls", None) in (None, {"runtime": 1}) + + +def test_sync_points_from_image_fills_missing_fields(): + image_meta = ImageMetadata( + root="project/labeled-data/mouse1", + paths=["labeled-data/mouse1/img001.png"], + shape=[100, 200], + name="images", + ) + points_meta = PointsMetadata() + + synced = metadata_mod.sync_points_from_image(image_meta, points_meta) + + assert synced.root == "project/labeled-data/mouse1" + assert synced.paths == ["labeled-data/mouse1/img001.png"] + assert tuple(synced.shape) == (100, 200) + assert synced.name == "images" + + +def test_sync_points_from_image_overrides_project_root_with_dataset_root(tmp_path: Path): + project_root = tmp_path / "project" + dataset_root = project_root / "labeled-data" / "mouse1" + dataset_root.mkdir(parents=True) + + image_meta = ImageMetadata( + root=str(dataset_root), + paths=[str(dataset_root / "img001.png")], + name="images", + ) + points_meta = PointsMetadata( + root=str(project_root), # stale / wrong + project=str(project_root), + ) + + synced = metadata_mod.sync_points_from_image(image_meta, points_meta) + + assert synced.root == str(dataset_root) + + +def test_sync_points_from_image_keeps_existing_dataset_root_when_already_good(tmp_path: Path): + project_root = tmp_path / "project" + good_points_root = project_root / "labeled-data" / "mouse1" + other_dataset_root = project_root / "labeled-data" / "mouse2" + good_points_root.mkdir(parents=True) + other_dataset_root.mkdir(parents=True) + + image_meta = ImageMetadata(root=str(other_dataset_root)) + points_meta = PointsMetadata( + root=str(good_points_root), + project=str(project_root), + ) + + synced = metadata_mod.sync_points_from_image(image_meta, points_meta) + + # already a valid dataset root -> do not overwrite + assert synced.root == str(good_points_root) + + +def test_ensure_metadata_models_accepts_dicts_and_models(): + ImageMetadata(root="img-root") + pts_model = PointsMetadata(root="pts-root") + + img, pts = metadata_mod.ensure_metadata_models( + {"root": "img-root"}, + pts_model, + ) + + assert isinstance(img, ImageMetadata) + assert img.root == "img-root" + assert pts is pts_model + + +# ----------------------------------------------------------------------------- +# parsing / coercion +# ----------------------------------------------------------------------------- + + +def test_normalize_columns_handles_index_and_multiindex(): + idx = pd.Index(["a", "b"]) + mi = pd.MultiIndex.from_tuples([("scorer", "nose"), ("scorer", "tail")]) + + assert metadata_mod._normalize_columns(idx) == ["a", "b"] + assert metadata_mod._normalize_columns(mi) == [("scorer", "nose"), ("scorer", "tail")] + + +def test_coerce_io_kind_accepts_value_and_name(): + d1 = {"kind": "gt"} + d2 = {"kind": "MACHINE"} + + metadata_mod._coerce_io_kind(d1) + metadata_mod._coerce_io_kind(d2) + + assert d1["kind"] == AnnotationKind.GT + assert d2["kind"] == AnnotationKind.MACHINE + + +def test_parse_points_metadata_none_returns_empty_model(): + parsed = metadata_mod.parse_points_metadata(None) + assert isinstance(parsed, PointsMetadata) + + +def test_parse_points_metadata_drops_controls_and_coerces_kinds(monkeypatch): + captured = {} + + def fake_model_validate(payload): + captured["payload"] = payload + return PointsMetadata() + + monkeypatch.setattr(metadata_mod.PointsMetadata, "model_validate", fake_model_validate) + + md = { + "controls": {"runtime": object()}, + "io": { + "kind": "gt", + "project_root": "/tmp", + "source_relpath_posix": "CollectedData_A.h5", + "dataset_key": "keypoints", + }, + "save_target": {"kind": "MACHINE"}, + } + + parsed = metadata_mod.parse_points_metadata(md) + + assert isinstance(parsed, PointsMetadata) + assert "controls" not in captured["payload"] + assert captured["payload"]["io"]["kind"] == AnnotationKind.GT + assert captured["payload"]["save_target"]["kind"] == AnnotationKind.MACHINE + + +def test_parse_points_metadata_drop_header_removes_header(monkeypatch): + captured = {} + + def fake_model_validate(payload): + captured["payload"] = payload + return PointsMetadata() + + monkeypatch.setattr(metadata_mod.PointsMetadata, "model_validate", fake_model_validate) + + parsed = metadata_mod.parse_points_metadata( + {"header": {"columns": [("scorer", "nose")]}, "root": "x"}, + drop_header=True, + ) + + assert isinstance(parsed, PointsMetadata) + assert "header" not in captured["payload"] + + +def test_parse_points_metadata_falls_back_to_empty_model_on_validation_error(monkeypatch): + def boom(payload): + raise RuntimeError("bad metadata") + + monkeypatch.setattr(metadata_mod.PointsMetadata, "model_validate", boom) + + parsed = metadata_mod.parse_points_metadata({"root": "x"}) + assert isinstance(parsed, PointsMetadata) + + +def test_coerce_header_model_none_passthrough(): + assert metadata_mod.coerce_header_model(None) is None + + +def test_coerce_header_model_returns_existing_model(): + # NOTE: might need to use the header fixture here + header = DLCHeaderModel(columns=[("scorer", "nose")]) + assert metadata_mod.coerce_header_model(header) is header + + +# ----------------------------------------------------------------------------- +# metadata dict / legacy migration helpers +# ----------------------------------------------------------------------------- + + +def test_layer_metadata_dict_handles_none_and_mapping_like(): + assert metadata_mod._layer_metadata_dict(SimpleNamespace(metadata=None)) == {} + assert metadata_mod._layer_metadata_dict(SimpleNamespace(metadata={"a": 1})) == {"a": 1} + + class MappingLike: + def __iter__(self): + return iter([("x", 1)]) + + assert metadata_mod._layer_metadata_dict(SimpleNamespace(metadata=MappingLike())) == {"x": 1} + + +def test_build_io_from_source_h5_returns_none_for_empty_source(): + assert metadata_mod._build_io_from_source_h5("") is None + assert metadata_mod._build_io_from_source_h5(None) is None + + +def test_prepare_points_payload_migrates_legacy_source_h5(monkeypatch): + monkeypatch.setattr( + metadata_mod, + "_build_io_from_source_h5", + lambda src, dataset_key="keypoints": {"kind": AnnotationKind.GT, "dataset_key": dataset_key}, + ) + + payload = metadata_mod._prepare_points_payload( + {"source_h5": "/tmp/CollectedData_A.h5"}, + migrate_legacy=True, + ) + + assert payload["io"]["kind"] == AnnotationKind.GT + assert payload["io"]["dataset_key"] == "keypoints" + + +# ----------------------------------------------------------------------------- +# attach_source_and_io_to_layer_kwargs +# ----------------------------------------------------------------------------- + + +def test_attach_source_and_io_to_layer_kwargs_sets_legacy_fields_and_io(monkeypatch, tmp_path: Path): + file_path = tmp_path / "CollectedData_Jane.h5" + file_path.touch() + + monkeypatch.setattr(metadata_mod, "canonicalize_path", lambda p, n=1: "CollectedData_Jane.h5") + monkeypatch.setattr(metadata_mod, "infer_annotation_kind_for_file", lambda p: AnnotationKind.GT) + + metadata = {} + metadata_mod.attach_source_and_io_to_layer_kwargs(metadata, file_path) + + inner = metadata["metadata"] + assert inner["source_h5_name"] == "CollectedData_Jane.h5" + assert inner["source_h5_stem"] == "CollectedData_Jane" + assert inner["source_h5"].endswith("CollectedData_Jane.h5") + assert inner["io"]["kind"] == AnnotationKind.GT + assert inner["io"]["source_relpath_posix"] == "CollectedData_Jane.h5" + assert inner["io"]["dataset_key"] == "keypoints" + + +# ----------------------------------------------------------------------------- +# read/write adapter gateway +# ----------------------------------------------------------------------------- + + +def test_read_points_meta_returns_validation_error(monkeypatch): + err = _make_validation_error() + + monkeypatch.setattr(metadata_mod, "_prepare_points_payload", lambda *args, **kwargs: {"bad": "payload"}) + monkeypatch.setattr(metadata_mod.PointsMetadata, "model_validate", lambda payload: (_ for _ in ()).throw(err)) + + layer = DummyLayer(metadata={"root": "x"}) + result = metadata_mod.read_points_meta(layer) + + assert isinstance(result, ValidationError) + + +def test_write_points_meta_merge_missing_preserves_existing_and_restores_header(monkeypatch): + header = DLCHeaderModel(columns=[("scorer", "nose")]) + layer = DummyLayer(metadata={"root": "existing-root", "header": header}) + + validated = PointsMetadata(root="existing-root", name="incoming-name", header=header) + + def fake_model_validate(payload): + # root should remain existing because MERGE_MISSING only fills empties + assert payload["root"] == "existing-root" + assert payload["name"] == "incoming-name" + return validated + + monkeypatch.setattr(metadata_mod.PointsMetadata, "model_validate", fake_model_validate) + + result = metadata_mod.write_points_meta( + layer, + {"root": "new-root", "name": "incoming-name"}, + metadata_mod.MergePolicy.MERGE_MISSING, + validate=True, + ) + + assert isinstance(result, PointsMetadata) + assert layer.metadata["root"] == "existing-root" + assert layer.metadata["name"] == "incoming-name" + assert "header" in layer.metadata + + +def test_write_points_meta_replace_without_validation_writes_raw_mapping(): + layer = DummyLayer(metadata={"root": "old", "name": "old-name"}) + + result = metadata_mod.write_points_meta( + layer, + {"root": "new", "name": "new-name", "controls": {"runtime": 1}}, + metadata_mod.MergePolicy.REPLACE, + validate=False, + ) + + assert layer.metadata == {"root": "new", "name": "new-name"} + assert isinstance(result, (PointsMetadata, ValidationError)) + + +def test_write_points_meta_returns_validation_error_and_leaves_metadata_stable(monkeypatch): + err = _make_validation_error() + layer = DummyLayer(metadata={"root": "old"}) + + monkeypatch.setattr( + metadata_mod.PointsMetadata, + "model_validate", + lambda payload: (_ for _ in ()).throw(err), + ) + + result = metadata_mod.write_points_meta( + layer, + {"root": "new"}, + metadata_mod.MergePolicy.MERGE, + validate=True, + ) + + assert isinstance(result, ValidationError) + # write should not have replaced metadata after failed validation + assert layer.metadata == {"root": "old"} + + +def test_read_image_meta_returns_validation_error(monkeypatch): + err = _make_validation_error() + layer = DummyLayer(metadata={"root": "x"}) + + monkeypatch.setattr( + metadata_mod.ImageMetadata, + "model_validate", + lambda payload: (_ for _ in ()).throw(err), + ) + + result = metadata_mod.read_image_meta(layer) + assert isinstance(result, ValidationError) + + +def test_write_image_meta_merge_policy_string_and_fields_filter(): + layer = DummyLayer(metadata={"root": "old", "name": ""}) + + result = metadata_mod.write_image_meta( + layer, + {"root": "new", "name": "images", "shape": [10, 20]}, + "merge_missing", + fields={"name"}, + validate=True, + ) + + assert isinstance(result, ImageMetadata) + assert layer.metadata["root"] == "old" + assert layer.metadata["name"] == "images" + assert "shape" not in layer.metadata + + +def test_migrate_points_layer_metadata_round_trips_through_gateway(monkeypatch): + layer = DummyLayer(metadata={"source_h5": "/tmp/CollectedData_A.h5"}) + + read_result = PointsMetadata(root="rootA") + write_result = PointsMetadata(root="rootA", name="points") + + monkeypatch.setattr(metadata_mod, "read_points_meta", lambda *args, **kwargs: read_result) + monkeypatch.setattr(metadata_mod, "write_points_meta", lambda *args, **kwargs: write_result) + + result = metadata_mod.migrate_points_layer_metadata(layer) + + assert result is write_result + + +# ----------------------------------------------------------------------------- +# build_io_provenance_dict +# ----------------------------------------------------------------------------- + + +def test_build_io_provenance_dict_keeps_enum_kind_object(tmp_path: Path): + d = build_io_provenance_dict( + project_root=tmp_path, + source_relpath_posix="CollectedData_Jane.h5", + kind=AnnotationKind.GT, + dataset_key="keypoints", + ) + # mode="python" => should keep enum object at runtime + assert isinstance(d["kind"], AnnotationKind) + assert d["kind"] == AnnotationKind.GT + assert d["project_root"] == str(tmp_path) + assert d["source_relpath_posix"] == "CollectedData_Jane.h5" + assert d["dataset_key"] == "keypoints" + assert d["schema_version"] == 1 + + +def test_build_io_provenance_dict_excludes_none_fields(tmp_path: Path): + d = build_io_provenance_dict( + project_root=tmp_path, + source_relpath_posix="CollectedData_Jane.h5", + kind=None, # exclude_none=True => kind should be absent + dataset_key="keypoints", + ) + assert "kind" not in d diff --git a/src/napari_deeplabcut/_tests/core/test_project_paths.py b/src/napari_deeplabcut/_tests/core/test_project_paths.py new file mode 100644 index 00000000..ad1e4853 --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_project_paths.py @@ -0,0 +1,620 @@ +from __future__ import annotations + +import inspect +from pathlib import Path +from types import SimpleNamespace + +import pytest + +import napari_deeplabcut.core.project_paths as paths_mod +from napari_deeplabcut.core.project_paths import ( + coerce_paths_to_dlc_row_keys, + dataset_folder_has_files, + infer_dlc_project_from_config, + resolve_project_root_from_config, + target_dataset_folder_for_config, +) + + +# ----------------------------------------------------------------------------- +# canonicalize_path +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + ("value", "n", "expected"), + [ + # basic POSIX cases + ("root/sub1/sub2/file.png", 3, "sub1/sub2/file.png"), + ("root/sub/file.png", 2, "sub/file.png"), + ("root/sub/file.png", 1, "file.png"), + ("a/b/c", 10, "a/b/c"), + (Path("a/b/c/d.txt"), 3, "b/c/d.txt"), + # empty / degenerate inputs + ("", 3, ""), + (".", 3, ""), + ("..", 3, ""), + ("/", 3, ""), + ("a/b/c/", 3, "a/b/c"), + # non-string coercion + (123, 3, "123"), + # Windows / mixed separators + (r"a\b\c\file.png", 3, "b/c/file.png"), + (r"frames\\test\video0/img001.png", 3, "test/video0/img001.png"), + # string-based normalization, not filesystem resolution + ("./a/b/../c/d", 3, "b/c/d"), + # invalid n + ("a/b/c", 0, ValueError), + ("a/b/c/d", -1, ValueError), + ], +) +def test_canonicalize_path_contract(value, n, expected): + is_exc_class = inspect.isclass(expected) and issubclass(expected, Exception) + if is_exc_class: + with pytest.raises(expected): + paths_mod.canonicalize_path(value, n=n) + return + + assert paths_mod.canonicalize_path(value, n=n) == expected + + +def test_canonicalize_path_removes_backslashes(): + out = paths_mod.canonicalize_path(r"a\b\c\file.png", n=3) + assert "\\" not in out + + +def test_canonicalize_path_stringifies_objects_and_normalizes_separators(): + class Weird: + def __str__(self): + return r"x\y\z" + + out = paths_mod.canonicalize_path(Weird(), n=3) # type: ignore[arg-type] + assert out == "x/y/z" + assert "\\" not in out + + +def test_canonicalize_path_returns_empty_string_when_stringify_fails(): + class BadPath: + def __str__(self): + raise RuntimeError("boom") + + assert paths_mod.canonicalize_path(BadPath(), n=3) == "" + + +# ----------------------------------------------------------------------------- +# PathMatchPolicy / find_matching_depth +# ----------------------------------------------------------------------------- +def test_path_match_policy_ordered_depths(): + assert paths_mod.PathMatchPolicy.ORDERED_DEPTHS.depths == (3, 2, 1) + + +def test_find_matching_depth_prefers_deepest_first_match(): + old_paths = [ + "/project/labeled-data/mouse1/img001.png", + "/project/labeled-data/mouse1/img002.png", + ] + new_paths = [ + "/other/root/labeled-data/mouse1/img002.png", + "/other/root/labeled-data/mouse1/img003.png", + ] + + # exact overlap at depth=3 -> labeled-data/mouse1/img002.png + assert paths_mod.find_matching_depth(old_paths, new_paths) == 3 + + +def test_find_matching_depth_falls_back_to_shallower_depth(): + old_paths = ["/a/b/c/img001.png"] + new_paths = ["/x/y/z/img001.png"] + + # depth=3 -> b/c/img001 vs y/z/img001 (no overlap) + # depth=2 -> c/img001 vs z/img001 (no overlap) + # depth=1 -> img001 (overlap) + assert paths_mod.find_matching_depth(old_paths, new_paths) == 1 + + +def test_find_matching_depth_returns_none_when_no_overlap(): + old_paths = ["/a/b/c/img001.png"] + new_paths = ["/x/y/z/img999.png"] + + assert paths_mod.find_matching_depth(old_paths, new_paths) is None + + +@pytest.mark.parametrize( + ("old_paths", "new_paths"), + [ + ([], ["/x/y/z/img001.png"]), + (["/a/b/c/img001.png"], []), + ([], []), + ], +) +def test_find_matching_depth_returns_none_for_empty_inputs(old_paths, new_paths): + assert paths_mod.find_matching_depth(old_paths, new_paths) is None + + +# ----------------------------------------------------------------------------- +# config.yaml / DLC artifact heuristics +# ----------------------------------------------------------------------------- +def test_is_config_yaml_true_only_for_existing_config_yaml(tmp_path: Path): + cfg = tmp_path / "config.yaml" + cfg.touch() + + other = tmp_path / "not_config.yaml" + other.touch() + + assert paths_mod.is_config_yaml(cfg) is True + assert paths_mod.is_config_yaml(other) is False + assert paths_mod.is_config_yaml(tmp_path / "missing_config.yaml") is False + + +def test_is_config_yaml_returns_false_for_bad_input(): + assert paths_mod.is_config_yaml(None) is False + + +def test_has_dlc_datafiles_detects_collecteddata_and_machinelabels(tmp_path: Path): + folder = tmp_path / "dataset" + folder.mkdir() + + assert paths_mod.has_dlc_datafiles(folder) is False + + (folder / "CollectedData_Jane.h5").touch() + assert paths_mod.has_dlc_datafiles(folder) is True + + # also cover another supported pattern + other = tmp_path / "dataset2" + other.mkdir() + (other / "machinelabels_alex.csv").touch() + assert paths_mod.has_dlc_datafiles(other) is True + + +def test_has_dlc_datafiles_returns_false_for_missing_or_non_directory(tmp_path: Path): + missing = tmp_path / "missing" + file_path = tmp_path / "a_file.txt" + file_path.touch() + + assert paths_mod.has_dlc_datafiles(missing) is False + assert paths_mod.has_dlc_datafiles(file_path) is False + + +def test_looks_like_dlc_labeled_folder_true_when_artifacts_present(tmp_path: Path): + folder = tmp_path / "some_folder" + folder.mkdir() + (folder / "CollectedData_Jane.csv").touch() + + assert paths_mod.looks_like_dlc_labeled_folder(folder) is True + + +def test_looks_like_dlc_labeled_folder_true_inside_labeled_data(tmp_path: Path): + folder = tmp_path / "project" / "labeled-data" / "mouse1" + folder.mkdir(parents=True) + + assert paths_mod.looks_like_dlc_labeled_folder(folder) is True + + +def test_looks_like_dlc_labeled_folder_false_for_regular_folder(tmp_path: Path): + folder = tmp_path / "images" + folder.mkdir() + + assert paths_mod.looks_like_dlc_labeled_folder(folder) is False + + +def test_should_force_dlc_reader_true_for_config_yaml(tmp_path: Path): + cfg = tmp_path / "config.yaml" + cfg.touch() + + assert paths_mod.should_force_dlc_reader(cfg) is True + + +def test_should_force_dlc_reader_true_for_labeled_folder(tmp_path: Path): + folder = tmp_path / "project" / "labeled-data" / "mouse1" + folder.mkdir(parents=True) + + assert paths_mod.should_force_dlc_reader(folder) is True + + +def test_should_force_dlc_reader_false_for_empty_or_regular_inputs(tmp_path: Path): + regular = tmp_path / "images" + regular.mkdir() + + assert paths_mod.should_force_dlc_reader([]) is False + assert paths_mod.should_force_dlc_reader(regular) is False + assert paths_mod.should_force_dlc_reader([regular]) is False + + +# ----------------------------------------------------------------------------- +# normalize_anchor_candidate +# ----------------------------------------------------------------------------- +def test_normalize_anchor_candidate_returns_directory_for_directory(tmp_path: Path): + folder = tmp_path / "dataset" + folder.mkdir() + + assert paths_mod.normalize_anchor_candidate(folder) == folder.resolve() + + +def test_normalize_anchor_candidate_returns_parent_for_file(tmp_path: Path): + folder = tmp_path / "dataset" + folder.mkdir() + file_path = folder / "CollectedData_Jane.h5" + file_path.touch() + + assert paths_mod.normalize_anchor_candidate(file_path) == folder.resolve() + + +def test_normalize_anchor_candidate_returns_path_for_missing_path(tmp_path: Path): + missing = tmp_path / "does_not_exist" + + result = paths_mod.normalize_anchor_candidate(missing) + assert result == missing.resolve() + + +def test_normalize_anchor_candidate_returns_none_for_none(): + assert paths_mod.normalize_anchor_candidate(None) is None + + +# ----------------------------------------------------------------------------- +# find_nearest_config +# ----------------------------------------------------------------------------- +def test_find_nearest_config_finds_config_in_current_directory(tmp_path: Path): + project = tmp_path / "project" + project.mkdir() + cfg = project / "config.yaml" + cfg.touch() + + assert paths_mod.find_nearest_config(project) == cfg.resolve() + + +def test_find_nearest_config_finds_parent_project_from_nested_file(tmp_path: Path): + project = tmp_path / "project" + dataset = project / "labeled-data" / "mouse1" + dataset.mkdir(parents=True) + cfg = project / "config.yaml" + cfg.touch() + + img = dataset / "img001.png" + img.touch() + + assert paths_mod.find_nearest_config(img) == cfg.resolve() + + +def test_find_nearest_config_respects_max_levels(tmp_path: Path): + project = tmp_path / "project" + deep = project / "a" / "b" / "c" / "d" / "e" / "f" + deep.mkdir(parents=True) + cfg = project / "config.yaml" + cfg.touch() + + assert paths_mod.find_nearest_config(deep, max_levels=2) is None + assert paths_mod.find_nearest_config(deep, max_levels=6) == cfg.resolve() + + +def test_find_nearest_config_returns_none_when_no_config(tmp_path: Path): + folder = tmp_path / "no_project_here" + folder.mkdir() + + assert paths_mod.find_nearest_config(folder) is None + + +# ----------------------------------------------------------------------------- +# infer_labeled_data_folder_from_paths +# ----------------------------------------------------------------------------- +def test_infer_labeled_data_folder_from_paths_uses_fallback_root_when_already_dataset(tmp_path: Path): + dataset = tmp_path / "project" / "labeled-data" / "mouse1" + dataset.mkdir(parents=True) + + result = paths_mod.infer_labeled_data_folder_from_paths( + [], + fallback_root=dataset, + ) + + assert result == dataset.resolve() + + +def test_infer_labeled_data_folder_from_paths_builds_folder_from_project_and_relpaths(tmp_path: Path): + project = tmp_path / "project" + project.mkdir() + + result = paths_mod.infer_labeled_data_folder_from_paths( + ["labeled-data/mouse1/img001.png"], + project_root=project, + ) + + assert result == (project / "labeled-data" / "mouse1").resolve() + + +def test_infer_labeled_data_folder_from_paths_returns_none_without_dataset_name(tmp_path: Path): + project = tmp_path / "project" + project.mkdir() + + result = paths_mod.infer_labeled_data_folder_from_paths( + ["images/img001.png"], + project_root=project, + ) + + assert result is None + + +def test_infer_labeled_data_folder_from_paths_returns_none_without_project_root(): + result = paths_mod.infer_labeled_data_folder_from_paths( + ["labeled-data/mouse1/img001.png"], + project_root=None, + ) + + assert result is None + + +# ----------------------------------------------------------------------------- +# infer_dlc_project +# ----------------------------------------------------------------------------- +def test_infer_dlc_project_prefers_explicit_root_and_finds_config(tmp_path: Path): + project = tmp_path / "project" + project.mkdir() + cfg = project / "config.yaml" + cfg.touch() + + other = tmp_path / "other" + other.mkdir() + + ctx = paths_mod.infer_dlc_project( + explicit_root=project, + anchor_candidates=[other], + prefer_project_root=True, + ) + + assert ctx.project_root == project.resolve() + assert ctx.config_path == cfg.resolve() + assert ctx.root_anchor == project.resolve() + + +def test_infer_dlc_project_keeps_local_anchor_when_prefer_project_root_false(tmp_path: Path): + project = tmp_path / "project" + dataset = project / "labeled-data" / "mouse1" + dataset.mkdir(parents=True) + cfg = project / "config.yaml" + cfg.touch() + + ctx = paths_mod.infer_dlc_project( + anchor_candidates=[dataset], + prefer_project_root=False, + ) + + assert ctx.project_root == project.resolve() + assert ctx.config_path == cfg.resolve() + assert ctx.root_anchor == dataset.resolve() + + +def test_infer_dlc_project_returns_best_effort_without_config(tmp_path: Path): + dataset = tmp_path / "labeled-data" / "mouse1" + dataset.mkdir(parents=True) + + ctx = paths_mod.infer_dlc_project( + anchor_candidates=[dataset], + prefer_project_root=True, + ) + + assert ctx.project_root is None + assert ctx.config_path is None + assert ctx.root_anchor == dataset.resolve() + + +def test_infer_dlc_project_uses_dataset_candidate_when_no_anchor_candidates(tmp_path: Path): + dataset = tmp_path / "project" / "labeled-data" / "mouse1" + dataset.mkdir(parents=True) + + ctx = paths_mod.infer_dlc_project( + dataset_candidates=[dataset], + ) + + assert ctx.dataset_folder == dataset.resolve() + assert ctx.root_anchor == dataset.resolve() + + +# ----------------------------------------------------------------------------- +# infer_dlc_project_from_opened +# ----------------------------------------------------------------------------- +def test_infer_dlc_project_from_opened_uses_opened_path(tmp_path: Path): + dataset = tmp_path / "dataset" + dataset.mkdir() + + ctx = paths_mod.infer_dlc_project_from_opened(dataset) + + assert ctx.root_anchor == dataset.resolve() + assert ctx.project_root is None + assert ctx.config_path is None + + +def test_infer_dlc_project_from_opened_can_find_project_root(tmp_path: Path): + project = tmp_path / "project" + dataset = project / "labeled-data" / "mouse1" + dataset.mkdir(parents=True) + cfg = project / "config.yaml" + cfg.touch() + + ctx = paths_mod.infer_dlc_project_from_opened(dataset) + + assert ctx.project_root == project.resolve() + assert ctx.config_path == cfg.resolve() + assert ctx.root_anchor == project.resolve() + + +# ----------------------------------------------------------------------------- +# infer_dlc_project_from_points_meta +# ----------------------------------------------------------------------------- +def test_infer_dlc_project_from_points_meta_infers_dataset_and_project(tmp_path: Path): + project = tmp_path / "project" + project.mkdir() + cfg = project / "config.yaml" + cfg.touch() + + pts_meta = SimpleNamespace( + project=str(project), + root=None, + paths=["labeled-data/mouse1/img001.png"], + ) + + ctx = paths_mod.infer_dlc_project_from_points_meta(pts_meta) + + assert ctx.project_root == project.resolve() + assert ctx.config_path == cfg.resolve() + assert ctx.dataset_folder == (project / "labeled-data" / "mouse1").resolve() + assert ctx.root_anchor == project.resolve() + + +def test_infer_dlc_project_from_points_meta_uses_root_as_dataset_fallback(tmp_path: Path): + dataset = tmp_path / "project" / "labeled-data" / "mouse1" + dataset.mkdir(parents=True) + + pts_meta = SimpleNamespace( + project=None, + root=str(dataset), + paths=[], + ) + + ctx = paths_mod.infer_dlc_project_from_points_meta(pts_meta) + + assert ctx.dataset_folder == dataset.resolve() + assert ctx.root_anchor == dataset.resolve() + + +# ----------------------------------------------------------------------------- +# infer_dlc_project_from_image_layer +# ----------------------------------------------------------------------------- +def test_infer_dlc_project_from_image_layer_uses_metadata_project(tmp_path: Path): + project = tmp_path / "project" + project.mkdir() + cfg = project / "config.yaml" + cfg.touch() + + layer = SimpleNamespace( + metadata={"project": str(project)}, + source=SimpleNamespace(path=None), + ) + + ctx = paths_mod.infer_dlc_project_from_image_layer(layer) + + assert ctx.project_root == project.resolve() + assert ctx.config_path == cfg.resolve() + assert ctx.root_anchor == project.resolve() + + +def test_infer_dlc_project_from_image_layer_falls_back_to_source_path(tmp_path: Path): + project = tmp_path / "project" + videos = project / "videos" + videos.mkdir(parents=True) + cfg = project / "config.yaml" + cfg.touch() + + video = videos / "demo.mp4" + video.touch() + + layer = SimpleNamespace( + metadata={}, + source=SimpleNamespace(path=str(video)), + ) + + ctx = paths_mod.infer_dlc_project_from_image_layer(layer) + + assert ctx.project_root == project.resolve() + assert ctx.config_path == cfg.resolve() + assert ctx.root_anchor == project.resolve() + + +def test_infer_dlc_project_from_image_layer_returns_best_effort_without_config(tmp_path: Path): + videos = tmp_path / "videos" + videos.mkdir() + + video = videos / "demo.mp4" + video.touch() + + layer = SimpleNamespace( + metadata={}, + source=SimpleNamespace(path=str(video)), + ) + + ctx = paths_mod.infer_dlc_project_from_image_layer(layer) + + assert ctx.project_root is None + assert ctx.config_path is None + assert ctx.root_anchor == videos.resolve() + + +def test_resolve_project_root_from_config(tmp_path): + project = tmp_path / "my-project" + project.mkdir() + cfg = project / "config.yaml" + cfg.write_text("scorer: test\n", encoding="utf-8") + + assert resolve_project_root_from_config(cfg) == project + assert resolve_project_root_from_config(project / "not_config.yaml") is None + assert resolve_project_root_from_config(project / "config.yml") is None + assert resolve_project_root_from_config(project / "missing" / "config.yaml") is None + + +def test_coerce_paths_to_dlc_row_keys_for_projectless_folder(tmp_path): + source_root = tmp_path / "session_42" + source_root.mkdir() + + inside_abs = source_root / "img001.png" + nested_abs = source_root / "nested" / "img_nested.png" + nested_abs.parent.mkdir() + outside_abs = tmp_path / "elsewhere" / "img999.png" + outside_abs.parent.mkdir() + + rewritten, unresolved = coerce_paths_to_dlc_row_keys( + [ + inside_abs, + "img002.png", + "labeled-data\\session_42\\img003.png", + nested_abs, + outside_abs, + ], + source_root=source_root, + ) + + assert rewritten == [ + "labeled-data/session_42/img001.png", + "labeled-data/session_42/img002.png", + "labeled-data/session_42/img003.png", + nested_abs.as_posix(), + outside_abs.as_posix(), + ] + assert unresolved == (3, 4) + + +def test_target_dataset_folder_and_existing_files_guard(tmp_path): + project = tmp_path / "proj" + project.mkdir() + cfg = project / "config.yaml" + cfg.write_text("scorer: John\n", encoding="utf-8") + + target = target_dataset_folder_for_config(cfg, dataset_name="session_42") + assert target == project / "labeled-data" / "session_42" + assert not dataset_folder_has_files(target) + + target.mkdir(parents=True) + assert not dataset_folder_has_files(target) + + (target / "img001.png").write_bytes(b"x") + assert dataset_folder_has_files(target) + + +def test_infer_dlc_project_from_config_returns_explicit_project_context(tmp_path): + project = tmp_path / "my-project" + project.mkdir() + + config_path = project / "config.yaml" + config_path.write_text("scorer: John\n", encoding="utf-8") + + ctx = infer_dlc_project_from_config(config_path) + + assert ctx.root_anchor == project + assert ctx.project_root == project + assert ctx.config_path == config_path + assert ctx.dataset_folder is None + + +def test_infer_dlc_project_from_config_rejects_invalid_path(tmp_path): + project = tmp_path / "my-project" + project.mkdir() + + bad_config = project / "not_config.yaml" + bad_config.write_text("scorer: John\n", encoding="utf-8") + + with pytest.raises(ValueError): + infer_dlc_project_from_config(bad_config) diff --git a/src/napari_deeplabcut/_tests/core/test_provenance.py b/src/napari_deeplabcut/_tests/core/test_provenance.py new file mode 100644 index 00000000..36856c54 --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_provenance.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from napari_deeplabcut.config.models import AnnotationKind, IOProvenance +from napari_deeplabcut.core.errors import MissingProvenanceError, UnresolvablePathError +from napari_deeplabcut.core.provenance import ( + ensure_io_provenance, + normalize_provenance, + resolve_provenance_path, +) + +# ----------------------------------------------------------------------------- +# ensure_io_provenance +# ----------------------------------------------------------------------------- + + +def test_ensure_io_provenance_none_returns_none(): + assert ensure_io_provenance(None) is None + + +def test_ensure_io_provenance_accepts_model_instance(tmp_path: Path): + io = IOProvenance( + project_root=str(tmp_path), + source_relpath_posix="CollectedData_Jane.h5", + kind=AnnotationKind.GT, + dataset_key="keypoints", + ) + out = ensure_io_provenance(io) + assert out is io + + +def test_ensure_io_provenance_accepts_dict_with_enum_kind(tmp_path: Path): + payload = { + "schema_version": 1, + "project_root": str(tmp_path), + "source_relpath_posix": "CollectedData_Jane.h5", + "kind": AnnotationKind.GT, # IMPORTANT: enum instance, not string + "dataset_key": "keypoints", + } + out = ensure_io_provenance(payload) + assert isinstance(out, IOProvenance) + assert out.kind == AnnotationKind.GT + assert out.source_relpath_posix == "CollectedData_Jane.h5" + + +def test_ensure_io_provenance_rejects_dict_with_string_kind(tmp_path: Path): + """ + Contract from provenance.py docstring: + runtime must carry AnnotationKind objects; strings invalid. + This is enforced by IOProvenance.kind strict=True. + """ + payload = { + "schema_version": 1, + "project_root": str(tmp_path), + "source_relpath_posix": "CollectedData_Jane.h5", + "kind": "gt", # invalid at runtime by policy + "dataset_key": "keypoints", + } + with pytest.raises(MissingProvenanceError): + ensure_io_provenance(payload) + + +def test_ensure_io_provenance_rejects_invalid_kind_value(tmp_path: Path): + payload = { + "schema_version": 1, + "project_root": str(tmp_path), + "source_relpath_posix": "CollectedData_Jane.h5", + "kind": "not-a-kind", + "dataset_key": "keypoints", + } + with pytest.raises(MissingProvenanceError): + ensure_io_provenance(payload) + + +def test_ensure_io_provenance_rejects_invalid_type(): + with pytest.raises(MissingProvenanceError): + ensure_io_provenance(["not", "a", "dict"]) # type: ignore[arg-type] + + +def test_ensure_io_provenance_rejects_missing_required_relpath(tmp_path: Path): + payload = { + "schema_version": 1, + "project_root": str(tmp_path), + # missing source_relpath_posix + "kind": AnnotationKind.GT, + "dataset_key": "keypoints", + } + # ensure_io_provenance validates; resolve_provenance_path is stricter about missing relpath + out = ensure_io_provenance(payload) + assert isinstance(out, IOProvenance) + assert out.source_relpath_posix is None + + +# ----------------------------------------------------------------------------- +# normalize_provenance +# ----------------------------------------------------------------------------- + + +def test_normalize_provenance_none_returns_none(): + assert normalize_provenance(None) is None + + +def test_normalize_provenance_converts_backslashes(tmp_path: Path): + io = IOProvenance( + project_root=str(tmp_path), + source_relpath_posix=r"labeled-data\test\CollectedData_Jane.h5", + kind=AnnotationKind.GT, + dataset_key="keypoints", + ) + out = normalize_provenance(io) + assert out is not None + assert out.source_relpath_posix == "labeled-data/test/CollectedData_Jane.h5" + + +# ----------------------------------------------------------------------------- +# resolve_provenance_path +# ----------------------------------------------------------------------------- + + +def test_resolve_provenance_path_uses_root_anchor_when_provided(tmp_path: Path): + # Two valid roots + anchor = tmp_path / "anchor" + anchor.mkdir() + other_root = tmp_path / "other_root" + other_root.mkdir() + + # File exists ONLY under anchor + (anchor / "CollectedData_Jane.h5").write_bytes(b"dummy") + + io = { + "schema_version": 1, + "project_root": str(other_root), # valid dir, but not where file exists + "source_relpath_posix": "CollectedData_Jane.h5", + "kind": AnnotationKind.GT, + "dataset_key": "keypoints", + } + + resolved = resolve_provenance_path(io, root_anchor=anchor) + assert resolved == anchor / "CollectedData_Jane.h5" + + +def test_resolve_provenance_path_uses_project_root_when_root_anchor_missing(tmp_path: Path): + root = tmp_path / "root" + root.mkdir() + (root / "CollectedData_Jane.h5").write_bytes(b"dummy") + + io = { + "schema_version": 1, + "project_root": str(root), + "source_relpath_posix": "CollectedData_Jane.h5", + "kind": AnnotationKind.GT, + "dataset_key": "keypoints", + } + + resolved = resolve_provenance_path(io, root_anchor=None) + assert resolved == root / "CollectedData_Jane.h5" + + +def test_resolve_provenance_path_requires_source_relpath_posix(tmp_path: Path): + payload = { + "schema_version": 1, + "project_root": str(tmp_path), + "source_relpath_posix": None, + "kind": AnnotationKind.GT, + "dataset_key": "keypoints", + } + with pytest.raises(MissingProvenanceError): + resolve_provenance_path(payload) + + +def test_resolve_provenance_path_requires_anchor_or_project_root(tmp_path: Path): + payload = { + "schema_version": 1, + "project_root": None, + "source_relpath_posix": "CollectedData_Jane.h5", + "kind": AnnotationKind.GT, + "dataset_key": "keypoints", + } + with pytest.raises(UnresolvablePathError): + resolve_provenance_path(payload, root_anchor=None) + + +def test_resolve_provenance_path_raises_if_missing_by_default(tmp_path: Path): + payload = { + "schema_version": 1, + "project_root": str(tmp_path), + "source_relpath_posix": "CollectedData_Jane.h5", + "kind": AnnotationKind.GT, + "dataset_key": "keypoints", + } + with pytest.raises(UnresolvablePathError): + resolve_provenance_path(payload, allow_missing=False) + + +def test_resolve_provenance_path_allows_missing_when_flag_true(tmp_path: Path): + payload = { + "schema_version": 1, + "project_root": str(tmp_path), + "source_relpath_posix": "CollectedData_Jane.h5", + "kind": AnnotationKind.GT, + "dataset_key": "keypoints", + } + resolved = resolve_provenance_path(payload, allow_missing=True) + assert resolved == tmp_path / "CollectedData_Jane.h5" + + +def test_resolve_provenance_path_normalizes_backslashes(tmp_path: Path): + # Create expected file + (tmp_path / "labeled-data").mkdir() + (tmp_path / "labeled-data" / "CollectedData_Jane.h5").write_bytes(b"dummy") + + payload = { + "schema_version": 1, + "project_root": str(tmp_path), + "source_relpath_posix": r"labeled-data\CollectedData_Jane.h5", + "kind": AnnotationKind.GT, + "dataset_key": "keypoints", + } + resolved = resolve_provenance_path(payload) + assert resolved == tmp_path / "labeled-data" / "CollectedData_Jane.h5" diff --git a/src/napari_deeplabcut/_tests/core/test_reader_layerdata_contract.py b/src/napari_deeplabcut/_tests/core/test_reader_layerdata_contract.py new file mode 100644 index 00000000..4cd165d8 --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_reader_layerdata_contract.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pandas as pd + +from napari_deeplabcut.config.models import AnnotationKind +from napari_deeplabcut.core.io import read_hdf_single + + +def _write_minimal_h5(path: Path, scorer: str, with_likelihood: bool = False, all_nan: bool = True): + # Create a single-row DLC-style dataframe with x/y(/likelihood) + coords = ["x", "y"] + (["likelihood"] if with_likelihood else []) + cols = pd.MultiIndex.from_product([[scorer], ["bp1"], coords], names=["scorer", "bodyparts", "coords"]) + if all_nan: + row = [np.nan] * len(coords) + else: + row = [10.0, 20.0] + ([0.9] if with_likelihood else []) + df = pd.DataFrame([row], index=["img000.png"], columns=cols) + path.parent.mkdir(parents=True, exist_ok=True) + df.to_hdf(path, key="keypoints", mode="w") + return df + + +def _assert_properties_match_data_len(meta: dict, n: int): + props = meta.get("properties") or {} + assert isinstance(props, dict) + for k, v in props.items(): + # v can be list, np array, pandas series + ln = len(v) + assert ln == n, f"Property '{k}' has len={ln} but data has n={n}" + + +def test_read_hdf_single_machine_all_nan_returns_empty_points_layer(tmp_path: Path): + h5 = tmp_path / "machinelabels-iter0.h5" + _write_minimal_h5(h5, scorer="machine", with_likelihood=False, all_nan=True) + + layers = read_hdf_single(h5, kind=AnnotationKind.MACHINE) + assert len(layers) == 1 + + data, meta, layer_type = layers[0] + assert layer_type == "points" + assert np.asarray(data).shape[1] == 3 + + # Expect: no finite coords => empty points + assert np.asarray(data).shape[0] == 0 + + # Properties must match data length (0) + _assert_properties_match_data_len(meta, 0) + + +def test_read_hdf_single_gt_with_point_produces_one_point_and_matching_properties(tmp_path: Path): + h5 = tmp_path / "CollectedData_John.h5" + _write_minimal_h5(h5, scorer="John", with_likelihood=False, all_nan=False) + + layers = read_hdf_single(h5, kind=AnnotationKind.GT) + data, meta, layer_type = layers[0] + + assert np.asarray(data).shape[0] == 1 + _assert_properties_match_data_len(meta, 1) + assert meta["properties"]["label"][0] == "bp1" + + +def test_read_hdf_single_filters_data_and_properties_consistently(tmp_path: Path): + # Two "rows" will appear after stack: simulate one finite, one non-finite in the stacked frame. + # Easiest: two bodyparts, one is NaN, one is finite. + h5 = tmp_path / "CollectedData_John.h5" + cols = pd.MultiIndex.from_product([["John"], ["bp1", "bp2"], ["x", "y"]], names=["scorer", "bodyparts", "coords"]) + df = pd.DataFrame([[10.0, 20.0, np.nan, np.nan]], index=["img000.png"], columns=cols) + df.to_hdf(h5, key="keypoints", mode="w") + + layers = read_hdf_single(h5, kind=AnnotationKind.GT) + data, meta, _ = layers[0] + + # Only bp1 is finite => one point + assert np.asarray(data).shape[0] == 1 + assert meta["properties"]["label"] == ["bp1"] + assert len(meta["properties"]["id"]) == 1 + + +def test_reader_never_returns_nonempty_data_with_empty_properties(tmp_path: Path): + h5 = tmp_path / "machinelabels-iter0.h5" + _write_minimal_h5(h5, scorer="machine", all_nan=True) + + data, meta, _ = read_hdf_single(h5, kind=AnnotationKind.MACHINE)[0] + n = np.asarray(data).shape[0] + props = meta.get("properties") or {} + if n > 0: + assert all(len(v) == n for v in props.values()) diff --git a/src/napari_deeplabcut/_tests/core/test_remap.py b/src/napari_deeplabcut/_tests/core/test_remap.py new file mode 100644 index 00000000..869864ad --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_remap.py @@ -0,0 +1,303 @@ +import logging + +import numpy as np + +from napari_deeplabcut.core.project_paths import PathMatchPolicy +from napari_deeplabcut.core.remap import ( + build_frame_index_map, + remap_layer_data_by_paths, + remap_time_indices, +) + + +def test_build_frame_index_map_depth3_overlap_and_reorder(): + """ + If canonicalization depth=3 yields overlap, and new_paths order differs, + build_frame_index_map must produce correct old_idx -> new_idx mapping. + """ + old_paths = [ + "projA/labeled-data/test/img000.png", + "projA/labeled-data/test/img001.png", + ] + # reverse order in new paths + new_paths = [ + "projB/labeled-data/test/img001.png", + "projB/labeled-data/test/img000.png", + ] + + idx_map, depth = build_frame_index_map( + old_paths=old_paths, + new_paths=new_paths, + policy=PathMatchPolicy.ORDERED_DEPTHS, + ) + + assert depth == 3 + assert idx_map == {0: 1, 1: 0} + + +def test_build_frame_index_map_falls_back_to_depth2_when_depth3_has_no_overlap(): + """ + Construct paths such that: + - last 3 components differ (no overlap at n=3) + - last 2 components match (overlap at n=2) + """ + old_paths = [ + "A/one/img000.png", # last3 => A/one/img000.png + "A/one/img001.png", + ] + new_paths = [ + "B/one/img000.png", # last3 => B/one/img000.png (no overlap) + "B/one/img001.png", + ] + idx_map, depth = build_frame_index_map( + old_paths=old_paths, + new_paths=new_paths, + policy=PathMatchPolicy.ORDERED_DEPTHS, + ) + + assert depth == 2 + assert idx_map == {0: 0, 1: 1} + + +def test_remap_layer_data_by_paths_array_like_points_time_col0(): + """ + Array-like data (Points/Tracks) should have its time column remapped. + """ + old_paths = ["x/y/img0.png", "x/y/img1.png"] + new_paths = ["x/y/img1.png", "x/y/img0.png"] # swap order -> mapping {0:1,1:0} + + # data columns: (time, y, x) + data = np.array( + [ + [0.0, 10.0, 20.0], + [1.0, 11.0, 21.0], + ], + dtype=float, + ) + + res = remap_layer_data_by_paths( + data=data, + old_paths=old_paths, + new_paths=new_paths, + time_col=0, + policy=PathMatchPolicy.ORDERED_DEPTHS, + ) + + assert res.depth_used in (3, 2, 1) # depends on how many components are present + assert res.changed is True + assert res.data is not None + assert np.array_equal(res.data[:, 0].astype(int), np.array([1, 0])) + + +def test_remap_layer_data_by_paths_array_like_tracks_time_col1(): + """ + For Tracks layers, time column is typically column 1. + Ensure remap works for arbitrary time_col. + """ + old_paths = ["root/seq/img0.png", "root/seq/img1.png"] + new_paths = ["root/seq/img1.png", "root/seq/img0.png"] + + # tracks-like: (track_id, time, y, x) + data = np.array( + [ + [5.0, 0.0, 10.0, 20.0], + [5.0, 1.0, 11.0, 21.0], + ], + dtype=float, + ) + + res = remap_layer_data_by_paths( + data=data, + old_paths=old_paths, + new_paths=new_paths, + time_col=1, + policy=PathMatchPolicy.ORDERED_DEPTHS, + ) + + assert res.changed is True + assert res.data is not None + assert np.array_equal(res.data[:, 1].astype(int), np.array([1, 0])) + # track_id must remain unchanged + assert np.array_equal(res.data[:, 0].astype(int), np.array([5, 5])) + + +def test_remap_layer_data_by_paths_list_like_shapes(): + """ + Shapes-like data is a list of arrays; remap should apply per-vertex array. + """ + old_paths = ["p/q/img0.png", "p/q/img1.png"] + new_paths = ["p/q/img1.png", "p/q/img0.png"] + + shapes = [ + np.array( + [ + [0.0, 10.0, 20.0], + [1.0, 11.0, 21.0], + ], + dtype=float, + ) + ] + + res = remap_layer_data_by_paths( + data=shapes, + old_paths=old_paths, + new_paths=new_paths, + time_col=0, + policy=PathMatchPolicy.ORDERED_DEPTHS, + ) + + assert res.changed is True + assert isinstance(res.data, list) + assert np.array_equal(np.asarray(res.data[0])[:, 0].astype(int), np.array([1, 0])) + + +def test_no_overlap_returns_no_change_and_no_data(): + old_paths = ["a/b/c/img0.png"] + new_paths = ["x/y/z/other.png"] + + data = np.array([[0.0, 1.0, 2.0]], dtype=float) + + res = remap_layer_data_by_paths( + data=data, + old_paths=old_paths, + new_paths=new_paths, + time_col=0, + policy=PathMatchPolicy.ORDERED_DEPTHS, + ) + + assert res.changed is False + assert res.data is None + assert res.depth_used is None + assert "No overlap" in res.message or "skipping" in res.message.lower() + + +def test_already_aligned_paths_skips_remap(): + """ + If canonicalized old_keys == new_keys, remap_layer_data_by_paths should no-op. + """ + paths = ["labeled-data/test/img0.png", "labeled-data/test/img1.png"] + + data = np.array( + [ + [0.0, 10.0, 20.0], + [1.0, 11.0, 21.0], + ], + dtype=float, + ) + + res = remap_layer_data_by_paths( + data=data, + old_paths=paths, + new_paths=paths, + time_col=0, + policy=PathMatchPolicy.ORDERED_DEPTHS, + ) + + assert res.changed is False + assert res.data is None + assert res.depth_used is not None # a matching depth exists + assert "already aligned" in res.message.lower() or "no remap needed" in res.message.lower() + + +def test_missing_old_or_new_paths_skips(): + data = np.array([[0.0, 1.0, 2.0]], dtype=float) + + res1 = remap_layer_data_by_paths(data=data, old_paths=[], new_paths=["a/b/c.png"], time_col=0) + assert res1.changed is False + assert res1.data is None + + res2 = remap_layer_data_by_paths(data=data, old_paths=["a/b/c.png"], new_paths=[], time_col=0) + assert res2.changed is False + assert res2.data is None + + +def test_remap_time_indices_leaves_unmapped_indices_unchanged(): + """ + If idx_map doesn't contain a value, it should remain unchanged. + """ + idx_map = {0: 10} # only frame 0 remaps + data = np.array( + [ + [0.0, 10.0, 20.0], + [1.0, 11.0, 21.0], # frame 1 unmapped -> should stay 1 + ], + dtype=float, + ) + + res = remap_time_indices(data=data, time_col=0, idx_map=idx_map) + + assert res.changed is True + assert res.data is not None + assert np.array_equal(res.data[:, 0].astype(int), np.array([10, 1])) + + +def test_remap_time_indices_gracefully_handles_empty_and_none(): + res_none = remap_time_indices(data=None, time_col=0, idx_map={0: 1}) + assert res_none.changed is False + assert res_none.data is None + + res_empty = remap_time_indices(data=np.array([]), time_col=0, idx_map={0: 1}) + assert res_empty.changed is False + assert res_empty.data is None + + +def test_remap_warns_on_duplicate_canonical_keys(caplog): + caplog.set_level(logging.WARNING, logger="napari_deeplabcut.core.remap") + + # Depth=2 canonical keys will be: + # old_keys = ["dup/img0.png", "dup/img0.png", "dup/img1.png"] + # new_keys = ["dup/img1.png", "dup/img0.png", "dup/img0.png"] + # -> not equal => remap path is taken + old_paths = [ + "A/dup/img0.png", + "B/dup/img0.png", + "C/dup/img1.png", + ] + new_paths = [ + "X/dup/img1.png", + "Y/dup/img0.png", + "Z/dup/img0.png", + ] + + data = np.array( + [ + [0.0, 1.0, 2.0], + [1.0, 3.0, 4.0], + [2.0, 5.0, 6.0], + ], + dtype=float, + ) + + res = remap_layer_data_by_paths( + data=data, + old_paths=old_paths, + new_paths=new_paths, + time_col=0, + policy=PathMatchPolicy.ORDERED_DEPTHS, + ) + + assert "Remap may be ambiguous/risky" in caplog.text + assert "Duplicate canonical keys" in caplog.text + assert any("Duplicate canonical keys" in w for w in res.warnings) + + +def test_remap_warns_on_low_overlap_ratio(caplog): + caplog.set_level(logging.WARNING) + + old_paths = ["x/y/img0.png", "x/y/img1.png", "x/y/img2.png", "x/y/img3.png"] + # same length, only one overlapping key + new_paths = ["x/y/img0.png", "x/y/img9.png", "x/y/img8.png", "x/y/img7.png"] + + data = np.array([[0.0, 1.0, 2.0]], dtype=float) + + res = remap_layer_data_by_paths( + data=data, + old_paths=old_paths, + new_paths=new_paths, + time_col=0, + policy=PathMatchPolicy.ORDERED_DEPTHS, + ) + + assert "Low path overlap ratio" in caplog.text + assert "Low mapping coverage" in caplog.text + assert any("Low path overlap ratio" in w for w in res.warnings) diff --git a/src/napari_deeplabcut/_tests/core/test_sidecar.py b/src/napari_deeplabcut/_tests/core/test_sidecar.py new file mode 100644 index 00000000..ca24dd56 --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_sidecar.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from napari_deeplabcut.config.models import FolderUIState, TrailsDisplayConfig +from napari_deeplabcut.core.sidecar import ( + _migrate_sidecar_payload, + get_default_scorer, + get_trails_config, + read_sidecar_state, + set_default_scorer, + set_trails_config, + sidecar_path, + update_sidecar_state, + update_trails_config, + write_sidecar_state, +) + + +def test_sidecar_path_joins_anchor_and_filename(tmp_path: Path): + p = sidecar_path(tmp_path) + assert p == tmp_path / ".napari-deeplabcut.json" + + +def test_read_sidecar_state_missing_file_returns_defaults(tmp_path: Path): + state = read_sidecar_state(tmp_path) + + assert isinstance(state, FolderUIState) + assert state.schema_version == 1 + assert state.default_scorer is None + assert state.trails == TrailsDisplayConfig() + + +def test_read_sidecar_state_invalid_json_returns_defaults(tmp_path: Path): + p = sidecar_path(tmp_path) + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text("{not valid json", encoding="utf-8") + + state = read_sidecar_state(tmp_path) + + assert state.schema_version == 1 + assert state.default_scorer is None + assert state.trails == TrailsDisplayConfig() + + +def test_read_sidecar_state_non_dict_json_returns_defaults(tmp_path: Path): + p = sidecar_path(tmp_path) + p.write_text(json.dumps([1, 2, 3]), encoding="utf-8") + + state = read_sidecar_state(tmp_path) + assert state == FolderUIState() + + +def test_migrate_sidecar_payload_normalizes_schema_version(): + payload = {"default_scorer": "John"} + migrated = _migrate_sidecar_payload(payload) + + assert migrated["schema_version"] == 1 + assert migrated["default_scorer"] == "John" + + +def test_write_and_read_sidecar_state_roundtrip(tmp_path: Path): + state = FolderUIState( + default_scorer="John", + trails=TrailsDisplayConfig( + tail_length=70, + head_length=10, + tail_width=4.25, + opacity=0.6, + blending="opaque", + visible=False, + ), + ) + + write_sidecar_state(tmp_path, state) + reloaded = read_sidecar_state(tmp_path) + + assert reloaded.schema_version == 1 + assert reloaded.default_scorer == "John" + assert reloaded.trails == state.trails + + +def test_update_sidecar_state_shallow_patch_preserves_other_fields(tmp_path: Path): + initial = FolderUIState( + default_scorer="John", + trails=TrailsDisplayConfig(tail_length=80, head_length=20, tail_width=5.0), + ) + write_sidecar_state(tmp_path, initial) + + updated = update_sidecar_state(tmp_path, default_scorer="Jane") + + assert updated.default_scorer == "Jane" + assert updated.trails.tail_length == 80 + assert updated.trails.head_length == 20 + assert updated.trails.tail_width == 5.0 + + +def test_get_and_set_default_scorer_roundtrip(tmp_path: Path): + assert get_default_scorer(tmp_path) is None + + set_default_scorer(tmp_path, "John") + assert get_default_scorer(tmp_path) == "John" + + +def test_set_default_scorer_rejects_empty(tmp_path: Path): + with pytest.raises(ValueError, match="default_scorer must be non-empty"): + set_default_scorer(tmp_path, " ") + + +def test_get_trails_config_returns_defaults_when_absent(tmp_path: Path): + cfg = get_trails_config(tmp_path) + assert cfg == TrailsDisplayConfig() + + +def test_set_trails_config_roundtrip(tmp_path: Path): + cfg = TrailsDisplayConfig( + tail_length=90, + head_length=15, + tail_width=3.75, + opacity=0.85, + blending="minimum", + visible=False, + ) + + set_trails_config(tmp_path, cfg) + reloaded = get_trails_config(tmp_path) + + assert reloaded == cfg + + +def test_set_trails_config_accepts_dict_payload(tmp_path: Path): + set_trails_config( + tmp_path, + { + "tail_length": 65, + "head_length": 11, + "tail_width": 2.5, + "opacity": 0.7, + "blending": "opaque", + "visible": False, + }, + ) + + cfg = get_trails_config(tmp_path) + assert cfg == TrailsDisplayConfig( + tail_length=65, + head_length=11, + tail_width=2.5, + opacity=0.7, + blending="opaque", + visible=False, + ) + + +def test_update_trails_config_patches_only_requested_fields(tmp_path: Path): + set_trails_config( + tmp_path, + TrailsDisplayConfig( + tail_length=50, + head_length=50, + tail_width=6.0, + opacity=1.0, + blending="translucent", + visible=True, + ), + ) + + updated = update_trails_config(tmp_path, tail_length=120, visible=False) + + assert updated == TrailsDisplayConfig( + tail_length=120, + head_length=50, + tail_width=6.0, + opacity=1.0, + blending="translucent", + visible=False, + ) + + reloaded = get_trails_config(tmp_path) + assert reloaded == updated + + +def test_sidecar_json_written_with_schema_version(tmp_path: Path): + set_default_scorer(tmp_path, "John") + + raw = json.loads(sidecar_path(tmp_path).read_text(encoding="utf-8")) + assert raw["schema_version"] == 1 + assert raw["default_scorer"] == "John" + + +def test_read_sidecar_state_old_payload_gets_default_trails(tmp_path: Path): + sidecar_path(tmp_path).write_text( + json.dumps({"default_scorer": "John"}), + encoding="utf-8", + ) + + state = read_sidecar_state(tmp_path) + assert state.schema_version == 1 + assert state.default_scorer == "John" + assert state.trails == TrailsDisplayConfig() diff --git a/src/napari_deeplabcut/_tests/core/test_trails.py b/src/napari_deeplabcut/_tests/core/test_trails.py new file mode 100644 index 00000000..f05c983e --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_trails.py @@ -0,0 +1,501 @@ +from __future__ import annotations + +import numpy as np +import pytest +from napari.layers import Points, Tracks + +from napari_deeplabcut.config.models import TrailsDisplayConfig +from napari_deeplabcut.core import keypoints +from napari_deeplabcut.core.layer_versioning import layer_change_generations +from napari_deeplabcut.core.trails import ( + _trails_rgba_array, + active_trails_color_property, + build_trails_payload, + categorical_colormap_from_points_layer, + display_config_from_tracks_layer, + is_multianimal_points_layer, + tracks_kwargs_from_display_config, + trails_geometry_signature, + trails_signature, + trails_track_ids, +) + + +def _make_points( + data: np.ndarray, + *, + labels=None, + ids=None, + face_color_cycles=None, + colormap_name="viridis", +) -> Points: + properties = {} + if labels is not None: + properties["label"] = np.asarray(labels, dtype=object) + if ids is not None and len(ids) > 0: + properties["id"] = np.asarray(ids, dtype=object) + + metadata = { + "colormap_name": colormap_name, + "face_color_cycles": face_color_cycles or {}, + } + + return Points( + data=np.asarray(data, dtype=float), + properties=properties, + metadata=metadata, + name="points", + ) + + +@pytest.fixture +def single_points_layer(): + data = np.array( + [ + [0, 10, 20], + [1, 11, 21], + [2, 12, 22], + [3, 13, 23], + ], + dtype=float, + ) + labels = ["nose", "tail", "nose", "tail"] + ids = ["", "", "", ""] + face_color_cycles = { + "label": { + "nose": [1.0, 0.0, 0.0, 1.0], + "tail": [0.0, 1.0, 0.0, 1.0], + }, + "id": { + "": [0.3, 0.3, 0.3, 1.0], + }, + } + return _make_points( + data, + labels=labels, + ids=ids, + face_color_cycles=face_color_cycles, + colormap_name="magma", + ) + + +@pytest.fixture +def multi_points_layer(): + data = np.array( + [ + [0, 10, 20], + [0, 12, 22], + [0, 14, 24], + [0, 16, 26], + [1, 11, 21], + [1, 13, 23], + ], + dtype=float, + ) + labels = ["nose", "tail", "nose", "tail", "nose", "tail"] + ids = ["mouseA", "mouseA", "mouseB", "mouseB", "mouseA", "mouseA"] + face_color_cycles = { + "label": { + "nose": [1.0, 0.0, 0.0, 1.0], + "tail": [0.0, 1.0, 0.0, 1.0], + }, + "id": { + "mouseA": [0.2, 0.4, 0.6, 1.0], + "mouseB": [0.8, 0.6, 0.2, 1.0], + }, + } + return _make_points( + data, + labels=labels, + ids=ids, + face_color_cycles=face_color_cycles, + colormap_name="plasma", + ) + + +@pytest.fixture +def three_id_points_layer(): + data = np.array( + [ + [0, 10, 20], + [0, 12, 22], + [0, 14, 24], + [1, 11, 21], + [1, 13, 23], + [1, 15, 25], + ], + dtype=float, + ) + labels = ["nose", "nose", "nose", "nose", "nose", "nose"] + ids = ["mouseA", "mouseB", "mouseC", "mouseA", "mouseB", "mouseC"] + face_color_cycles = { + "label": { + "nose": [0.9, 0.1, 0.1, 1.0], + }, + "id": { + "mouseA": [1.0, 0.0, 0.0, 1.0], + "mouseB": [0.0, 1.0, 0.0, 1.0], + "mouseC": [0.0, 0.0, 1.0, 1.0], + }, + } + return _make_points( + data, + labels=labels, + ids=ids, + face_color_cycles=face_color_cycles, + colormap_name="viridis", + ) + + +@pytest.fixture +def no_label_points_layer(): + data = np.array([[0, 1, 2], [1, 2, 3]], dtype=float) + return _make_points( + data, + labels=None, + ids=["animal1", "animal1"], + face_color_cycles={}, + ) + + +@pytest.fixture +def tracks_layer(): + data = np.array( + [ + [0, 0, 10, 20], + [0, 1, 11, 21], + [1, 0, 30, 40], + ], + dtype=float, + ) + layer = Tracks( + data, + tail_length=12, + head_length=7, + tail_width=3.5, + opacity=0.4, + blending="opaque", + name="trails", + ) + layer.visible = False + return layer + + +def test_trails_signature_contains_expected_fields(single_points_layer): + sig = trails_signature(single_points_layer, keypoints.ColorMode.BODYPART) + generations = layer_change_generations(single_points_layer) + + assert sig[0] == id(single_points_layer) + assert sig[1] == str(keypoints.ColorMode.BODYPART) + assert sig[2] == "magma" + assert sig[3] == generations.content + assert sig[4] == generations.presentation + + +def test_trails_geometry_signature_contains_shape_and_properties(single_points_layer): + sig = trails_geometry_signature(single_points_layer) + generations = layer_change_generations(single_points_layer) + + assert sig[0] == id(single_points_layer) + assert sig[1] == generations.content + + +@pytest.mark.parametrize( + ("ids", "expected"), + [ + (["animal1", "animal1"], True), + (["", ""], False), + ([1, 2], False), + (None, False), + ([], False), + ], +) +def test_is_multianimal_points_layer(ids, expected): + layer = _make_points( + np.array([[0, 1, 2], [1, 2, 3]], dtype=float), + labels=["nose", "tail"], + ids=ids, + face_color_cycles={}, + ) + assert is_multianimal_points_layer(layer) is expected + + +def test_active_trails_color_property_individual_mode_multi(multi_points_layer): + color_prop, categories, is_multi = active_trails_color_property( + multi_points_layer, + keypoints.ColorMode.INDIVIDUAL, + ) + + assert color_prop == "id" + assert is_multi is True + np.testing.assert_array_equal( + categories, + np.array(["mouseA", "mouseA", "mouseB", "mouseB", "mouseA", "mouseA"], dtype=object), + ) + + +def test_active_trails_color_property_bodypart_mode_multi(multi_points_layer): + color_prop, categories, is_multi = active_trails_color_property( + multi_points_layer, + keypoints.ColorMode.BODYPART, + ) + + assert color_prop == "label" + assert is_multi is True + np.testing.assert_array_equal( + categories, + np.array(["nose", "tail", "nose", "tail", "nose", "tail"], dtype=object), + ) + + +def test_active_trails_color_property_individual_mode_single_falls_back_to_label(single_points_layer): + color_prop, categories, is_multi = active_trails_color_property( + single_points_layer, + keypoints.ColorMode.INDIVIDUAL, + ) + + assert color_prop == "label" + assert is_multi is False + np.testing.assert_array_equal( + categories, + np.array(["nose", "tail", "nose", "tail"], dtype=object), + ) + + +def test_active_trails_color_property_raises_without_labels(no_label_points_layer): + with pytest.raises(ValueError, match="no 'label' property"): + active_trails_color_property(no_label_points_layer, keypoints.ColorMode.BODYPART) + + +def test_rgba_array_converts_rgb_rgba_and_scalar(): + arr = _trails_rgba_array( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.5], + 7.0, + ] + ) + + expected = np.array( + [ + [1.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 0.5], + [0.5, 0.5, 0.5, 1.0], + ], + dtype=float, + ) + np.testing.assert_allclose(arr, expected) + + +def test_categorical_colormap_from_points_layer_uses_face_color_cycles(multi_points_layer): + categories = np.array(["mouseB", "mouseA", "mouseB"], dtype=object) + + cmap, uniq_color, codes_norm = categorical_colormap_from_points_layer( + multi_points_layer, + "id", + categories, + ) + + assert uniq_color == ["mouseB", "mouseA"] + np.testing.assert_allclose(codes_norm, np.array([0.25, 0.75, 0.25])) + assert cmap.name == "id_categorical" + assert cmap.interpolation == "zero" + + expected_colors = np.array( + [ + [0.8, 0.6, 0.2, 1.0], + [0.2, 0.4, 0.6, 1.0], + ], + dtype=float, + ) + np.testing.assert_allclose(np.asarray(cmap.colors), expected_colors) + np.testing.assert_allclose(np.asarray(cmap.controls), np.array([0.0, 0.5, 1.0])) + + +def test_categorical_colormap_from_points_layer_prefers_cycle_override(multi_points_layer): + categories = np.array(["mouseB", "mouseA", "mouseB"], dtype=object) + override = { + "mouseA": [0.11, 0.22, 0.33, 1.0], + "mouseB": [0.44, 0.55, 0.66, 1.0], + } + + cmap, uniq_color, codes_norm = categorical_colormap_from_points_layer( + multi_points_layer, + "id", + categories, + cycle_override=override, + ) + + assert uniq_color == ["mouseB", "mouseA"] + np.testing.assert_allclose(codes_norm, np.array([0.25, 0.75, 0.25])) + expected_colors = np.array( + [ + [0.44, 0.55, 0.66, 1.0], + [0.11, 0.22, 0.33, 1.0], + ], + dtype=float, + ) + np.testing.assert_allclose(np.asarray(cmap.colors), expected_colors) + + +def test_categorical_colormap_from_points_layer_single_category_duplicates_color(single_points_layer): + categories = np.array(["nose", "nose", "nose"], dtype=object) + + cmap, uniq_color, codes_norm = categorical_colormap_from_points_layer( + single_points_layer, + "label", + categories, + ) + + assert uniq_color == ["nose"] + np.testing.assert_allclose(codes_norm, np.array([0.5, 0.5, 0.5])) + expected_colors = np.array( + [ + [1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 1.0], + ], + dtype=float, + ) + np.testing.assert_allclose(np.asarray(cmap.colors), expected_colors) + np.testing.assert_allclose(np.asarray(cmap.controls), np.array([0.0, 0.5, 1.0])) + + +def test_build_trails_payload_three_individuals_maps_to_distinct_colors(three_id_points_layer): + payload = build_trails_payload(three_id_points_layer, keypoints.ColorMode.INDIVIDUAL) + + assert payload.color_by == "id_codes" + cmap = payload.colormaps_dict["id_codes"] + codes = payload.properties["id_codes"] + + mapped = np.asarray(cmap.map(codes[:3])) + unique_rows = np.unique(np.round(mapped, decimals=8), axis=0) + + assert unique_rows.shape[0] == 3 + + +def test_categorical_colormap_from_points_layer_falls_back_to_tab20(monkeypatch, multi_points_layer): + class DummyCmap: + colors = [ + (0.11, 0.22, 0.33), + (0.44, 0.55, 0.66), + (0.77, 0.88, 0.99), + ] + + def fake_get_cmap(name): + assert name == "tab20" + return DummyCmap() + + monkeypatch.setattr("napari_deeplabcut.core.trails.plt.get_cmap", fake_get_cmap) + + categories = np.array(["missingA", "missingB", "missingA"], dtype=object) + cmap, uniq_color, codes_norm = categorical_colormap_from_points_layer( + multi_points_layer, + "id", + categories, + ) + + assert uniq_color == ["missingA", "missingB"] + np.testing.assert_allclose(codes_norm, np.array([0.25, 0.75, 0.25])) + expected_colors = np.array( + [ + [0.11, 0.22, 0.33, 1.0], + [0.44, 0.55, 0.66, 1.0], + ], + dtype=float, + ) + np.testing.assert_allclose(np.asarray(cmap.colors), expected_colors) + + +def test_categorical_colormap_from_points_layer_raises_on_empty_categories(single_points_layer): + with pytest.raises(ValueError, match="No categories found"): + categorical_colormap_from_points_layer( + single_points_layer, + "label", + np.array([], dtype=object), + ) + + +def test_trails_track_ids_single_animal_groups_by_label(single_points_layer): + track_ids = trails_track_ids(single_points_layer, is_multi=False) + np.testing.assert_array_equal(track_ids, np.array([0, 1, 0, 1])) + + +def test_trails_track_ids_multi_animal_groups_by_id_and_label(multi_points_layer): + track_ids = trails_track_ids(multi_points_layer, is_multi=True) + np.testing.assert_array_equal(track_ids, np.array([0, 1, 2, 3, 0, 1])) + + +def test_trails_track_ids_raises_without_labels(no_label_points_layer): + with pytest.raises(ValueError, match="no 'label' property"): + trails_track_ids(no_label_points_layer, is_multi=True) + + +def test_build_trails_payload_multi_individual_mode(multi_points_layer): + payload = build_trails_payload(multi_points_layer, keypoints.ColorMode.INDIVIDUAL) + + assert payload.color_by == "id_codes" + assert set(payload.properties) == {"id_codes"} + assert set(payload.colormaps_dict) == {"id_codes"} + assert payload.signature == trails_signature(multi_points_layer, keypoints.ColorMode.INDIVIDUAL) + assert payload.geometry_signature == trails_geometry_signature(multi_points_layer) + + np.testing.assert_array_equal(payload.tracks_data[:, 0], np.array([0, 1, 2, 3, 0, 1])) + np.testing.assert_allclose(payload.tracks_data[:, 1:], multi_points_layer.data) + np.testing.assert_allclose( + payload.properties["id_codes"], + np.array([0.25, 0.25, 0.75, 0.75, 0.25, 0.25]), + ) + + +def test_build_trails_payload_single_individual_mode_falls_back_to_label(single_points_layer): + payload = build_trails_payload(single_points_layer, keypoints.ColorMode.INDIVIDUAL) + + assert payload.color_by == "label_codes" + assert set(payload.properties) == {"label_codes"} + assert set(payload.colormaps_dict) == {"label_codes"} + assert payload.signature == trails_signature(single_points_layer, keypoints.ColorMode.INDIVIDUAL) + assert payload.geometry_signature == trails_geometry_signature(single_points_layer) + + np.testing.assert_array_equal(payload.tracks_data[:, 0], np.array([0, 1, 0, 1])) + np.testing.assert_allclose(payload.tracks_data[:, 1:], single_points_layer.data) + np.testing.assert_allclose(payload.properties["label_codes"], np.array([0.25, 0.75, 0.25, 0.75])) + + +def test_tracks_kwargs_from_display_config_excludes_visible(): + cfg = TrailsDisplayConfig( + tail_length=70, + head_length=12, + tail_width=4.5, + opacity=0.75, + blending="opaque", + visible=False, + ) + + kwargs = tracks_kwargs_from_display_config(cfg) + + assert kwargs == { + "tail_length": 70, + "head_length": 12, + "tail_width": 4.5, + "opacity": 0.75, + "blending": "opaque", + } + assert "visible" not in kwargs + + +def test_display_config_from_tracks_layer_reads_all_display_fields(tracks_layer): + cfg = display_config_from_tracks_layer(tracks_layer) + + assert cfg == TrailsDisplayConfig( + tail_length=12, + head_length=7, + tail_width=3.5, + opacity=0.4, + blending="opaque", + visible=False, + ) + + +def test_display_config_from_tracks_layer_visible_override(tracks_layer): + cfg = display_config_from_tracks_layer(tracks_layer, visible=True) + assert cfg.visible is True diff --git a/src/napari_deeplabcut/_tests/core/test_writer_promotion.py b/src/napari_deeplabcut/_tests/core/test_writer_promotion.py new file mode 100644 index 00000000..1a8edcee --- /dev/null +++ b/src/napari_deeplabcut/_tests/core/test_writer_promotion.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from napari_deeplabcut import _writer +from napari_deeplabcut.config.models import AnnotationKind, DLCHeaderModel +from napari_deeplabcut.core.errors import MissingProvenanceError +from napari_deeplabcut.ui import dialogs + + +def _make_minimal_points_metadata( + root: Path, header, *, name: str, kind: AnnotationKind, save_target: dict | None = None +): + # Minimal metadata payload compatible with form_df usage + md = { + "name": name, + "properties": { + "label": ["bodypart1", "bodypart2"], + "id": ["", ""], + "likelihood": [1.0, 1.0], + }, + "metadata": { + "header": header, + "paths": ["img000.png", "img000.png"], # indices 0 and 1 refer to same img for simplicity + "root": str(root), + "io": { + "schema_version": 1, + "project_root": str(root), + "source_relpath_posix": f"{name}.h5", + "kind": kind, + "dataset_key": "keypoints", + }, + }, + } + if save_target is not None: + md["metadata"]["save_target"] = save_target + return md + + +def _read_keypoints_h5(p: Path) -> pd.DataFrame: + return pd.read_hdf(p, key="keypoints") + + +def test_writer_aborts_if_machine_source_without_save_target(tmp_path: Path): + # Create minimal header for single animal: scorer/bodyparts/coords + cols = pd.MultiIndex.from_product( + [["John"], ["bodypart1", "bodypart2"], ["x", "y"]], + names=["scorer", "bodyparts", "coords"], + ) + header = DLCHeaderModel(columns=cols) + + metadata = _make_minimal_points_metadata( + tmp_path, header, name="machinelabels-iter0", kind=AnnotationKind.MACHINE, save_target=None + ) + + points = np.array( + [ + [0.0, 10.0, 20.0], + [0.0, 30.0, 40.0], + ], + dtype=float, + ) + + with pytest.raises(MissingProvenanceError): + _writer.write_hdf_napari_dlc("ignored.h5", points, metadata) + + +def test_writer_promotion_writes_collecteddata_and_rewrites_scorer(tmp_path: Path, monkeypatch): + # Build minimal header for single animal + cols = pd.MultiIndex.from_product( + [["machine"], ["bodypart1", "bodypart2"], ["x", "y"]], + names=["scorer", "bodyparts", "coords"], + ) + header = DLCHeaderModel(columns=cols) + + # Ensure overwrite confirm always returns True + monkeypatch.setattr(dialogs, "maybe_confirm_overwrite", lambda *args, **kwargs: True) + + # Pretend we loaded from a machine file but will promote to GT file CollectedData_Alice.h5 + save_target = { + "schema_version": 1, + "project_root": str(tmp_path), + "source_relpath_posix": "CollectedData_Alice.h5", + "kind": AnnotationKind.GT, + "dataset_key": "keypoints", + "scorer": "Alice", + } + + metadata = _make_minimal_points_metadata( + tmp_path, + header, + name="machinelabels-iter0", + kind=AnnotationKind.MACHINE, + save_target=save_target, + ) + + # Create a dummy machine file and snapshot it (writer must not touch it) + machine_path = tmp_path / "machinelabels-iter0.h5" + df_machine = pd.DataFrame(np.nan, columns=cols, index=["img000.png"]) + df_machine.to_hdf(machine_path, key="keypoints", mode="w") + machine_before = _read_keypoints_h5(machine_path) + + points = np.array( + [ + [0.0, 33.0, 44.0], # bodypart1 + [0.0, 55.0, 66.0], # bodypart2 + ], + dtype=float, + ) + + fnames = _writer.write_hdf_napari_dlc("ignored.h5", points, metadata) + assert Path(fnames[0]).name == "CollectedData_Alice.h5" + assert Path(fnames[1]).name == Path(fnames[0]).with_suffix(".csv").name + + gt_path = tmp_path / "CollectedData_Alice.h5" + assert gt_path.exists() + + df_gt = _read_keypoints_h5(gt_path) + + # Scorer level should be rewritten to Alice + assert "scorer" in df_gt.columns.names + assert set(df_gt.columns.get_level_values("scorer")) == {"Alice"} + + # Machine file should be unchanged + machine_after = _read_keypoints_h5(machine_path) + pd.testing.assert_frame_equal(machine_before, machine_after) diff --git a/src/napari_deeplabcut/_tests/e2e/__init__.py b/src/napari_deeplabcut/_tests/e2e/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/napari_deeplabcut/_tests/e2e/conftest.py b/src/napari_deeplabcut/_tests/e2e/conftest.py new file mode 100644 index 00000000..31060437 --- /dev/null +++ b/src/napari_deeplabcut/_tests/e2e/conftest.py @@ -0,0 +1,134 @@ +# src/napari_deeplabcut/_tests/e2e/conftest.py +from __future__ import annotations + +from pathlib import Path + +import pytest +from qtpy.QtWidgets import QInputDialog, QMessageBox + + +def pytest_collection_modifyitems(config, items): + """ + Automatically mark all tests collected from this folder as @pytest.mark.e2e. + """ + here = Path(__file__).resolve().parent + + for item in items: + try: + item_path = Path(str(item.fspath)).resolve() + except Exception: + continue + + if here in item_path.parents or item_path == here: + item.add_marker(pytest.mark.e2e) + + +@pytest.fixture(autouse=True) +def _auto_accept_qmessagebox(monkeypatch): + """Prevent any QMessageBox modal dialogs from blocking tests.""" + monkeypatch.setattr(QMessageBox, "warning", lambda *args, **kwargs: QMessageBox.Yes) + monkeypatch.setattr(QMessageBox, "question", lambda *args, **kwargs: QMessageBox.Yes) + + +@pytest.fixture +def inputdialog(monkeypatch): + """ + Controller for QInputDialog.getText used by promotion-to-GT first save prompt. + """ + state = {"value": "Alice", "ok": True, "calls": 0, "forbid": False} + + def _fake_getText(*args, **kwargs): + state["calls"] += 1 + if state["forbid"]: + raise AssertionError("QInputDialog.getText was called but forbid=True") + return state["value"], state["ok"] + + monkeypatch.setattr(QInputDialog, "getText", _fake_getText) + + class Controller: + @property + def calls(self): + return state["calls"] + + def set(self, value: str, ok: bool = True): + state["value"] = value + state["ok"] = ok + return self + + def forbid(self): + state["forbid"] = True + return self + + return Controller() + + +@pytest.fixture(autouse=True) +def overwrite_confirm(monkeypatch): + """ + Control the overwrite-confirmation path used by the UI preflight. + + API: + - forbid(): fail test if confirmation is requested for a real overwrite + - cancel(): return False (simulate user cancel) + - capture(): record calls and return True + - set_result(bool): return chosen bool + - reset_calls(): clear recorded calls + """ + calls = [] + state = {"mode": "always_true", "result": True} + + def _patched_maybe_confirm_overwrite(parent, report): + n_pairs = getattr(report, "n_overwrites", 0) + n_images = getattr(report, "n_frames", None) + + calls.append( + { + "parent_type": type(parent).__name__ if parent is not None else None, + "layer_name": getattr(report, "layer_name", None), + "destination_path": getattr(report, "destination_path", None), + "n_pairs": n_pairs, + "n_images": n_images, + "details_text": getattr(report, "details_text", None), + } + ) + + # In "forbid" mode: only fail if there is a real overwrite. + if state["mode"] == "forbid" and n_pairs > 0: + raise AssertionError("maybe_confirm_overwrite was called unexpectedly for a real overwrite (n_pairs > 0).") + + return state["result"] + + import napari_deeplabcut.ui.dialogs as dlg + + monkeypatch.setattr(dlg, "maybe_confirm_overwrite", _patched_maybe_confirm_overwrite) + + class Controller: + @property + def calls(self): + return calls + + def forbid(self): + state["mode"] = "forbid" + state["result"] = True + return self + + def cancel(self): + state["mode"] = "capture" + state["result"] = False + return self + + def capture(self): + state["mode"] = "capture" + state["result"] = True + return self + + def set_result(self, value: bool): + state["mode"] = "capture" + state["result"] = bool(value) + return self + + def reset_calls(self): + calls.clear() + return self + + return Controller() diff --git a/src/napari_deeplabcut/_tests/e2e/test_e2e_misc.py b/src/napari_deeplabcut/_tests/e2e/test_e2e_misc.py new file mode 100644 index 00000000..a000013a --- /dev/null +++ b/src/napari_deeplabcut/_tests/e2e/test_e2e_misc.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import pytest + +from .utils import assert_only_these_changed_nan_safe, sig_equal + +# -----------------------------------------------------------------------------# +# Assertions on signatures of written files, with NaN-stable equality +# This is required because in DLC h5s NaN means "unlabeled", +# so NaN to value changes are meaningful and should be detected, +# but NaN to NaN should be treated as unchanged (remains unlabeled). +# Below tests are meant to avoid any future regressions in this logic, +# which is critical for correct writer behavior and testability. +# -----------------------------------------------------------------------------# + + +def test_sig_equal_treats_nan_as_equal(): + a = {"b2x": float("nan"), "b2y": float("nan")} + b = {"b2x": float("nan"), "b2y": float("nan")} + assert sig_equal(a, b) + + +def test_sig_equal_detects_nan_to_value_change(): + a = {"b2x": float("nan")} + b = {"b2x": 77.0} + assert not sig_equal(a, b) + + +def test_sig_equal_detects_value_change(): + a = {"b1x": 10.0} + b = {"b1x": 11.0} + assert not sig_equal(a, b) + + +def test_assert_only_these_changed_nan_safe_passes_expected_case(tmp_path: Path): + p1 = tmp_path / "A.h5" + p2 = tmp_path / "B.h5" + + before = { + p1: {"b2x": float("nan")}, + p2: {"b2x": float("nan")}, + } + after = { + p1: {"b2x": float("nan")}, # unchanged + p2: {"b2x": 77.0}, # changed + } + + assert_only_these_changed_nan_safe(before, after, changed={p2}) + + +def test_assert_only_these_changed_nan_safe_fails_when_unexpected_change(tmp_path: Path): + p1 = tmp_path / "A.h5" + before = {p1: {"b2x": float("nan")}} + after = {p1: {"b2x": 1.0}} + + with pytest.raises(AssertionError): + assert_only_these_changed_nan_safe(before, after, changed=set()) diff --git a/src/napari_deeplabcut/_tests/e2e/test_layer_coloring.py b/src/napari_deeplabcut/_tests/e2e/test_layer_coloring.py new file mode 100644 index 00000000..685f5a8c --- /dev/null +++ b/src/napari_deeplabcut/_tests/e2e/test_layer_coloring.py @@ -0,0 +1,305 @@ +from __future__ import annotations + +import numpy as np +import pytest +from napari.layers import Points +from qtpy.QtWidgets import QDockWidget + +from napari_deeplabcut.config.models import DLCHeaderModel + +from ..conftest import force_show +from .utils import _cycles_from_policy, _make_minimal_dlc_project, _scheme_from_policy + + +def _get_existing_keypoint_controls(viewer): + from napari_deeplabcut._widgets import KeypointControls + + matches = list(viewer.window._qt_window.findChildren(KeypointControls)) + assert matches, "Expected viewer fixture to provide a KeypointControls widget" + assert len(matches) == 1, f"Expected exactly one KeypointControls widget, found {len(matches)}" + + controls = matches[0] + + dock = controls.parentWidget() + while dock is not None and not isinstance(dock, QDockWidget): + dock = dock.parentWidget() + + assert dock is not None, "Expected KeypointControls to be docked in a QDockWidget" + return controls, dock + + +@pytest.mark.usefixtures("qtbot") +def test_config_placeholder_points_layer_colors_after_first_keypoint_added(viewer, qtbot, tmp_path): + """ + E2E regression: a Points layer created from config.yaml starts empty (placeholder). + When the user begins adding keypoints, the layer must switch into categorical + coloring (cycle mode) and colors must follow the derived bodypart policy. + """ + project, config_path, labeled_folder, h5_path = _make_minimal_dlc_project(tmp_path) + + from napari_deeplabcut.core import keypoints + + controls, controls_dock = _get_existing_keypoint_controls(viewer) + + # 1) Open config.yaml -> creates placeholder Points layer (empty) + viewer.open(str(config_path), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: any(isinstance(ly, Points) for ly in viewer.layers), timeout=5000) + + placeholder = next((ly for ly in viewer.layers if isinstance(ly, Points)), None) + assert placeholder is not None + assert placeholder.data is None or len(placeholder.data) == 0 + + # 2) Placeholder should carry the policy input we rely on + md = placeholder.metadata or {} + assert "config_colormap" in md, "Expected config_colormap in metadata for config.yaml placeholder layer" + assert "header" in md, "Expected header in metadata for config.yaml placeholder layer" + + # 3) Begin editing: add bodypart1 then bodypart2 + store = controls._stores.get(placeholder) + assert store is not None, "Expected KeypointStore to be registered for placeholder Points layer" + + # Add first point + placeholder.add(np.array([0.0, 20.0, 10.0], dtype=float)) + qtbot.waitUntil(lambda: placeholder.data is not None and len(placeholder.data) == 1, timeout=2000) + + # Wait for recolor to activate cycle mode + qtbot.waitUntil(lambda: placeholder.face_color_mode == "cycle", timeout=5000) + + # Validate color matches derived policy for the actual stored label + label0 = str(placeholder.properties["label"][0]) + expected_cycles = _cycles_from_policy(placeholder) + expected0 = np.asarray(expected_cycles["label"][label0], dtype=float) + c0 = np.asarray(placeholder._face.colors[0], dtype=float) + assert np.allclose(c0, expected0, atol=1e-6), f"color mismatch for {label0!r}: got={c0}, expected={expected0}" + + # Ensure the second add targets a different bodypart. + hdr = placeholder.metadata.get("header") + assert hdr is not None, "Expected header in placeholder metadata" + + header_model = hdr if isinstance(hdr, DLCHeaderModel) else DLCHeaderModel.model_validate(hdr) + all_bodyparts = list(header_model.bodyparts) + assert all_bodyparts, "Header has no bodyparts; cannot drive second add deterministically." + + label_alt = next((bp for bp in all_bodyparts if str(bp) != label0), None) + assert label_alt is not None, f"Only one bodypart present; cannot add a second distinct keypoint. label0={label0!r}" + + placeholder.selected_data = set() + store.current_keypoint = keypoints.Keypoint(str(label_alt), "") + + placeholder.add(np.array([0.0, 33.0, 44.0], dtype=float)) + qtbot.waitUntil(lambda: placeholder.data is not None and len(placeholder.data) == 2, timeout=2000) + + label1 = str(placeholder.properties["label"][1]) + expected1 = np.asarray(expected_cycles["label"][label1], dtype=float) + c1 = np.asarray(placeholder._face.colors[1], dtype=float) + assert np.allclose(c1, expected1, atol=1e-6), f"color mismatch for {label1!r}: got={c1}, expected={expected1}" + + assert label0 != label1, f"Expected successive adds to label different keypoints, got {label0!r} then {label1!r}" + assert not np.allclose(c0, c1, atol=1e-6), "Expected distinct colors for different labels in cycle mode" + + +@pytest.mark.usefixtures("qtbot") +def test_config_placeholder_multianimal_colors_by_id_after_first_keypoint_added( + viewer, + qtbot, + multianimal_config_project, +): + """ + E2E regression: a Points layer created from a multi-animal config.yaml starts empty. + When the user adds keypoints, the layer must switch into categorical coloring + (cycle mode) and, in multi-animal mode, color by id according to the derived policy. + """ + _, config_path = multianimal_config_project + + from napari_deeplabcut.core import keypoints + + controls, controls_dock = _get_existing_keypoint_controls(viewer) + + # 1) Open config.yaml -> empty placeholder Points layer + viewer.open(str(config_path), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: any(isinstance(ly, Points) for ly in viewer.layers), timeout=5_000) + + placeholder = next((ly for ly in viewer.layers if isinstance(ly, Points)), None) + assert placeholder is not None + assert placeholder.data is None or len(placeholder.data) == 0 + + # 2) Must carry policy inputs; exact cycles are derived, not metadata-authored + md = placeholder.metadata or {} + assert "config_colormap" in md, "Expected config_colormap in metadata" + assert "header" in md, "Expected header in metadata" + + expected_cycles = _cycles_from_policy(placeholder) + id_cycles = expected_cycles["id"] + assert "animal1" in id_cycles, f"Expected 'animal1' in derived id cycles; got keys={list(id_cycles)[:10]}" + assert "animal2" in id_cycles, f"Expected 'animal2' in derived id cycles; got keys={list(id_cycles)[:10]}" + + # 3) Begin editing: add a point for animal1, then animal2 + store = controls._stores.get(placeholder) + assert store is not None, "Expected KeypointStore for placeholder Points layer" + + # Add first point: (frame, y, x) + store.current_keypoint = keypoints.Keypoint("bodypart1", "animal1") + placeholder.add(np.array([0.0, 12.0, 34.0], dtype=float)) + + qtbot.waitUntil(lambda: placeholder.data is not None and len(placeholder.data) == 1, timeout=2_000) + + # Wait for recolor timer to switch to cycle mode + qtbot.waitUntil(lambda: placeholder.face_color_mode == "cycle", timeout=5_000) + + # Must be coloring by id in multi-animal case + assert placeholder._face.color_properties.name == "id" + + got0 = np.asarray(placeholder._face.colors[0], dtype=float) + exp0 = np.asarray(id_cycles["animal1"], dtype=float) + assert np.allclose(got0, exp0, atol=1e-6), f"animal1 color mismatch: got={got0}, expected={exp0}" + + # Add second point for animal2 + store.current_keypoint = keypoints.Keypoint("bodypart2", "animal2") + placeholder.add(np.array([0.0, 56.0, 78.0], dtype=float)) + + qtbot.waitUntil(lambda: placeholder.data is not None and len(placeholder.data) == 2, timeout=2_000) + qtbot.wait(50) # small buffer for color refresh + + assert placeholder.face_color_mode == "cycle" + assert placeholder._face.color_properties.name == "id" + + got1 = np.asarray(placeholder._face.colors[1], dtype=float) + exp1 = np.asarray(id_cycles["animal2"], dtype=float) + assert np.allclose(got1, exp1, atol=1e-6), f"animal2 color mismatch: got={got1}, expected={exp1}" + + assert not np.allclose(got0, got1, atol=1e-6), "Expected distinct colors for animal1 vs animal2" + + +@pytest.mark.usefixtures("qtbot") +def test_color_scheme_panel_toggle_shows_active_then_full_config_bodyparts( + viewer, + qtbot, + tmp_path, +): + """ + E2E: + - open config first -> placeholder points layer + - open dataset folder for context + - add one visible keypoint to placeholder + - color scheme panel (unchecked) should show only the active/current visible keypoint(s) + - toggling config preview should show all bodyparts from config.yaml + """ + project, config_path, labeled_folder, _h5_path = _make_minimal_dlc_project(tmp_path) + + from napari_deeplabcut.core import keypoints + + controls, controls_dock = _get_existing_keypoint_controls(viewer) + # Force-show the viewer hierarchy and relevant docks/panels. + force_show(viewer.window._qt_window, qtbot) + force_show(controls_dock, qtbot) + force_show(controls, qtbot) + force_show(controls._color_scheme_display, qtbot) + force_show(controls._color_scheme_panel, qtbot) + + # 1) Open config first -> placeholder Points layer + viewer.open(str(config_path), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: any(isinstance(ly, Points) for ly in viewer.layers), timeout=5_000) + + placeholder = next((ly for ly in viewer.layers if isinstance(ly, Points)), None) + assert placeholder is not None + assert placeholder.data is None or len(placeholder.data) == 0 + + # 2) Open folder so image/dataset context exists + viewer.open(str(labeled_folder), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len(viewer.layers) >= 2, timeout=10_000) + qtbot.wait(200) + + # Make sure the placeholder is the active target layer + viewer.layers.selection.active = placeholder + + store = controls._stores.get(placeholder) + assert store is not None + + # Deterministically add bodypart1 + store.current_keypoint = keypoints.Keypoint("bodypart1", "") + placeholder.add(np.array([0.0, 20.0, 10.0], dtype=float)) + + qtbot.waitUntil(lambda: placeholder.data is not None and len(placeholder.data) == 1, timeout=2_000) + qtbot.waitUntil(lambda: placeholder.face_color_mode == "cycle", timeout=5_000) + + panel = controls._color_scheme_panel + + expected_active = _scheme_from_policy(placeholder, "label", ["bodypart1"]) + qtbot.waitUntil(lambda: panel.display.scheme_dict == expected_active, timeout=5_000) + + # Toggle full config preview + panel._toggle.setChecked(True) + + expected_config = _scheme_from_policy(placeholder, "label", ["bodypart1", "bodypart2"]) + qtbot.waitUntil(lambda: panel.display.scheme_dict == expected_config, timeout=5_000) + + +@pytest.mark.usefixtures("qtbot") +def test_color_scheme_panel_multianimal_toggle_shows_active_then_full_config_individuals( + viewer, + qtbot, + multianimal_config_project, +): + """ + E2E: + - open multi-animal config first -> placeholder points layer + - add one keypoint for animal1 + - because the first/real KeypointControls switches multi-animal coloring + to individual mode, the color scheme panel should: + * unchecked: show only currently visible active individual(s) + * checked: show all configured individuals from config.yaml + """ + _, config_path = multianimal_config_project + + from napari_deeplabcut.core import keypoints + + controls, controls_dock = _get_existing_keypoint_controls(viewer) + + force_show(viewer.window._qt_window, qtbot) + force_show(controls_dock, qtbot) + force_show(controls, qtbot) + force_show(controls._color_scheme_display, qtbot) + force_show(controls._color_scheme_panel, qtbot) + + # Open config -> placeholder Points layer + viewer.open(str(config_path), plugin="napari-deeplabcut") + qtbot.waitUntil( + lambda: any(isinstance(ly, Points) for ly in viewer.layers), + timeout=5_000, + ) + + placeholder = next((ly for ly in viewer.layers if isinstance(ly, Points)), None) + assert placeholder is not None + assert placeholder.data is None or len(placeholder.data) == 0 + + # Wait until the existing controls instance has wired the layer + qtbot.waitUntil(lambda: placeholder in controls._stores, timeout=5_000) + store = controls._stores.get(placeholder) + assert store is not None + + # This assertion is now valid because we're using the controls instance + # that actually wired the layer. + qtbot.waitUntil(lambda: controls.color_mode == "individual", timeout=5_000) + assert controls.color_mode == "individual" + + # Make the placeholder the active target layer + viewer.layers.selection.active = placeholder + + # Add one keypoint for animal1 + store.current_keypoint = keypoints.Keypoint("bodypart1", "animal1") + placeholder.add(np.array([0.0, 12.0, 34.0], dtype=float)) + + qtbot.waitUntil(lambda: placeholder.data is not None and len(placeholder.data) == 1, timeout=2_000) + qtbot.waitUntil(lambda: placeholder.face_color_mode == "cycle", timeout=5_000) + qtbot.waitUntil(lambda: placeholder._face.color_properties.name == "id", timeout=5_000) + + panel = controls._color_scheme_panel + + expected_active = _scheme_from_policy(placeholder, "id", ["animal1"]) + qtbot.waitUntil(lambda: panel.display.scheme_dict == expected_active, timeout=5_000) + + # Toggle full config preview -> should show both configured individuals + panel._toggle.setChecked(True) + + expected_config = _scheme_from_policy(placeholder, "id", ["animal1", "animal2"]) + qtbot.waitUntil(lambda: panel.display.scheme_dict == expected_config, timeout=5_000) diff --git a/src/napari_deeplabcut/_tests/e2e/test_overwrite_and_merge.py b/src/napari_deeplabcut/_tests/e2e/test_overwrite_and_merge.py new file mode 100644 index 00000000..c4a40ca5 --- /dev/null +++ b/src/napari_deeplabcut/_tests/e2e/test_overwrite_and_merge.py @@ -0,0 +1,208 @@ +import logging + +import numpy as np +import pandas as pd +import pytest +from napari.layers import Points + +from napari_deeplabcut.config.models import DLCHeaderModel + +from .utils import _get_coord_from_df, _get_points_layer_with_data, _make_minimal_dlc_project, _set_or_add_bodypart_xy + +logger = logging.getLogger(__name__) + + +@pytest.mark.usefixtures("qtbot") +def test_config_first_hazard_regression_no_silent_deletion(viewer, keypoint_controls, qtbot, tmp_path, caplog): + """ + Regression for the original report: + Save the WRONG (placeholder) layer and still preserve previous labels due to merge-on-save. + """ + caplog.set_level(logging.DEBUG) + + project, config_path, labeled_folder, h5_path = _make_minimal_dlc_project(tmp_path) + + pre = pd.read_hdf(h5_path, key="keypoints") + assert np.isfinite(_get_coord_from_df(pre, "bodypart1", "x")) + assert np.isnan(_get_coord_from_df(pre, "bodypart2", "x")) + + from napari_deeplabcut.core import keypoints + + viewer.window.add_dock_widget(keypoint_controls, name="Keypoint controls", area="right") + + # Open config first -> placeholder Points layer exists + viewer.open(str(config_path), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len([ly for ly in viewer.layers if isinstance(ly, Points)]) >= 1, timeout=5_000) + + placeholder = next((ly for ly in viewer.layers if isinstance(ly, Points)), None) + assert placeholder is not None + assert placeholder.data is None or len(placeholder.data) == 0 + + # Open folder -> images + GT points layer + viewer.open(str(labeled_folder), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len(viewer.layers) >= 2, timeout=10_000) + qtbot.wait(200) + + # Placeholder should still be present for this regression to apply + assert placeholder in viewer.layers + + store = keypoint_controls._stores.get(placeholder) + assert store is not None + + # Add a new bodypart2 point to placeholder using (frame, y, x) + store.current_keypoint = keypoints.Keypoint("bodypart2", "") + placeholder.add(np.array([0.0, 33.0, 44.0], dtype=float)) + + viewer.layers.selection.active = placeholder + viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") + qtbot.wait(200) + + post = pd.read_hdf(h5_path, key="keypoints") + b1x_post = _get_coord_from_df(post, "bodypart1", "x") + b2x_post = _get_coord_from_df(post, "bodypart2", "x") + + assert np.isfinite(b1x_post), "bodypart1 must be preserved (no silent deletion)." + assert np.isfinite(b2x_post), "bodypart2 must be saved." + assert b2x_post == 44.0 + + +@pytest.mark.usefixtures("qtbot") +def test_no_overwrite_warning_when_only_filling_nans(viewer, keypoint_controls, qtbot, tmp_path, overwrite_confirm): + """ + Adding new labels (filling NaNs) must not prompt overwrite confirmation. + """ + overwrite_confirm.forbid() + + _, _, labeled_folder, h5_path = _make_minimal_dlc_project(tmp_path) + + viewer.window.add_dock_widget(keypoint_controls, name="Keypoint controls", area="right") + + viewer.open(str(labeled_folder), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len(viewer.layers) >= 2, timeout=10_000) + qtbot.wait(200) + + points = _get_points_layer_with_data(viewer) + + logger.debug("points.name: %s", points.name) + logger.debug("points.data shape: %s", None if points.data is None else np.asarray(points.data).shape) + logger.debug("points.data[:5]: %s", None if points.data is None else np.asarray(points.data)[:5]) + logger.debug( + "any NaNs in points.data: %s", False if points.data is None else np.isnan(np.asarray(points.data)).any() + ) + logger.debug( + "any finite xy: %s", False if points.data is None else np.isfinite(np.asarray(points.data)[:, 1:3]).any() + ) + + logger.debug("len(label): %s", len(points.properties.get("label", []))) + logger.debug("len(id): %s", len(points.properties.get("id", []))) + logger.debug("label[:10]: %s", points.properties.get("label", [])[:10]) + logger.debug("id[:10]: %s", points.properties.get("id", [])[:10]) + + hdr = points.metadata.get("header") + logger.debug("header type: %s", type(hdr)) + if hdr is not None: + if isinstance(hdr, DLCHeaderModel): + header_model = hdr + elif isinstance(hdr, dict): + header_model = DLCHeaderModel.model_validate(hdr) + else: + header_model = DLCHeaderModel(columns=hdr) + + # Prefer portable inspection: tuple columns (pandas optional) + logger.debug("header ncols=%s", len(header_model.columns)) + logger.debug("header scorer=%s", header_model.scorer) + logger.debug("header individuals=%s", header_model.individuals) + logger.debug("header bodyparts=%s", header_model.bodyparts) + logger.debug("header coords=%s", header_model.coords) + + logger.info("points.data[:5] = %s", points.data[:5]) + logger.info("any NaNs in points.data = %s", np.isnan(points.data).any()) + logger.info("labels[:10] = %s", points.properties.get("label")[:10]) + logger.info("ids[:10] = %s", points.properties.get("id")[:10] if "id" in points.properties else None) + store = keypoint_controls._stores.get(points) + assert store is not None + + # Fill NaNs for bodypart2 + _set_or_add_bodypart_xy(points, store, "bodypart2", x=44.0, y=33.0) + + viewer.layers.selection.active = points + viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") + qtbot.wait(200) + + post = pd.read_hdf(h5_path, key="keypoints") + assert np.isfinite(_get_coord_from_df(post, "bodypart1", "x")) + assert np.isfinite(_get_coord_from_df(post, "bodypart2", "x")) + + +@pytest.mark.usefixtures("qtbot") +def test_overwrite_warning_triggers_on_conflict(viewer, keypoint_controls, qtbot, tmp_path, overwrite_confirm): + """ + Modifying an existing non-NaN label must trigger overwrite confirmation. + """ + overwrite_confirm.capture().reset_calls() + + project, config_path, labeled_folder, h5_path = _make_minimal_dlc_project(tmp_path) + viewer.window.add_dock_widget(keypoint_controls, name="Keypoint controls", area="right") + + viewer.open(str(labeled_folder), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len(viewer.layers) >= 2, timeout=10_000) + qtbot.wait(200) + + points = _get_points_layer_with_data(viewer) + store = keypoint_controls._stores.get(points) + assert store is not None + + # Create conflict: overwrite bodypart1 from (10,20) -> (99,88) + _set_or_add_bodypart_xy(points, store, "bodypart1", x=99.0, y=88.0) + + viewer.layers.selection.active = points + keypoint_controls._save_layers_dialog(selected=True) + qtbot.wait(200) + + assert len(overwrite_confirm.calls) == 1, "Expected overwrite confirmation to be requested once." + assert overwrite_confirm.calls[0]["n_pairs"] is not None + assert overwrite_confirm.calls[0]["n_pairs"] >= 1 + + post = pd.read_hdf(h5_path, key="keypoints") + assert _get_coord_from_df(post, "bodypart1", "x") == 99.0 + + +@pytest.mark.usefixtures("qtbot") +def test_overwrite_warning_cancel_aborts_write(viewer, keypoint_controls, qtbot, tmp_path, overwrite_confirm): + """ + If overwrite confirmation is requested and user cancels, file must remain unchanged. + """ + overwrite_confirm.cancel().reset_calls() + + project, config_path, labeled_folder, h5_path = _make_minimal_dlc_project(tmp_path) + + pre = pd.read_hdf(h5_path, key="keypoints") + b1x_pre = _get_coord_from_df(pre, "bodypart1", "x") + b1y_pre = _get_coord_from_df(pre, "bodypart1", "y") + + viewer.window.add_dock_widget(keypoint_controls, name="Keypoint controls", area="right") + + viewer.open(str(labeled_folder), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len(viewer.layers) >= 2, timeout=10_000) + qtbot.wait(200) + + points = _get_points_layer_with_data(viewer) + store = keypoint_controls._stores.get(points) + assert store is not None + + _set_or_add_bodypart_xy(points, store, "bodypart1", x=456.0, y=123.0) + + viewer.layers.selection.active = points + try: + keypoint_controls._save_layers_dialog(selected=True) + except Exception: + # Some napari/npe2 versions may raise when writer aborts; file integrity is what matters. + pass + + qtbot.wait(200) + + assert len(overwrite_confirm.calls) == 1, "Expected overwrite confirmation to be requested once." + + post = pd.read_hdf(h5_path, key="keypoints") + assert _get_coord_from_df(post, "bodypart1", "x") == b1x_pre + assert _get_coord_from_df(post, "bodypart1", "y") == b1y_pre diff --git a/src/napari_deeplabcut/_tests/e2e/test_points_layers.py b/src/napari_deeplabcut/_tests/e2e/test_points_layers.py new file mode 100644 index 00000000..772bb4a4 --- /dev/null +++ b/src/napari_deeplabcut/_tests/e2e/test_points_layers.py @@ -0,0 +1,241 @@ +import numpy as np +import pytest +from napari.layers import Points + +from napari_deeplabcut.core.layers import populate_keypoint_layer_properties + + +@pytest.mark.usefixtures("qtbot") +def test_on_insert_empty_points_layer_does_not_crash(viewer, make_real_header_factory): + """ + Contract: inserting an empty Points layer must not crash. + This guards against KeyError: nan coming from napari cycle colormap logic. + """ + header = make_real_header_factory(individuals=("animal1",)) # either is fine + md = populate_keypoint_layer_properties( + header, + labels=[], # empty properties + ids=[], # empty properties + likelihood=np.array([], dtype=float), + paths=[], + colormap="viridis", + ) + + # napari Points layer coordinates: (N, D). For your plugin D=3 (frame,y,x) + empty_data = np.empty((0, 3), dtype=float) + + # The test is simply: adding must not raise + viewer.add_points(empty_data, **md) + + +@pytest.mark.usefixtures("qtbot") +def test_on_insert_empty_points_layer_does_not_enable_cycle_mode(viewer, make_real_header_factory): + """ + Contract: for empty layers, widget should not set face_color_mode='cycle' + (or should otherwise avoid the cycle colormap path that crashes on nan). + """ + header = make_real_header_factory(individuals=("",)) # single animal + md = populate_keypoint_layer_properties( + header, + labels=[], + ids=[], + likelihood=np.array([], dtype=float), + paths=[], + colormap="viridis", + ) + layer = viewer.add_points(np.empty((0, 3), dtype=float), **md) + + assert isinstance(layer, Points) + assert layer.data.shape[0] == 0 + + # Allow either "direct" or something else, but cycle is unsafe for empties. + assert layer.face_color_mode != "cycle" + + +@pytest.mark.usefixtures("qtbot") +def test_adopt_existing_empty_points_layer_does_not_crash(viewer, make_real_header_factory): + """ + Contract: adoption path must not crash for empty points layers. + This exercises _adopt_existing_layers -> _handle_existing_points_layer. + """ + + # Add layer BEFORE creating the widget (forces adoption path) + header = make_real_header_factory(individuals=("animal1",)) + md = populate_keypoint_layer_properties( + header, + labels=[], + ids=[], + likelihood=np.array([], dtype=float), + paths=[], + colormap="viridis", + ) + viewer.add_points(np.empty((0, 3), dtype=float), **md) + + # If we got here without exception, adoption didn’t crash + pts_layers = [ly for ly in viewer.layers if isinstance(ly, Points)] + assert pts_layers, "Expected the empty Points layer to exist" + + +@pytest.mark.usefixtures("qtbot") +def test_layer_insert_does_not_crash_when_current_property_is_nan(viewer, keypoint_controls, make_real_header_factory): + """ + Contract: even if a property value is NaN (bad input), widget must not crash. + It may fall back to direct mode or sanitize the property. + """ + header = make_real_header_factory(individuals=("",)) + md = populate_keypoint_layer_properties( + header, + labels=["bodypart1"], + ids=[""], + likelihood=np.array([1.0], dtype=float), + paths=[], + colormap="viridis", + ) + + # One point, but corrupt the property used for cycling + data = np.array([[0.0, 10.0, 20.0]], dtype=float) + md["properties"]["label"] = [np.nan] # intentionally wrong + + # Must not raise on insertion + layer = viewer.add_points(data, **md) + # Plot cannot be formed because of the NaN, + # but the layer must still be added and cycle mode must not be enabled. + assert keypoint_controls._matplotlib_canvas.df is None + assert isinstance(layer, Points) + assert layer.face_color_mode != "cycle" + + +@pytest.mark.usefixtures("qtbot") +def test_copy_paste_points_to_new_frame_does_not_crash_and_offsets_frame( + viewer, + keypoint_controls, + make_real_header_factory, + qtbot, +): + """ + Regression test for DLC's patched Points._paste_data. + + Scenario: + - create a 3D (t, y, x) points layer + - copy selected points on frame 0 + - move to frame 1 + - paste + + Contract: + - must not crash + - pasted points must appear on the current frame + - point properties (e.g. labels) must be preserved + """ + controls = keypoint_controls + # Add an image stack to make the time/frame axis explicit in a realistic way. + viewer.add_image( + np.zeros((2, 64, 64), dtype=np.uint8), + name="frames", + metadata={"paths": ["frame0.png", "frame1.png"]}, + ) + + header = make_real_header_factory(individuals=("",)) # single-animal layout + md = populate_keypoint_layer_properties( + header, + labels=["head", "tail"], + ids=["", ""], + likelihood=np.array([1.0, 1.0], dtype=float), + paths=["frame0.png", "frame1.png"], + colormap="viridis", + ) + + # two points on frame 0 + data = np.array( + [ + [0.0, 10.0, 20.0], # (t, y, x) + [0.0, 30.0, 40.0], + ], + dtype=float, + ) + + layer = viewer.add_points(data, **md) + + assert isinstance(layer, Points) + qtbot.waitUntil(lambda: layer in controls._stores, timeout=5_000) + assert layer in controls._stores + + # frame 0: select and copy + viewer.dims.set_point(0, 0) + qtbot.wait(0) + + layer.selected_data = {0, 1} + layer._copy_data() + + assert "data" in layer._clipboard + assert "features" in layer._clipboard + assert len(layer._clipboard["data"]) == 2 + + # move to frame 1 and paste + viewer.dims.set_point(0, 1) + qtbot.wait(0) + + layer._paste_data() + qtbot.wait(0) + + # original 2 + pasted 2 + assert len(layer.data) == 4 + + pasted = np.asarray(layer.data)[-2:] + np.testing.assert_array_equal(pasted[:, 0], np.array([1.0, 1.0])) + np.testing.assert_allclose(pasted[:, 1:], data[:, 1:]) + + # labels/features should be preserved for pasted points + labels = list(layer.properties["label"]) + assert labels[-2:] == ["head", "tail"] + + +@pytest.mark.usefixtures("qtbot") +def test_copy_paste_same_frame_does_not_duplicate_existing_keypoints( + viewer, + make_real_header_factory, + qtbot, +): + """ + Contract: + If the copied keypoints are already annotated on the current frame, + DLC's patched paste should not duplicate them. + """ + viewer.add_image( + np.zeros((1, 64, 64), dtype=np.uint8), + name="frame", + metadata={"paths": ["frame0.png"]}, + ) + + header = make_real_header_factory(individuals=("",)) + md = populate_keypoint_layer_properties( + header, + labels=["head", "tail"], + ids=["", ""], + likelihood=np.array([1.0, 1.0], dtype=float), + paths=["frame0.png"], + colormap="viridis", + ) + + data = np.array( + [ + [0.0, 10.0, 20.0], + [0.0, 30.0, 40.0], + ], + dtype=float, + ) + + layer = viewer.add_points(data, **md) + assert isinstance(layer, Points) + + viewer.dims.set_point(0, 0) + qtbot.wait(0) + + layer.selected_data = {0, 1} + layer._copy_data() + + before = len(layer.data) + layer._paste_data() + qtbot.wait(0) + + # no duplicates expected on same frame + assert len(layer.data) == before diff --git a/src/napari_deeplabcut/_tests/e2e/test_routing_and_provenance.py b/src/napari_deeplabcut/_tests/e2e/test_routing_and_provenance.py new file mode 100644 index 00000000..149d02a2 --- /dev/null +++ b/src/napari_deeplabcut/_tests/e2e/test_routing_and_provenance.py @@ -0,0 +1,838 @@ +import logging + +import numpy as np +import pandas as pd +import pytest +from napari.layers import Points + +from napari_deeplabcut.core.io import AnnotationKind, MissingProvenanceError + +from .utils import ( + _assert_only_these_files_changed, + _make_dlc_project_with_multiple_gt, + _make_labeled_folder_with_machine_only, + _make_project_config_and_frames_no_gt, + _set_or_add_bodypart_xy, + _snapshot_coords, + file_sig, +) + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def forbid_project_config_dialog(monkeypatch): + monkeypatch.setattr( + "napari_deeplabcut._widgets.ui_dialogs.prompt_for_project_config_for_save", + lambda *args, **kwargs: pytest.fail("Unexpected project-config dialog."), + ) + monkeypatch.setattr( + "napari_deeplabcut._widgets.ui_dialogs.maybe_confirm_dataset_path_rewrite", + lambda *args, **kwargs: pytest.fail("Unexpected dataset path rewrite confirmation."), + ) + monkeypatch.setattr( + "napari_deeplabcut._widgets.ui_dialogs.warn_existing_dataset_folder_conflict", + lambda *args, **kwargs: pytest.fail("Unexpected dataset-folder conflict warning."), + ) + + +@pytest.fixture +def skip_project_config_dialog(monkeypatch): + """ + Simulate the new promotion policy when no config.yaml exists. + + The save flow now asks whether the user wants to locate a DLC config.yaml + before falling back to sidecar/manual scorer entry. In these no-config e2e + scenarios, emulate the user explicitly choosing to continue without config. + """ + from napari_deeplabcut.ui import dialogs as ui_dialogs + + calls = {"count": 0, "kwargs": None} + + def _skip(*args, **kwargs): + calls["count"] += 1 + calls["kwargs"] = kwargs + return ui_dialogs.ProjectConfigPromptResult( + action=ui_dialogs.ProjectConfigPromptAction.SKIP, + ) + + monkeypatch.setattr( + "napari_deeplabcut._widgets.ui_dialogs.prompt_for_project_config_for_save", + _skip, + ) + return calls + + +@pytest.mark.usefixtures("qtbot") +def test_save_routes_to_correct_gt_when_multiple_gt_exist( + viewer, keypoint_controls, qtbot, tmp_path, overwrite_confirm +): + """ + Contract: Saving a Points layer must write back ONLY to the file it came from. + No 'first CollectedData*.h5' selection when multiple exist. + """ + overwrite_confirm.forbid() + + project, config_path, labeled_folder, gt_paths, _ = _make_dlc_project_with_multiple_gt( + tmp_path, scorers=("John", "Jane"), with_machine=False + ) + gt_a, gt_b = gt_paths + + before = {p: _snapshot_coords(p) for p in gt_paths} + # Open both GT files explicitly so we get two Points layers + viewer.open(str(gt_a), plugin="napari-deeplabcut") + viewer.open(str(gt_b), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len([ly for ly in viewer.layers if isinstance(ly, Points)]) >= 2, timeout=10_000) + qtbot.wait(200) + + # Select the layer corresponding to gt_b + points_b = next((ly for ly in viewer.layers if isinstance(ly, Points) and ly.name == gt_b.stem), None) + assert points_b is not None, f"Expected a Points layer named {gt_b.stem}" + + store_b = keypoint_controls._stores.get(points_b) + assert store_b is not None + + # Fill NaNs for bodypart2 in B only (no overwrite dialog) + _set_or_add_bodypart_xy(points_b, store_b, "bodypart2", x=77.0, y=66.0) + + logger.info("BEFORE SAVE : name=%s, sig=%s", gt_paths[0].name, file_sig(gt_paths[0])) + logger.info("BEFORE SAVE : name=%s, sig=%s", gt_paths[1].name, file_sig(gt_paths[1])) + + viewer.layers.selection.active = points_b + logger.info("Layer selected for save: %s", points_b.name) + # logger.info("Layer metadata: %s", points_b.metadata) + viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") + + logger.info("AFTER SAVE : name=%s, sig=%s", gt_paths[0].name, file_sig(gt_paths[0])) + logger.info("AFTER SAVE : name=%s, sig=%s", gt_paths[1].name, file_sig(gt_paths[1])) + + qtbot.wait(200) + + after = {p: _snapshot_coords(p) for p in gt_paths} + + _assert_only_these_files_changed(before, after, changed={gt_b}) + assert after[gt_b]["b2x"] == 77.0 + + +@pytest.mark.usefixtures("qtbot") +def test_machine_layer_does_not_modify_gt_on_save(viewer, keypoint_controls, qtbot, tmp_path, overwrite_confirm): + """ + Contract: machine outputs must never save to their own file. + Users must explicitly provide a scorer name that is then used to save the h5. + """ + overwrite_confirm.forbid() + + project, config_path, labeled_folder, gt_paths, machine_path = _make_dlc_project_with_multiple_gt( + tmp_path, scorers=("John", "Jane"), with_machine=True + ) + assert machine_path is not None + + before = {p: _snapshot_coords(p) for p in gt_paths + [machine_path]} + + viewer.open(str(machine_path), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len([ly for ly in viewer.layers if isinstance(ly, Points)]) >= 1, timeout=10_000) + qtbot.wait(200) + + machine_layer = next((ly for ly in viewer.layers if isinstance(ly, Points) and ly.name == machine_path.stem), None) + assert machine_layer is not None + + store = keypoint_controls._stores.get(machine_layer) + assert store is not None + + # Fill NaNs in machine file (no overwrite prompt) + _set_or_add_bodypart_xy(machine_layer, store, "bodypart2", x=55.0, y=44.0) + + viewer.layers.selection.active = machine_layer + + with pytest.raises(MissingProvenanceError): + viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") + + qtbot.wait(200) + + after = {p: _snapshot_coords(p) for p in gt_paths + [machine_path]} + + # Machine file should be unchanged (no save path), + # and GT files should be unchanged (machine edits must not touch GT). + _assert_only_these_files_changed(before, after, changed=set()) + # assert after[machine_path]["b2x"] == 55.0 + + +@pytest.mark.usefixtures("qtbot") +def test_layer_rename_does_not_change_save_target(viewer, keypoint_controls, qtbot, tmp_path, overwrite_confirm): + """ + Contract: layer renaming must not redirect output or create new file. + """ + overwrite_confirm.forbid() + + project, config_path, labeled_folder, gt_paths, _ = _make_dlc_project_with_multiple_gt( + tmp_path, scorers=("John", "Jane"), with_machine=False + ) + gt_a = gt_paths[0] + + before = {p: _snapshot_coords(p) for p in gt_paths} + + viewer.open(str(gt_a), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len([ly for ly in viewer.layers if isinstance(ly, Points)]) >= 1, timeout=10_000) + qtbot.wait(200) + + layer = next((ly for ly in viewer.layers if isinstance(ly, Points) and ly.name == gt_a.stem), None) + assert layer is not None + store = keypoint_controls._stores.get(layer) + assert store is not None + + # Rename in UI + layer.name = "foo" + + # Fill NaNs so no overwrite dialog + _set_or_add_bodypart_xy(layer, store, "bodypart2", x=12.0, y=34.0) + + viewer.layers.selection.active = layer + viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") + qtbot.wait(200) + + # Must not create foo.h5 in the folder + assert not (gt_a.parent / "foo.h5").exists(), "Renaming must not create foo.h5" + + after = {p: _snapshot_coords(p) for p in gt_paths} + _assert_only_these_files_changed(before, after, changed={gt_a}) + + +@pytest.mark.usefixtures("qtbot") +def test_ambiguous_placeholder_save_aborts_when_multiple_gt_exist( + viewer, keypoint_controls, qtbot, tmp_path, overwrite_confirm +): + """ + Contract: If provenance is missing and multiple candidate GT files exist, + save must refuse (deterministic) rather than silently choosing. + """ + overwrite_confirm.forbid() + + project, config_path, labeled_folder, gt_paths, _ = _make_dlc_project_with_multiple_gt( + tmp_path, scorers=("John", "Jane"), with_machine=False + ) + + before = {p: _snapshot_coords(p) for p in gt_paths} + + from napari_deeplabcut.core import keypoints + + # Open config first => placeholder points layer + viewer.open(str(config_path), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len([ly for ly in viewer.layers if isinstance(ly, Points)]) >= 1, timeout=5_000) + qtbot.wait(200) + + placeholder = next((ly for ly in viewer.layers if isinstance(ly, Points)), None) + assert placeholder is not None + + # Ensure it's a placeholder (no actual data) + assert placeholder.data is None or len(placeholder.data) == 0 + + # Open labeled folder (images) so root/paths are present for saving attempt + viewer.open(str(labeled_folder), plugin="napari-deeplabcut") + qtbot.wait(200) + + store = keypoint_controls._stores.get(placeholder) + assert store is not None + + # Add a point to placeholder + store.current_keypoint = keypoints.Keypoint("bodypart2", "") + placeholder.add(np.array([0.0, 33.0, 44.0], dtype=float)) + + viewer.layers.selection.active = placeholder + + # Expect save to abort deterministically + try: + viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") + except Exception: + pass # acceptable in headless/test mode + + qtbot.wait(200) + + after = {p: _snapshot_coords(p) for p in gt_paths} + _assert_only_these_files_changed(before, after, changed=set()) + + +@pytest.mark.usefixtures("qtbot") +def test_folder_open_loads_all_h5_when_multiple_exist(viewer, qtbot, tmp_path): + """ + Contract: Opening a labeled-data folder with multiple H5 files should not + silently pick the first one. Preferred policy: load all as separate Points layers. + """ + project, config_path, labeled_folder, gt_paths, machine_path = _make_dlc_project_with_multiple_gt( + tmp_path, scorers=("John", "Jane"), with_machine=True + ) + + viewer.open(str(labeled_folder), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len(viewer.layers) >= 2, timeout=10_000) # images + points at least + qtbot.wait(200) + + pts = [ly for ly in viewer.layers if isinstance(ly, Points)] + # Expected: one points layer per H5 file (2 GT + 1 machine) + assert len(pts) == 3, f"Expected 3 Points layers (2 GT + 1 machine), got {len(pts)}: {[p.name for p in pts]}" + # ------------------------------------------------------------------ + # New assertion: each Points layer must carry authoritative source_h5 + # matching the file it originated from (stable across layer renames). + # ------------------------------------------------------------------ + all_expected = list(gt_paths) + ([machine_path] if machine_path is not None else []) + expected_by_stem = {p.stem: str(p.expanduser().resolve()) for p in all_expected} + + for ly in pts: + assert "source_h5" in ly.metadata, f"Missing source_h5 in layer.metadata for {ly.name}" + # Ensure it points to the actual file for that layer stem + assert ly.metadata["source_h5"] == expected_by_stem[ly.name], ( + f"Layer {ly.name} has wrong source_h5:\n" + f" got: {ly.metadata['source_h5']}\n" + f" expected: {expected_by_stem[ly.name]}" + ) + + assert "io" in (ly.metadata or {}), f"Missing io provenance dict in layer.metadata for {ly.name}" + assert ly.metadata["io"].get("source_relpath_posix"), f"io.source_relpath_posix missing for {ly.name}" + + +@pytest.mark.usefixtures("qtbot") +def test_config_first_save_writes_gt_into_dataset_folder(viewer, keypoint_controls, qtbot, tmp_path, overwrite_confirm): + """ + Regression: config-first workflow must save CollectedData_.h5 inside + project/labeled-data//, not next to config.yaml. + """ + overwrite_confirm.forbid() + + project, config_path, labeled_folder = _make_project_config_and_frames_no_gt(tmp_path) + + # Open config first -> placeholder points layer + viewer.open(str(config_path), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: any(isinstance(ly, Points) for ly in viewer.layers), timeout=5_000) + + # Open dataset folder -> provides dataset context + viewer.open(str(labeled_folder), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len(viewer.layers) >= 2, timeout=10_000) + qtbot.wait(200) + + pts_layers = [ly for ly in viewer.layers if isinstance(ly, Points)] + assert pts_layers, "Expected a Points layer from config.yaml" + + points = pts_layers[0] + store = keypoint_controls._stores.get(points) + assert store is not None + + # Add a point and save + _set_or_add_bodypart_xy(points, store, "bodypart1", x=11.0, y=22.0) + + viewer.layers.selection.active = points + viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") + qtbot.wait(300) + + expected = labeled_folder / "CollectedData_John.h5" + assert expected.exists(), f"Expected GT to be created in dataset folder: {expected}" + + wrong = project / "CollectedData_John.h5" + assert not wrong.exists(), f"Must not save next to config.yaml: {wrong}" + + +@pytest.mark.usefixtures("qtbot") +def test_promotion_first_save_skip_config_then_prompt_scorer_and_create_sidecar( + viewer, keypoint_controls, qtbot, tmp_path, inputdialog, skip_project_config_dialog +): + """ + First save on a machine/prediction layer (no config.yaml, no sidecar): + - offers project-config lookup first + - user continues without config + - then prompts for scorer + - writes .napari-deeplabcut.json sidecar + - creates CollectedData_.h5 + - does NOT modify machinelabels-iter0.h5 + """ + labeled_folder = _make_labeled_folder_with_machine_only(tmp_path) + + machine_path = labeled_folder / "machinelabels-iter0.h5" + machine_pre = pd.read_hdf(machine_path, key="keypoints") + + # Open folder + viewer.open(str(labeled_folder), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len(viewer.layers) >= 2, timeout=10_000) + qtbot.wait(200) + + # Find machine points layer + pts_layers = [ly for ly in viewer.layers if isinstance(ly, Points)] + assert any(p.name == "machinelabels-iter0" for p in pts_layers) + machine_layer = next(p for p in pts_layers if p.name == "machinelabels-iter0") + + # Edit: add bodypart2 (use helper that works across versions) + store = keypoint_controls._stores.get(machine_layer) + assert store is not None + _set_or_add_bodypart_xy(machine_layer, store, "bodypart2", x=44.0, y=33.0) + + # Set user input for scorer + inputdialog.set("Alice", ok=True) + + # Save via the widget path (ensures prompt runs) + viewer.layers.selection.active = machine_layer + keypoint_controls.viewer.layers.selection.active = machine_layer + keypoint_controls.viewer.layers.selection.select_only(machine_layer) + + assert "io" in machine_layer.metadata + assert machine_layer.metadata["io"].get("kind") in ("machine", AnnotationKind.MACHINE) + + # Call your menu-hooked save action (this hits promotion logic) + keypoint_controls._save_layers_dialog(selected=True) + qtbot.wait(200) + assert "save_target" in machine_layer.metadata, machine_layer.metadata.keys() + + keypoint_controls._save_layers_dialog(selected=True) + qtbot.wait(200) + + assert skip_project_config_dialog["count"] == 1 + assert skip_project_config_dialog["kwargs"]["resolve_scorer"] is True + assert "save_target" in machine_layer.metadata, machine_layer.metadata.keys() + + # Sidecar created + sidecar = labeled_folder / ".napari-deeplabcut.json" + assert sidecar.exists() + assert "Alice" in sidecar.read_text(encoding="utf-8") + + # GT created + gt_path = labeled_folder / "CollectedData_Alice.h5" + assert gt_path.exists() + + # Machine file unchanged + machine_post = pd.read_hdf(machine_path, key="keypoints") + pd.testing.assert_frame_equal(machine_pre, machine_post) + + +@pytest.mark.usefixtures("qtbot") +def test_promotion_second_save_skip_config_then_use_sidecar_without_scorer_prompt( + viewer, keypoint_controls, qtbot, tmp_path, inputdialog, skip_project_config_dialog +): + """ + After sidecar exists, saving again with no config.yaml available: + - offers project-config lookup first + - user continues without config + - QInputDialog.getText not called because sidecar provides scorer + - writes/updates same CollectedData_.h5 + - machine file unchanged + """ + labeled_folder = _make_labeled_folder_with_machine_only(tmp_path) + + # Pre-create sidecar (as if first run already happened) + sidecar = labeled_folder / ".napari-deeplabcut.json" + sidecar.write_text('{"schema_version": 1, "default_scorer": "Alice"}', encoding="utf-8") + + machine_path = labeled_folder / "machinelabels-iter0.h5" + machine_pre = pd.read_hdf(machine_path, key="keypoints") + + controls = keypoint_controls + viewer.window.add_dock_widget(controls, name="Keypoint controls", area="right") + + viewer.open(str(labeled_folder), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len(viewer.layers) >= 2, timeout=10_000) + qtbot.wait(200) + + pts_layers = [ly for ly in viewer.layers if isinstance(ly, Points)] + machine_layer = next(p for p in pts_layers if p.name == "machinelabels-iter0") + + store = controls._stores.get(machine_layer) + assert store is not None + _set_or_add_bodypart_xy(machine_layer, store, "bodypart1", x=99.0, y=88.0) + + # No prompt expected + inputdialog.forbid() + + # Save via widget path + controls._save_layers_dialog(selected=True) + qtbot.wait(200) + + assert inputdialog.calls == 0 + + gt_path = labeled_folder / "CollectedData_Alice.h5" + assert gt_path.exists() + + machine_post = pd.read_hdf(machine_path, key="keypoints") + pd.testing.assert_frame_equal(machine_pre, machine_post) + + controls._save_layers_dialog(selected=True) + qtbot.wait(200) + + assert skip_project_config_dialog["count"] == 1 + assert skip_project_config_dialog["kwargs"]["resolve_scorer"] is True + assert inputdialog.calls == 0 + + +@pytest.mark.usefixtures("qtbot") +def test_projectless_folder_save_can_associate_with_config_and_coerce_paths_to_dlc_row_keys( + viewer, + keypoint_controls, + qtbot, + tmp_path, + monkeypatch, + overwrite_confirm, +): + """ + Contract: a project-less labeled folder can be associated with a chosen DLC + project at save time by rewriting safe paths to canonical DLC row keys. + + Goals + ----- + - Use current external folder name as the target dataset name. + - Save safe paths as labeled-data//. + - Use the same normalized metadata for overwrite preflight and actual write. + - Persist the improved metadata on the live layer after successful save. + + Non-goals + --------- + - Do NOT require the current files to already be inside the selected project. + - Do NOT coerce nested/multi-folder layouts into DLC row keys. + - Do NOT rewrite unrelated outside paths. + """ + overwrite_confirm.forbid() + + project, config_path, _project_dataset_folder = _make_project_config_and_frames_no_gt(tmp_path) + + # External project-less folder that the user labeled outside the project. + external_folder = tmp_path / "session_external" + external_folder.mkdir() + + inside_abs = external_folder / "img001.png" + inside_abs.write_bytes(b"placeholder") + dataset = external_folder.name + + outside_dir = tmp_path / "external-images" + outside_dir.mkdir() + outside_img = outside_dir / "img999.png" + outside_img.write_bytes(b"placeholder") + + from napari_deeplabcut.core import keypoints + + # Open config first -> placeholder points layer + viewer.open(str(config_path), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: any(isinstance(ly, Points) for ly in viewer.layers), timeout=5_000) + + points = next(ly for ly in viewer.layers if isinstance(ly, Points)) + store = keypoint_controls._stores.get(points) + assert store is not None + + # Simulate project-less folder metadata: + points.metadata = dict(points.metadata or {}) + points.metadata["root"] = str(external_folder) + points.metadata["paths"] = [ + str(inside_abs), # direct child of source_root -> should coerce + "img002.png", # basename -> should coerce + f"labeled-data/{dataset}/img003.png", # already canonical -> preserve + str(outside_img), # unrelated absolute path -> preserve unchanged + ] + points.metadata.pop("project", None) + + store.current_keypoint = keypoints.Keypoint("bodypart1", "") + points.add(np.array([0.0, 11.0, 22.0], dtype=float)) + + from napari_deeplabcut.ui import dialogs as ui_dialogs + + monkeypatch.setattr( + "napari_deeplabcut._widgets.ui_dialogs.prompt_for_project_config_for_save", + lambda *args, **kwargs: ui_dialogs.ProjectConfigPromptResult( + action=ui_dialogs.ProjectConfigPromptAction.ASSOCIATE, + config_path=str(config_path), + ), + ) + monkeypatch.setattr( + "napari_deeplabcut._widgets.ui_dialogs.maybe_confirm_dataset_path_rewrite", + lambda *args, **kwargs: True, + ) + + import napari_deeplabcut.core.conflicts as conflicts + + real_compute = conflicts.compute_overwrite_report_for_points_save + captured = {} + + def _wrapped_compute(data, attributes): + captured["attributes"] = attributes + return real_compute(data, attributes) + + monkeypatch.setattr( + "napari_deeplabcut._widgets.compute_overwrite_report_for_points_save", + _wrapped_compute, + ) + + viewer.layers.selection.active = points + keypoint_controls.viewer.layers.selection.active = points + keypoint_controls.viewer.layers.selection.select_only(points) + + keypoint_controls._save_layers_dialog(selected=True) + qtbot.wait(300) + + # After project association, save should route into the chosen project's + # labeled-data// folder inferred from the rewritten metadata. + expected_dataset_dir = project / "labeled-data" / dataset + expected_h5 = expected_dataset_dir / "CollectedData_John.h5" + expected_csv = expected_dataset_dir / "CollectedData_John.csv" + + assert expected_h5.exists() + assert expected_csv.exists() + + # And it should NOT create a GT file next to the external source folder. + assert not (external_folder / "CollectedData_John.h5").exists() + assert not (external_folder / "CollectedData_John.csv").exists() + + expected_paths = [ + f"labeled-data/{dataset}/{inside_abs.name}", + f"labeled-data/{dataset}/img002.png", + f"labeled-data/{dataset}/img003.png", + outside_img.as_posix(), + ] + + # Preflight saw normalized metadata + assert captured["attributes"]["metadata"]["project"] == str(project) + assert captured["attributes"]["metadata"]["paths"] == expected_paths + + # Live layer metadata persisted the successful normalization + assert points.metadata["project"] == str(project) + assert points.metadata["paths"] == expected_paths + + # H5 row index contains canonical DLC row keys for the safe cases + df = pd.read_hdf(expected_h5, key="keypoints") + if isinstance(df.index, pd.MultiIndex): + observed_rows = ["/".join(map(str, idx)) for idx in df.index] + else: + observed_rows = [str(idx).replace("\\", "/") for idx in df.index] + + assert f"labeled-data/{dataset}/{inside_abs.name}" in observed_rows + assert f"labeled-data/{dataset}/img002.png" not in observed_rows + assert f"labeled-data/{dataset}/img003.png" not in observed_rows + assert outside_img.as_posix() not in observed_rows + + +@pytest.mark.usefixtures("qtbot") +def test_projectless_folder_save_refuses_when_target_dataset_folder_already_contains_files( + viewer, + keypoint_controls, + qtbot, + tmp_path, + monkeypatch, + overwrite_confirm, +): + """ + Contract: project-association save must refuse if the target dataset folder + already exists in the chosen project and contains files. + """ + overwrite_confirm.forbid() + + project, config_path, existing_project_dataset = _make_project_config_and_frames_no_gt(tmp_path) + dataset = existing_project_dataset.name + + # Existing populated target dataset folder inside project -> must refuse + assert existing_project_dataset.exists() + assert any(existing_project_dataset.iterdir()), "Expected existing project dataset folder to already contain files." + + # External project-less folder with the SAME dataset name + external_parent = tmp_path / "external-root" + external_parent.mkdir() + external_folder = external_parent / dataset + external_folder.mkdir() + + external_img = external_folder / "img_external.png" + external_img.write_bytes(b"placeholder") + + from napari_deeplabcut.core import keypoints + + viewer.open(str(config_path), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: any(isinstance(ly, Points) for ly in viewer.layers), timeout=5_000) + + points = next(ly for ly in viewer.layers if isinstance(ly, Points)) + store = keypoint_controls._stores.get(points) + assert store is not None + + points.metadata = dict(points.metadata or {}) + points.metadata["root"] = str(external_folder) + points.metadata["paths"] = [str(external_img)] + points.metadata.pop("project", None) + + store.current_keypoint = keypoints.Keypoint("bodypart1", "") + points.add(np.array([0.0, 11.0, 22.0], dtype=float)) + + warned = {} + + from napari_deeplabcut.ui import dialogs as ui_dialogs + + monkeypatch.setattr( + "napari_deeplabcut._widgets.ui_dialogs.prompt_for_project_config_for_save", + lambda *args, **kwargs: ui_dialogs.ProjectConfigPromptResult( + action=ui_dialogs.ProjectConfigPromptAction.ASSOCIATE, + config_path=str(config_path), + ), + ) + monkeypatch.setattr( + "napari_deeplabcut._widgets.ui_dialogs.warn_existing_dataset_folder_conflict", + lambda *args, **kwargs: warned.setdefault("called", True), + ) + monkeypatch.setattr( + "napari_deeplabcut._widgets.ui_dialogs.maybe_confirm_dataset_path_rewrite", + lambda *args, **kwargs: True, + ) + + keypoint_controls.viewer.layers.selection.active = points + keypoint_controls.viewer.layers.selection.select_only(points) + + keypoint_controls._save_layers_dialog(selected=True) + qtbot.wait(200) + + assert warned.get("called", False), "Expected conflict warning for populated target dataset folder." + + # No GT should be created in the external folder because association was refused. + assert not (external_folder / "CollectedData_John.h5").exists() + + +@pytest.mark.usefixtures("qtbot") +def test_promotion_nearby_config_wins_no_dialog_no_prompt( + viewer, + keypoint_controls, + qtbot, + tmp_path, + monkeypatch, + inputdialog, +): + """ + If a valid DLC config.yaml is discoverable near a machine-labeled layer, + promotion must use the scorer from that config without showing either: + - the project-config selection dialog + - the manual scorer prompt + + Sidecar, if present, must be ignored in favor of config.yaml. + """ + project, config_path, labeled_folder, _gt_paths, machine_path = _make_dlc_project_with_multiple_gt( + tmp_path, scorers=("John", "Jane"), with_machine=True + ) + assert machine_path is not None + + # Create a conflicting sidecar scorer to prove config.yaml wins. + sidecar = labeled_folder / ".napari-deeplabcut.json" + sidecar.write_text('{"schema_version": 1, "default_scorer": "Alice"}', encoding="utf-8") + + machine_pre = pd.read_hdf(machine_path, key="keypoints") + + dialog_calls = {"count": 0} + + def _unexpected_config_dialog(*args, **kwargs): + dialog_calls["count"] += 1 + pytest.fail("Config-selection dialog must not appear when nearby config.yaml is auto-discovered.") + + monkeypatch.setattr( + "napari_deeplabcut._widgets.ui_dialogs.prompt_for_project_config_for_save", + _unexpected_config_dialog, + ) + + # Manual scorer prompt must not be used either. + inputdialog.forbid() + + viewer.open(str(labeled_folder), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len(viewer.layers) >= 2, timeout=10_000) + qtbot.wait(200) + + pts_layers = [ly for ly in viewer.layers if isinstance(ly, Points)] + machine_layer = next(p for p in pts_layers if p.name == machine_path.stem) + + store = keypoint_controls._stores.get(machine_layer) + assert store is not None + _set_or_add_bodypart_xy(machine_layer, store, "bodypart2", x=54.0, y=43.0) + + viewer.layers.selection.active = machine_layer + keypoint_controls.viewer.layers.selection.active = machine_layer + keypoint_controls.viewer.layers.selection.select_only(machine_layer) + + keypoint_controls._save_layers_dialog(selected=True) + qtbot.wait(300) + + assert dialog_calls["count"] == 0 + assert inputdialog.calls == 0 + assert "save_target" in machine_layer.metadata, machine_layer.metadata.keys() + + # Config scorer must win over sidecar scorer. + expected_gt = labeled_folder / "CollectedData_John.h5" + unexpected_gt = labeled_folder / "CollectedData_Alice.h5" + assert expected_gt.exists(), f"Expected GT with config scorer to be created: {expected_gt}" + assert not unexpected_gt.exists(), f"Sidecar scorer must be ignored when config.yaml is nearby: {unexpected_gt}" + + machine_post = pd.read_hdf(machine_path, key="keypoints") + pd.testing.assert_frame_equal(machine_pre, machine_post) + + +@pytest.mark.usefixtures("qtbot") +def test_promotion_selected_external_config_wins_no_scorer_prompt( + viewer, + keypoint_controls, + qtbot, + tmp_path, + monkeypatch, + inputdialog, +): + """ + If no nearby config.yaml is found, but the user points the save flow to a + valid external DLC config.yaml, promotion must use that config scorer and + must not show the manual scorer prompt. + + Sidecar, if present, must be ignored in favor of the user-selected config. + """ + labeled_folder = _make_labeled_folder_with_machine_only(tmp_path) + machine_path = labeled_folder / "machinelabels-iter0.h5" + machine_pre = pd.read_hdf(machine_path, key="keypoints") + + # External DLC project whose config scorer should be used. + external_project, external_config_path, _external_dataset = _make_project_config_and_frames_no_gt( + tmp_path / "extproj" + ) + assert external_config_path.exists() + + # Create a conflicting sidecar scorer to prove selected config wins. + sidecar = labeled_folder / ".napari-deeplabcut.json" + sidecar.write_text('{"schema_version": 1, "default_scorer": "Alice"}', encoding="utf-8") + + from napari_deeplabcut.ui import dialogs as ui_dialogs + + dialog_calls = {"count": 0, "kwargs": None} + + def _choose_external_config(*args, **kwargs): + dialog_calls["count"] += 1 + dialog_calls["kwargs"] = kwargs + return ui_dialogs.ProjectConfigPromptResult( + action=ui_dialogs.ProjectConfigPromptAction.ASSOCIATE, + config_path=str(external_config_path), + scorer="John", + ) + + monkeypatch.setattr( + "napari_deeplabcut._widgets.ui_dialogs.prompt_for_project_config_for_save", + _choose_external_config, + ) + + # Manual scorer prompt must not be used when selected config already resolves scorer. + inputdialog.forbid() + + viewer.open(str(labeled_folder), plugin="napari-deeplabcut") + qtbot.waitUntil(lambda: len(viewer.layers) >= 2, timeout=10_000) + qtbot.wait(200) + + pts_layers = [ly for ly in viewer.layers if isinstance(ly, Points)] + machine_layer = next(p for p in pts_layers if p.name == "machinelabels-iter0") + + store = keypoint_controls._stores.get(machine_layer) + assert store is not None + _set_or_add_bodypart_xy(machine_layer, store, "bodypart1", x=91.0, y=82.0) + + viewer.layers.selection.active = machine_layer + keypoint_controls.viewer.layers.selection.active = machine_layer + keypoint_controls.viewer.layers.selection.select_only(machine_layer) + + keypoint_controls._save_layers_dialog(selected=True) + qtbot.wait(300) + + assert dialog_calls["count"] == 1 + assert dialog_calls["kwargs"]["resolve_scorer"] is True + assert inputdialog.calls == 0 + assert "save_target" in machine_layer.metadata, machine_layer.metadata.keys() + + expected_gt = labeled_folder / "CollectedData_John.h5" + unexpected_gt = labeled_folder / "CollectedData_Alice.h5" + assert expected_gt.exists(), f"Expected GT with user-selected config scorer to be created: {expected_gt}" + assert not unexpected_gt.exists(), ( + f"Sidecar scorer must be ignored when a valid external config is selected: {unexpected_gt}" + ) + + machine_post = pd.read_hdf(machine_path, key="keypoints") + pd.testing.assert_frame_equal(machine_pre, machine_post) diff --git a/src/napari_deeplabcut/_tests/e2e/test_trails_e2e.py b/src/napari_deeplabcut/_tests/e2e/test_trails_e2e.py new file mode 100644 index 00000000..0e1b4cc8 --- /dev/null +++ b/src/napari_deeplabcut/_tests/e2e/test_trails_e2e.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import warnings +from pathlib import Path + +import numpy as np +import pandas as pd +from napari.layers import Tracks + +from napari_deeplabcut._widgets import KeypointControls +from napari_deeplabcut.core import keypoints + + +def _open_multianimal_points(viewer, tmp_path: Path, *, n_animals: int = 3, n_kpts: int = 2, n_frames: int = 4): + rng = np.random.default_rng(123) + data = rng.random((n_frames, n_animals * n_kpts * 2)) + + cols = pd.MultiIndex.from_product( + [ + ["me"], + [f"animal_{i}" for i in range(n_animals)], + [f"kpt_{i}" for i in range(n_kpts)], + ["x", "y"], + ], + names=["scorer", "individuals", "bodyparts", "coords"], + ) + df = pd.DataFrame(data, columns=cols, index=range(n_frames)) + + path = tmp_path / "three_animals.h5" + df.to_hdf(path, key="data") + + layer = viewer.open(path, plugin="napari-deeplabcut")[0] + return layer + + +def _trails_layer(controls) -> Tracks | None: + """Current live trails layer managed by the extracted trails controller.""" + return controls._trails_controller.layer + + +def _expected_cycle_colors_from_controls(controls, points_layer): + """ + Expected trails colors must come from the same resolved cycle path as the widget, + not directly from raw metadata. + """ + prop = "id" if controls.color_mode == str(keypoints.ColorMode.INDIVIDUAL) else "label" + vals = list(dict.fromkeys(map(str, points_layer.properties[prop]))) + + cycle = controls._resolved_cycle_for_layer(points_layer) + out = [] + for v in vals: + c = np.asarray(cycle[v], dtype=float) + if c.shape[0] == 3: + c = np.r_[c, 1.0] + out.append(c) + return np.asarray(out, dtype=float), vals + + +def _current_trails_cmap_colors(tracks_layer: Tracks): + cmap = tracks_layer.colormaps_dict[tracks_layer.color_by] + return np.asarray(cmap.colors) + + +def test_trails_mode_switch_does_not_fallback_to_track_id(viewer, tmp_path): + points = _open_multianimal_points(viewer, tmp_path, n_animals=3, n_kpts=2, n_frames=4) + controls = KeypointControls.get_layer_controls(points) + assert controls is not None + + # Make the points layer active so trails source selection is deterministic + viewer.layers.selection.active = points + + controls._trail_cb.setChecked(True) + trails = _trails_layer(controls) + assert trails is not None + assert trails.color_by == "id_codes" + + with warnings.catch_warnings(record=True) as rec: + warnings.simplefilter("always") + + controls.color_mode = keypoints.ColorMode.BODYPART + trails = _trails_layer(controls) + assert trails is not None + assert trails.color_by == "label_codes" + + controls.color_mode = keypoints.ColorMode.INDIVIDUAL + trails = _trails_layer(controls) + assert trails is not None + assert trails.color_by == "id_codes" + + msgs = [str(w.message) for w in rec] + assert not any("Falling back to track_id" in m for m in msgs), msgs + assert not any("Previous color_by key" in m for m in msgs), msgs + + +def test_trails_repeated_mode_switch_keeps_expected_colormap(viewer, tmp_path): + points = _open_multianimal_points(viewer, tmp_path, n_animals=3, n_kpts=2, n_frames=4) + controls = KeypointControls.get_layer_controls(points) + assert controls is not None + + viewer.layers.selection.active = points + controls._trail_cb.setChecked(True) + + trails = _trails_layer(controls) + assert trails is not None + + # Initial individual mode colors + expected_id_colors, _ = _expected_cycle_colors_from_controls(controls, points) + actual_id_colors = _current_trails_cmap_colors(trails) + np.testing.assert_allclose(actual_id_colors, expected_id_colors) + + # Switch to bodypart mode + controls.color_mode = keypoints.ColorMode.BODYPART + trails = _trails_layer(controls) + assert trails is not None + assert trails.color_by == "label_codes" + + expected_label_colors, _ = _expected_cycle_colors_from_controls(controls, points) + actual_label_colors = _current_trails_cmap_colors(trails) + np.testing.assert_allclose(actual_label_colors, expected_label_colors) + + # Switch back to individual mode + controls.color_mode = keypoints.ColorMode.INDIVIDUAL + trails = _trails_layer(controls) + assert trails is not None + assert trails.color_by == "id_codes" + + actual_id_colors_2 = _current_trails_cmap_colors(trails) + np.testing.assert_allclose(actual_id_colors_2, expected_id_colors) + + # And one more round trip for stability + controls.color_mode = keypoints.ColorMode.BODYPART + trails = _trails_layer(controls) + assert trails is not None + + actual_label_colors_2 = _current_trails_cmap_colors(trails) + np.testing.assert_allclose(actual_label_colors_2, expected_label_colors) + + +def test_trails_individual_mode_three_animals_have_three_distinct_mapped_colors(viewer, tmp_path): + points = _open_multianimal_points(viewer, tmp_path, n_animals=3, n_kpts=1, n_frames=4) + controls = KeypointControls.get_layer_controls(points) + assert controls is not None + + viewer.layers.selection.active = points + controls.color_mode = keypoints.ColorMode.INDIVIDUAL + controls._trail_cb.setChecked(True) + + trails = _trails_layer(controls) + assert trails is not None + assert trails.color_by == "id_codes" + + cmap = trails.colormaps_dict["id_codes"] + codes = np.asarray(trails.properties["id_codes"]) + + # Only inspect the first instance of each animal + uniq_codes = list(dict.fromkeys(codes.tolist())) + mapped = np.asarray(cmap.map(np.asarray(uniq_codes, dtype=float))) + unique_rows = np.unique(np.round(mapped, decimals=8), axis=0) + + assert unique_rows.shape[0] == 3, ( + f"Expected 3 unique colors for 3 animals, got {unique_rows.shape[0]} unique colors: {unique_rows}" + ) diff --git a/src/napari_deeplabcut/_tests/e2e/utils.py b/src/napari_deeplabcut/_tests/e2e/utils.py new file mode 100644 index 00000000..467a397e --- /dev/null +++ b/src/napari_deeplabcut/_tests/e2e/utils.py @@ -0,0 +1,345 @@ +# src/napari_deeplabcut/_tests/e2e/_helpers.py +from __future__ import annotations + +import hashlib +import math +import os +from pathlib import Path + +import numpy as np +import pandas as pd +from napari.layers import Points + +from napari_deeplabcut.config.models import DLCHeaderModel +from napari_deeplabcut.config.settings import ( + DEFAULT_MULTI_ANIMAL_INDIVIDUAL_CMAP, + DEFAULT_SINGLE_ANIMAL_CMAP, +) +from napari_deeplabcut.core.keypoints import build_color_cycles +from napari_deeplabcut.ui.color_scheme_display import _to_hex + + +def file_sig(p: Path): + b = p.read_bytes() + return { + "mtime": os.path.getmtime(p), + "size": len(b), + "sha256": hashlib.sha256(b).hexdigest()[:16], + } + + +def _write_minimal_png(path: Path, *, shape=(64, 64, 3)) -> None: + """Write a tiny RGB image to satisfy the folder reader.""" + from skimage.io import imsave + + path.parent.mkdir(parents=True, exist_ok=True) + img = np.zeros(shape, dtype=np.uint8) + img[8:24, 8:24, 0] = 255 + imsave(str(path), img, check_contrast=False) + + +def _make_minimal_dlc_project(tmp_path: Path): + """ + Build a minimal DLC-like folder: + project/ + config.yaml + labeled-data/test/img000.png + labeled-data/test/CollectedData_John.h5 (bodypart1 labeled, bodypart2 NaN) + """ + import yaml + + project = tmp_path / "project" + labeled = project / "labeled-data" / "test" + labeled.mkdir(parents=True, exist_ok=True) + + img_rel = ("labeled-data", "test", "img000.png") + img_path = project / Path(*img_rel) + _write_minimal_png(img_path) + + cfg = { + "scorer": "John", + "bodyparts": ["bodypart1", "bodypart2"], + "dotsize": 8, + "pcutoff": 0.6, + "colormap": "viridis", + } + config_path = project / "config.yaml" + config_path.write_text(yaml.safe_dump(cfg), encoding="utf-8") + + cols = pd.MultiIndex.from_product( + [["John"], ["bodypart1", "bodypart2"], ["x", "y"]], + names=["scorer", "bodyparts", "coords"], + ) + idx = pd.MultiIndex.from_tuples([img_rel]) + df0 = pd.DataFrame([[10.0, 20.0, np.nan, np.nan]], index=idx, columns=cols) + + h5_path = labeled / "CollectedData_John.h5" + df0.to_hdf(h5_path, key="keypoints", mode="w") + df0.to_csv(str(h5_path).replace(".h5", ".csv")) + + return project, config_path, labeled, h5_path + + +def _make_labeled_folder_with_machine_only(tmp_path: Path) -> Path: + """ + Folder contains: + - images + - machinelabels-iter0.h5 (no CollectedData*, no config.yaml) + """ + folder = tmp_path / "shared" / "labeled-data" / "test" + folder.mkdir(parents=True, exist_ok=True) + + _write_minimal_png(folder / "img000.png") + + cols = pd.MultiIndex.from_product( + [["machine"], ["bodypart1", "bodypart2"], ["x", "y"]], + names=["scorer", "bodyparts", "coords"], + ) + df0 = pd.DataFrame([[np.nan, np.nan, np.nan, np.nan]], index=["img000.png"], columns=cols) + (folder / "machinelabels-iter0.h5").unlink(missing_ok=True) + df0.to_hdf(folder / "machinelabels-iter0.h5", key="keypoints", mode="w") + df0.to_csv(str(folder / "machinelabels-iter0.csv")) + + return folder + + +def _write_keypoints_h5( + path: Path, + *, + scorer: str, + img_rel: tuple[str, ...], + bodyparts=("bodypart1", "bodypart2"), + values=None, +) -> Path: + """ + Write a single-row DLC keypoints H5 in the same format used by _make_minimal_dlc_project. + `values` should be [b1x, b1y, b2x, b2y] where some can be NaN. + """ + if values is None: + values = [10.0, 20.0, np.nan, np.nan] + + cols = pd.MultiIndex.from_product( + [[scorer], list(bodyparts), ["x", "y"]], + names=["scorer", "bodyparts", "coords"], + ) + idx = pd.MultiIndex.from_tuples([img_rel]) + df = pd.DataFrame([values], index=idx, columns=cols) + + path.parent.mkdir(parents=True, exist_ok=True) + df.to_hdf(path, key="keypoints", mode="w") + df.to_csv(str(path).replace(".h5", ".csv")) + return path + + +def _make_dlc_project_with_multiple_gt( + tmp_path: Path, + *, + scorers=("John", "Jane"), + with_machine: bool = False, +): + """ + Build a minimal DLC-like labeled-data folder with multiple GT files. + """ + import yaml + + project = tmp_path / "project" + labeled = project / "labeled-data" / "test" + labeled.mkdir(parents=True, exist_ok=True) + + img_rel = ("labeled-data", "test", "img000.png") + img_path = project / Path(*img_rel) + _write_minimal_png(img_path) + + cfg = { + "scorer": scorers[0], + "bodyparts": ["bodypart1", "bodypart2"], + "dotsize": 8, + "pcutoff": 0.6, + "colormap": "viridis", + } + config_path = project / "config.yaml" + config_path.write_text(yaml.safe_dump(cfg), encoding="utf-8") + + gt_paths = [] + base = 10.0 + for i, scorer in enumerate(scorers): + vals = [base + i * 100.0, base + i * 100.0 + 10.0, np.nan, np.nan] + gt_path = labeled / f"CollectedData_{scorer}.h5" + _write_keypoints_h5(gt_path, scorer=scorer, img_rel=img_rel, values=vals) + gt_paths.append(gt_path) + + machine_path = None + if with_machine: + machine_path = labeled / "machinelabels-iter0.h5" + _write_keypoints_h5( + machine_path, + scorer="machine", + img_rel=img_rel, + values=[np.nan, np.nan, np.nan, np.nan], + ) + + return project, config_path, labeled, gt_paths, machine_path + + +def _make_project_config_and_frames_no_gt(tmp_path: Path): + """ + Project with: + project/config.yaml + project/labeled-data/test/img000.png + No CollectedData*.h5 initially. + """ + import yaml + + project = tmp_path / "project" + labeled = project / "labeled-data" / "test" + labeled.mkdir(parents=True, exist_ok=True) + + img_rel = ("labeled-data", "test", "img000.png") + _write_minimal_png(project / Path(*img_rel)) + + cfg = { + "scorer": "John", + "bodyparts": ["bodypart1", "bodypart2"], + "dotsize": 8, + "pcutoff": 0.6, + "colormap": "magma", + } + config_path = project / "config.yaml" + config_path.write_text(yaml.safe_dump(cfg), encoding="utf-8") + + return project, config_path, labeled + + +def _read_h5_keypoints(path: Path) -> pd.DataFrame: + return pd.read_hdf(path, key="keypoints") + + +def _index_mask_for_img(df: pd.DataFrame, basename: str) -> np.ndarray: + """Return boolean mask selecting rows that correspond to a given image basename.""" + if isinstance(df.index, pd.MultiIndex): + return np.array([str(Path(*t)).endswith(basename) for t in df.index]) + return df.index.astype(str).str.endswith(basename).to_numpy() + + +def _get_coord_from_df(df: pd.DataFrame, bodypart: str, coord: str, basename: str = "img000.png") -> float: + """Extract the single value for (bodypart, coord) in the row matching basename.""" + series = df.xs((bodypart, coord), axis=1, level=["bodyparts", "coords"]) + mask = _index_mask_for_img(series, basename) + assert mask.any(), f"Could not find row for {basename} in saved dataframe index: {df.index!r}" + return float(series.loc[series.index[mask]].iloc[0, 0]) + + +def _snapshot_coords(path: Path) -> dict[str, float]: + df = _read_h5_keypoints(path) + return { + "b1x": _get_coord_from_df(df, "bodypart1", "x"), + "b1y": _get_coord_from_df(df, "bodypart1", "y"), + "b2x": _get_coord_from_df(df, "bodypart2", "x"), + "b2y": _get_coord_from_df(df, "bodypart2", "y"), + } + + +def sig_equal(a: dict, b: dict) -> bool: + """NaN-stable signature equality for test signatures.""" + if a.keys() != b.keys(): + return False + for k in a.keys(): + va, vb = a[k], b[k] + if isinstance(va, float) and isinstance(vb, float): + if math.isnan(va) and math.isnan(vb): + continue + if va != vb: + return False + return True + + +def assert_only_these_changed_nan_safe(before: dict[Path, dict], after: dict[Path, dict], changed: set[Path]): + for p in before: + if p in changed: + assert not sig_equal(before[p], after[p]), f"Expected {p.name} to change, but signature did not." + else: + assert sig_equal(before[p], after[p]), f"Expected {p.name} NOT to change, but signature changed." + + +def _assert_only_these_files_changed(before: dict[Path, dict], after: dict[Path, dict], changed: set[Path]): + return assert_only_these_changed_nan_safe(before, after, changed) + + +def _get_points_layer_with_data(viewer) -> Points: + """Return the first Points layer with actual data; fallback to first Points layer.""" + pts = [ly for ly in viewer.layers if isinstance(ly, Points)] + assert pts, "Expected at least one Points layer in viewer." + return next((ly for ly in pts if ly.data is not None and np.isfinite(np.asarray(ly.data)[:, 1:3]).any()), pts[0]) + + +def _set_or_add_bodypart_xy(points_layer: Points, store, bodypart: str, *, x: float, y: float, frame: int = 0): + """ + Cross-version helper: + - If the bodypart already exists as a row (possibly NaN placeholder), update it. + - Otherwise, add a new point for that bodypart via the store/Points.add. + """ + labels = np.asarray(points_layer.properties.get("label", []), dtype=object) + mask = labels == bodypart + + if mask.any(): + data = np.array(points_layer.data, copy=True) # (frame, y, x) + data[mask, 1] = y + data[mask, 2] = x + points_layer.data = data + return + + from napari_deeplabcut.core import keypoints + + store.current_keypoint = keypoints.Keypoint(bodypart, "") + points_layer.add(np.array([float(frame), float(y), float(x)], dtype=float)) + + +def _header_model_from_layer(layer) -> DLCHeaderModel: + hdr = layer.metadata.get("header") + if hdr is None: + raise AssertionError("Expected header in layer metadata") + return hdr if isinstance(hdr, DLCHeaderModel) else DLCHeaderModel.model_validate(hdr) + + +def _is_multianimal_header(header: DLCHeaderModel) -> bool: + inds = list(getattr(header, "individuals", []) or []) + return bool(inds and str(inds[0]) != "") + + +def _config_colormap_from_layer(layer) -> str: + md = layer.metadata or {} + cmap = md.get("config_colormap") + if isinstance(cmap, str) and cmap: + return cmap + return DEFAULT_SINGLE_ANIMAL_CMAP + + +def _cycles_from_policy(layer) -> dict[str, dict[str, np.ndarray]]: + """ + Compute expected cycles from the new centralized color policy. + + Source of truth: + - layer header + - metadata['config_colormap'] + - multi-animal id coloring uses DEFAULT_MULTI_ANIMAL_INDIVIDUAL_CMAP + """ + header = _header_model_from_layer(layer) + config_cmap = _config_colormap_from_layer(layer) + + config_cycles = build_color_cycles(header, config_cmap) or {} + + if _is_multianimal_header(header): + individual_cycles = build_color_cycles(header, DEFAULT_MULTI_ANIMAL_INDIVIDUAL_CMAP) or {} + else: + individual_cycles = config_cycles + + return { + "label": config_cycles.get("label", {}), + "id": individual_cycles.get("id", {}), + } + + +def _scheme_from_policy(layer, prop: str, names: list[str]) -> dict[str, str]: + cycles = _cycles_from_policy(layer) + mapping = cycles.get(prop, {}) + return {name: _to_hex(mapping[name]) for name in names if name in mapping} diff --git a/src/napari_deeplabcut/_tests/test_misc.py b/src/napari_deeplabcut/_tests/test_misc.py index 99f02569..d0cf1d3c 100644 --- a/src/napari_deeplabcut/_tests/test_misc.py +++ b/src/napari_deeplabcut/_tests/test_misc.py @@ -1,12 +1,15 @@ # test_misc.py -import inspect from pathlib import Path import numpy as np import pandas as pd import pytest -from napari_deeplabcut import _reader, misc +from napari_deeplabcut import misc +from napari_deeplabcut.config.models import DLCHeaderModel +from napari_deeplabcut.core.dataframes import guarantee_multiindex_rows, merge_multiple_scorers +from napari_deeplabcut.core.io import load_config +from napari_deeplabcut.core.keypoints import build_color_cycle # ---------------------------- @@ -25,62 +28,6 @@ def test_unsorted_unique(seq, expected): assert list(out) == expected -# ---------------------------- -# canonicalize_path tests -# ---------------------------- -@pytest.mark.parametrize( - "p, n, expected", - [ - # basic POSIX cases - ("root/sub1/sub2/file.png", 3, "sub1/sub2/file.png"), - ("root/sub/file.png", 2, "sub/file.png"), - ("root/sub/file.png", 1, "file.png"), - ("a/b/c", 10, "a/b/c"), - (Path("a/b/c/d.txt"), 3, "b/c/d.txt"), - ("", 3, ""), - (".", 3, ""), - ("..", 3, ""), - ("/", 3, ""), - ("a/b/c/", 3, "a/b/c"), - # n <= 0 raises ValueError - ("a/b/c", 0, ValueError), - ("a/b/c/d", -1, ValueError), - # non-string coercion - (123, 3, "123"), - # Windows-style backslashes normalized to POSIX; last 3 components kept - (r"a\b\c\file.png", 3, "b/c/file.png"), - # Mixed separators: double backslash becomes empty path component after replace -> filtered out - (r"frames\\test\video0/img001.png", 3, "test/video0/img001.png"), - ], -) -def test_canonicalize_path_cases(p, n, expected): - # If expected is an Exception class, assert it is raised - is_exc_class = inspect.isclass(expected) and issubclass(expected, Exception) - if is_exc_class: - with pytest.raises(expected): - misc.canonicalize_path(p, n=n) - return - - out = misc.canonicalize_path(p, n=n) - assert out == expected - - -def test_canonicalize_path_converts_and_drops_backslashes(): - # Dedicated check that backslashes are removed - out = misc.canonicalize_path(r"a\b\c\file.png", n=3) - assert "\\" not in out - - -def test_canonicalize_path_exception_fallback_still_replaces_backslashes(): - class Weird: - def __str__(self): - return r"x\y\z" - - out = misc.canonicalize_path(Weird(), n=3) # type: ignore[arg-type] - assert out == "x/y/z" - assert "\\" not in out - - # ---------------------------- # encode_categories helpers (kept for existing tests) # ---------------------------- @@ -308,7 +255,7 @@ def test_merge_multiple_scorers_no_likelihood(fake_keypoints): temp = fake_keypoints.copy(deep=True) temp.columns = temp.columns.set_levels(["you"], level="scorer") df = fake_keypoints.merge(temp, left_index=True, right_index=True) - df = misc.merge_multiple_scorers(df) + df = merge_multiple_scorers(df) pd.testing.assert_frame_equal(df, fake_keypoints) @@ -324,7 +271,7 @@ def test_merge_multiple_scorers(fake_keypoints): fake_keypoints.iloc[:5] = np.nan temp.iloc[5:] = np.nan df = fake_keypoints.merge(temp, left_index=True, right_index=True) - df = misc.merge_multiple_scorers(df) + df = merge_multiple_scorers(df) pd.testing.assert_index_equal(df.columns, fake_keypoints.columns) assert not df.isna().any(axis=None) @@ -335,13 +282,13 @@ def test_merge_multiple_scorers(fake_keypoints): def test_guarantee_multiindex_rows(): fake_index = [f"labeled-data/subfolder_{i}/image_{j}" for i in range(3) for j in range(10)] df = pd.DataFrame(index=fake_index) - misc.guarantee_multiindex_rows(df) + guarantee_multiindex_rows(df) assert isinstance(df.index, pd.MultiIndex) # Substitute index with frame numbers frame_numbers = list(range(df.shape[0])) df.index = frame_numbers - misc.guarantee_multiindex_rows(df) + guarantee_multiindex_rows(df) assert df.index.to_list() == frame_numbers @@ -350,14 +297,14 @@ def test_guarantee_multiindex_rows(): # ---------------------------- @pytest.mark.parametrize("n_colors", range(1, 11)) def test_build_color_cycle(n_colors): - color_cycle = misc.build_color_cycle(n_colors) + color_cycle = build_color_cycle(n_colors) assert color_cycle.shape[0] == n_colors # Test whether all colors are different assert len(set(map(tuple, color_cycle))) == n_colors # ---------------------------- -# DLCHeader tests +# DLCHeaderModel tests # ---------------------------- def test_dlc_header(): n_animals = 2 @@ -374,22 +321,22 @@ def test_dlc_header(): ], names=["scorer", "individuals", "bodyparts", "coords"], ) - header = misc.DLCHeader(fake_columns) + header = DLCHeaderModel(columns=fake_columns) + assert header.scorer == scorer + + header2 = header.with_scorer("you") + assert header2.scorer == "you" + # original header unchanged (functional) assert header.scorer == scorer - header.scorer = "you" - assert header.scorer == "you" - assert header.individuals == animals - assert header.bodyparts == keypoints - assert header.coords == ["x", "y", "likelihood"] def test_dlc_header_from_config_multi(config_path): - config = _reader._load_config(config_path) + config = load_config(config_path) config["multianimalproject"] = True config["individuals"] = ["animal"] config["multianimalbodyparts"] = list("abc") config["uniquebodyparts"] = list("de") - header = misc.DLCHeader.from_config(config) + header = DLCHeaderModel.from_config(config) assert header.individuals != [""] diff --git a/src/napari_deeplabcut/_tests/test_reader.py b/src/napari_deeplabcut/_tests/test_reader.py index 1efc73ad..f43994da 100644 --- a/src/napari_deeplabcut/_tests/test_reader.py +++ b/src/napari_deeplabcut/_tests/test_reader.py @@ -3,13 +3,25 @@ import numpy as np import pandas as pd import pytest -from PIL import Image from skimage.io import imsave from napari_deeplabcut import _reader +from napari_deeplabcut.core.io import ( + Video, + _lazy_imread, + is_video, + load_superkeypoints, + load_superkeypoints_diagram, + read_config, + read_hdf, + read_images, + read_video, +) FAKE_EXTENSION = ".notanimage" +# TODO @C-Achard 2026-02-18 Split IO vs reader tests. + @pytest.mark.parametrize("ext", _reader.SUPPORTED_IMAGES) def test_get_image_reader(ext): @@ -28,20 +40,22 @@ def test_get_config_reader_invalid_path(): def test_get_folder_parser(tmp_path_factory, fake_keypoints): folder = tmp_path_factory.mktemp("folder") + # Make it look like a DLC labeled folder + labeled = folder / "labeled-data" / "test" + labeled.mkdir(parents=True) + frame = (np.random.rand(10, 10) * 255).astype(np.uint8) - imsave(folder / "img1.png", frame) - imsave(folder / "img2.png", frame) - layers = _reader.get_folder_parser(folder)(None) - # There should be only an Image layer - assert len(layers) == 1 - assert layers[0][-1] == "image" + imsave(labeled / "img1.png", frame) + imsave(labeled / "img2.png", frame) + + # Add a DLC artifact so reader claims it + (labeled / "CollectedData_me.csv").write_text("dummy", encoding="utf-8") - # Add an annotation data file - fake_keypoints.to_hdf(folder / "data.h5", key="data") - layers = _reader.get_folder_parser(folder)(None) - # There should now be an additional Points layer - assert len(layers) == 2 - assert layers[-1][-1] == "points" + parser = _reader.get_folder_parser(str(labeled)) + assert parser is not None + layers = parser(None) + assert isinstance(layers, list) + assert len(layers) > 0 def test_get_folder_parser_wrong_input(): @@ -49,20 +63,20 @@ def test_get_folder_parser_wrong_input(): def test_get_folder_parser_no_images(tmp_path_factory): - folder = str(tmp_path_factory.mktemp("images")) - with pytest.raises(OSError): - _reader.get_folder_parser(folder) + folder = tmp_path_factory.mktemp("images") + parser = _reader.get_folder_parser(str(folder)) + assert parser is None def test_read_images(tmp_path_factory, fake_image): folder = tmp_path_factory.mktemp("folder") path = str(folder / "img.png") imsave(path, fake_image) - _ = _reader.read_images(path)[0] + _ = read_images(path)[0] def test_read_config(config_path): - dict_ = _reader.read_config(config_path)[0][1] + dict_ = read_config(config_path)[0][1] assert dict_["name"].startswith("CollectedData_") assert config_path.startswith(dict_["metadata"]["project"]) @@ -71,8 +85,8 @@ def test_read_hdf_old_index(tmp_path_factory, fake_keypoints): path = str(tmp_path_factory.mktemp("folder") / "data.h5") old_index = [f"labeled-data/video/img{i}.png" for i in range(fake_keypoints.shape[0])] fake_keypoints.index = old_index - fake_keypoints.to_hdf(path, key="data") - layers = _reader.read_hdf(path) + fake_keypoints.to_hdf(path, key="keypoints") + layers = read_hdf(path) assert len(layers) == 1 image_paths = layers[0][1]["metadata"]["paths"] assert len(image_paths) == len(fake_keypoints) @@ -90,8 +104,8 @@ def test_read_hdf_new_index(tmp_path_factory, fake_keypoints): ] ) fake_keypoints.index = new_index - fake_keypoints.to_hdf(path, key="data") - layers = _reader.read_hdf(path) + fake_keypoints.to_hdf(path, key="keypoints") + layers = read_hdf(path) assert len(layers) == 1 image_paths = layers[0][1]["metadata"]["paths"] assert len(image_paths) == len(fake_keypoints) @@ -107,7 +121,7 @@ def test_read_images_mixed_extensions_list_input(tmp_path): imsave(p_jpg, img) imsave(p_png, img) - layers = _reader.read_images([p_jpg, p_png]) + layers = read_images([p_jpg, p_png]) assert len(layers) == 1 data, params, kind = layers[0] assert kind == "image" @@ -131,7 +145,7 @@ def test_read_images_mixed_extensions_directory_ignores_unsupported(tmp_path): imsave(p_png, img) imsave(p_unsupp, img) - layers = _reader.read_images(tmp_path) # pass directory + layers = read_images(tmp_path) # pass directory assert len(layers) == 1 data, params, kind = layers[0] assert kind == "image" @@ -155,7 +169,7 @@ def test_lazy_imread_mixed_extensions_list(tmp_path): imsave(p_jpg, img1) imsave(p_png, img2) - result = _reader._lazy_imread([p_jpg, p_png], use_dask=True, stack=True) + result = _lazy_imread([p_jpg, p_png], use_dask=True, stack=True) assert isinstance(result, da.Array) assert result.shape == (2, 12, 9, 3) first = result[0].compute() @@ -173,7 +187,7 @@ def test_read_images_mixed_extensions_globs(tmp_path): imsave(p2, img) imsave(p3, img) - layers = _reader.read_images([str(tmp_path / "*.jpg"), str(tmp_path / "*.png")]) + layers = read_images([str(tmp_path / "*.jpg"), str(tmp_path / "*.png")]) assert len(layers) == 1 data, params, kind = layers[0] assert kind == "image" @@ -195,7 +209,7 @@ def test_read_images_mixed_extensions_tuple_input(tmp_path): imsave(p1, img) imsave(p2, img) - layers = _reader.read_images((p1, p2)) # tuple + layers = read_images((p1, p2)) # tuple assert len(layers) == 1 data, params, kind = layers[0] assert kind == "image" @@ -213,7 +227,7 @@ def test_lazy_imread_single_image(tmp_path): path = tmp_path / "img.png" imsave(path, img) - result = _reader._lazy_imread(path, use_dask=False) + result = _lazy_imread(path, use_dask=False) assert isinstance(result, np.ndarray) assert result.shape == img.shape @@ -226,7 +240,7 @@ def test_lazy_imread_multiple_images_equal_shape(tmp_path): imsave(path1, img1) imsave(path2, img2) - result = _reader._lazy_imread([path1, path2], use_dask=True) + result = _lazy_imread([path1, path2], use_dask=True) assert isinstance(result, da.Array) assert result.shape == (2, 10, 10, 3) @@ -241,7 +255,7 @@ def test_lazy_imread_mixed_shapes(tmp_path): # Should fail when stacking mixed shapes without padding with pytest.raises(ValueError): - _ = _reader._lazy_imread([path1, path2], use_dask=False, stack=True) + _ = _lazy_imread([path1, path2], use_dask=False, stack=True) def test_read_images_list_input(tmp_path): @@ -251,7 +265,7 @@ def test_read_images_list_input(tmp_path): imsave(path1, img) imsave(path2, img) - layers = _reader.read_images([path1, path2]) + layers = read_images([path1, path2]) assert len(layers) == 1 data, params, kind = layers[0] assert kind == "image" @@ -262,7 +276,7 @@ def test_read_images_list_input(tmp_path): def test_read_images_empty_list(): with pytest.raises(OSError): - _reader.read_images([]) + read_images([]) def test_read_images_single_glob_pattern(tmp_path): @@ -276,7 +290,7 @@ def test_read_images_single_glob_pattern(tmp_path): imsave(p2, img) imsave(p3, img) - layers = _reader.read_images(str(tmp_path / "*.png")) + layers = read_images(str(tmp_path / "*.png")) assert len(layers) == 1 data, params, kind = layers[0] assert kind == "image" @@ -292,7 +306,7 @@ def test_read_images_single_glob_pattern(tmp_path): def test_video_init_and_properties(video_path): """Ensure Video object initializes and exposes correct properties.""" - vid = _reader.Video(video_path) + vid = Video(video_path) assert len(vid) == 5 # number of frames created by fixture assert vid.width == 50 @@ -304,7 +318,7 @@ def test_video_init_and_properties(video_path): def test_video_read_single_frame(video_path): """Check that we can read at least one frame correctly.""" - vid = _reader.Video(video_path) + vid = Video(video_path) vid.set_to_frame(0) frame = vid.read_frame() @@ -318,12 +332,12 @@ def test_video_read_single_frame(video_path): def test_video_reader_invalid_path(): """Invalid path should raise ValueError.""" with pytest.raises(ValueError): - _ = _reader.Video("") + _ = Video("") def test_read_video_output(video_path): """Test the full read_video() API returns expected tuple structure.""" - layers = _reader.read_video(video_path) + layers = read_video(video_path) assert len(layers) == 1 data, params, kind = layers[0] @@ -342,9 +356,9 @@ def test_read_video_output(video_path): def test_get_video_reader_dispatch(video_path): assert _reader.get_video_reader(video_path) is not None - assert _reader.is_video(str(video_path)) + assert is_video(str(video_path)) assert _reader.get_video_reader("file.txt") is None - assert not _reader.is_video("file.png") + assert not is_video("file.png") def test_lazy_imread_list_no_stack(tmp_path): @@ -352,7 +366,7 @@ def test_lazy_imread_list_no_stack(tmp_path): p1, p2 = tmp_path / "a.png", tmp_path / "b.png" imsave(p1, img) imsave(p2, img) - res = _reader._lazy_imread([p1, p2], use_dask=True, stack=False) + res = _lazy_imread([p1, p2], use_dask=True, stack=False) assert isinstance(res, list) and len(res) == 2 assert all(isinstance(x, da.Array) for x in res) @@ -362,7 +376,7 @@ def test_read_images_list_metadata_paths(tmp_path): p1, p2 = tmp_path / "img1.png", tmp_path / "img2.png" imsave(p1, img) imsave(p2, img) - [(data, params, kind)] = _reader.read_images([p2, p1]) # unordered input + [(data, params, kind)] = read_images([p2, p1]) # unordered input assert params["metadata"]["paths"] # exists assert len(params["metadata"]["paths"]) == 2 # natsorted is applied; assert the order is deterministic by name @@ -376,70 +390,17 @@ def test_lazy_imread_grayscale_and_rgba(tmp_path): p1, p2 = tmp_path / "g.png", tmp_path / "r.png" cv2.imwrite(str(p1), gray) # cv2 writes grayscale as-is; color images are written as BGR cv2.imwrite(str(p2), cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA)) - res = _reader._lazy_imread([p1, p2], use_dask=False, stack=False) + res = _lazy_imread([p1, p2], use_dask=False, stack=False) assert all(img.shape[-1] == 3 for img in res) -@pytest.mark.parametrize( - "exists", - [ - True, - False, - ], -) -def test_load_superkeypoints(monkeypatch, tmp_path, exists): - """Test loading of superkeypoints JSON with and without the file present.""" - module_dir = tmp_path / "module" - assets_dir = module_dir / "assets" - assets_dir.mkdir(parents=True) - - super_animal = "fake" - json_path = assets_dir / f"{super_animal}.json" - - if exists: - json_path.write_text('{"SK1": [1, 2]}') - - # Patch module __file__ - fake_file = module_dir / "_reader_fake.py" - fake_file.write_text("# fake module") - monkeypatch.setattr("napari_deeplabcut._reader.__file__", str(fake_file)) - - if exists: - assert _reader._load_superkeypoints(super_animal) == {"SK1": [1, 2]} - else: - with pytest.raises(FileNotFoundError): - _reader._load_superkeypoints(super_animal) - - -@pytest.mark.parametrize( - "exists", - [ - True, - False, - ], -) -def test_load_superkeypoints_diagram(monkeypatch, tmp_path, exists): - """Test loading of superkeypoints diagram with and without the file present.""" - module_dir = tmp_path / "module" - assets_dir = module_dir / "assets" - assets_dir.mkdir(parents=True) - - super_animal = "fake" - jpg_path = assets_dir / f"{super_animal}.jpg" - - if exists: - Image.new("RGB", (10, 10), "white").save(jpg_path) - - # Patch module __file__ - fake_file = module_dir / "_reader_fake.py" - fake_file.write_text("# fake") - monkeypatch.setattr("napari_deeplabcut._reader.__file__", str(fake_file)) - - if exists: - array, meta, layer_type = _reader._load_superkeypoints_diagram(super_animal) - assert layer_type == "images" - assert meta == {"root": ""} - assert tuple(array.shape[-3:-1]) == (10, 10) - else: - with pytest.raises(FileNotFoundError): - _reader._load_superkeypoints_diagram(super_animal) +def test_load_superkeypoints(): + """Test loading of superkeypoints JSON to ensure file is present and correctly parsed.""" + json_file = load_superkeypoints("superanimal_quadruped") + assert isinstance(json_file, dict) + + +def test_load_superkeypoints_diagram(): + """Test loading of superkeypoints diagram to ensure file is present and correctly read.""" + diagram = load_superkeypoints_diagram("superanimal_quadruped") + assert diagram.ndim == 3 # should be an RGB image diff --git a/src/napari_deeplabcut/_tests/test_widgets.py b/src/napari_deeplabcut/_tests/test_widgets.py index 05f60f56..c66e5bfd 100644 --- a/src/napari_deeplabcut/_tests/test_widgets.py +++ b/src/napari_deeplabcut/_tests/test_widgets.py @@ -1,61 +1,85 @@ +# src/napari_deeplabcut/_tests/test_widgets.py import os import types +from pathlib import Path import numpy as np import pytest import yaml -from qtpy.QtSvgWidgets import QSvgWidget +from napari.layers import Image, Tracks +from qtpy.QtWidgets import QScrollArea from vispy import keys from napari_deeplabcut import _widgets +from napari_deeplabcut.core import io, keypoints +from napari_deeplabcut.core.io import populate_keypoint_layer_properties +from napari_deeplabcut.ui.color_scheme_display import ColorSchemeDisplay +from napari_deeplabcut.ui.dialogs import ShortcutRow +from napari_deeplabcut.ui.labels_and_dropdown import KeypointsDropdownMenu, LabelPair +from napari_deeplabcut.ui.plots.trajectory import KeypointMatplotlibCanvas + +from .conftest import force_show def test_guess_continuous(): - # Hack: guess_continuous overrides napari's default logic to avoid misclassifying categorical properties - assert _widgets.guess_continuous(np.array([0.0])) # Floats → continuous - assert not _widgets.guess_continuous(np.array(list("abc"))) # Strings → categorical + import numpy as np + from napari.layers.utils import color_manager + + # Patch is applied during KeypointControls init (or import-time depending on your setup) + # Expect float -> continuous + assert color_manager.guess_continuous(np.array([0.0])) + # Expect object/categorical -> NOT continuous + assert not color_manager.guess_continuous(np.array(["a", "b"], dtype=object)) -def test_keypoint_controls(viewer, qtbot): - controls = _widgets.KeypointControls(viewer) +@pytest.mark.usefixtures("qtbot") +def test_keypoint_controls(keypoint_controls): + controls = keypoint_controls controls.label_mode = "loop" assert controls._radio_group.checkedButton().text() == "Loop" controls.cycle_through_label_modes() assert controls._radio_group.checkedButton().text() == "Sequential" -def test_save_layers(viewer, points): - controls = _widgets.KeypointControls(viewer) +@pytest.mark.usefixtures("qtbot") +def test_save_layers(viewer, keypoint_controls, points): viewer.layers.selection.add(points) - # _save_layers_dialog bypasses napari's Save dialog for Points layers (used in headless tests) - _widgets._save_layers_dialog(controls) + keypoint_controls._save_layers_dialog() + + +@pytest.mark.usefixtures("qtbot") +def test_show_trails(viewer, keypoint_controls, store): + keypoint_controls._stores[store.layer] = store + viewer.layers.selection.active = store.layer + keypoint_controls._is_saved = True + keypoint_controls._trail_cb.setChecked(True) -def test_show_trails(viewer, store): - controls = _widgets.KeypointControls(viewer) - controls._stores["temp"] = store - controls._is_saved = True - controls._show_trails(state=2) + trails = keypoint_controls._trails_controller.layer + assert trails is not None + assert isinstance(trails, Tracks) + assert trails.visible is True -def test_extract_single_frame(viewer, images): +@pytest.mark.usefixtures("qtbot") +def test_extract_single_frame(keypoint_controls, viewer, images): viewer.layers.selection.add(images) - controls = _widgets.KeypointControls(viewer) - controls._extract_single_frame() + keypoint_controls._extract_single_frame() -def test_store_crop_coordinates(viewer, images, config_path): +@pytest.mark.usefixtures("qtbot") +def test_store_crop_coordinates(keypoint_controls, viewer, images, config_path): viewer.layers.selection.add(images) _ = viewer.add_shapes(np.random.random((4, 3)), shape_type="rectangle") - controls = _widgets.KeypointControls(viewer) - controls._images_meta = { - "name": "fake_video", - "project": os.path.dirname(config_path), - } + # _image_meta is expected to be an ImageMetadata instance + keypoint_controls._image_meta = _widgets.ImageMetadata(name="fake_video") + # _store_crop_coordinates now uses _project_path instead of reading "project" from _image_meta + keypoint_controls._project_path = os.path.dirname(config_path) # Stores crop coordinates from a rectangle shape into the project's config.yaml - controls._store_crop_coordinates() + keypoint_controls._store_crop_coordinates() +@pytest.mark.usefixtures("qtbot") def test_toggle_face_color(viewer, points): viewer.layers.selection.add(points) view = viewer.window._qt_viewer @@ -67,6 +91,7 @@ def test_toggle_face_color(viewer, points): assert points._face.color_properties.name == "id" +@pytest.mark.usefixtures("qtbot") def test_toggle_edge_color(viewer, points): viewer.layers.selection.add(points) view = viewer.window._qt_viewer @@ -76,6 +101,7 @@ def test_toggle_edge_color(viewer, points): np.testing.assert_array_equal(points.border_width, 2) +@pytest.mark.usefixtures("qtbot") def test_dropdown_menu(qtbot): widget = _widgets.DropdownMenu(list("abc")) qtbot.add_widget(widget) @@ -86,8 +112,9 @@ def test_dropdown_menu(qtbot): assert widget.currentText() == "a" +@pytest.mark.usefixtures("qtbot") def test_keypoints_dropdown_menu_selection_updates_store(store, qtbot): - widget = _widgets.KeypointsDropdownMenu(store) + widget = KeypointsDropdownMenu(store) qtbot.add_widget(widget) id_menu = widget.menus.get("id") label_menu = widget.menus["label"] @@ -103,16 +130,18 @@ def test_keypoints_dropdown_menu_selection_updates_store(store, qtbot): assert store.current_label == label_menu.currentText() +@pytest.mark.usefixtures("qtbot") def test_keypoints_dropdown_menu_single_animal_has_no_id_menu(single_animal_store, qtbot): - widget = _widgets.KeypointsDropdownMenu(single_animal_store) + widget = KeypointsDropdownMenu(single_animal_store) qtbot.add_widget(widget) assert "id" not in widget.menus assert "label" in widget.menus assert widget.menus["label"].count() > 0 +@pytest.mark.usefixtures("qtbot") def test_keypoints_dropdown_menu(store, qtbot): - widget = _widgets.KeypointsDropdownMenu(store) + widget = KeypointsDropdownMenu(store) qtbot.add_widget(widget) # Menus for both "id" and "label" should exist; label menu reflects current keypoint # This confirms we have multi-animal data @@ -132,16 +161,18 @@ def test_keypoints_dropdown_menu(store, qtbot): assert [label_menu.itemText(i) for i in range(label_menu.count())] == expected_labels_second +@pytest.mark.usefixtures("qtbot") def test_keypoints_dropdown_menu_unknown_id_yields_empty_list(store): # If an invalid ID is selected, the label menu should be empty - widget = _widgets.KeypointsDropdownMenu(store) + widget = KeypointsDropdownMenu(store) label_menu = widget.menus["label"] widget.refresh_label_menu("__NON_EXISTENT_ID__") assert label_menu.count() == 0 # defaultdict(list) → no labels +@pytest.mark.usefixtures("qtbot") def test_keypoints_dropdown_menu_updates_from_store_current_properties(store, qtbot): - widget = _widgets.KeypointsDropdownMenu(store) + widget = KeypointsDropdownMenu(store) qtbot.add_widget(widget) id_menu = widget.menus.get("id") label_menu = widget.menus["label"] @@ -158,9 +189,11 @@ def test_keypoints_dropdown_menu_updates_from_store_current_properties(store, qt assert label_menu.currentText() == target.label +@pytest.mark.usefixtures("qtbot") def test_keypoints_dropdown_menu_smart_reset(store, qtbot): - widget = _widgets.KeypointsDropdownMenu(store) + widget = KeypointsDropdownMenu(store) qtbot.add_widget(widget) + force_show(widget, qtbot) label_menu = widget.menus["label"] label_menu.update_to("kpt_2") widget._locked = True @@ -173,8 +206,9 @@ def test_keypoints_dropdown_menu_smart_reset(store, qtbot): assert label_menu.currentText() == "kpt_0" +@pytest.mark.usefixtures("qtbot") def test_color_pair(qtbot): - pair = _widgets.LabelPair(color="pink", name="kpt", parent=None) + pair = LabelPair(color="pink", name="kpt", parent=None) qtbot.add_widget(pair) # LabelPair couples a color swatch with a clickable label # Ensure setters update both UI and tooltip @@ -185,8 +219,9 @@ def test_color_pair(qtbot): assert pair.color_label.toolTip() == "kpt2" +@pytest.mark.usefixtures("qtbot") def test_color_scheme_display(qtbot): - widget = _widgets.ColorSchemeDisplay(None) + widget = ColorSchemeDisplay(None) qtbot.add_widget(widget) widget._build() # Initially empty: no color scheme entries and no layout widgets @@ -197,9 +232,10 @@ def test_color_scheme_display(qtbot): assert widget._container.layout().count() == 1 +@pytest.mark.usefixtures("qtbot") def test_matplotlib_canvas_initialization_and_slider(viewer, points, qtbot): # Create the canvas widget - canvas = _widgets.KeypointMatplotlibCanvas(viewer) + canvas = KeypointMatplotlibCanvas(viewer) qtbot.add_widget(canvas) # Simulate adding a Points layer (triggers _load_dataframe) @@ -236,10 +272,10 @@ def _no_autodock(monkeypatch): monkeypatch.setattr(_widgets.QTimer, "singleShot", lambda *args, **kwargs: None) -def test_ensure_mpl_canvas_docked_already_docked(viewer, qtbot, monkeypatch): +@pytest.mark.usefixtures("qtbot") +def test_ensure_mpl_canvas_docked_already_docked(keypoint_controls, qtbot, monkeypatch): """If already docked, it must be a no-op: do not call add_dock_widget again.""" - controls = _widgets.KeypointControls(viewer) - qtbot.add_widget(controls) + controls = keypoint_controls controls._mpl_docked = True # simulate already docked called = {"count": 0} @@ -255,9 +291,10 @@ def fake_add_dock_widget(*args, **kwargs): assert controls._mpl_docked is True # stays docked -def test_ensure_mpl_canvas_docked_missing_window(viewer, qtbot): +@pytest.mark.usefixtures("qtbot") +def test_ensure_mpl_canvas_docked_missing_window(keypoint_controls, qtbot): """If viewer has no window attribute, method should safely no-op.""" - controls = _widgets.KeypointControls(viewer) + controls = keypoint_controls qtbot.add_widget(controls) # Swap the viewer for a minimal stub object with *no* 'window' attribute @@ -270,9 +307,28 @@ def test_ensure_mpl_canvas_docked_missing_window(viewer, qtbot): assert controls._mpl_docked is False -def test_ensure_mpl_canvas_docked_missing_qt_window(viewer, qtbot): +@pytest.mark.usefixtures("qtbot") +def test_trajectory_loader_ignores_invalid_properties(viewer, keypoint_controls, make_real_header_factory): + header = make_real_header_factory(individuals=("",)) + md = populate_keypoint_layer_properties( + header, + labels=["bodypart1"], + ids=[""], + likelihood=np.array([1.0], dtype=float), + paths=[], + colormap="viridis", + ) + md["properties"]["label"] = [np.nan] # invalid + + layer = viewer.add_points(np.array([[0.0, 10.0, 20.0]]), **md) + assert layer is not None + assert keypoint_controls._matplotlib_canvas.df is None # loader should have bailed out safely + + +@pytest.mark.usefixtures("qtbot") +def test_ensure_mpl_canvas_docked_missing_qt_window(keypoint_controls, qtbot): """If window._qt_window is None, method should safely no-op.""" - controls = _widgets.KeypointControls(viewer) + controls = keypoint_controls qtbot.add_widget(controls) class DummyWindow: @@ -291,9 +347,10 @@ def add_dock_widget(self, *args, **kwargs): assert controls._mpl_docked is False -def test_ensure_mpl_canvas_docked_exception_during_docking(viewer, qtbot): +@pytest.mark.usefixtures("qtbot") +def test_ensure_mpl_canvas_docked_exception_during_docking(keypoint_controls, qtbot): """If add_dock_widget raises, method should catch, log, and remain undocked (no crash).""" - controls = _widgets.KeypointControls(viewer) + controls = keypoint_controls qtbot.add_widget(controls) class DummyWindow: @@ -314,9 +371,10 @@ def add_dock_widget(self, *args, **kwargs): assert controls._mpl_docked is False -def test_display_shortcuts_dialog(viewer, qtbot): +@pytest.mark.usefixtures("qtbot") +def test_display_shortcuts_dialog(keypoint_controls, qtbot): """Ensure that the Shortcuts dialog can be created and shown without errors.""" - controls = _widgets.KeypointControls(viewer) + controls = keypoint_controls qtbot.add_widget(controls) # Create the dialog directly @@ -329,48 +387,217 @@ def test_display_shortcuts_dialog(viewer, qtbot): # Verify it is visible assert dlg.isVisible() - - # Ensure the SVG widget is present - found_svg = False - for child in dlg.children(): - if isinstance(child, QSvgWidget): - found_svg = True - break - - assert found_svg, "Shortcuts dialog should contain a QSvgWidget with the shortcuts image." + assert dlg.windowTitle() == "Keyboard shortcuts" + assert dlg.findChildren(QScrollArea) + assert dlg.findChildren(ShortcutRow) # NOTE SuperAnimal keypoints functionality and testing may need an overhaul in the future: # these tests currently exercise only a narrow "everything fine" path and rely on specific metadata # layout and SuperAnimal conversion-table conventions, which makes them susceptible to API changes -def test_widget_load_superkeypoints_diagram(viewer, qtbot, points, superkeypoints_assets): - controls = _widgets.KeypointControls(viewer) - qtbot.add_widget(controls) - - # Inject conversion table into the existing Points layer +@pytest.mark.usefixtures("qtbot") +def test_widget_load_superkeypoints_diagram(keypoint_controls, viewer, qtbot, points, monkeypatch): + # Arrange: conversion table uses *realistic* keys (not SK1/SK2), + # and does not depend on any asset conventions. layer = points - super_animal = superkeypoints_assets["super_animal"] - layer.metadata["tables"] = {super_animal: {"kp1": "SK1", "kp2": "SK2"}} + super_animal = "superanimal_quadruped" + layer.metadata["tables"] = {super_animal: {"kp1": "nose", "kp2": "upper_jaw"}} + + # Arrange: stub I/O so the test doesn't depend on installed assets + dummy_img = np.zeros((8, 8), dtype=np.uint8) + dummy_superkpts = { + "nose": [1.0, 2.0], + "upper_jaw": [3.0, 4.0], + } + monkeypatch.setattr(io, "load_superkeypoints_diagram", lambda name: dummy_img) + monkeypatch.setattr(io, "load_superkeypoints", lambda name: dummy_superkpts) n_layers_before = len(viewer.layers) - controls.load_superkeypoints_diagram() + # Act + keypoint_controls.load_superkeypoints_diagram() + + # Assert: one new image layer is added assert len(viewer.layers) == n_layers_before + 1 + assert isinstance(viewer.layers[-1], Image) + assert viewer.layers[-1].data.shape == dummy_img.shape + + # Assert: labels match the table keys (reference keypoints) assert list(layer.properties["label"]) == ["kp1", "kp2"] - assert controls._keypoint_mapping_button.text() == "Map keypoints" + # Assert: points data updated to [0, x, y] for each mapping + assert layer.data.shape == (2, 3) + assert np.allclose(layer.data[:, 0], 0.0) + assert np.allclose(layer.data[:, 1:], np.array([[1.0, 2.0], [3.0, 4.0]])) + + # Assert: UI updated + assert keypoint_controls._keypoint_mapping_button.text() == "Map keypoints" + assert keypoint_controls._keypoint_mapping_button.text() == "Map keypoints" -def test_widget_map_keypoints_writes_to_config(viewer, qtbot, mapped_points, config_path): - controls = _widgets.KeypointControls(viewer) - qtbot.add_widget(controls) - _, super_animal, bp1, bp2 = mapped_points - controls._map_keypoints(super_animal) +@pytest.mark.usefixtures("qtbot") +def test_widget_map_keypoints_writes_to_config(keypoint_controls, qtbot, points, config_path, monkeypatch): + controls = keypoint_controls + qtbot.add_widget(controls) + # Arrange: ensure the points layer has some data (shape: [t, x, y]) + points.data = np.array( + [ + [0.0, 10.0, 20.0], + [0.0, 30.0, 40.0], + ], + dtype=float, + ) + + # Arrange: provide the metadata that _map_keypoints expects + # _map_keypoints builds config_path as Path(project)/"config.yaml" + project_dir = Path(config_path).parent + points.metadata["project"] = str(project_dir) + points.metadata["tables"] = {"superanimal_quadruped": {}} + + import pandas as pd + + from napari_deeplabcut.config.models import DLCHeaderModel + + cols = pd.MultiIndex.from_product( + [["S"], [""], ["bp1", "bp2"], ["x", "y"]], + names=["scorer", "individuals", "bodyparts", "coords"], + ) + points.metadata["header"] = DLCHeaderModel( + columns=cols, + ) + + # Ensure config file exists (some setups create it already; this is safe) + Path(config_path).write_text("{}", encoding="utf-8") + + # Arrange: stub superkeypoints + nearest-neighbor results to be deterministic + # Your JSON is dict(key -> [x,y]) so we mimic that. + dummy_superkpts = {"nose": [0.0, 0.0], "upper_jaw": [1.0, 1.0]} + monkeypatch.setattr(io, "load_superkeypoints", lambda name: dummy_superkpts) + + # neighbors indices correspond to ordering of list(dummy_superkpts) + # Here: ["nose", "upper_jaw"] -> indices [0, 1] + monkeypatch.setattr(keypoints, "_find_nearest_neighbors", lambda xy, xy_ref: np.array([0, 1])) + + # If your io.load_config / io.write_config do more than YAML I/O, + # you can keep them. Otherwise stubbing them makes the test isolated. + def _load_config(path): + with open(path, encoding="utf-8") as fh: + return yaml.safe_load(fh) or {} + + def _write_config(path, cfg): + with open(path, "w", encoding="utf-8") as fh: + yaml.safe_dump(cfg, fh, sort_keys=False) + + monkeypatch.setattr(io, "load_config", _load_config) + monkeypatch.setattr(io, "write_config", _write_config) + + # Act + controls._map_keypoints("superanimal_quadruped") + + # Assert with open(config_path, encoding="utf-8") as fh: - cfg = yaml.safe_load(fh) + cfg = yaml.safe_load(fh) or {} + assert "SuperAnimalConversionTables" in cfg - assert cfg["SuperAnimalConversionTables"][super_animal] == { - bp1: "SK1", - bp2: "SK2", + + # Optional stronger assertion: verify the mapping is written as expected + assert cfg["SuperAnimalConversionTables"]["superanimal_quadruped"] == { + "bp1": "nose", + "bp2": "upper_jaw", + } + + +def test_read_config_injects_tables_metadata(tmp_path): + cfg = { + "Task": "demo", + "scorer": "Tester", + "date": "2026-03-27", + "multianimalproject": False, + "identity": "", + "project_path": str(tmp_path), + "bodyparts": ["bp1", "bp2"], + "skeleton": [], + "pcutoff": 0.6, + "dotsize": 8, + "colormap": "viridis", + "SuperAnimalConversionTables": { + "superanimal_quadruped": { + "bp1": "nose", + "bp2": "upper_jaw", + } + }, + } + + config_path = tmp_path / "config.yaml" + config_path.write_text(yaml.safe_dump(cfg), encoding="utf-8") + + layers = io.read_config(str(config_path)) + _, layer_props, layer_type = layers[0] + + assert layer_type == "points" + assert "tables" in layer_props["metadata"] + assert layer_props["metadata"]["tables"] == { + "superanimal_quadruped": { + "bp1": "nose", + "bp2": "upper_jaw", + } } + + +@pytest.mark.usefixtures("qtbot") +def test_points_layer_with_tables_shows_superkeypoints_button(keypoint_controls, qtbot, points): + controls = keypoint_controls + qtbot.add_widget(controls) + + assert not controls._keypoint_mapping_button.isVisible() + + points.metadata["tables"] = {"superanimal_quadruped": {"bp1": "nose", "bp2": "upper_jaw"}} + + # Simulate the same setup path that real inserted/adopted layers use + controls._setup_points_layer(points, allow_merge=False) + + assert not controls._keypoint_mapping_button.isHidden() + assert controls._keypoint_mapping_button.text() == "Load superkeypoints diagram" + + +@pytest.mark.usefixtures("qtbot") +def test_points_layer_with_tables_button_not_lost_on_merge_path(keypoint_controls, qtbot, points, monkeypatch): + controls = keypoint_controls + qtbot.add_widget(controls) + + points.metadata["tables"] = {"superanimal_quadruped": {"bp1": "nose"}} + + # Force the merge branch to happen + monkeypatch.setattr(controls, "_maybe_merge_config_points_layer", lambda layer: True) + + controls._setup_points_layer(points, allow_merge=True) + + assert controls._keypoint_mapping_button.isHidden() + + +@pytest.mark.usefixtures("qtbot") +def test_video_panel_has_extraction_options(keypoint_controls): + controls = keypoint_controls + panel = controls._video_group + assert panel.extract_button.text() == "Extract current frame" + assert panel.crop_button.text() == "Save crop to config" + assert panel.export_labels_cb.text() == "Also export labels" + assert panel.apply_crop_cb.text() == "Crop to rectangle" + + +@pytest.mark.usefixtures("qtbot") +def test_extract_single_frame_warns_without_image_layer(keypoint_controls, qtbot, monkeypatch): + controls = keypoint_controls + qtbot.addWidget(controls) + + seen = {} + + monkeypatch.setattr( + "napari_deeplabcut.ui.cropping.show_warning", + lambda msg: seen.setdefault("warning", msg), + ) + + controls._extract_single_frame() + + assert "No image/video layer is active." in seen["warning"] diff --git a/src/napari_deeplabcut/_tests/test_writer.py b/src/napari_deeplabcut/_tests/test_writer.py index 4fac82bc..355fb018 100644 --- a/src/napari_deeplabcut/_tests/test_writer.py +++ b/src/napari_deeplabcut/_tests/test_writer.py @@ -2,10 +2,15 @@ import numpy as np import pandas as pd +import pytest import yaml from skimage.io import imread -from napari_deeplabcut import _writer, misc +from napari_deeplabcut import _writer +from napari_deeplabcut.config.models import AnnotationKind, DLCHeaderModel +from napari_deeplabcut.core import io +from napari_deeplabcut.core.dataframes import guarantee_multiindex_rows +from napari_deeplabcut.core.errors import MissingProvenanceError rng = np.random.default_rng(42) @@ -15,7 +20,7 @@ def test_write_config(tmp_path): cfg = {"a": 1, "b": 2} path = tmp_path / "config.yaml" - _writer._write_config(str(path), cfg) + io.write_config(str(path), cfg) assert path.exists() text = path.read_text() @@ -35,10 +40,8 @@ def test_write_image(tmp_path): # Form_df — multi-animal + single-animal - - def _fake_metadata_for_df(df, paths): - """Helper for metadata for _form_df. + """Helper for metadata for form_df. IMPORTANT: The writer assigns properties row-wise, so we must provide per-row arrays (length == n_rows). We cycle through (individual, bodypart) @@ -46,7 +49,7 @@ def _fake_metadata_for_df(df, paths): """ from itertools import cycle, islice, product - header = misc.DLCHeader(df.columns) + header = DLCHeaderModel(columns=df.columns) n_rows = len(df) # Build a cyclic sequence of (id, label) pairs from header @@ -78,6 +81,35 @@ def _fake_metadata_for_df(df, paths): } +def _add_source_io(metadata: dict, *, root: Path, kind: AnnotationKind, source_name: str) -> None: + """Attach minimal PointsMetadata.io dict to metadata['metadata'].""" + md = metadata.setdefault("metadata", {}) + md["io"] = { + "schema_version": 1, + "project_root": str(root), + "source_relpath_posix": source_name.replace("\\", "/"), + "kind": kind, # AnnotationKind.GT or AnnotationKind.MACHINE + "dataset_key": "keypoints", + } + # legacy migration compatibility (optional but good) + md["source_h5"] = str((root / source_name).resolve()) + md["source_h5_name"] = source_name + md["source_h5_stem"] = Path(source_name).stem + + +def _add_save_target(metadata: dict, *, root: Path, scorer: str) -> None: + """Attach promotion save_target (GT) to metadata['metadata'].""" + md = metadata.setdefault("metadata", {}) + md["save_target"] = { + "schema_version": 1, + "project_root": str(root), + "source_relpath_posix": f"CollectedData_{scorer}.h5", + "kind": AnnotationKind.GT, + "dataset_key": "keypoints", + "scorer": scorer, + } + + def test_form_df_multi_animal(fake_keypoints): n = len(fake_keypoints) metadata = _fake_metadata_for_df(fake_keypoints, [f"img{i}.png" for i in range(n)]) @@ -85,7 +117,7 @@ def test_form_df_multi_animal(fake_keypoints): # inds + (x,y) data = np.column_stack([np.arange(n), rng.random(n), rng.random(n)]) - df = _writer._form_df(data, metadata) + df = io.form_df(data, layer_metadata=metadata["metadata"], layer_properties=metadata["properties"]) assert isinstance(df, pd.DataFrame) assert len(df) == n @@ -99,7 +131,7 @@ def test_form_df_multi_animal(fake_keypoints): def test_form_df_single_animal(fake_keypoints): - """Drop the individuals level and check that _form_df handles it.""" + """Drop the individuals level and check that form_df handles it.""" df_single = fake_keypoints.xs("animal_0", axis=1, level="individuals") scorer_values = df_single.columns.get_level_values("scorer").unique() bodyparts_values = df_single.columns.get_level_values("bodyparts").unique() @@ -119,7 +151,7 @@ def test_form_df_single_animal(fake_keypoints): # inds + (x,y) points = np.column_stack([np.arange(n), rng.random(n), rng.random(n)]) - out = _writer._form_df(points, metadata) + out = io.form_df(points, layer_metadata=metadata["metadata"], layer_properties=metadata["properties"]) assert isinstance(out, pd.DataFrame) assert len(out) == n @@ -140,7 +172,7 @@ def test_write_hdf_basic(tmp_path, fake_keypoints): root.mkdir() fake_keypoints.to_hdf(root / "data.h5", key="data") - header = misc.DLCHeader(fake_keypoints.columns) + header = DLCHeaderModel(columns=fake_keypoints.columns) # Build per-row properties (length == n_rows) n_rows = len(fake_keypoints) @@ -173,10 +205,11 @@ def test_write_hdf_basic(tmp_path, fake_keypoints): ] ) - fname = _writer.write_hdf("whatever.h5", points, metadata) + fnames = _writer.write_hdf_napari_dlc("whatever.h5", points, metadata) - h5_path = root / fname - csv_path = h5_path.with_suffix(".csv") + h5_path = Path(fnames[0]) + csv_path = Path(fnames[1]) + assert h5_path.name == csv_path.with_suffix(".h5").name == "CollectedData_me.h5" assert h5_path.exists() assert csv_path.exists() @@ -187,26 +220,22 @@ def test_write_hdf_basic(tmp_path, fake_keypoints): assert len(df) == n_rows -def test_write_hdf_machine_prediction_merge(tmp_path, fake_keypoints): +def test_write_hdf_promotion_merges_into_existing_gt(tmp_path, fake_keypoints, monkeypatch): """ - Trigger the special 'machine' branch: - - metadata["name"] contains 'machine' - - an existing CollectedData*.h5 file exists - -> data should be merged + Promotion contract: + - source is machine/prediction (io.kind is AnnotationKind.MACHINE) + - save_target points to CollectedData_.h5 + - writer must MERGE safely into GT (not overwrite blindly), + and must NOT write back to prediction file. """ root = tmp_path / "proj" root.mkdir() - # --- FIX: make GT index consistent with writer output (single-level MultiIndex) --- - gt = fake_keypoints.copy() - gt_idx = [f"img{i}.png" for i in gt.index] - gt.index = pd.MultiIndex.from_tuples([(x,) for x in gt_idx]) # ('img0.png',) etc. - gt_path = root / "CollectedData_me.h5" - gt.to_hdf(gt_path, key="data") + # Always allow overwrite confirmation in unit test + # monkeypatch.setattr(dialogs, "maybe_confirm_overwrite", lambda *args, **kwargs: True) - header = misc.DLCHeader(fake_keypoints.columns) + header = DLCHeaderModel(columns=fake_keypoints.columns) - # Build per-row properties n_rows = len(fake_keypoints) from itertools import cycle, islice, product @@ -216,7 +245,7 @@ def test_write_hdf_machine_prediction_merge(tmp_path, fake_keypoints): per_row_labels = [p[1] for p in sel] metadata = { - "name": "machine_predictions", + "name": "machinelabels-iter0", "properties": { "label": per_row_labels, "id": per_row_ids, @@ -229,46 +258,59 @@ def test_write_hdf_machine_prediction_merge(tmp_path, fake_keypoints): }, } - points = np.column_stack( - [ - np.arange(n_rows), - rng.random(n_rows), - rng.random(n_rows), - ] - ) + # Source provenance: machine/prediction file + _add_source_io(metadata, root=root, kind=AnnotationKind.MACHINE, source_name="machinelabels-iter0.h5") + + # Promotion target: existing GT + _add_save_target(metadata, root=root, scorer="me") - fname = _writer.write_hdf("ignored.h5", points, metadata) - out_h5 = root / fname + # Create existing GT file with DLC-like path-based index (not RangeIndex) + gt_path = root / "CollectedData_me.h5" + gt = fake_keypoints.copy() + + # Use the same "paths" convention as the writer uses when forming df_new + gt.index = [f"img{i}.png" for i in range(len(gt))] + + # Convert to MultiIndex of path components (matches refactored indexing model) + guarantee_multiindex_rows(gt) - df = pd.read_hdf(out_h5) + gt.to_hdf(gt_path, key="keypoints", mode="w") - # merged data must include at least as many rows as the original - assert len(df) >= n_rows + # Create a machine file too; it must remain untouched + machine_path = root / "machinelabels-iter0.h5" + df_machine = pd.DataFrame(np.nan, index=[0], columns=fake_keypoints.columns) + df_machine.to_hdf(machine_path, key="keypoints", mode="w") + machine_before = pd.read_hdf(machine_path, key="keypoints") + + points = np.column_stack([np.arange(n_rows), rng.random(n_rows), rng.random(n_rows)]) + + fnames = _writer.write_hdf_napari_dlc("ignored.h5", points, metadata) + assert Path(fnames[0]).name == "CollectedData_me.h5" + + # GT should exist and be readable + df = pd.read_hdf(fnames[0], key="keypoints") + assert isinstance(df, pd.DataFrame) - # scorer should match original GT scorer + # Must still be scored as "me" after promotion assert df.columns.get_level_values("scorer")[0] == "me" + # Machine file must be unchanged + machine_after = pd.read_hdf(machine_path, key="keypoints") + pd.testing.assert_frame_equal(machine_before, machine_after) -def test_write_hdf_machine_pred_no_gt(tmp_path, fake_keypoints): + +def test_write_hdf_machine_source_without_save_target_aborts(tmp_path, fake_keypoints): """ - Trigger machine branch, but **without** a CollectedData*.h5 file. - It should: - - load config.yaml to get scorer - - write under "CollectedData_{scorer}.h5" + New contract: + - machine/prediction sources must NEVER be written back. + - if save_target is missing, writer must abort deterministically. """ - project_root = tmp_path / "proj" - project_root.mkdir() - - # The writer looks for config.yaml at Path(root).parents[1] / "config.yaml". - # With root = str(project_root), that is two levels above 'proj'. - cfg_path = project_root.parents[1] / "config.yaml" - cfg_path.parent.mkdir(parents=True, exist_ok=True) - cfg_path.write_text("scorer: alice") - - header = misc.DLCHeader(fake_keypoints.columns) + root = tmp_path / "proj" + root.mkdir() - # Build per-row properties + header = DLCHeaderModel(columns=fake_keypoints.columns) n_rows = len(fake_keypoints) + from itertools import cycle, islice, product pairs = list(product(header.individuals, header.bodyparts)) @@ -277,7 +319,7 @@ def test_write_hdf_machine_pred_no_gt(tmp_path, fake_keypoints): per_row_labels = [p[1] for p in sel] metadata = { - "name": "machine_predictions", + "name": "machinelabels-iter0", "properties": { "label": per_row_labels, "id": per_row_ids, @@ -286,54 +328,69 @@ def test_write_hdf_machine_pred_no_gt(tmp_path, fake_keypoints): "metadata": { "header": header, "paths": [f"img{i}.png" for i in range(n_rows)], - "root": str(project_root), + "root": str(root), }, } - points = np.column_stack( - [ - np.arange(n_rows), - rng.random(n_rows), - rng.random(n_rows), - ] - ) + _add_source_io(metadata, root=root, kind=AnnotationKind.MACHINE, source_name="machinelabels-iter0.h5") - fname = _writer.write_hdf("ignored.h5", points, metadata) + points = np.column_stack([np.arange(n_rows), rng.random(n_rows), rng.random(n_rows)]) - # Should name file based on scorer - assert fname.startswith("CollectedData_alice") + with pytest.raises(MissingProvenanceError): + _writer.write_hdf_napari_dlc("ignored.h5", points, metadata) - out_h5 = project_root / fname - df = pd.read_hdf(out_h5) - # columns scorer should be "alice" - assert df.columns.get_level_values("scorer")[0] == "alice" +def test_write_hdf_promotion_creates_gt_when_missing(tmp_path, fake_keypoints, monkeypatch): + """ + Promotion contract: + - machine source + save_target => create/update CollectedData_.h5 + - scorer level should be rewritten to chosen scorer + - machine file must not be created/modified by writer + """ + root = tmp_path / "proj" + root.mkdir() + # monkeypatch.setattr(dialogs, "maybe_confirm_overwrite", lambda *args, **kwargs: True) -# Write_masks — verify masks & vertices -def test_write_masks(tmp_path): - foldername = str(tmp_path / "masks.h5") + header = DLCHeaderModel(columns=fake_keypoints.columns) + n_rows = len(fake_keypoints) - # fake polygon: frame index always 0 - data = [ - np.array([[0, 5, 5], [0, 5, 2]]).T # (inds, y, x) - ] + from itertools import cycle, islice, product + + pairs = list(product(header.individuals, header.bodyparts)) + sel = list(islice(cycle(pairs), n_rows)) + per_row_ids = [p[0] for p in sel] + per_row_labels = [p[1] for p in sel] metadata = { + "name": "machinelabels-iter0", + "properties": { + "label": per_row_labels, + "id": per_row_ids, + "likelihood": [1.0] * n_rows, + }, "metadata": { - "shape": (1, 10, 10), - "paths": ["frame0.png"], - } + "header": header, + "paths": [f"img{i}.png" for i in range(n_rows)], + "root": str(root), + }, } - output_dir = _writer.write_masks(foldername, data, metadata) - out_path = Path(output_dir) + _add_source_io(metadata, root=root, kind=AnnotationKind.MACHINE, source_name="machinelabels-iter0.h5") + _add_save_target(metadata, root=root, scorer="alice") + + points = np.column_stack([np.arange(n_rows), rng.random(n_rows), rng.random(n_rows)]) - assert out_path.exists() + fnames = _writer.write_hdf_napari_dlc("ignored.h5", points, metadata) + assert Path(fnames[0]).name == "CollectedData_alice.h5" - # mask files present - mask_files = list(out_path.glob("*_obj_*.png")) - assert mask_files + out_h5 = Path(fnames[0]) + assert out_h5.exists() - # vertices.csv must be present - assert (out_path / "vertices.csv").exists() + df = pd.read_hdf(out_h5, key="keypoints") + assert df.columns.get_level_values("scorer")[0] == "alice" + + # Ensure we still did NOT write back to a machine source file + assert not (root / "machinelabels-iter0.h5").exists(), ( + "Writer should not create/overwrite prediction files during promotion." + ) diff --git a/src/napari_deeplabcut/_tests/ui/__init__.py b/src/napari_deeplabcut/_tests/ui/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/napari_deeplabcut/_tests/ui/conftest.py b/src/napari_deeplabcut/_tests/ui/conftest.py new file mode 100644 index 00000000..a73d4c01 --- /dev/null +++ b/src/napari_deeplabcut/_tests/ui/conftest.py @@ -0,0 +1,198 @@ +# src/napari_deeplabcut/_tests/ui/conftest.py +from __future__ import annotations + +import numpy as np +import pytest +from napari.layers import Points +from napari.utils.events import EmitterGroup, Event +from qtpy.QtWidgets import QWidget + + +class FakeSelection: + def __init__(self): + self._active = None + self.events = EmitterGroup(source=self, active=Event) + + @property + def active(self): + return self._active + + @active.setter + def active(self, value): + self._active = value + self.events.active(value=value) + + +class FakeLayers(list): + def __init__(self, iterable=()): + super().__init__(iterable) + self.events = EmitterGroup(source=self, inserted=Event, removed=Event) + self.selection = FakeSelection() + + def append(self, layer): + super().append(layer) + self.events.inserted(value=layer) + + def remove(self, layer): + super().remove(layer) + self.events.removed(value=layer) + + +class FakeDims: + def __init__(self, current_step=(0,)): + self._current_step = tuple(current_step) + self.events = EmitterGroup(source=self, current_step=Event) + + @property + def current_step(self): + return self._current_step + + @current_step.setter + def current_step(self, value): + self._current_step = tuple(value) + self.events.current_step(value=self._current_step) + + +class FakeViewer: + def __init__(self, layers=None, current_step=(0,)): + self.layers = FakeLayers(layers or []) + self.dims = FakeDims(current_step=current_step) + + +@pytest.fixture +def fake_viewer(): + return FakeViewer() + + +@pytest.fixture +def make_header(make_real_header_factory): + """ + Return a callable header factory, not a single pre-built header. + """ + + def _make_header( + *, + bodyparts=("nose", "tail"), + individuals=("",), + scorer="tester", + ): + return make_real_header_factory( + bodyparts=bodyparts, + individuals=individuals, + scorer=scorer, + ) + + return _make_header + + +@pytest.fixture +def get_header_model(): + """ + Match the production callback signature used by ColorSchemeResolver/Panel. + """ + + def _get_header_model(md: dict): + return md.get("header", None) + + return _get_header_model + + +def _make_cycles_for_bodyparts(bodyparts: list[str] | tuple[str, ...]): + base = { + "nose": np.array([1.0, 0.0, 0.0, 1.0]), # red + "tail": np.array([0.0, 1.0, 0.0, 1.0]), # green + "ear": np.array([0.0, 0.0, 1.0, 1.0]), # blue + "paw": np.array([1.0, 1.0, 0.0, 1.0]), # yellow + "cfg1": np.array([1.0, 0.0, 0.0, 1.0]), + "cfg2": np.array([0.0, 1.0, 0.0, 1.0]), + "bodypart1": np.array([1.0, 0.0, 0.0, 1.0]), + "bodypart2": np.array([0.0, 1.0, 0.0, 1.0]), + "bodypart3": np.array([0.0, 0.0, 1.0, 1.0]), + } + return {bp: base[bp] for bp in bodyparts} + + +def _make_cycles_for_ids(ids: list[str] | tuple[str, ...]): + base = { + "animal1": np.array([1.0, 0.0, 1.0, 1.0]), # magenta + "animal2": np.array([0.0, 1.0, 1.0, 1.0]), # cyan + "animal3": np.array([0.5, 0.5, 0.5, 1.0]), # gray + } + return {i: base[i] for i in ids} + + +@pytest.fixture +def make_points_layer(make_header): + def _make_points_layer( + *, + data: np.ndarray | None = None, + labels: list[str] | tuple[str, ...] = ("nose", "tail"), + ids: list[str] | tuple[str, ...] | None = None, + bodyparts: list[str] | tuple[str, ...] | None = None, + individuals: list[str] | tuple[str, ...] = ("",), + project: str | None = None, + shown: list[bool] | np.ndarray | None = None, + visible: bool = True, + include_id_cycle: bool = True, + extra_metadata: dict | None = None, + ) -> Points: + labels = list(labels) + if ids is None: + ids = [""] * len(labels) + ids = list(ids) + + if bodyparts is None: + bodyparts = tuple(dict.fromkeys(labels)) + + header = make_header(bodyparts=bodyparts, individuals=individuals) + + if data is None: + # default: frame, y, x + data = np.array([[0, float(i), float(i)] for i in range(len(labels))], dtype=float) + + metadata = { + "header": header, + "face_color_cycles": { + "label": _make_cycles_for_bodyparts(bodyparts), + }, + } + + if include_id_cycle: + non_empty_ids = [x for x in individuals if x != ""] + if non_empty_ids: + metadata["face_color_cycles"]["id"] = _make_cycles_for_ids(non_empty_ids) + + if project is not None: + metadata["project"] = str(project) + + if extra_metadata: + metadata.update(extra_metadata) + + properties = { + "label": np.asarray(labels, dtype=object), + "id": np.asarray(ids, dtype=object), + } + + layer = Points( + data=np.asarray(data, dtype=float), + properties=properties, + metadata=metadata, + name="points", + ) + layer.visible = visible + + if shown is not None: + layer.shown = np.asarray(shown, dtype=bool) + + return layer + + return _make_points_layer + + +@pytest.fixture +def dialog_parent(qtbot): + parent = QWidget() + parent.setGeometry(100, 200, 800, 600) + qtbot.addWidget(parent) + parent.show() + return parent diff --git a/src/napari_deeplabcut/_tests/ui/test_color_scheme.py b/src/napari_deeplabcut/_tests/ui/test_color_scheme.py new file mode 100644 index 00000000..95c7612e --- /dev/null +++ b/src/napari_deeplabcut/_tests/ui/test_color_scheme.py @@ -0,0 +1,393 @@ +# src/napari_deeplabcut/_tests/ui/test_color_scheme.py +from __future__ import annotations + +import numpy as np +import pytest + +from napari_deeplabcut.config.models import DLCHeaderModel +from napari_deeplabcut.config.settings import ( + DEFAULT_MULTI_ANIMAL_INDIVIDUAL_CMAP, + DEFAULT_SINGLE_ANIMAL_CMAP, +) +from napari_deeplabcut.core import keypoints +from napari_deeplabcut.ui.color_scheme_display import ( + ColorSchemeDisplay, + ColorSchemePanel, + ColorSchemeResolver, + _to_hex, +) + +from ..conftest import force_show + + +def _header_model_from_layer(layer) -> DLCHeaderModel: + hdr = layer.metadata.get("header") + assert hdr is not None, "Expected header in layer metadata" + return hdr if isinstance(hdr, DLCHeaderModel) else DLCHeaderModel.model_validate(hdr) + + +def _is_multianimal_header(header: DLCHeaderModel) -> bool: + inds = list(getattr(header, "individuals", []) or []) + return bool(inds and str(inds[0]) != "") + + +def _config_colormap_from_layer(layer) -> str: + md = layer.metadata or {} + cmap = md.get("config_colormap") + if isinstance(cmap, str) and cmap: + return cmap + return DEFAULT_SINGLE_ANIMAL_CMAP + + +def _expected_cycles_for_policy(layer) -> dict[str, dict[str, np.ndarray]]: + """ + Compute expected cycles from the new centralized policy. + + Source of truth: + - layer header + - metadata['config_colormap'] + - multi-animal policy for individual coloring + """ + header = _header_model_from_layer(layer) + config_cmap = _config_colormap_from_layer(layer) + + config_cycles = keypoints.build_color_cycles(header, config_cmap) or {} + + if _is_multianimal_header(header): + individual_cycles = keypoints.build_color_cycles(header, DEFAULT_MULTI_ANIMAL_INDIVIDUAL_CMAP) or {} + else: + individual_cycles = config_cycles + + return { + "label": config_cycles.get("label", {}), + "id": individual_cycles.get("id", {}), + } + + +def _expected_scheme(layer, *, prop: str, names: list[str]) -> dict[str, str]: + cycles = _expected_cycles_for_policy(layer) + mapping = cycles.get(prop, {}) + return {name: _to_hex(mapping[name]) for name in names if name in mapping} + + +def _expected_scheme_from_policy(layer, *, prop: str, names: list[str]) -> dict[str, str]: + cycles = _expected_cycles_for_policy(layer) + mapping = cycles.get(prop, {}) + return {name: _to_hex(mapping[name]) for name in names if name in mapping} + + +def test_to_hex_converts_rgb_and_rgba(): + assert _to_hex([1.0, 0.0, 0.0]) == "#ff0000" + assert _to_hex([0.0, 1.0, 0.0, 0.5]) == "#00ff00" + assert _to_hex([0.0, 0.0, 1.0, 1.0]) == "#0000ff" + + +def test_to_hex_returns_black_for_too_short_input(): + assert _to_hex([]) == "#000000" + assert _to_hex([1.0, 0.0]) == "#000000" + + +@pytest.mark.usefixtures("qtbot") +def test_color_scheme_display_update_reuse_and_reset(qtbot): + widget = ColorSchemeDisplay() + qtbot.addWidget(widget) + + # Parent does not need to be shown for these tests; check hidden-state instead of isVisible(). + widget.update_color_scheme({"nose": "#ff0000", "tail": "#00ff00"}) + assert widget.scheme_dict == {"nose": "#ff0000", "tail": "#00ff00"} + assert len(widget.labels) == 2 + assert widget.labels[0].part_name == "nose" + assert widget.labels[0].color == "#ff0000" + assert widget.labels[1].part_name == "tail" + assert widget.labels[1].color == "#00ff00" + assert widget.labels[0].isHidden() is False + assert widget.labels[1].isHidden() is False + + # Reuse first widget, hide extra second widget + widget.update_color_scheme({"ear": "#0000ff"}) + assert widget.scheme_dict == {"ear": "#0000ff"} + assert len(widget.labels) == 2 + assert widget.labels[0].part_name == "ear" + assert widget.labels[0].color == "#0000ff" + assert widget.labels[0].isHidden() is False + assert widget.labels[1].isHidden() is True + + widget.reset() + assert widget.scheme_dict == {} + assert all(w.isHidden() for w in widget.labels) + + +def test_resolver_get_target_layer_prefers_active_visible(fake_viewer, make_points_layer, get_header_model): + layer1 = make_points_layer(labels=["nose"], bodyparts=["nose"], visible=True) + layer2 = make_points_layer(labels=["tail"], bodyparts=["tail"], visible=True) + + fake_viewer.layers.append(layer1) + fake_viewer.layers.append(layer2) + fake_viewer.layers.selection.active = layer1 + + resolver = ColorSchemeResolver( + viewer=fake_viewer, + get_color_mode=lambda: str(keypoints.ColorMode.BODYPART), + get_header_model=get_header_model, + ) + + assert resolver.get_target_layer() is layer1 + + +def test_resolver_get_target_layer_falls_back_to_topmost_visible_when_active_hidden( + fake_viewer, + make_points_layer, + get_header_model, +): + layer1 = make_points_layer(labels=["nose"], bodyparts=["nose"], visible=True) + layer2 = make_points_layer(labels=["tail"], bodyparts=["tail"], visible=True) + + fake_viewer.layers.append(layer1) + fake_viewer.layers.append(layer2) + + # Active layer exists but is hidden -> should fall back to topmost visible + layer1.visible = False + fake_viewer.layers.selection.active = layer1 + + resolver = ColorSchemeResolver( + viewer=fake_viewer, + get_color_mode=lambda: str(keypoints.ColorMode.BODYPART), + get_header_model=get_header_model, + ) + + assert resolver.get_target_layer() is layer2 + + +def test_resolver_get_color_property_prefers_id_in_multianimal_individual_mode( + fake_viewer, + make_points_layer, + get_header_model, +): + layer = make_points_layer( + labels=["bodypart1", "bodypart2"], + ids=["animal1", "animal2"], + bodyparts=["bodypart1", "bodypart2"], + individuals=["animal1", "animal2"], + ) + fake_viewer.layers.append(layer) + fake_viewer.layers.selection.active = layer + + resolver = ColorSchemeResolver( + viewer=fake_viewer, + get_color_mode=lambda: str(keypoints.ColorMode.INDIVIDUAL), + get_header_model=get_header_model, + ) + + assert resolver.get_color_property(layer) == "id" + + +def test_resolver_get_visible_categories_filters_frame_shown_and_deduplicates( + fake_viewer, + make_points_layer, + get_header_model, +): + data = np.array( + [ + [0, 0.0, 0.0], # nose, shown + [0, 1.0, 1.0], # tail, hidden + [1, 2.0, 2.0], # nose, other frame + [0, 3.0, 3.0], # ear, shown + [0, 4.0, 4.0], # nose, duplicate in current frame + ], + dtype=float, + ) + layer = make_points_layer( + data=data, + labels=["nose", "tail", "nose", "ear", "nose"], + bodyparts=["nose", "tail", "ear"], + shown=[True, False, True, True, True], + ) + + fake_viewer.layers.append(layer) + fake_viewer.layers.selection.active = layer + fake_viewer.dims.current_step = (0,) + + resolver = ColorSchemeResolver( + viewer=fake_viewer, + get_color_mode=lambda: str(keypoints.ColorMode.BODYPART), + get_header_model=get_header_model, + ) + + visible = resolver.get_visible_categories(layer, "label") + assert visible == ["nose", "ear"] + + +def test_resolver_get_config_categories_prefers_config_yaml_bodyparts( + fake_viewer, + make_points_layer, + get_header_model, + single_animal_project, +): + project, _config_path = single_animal_project + layer = make_points_layer( + labels=["nose"], + bodyparts=["nose", "tail"], + individuals=[""], + project=str(project), + ) + fake_viewer.layers.append(layer) + fake_viewer.layers.selection.active = layer + + resolver = ColorSchemeResolver( + viewer=fake_viewer, + get_color_mode=lambda: str(keypoints.ColorMode.BODYPART), + get_header_model=get_header_model, + ) + + assert resolver.get_config_categories(layer, "label") == ["cfg1", "cfg2"] + + +def test_resolver_get_config_categories_id_falls_back_to_bodyparts_for_single_animal_config( + fake_viewer, + make_points_layer, + get_header_model, + single_animal_project, +): + project, _config_path = single_animal_project + layer = make_points_layer( + labels=["nose"], + bodyparts=["nose", "tail"], + individuals=[""], + project=str(project), + ) + fake_viewer.layers.append(layer) + fake_viewer.layers.selection.active = layer + + resolver = ColorSchemeResolver( + viewer=fake_viewer, + get_color_mode=lambda: str(keypoints.ColorMode.INDIVIDUAL), + get_header_model=get_header_model, + ) + + assert resolver.get_config_categories(layer, "id") == ["cfg1", "cfg2"] + + +@pytest.mark.usefixtures("qtbot") +def test_panel_initial_active_mode_updates_display_from_current_frame( + qtbot, + fake_viewer, + make_points_layer, + get_header_model, +): + data = np.array( + [ + [0, 0.0, 0.0], # nose + [1, 1.0, 1.0], # tail + ], + dtype=float, + ) + layer = make_points_layer( + data=data, + labels=["nose", "tail"], + bodyparts=["nose", "tail"], + ) + fake_viewer.layers.append(layer) + fake_viewer.layers.selection.active = layer + fake_viewer.dims.current_step = (0,) + + panel = ColorSchemePanel( + viewer=fake_viewer, + get_color_mode=lambda: str(keypoints.ColorMode.BODYPART), + get_header_model=get_header_model, + ) + qtbot.addWidget(panel) + force_show(panel, qtbot) + + expected = _expected_scheme_from_policy(layer, prop="label", names=["nose"]) + qtbot.waitUntil(lambda: panel.display.scheme_dict == expected) + + +@pytest.mark.usefixtures("qtbot") +def test_panel_reacts_to_frame_change_event( + qtbot, + fake_viewer, + make_points_layer, + get_header_model, +): + data = np.array( + [ + [0, 0.0, 0.0], # nose + [1, 1.0, 1.0], # tail + ], + dtype=float, + ) + layer = make_points_layer( + data=data, + labels=["nose", "tail"], + bodyparts=["nose", "tail"], + ) + fake_viewer.layers.append(layer) + fake_viewer.layers.selection.active = layer + fake_viewer.dims.current_step = (0,) + + panel = ColorSchemePanel( + viewer=fake_viewer, + get_color_mode=lambda: str(keypoints.ColorMode.BODYPART), + get_header_model=get_header_model, + ) + qtbot.addWidget(panel) + force_show(panel, qtbot) + + expected0 = _expected_scheme_from_policy(layer, prop="label", names=["nose"]) + qtbot.waitUntil(lambda: panel.display.scheme_dict == expected0) + + fake_viewer.dims.current_step = (1,) + expected1 = _expected_scheme_from_policy(layer, prop="label", names=["tail"]) + qtbot.waitUntil(lambda: panel.display.scheme_dict == expected1) + + +@pytest.mark.usefixtures("qtbot") +def test_panel_toggle_switches_from_active_to_config_preview( + qtbot, + fake_viewer, + make_points_layer, + get_header_model, + single_animal_project, +): + project, _config_path = single_animal_project + layer = make_points_layer( + data=np.array([[0, 0.0, 0.0]], dtype=float), + labels=["nose"], + bodyparts=["nose", "tail"], + project=str(project), + extra_metadata={ + "config_colormap": "rainbow", + }, + ) + fake_viewer.layers.append(layer) + fake_viewer.layers.selection.active = layer + fake_viewer.dims.current_step = (0,) + + panel = ColorSchemePanel( + viewer=fake_viewer, + get_color_mode=lambda: str(keypoints.ColorMode.BODYPART), + get_header_model=get_header_model, + ) + qtbot.addWidget(panel) + force_show(panel, qtbot) + + # Active mode should show the currently visible label from the layer. + expected_active = _expected_scheme_from_policy(layer, prop="label", names=["nose"]) + qtbot.waitUntil(lambda: panel.display.scheme_dict == expected_active) + + # Config preview should show the configured bodyparts from config.yaml. + panel._toggle.setChecked(True) + expected_config = _expected_scheme_from_policy(layer, prop="label", names=["cfg1", "cfg2"]) + qtbot.waitUntil(lambda: panel.display.scheme_dict == expected_config) + + +@pytest.mark.usefixtures("qtbot") +def test_color_scheme_panel_delete_later_does_not_crash_on_pending_update(qtbot, fake_viewer, get_header_model): + panel = ColorSchemePanel( + viewer=fake_viewer, + get_color_mode=lambda: str(keypoints.ColorMode.BODYPART), + get_header_model=get_header_model, + ) + qtbot.addWidget(panel) + panel.schedule_update() + qtbot.wait(50) diff --git a/src/napari_deeplabcut/_tests/ui/test_cropping.py b/src/napari_deeplabcut/_tests/ui/test_cropping.py new file mode 100644 index 00000000..c750d230 --- /dev/null +++ b/src/napari_deeplabcut/_tests/ui/test_cropping.py @@ -0,0 +1,516 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import pandas as pd +import pytest + +from napari_deeplabcut.ui import cropping as cropping_mod + +# ----------------------------------------------------------------------------- +# Small fake layer/viewer helpers (no qtbot, no napari viewer fixture) +# ----------------------------------------------------------------------------- + + +class FakeShapes: + def __init__( + self, + *, + data, + shape_type, + metadata=None, + name="shapes", + selected_data=None, + ): + self.data = data + self.shape_type = shape_type + self.metadata = metadata or {} + self.name = name + self.selected_data = selected_data or set() + + +class FakeImage: + def __init__(self, *, data=None, metadata=None, name="image", source_path=None): + self.data = data + self.metadata = metadata or {} + self.name = name + self.source = SimpleNamespace(path=source_path) + + +class FakePoints: + def __init__(self, *, data=None, metadata=None, properties=None, name="points"): + self.data = data + self.metadata = metadata or {} + self.properties = properties or {} + self.name = name + + +class FakeLayerList(list): + def __init__(self, layers=(), active=None): + super().__init__(layers) + self.selection = SimpleNamespace(active=active) + + +class DummyPanel: + def __init__(self): + self.text = None + + def set_context_text(self, text: str) -> None: + self.text = text + + +# ----------------------------------------------------------------------------- +# Basic schema validation +# ----------------------------------------------------------------------------- + + +def test_viewer_crop_coords_accept_valid_tuple(): + coords = cropping_mod.ViewerCropCoords(values=(1, 10, 2, 20)) + assert coords.values == (1, 10, 2, 20) + + +@pytest.mark.parametrize( + "bad", + [ + (1, 1, 2, 20), # x2 <= x1 + (1, 10, 5, 5), # y2 <= y1 + ], +) +def test_viewer_crop_coords_reject_invalid_tuple(bad): + with pytest.raises(ValueError, match="ViewerCropCoords"): + cropping_mod.ViewerCropCoords(values=bad) + + +def test_dlc_config_crop_coords_accept_valid_tuple(): + coords = cropping_mod.DLCConfigCropCoords(values=(5, 15, 6, 30)) + assert coords.values == (5, 15, 6, 30) + + +def test_crop_save_plan_rejects_viewer_coords_for_config(tmp_path: Path): + with pytest.raises(ValueError, match="Refusing to write napari/viewer crop coordinates"): + cropping_mod.CropSavePlan( + config_path=tmp_path / "config.yaml", + project_root=tmp_path, + video_key="video.mp4", + config_crop=cropping_mod.ViewerCropCoords(values=(1, 10, 2, 20)), + ) + + +# ----------------------------------------------------------------------------- +# Rectangle resolution / coordinate conventions +# ----------------------------------------------------------------------------- + + +def test_rectangle_spec_returns_viewer_and_DLC_config_coords(monkeypatch): + monkeypatch.setattr(cropping_mod, "Shapes", FakeShapes) + + # rectangle vertices in [t, y, x] + rect = np.array( + [ + [0.0, 10.0, 20.0], + [0.0, 10.0, 60.0], + [0.0, 40.0, 60.0], + [0.0, 40.0, 20.0], + ], + dtype=float, + ) + layer = FakeShapes( + data=[rect], + shape_type=["rectangle"], + metadata={cropping_mod.DLC_CROP_LAYER_META_KEY: True}, + selected_data={0}, + ) + + viewer = SimpleNamespace( + dims=SimpleNamespace( + # Y extent = 100, X extent = 200 + # This matches the [t, y, x] rectangle coordinates used in the test + range=[(0, 10, 1), (0, 100, 1), (0, 200, 1)] + ) + ) + + spec = cropping_mod._rectangle_spec(viewer, layer, 0) + assert spec is not None + + # raw napari/image-data coords + assert spec.viewer_crop.values == (20, 60, 10, 40) + + # DLC config coords (y flipped with h=100) + assert spec.config_crop.values == (20, 60, 60, 90) + + +def test_find_crop_rectangle_prefers_dedicated_crop_layer(monkeypatch): + monkeypatch.setattr(cropping_mod, "Shapes", FakeShapes) + + dedicated_rect = np.array( + [ + [0.0, 10.0, 20.0], + [0.0, 10.0, 60.0], + [0.0, 40.0, 60.0], + [0.0, 40.0, 20.0], + ], + dtype=float, + ) + other_rect = np.array( + [ + [0.0, 1.0, 2.0], + [0.0, 1.0, 5.0], + [0.0, 3.0, 5.0], + [0.0, 3.0, 2.0], + ], + dtype=float, + ) + + dedicated = FakeShapes( + data=[dedicated_rect], + shape_type=["rectangle"], + metadata={cropping_mod.DLC_CROP_LAYER_META_KEY: True}, + name=cropping_mod.DLC_CROP_LAYER_NAME, + selected_data={0}, + ) + other = FakeShapes( + data=[other_rect], + shape_type=["rectangle"], + metadata={}, + name="other", + selected_data={0}, + ) + + viewer = SimpleNamespace( + layers=FakeLayerList([other, dedicated], active=other), + dims=SimpleNamespace(range=[(0, 10, 1), (0, 100, 1), (0, 200, 1)]), + ) + + spec = cropping_mod.find_crop_rectangle(viewer, prefer_selected=True) + assert spec is not None + assert spec.viewer_crop.values == (20, 60, 10, 40) + + +def test_find_crop_rectangle_ignores_non_rectangles(monkeypatch): + monkeypatch.setattr(cropping_mod, "Shapes", FakeShapes) + + poly = np.array( + [ + [0.0, 10.0, 20.0], + [0.0, 15.0, 25.0], + [0.0, 12.0, 30.0], + ], + dtype=float, + ) + rect = np.array( + [ + [0.0, 5.0, 10.0], + [0.0, 5.0, 20.0], + [0.0, 15.0, 20.0], + [0.0, 15.0, 10.0], + ], + dtype=float, + ) + + poly_layer = FakeShapes( + data=[poly], + shape_type=["polygon"], + metadata={}, + name="poly", + selected_data={0}, + ) + rect_layer = FakeShapes( + data=[rect], + shape_type=["rectangle"], + metadata={}, + name="rect", + selected_data={0}, + ) + + viewer = SimpleNamespace( + layers=FakeLayerList([poly_layer, rect_layer], active=poly_layer), + dims=SimpleNamespace(range=[(0, 10, 1), (0, 100, 1), (0, 200, 1)]), + ) + + spec = cropping_mod.find_crop_rectangle(viewer, prefer_selected=True) + assert spec is not None + assert spec.viewer_crop.values == (10, 20, 5, 15) + + +# ----------------------------------------------------------------------------- +# Planning logic +# ----------------------------------------------------------------------------- + + +def test_plan_frame_extraction_uses_viewer_crop(tmp_path: Path, monkeypatch): + from napari.layers import Image + + monkeypatch.setattr( + cropping_mod, + "find_crop_rectangle", + lambda viewer, prefer_selected=True: cropping_mod.CropRectangleSpec( + viewer_crop=cropping_mod.ViewerCropCoords(values=(2, 8, 3, 9)), + config_crop=cropping_mod.DLCConfigCropCoords(values=(2, 8, 91, 97)), + ), + ) + + image = Image( + np.zeros((10, 20, 30), dtype=np.uint8), + name="demo.mp4", + metadata={"root": str(tmp_path)}, + ) + + viewer = SimpleNamespace( + dims=SimpleNamespace(current_step=(4,), range=[(0, 10, 1), (0, 20, 1), (0, 100, 1)]), + ) + + plan, error = cropping_mod.plan_frame_extraction( + viewer, + image_layer=image, + export_labels=False, + apply_crop=True, + ) + + assert error is None + assert plan is not None + assert isinstance(plan.viewer_crop, cropping_mod.ViewerCropCoords) + assert plan.viewer_crop.values == (2, 8, 3, 9) + + +def test_plan_crop_save_uses_config_crop(monkeypatch, tmp_path: Path): + monkeypatch.setattr( + cropping_mod, + "find_crop_rectangle", + lambda viewer, prefer_selected=True: cropping_mod.CropRectangleSpec( + viewer_crop=cropping_mod.ViewerCropCoords(values=(2, 8, 3, 9)), + config_crop=cropping_mod.DLCConfigCropCoords(values=(2, 8, 91, 97)), + ), + ) + monkeypatch.setattr( + cropping_mod, + "infer_dlc_project_from_image_layer", + lambda image_layer, prefer_project_root=True: SimpleNamespace( + config_path=tmp_path / "config.yaml", + project_root=tmp_path, + root_anchor=tmp_path, + ), + ) + + image = FakeImage( + data=np.zeros((10, 20, 30), dtype=np.uint8), + metadata={}, + name="demo.mp4", + source_path=str(tmp_path / "videos" / "demo.mp4"), + ) + + viewer = SimpleNamespace( + dims=SimpleNamespace(range=[(0, 10, 1), (0, 20, 1), (0, 100, 1)]), + layers=FakeLayerList([], active=None), + ) + + plan, error = cropping_mod.plan_crop_save( + viewer, + image_layer=image, + explicit_project_path=None, + fallback_video_name="demo.mp4", + ) + + assert error is None + assert plan is not None + assert isinstance(plan.config_crop, cropping_mod.DLCConfigCropCoords) + assert plan.config_crop.values == (2, 8, 91, 97) + + +# ----------------------------------------------------------------------------- +# Saving regression +# ----------------------------------------------------------------------------- + + +def test_store_crop_coordinates_saves_DLC_config_coords(monkeypatch, tmp_path: Path): + monkeypatch.setattr(cropping_mod, "Shapes", FakeShapes) + + project = tmp_path / "project" + project.mkdir() + + config_path = project / "config.yaml" + config_path.write_text("video_sets: {}\n", encoding="utf-8") + + video_path = project / "videos" / "demo.mp4" + video_path.parent.mkdir() + video_path.touch() + + rect = np.array( + [ + [0.0, 10.0, 20.0], + [0.0, 10.0, 60.0], + [0.0, 40.0, 60.0], + [0.0, 40.0, 20.0], + ], + dtype=float, + ) + + crop_layer = FakeShapes( + data=[rect], + shape_type=["rectangle"], + metadata={cropping_mod.DLC_CROP_LAYER_META_KEY: True}, + name=cropping_mod.DLC_CROP_LAYER_NAME, + selected_data={0}, + ) + + image_layer = FakeImage( + data=np.zeros((10, 20, 30), dtype=np.uint8), + metadata={"root": str(project / "labeled-data" / "demo")}, + name="demo.mp4", + source_path=str(video_path), + ) + + viewer = SimpleNamespace( + layers=FakeLayerList([crop_layer, image_layer], active=crop_layer), + dims=SimpleNamespace( + current_step=(0,), + # Y extent is 100 (index -2), X extent is 200 (index -1 / old hardcoded index 2) + range=[(0, 10, 1), (0, 100, 1), (0, 200, 1)], + ), + ) + + ok, msg = cropping_mod.store_crop_coordinates( + viewer, + image_layer=image_layer, + explicit_project_path=str(project), + fallback_video_name="demo.mp4", + ) + + assert ok is True + assert "Saved crop" in msg + + cfg = cropping_mod.io.load_config(str(config_path)) + assert cfg["video_sets"][str(video_path)]["crop"] == "20, 60, 60, 90" + + +def test_rectangle_spec_uses_y_axis_extent_for_DLC_config_coords(monkeypatch): + monkeypatch.setattr(cropping_mod, "Shapes", FakeShapes) + + # rectangle vertices in [t, y, x] + rect = np.array( + [ + [0.0, 10.0, 20.0], + [0.0, 10.0, 60.0], + [0.0, 40.0, 60.0], + [0.0, 40.0, 20.0], + ], + dtype=float, + ) + layer = FakeShapes( + data=[rect], + shape_type=["rectangle"], + metadata={cropping_mod.DLC_CROP_LAYER_META_KEY: True}, + selected_data={0}, + ) + + # Make Y extent and X extent different so using the wrong axis would fail. + # dims.range[-2][1] == 100 (Y) + # dims.range[2][1] == 200 (X) <-- old buggy code would incorrectly use this + viewer = SimpleNamespace(dims=SimpleNamespace(range=[(0, 10, 1), (0, 100, 1), (0, 200, 1)])) + + spec = cropping_mod._rectangle_spec(viewer, layer, 0) + assert spec is not None + + # raw napari/image-data coords + assert spec.viewer_crop.values == (20, 60, 10, 40) + + # DLC config coords must use Y extent = 100, not X extent = 200 + assert spec.config_crop.values == (20, 60, 60, 90) + + +# ----------------------------------------------------------------------------- +# Context rendering (logic only, no Qt) +# ----------------------------------------------------------------------------- + + +def test_update_video_panel_context_renders_current_summary(monkeypatch, tmp_path: Path): + monkeypatch.setattr(cropping_mod, "Image", FakeImage) + monkeypatch.setattr(cropping_mod, "sync_crop_layer_autorefresh", lambda viewer, panel, refresh_callback: None) + monkeypatch.setattr( + cropping_mod, + "get_crop_source_summary", + lambda viewer: ( + "DLC crop layer", + cropping_mod.CropRectangleSpec( + viewer_crop=cropping_mod.ViewerCropCoords(values=(1, 10, 2, 20)), + config_crop=cropping_mod.DLCConfigCropCoords(values=(1, 10, 80, 98)), + ), + ), + ) + + image = FakeImage( + data=np.zeros((5, 20, 30), dtype=np.uint8), + metadata={"root": str(tmp_path / "dataset")}, + name="demo.mp4", + ) + + viewer = SimpleNamespace( + layers=FakeLayerList([image], active=image), + dims=SimpleNamespace(current_step=(2,), range=[(0, 5, 1), (0, 20, 1), (0, 100, 1)]), + ) + panel = DummyPanel() + + cropping_mod.update_video_panel_context(viewer, panel) + + assert "Frame 3/5" in panel.text + assert f"Output folder: {tmp_path / 'dataset'}" in panel.text + assert "Crop source: DLC crop layer" in panel.text + + +def test_execute_frame_extraction_keeps_new_labels_row_on_duplicate_index(monkeypatch, tmp_path: Path): + from napari.layers import Image, Points + + # Avoid writing a real image file through skimage; just create the output file. + monkeypatch.setattr( + cropping_mod, + "_write_image", + lambda arr, path: Path(path).write_bytes(b"fake-image"), + ) + + image = Image( + np.zeros((3, 20, 30), dtype=np.uint8), + name="demo.mp4", + metadata={"root": str(tmp_path)}, + ) + + output_path = tmp_path / "img1.png" + labels_path = tmp_path / "machinelabels-iter0.h5" + + # Existing row for the same extracted image path + idx = pd.MultiIndex.from_tuples([("tmp", "pytest", "img1.png")]) + df_prev = pd.DataFrame({"bp1": [111.0]}, index=idx) + df_prev.to_hdf(labels_path, key="df_with_missing") + + # New extracted row should overwrite the previous one + df_new = pd.DataFrame({"bp1": [222.0]}, index=idx) + + monkeypatch.setattr( + cropping_mod, + "_build_extracted_frame_labels_df", + lambda plan: (df_new, None), + ) + + points = Points( + np.empty((0, 3), dtype=float), + name="pts", + ) + + plan = cropping_mod.FrameExtractionPlan( + image_layer=image, + points_layer=points, + frame_index=1, + output_root=tmp_path, + output_path=output_path, + labels_path=labels_path, + export_labels=True, + viewer_crop=None, + ) + + written, note = cropping_mod.execute_frame_extraction(plan) + + assert note is None + assert labels_path in written + + df_written = pd.read_hdf(labels_path, key="df_with_missing") + assert float(df_written.iloc[0]["bp1"]) == 222.0 diff --git a/src/napari_deeplabcut/_tests/ui/test_dialogs.py b/src/napari_deeplabcut/_tests/ui/test_dialogs.py new file mode 100644 index 00000000..6b7951d1 --- /dev/null +++ b/src/napari_deeplabcut/_tests/ui/test_dialogs.py @@ -0,0 +1,680 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import numpy as np +import pytest +from napari.layers import Image, Points +from qtpy.QtCore import QPoint, Qt +from qtpy.QtWidgets import QDialog, QLabel, QPlainTextEdit, QPushButton, QScrollArea + +import napari_deeplabcut.ui.dialogs as ui_dialogs +from napari_deeplabcut.config.keybinds import iter_shortcuts +from napari_deeplabcut.ui.dialogs import ( + OverwriteConflictsDialog, + ProjectConfigPromptAction, + ShortcutRow, + Shortcuts, + Tutorial, + load_scorer_from_config, + maybe_confirm_overwrite, + prompt_for_project_config_for_save, +) + +# ----------------------------------------------------------------------------- +# Shortcuts +# ----------------------------------------------------------------------------- + + +class _DummyEmitter: + def __init__(self): + self._callbacks = [] + + def connect(self, callback): + self._callbacks.append(callback) + + def disconnect(self, callback): + self._callbacks.remove(callback) + + def emit(self, event=None): + for callback in list(self._callbacks): + callback(event) + + +class _DummyViewer: + def __init__(self, *, active=None): + self.layers = SimpleNamespace( + selection=SimpleNamespace( + active=active, + events=SimpleNamespace(active=_DummyEmitter()), + ) + ) + + +def _points_layer(name: str = "points"): + return Points(np.empty((0, 2)), name=name) + + +def _image_layer(name: str = "image"): + return Image(np.zeros((8, 8)), name=name) + + +def test_shortcuts_dialog_without_viewer_renders_registry(dialog_parent, qtbot): + dlg = Shortcuts(dialog_parent) + qtbot.addWidget(dlg) + + assert dlg.parent() is dialog_parent + assert dlg.windowTitle() == "Keyboard shortcuts" + assert dlg.testAttribute(Qt.WA_DeleteOnClose) + + scroll_areas = dlg.findChildren(QScrollArea) + assert len(scroll_areas) == 1 + + rows = dlg.findChildren(ShortcutRow) + assert len(rows) == len(tuple(iter_shortcuts())) + + assert ( + dlg.context_banner.text() + == "Showing all known shortcuts. Availability cannot be determined without a viewer context." + ) + + +def test_shortcuts_dialog_marks_points_shortcuts_unavailable_for_non_points_layer(dialog_parent, qtbot): + viewer = _DummyViewer(active=_image_layer("raw image")) + dlg = Shortcuts(dialog_parent, viewer=viewer) + qtbot.addWidget(dlg) + + assert "No active Points layer" in dlg.context_banner.text() + + rows = dlg.findChildren(ShortcutRow) + points_rows = [row for row in rows if row.spec.scope == "points-layer"] + global_rows = [row for row in rows if row.spec.scope == "global-points"] + + assert points_rows, "expected at least one points-layer shortcut row" + assert global_rows, "expected at least one global-points shortcut row" + + for row in points_rows: + assert row.graphicsEffect().opacity() == 0.45 + assert "No active Points layer." in row.toolTip() + + for row in global_rows: + assert row.graphicsEffect().opacity() == 1.0 + assert "No active Points layer." not in row.toolTip() + + +def test_shortcuts_dialog_updates_when_active_layer_changes_and_escapes_layer_name(dialog_parent, qtbot): + viewer = _DummyViewer(active=_image_layer("raw image")) + dlg = Shortcuts(dialog_parent, viewer=viewer) + qtbot.addWidget(dlg) + + # Start with a non-Points layer: points shortcuts are dimmed. + row = next(row for row in dlg.findChildren(ShortcutRow) if row.spec.scope == "points-layer") + assert row.graphicsEffect().opacity() == 0.45 + assert "No active Points layer." in row.toolTip() + + # Switch to a Points layer with characters that must be escaped in rich text. + viewer.layers.selection.active = _points_layer("a & layer") + viewer.layers.selection.events.active.emit() + + banner = dlg.context_banner.text() + assert "Active Points layer:" in banner + assert "<bizarre>" in banner + assert "& layer" in banner + assert "a & layer" not in banner # raw rich-text-breaking text should not appear + + assert row.graphicsEffect().opacity() == 1.0 + assert "No active Points layer." not in row.toolTip() + + +def test_shortcuts_dialog_disconnects_from_viewer_on_close(dialog_parent, qtbot): + viewer = _DummyViewer(active=_image_layer("raw image")) + emitter = viewer.layers.selection.events.active + + dlg = Shortcuts(dialog_parent, viewer=viewer) + qtbot.addWidget(dlg) + + assert len(emitter._callbacks) == 1 + + dlg.show() + dlg.close() + qtbot.wait(0) + + assert emitter._callbacks == [] + + +# ----------------------------------------------------------------------------- +# Tutorial +# ----------------------------------------------------------------------------- + + +def test_tutorial_initial_state(dialog_parent, qtbot): + dlg = Tutorial(dialog_parent) + qtbot.addWidget(dlg) + + assert dlg.parent() is dialog_parent + assert dlg.isModal() + assert dlg._current_tip == -1 + assert dlg.count.text() == "" + + # initial nav state with "intro" screen before first tip + assert not dlg.prev_button.isEnabled() + assert dlg.next_button.isEnabled() + + assert dlg.message.openExternalLinks() + assert dlg.message.textInteractionFlags() & Qt.LinksAccessibleByMouse + + +def test_tutorial_next_advances_to_first_tip_and_updates_position(dialog_parent, qtbot): + dlg = Tutorial(dialog_parent) + qtbot.addWidget(dlg) + + qtbot.mouseClick(dlg.next_button, Qt.LeftButton) + + assert dlg._current_tip == 0 + + # first real tip still has prev disabled, next enabled + assert not dlg.prev_button.isEnabled() + assert dlg.next_button.isEnabled() + + xrel, yrel = dlg._tips[0].pos + geom = dialog_parent.geometry() + expected = QPoint( + int(geom.left() + geom.width() * xrel), + int(geom.top() + geom.height() * yrel), + ) + assert dlg.pos() == expected + + +def test_tutorial_navigation_enables_and_disables_buttons(dialog_parent, qtbot): + dlg = Tutorial(dialog_parent) + qtbot.addWidget(dlg) + + qtbot.mouseClick(dlg.next_button, Qt.LeftButton) # tip 1 + qtbot.mouseClick(dlg.next_button, Qt.LeftButton) # tip 2 + + assert dlg._current_tip == 1 + assert dlg.prev_button.isEnabled() + assert dlg.next_button.isEnabled() + + qtbot.mouseClick(dlg.prev_button, Qt.LeftButton) + assert dlg._current_tip == 0 + assert not dlg.prev_button.isEnabled() + assert dlg.next_button.isEnabled() + + +def test_tutorial_last_tip_has_no_emoji_prefix_and_disables_next(dialog_parent, qtbot): + dlg = Tutorial(dialog_parent) + qtbot.addWidget(dlg) + + for _ in range(len(dlg._tips)): + qtbot.mouseClick(dlg.next_button, Qt.LeftButton) + + assert dlg._current_tip == len(dlg._tips) - 1 + assert dlg.prev_button.isEnabled() + assert not dlg.next_button.isEnabled() + + # last tip should not be prefixed with the emoji + assert "napari-deeplabcut" in dlg.message.text() + + +# ----------------------------------------------------------------------------- +# OverwriteConflictsDialog +# ----------------------------------------------------------------------------- + + +def test_overwrite_conflicts_dialog_smoke(dialog_parent, qtbot): + dlg = OverwriteConflictsDialog( + dialog_parent, + title="Overwrite warning", + summary="Saving will overwrite existing keypoints in the destination file.", + layer_text="points", + dest_text="/tmp/labels.h5", + affected_text="3 keypoint overwrite(s) across 2 frame(s)/image(s).", + details="img001.png -> nose, tail\nimg002.png -> paw", + ) + qtbot.addWidget(dlg) + + assert dlg.windowTitle() == "Overwrite warning" + assert dlg.isModal() + assert dlg.text is not None + assert isinstance(dlg.text, QPlainTextEdit) + assert dlg.text.isReadOnly() + assert dlg.text.toPlainText() == "img001.png -> nose, tail\nimg002.png -> paw" + + assert isinstance(dlg.cancel_btn, QPushButton) + assert isinstance(dlg.overwrite_btn, QPushButton) + assert dlg.overwrite_btn.isDefault() + assert dlg.overwrite_btn.autoDefault() + + labels = [w.text() for w in dlg.findChildren(QLabel)] + assert any("Saving will overwrite existing keypoints" in text for text in labels) + assert any("Layer: points" == text for text in labels) + assert any("Destination: /tmp/labels.h5" == text for text in labels) + assert any("Affected: 3 keypoint overwrite(s) across 2 frame(s)/image(s)." == text for text in labels) + assert any("Conflicts (frame/image" in text for text in labels) + + +def test_overwrite_conflicts_dialog_cancel_button_rejects(dialog_parent, qtbot): + dlg = OverwriteConflictsDialog( + dialog_parent, + title="Overwrite warning", + summary="summary", + layer_text="layer", + dest_text="dest", + affected_text="affected", + details="details", + ) + qtbot.addWidget(dlg) + + dlg.show() + qtbot.mouseClick(dlg.cancel_btn, Qt.LeftButton) + + assert dlg.result() == QDialog.Rejected + + +def test_overwrite_conflicts_dialog_overwrite_button_accepts(dialog_parent, qtbot): + dlg = OverwriteConflictsDialog( + dialog_parent, + title="Overwrite warning", + summary="summary", + layer_text="layer", + dest_text="dest", + affected_text="affected", + details="details", + ) + qtbot.addWidget(dlg) + + dlg.show() + qtbot.mouseClick(dlg.overwrite_btn, Qt.LeftButton) + + assert dlg.result() == QDialog.Accepted + + +# Confirm +def test_overwrite_conflicts_dialog_confirm_returns_true_on_accept(monkeypatch, dialog_parent): + def fake_exec(self): + return QDialog.Accepted + + monkeypatch.setattr(OverwriteConflictsDialog, "exec_", fake_exec) + + result = OverwriteConflictsDialog.confirm( + dialog_parent, + summary="summary", + layer_text="layer", + dest_text="dest", + affected_text="affected", + details="details", + title="Overwrite warning", + ) + + assert result is True + + +def test_overwrite_conflicts_dialog_confirm_returns_false_on_reject(monkeypatch, dialog_parent): + def fake_exec(self): + return QDialog.Rejected + + monkeypatch.setattr(OverwriteConflictsDialog, "exec_", fake_exec) + + result = OverwriteConflictsDialog.confirm( + dialog_parent, + summary="summary", + layer_text="layer", + dest_text="dest", + affected_text="affected", + details="details", + title="Overwrite warning", + ) + + assert result is False + + +# ----------------------------------------------------------------------------- +# maybe_confirm_overwrite +# ----------------------------------------------------------------------------- + + +def test_maybe_confirm_overwrite_returns_true_when_no_conflicts(monkeypatch, dialog_parent): + report = SimpleNamespace( + has_conflicts=False, + layer_name="layer", + destination_path="/tmp/file.h5", + n_overwrites=0, + n_frames=0, + details_text="", + ) + + called = [] + + monkeypatch.setattr( + "napari_deeplabcut.ui.dialogs.get_overwrite_confirmation_enabled", + lambda: True, + ) + monkeypatch.setattr( + "napari_deeplabcut.ui.dialogs.OverwriteConflictsDialog.confirm", + lambda *args, **kwargs: called.append((args, kwargs)), + ) + + result = maybe_confirm_overwrite(dialog_parent, report) + + assert result is True + assert called == [] + + +def test_maybe_confirm_overwrite_returns_true_when_confirmation_disabled(monkeypatch, dialog_parent): + report = SimpleNamespace( + has_conflicts=True, + layer_name="layer", + destination_path="/tmp/file.h5", + n_overwrites=3, + n_frames=2, + details_text="details", + ) + + called = [] + + monkeypatch.setattr( + "napari_deeplabcut.ui.dialogs.get_overwrite_confirmation_enabled", + lambda: False, + ) + monkeypatch.setattr( + "napari_deeplabcut.ui.dialogs.OverwriteConflictsDialog.confirm", + lambda *args, **kwargs: called.append((args, kwargs)), + ) + + result = maybe_confirm_overwrite(dialog_parent, report) + + assert result is True + assert called == [] + + +def test_maybe_confirm_overwrite_delegates_to_confirm(monkeypatch, dialog_parent): + report = SimpleNamespace( + has_conflicts=True, + layer_name="pose-layer", + destination_path="/tmp/labels.h5", + n_overwrites=3, + n_frames=2, + details_text="img001.png -> nose, tail", + ) + + monkeypatch.setattr( + "napari_deeplabcut.ui.dialogs.get_overwrite_confirmation_enabled", + lambda: True, + ) + + captured = {} + + def fake_confirm(parent, **kwargs): + captured["parent"] = parent + captured["kwargs"] = kwargs + return False + + monkeypatch.setattr( + "napari_deeplabcut.ui.dialogs.OverwriteConflictsDialog.confirm", + fake_confirm, + ) + + result = maybe_confirm_overwrite(dialog_parent, report) + + assert result is False + assert captured["parent"] is dialog_parent + assert captured["kwargs"] == { + "summary": "Saving will overwrite existing keypoints in the destination file.", + "layer_text": "pose-layer", + "dest_text": "/tmp/labels.h5", + "affected_text": "3 keypoint overwrite(s) across 2 frame(s)/image(s).", + "details": "img001.png -> nose, tail", + } + + +# ----------------------------------------------------------------------------- +# Project config / scorer resolution dialogs +# ----------------------------------------------------------------------------- + + +class _FakeButton: + def __init__(self, text=None, role=None): + self.text = text + self.role = role + + +class _FakeMessageBox: + Question = object() + YesRole = object() + NoRole = object() + Cancel = object() + Rejected = 0 + + planned_click = "yes" # "yes" | "no" | "cancel" + warnings = [] + last_instance = None + + def __init__(self, parent=None): + self.parent = parent + self._buttons = [] + self._clicked = None + self.window_title = None + self.text = None + self.default_button = None + type(self).last_instance = self + + def setIcon(self, icon): + self.icon = icon + + def setWindowTitle(self, title): + self.window_title = title + + def setText(self, text): + self.text = text + + def addButton(self, *args): + if len(args) == 2: + text, role = args + btn = _FakeButton(text=text, role=role) + else: + btn = _FakeButton(text="cancel", role=None) + self._buttons.append(btn) + return btn + + def setDefaultButton(self, btn): + self.default_button = btn + + def exec_(self): + if self.planned_click == "cancel": + self._clicked = None + return self.Rejected + + if self.planned_click == "no": + self._clicked = next((b for b in self._buttons if b.role is self.NoRole), None) + return 1 + + self._clicked = next((b for b in self._buttons if b.role is self.YesRole), None) + return 1 + + def clickedButton(self): + return self._clicked + + @staticmethod + def warning(parent, title, text): + _FakeMessageBox.warnings.append((title, text)) + + +class _FakeFileDialog: + next_result = ("", "") + calls = [] + + @staticmethod + def getOpenFileName(*args, **kwargs): + _FakeFileDialog.calls.append((args, kwargs)) + return _FakeFileDialog.next_result + + +@pytest.fixture +def fake_config_prompt_qt(monkeypatch): + _FakeMessageBox.planned_click = "yes" + _FakeMessageBox.warnings = [] + _FakeMessageBox.last_instance = None + _FakeFileDialog.next_result = ("", "") + _FakeFileDialog.calls = [] + monkeypatch.setattr(ui_dialogs, "QMessageBox", _FakeMessageBox) + monkeypatch.setattr(ui_dialogs, "QFileDialog", _FakeFileDialog) + return _FakeMessageBox, _FakeFileDialog + + +def test_load_scorer_from_config_returns_trimmed_scorer(tmp_path): + cfg = tmp_path / "config.yaml" + cfg.write_text("scorer: ' John '\n", encoding="utf-8") + + scorer = load_scorer_from_config(cfg) + + assert scorer == "John" + + +def test_load_scorer_from_config_returns_none_when_missing(tmp_path): + cfg = tmp_path / "config.yaml" + cfg.write_text("dotsize: 5\npcutoff: 0.6\n", encoding="utf-8") + + scorer = load_scorer_from_config(cfg) + + assert scorer is None + + +def test_load_scorer_from_config_returns_none_when_blank(tmp_path): + cfg = tmp_path / "config.yaml" + cfg.write_text("scorer: ' '\n", encoding="utf-8") + + scorer = load_scorer_from_config(cfg) + + assert scorer is None + + +def test_prompt_for_project_config_for_save_returns_skip_when_user_chooses_no(fake_config_prompt_qt): + fake_messagebox, fake_filedialog = fake_config_prompt_qt + fake_messagebox.planned_click = "no" + + result = prompt_for_project_config_for_save(parent=None) + + assert result.action is ProjectConfigPromptAction.SKIP + assert result.config_path is None + assert result.scorer is None + assert fake_filedialog.calls == [] + + +def test_prompt_for_project_config_for_save_returns_cancel_when_messagebox_cancelled(fake_config_prompt_qt): + fake_messagebox, fake_filedialog = fake_config_prompt_qt + fake_messagebox.planned_click = "cancel" + + result = prompt_for_project_config_for_save(parent=None) + + assert result.action is ProjectConfigPromptAction.CANCEL + assert result.config_path is None + assert result.scorer is None + assert fake_filedialog.calls == [] + + +def test_prompt_for_project_config_for_save_resolve_scorer_valid_config(fake_config_prompt_qt, tmp_path): + fake_messagebox, fake_filedialog = fake_config_prompt_qt + fake_messagebox.planned_click = "yes" + + cfg = tmp_path / "config.yaml" + cfg.write_text("scorer: John\n", encoding="utf-8") + fake_filedialog.next_result = (str(cfg), "DeepLabCut config (config.yaml)") + + result = prompt_for_project_config_for_save(parent=None, resolve_scorer=True) + + assert result.action is ProjectConfigPromptAction.ASSOCIATE + assert result.config_path == str(cfg) + assert result.scorer == "John" + assert fake_messagebox.warnings == [] + + +def test_prompt_for_project_config_for_save_resolve_scorer_invalid_config_missing_scorer( + fake_config_prompt_qt, + tmp_path, +): + fake_messagebox, fake_filedialog = fake_config_prompt_qt + fake_messagebox.planned_click = "yes" + + cfg = tmp_path / "config.yaml" + cfg.write_text("dotsize: 8\n", encoding="utf-8") + fake_filedialog.next_result = (str(cfg), "DeepLabCut config (config.yaml)") + + result = prompt_for_project_config_for_save(parent=None, resolve_scorer=True) + + assert result.action is ProjectConfigPromptAction.CANCEL + assert result.config_path is None + assert result.scorer is None + + assert len(fake_messagebox.warnings) == 1 + title, text = fake_messagebox.warnings[0] + assert title == "Invalid project configuration" + assert "does not define a valid non-empty 'scorer' field" in text + assert str(cfg) in text + + +def test_prompt_for_project_config_for_save_resolve_scorer_unreadable_config( + monkeypatch, + fake_config_prompt_qt, + tmp_path, +): + fake_messagebox, fake_filedialog = fake_config_prompt_qt + fake_messagebox.planned_click = "yes" + + cfg = tmp_path / "config.yaml" + cfg.write_text("scorer: John\n", encoding="utf-8") + fake_filedialog.next_result = (str(cfg), "DeepLabCut config (config.yaml)") + + def _boom(*args, **kwargs): + raise ValueError("bad yaml") + + monkeypatch.setattr(ui_dialogs, "load_scorer_from_config", _boom) + + result = prompt_for_project_config_for_save(parent=None, resolve_scorer=True) + + assert result.action is ProjectConfigPromptAction.CANCEL + assert result.config_path is None + assert result.scorer is None + + assert len(fake_messagebox.warnings) == 1 + title, text = fake_messagebox.warnings[0] + assert title == "Invalid project configuration" + assert "could not be read as a DeepLabCut config.yaml" in text + assert str(cfg) in text + + +def test_prompt_for_project_config_for_save_uses_custom_text(fake_config_prompt_qt): + fake_messagebox, _ = fake_config_prompt_qt + fake_messagebox.planned_click = "cancel" + + prompt_for_project_config_for_save( + parent=None, + window_title="Locate config", + message="Pick a config for scorer resolution", + choose_button_text="Browse…", + skip_button_text="Continue without config", + ) + + inst = fake_messagebox.last_instance + assert inst is not None + assert inst.window_title == "Locate config" + assert inst.text == "Pick a config for scorer resolution" + assert [b.text for b in inst._buttons[:2]] == ["Browse…", "Continue without config"] + + +def test_warn_invalid_config_for_scorer_auto_found_unreadable(fake_config_prompt_qt): + fake_messagebox, _ = fake_config_prompt_qt + + ui_dialogs.warn_invalid_config_for_scorer( + parent=None, + config_path="/tmp/config.yaml", + reason="unreadable", + auto_found=True, + ) + + assert len(fake_messagebox.warnings) == 1 + title, text = fake_messagebox.warnings[0] + assert title == "Invalid project configuration" + assert "found automatically" in text + assert "could not be read" in text + assert "/tmp/config.yaml" in text diff --git a/src/napari_deeplabcut/_tests/ui/test_layer_stats.py b/src/napari_deeplabcut/_tests/ui/test_layer_stats.py new file mode 100644 index 00000000..ae01179c --- /dev/null +++ b/src/napari_deeplabcut/_tests/ui/test_layer_stats.py @@ -0,0 +1,27 @@ +from napari_deeplabcut.ui.layer_stats import LayerStatusPanel + + +def test_set_invalid_points_layer_disables_slider_and_updates_text(qtbot): + panel = LayerStatusPanel() + qtbot.addWidget(panel) + + panel.set_invalid_points_layer() + + assert panel._progress_value.text() == "Active layer is not a DLC keypoints layer" + assert not panel._size_slider.isEnabled() + assert not panel._size_value.isEnabled() + + +def test_set_no_active_points_layer_disables_slider_and_value_label(qtbot): + panel = LayerStatusPanel() + qtbot.addWidget(panel) + + panel.set_point_size_enabled(True) + assert panel._size_slider.isEnabled() + assert panel._size_value.isEnabled() + + panel.set_no_active_points_layer() + + assert panel._progress_value.text() == "No active keypoints layer" + assert not panel._size_slider.isEnabled() + assert not panel._size_value.isEnabled() diff --git a/src/napari_deeplabcut/_tests/utils/test_deprecation.py b/src/napari_deeplabcut/_tests/utils/test_deprecation.py new file mode 100644 index 00000000..e9fd8c91 --- /dev/null +++ b/src/napari_deeplabcut/_tests/utils/test_deprecation.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import pytest + +from napari_deeplabcut.utils.deprecations import ( + DeprecationMode, + NapariDLCDeprecationWarning, + deprecated, + deprecation_mode, +) + + +def test_deprecated_function_warns(): + @deprecated( + since="0.9.0", + remove_in="1.1.0", + replacement="new_func", + ) + def old_func(): + return 123 + + with pytest.warns(NapariDLCDeprecationWarning, match="new_func"): + assert old_func() == 123 + + +def test_deprecated_function_can_error_via_context(): + @deprecated(since="0.9.0", remove_in="1.1.0") + def old_func(): + return 123 + + with pytest.raises(RuntimeError, match="deprecated"): + with deprecation_mode(DeprecationMode.ERROR): + old_func() + + +def test_deprecated_class_warns(): + @deprecated(since="0.9.0", replacement="NewThing") + class OldThing: + def __init__(self, x): + self.x = x + + with pytest.warns(NapariDLCDeprecationWarning, match="NewThing"): + obj = OldThing(5) + + assert obj.x == 5 diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index b391dfff..293c9270 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -1,559 +1,169 @@ +"""Main widget and controls for napari-deeplabcut, including the tutorial and shortcuts windows. + +NOTE: This file is generally already too long. For future development, please consider: +- Moving existing responsibilities out into separate modules (existing or new) +- Avoiding adding anything that is not strictly related to : + - Building the final UI (blocks can be moved to ui/ for better organization) + - Wiring to the core plugin functionality (e.g. via signals/slots, method calls, etc.) + - Anything that requires the full widget+viewer+signal/event context to function properly + - Similarly, test_widgets.py is a bit of a default drawer right now, please create new tests in _tests/ui +""" + +# src/napari_deeplabcut/_widgets.py +from __future__ import annotations + import logging -import os -from collections import defaultdict, namedtuple -from collections.abc import Sequence +from contextlib import contextmanager from copy import deepcopy from datetime import datetime from functools import cached_property, partial -from math import ceil, log10 from pathlib import Path from types import MethodType import matplotlib.pyplot as plt -import matplotlib.style as mplstyle -import napari import numpy as np -import pandas as pd -from matplotlib.backends.backend_qtagg import FigureCanvas, NavigationToolbar2QT -from napari._qt.widgets.qt_welcome import QtWelcomeLabel -from napari.layers import Image, Points, Shapes, Tracks -from napari.layers.points._points_key_bindings import register_points_action -from napari.layers.utils import color_manager -from napari.layers.utils.layer_utils import _features_to_properties +from napari.layers import Image, Points, Tracks from napari.utils.events import Event from napari.utils.history import get_save_history, update_save_history -from qtpy.QtCore import QPoint, QSettings, QSize, Qt, QTimer, Signal -from qtpy.QtGui import QAction, QCursor, QIcon, QPainter -from qtpy.QtSvgWidgets import QSvgWidget +from pydantic import ValidationError +from qtpy.QtCore import QSettings, QSignalBlocker, Qt, QTimer +from qtpy.QtGui import QAction from qtpy.QtWidgets import ( QButtonGroup, QCheckBox, - QComboBox, - QDialog, QFileDialog, QGridLayout, QGroupBox, QHBoxLayout, + QInputDialog, QLabel, QMessageBox, QPushButton, QRadioButton, - QScrollArea, - QSizePolicy, - QSlider, - QStyle, - QStyleOption, QVBoxLayout, QWidget, ) -from napari_deeplabcut import keypoints -from napari_deeplabcut._reader import ( - _load_config, - _load_superkeypoints, - _load_superkeypoints_diagram, - is_video, +import napari_deeplabcut.core.io as io +from napari_deeplabcut import misc +from napari_deeplabcut.config import settings +from napari_deeplabcut.config.keybinds import ( + install_global_points_keybindings, + install_points_layer_keybindings, ) -from napari_deeplabcut._writer import _form_df, _write_config, _write_image -from napari_deeplabcut.misc import ( - build_color_cycles, - canonicalize_path, - encode_categories, - guarantee_multiindex_rows, - remap_array, +from napari_deeplabcut.config.models import DLCHeaderModel, ImageMetadata, PointsMetadata +from napari_deeplabcut.core import keypoints +from napari_deeplabcut.core.config_sync import ( + load_point_size_from_config, + resolve_config_path_from_layer, + save_point_size_to_config, ) - -Tip = namedtuple("Tip", ["msg", "pos"]) - - -class Shortcuts(QDialog): - """Opens a window displaying available napari-deeplabcut shortcuts""" - - def __init__(self, parent): - super().__init__(parent=parent) - self.setParent(parent) - self.setWindowTitle("Shortcuts") - - image_path = str(Path(__file__).parent / "assets" / "napari_shortcuts.svg") - - vlayout = QVBoxLayout() - svg_widget = QSvgWidget(image_path) - svg_widget.setStyleSheet("background-color: white;") - vlayout.addWidget(svg_widget) - self.setLayout(vlayout) - - -class Tutorial(QDialog): - def __init__(self, parent): - super().__init__(parent=parent) - self.setParent(parent) - self.setWindowTitle("Tutorial") - self.setModal(True) - self.setStyleSheet("background:#361AE5") - self.setAttribute(Qt.WA_DeleteOnClose) - self.setWindowOpacity(0.95) - self.setWindowFlags(self.windowFlags() | Qt.WindowCloseButtonHint) - - self._current_tip = -1 - self._tips = [ - Tip( - "Load a folder of annotated data\n" - "(and optionally a config file if labeling from scratch)\n" - "from the menu File > Open File or Open Folder.\n" - "Alternatively, files and folders of images can be dragged\n" - "and dropped onto the main window.", - (0.35, 0.15), - ), - Tip( - "Data layers will be listed at the bottom left;\n" - "their visibility can be toggled by clicking on the small eye icon.", - (0.1, 0.65), - ), - Tip( - "Corresponding layer controls can be found at the top left.\n" - "Switch between labeling and selection mode using the numeric keys 2 and 3,\n" - "or clicking on the + or -> icons.", - (0.1, 0.2), - ), - Tip( - "There are three keypoint labeling modes:\nthe key M can be used to cycle between them.", - (0.65, 0.05), - ), - Tip( - "When done labeling, save your data by selecting the Points layer\n" - "and hitting Ctrl+S (or File > Save Selected Layer(s)...).", - (0.1, 0.65), - ), - Tip( - "Read more at napari-deeplabcut", - (0.4, 0.4), - ), - ] - - vlayout = QVBoxLayout() - self.message = QLabel("💡\n\nLet's get started with a quick walkthrough!") - self.message.setTextInteractionFlags(Qt.LinksAccessibleByMouse) - self.message.setOpenExternalLinks(True) - vlayout.addWidget(self.message) - - nav_layout = QHBoxLayout() - self.prev_button = QPushButton("<") - self.prev_button.clicked.connect(self.prev_tip) - nav_layout.addWidget(self.prev_button) - self.next_button = QPushButton(">") - self.next_button.clicked.connect(self.next_tip) - nav_layout.addWidget(self.next_button) - - self.update_nav_buttons() - - hlayout = QHBoxLayout() - self.count = QLabel("") - hlayout.addWidget(self.count) - hlayout.addLayout(nav_layout) - vlayout.addLayout(hlayout) - self.setLayout(vlayout) - - def prev_tip(self): - self._current_tip = (self._current_tip - 1) % len(self._tips) - self.update_tip() - self.update_nav_buttons() - - def next_tip(self): - self._current_tip = (self._current_tip + 1) % len(self._tips) - self.update_tip() - self.update_nav_buttons() - - def update_tip(self): - tip = self._tips[self._current_tip] - msg = tip.msg - if self._current_tip < len(self._tips) - 1: # No emoji in the last tip otherwise the hyperlink breaks - msg = "💡\n\n" + msg - self.message.setText(msg) - self.count.setText(f"Tip {self._current_tip + 1}|{len(self._tips)}") - self.adjustSize() - xrel, yrel = tip.pos - geom = self.parent().geometry() - p = QPoint( - int(geom.left() + geom.width() * xrel), - int(geom.top() + geom.height() * yrel), - ) - self.move(p) - - def update_nav_buttons(self): - self.prev_button.setEnabled(self._current_tip > 0) - self.next_button.setEnabled(self._current_tip < len(self._tips) - 1) - - -def _get_and_try_preferred_reader( - self, - dialog, - *args, -): +from napari_deeplabcut.core.conflicts import compute_overwrite_report_for_points_save +from napari_deeplabcut.core.layer_versioning import mark_layer_presentation_changed +from napari_deeplabcut.core.layers import ( + compute_label_progress, + find_relevant_image_layer, + get_first_points_layer, + get_points_layer_with_tables, + get_uniform_point_size, + infer_folder_display_name, + is_machine_layer, + set_uniform_point_size, +) +from napari_deeplabcut.core.metadata import ( + MergePolicy, + apply_project_paths_override_to_points_meta, + infer_image_root, + migrate_points_layer_metadata, + read_points_meta, + sync_points_from_image, + write_points_meta, +) +from napari_deeplabcut.core.project_paths import ( + PathMatchPolicy, + coerce_paths_to_dlc_row_keys, + dataset_folder_has_files, + find_nearest_config, + resolve_project_root_from_config, + target_dataset_folder_for_config, +) +from napari_deeplabcut.core.provenance import ( + apply_gt_save_target, + is_projectless_folder_association_candidate, + requires_gt_promotion, + suggest_human_placeholder, +) +from napari_deeplabcut.core.remap import remap_layer_data_by_paths +from napari_deeplabcut.core.sidecar import ( + get_default_scorer, + set_default_scorer, +) +from napari_deeplabcut.core.trails import TrailsController, safe_folder_anchor_from_points_layer +from napari_deeplabcut.napari_compat import ( + apply_points_layer_ui_tweaks, + install_add_wrapper, + install_paste_patch, + patch_color_manager_guess_continuous, + register_points_action, +) +from napari_deeplabcut.napari_compat.points_layer import make_paste_data +from napari_deeplabcut.ui import dialogs as ui_dialogs +from napari_deeplabcut.ui.color_scheme_display import ColorSchemePanel +from napari_deeplabcut.ui.cropping import ( + build_video_action_menu, + handle_apply_crop_toggled, + resolve_project_path_from_image_layer, + run_extract_current_frame, + run_store_crop_coordinates, + update_video_panel_context, +) +from napari_deeplabcut.ui.dialogs import Shortcuts, Tutorial +from napari_deeplabcut.ui.labels_and_dropdown import ( + DropdownMenu, + KeypointsDropdownMenu, +) +from napari_deeplabcut.ui.layer_stats import LayerStatusPanel +from napari_deeplabcut.ui.plots.trajectory import KeypointMatplotlibCanvas + +logger = logging.getLogger("napari-deeplabcut._widgets") +# logger.setLevel(logging.DEBUG) # FIXME @C-Achard temp remove before merging + + +def _prompt_for_scorer(parent_widget, *, anchor: str, suggested: str) -> str | None: + """Prompt user for a scorer name. Returns non-empty string or None if cancelled.""" + text, ok = QInputDialog.getText( + parent_widget, + "Choose scorer", + "No DLC config.yaml scorer found.\n" + "Please enter a scorer name for the CollectedData file.\n\n" + "Tip: Use your name or a stable lab identifier.\n" + "(We strongly discourage keeping the generic 'human_xxxxxx'.)", + text=suggested, + ) + if not ok: + return None + scorer = (text or "").strip() + if not scorer: + return None + return scorer + + +@contextmanager +def _temporary_layer_metadata(layer: Points, metadata: dict): + old_metadata = dict(layer.metadata or {}) + layer.metadata = metadata try: - self.viewer.open( - dialog._current_file, - plugin="napari-deeplabcut", - ) - except ValueError: - self.viewer.open( - dialog._current_file, - plugin="builtins", - ) - - -# Hack to avoid napari's silly variable type guess, -# where property is understood as continuous if -# there are more than 16 unique categories... -def guess_continuous(property): - if issubclass(property.dtype.type, np.floating): - return True - else: - return False - - -color_manager.guess_continuous = guess_continuous - - -def _paste_data(self, store): - """Paste only currently unannotated data.""" - features = self._clipboard.pop("features", None) - if features is None: - return - - unannotated = [ - keypoints.Keypoint(label, id_) not in store.annotated_keypoints - for label, id_ in zip(features["label"], features["id"], strict=False) - ] - if not any(unannotated): - return - - new_features = features.iloc[unannotated] - indices_ = self._clipboard.pop("indices") - text_ = self._clipboard.pop("text") - self._clipboard = {k: v[unannotated] for k, v in self._clipboard.items()} - self._clipboard["features"] = new_features - self._clipboard["indices"] = indices_ - if text_ is not None: - new_text = { - "string": text_["string"][unannotated], - "color": text_["color"], - } - self._clipboard["text"] = new_text - - npoints = len(self._view_data) - totpoints = len(self.data) - - if len(self._clipboard.keys()) > 0: - not_disp = self._slice_input.not_displayed - data = deepcopy(self._clipboard["data"]) - offset = [self._slice_indices[i] - self._clipboard["indices"][i] for i in not_disp] - data[:, not_disp] = data[:, not_disp] + np.array(offset) - self._data = np.append(self.data, data, axis=0) - self._shown = np.append(self.shown, deepcopy(self._clipboard["shown"]), axis=0) - self._size = np.append(self.size, deepcopy(self._clipboard["size"]), axis=0) - self._symbol = np.append(self.symbol, deepcopy(self._clipboard["symbol"]), axis=0) - - self._feature_table.append(self._clipboard["features"]) - - self.text._paste(**self._clipboard["text"]) - - self._edge_width = np.append( - self.edge_width, - deepcopy(self._clipboard["edge_width"]), - axis=0, - ) - self._edge._paste( - colors=self._clipboard["edge_color"], - properties=_features_to_properties(self._clipboard["features"]), - ) - self._face._paste( - colors=self._clipboard["face_color"], - properties=_features_to_properties(self._clipboard["features"]), - ) - - self._selected_view = list(range(npoints, npoints + len(self._clipboard["data"]))) - self._selected_data = set(range(totpoints, totpoints + len(self._clipboard["data"]))) - self.refresh() - - -# Hack to save a KeyPoints layer without showing the Save dialog -def _save_layers_dialog(self, selected=False): - """Save layers (all or selected) to disk, using ``LayerList.save()``. - Parameters - ---------- - selected : bool - If True, only layers that are selected in the viewer will be saved. - By default, all layers are saved. - """ - selected_layers = list(self.viewer.layers.selection) - msg = "" - if not len(self.viewer.layers): - msg = "There are no layers in the viewer to save." - elif selected and not len(selected_layers): - msg = "Please select a Points layer to save." - if msg: - QMessageBox.warning(self, "Nothing to save", msg, QMessageBox.Ok) - return - if len(selected_layers) == 1 and isinstance(selected_layers[0], Points): - self.viewer.layers.save("", selected=True, plugin="napari-deeplabcut") - self.viewer.status = "Data successfully saved" - else: - dlg = QFileDialog() - hist = get_save_history() - dlg.setHistory(hist) - filename, _ = dlg.getSaveFileName( - caption=f"Save {'selected' if selected else 'all'} layers", - dir=hist[0], # home dir by default - ) - if filename: - self.viewer.layers.save(filename, selected=selected) - else: - return - self._is_saved = True - self.last_saved_label.setText(f"Last saved at {str(datetime.now().time()).split('.')[0]}") - self.last_saved_label.show() - - -def on_close(self, event, widget): - if widget._stores and not widget._is_saved: - choice = QMessageBox.warning( - widget, - "Warning", - "Data were not saved. Are you certain you want to leave?", - QMessageBox.Yes | QMessageBox.No, - ) - if choice == QMessageBox.Yes: - event.accept() - else: - event.ignore() - else: - event.accept() - - -# Class taken from https://github.com/matplotlib/napari-matplotlib/blob/53aa5ec95c1f3901e21dedce8347d3f95efe1f79/src/napari_matplotlib/base.py#L309 -class NapariNavigationToolbar(NavigationToolbar2QT): - """Custom Toolbar style for Napari.""" - - def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] - super().__init__(*args, **kwargs) - self.setIconSize(QSize(28, 28)) - - def _update_buttons_checked(self) -> None: - """Update toggle tool icons when selected/unselected.""" - super()._update_buttons_checked() - icon_dir = self.parentWidget()._get_path_to_icon() - - # changes pan/zoom icons depending on state (checked or not) - if "pan" in self._actions: - if self._actions["pan"].isChecked(): - self._actions["pan"].setIcon(QIcon(os.path.join(icon_dir, "Pan_checked.png"))) - else: - self._actions["pan"].setIcon(QIcon(os.path.join(icon_dir, "Pan.png"))) - if "zoom" in self._actions: - if self._actions["zoom"].isChecked(): - self._actions["zoom"].setIcon(QIcon(os.path.join(icon_dir, "Zoom_checked.png"))) - else: - self._actions["zoom"].setIcon(QIcon(os.path.join(icon_dir, "Zoom.png"))) - - -class KeypointMatplotlibCanvas(QWidget): - """ - Class containing the trajectory plot using matplotlib. - Shown if selected at the bottom of the screen - Uses keypoints from a specified range of frames to plot them on a t-y axis. - """ - - # FIXME : y axis should be reversed due to napari using top-left as origin - def __init__(self, napari_viewer, parent=None): - super().__init__(parent=parent) - - self.viewer = napari_viewer - with mplstyle.context(self.mpl_style_sheet_path): - self.canvas = FigureCanvas() - self.canvas.figure.set_size_inches(4, 2, forward=True) - self.canvas.figure.set_layout_engine("constrained") - self.ax = self.canvas.figure.subplots() - self.toolbar = NapariNavigationToolbar(self.canvas, parent=self) - self._replace_toolbar_icons() - self.canvas.mpl_connect("button_press_event", self.on_doubleclick) - self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--") - self.ax.set_xlabel("Frame") - self.ax.set_ylabel("Y position") - # Add a slot to specify the range of frames to plot - self.slider = QSlider(Qt.Horizontal) - self.slider.setMinimum(50) - self.slider.setMaximum(10000) - self.slider.setValue(50) - self.slider.setToolTip("Adjust the range of frames to display around the current frame") - self.slider.setTickPosition(QSlider.TicksBelow) - self.slider.setTickInterval(50) - self.slider_value = QLabel(str(self.slider.value())) - self._window = self.slider.value() - # Connect slider to window setter - self.slider.valueChanged.connect(self.set_window) - - layout = QVBoxLayout() - layout.addWidget(self.canvas) - layout.addWidget(self.toolbar) - layout2 = QHBoxLayout() - layout2.addWidget(self.slider) - layout2.addWidget(self.slider_value) - - layout.addLayout(layout2) - self.setLayout(layout) - - self.frames = [] - self.keypoints = [] - self.df = None - # Make widget larger - self.setMinimumHeight(300) - # connect sliders to update plot - self.viewer.dims.events.current_step.connect(self.update_plot_range) - - # Run update plot range once to initialize the plot - self._n = 0 - self.update_plot_range(Event(type_name="", value=[self.viewer.dims.current_step[0]])) - - self.viewer.layers.events.inserted.connect(self._load_dataframe) - self.viewer.dims.events.range.connect(self._update_slider_max) - self._lines = {} - - def on_doubleclick(self, event): - if event.dblclick: - show = list(self._lines.values())[0][0].get_visible() - for lines in self._lines.values(): - for l in lines: - l.set_visible(not show) - self._refresh_canvas(value=self._n) - - def _napari_theme_has_light_bg(self) -> bool: - """ - Does this theme have a light background? - - Returns - ------- - bool - True if theme's background colour has hsl lighter than 50%, False if darker. - """ - theme = napari.utils.theme.get_theme( - self.viewer.theme, - # as_dict=False # deprecated as of napari 0.6.6 - ) - _, _, bg_lightness = theme.background.as_hsl_tuple() - return bg_lightness > 0.5 - - @property - def mpl_style_sheet_path(self) -> Path: - """ - Path to the set Matplotlib style sheet. - """ - if self._napari_theme_has_light_bg(): - return Path(__file__).parent / "styles" / "light.mplstyle" - else: - return Path(__file__).parent / "styles" / "dark.mplstyle" - - def _get_path_to_icon(self) -> Path: - """ - Get the icons directory (which is theme-dependent). - - Icons modified from - https://github.com/matplotlib/matplotlib/tree/main/lib/matplotlib/mpl-data/images - """ - icon_root = Path(__file__).parent / "assets" - if self._napari_theme_has_light_bg(): - return icon_root / "black" - else: - return icon_root / "white" - - def _replace_toolbar_icons(self) -> None: - """ - Modifies toolbar icons to match the napari theme, and add some tooltips. - """ - icon_dir = self._get_path_to_icon() - for action in self.toolbar.actions(): - text = action.text() - if text == "Pan": - action.setToolTip( - "Pan/Zoom: Left button pans; Right button zooms; Click once to activate; Click again to deactivate" - ) - if text == "Zoom": - action.setToolTip("Zoom to rectangle; Click once to activate; Click again to deactivate") - if len(text) > 0: # i.e. not a separator item - icon_path = os.path.join(icon_dir, text + ".png") - action.setIcon(QIcon(icon_path)) - - def _load_dataframe(self): - points_layer = None - for layer in self.viewer.layers: - if isinstance(layer, Points): - points_layer = layer - break - - if points_layer is None or ~np.any(points_layer.data): - return - - self.show() # Silly hack so the window does not hang the first time it is shown - self.hide() - - self.df = _form_df( - points_layer.data, - { - "metadata": points_layer.metadata, - "properties": points_layer.properties, - }, - ) - for keypoint in self.df.columns.get_level_values("bodyparts").unique(): - y = self.df.xs((keypoint, "y"), axis=1, level=["bodyparts", "coords"]) - x = np.arange(len(y)) - color = points_layer.metadata["face_color_cycles"]["label"][keypoint] - lines = self.ax.plot(x, y, color=color, label=keypoint) - self._lines[keypoint] = lines - - self._refresh_canvas(value=self._n) - - def _toggle_line_visibility(self, keypoint): - for artist in self._lines[keypoint]: - artist.set_visible(not artist.get_visible()) - self._refresh_canvas(value=self._n) - - def _refresh_canvas(self, value): - start = max(0, value - self._window // 2) - end = min(value + self._window // 2, len(self.df)) - - self.ax.set_xlim(start, end) - self.vline.set_xdata([value]) - self.canvas.draw() - - def set_window(self, value): - self._window = value - self.slider_value.setText(str(value)) - self.update_plot_range(Event(type_name="", value=[self._n])) - - def update_plot_range(self, event): - value = event.value[0] - self._n = value - - if self.df is None: - return - - self._refresh_canvas(value) - - def _update_slider_max(self, event): - """Update the slider's maximum value based on the number of frames in the data.""" - for layer in self.viewer.layers: - if isinstance(layer, Image) and len(layer.data.shape) >= 3: - n_frames = layer.data.shape[0] - # if less than 50 frames, set max to min to avoid slider issues - if n_frames < self.slider.minimum(): - self.slider.setMaximum(self.slider.minimum()) - else: - self.slider.setMaximum(n_frames - 1) - break + yield + finally: + layer.metadata = old_metadata class KeypointControls(QWidget): def __init__(self, napari_viewer): super().__init__() + # Monkey-patch napari continuous variable type guess + patch_color_manager_guess_continuous() + self._is_saved = False self.viewer = napari_viewer @@ -561,41 +171,43 @@ def __init__(self, napari_viewer): self.viewer.layers.events.inserted.connect(self.on_insert) self.viewer.layers.events.removed.connect(self.on_remove) - self.viewer.window.qt_viewer._get_and_try_preferred_reader = MethodType( - _get_and_try_preferred_reader, - self.viewer.window.qt_viewer, - ) + # self.viewer.window.qt_viewer._get_and_try_preferred_reader = MethodType( + # _get_and_try_preferred_reader, + # self.viewer.window.qt_viewer, + # ) + # Project data + self._project_path: str | None = None status_bar = self.viewer.window._qt_window.statusBar() self.last_saved_label = QLabel("") self.last_saved_label.hide() status_bar.addPermanentWidget(self.last_saved_label) - # Hack napari's Welcome overlay to show more relevant instructions - overlay = self.viewer.window._qt_viewer._welcome_widget - welcome_widget = overlay.layout().itemAt(1).widget() - welcome_widget.deleteLater() - w = QtWelcomeWidget(None) - overlay._overlay = w - overlay.addWidget(w) - # overlay._overlay.sig_dropped.connect(overlay.sig_dropped) - self._color_mode = keypoints.ColorMode.default() self._label_mode = keypoints.LabelMode.default() # Hold references to the KeypointStores self._stores = {} # Intercept close event if data were not saved - self.viewer.window._qt_window.closeEvent = partial( - on_close, - self.viewer.window._qt_window, - widget=self, - ) + qt_win = self.viewer.window._qt_window + orig_close_event = qt_win.closeEvent + + # Wrap event without overriding the original + # for future-proofing + def _close_event(event): + self.on_close(event) + # if accepted, call original + if event.isAccepted(): + orig_close_event(event) + + qt_win.closeEvent = _close_event # Storage for extra image metadata that are relevant to other layers. # These are updated anytime images are added to the Viewer # and passed on to the other layers upon creation. - self._images_meta = dict() + self._image_meta = ImageMetadata() + # Storage for layers requiring recoloring + self._recolor_pending = set() # Add some more controls self._layout = QVBoxLayout(self) @@ -603,23 +215,44 @@ def __init__(self, napari_viewer): self._layer_to_menu = {} self.viewer.layers.selection.events.active.connect(self.on_active_layer_change) - self._video_group = self._form_video_action_menu() + self._video_group = build_video_action_menu( + on_extract_frame=self._extract_single_frame, + on_store_crop=self._store_crop_coordinates, + ) self.video_widget = self.viewer.window.add_dock_widget(self._video_group, name="video", area="right") self.video_widget.setVisible(False) + self._video_group.export_labels_cb.toggled.connect(lambda _checked: self._refresh_video_panel_context()) + self._video_group.apply_crop_cb.toggled.connect(self._on_apply_crop_toggled) + self.viewer.dims.events.current_step.connect(lambda event: self._refresh_video_panel_context()) + self.viewer.layers.selection.events.active.connect(lambda event: self._refresh_video_panel_context()) + QTimer.singleShot(0, self._refresh_video_panel_context) # form helper display self._keypoint_mapping_button = None - self._func_id = None + self._load_superkeypoints_action = None help_buttons = self._form_help_buttons() self._layout.addLayout(help_buttons) grid = QGridLayout() + + self._confirm_overwrite_cb = QCheckBox("Warn on overwrite", parent=self) + self._confirm_overwrite_cb.setToolTip( + "When enabled, saving a layer that would overwrite existing keypoints will show a confirmation dialog." + ) + self._confirm_overwrite_cb.setChecked(settings.get_overwrite_confirmation_enabled()) + self._confirm_overwrite_cb.stateChanged.connect(self._toggle_overwrite_confirmation) + self._trail_cb = QCheckBox("Show trails", parent=self) self._trail_cb.setToolTip("Show the trails for each keypoint over time, in the main video viewer") self._trail_cb.setChecked(False) self._trail_cb.setEnabled(False) - self._trail_cb.stateChanged.connect(self._show_trails) - self._trails = None + self._trail_cb.stateChanged.connect(self._on_show_trails_toggled) + self._trails_controller = TrailsController( + self.viewer, + managed_points_layers_getter=lambda: tuple(self._stores.keys()), + color_mode_getter=lambda: self.color_mode, + resolved_cycle_getter=self._resolved_cycle_for_layer, + ) self._mpl_docked = False self._matplotlib_canvas = KeypointMatplotlibCanvas(self.viewer) @@ -630,9 +263,16 @@ def __init__(self, napari_viewer): self._show_traj_plot_cb.setEnabled(False) self._view_scheme_cb = QCheckBox("Show color scheme", parent=self) - grid.addWidget(self._show_traj_plot_cb, 0, 0) - grid.addWidget(self._trail_cb, 1, 0) - grid.addWidget(self._view_scheme_cb, 2, 0) + grid.addWidget(self._confirm_overwrite_cb, 0, 0) + grid.addWidget(self._show_traj_plot_cb, 1, 0) + grid.addWidget(self._trail_cb, 2, 0) + grid.addWidget(self._view_scheme_cb, 3, 0) + + # UX / status panel (folder, progress, point size) + self._layer_status_panel = LayerStatusPanel(self) + self._layer_status_panel.point_size_changed.connect(self._on_active_points_size_changed) + self._layer_status_panel.point_size_commit_requested.connect(self._commit_active_points_size_to_config) + self._layout.addWidget(self._layer_status_panel) self._layout.addLayout(grid) @@ -643,13 +283,30 @@ def __init__(self, napari_viewer): # form color scheme display + color mode selector self._color_grp, self._color_mode_selector = self._form_color_mode_selector() self._color_grp.setEnabled(False) - self._display = ColorSchemeDisplay(parent=self) - self._color_scheme_display = self._form_color_scheme_display(self.viewer) + + # Color scheme display panel + self._color_scheme_panel = ColorSchemePanel( + viewer=self.viewer, + get_color_mode=lambda: self.color_mode, + get_header_model=self._get_header_model_from_metadata, + parent=self, + ) + self._color_scheme_display = self.viewer.window.add_dock_widget( + self._color_scheme_panel, + name="Color scheme reference", + area="left", + ) + self._view_scheme_cb.setChecked(True) self._view_scheme_cb.toggled.connect(self._show_color_scheme) - self._view_scheme_cb.toggle() - self._display.added.connect( + self._show_color_scheme() + self._color_scheme_panel.display.added.connect( lambda w: w.part_label.clicked.connect(self._matplotlib_canvas._toggle_line_visibility), ) + ### UI setup ends here + + # Modes init + self.color_mode = self._color_mode + self.label_mode = self._label_mode # Substitute default menu action with custom one for action in self.viewer.window.file_menu.actions()[::-1]: @@ -657,8 +314,7 @@ def __init__(self, napari_viewer): if "save selected layer" in action_name: action.triggered.disconnect() action.triggered.connect( - lambda: _save_layers_dialog( - self, + lambda: self._save_layers_dialog( selected=True, ) ) @@ -674,6 +330,8 @@ def __init__(self, napari_viewer): display_shortcuts_action = QAction("&Shortcuts", self) display_shortcuts_action.triggered.connect(self.display_shortcuts) self.viewer.window.help_menu.addAction(display_shortcuts_action) + # Install global keybinds + install_global_points_keybindings() # Hide some unused viewer buttons # NOTE (future) do we truly want to disable these ? Tracking util may need to create new points layers @@ -697,6 +355,286 @@ def __init__(self, napari_viewer): # (Of course this will be a problem if we start using it everywhere so do not reuse lightly) QTimer.singleShot(10, self.silently_dock_matplotlib_canvas) + # If layers already exist (user loaded data before opening this widget), + # adopt them so keypoint controls take ownership immediately. + QTimer.singleShot(0, self._adopt_existing_layers) + + # Refresh layers stats widget + QTimer.singleShot(0, self._refresh_layer_status_panel) + + # ######################## # + # Layer setup core methods # + # ######################## # + + def _setup_image_layer(self, layer: Image, index: int | None = None, *, reorder: bool = True) -> None: + md = layer.metadata or {} + paths = md.get("paths") + if paths is None and io.is_video(layer.name): + self.video_widget.setVisible(True) + + self._update_image_meta_from_layer(layer) + + if not self._project_path: + self._cache_project_path_from_image_layer(layer) + if self._project_path is not None: + try: + layer.metadata = dict(layer.metadata or {}) + layer.metadata.setdefault("project", self._project_path) + except Exception: + logger.debug( + "Failed to set project path metadata on image layer %r", + getattr(layer, "name", layer), + exc_info=True, + ) + + self._sync_points_layers_from_image_meta() + self._refresh_video_panel_context() + + if reorder and index is not None: + QTimer.singleShot(10, partial(self._move_image_layer_to_bottom, index)) + + def _maybe_merge_config_points_layer(self, layer: Points) -> bool: + if not layer.metadata.get("project", "") or not self._stores: + return False + + new_metadata = layer.metadata.copy() + + keypoints_menu = self._menus[0].menus["label"] + current_keypoint_set = {keypoints_menu.itemText(i) for i in range(keypoints_menu.count())} + hdr = self._get_header_model_from_metadata(new_metadata) + if hdr is None: + return False + new_keypoint_set = set(hdr.bodyparts) + diff = new_keypoint_set.difference(current_keypoint_set) + + if diff: + answer = QMessageBox.question(self, "", "Do you want to display the new keypoints only?") + if answer == QMessageBox.Yes: + self.viewer.layers[-2].shown = False + + self.viewer.status = f"New keypoint{'s' if len(diff) > 1 else ''} {', '.join(diff)} found." + for _layer, store in self._stores.items(): + pts = read_points_meta(_layer, migrate_legacy=True, drop_controls=True, drop_header=False) + if not hasattr(pts, "errors"): + updated = pts.model_copy(update={"header": hdr}) + write_points_meta(_layer, updated, merge_policy=MergePolicy.MERGE, fields={"header"}) + store.layer = _layer + + for menu in self._menus: + menu._map_individuals_to_bodyparts() + menu._update_items() + + QTimer.singleShot(10, self.viewer.layers.pop) + + # apply the new color cycles + recolor safely + for _layer, store in self._stores.items(): + _layer.metadata["config_colormap"] = new_metadata.get( + "config_colormap", _layer.metadata.get("config_colormap") + ) + _layer.metadata["face_color_cycles"] = new_metadata["face_color_cycles"] + _layer.metadata["colormap_name"] = new_metadata.get("colormap_name", _layer.metadata.get("colormap_name")) + mark_layer_presentation_changed(_layer) + self._apply_points_coloring_from_metadata(_layer) + store.layer = _layer + + self._update_color_scheme() + return True + + def _get_header_model_from_metadata(self, md: dict) -> DLCHeaderModel | None: + """Return DLCHeaderModel regardless of whether md['header'] is a model, dict payload, or MultiIndex.""" + if not isinstance(md, dict): + return None + hdr = md.get("header", None) + if hdr is None: + return None + + if isinstance(hdr, DLCHeaderModel): + logger.debug("Header is already a DLCHeaderModel instance.") + return hdr + + if isinstance(hdr, dict): + try: + return DLCHeaderModel.model_validate(hdr) + except Exception: + return None + + # fallback: allow MultiIndex / list-of-tuples / Index inputs + try: + return DLCHeaderModel(columns=hdr) + except Exception: + return None + + @staticmethod + def get_layer_controls(layer: Points) -> KeypointControls | None: + return getattr(layer, "_dlc_controls", None) + + @staticmethod + def get_layer_store(layer: Points) -> keypoints.KeypointStore | None: + return getattr(layer, "_dlc_store", None) + + def _wire_points_layer(self, layer: Points) -> keypoints.KeypointStore | None: + if not self._validate_header(layer): + return None + existing = getattr(layer, "_dlc_store", None) + if existing is not None: + self._stores[layer] = existing + layer._dlc_controls = self + return existing + + # ensure presence of IO metadata for saving & routing + mig = migrate_points_layer_metadata(layer) + if hasattr(mig, "errors"): + logger.warning( + "Points metadata validation failed during wiring for layer=%r: %s", + getattr(layer, "name", layer), + mig, + ) + + store = keypoints.KeypointStore(self.viewer, layer) + self._stores[layer] = store + layer._dlc_store = store + layer._dlc_controls = self + + # default root/paths from current image meta if missing + if not layer.metadata.get("root") and self._image_meta.root: + layer.metadata["root"] = self._image_meta.root + if not layer.metadata.get("paths") and self._image_meta.paths: + layer.metadata["paths"] = self._image_meta.paths + + # save history + if root := layer.metadata.get("root"): + update_save_history(root) + + store._get_label_mode = lambda: self._label_mode + layer.text.visible = False + + paste_func = make_paste_data(self, store=store) + install_paste_patch(layer, paste_func=paste_func) + + add_impl = MethodType(keypoints._add, store) # bind store to add implementation + install_add_wrapper(layer, add_impl=add_impl, schedule_recolor=self._schedule_recolor) + + # store events / navigation + layer.events.add(query_next_frame=Event) + layer.events.query_next_frame.connect(store._advance_step) + + # Install keybinds + install_points_layer_keybindings(layer, self, store) + + if len(self._stores) == 1 and self._is_multianimal(layer): + # set internal mode without triggering recolor storms + self._color_mode = keypoints.ColorMode.INDIVIDUAL + # update button state so UI matches (optional but good) + for btn in self._color_mode_selector.buttons(): + if btn.text().lower() == str(self._color_mode).lower(): + btn.setChecked(True) + break + + # apply cycles (works even if empty; see method) + self._apply_points_coloring_from_metadata(layer) + self._maybe_initialize_layer_point_size_from_config(layer) + self._connect_layer_status_events(layer) + # refresh trails if enabled (e.g. when merging a config points layer with trails metadata) + self._trails_controller.on_points_layer_added_or_rewired(checkbox_checked=self._trail_cb.isChecked()) + + # menus + self._form_dropdown_menus(store) + + # project path + proj = layer.metadata.get("project") + if proj: + self._project_path = proj + + # enable GUI groups + self._radio_box.setEnabled(True) + self._color_grp.setEnabled(True) + self._trail_cb.setEnabled(True) + self._show_traj_plot_cb.setEnabled(True) + + return store + + def _setup_points_layer(self, layer: Points, *, allow_merge: bool = True) -> None: + if not self._validate_header(layer): + return + + if allow_merge and self._maybe_merge_config_points_layer(layer): + return + + if layer.metadata.get("tables", ""): + self._keypoint_mapping_button.show() + + store = self._wire_points_layer(layer) + if store is None: + return + + selector = apply_points_layer_ui_tweaks(self.viewer, layer, dropdown_cls=DropdownMenu, plt_module=plt) + if selector is not None: + try: + selector.currentTextChanged.connect(self._update_colormap) + except Exception: + pass + + self._update_color_scheme() + + def _adopt_existing_layers(self) -> None: + """ + When the widget is opened after layers already exist, we need to + run the same initialization as if they had been inserted. + """ + # Iterate over a snapshot, because on_insert may modify layer order + layers_snapshot = list(self.viewer.layers) + + for idx, layer in enumerate(layers_snapshot): + self._adopt_layer(layer, idx) + + # After adoption, refresh UI state + try: + active = self.viewer.layers.selection.active + if active is not None: + # Force the GUI to update visibility of menus, etc. + self.on_active_layer_change(Event(type_name="active", value=active)) + except Exception: + pass + + def _adopt_layer(self, layer, index: int) -> None: + """ + Run the relevant portion of on_insert() for an already-existing layer. + This avoids duplicating your logic and prevents reliance on napari's Event object. + """ + if isinstance(layer, Image): + self._setup_image_layer(layer, index, reorder=True) + elif isinstance(layer, Points): + if layer not in self._stores: + self._setup_points_layer(layer, allow_merge=False) # typically don’t merge during adopt + if not isinstance(layer, Image): + self._remap_frame_indices(layer) + + def _validate_header(self, layer) -> bool: + res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=False) + if isinstance(res, ValidationError) or res.header is None: + self.viewer.status = ( + "This Points layer does not look like a DLC keypoints layer. Missing a valid DLC header." + ) + return False + return True + + def _schedule_recolor(self, layer: Points) -> None: + if not hasattr(self, "_recolor_pending"): + self._recolor_pending = set() + + if layer in self._recolor_pending: + return + + self._recolor_pending.add(layer) + + def _do(): + try: + self._apply_points_coloring_from_metadata(layer) + finally: + self._recolor_pending.discard(layer) + + QTimer.singleShot(0, _do) + def _ensure_mpl_canvas_docked(self) -> None: """ Dock the Matplotlib canvas as a napari dock widget, exactly once, @@ -744,12 +682,7 @@ def settings(self): return QSettings() def load_superkeypoints_diagram(self): - points_layer = None - for layer in self.viewer.layers: - if isinstance(layer, Points): - points_layer = layer - break - + points_layer = get_first_points_layer(self.viewer) if points_layer is None: return @@ -758,9 +691,9 @@ def load_superkeypoints_diagram(self): return super_animal, table = tables.popitem() - layer_data = _load_superkeypoints_diagram(super_animal) - self.viewer.add_image(layer_data[0], metadata=layer_data[1]) - superkpts_dict = _load_superkeypoints(super_animal) + image = io.load_superkeypoints_diagram(super_animal) + self.viewer.add_image(image, name=f"{super_animal} keypoint diagram", metadata={"super_animal": super_animal}) + superkpts_dict = io.load_superkeypoints(super_animal) xy = [] labels = [] for kpt_ref, kpt_super in table.items(): @@ -770,8 +703,12 @@ def load_superkeypoints_diagram(self): properties = deepcopy(points_layer.properties) properties["label"] = np.array(labels) points_layer.properties = properties + self._apply_points_coloring_from_metadata(points_layer) self._keypoint_mapping_button.setText("Map keypoints") - self._keypoint_mapping_button.clicked.disconnect(self._func_id) + try: + self._keypoint_mapping_button.clicked.disconnect(self._load_superkeypoints_action) + except TypeError: + pass self._keypoint_mapping_button.clicked.connect(lambda: self._map_keypoints(super_animal)) def _map_keypoints(self, super_animal: str): @@ -780,221 +717,471 @@ def _map_keypoints(self, super_animal: str): # - Assumes _load_superkeypoints and _load_config succeed # and return well-formed data; I/O errors are not handled. # - Silently ignores keypoints that have no nearest neighbor in the superkeypoint set (no user feedback). - points_layer = None - for layer in self.viewer.layers: - if isinstance(layer, Points) and layer.metadata.get("tables"): - points_layer = layer - break - - if points_layer is None or ~np.any(points_layer.data): + points_layer = get_points_layer_with_tables(self.viewer) + if points_layer is None or not np.any(points_layer.data): return xy = points_layer.data[:, 1:3] - superkpts_dict = _load_superkeypoints(super_animal) - xy_ref = np.c_[[val for val in superkpts_dict.values()]] + superkpts_dict = io.load_superkeypoints(super_animal) + xy_ref = np.asarray(list(superkpts_dict.values()), dtype=float) neighbors = keypoints._find_nearest_neighbors(xy, xy_ref) found = neighbors != -1 - if ~np.any(found): + if not np.any(found): return project_path = points_layer.metadata["project"] config_path = str(Path(project_path) / "config.yaml") - cfg = _load_config(config_path) + cfg = io.load_config(config_path) conversion_tables = cfg.get("SuperAnimalConversionTables", {}) + hdr = self._get_header_model_from_metadata(points_layer.metadata or {}) + if hdr is None: + return + bdprts_map = map(str, hdr.bodyparts) conversion_tables[super_animal] = dict( zip( - map(str, points_layer.metadata["header"].bodyparts), # Needed to fix an ugly yaml RepresenterError + bdprts_map, # Needed to fix an ugly yaml RepresenterError map(str, list(np.array(list(superkpts_dict))[neighbors[found]])), strict=False, ) ) cfg["SuperAnimalConversionTables"] = conversion_tables - _write_config(config_path, cfg) + io.write_config(config_path, cfg) self.viewer.status = "Mapping to superkeypoint set successfully saved" def start_tutorial(self): Tutorial(self.viewer.window._qt_window.current()).show() def display_shortcuts(self): - Shortcuts(self.viewer.window._qt_window.current()).show() + Shortcuts(self.viewer.window._qt_window.current(), viewer=self.viewer).show() def _move_image_layer_to_bottom(self, index): if (ind := index) != 0: self.viewer.layers.move_selected(ind, 0) self.viewer.layers.select_next() # Auto-select the Points layer - def _show_color_scheme(self): - show = self._view_scheme_cb.isChecked() - self._color_scheme_display.setVisible(show) + # ------------------------------------------------------------------ + # Metadata helpers (authoritative models + napari-friendly dict sync) + # ------------------------------------------------------------------ + @staticmethod + def _layer_source_path(layer) -> str | None: + """Best-effort access to napari layer source path (may not exist).""" + try: + src = getattr(layer, "source", None) + p = getattr(src, "path", None) if src is not None else None + return str(p) if p else None + except Exception: + return None - def _show_trails(self, state): - if Qt.CheckState(state) != Qt.CheckState.Checked: - if self._trails is not None: - self._trails.visible = False - return + def _update_image_meta_from_layer(self, layer: Image) -> None: + """ + Update authoritative self._images_meta using an Image layer. + Also keep a dict-like subset synced for other layers (non-breaking). + """ + md = layer.metadata or {} - if self._trails is None: - store = list(self._stores.values())[0] + paths = md.get("paths") + shape = None + try: + shape = layer.level_shapes[0] + except Exception: + shape = None - # Determine coloring mode - mode = "label" - if self.color_mode == str(keypoints.ColorMode.INDIVIDUAL): - mode = "id" + root = infer_image_root( + explicit_root=md.get("root"), + paths=paths, + source_path=self._layer_source_path(layer), + ) - categories = store.layer.properties.get(mode) - # Check for single animal data - if categories is None or (mode == "id" and (not categories[0])): - mode = "label" - categories = store.layer.properties["label"] + incoming = ImageMetadata( + paths=list(paths) if paths else None, + root=str(root) if root else None, + shape=tuple(shape) if shape is not None else None, + name=getattr(layer, "name", None), + ) - inds = encode_categories(categories, is_path=False, do_sort=False) + # Merge without clobbering already-known values + # (same behavior as old "only set if non-empty") + base = self._image_meta + merged = base.model_copy(deep=True) + for field, value in incoming.model_dump().items(): + if getattr(merged, field) in (None, "", []) and value not in (None, "", []): + setattr(merged, field, value) - # Build Tracks data - temp = np.c_[inds, store.layer.data] - cmap = "viridis" - for layer in self.viewer.layers: - if isinstance(layer, Points): - colormap_name = layer.metadata.get("colormap_name") - if colormap_name: - cmap = colormap_name - break + self._image_meta = merged - # 5) Create Tracks layer - self._trails = self.viewer.add_tracks( - temp, - colormap=cmap, - tail_length=50, - head_length=50, - tail_width=6, - name="trails", - ) + def _sync_points_layers_from_image_meta(self) -> None: + """ + Ensure all Points layers have core fields required for saving. - self._trails.visible = True + Adapter-based flow: + - read validated points meta (visible failures) + - apply sync logic against authoritative self._image_meta + - write back validated dict via gateway + """ + if self._image_meta is None: + return - def _form_video_action_menu(self): - group_box = QGroupBox("Video") - layout = QVBoxLayout() - extract_button = QPushButton("Extract frame") - extract_button.clicked.connect(self._extract_single_frame) - layout.addWidget(extract_button) - crop_button = QPushButton("Store crop coordinates") - crop_button.clicked.connect(self._store_crop_coordinates) - layout.addWidget(crop_button) - group_box.setLayout(layout) - return group_box + for ly in list(self.viewer.layers): + if not isinstance(ly, Points): + continue - def _form_help_buttons(self): - layout = QVBoxLayout() - help_buttons_layout = QHBoxLayout() - show_shortcuts = QPushButton("View shortcuts") - show_shortcuts.clicked.connect(self.display_shortcuts) - help_buttons_layout.addWidget(show_shortcuts) - tutorial = QPushButton("Start tutorial") - tutorial.clicked.connect(self.start_tutorial) - help_buttons_layout.addWidget(tutorial) - layout.addLayout(help_buttons_layout) - self._keypoint_mapping_button = QPushButton("Load superkeypoints diagram") - self._func_id = self._keypoint_mapping_button.clicked.connect(self.load_superkeypoints_diagram) - self._keypoint_mapping_button.hide() - layout.addWidget(self._keypoint_mapping_button) - return layout + if ly.metadata is None: + ly.metadata = {} - def _extract_single_frame(self, *args): - image_layer = None - points_layer = None - for layer in self.viewer.layers: - if isinstance(layer, Image): - image_layer = layer - elif isinstance(layer, Points): - points_layer = layer - if image_layer is not None: - ind = self.viewer.dims.current_step[0] - frame = image_layer.data[ind] - n_frames = image_layer.data.shape[0] - name = f"img{str(ind).zfill(int(ceil(log10(n_frames))))}.png" - output_path = os.path.join(image_layer.metadata["root"], name) - _write_image(frame, str(output_path)) - - # If annotations were loaded, they should be written to a machinefile.h5 file - if points_layer is not None: - df = _form_df( - points_layer.data, - { - "metadata": points_layer.metadata, - "properties": points_layer.properties, - }, + # 1) Read + migrate legacy (io from source_h5, header coercion, etc.) + res = read_points_meta(ly, migrate_legacy=True, drop_controls=False, drop_header=False) + if hasattr(res, "errors"): # ValidationError duck-typing + logger.warning( + "Points metadata validation failed during sync for layer=%r: %s", + getattr(ly, "name", ly), + res, ) - df = df.iloc[ind : ind + 1] - canon = canonicalize_path(output_path, 3) - df.index = pd.MultiIndex.from_tuples([tuple(canon.split("/"))]) - filepath = os.path.join(image_layer.metadata["root"], "machinelabels-iter0.h5") - if Path(filepath).is_file(): - df_prev = pd.read_hdf(filepath) - guarantee_multiindex_rows(df_prev) - df = pd.concat([df_prev, df]) - df = df[~df.index.duplicated(keep="first")] - df.to_hdf(filepath, key="df_with_missing") + continue - def _store_crop_coordinates(self, *args): - if not (project_path := self._images_meta.get("project")): - return - for layer in self.viewer.layers: - if isinstance(layer, Shapes): - try: - ind = layer.shape_type.index("rectangle") - except ValueError: - return - bbox = layer.data[ind][:, 1:] - h = self.viewer.dims.range[2][1] - bbox[:, 0] = h - bbox[:, 0] - bbox = np.clip(bbox, 0, a_max=None).astype(int) - y1, x1 = bbox.min(axis=0) - y2, x2 = bbox.max(axis=0) - temp = {"crop": ", ".join(map(str, [x1, x2, y1, y2]))} - config_path = os.path.join(project_path, "config.yaml") - cfg = _load_config(config_path) - cfg["video_sets"][os.path.join(project_path, "videos", self._images_meta["name"])] = temp - _write_config(config_path, cfg) - break + pts_model: PointsMetadata = res - def _form_dropdown_menus(self, store): - menu = KeypointsDropdownMenu(store) - self.viewer.dims.events.current_step.connect( - menu.smart_reset, - position="last", - ) - menu.smart_reset(event=None) - self._menus.append(menu) - self._layer_to_menu[store.layer] = len(self._menus) - 1 - layout = QVBoxLayout() - layout.addWidget(menu) - self._layout.addLayout(layout) + # 2) Sync missing fields from image meta (pure model transform) + synced = sync_points_from_image(self._image_meta, pts_model) - def _form_mode_radio_buttons(self): - group_box = QGroupBox("Labeling mode") - layout = QHBoxLayout() - group = QButtonGroup(self) - for i, mode in enumerate(keypoints.LabelMode.__members__, start=1): - btn = QRadioButton(mode.capitalize()) - btn.setToolTip(keypoints.TOOLTIPS[mode]) - group.addButton(btn, i) - layout.addWidget(btn) - group.button(1).setChecked(True) - group_box.setLayout(layout) - self._layout.addWidget(group_box) + # 3) Write back through gateway (fill missing only; never clobber) + out = write_points_meta( + ly, + synced, + merge_policy=MergePolicy.MERGE_MISSING, + migrate_legacy=True, + validate=True, + ) + if hasattr(out, "errors"): + logger.warning( + "Failed to write synced points metadata for layer=%r: %s", + getattr(ly, "name", ly), + out, + ) - def _func(): - self.label_mode = group.checkedButton().text().lower() + def _resolve_config_path_for_layer(self, layer: Points | None) -> Path | None: + if layer is None: + return None - group.buttonClicked.connect(_func) - return group_box, group + image_layer = find_relevant_image_layer(self.viewer) - def _form_color_mode_selector(self): - group_box = QGroupBox("Keypoint coloring mode") - layout = QHBoxLayout() - group = QButtonGroup(self) - for i, mode in enumerate(keypoints.ColorMode.__members__, start=1): - btn = QRadioButton(mode.lower()) + return resolve_config_path_from_layer( + layer, + fallback_project=self._project_path, + fallback_root=self._image_meta.root, + image_layer=image_layer, + prefer_project_root=True, + max_levels=5, + ) + + def _maybe_prepare_project_path_override_metadata(self, layer: Points) -> tuple[dict | None, bool]: + """ + Optionally prepare save-time metadata by associating a project-less labeled + folder with an explicit DLC project chosen via config.yaml. + + Returns + ------- + tuple[dict | None, bool] + (overridden_metadata, abort_save) + + - (None, False): feature not applicable; continue normal save + - (metadata, False): apply metadata override and continue + - (None, True): user cancelled or operation was refused; abort save + """ + res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=False) + if isinstance(res, ValidationError): + return None, False + + pts_meta: PointsMetadata = res + paths = pts_meta.paths or [] + if not paths: + return None, False + + if not is_projectless_folder_association_candidate(pts_meta): + return None, False + + source_root = pts_meta.root + if not source_root: + return None, False + + try: + source_root_path = Path(source_root).expanduser().resolve(strict=False) + except Exception: + source_root_path = Path(source_root) + + # NOTE: @C-Achard 2026-03-27 Currently does not let user choose + # a different dataset name than the source folder, + # to keep a lightweight workflow. + # This could be allowed in the future if there is demand. + dataset_name = source_root_path.name + if not dataset_name: + return None, False + + initial_dir = self._project_path or pts_meta.project or str(source_root_path) + dialog_result = ui_dialogs.prompt_for_project_config_for_save(self, initial_dir=initial_dir) + + if dialog_result.action is ui_dialogs.ProjectConfigPromptAction.CANCEL: + logger.debug("User cancelled project association prompt.") + return None, True # abort save + + if dialog_result.action is ui_dialogs.ProjectConfigPromptAction.SKIP: + logger.debug("User chose to continue without project association.") + return None, False # continue normal save path + + if dialog_result.action is not ui_dialogs.ProjectConfigPromptAction.ASSOCIATE: + logger.warning("Unexpected project association dialog result: %r", dialog_result) + return None, True # fail safe: abort save + + config_path = dialog_result.config_path + if not config_path: + logger.warning("Project association result was ASSOCIATE but config_path was empty.") + return None, True # fail safe: abort save + + project_root = resolve_project_root_from_config(config_path) + if project_root is None: + QMessageBox.warning( + self, + "Invalid project configuration", + "The selected file is not a valid DeepLabCut config.yaml or project root. " + "The save operation has been cancelled.", + ) + return None, True + + target_folder = target_dataset_folder_for_config(config_path, dataset_name=dataset_name) + if dataset_folder_has_files(target_folder): + ui_dialogs.warn_existing_dataset_folder_conflict(self, target_folder=target_folder) + return None, True # refuse the save + + rewritten_paths, unresolved = coerce_paths_to_dlc_row_keys( + paths, + source_root=source_root_path, + dataset_name=dataset_name, + ) + + if not ui_dialogs.maybe_confirm_dataset_path_rewrite( + self, + project_root=project_root, + dataset_name=dataset_name, + n_paths=len(paths), + n_unresolved=len(unresolved), + ): + return None, True # user declined + + overridden = apply_project_paths_override_to_points_meta( + pts_meta, + project_root=project_root, + rewritten_paths=rewritten_paths, + ) + + return overridden.model_dump(mode="python", exclude_none=True), False + + def _show_color_scheme(self): + show = self._view_scheme_cb.isChecked() + self._color_scheme_display.setVisible(show) + + def _current_dlc_points_layer(self) -> Points | None: + active = self.viewer.layers.selection.active + if not isinstance(active, Points): + return None + + try: + res = read_points_meta(active, migrate_legacy=True, drop_controls=True, drop_header=False) + except Exception: + return None + + if isinstance(res, ValidationError): + return None + + if getattr(res, "header", None) is None: + return None + + return active + + def _refresh_layer_status_panel(self) -> None: + active_layer = self.viewer.layers.selection.active + active_dlc_points = self._current_dlc_points_layer() + active_image = find_relevant_image_layer(self.viewer) + + folder_name = infer_folder_display_name( + active_image if active_image is not None else active_layer, + fallback_root=self._image_meta.root, + ) + self._layer_status_panel.set_folder_name(folder_name) + + # No active layer or not a Points layer at all + if active_layer is None or not isinstance(active_layer, Points): + self._layer_status_panel.set_no_active_points_layer() + return + + # Active layer is a Points layer, but not a valid DLC points layer + if active_dlc_points is None: + self._layer_status_panel.set_invalid_points_layer() + return + + self._layer_status_panel.set_point_size_enabled(True) + self._layer_status_panel.set_point_size(get_uniform_point_size(active_dlc_points)) + + progress = compute_label_progress(active_dlc_points, fallback_paths=self._image_meta.paths) + self._layer_status_panel.set_progress_summary( + labeled_percent=progress.labeled_percent, + remaining_percent=progress.remaining_percent, + labeled_points=progress.labeled_points, + total_points=progress.total_points, + frame_count=progress.frame_count, + bodypart_count=progress.bodypart_count, + individual_count=progress.individual_count, + ) + + def _on_active_points_size_changed(self, size: int) -> None: + layer = self._current_dlc_points_layer() + if layer is None: + return + + set_uniform_point_size(layer, size) + mark_layer_presentation_changed(layer) + + def _commit_active_points_size_to_config(self, size: int) -> None: + layer = self._current_dlc_points_layer() + if layer is None: + return + + config_path = self._resolve_config_path_for_layer(layer) + if config_path is None: + logger.debug( + "No config.yaml could be resolved at commit time for active layer %r", + getattr(layer, "name", layer), + ) + return + + try: + changed = save_point_size_to_config(config_path, int(size)) + if changed: + self.viewer.status = f"Updated config dotsize to {int(size)}" + except Exception: + logger.debug("Failed to sync point size to config", exc_info=True) + + def _maybe_initialize_layer_point_size_from_config(self, layer: Points) -> None: + config_path = self._resolve_config_path_for_layer(layer) + if config_path is None: + return + + config_size = load_point_size_from_config(config_path) + if config_size is None: + return + + current_size = get_uniform_point_size(layer) + + # Conservative initialization + if current_size <= 8: + try: + set_uniform_point_size(layer, config_size) + mark_layer_presentation_changed(layer) + except Exception: + logger.debug("Could not initialize layer point size from config", exc_info=True) + + def _connect_layer_status_events(self, layer: Points) -> None: + """ + Keep the UX panel live without adding heavy watchers. + """ + try: + layer.events.data.connect(lambda event=None, _layer=layer: self._refresh_layer_status_panel()) + except Exception: + pass + + try: + layer.events.size.connect(lambda event=None, _layer=layer: self._refresh_layer_status_panel()) + except Exception: + pass + + try: + layer.events.properties.connect(lambda event=None, _layer=layer: self._refresh_layer_status_panel()) + except Exception: + pass + + def _form_help_buttons(self): + layout = QVBoxLayout() + help_buttons_layout = QHBoxLayout() + self.show_shortcuts_btn = QPushButton("View shortcuts") + self.show_shortcuts_btn.clicked.connect(self.display_shortcuts) + help_buttons_layout.addWidget(self.show_shortcuts_btn) + self.tutorial_btn = QPushButton("Start tutorial") + self.tutorial_btn.clicked.connect(self.start_tutorial) + help_buttons_layout.addWidget(self.tutorial_btn) + layout.addLayout(help_buttons_layout) + self._keypoint_mapping_button = QPushButton("Load superkeypoints diagram") + self._load_superkeypoints_action = self._keypoint_mapping_button.clicked.connect( + self.load_superkeypoints_diagram + ) + self._keypoint_mapping_button.hide() + layout.addWidget(self._keypoint_mapping_button) + return layout + + def _refresh_video_panel_context(self) -> None: + update_video_panel_context(self.viewer, self._video_group) + + def _cache_project_path_from_image_layer(self, layer: Image) -> None: + """Best-effort cache of project path from an image/video layer.""" + project_path = resolve_project_path_from_image_layer(layer) + if project_path is None: + return + + self._project_path = project_path + try: + layer.metadata = dict(layer.metadata or {}) + layer.metadata.setdefault("project", self._project_path) + except Exception: + logger.debug( + "Failed to set project path metadata on image layer %r", + getattr(layer, "name", layer), + exc_info=True, + ) + + self._refresh_video_panel_context() + + def _extract_single_frame(self, *args): + ok, msg = run_extract_current_frame( + self.viewer, + self._video_group, + validate_points_layer=self._validate_header, + ) + self.viewer.status = msg + self._refresh_video_panel_context() + + def _on_apply_crop_toggled(self, checked) -> None: + handle_apply_crop_toggled(self.viewer, self._video_group, bool(checked)) + self._refresh_video_panel_context() + + def _store_crop_coordinates(self, *args): + ok, msg, project_path = run_store_crop_coordinates( + self.viewer, + self._video_group, + explicit_project_path=self._project_path, + fallback_video_name=self._image_meta.name, + ) + if project_path is not None: + self._project_path = project_path + self.viewer.status = msg + self._refresh_video_panel_context() + + def _form_dropdown_menus(self, store): + menu = KeypointsDropdownMenu(store) + self.viewer.dims.events.current_step.connect( + menu.smart_reset, + position="last", + ) + menu.smart_reset(event=None) + self._menus.append(menu) + self._layer_to_menu[store.layer] = len(self._menus) - 1 + layout = QVBoxLayout() + layout.addWidget(menu) + self._layout.addLayout(layout) + + def _form_mode_radio_buttons(self): + group_box = QGroupBox("Labeling mode") + layout = QHBoxLayout() + group = QButtonGroup(self) + for i, mode in enumerate(keypoints.LabelMode.__members__, start=1): + btn = QRadioButton(mode.capitalize()) + btn.setToolTip(keypoints.TOOLTIPS[mode]) group.addButton(btn, i) layout.addWidget(btn) group.button(1).setChecked(True) @@ -1002,328 +1189,513 @@ def _form_color_mode_selector(self): self._layout.addWidget(group_box) def _func(): - self.color_mode = group.checkedButton().text() + self.label_mode = group.checkedButton().text().lower() group.buttonClicked.connect(_func) return group_box, group - def _form_color_scheme_display(self, viewer): - self.viewer.layers.events.inserted.connect(self._update_color_scheme) - return viewer.window.add_dock_widget(self._display, name="Color scheme reference", area="left") + def _form_color_mode_selector(self): + group_box = QGroupBox("Keypoint coloring mode") + layout = QHBoxLayout() + group = QButtonGroup(self) + for i, mode in enumerate(keypoints.ColorMode.__members__, start=1): + btn = QRadioButton(mode.lower()) + group.addButton(btn, i) + layout.addWidget(btn) + group.button(1).setChecked(True) + group_box.setLayout(layout) + self._layout.addWidget(group_box) + + def _func(): + self.color_mode = group.checkedButton().text() + + group.buttonClicked.connect(_func) + return group_box, group def _update_color_scheme(self): - def to_hex(nparray): - a = np.array(nparray * 255, dtype=int) + if hasattr(self, "_color_scheme_panel"): + self._color_scheme_panel.schedule_update() + + def _apply_points_coloring_from_metadata(self, layer: Points) -> None: + """Apply categorical coloring using centralized resolver policy.""" + resolver = self._color_scheme_panel._resolver + cycles = resolver.get_face_color_cycles(layer) + if not cycles: + try: + layer.face_color_mode = "direct" + except Exception: + pass + return + + prop = resolver.get_active_color_property(layer) + if prop not in cycles or not cycles[prop]: + return - def rgb2hex(r, g, b, _): - return f"#{r:02x}{g:02x}{b:02x}" + props = getattr(layer, "properties", {}) or {} + values = props.get(prop, None) - res = rgb2hex(*a) - return res + # id mode on single-animal / blank ids -> fallback to label + if prop == "id": + try: + vals = np.asarray(values, dtype=object).ravel() if values is not None else np.array([], dtype=object) + if len(vals) == 0 or all(v in ("", None) or misc._is_nan_value(v) for v in vals): + prop = "label" + values = props.get("label", None) + except Exception: + prop = "label" + values = props.get("label", None) - self._display.reset() - mode = "label" - if self.color_mode == str(keypoints.ColorMode.INDIVIDUAL): - mode = "id" + if values is None or len(values) == 0 or misc._array_has_nan(values): + try: + layer.face_color_mode = "direct" + except Exception: + pass + return - for layer in self.viewer.layers: - if isinstance(layer, Points) and layer.metadata: - self._display.update_color_scheme( - {name: to_hex(color) for name, color in layer.metadata["face_color_cycles"][mode].items()} - ) + try: + layer.face_color = prop + layer.face_color_cycle = cycles[prop] + layer.face_color_mode = "cycle" + layer.events.face_color() + except Exception: + try: + layer.face_color_mode = "direct" + except Exception: + pass def _remap_frame_indices(self, layer): """ Best-effort remap of time/frame indices in non-Image layers to match current Image order. - Safety principles: - - Never delete user data automatically (only remap what is safe). - - Only write back to layer.data after successful transformation. - - Always sync layer.metadata with self._images_meta when possible. + Safety principles + ----------------- + - Never delete or silently corrupt user data. + - Only write back to layer.data after a remap has been accepted as safe. + - Always sync non-path image metadata when possible. + - Do NOT replace metadata["paths"] unless remap is accepted as safe. + - Specifically reject ambiguous basename-only remaps (depth=1 with duplicate / + non-bijective warnings), which commonly happen when data are moved out of the + standard DLC labeled-data layout. """ try: - # Need new image paths to define the reference order - new_paths_raw = self._images_meta.get("paths") - if not new_paths_raw: + new_paths = self._image_meta.paths + if not new_paths: return - # Determine layer's stored paths - md = layer.metadata or {} - old_paths_raw = md.get("paths") or [] - # Always sync basic metadata (even if we can't remap) - try: - layer.metadata.update(self._images_meta) - except Exception: - pass - - if not old_paths_raw: - return + if layer.metadata is None: + layer.metadata = {} - # Try different canonicalization depths to find overlap - depth_used = None - for depth in (3, 2, 1): - new_keys = [canonicalize_path(p, depth) for p in new_paths_raw] - old_keys = [canonicalize_path(p, depth) for p in old_paths_raw] - overlap = set(new_keys) & set(old_keys) - if overlap: - depth_used = depth - break + md = layer.metadata + old_paths = md.get("paths") or [] - if depth_used is None: - logging.warning( - "Cannot remap %s: no path overlap found for all attempted matchings", + # Always sync safe non-path metadata from image meta. + # Do NOT sync "paths" yet; that is only safe after we decide remap is acceptable. + try: + safe_image_meta = self._image_meta.model_dump(exclude_none=True) + safe_image_meta.pop("paths", None) + layer.metadata.update(safe_image_meta) + except Exception: + logger.debug( + "Failed to sync non-path image metadata for layer=%r", getattr(layer, "name", str(layer)), + exc_info=True, ) - # logging.debug("Old keys (sample): %s... | New keys (sample): %s...", old_keys[:5], new_keys[:5]) - # logging.debug("Old basename sample: %s", [Path(p).name for p in old_paths_raw[:5]]) - # logging.debug("New basename sample: %s", [Path(p).name for p in new_paths_raw[:5]]) - return - if old_keys == new_keys: + if not old_paths: + logger.debug( + "Skipping remap for layer=%r: no existing layer metadata paths.", + getattr(layer, "name", str(layer)), + ) return - # Build map: canonical key -> new frame index - key_to_new_idx = {k: i for i, k in enumerate(new_keys)} - - # Determine which column is time/frame + # Determine time column (napari-specific choice) time_col = 1 if isinstance(layer, Tracks) else 0 - data = layer.data - if data is None: - return - - # Napari layers differ: Points/Tracks are ndarray-like; Shapes is list-like - is_list_like = isinstance(data, list) - - # Build an "old index -> new index" dict when we can map safely - # We only map old indices that correspond to an old canonical key present in new_keys. - idx_map = {} - for old_idx, k in enumerate(old_keys): - new_idx = key_to_new_idx.get(k) - if new_idx is not None: - idx_map[old_idx] = new_idx - - if not idx_map: - # No overlap at all; safest is to do nothing. - logging.warning( - f"Cannot remap {getattr(layer, 'name', str(layer))}:" - " no overlap between layer paths and current image paths.", + if logger.isEnabledFor(logging.DEBUG): + arr_before = np.asarray(layer.data) + logger.debug( + "Remap start layer=%r old_paths_len=%s new_paths_len=%s data_shape=%s frame_min=%s frame_max=%s", + getattr(layer, "name", str(layer)), + len(old_paths), + len(new_paths or []), + getattr(arr_before, "shape", None), + int(np.nanmin(arr_before[:, time_col])) if arr_before.size else None, + int(np.nanmax(arr_before[:, time_col])) if arr_before.size else None, ) - # logging.debug(f"Old keys (sample): {old_keys[:5]}... | New keys (sample): {new_keys[:5]}...") - return - - if is_list_like: - # Shapes-like: list of vertices arrays - new_data = [] - for verts in data: - arr = np.asarray(verts) - if arr.size == 0: - new_data.append(arr) - continue - arr2 = arr.copy() - t = arr2[:, time_col] + res = remap_layer_data_by_paths( + data=layer.data, + old_paths=old_paths, + new_paths=new_paths, + time_col=time_col, + policy=PathMatchPolicy.ORDERED_DEPTHS, + ) - # If t isn't integer-like, attempt safe conversion - try: - t_int = np.asarray(t).astype(int, copy=False) - except Exception: - # Can't interpret time column; keep shape as-is - new_data.append(arr2) - continue + logger.debug( + "Remap result layer=%r changed=%s mapped_count=%s depth=%s message=%s warnings=%s", + getattr(layer, "name", str(layer)), + res.changed, + res.mapped_count, + res.depth_used, + res.message, + res.warnings, + ) - arr2[:, time_col] = remap_array(t_int, idx_map) - new_data.append(arr2) + if res.applied and res.data is not None: + layer.data = res.data - layer.data = new_data + if res.accept_paths_update: + layer.metadata["paths"] = list(new_paths) + if isinstance(layer, Points): + mark_layer_presentation_changed(layer) + # Final debug logging + if res.depth_used is None: + logger.debug("Remap skipped for %s: %s", getattr(layer, "name", str(layer)), res.message) else: - arr = np.asarray(data) - if arr.size == 0: - return - if arr.ndim < 2 or arr.shape[1] <= time_col: - return - - arr2 = arr.copy() - t = arr2[:, time_col] - - # Handle NaNs/float time indices safely - # (If conversion fails, we skip remap.) - try: - t_int = np.asarray(t).astype(int, copy=False) - except Exception: - logging.warning( - f"Cannot remap {getattr(layer, 'name', str(layer))}: " - "time column could not be converted to int.", - ) - return - - arr2[:, time_col] = remap_array(t_int, idx_map) - layer.data = arr2 + logger.debug( + "Remap %s for %s (depth=%s, mapped=%s): %s", + "applied" if res.changed else "accepted-noop", + getattr(layer, "name", str(layer)), + res.depth_used, + res.mapped_count, + res.message, + ) except Exception: - logging.exception( - f"Failed to remap frame indices for layer {getattr(layer, 'name', str(layer))}", - ) - # Intentionally do not raise, simply warn and skip remapping + logger.exception("Failed to remap frame indices for layer %s", getattr(layer, "name", str(layer))) return def on_insert(self, event): layer = event.source[-1] - logging.debug(f"Inserting Layer {layer}") if isinstance(layer, Image): - paths = layer.metadata.get("paths") - if paths is None and is_video(layer.name): - self.video_widget.setVisible(True) - # Store the metadata and pass them on to the other layers - self._images_meta.update( - { - "paths": paths, - "shape": layer.level_shapes[0], - "root": layer.metadata["root"], - "name": layer.name, - } - ) - # Delay layer sorting - QTimer.singleShot(10, partial(self._move_image_layer_to_bottom, event.index)) + self._setup_image_layer(layer, event.index, reorder=True) elif isinstance(layer, Points): - # If the current Points layer comes from a config file, some have already - # been added and the body part names are different from the existing ones, - # then we update store's metadata and menus. - if layer.metadata.get("project", "") and self._stores: - new_metadata = layer.metadata.copy() - - keypoints_menu = self._menus[0].menus["label"] - current_keypoint_set = {keypoints_menu.itemText(i) for i in range(keypoints_menu.count())} - new_keypoint_set = set(new_metadata["header"].bodyparts) - diff = new_keypoint_set.difference(current_keypoint_set) - if diff: - answer = QMessageBox.question(self, "", "Do you want to display the new keypoints only?") - if answer == QMessageBox.Yes: - self.viewer.layers[-2].shown = False - - self.viewer.status = f"New keypoint{'s' if len(diff) > 1 else ''} {', '.join(diff)} found." - for _layer, store in self._stores.items(): - _layer.metadata["header"] = new_metadata["header"] - store.layer = _layer - - for menu in self._menus: - menu._map_individuals_to_bodyparts() - menu._update_items() - - # Remove the unnecessary layer newly added - QTimer.singleShot(10, self.viewer.layers.pop) - - # Always update the colormap to reflect the one in the config.yaml file - for _layer, store in self._stores.items(): - _layer.metadata["face_color_cycles"] = new_metadata["face_color_cycles"] - _layer.face_color = "label" - _layer.face_color_cycle = new_metadata["face_color_cycles"]["label"] - _layer.events.face_color() - store.layer = _layer - self._update_color_scheme() - - return - - if layer.metadata.get("tables", ""): - self._keypoint_mapping_button.show() - - store = keypoints.KeypointStore(self.viewer, layer) - self._stores[layer] = store - # TODO Set default dir of the save file dialog - if root := layer.metadata.get("root"): - update_save_history(root) - layer.metadata["controls"] = self - layer.text.visible = False - layer.bind_key("M", self.cycle_through_label_modes) - layer.bind_key("F", self.cycle_through_color_modes) - func = partial(_paste_data, store=store) - layer._paste_data = MethodType(func, layer) - layer.add = MethodType(keypoints._add, store) - layer.events.add(query_next_frame=Event) - layer.events.query_next_frame.connect(store._advance_step) - layer.bind_key("Shift-Right", store._find_first_unlabeled_frame) - layer.bind_key("Shift-Left", store._find_first_unlabeled_frame) - - layer.bind_key("Down", store.next_keypoint, overwrite=True) - layer.bind_key("Up", store.prev_keypoint, overwrite=True) - layer.face_color_mode = "cycle" - self._form_dropdown_menus(store) - - self._images_meta.update( - { - "project": layer.metadata.get("project"), - } - ) - self._radio_box.setEnabled(True) - self._color_grp.setEnabled(True) - self._trail_cb.setEnabled(True) - self._show_traj_plot_cb.setEnabled(True) - - # Hide the color pickers, as colormaps are strictly defined by users - controls = self.viewer.window.qt_viewer.dockLayerControls - point_controls = controls.widget().widgets[layer] - # Attempt to hide several napari UI elements. - # To avoid potential breakage, we pass if they don't exist. - try: - face_color_controls = point_controls._face_color_control.face_color_edit - face_color_label = point_controls._face_color_control.face_color_label - face_color_controls.hide() - face_color_label.hide() - except AttributeError: - pass - try: - # Border color edit in latest napari versions (0.6.6) - edge_color_controls = point_controls._border_color_control.border_color_edit - border_color_label = point_controls._border_color_control.border_color_edit_label - edge_color_controls.hide() - border_color_label.hide() - except AttributeError: - pass - # Hide out of slice checkbox - try: - out_of_slice_controls = point_controls._out_slice_checkbox_control.out_of_slice_checkbox - out_of_slice_label = point_controls._out_slice_checkbox_control.out_of_slice_checkbox_label - out_of_slice_controls.hide() - out_of_slice_label.hide() - except AttributeError: - pass - - # Add dropdown menu for colormap picking - colormap_selector = DropdownMenu(plt.colormaps, self) - colormap_selector.update_to(layer.metadata["colormap_name"]) - colormap_selector.currentTextChanged.connect(self._update_colormap) - point_controls.layout().addRow("colormap", colormap_selector) + self._setup_points_layer(layer, allow_merge=True) for layer_ in self.viewer.layers: if not isinstance(layer_, Image): self._remap_frame_indices(layer_) + self._refresh_video_panel_context() + self._refresh_layer_status_panel() def on_remove(self, event): layer = event.value n_points_layer = sum(isinstance(l, Points) for l in self.viewer.layers) - if isinstance(layer, Points) and n_points_layer == 0: - if self._color_scheme_display is not None: - self._display.reset() + + if isinstance(layer, Points): self._stores.pop(layer, None) - while self._menus: - menu = self._menus.pop() - self._layout.removeWidget(menu) - menu.deleteLater() - menu.destroy() - self._layer_to_menu = {} - self._trail_cb.setEnabled(False) - self._show_traj_plot_cb.setEnabled(False) - self.last_saved_label.hide() + + # Refresh color scheme panel regardless; it will clear itself if no valid target remains. + self._update_color_scheme() + self._trails_controller.on_points_layer_removed(layer) + + if n_points_layer == 0: + while self._menus: + menu = self._menus.pop() + self._layout.removeWidget(menu) + menu.deleteLater() + menu.destroy() + + self._layer_to_menu = {} + self._trail_cb.setEnabled(False) + self._show_traj_plot_cb.setEnabled(False) + self.last_saved_label.hide() + elif isinstance(layer, Image): - self._images_meta = dict() + self._image_meta = ImageMetadata() paths = layer.metadata.get("paths") if paths is None: self.video_widget.setVisible(False) + elif isinstance(layer, Tracks): - self._trail_cb.setChecked(False) - self._show_traj_plot_cb.setChecked(False) - self._trails = None + was_trails = self._trails_controller.on_tracks_layer_removed(layer) + if was_trails: + with QSignalBlocker(self._trail_cb): + self._trail_cb.setChecked(False) + + self._refresh_video_panel_context() + self._refresh_layer_status_panel() + + def _on_show_trails_toggled(self, state): + self._trails_controller.toggle(Qt.CheckState(state) == Qt.CheckState.Checked) + + def _ensure_promotion_save_target(self, layer: Points) -> bool: + """Ensure a prediction/machine source layer has a GT save_target set. + + Returns True if save_target is set (or already existed), False if user cancels. + """ + if not is_machine_layer(layer): + return True + + mig = migrate_points_layer_metadata(layer) + if hasattr(mig, "errors"): + logger.warning( + "Failed to migrate points layer metadata for layer=%r: %s", + getattr(layer, "name", layer), + mig, + ) + + res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=False) + if isinstance(res, ValidationError): + logger.warning( + "Points metadata validation failed for layer=%r during save target check: %s", + getattr(layer, "name", layer), + res, + ) + QMessageBox.warning(self, "Cannot save", "Layer metadata is invalid; see logs for details.") + return False + + pts: PointsMetadata = res + + if not requires_gt_promotion(pts): + return True + + anchor = safe_folder_anchor_from_points_layer(layer) + if not anchor: + QMessageBox.warning(self, "Cannot save", "Could not determine a folder anchor for saving.") + return False + + scorer = None + + # 1) Auto-discovered config.yaml always wins + cfg_path = None + try: + cfg_path = find_nearest_config(anchor) + except Exception: + logger.debug("Automatic config discovery failed for anchor=%r", anchor, exc_info=True) + + if cfg_path: + try: + scorer = ui_dialogs.load_scorer_from_config(cfg_path) + except Exception: + logger.exception("Failed to load auto-discovered config.yaml: %s", cfg_path) + ui_dialogs.warn_invalid_config_for_scorer( + self, + config_path=cfg_path, + reason="unreadable", + auto_found=True, + ) + return False + + if not scorer: + ui_dialogs.warn_invalid_config_for_scorer( + self, + config_path=cfg_path, + reason="missing_scorer", + auto_found=True, + ) + return False + + else: + # 2) No config found automatically -> let the user choose one + dialog_result = ui_dialogs.prompt_for_project_config_for_save( + self, + initial_dir=self._project_path or anchor, + window_title="Locate DLC config for scorer resolution", + message=( + "No DeepLabCut config.yaml could be found automatically for this machine-labeled layer.\n\n" + "If this layer belongs to a DLC project, choose its config.yaml so the save uses the " + "project scorer and standard naming.\n\n" + "If no config.yaml exists, you can continue without one." + ), + choose_button_text="Choose config.yaml", + skip_button_text="Continue without config", + resolve_scorer=True, + ) + + if dialog_result.action is ui_dialogs.ProjectConfigPromptAction.CANCEL: + return False + + if dialog_result.action is ui_dialogs.ProjectConfigPromptAction.ASSOCIATE: + scorer = dialog_result.scorer + + else: + # 3) Only if no config is available at all may sidecar be consulted + scorer = get_default_scorer(anchor) + + # 4) Final fallback: prompt manually + if not scorer: + suggested = suggest_human_placeholder(anchor) + while True: + s = _prompt_for_scorer(self, anchor=anchor, suggested=suggested) + if s is None: + return False + if s.startswith("human_"): + choice = QMessageBox.question( + self, + "Generic scorer name", + "You entered a generic scorer name starting with 'human_'.\n\n" + "We strongly recommend using a real name or stable identifier.\n" + "Do you want to keep this generic scorer anyway?", + QMessageBox.Yes | QMessageBox.No, + ) + if choice == QMessageBox.No: + suggested = s + continue + scorer = s + break + try: + set_default_scorer(anchor, scorer) + except Exception: + logger.debug("Failed to persist default scorer to sidecar", exc_info=True) + + updated = apply_gt_save_target( + pts, + anchor=anchor, + scorer=scorer, + dataset_key="keypoints", + ) + + out = write_points_meta( + layer, + updated, + merge_policy=MergePolicy.MERGE, + fields={"save_target"}, + migrate_legacy=True, + validate=True, + ) + + if hasattr(out, "errors"): + logger.warning("Failed to write save_target for layer=%r: %s", getattr(layer, "name", layer), out) + QMessageBox.warning(self, "Cannot save", "Failed to write save target metadata; see logs for details.") + return False + + return True + + def _toggle_overwrite_confirmation(self, state) -> None: + enabled = Qt.CheckState(state) == Qt.CheckState.Checked + settings.set_overwrite_confirmation_enabled(enabled) + self.viewer.status = "Overwrite confirmation enabled" if enabled else "Overwrite confirmation disabled" + + # Hack to save a KeyPoints layer without showing the Save dialog + def _save_layers_dialog(self, selected=False): + """Save layers (all or selected) to disk, using ``LayerList.save()``. + Parameters + ---------- + selected : bool + If True, only layers that are selected in the viewer will be saved. + By default, all layers are saved. + """ + + selected_layers = list(self.viewer.layers.selection) + msg = "" + if not len(self.viewer.layers): + msg = "There are no layers in the viewer to save." + elif selected and not len(selected_layers): + msg = "Please select a Points layer to save." + if msg: + QMessageBox.warning(self, "Nothing to save", msg, QMessageBox.Ok) + return + if len(selected_layers) == 1 and isinstance(selected_layers[0], Points): + layer = selected_layers[0] + + # Promotion-to-GT policy: never write back to machine/prediction sources. + ok = self._ensure_promotion_save_target(layer) + if not ok: + return + + logger.debug( + "About to save. io.kind=%r save_target=%r", + layer.metadata.get("io", {}).get("kind"), + layer.metadata.get("save_target"), + ) + try: + overridden_metadata, abort_save = self._maybe_prepare_project_path_override_metadata(layer) + if abort_save: + logger.debug("Save aborted during project-association path handling.") + return + + attributes = { + "name": layer.name, + "metadata": overridden_metadata if overridden_metadata is not None else dict(layer.metadata or {}), + "properties": dict(layer.properties or {}), + } + report = compute_overwrite_report_for_points_save(layer.data, attributes) + except Exception as e: + logger.exception("Failed to compute overwrite preflight for layer %r", getattr(layer, "name", layer)) + QMessageBox.warning( + self, + "Cannot save", + f"Failed to prepare save preflight:\n{e}", + QMessageBox.Ok, + ) + return + + if report is not None: + if not ui_dialogs.maybe_confirm_overwrite( + parent=self, + report=report, + ): + logger.debug("Save cancelled.") + return + + if overridden_metadata is not None: + with _temporary_layer_metadata(layer, overridden_metadata): + self.viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") + # Persist the successful override into live metadata after save + layer.metadata = dict(overridden_metadata) + else: + self.viewer.layers.save("__dlc__.h5", selected=True, plugin="napari-deeplabcut") + # hook to persist UI state on successful save + try: + self._trails_controller.persist_folder_ui_state_for_points_layer( + layer, + checkbox_checked=self._trail_cb.isChecked(), + ) + except Exception: + logger.debug( + "Failed to persist folder UI state after save for layer=%r", + getattr(layer, "name", layer), + exc_info=True, + ) + self.viewer.status = "Data successfully saved" + else: + dlg = QFileDialog() + hist = get_save_history() + dlg.setHistory(hist) + filename, _ = dlg.getSaveFileName( + caption=f"Save {'selected' if selected else 'all'} layers", + dir=hist[0], # home dir by default + ) + if filename: + self.viewer.layers.save(filename, selected=selected) + # hook to persist UI state on successful save + try: + if selected: + candidate_layers = [ly for ly in selected_layers if isinstance(ly, Points)] + else: + candidate_layers = list(self._stores.keys()) + + for ly in candidate_layers: + if ly in self.viewer.layers: + self._trails_controller.persist_folder_ui_state_for_points_layer( + ly, + checkbox_checked=self._trail_cb.isChecked(), + ) + except Exception: + logger.debug("Failed to persist sidecar UI state after multi-layer save", exc_info=True) + + else: + return + self._is_saved = True + self.last_saved_label.setText(f"Last saved at {str(datetime.now().time()).split('.')[0]}") + self.last_saved_label.show() + + def on_close(self, event): + if self._stores and not self._is_saved: + choice = QMessageBox.warning( + self, + "Warning", + "Data were not saved. Are you certain you want to leave?", + QMessageBox.Yes | QMessageBox.No, + ) + if choice == QMessageBox.Yes: + event.accept() + else: + event.ignore() + else: + event.accept() def on_active_layer_change(self, event) -> None: """Updates the GUI when the active layer changes @@ -1332,6 +1704,7 @@ def on_active_layer_change(self, event) -> None: is a multi-animal one, or False otherwise """ self._color_grp.setVisible(self._is_multianimal(event.value)) + # self._update_color_scheme() # if needed menu_idx = -1 if event.value is not None and isinstance(event.value, Points): menu_idx = self._layer_to_menu.get(event.value, -1) @@ -1342,22 +1715,20 @@ def on_active_layer_change(self, event) -> None: else: menu.setHidden(True) - def _update_colormap(self, colormap_name): + self._refresh_video_panel_context() + self._refresh_layer_status_panel() + + def _update_colormap(self, colormap_name: str): for layer in self.viewer.layers.selection: - if isinstance(layer, Points) and layer.metadata: - face_color_cycle_maps = build_color_cycles( - layer.metadata["header"], - colormap_name, - ) - layer.metadata["face_color_cycles"] = face_color_cycle_maps - face_color_prop = "label" - if self.color_mode == str(keypoints.ColorMode.INDIVIDUAL): - face_color_prop = "id" + if not isinstance(layer, Points) or not layer.metadata: + continue - layer.face_color = face_color_prop - layer.face_color_cycle = face_color_cycle_maps[face_color_prop] - layer.events.face_color() - self._update_color_scheme() + layer.metadata["config_colormap"] = colormap_name + mark_layer_presentation_changed(layer) + self._apply_points_coloring_from_metadata(layer) + + self._update_color_scheme() + self._trails_controller.on_points_visual_inputs_changed(checkbox_checked=self._trail_cb.isChecked()) @register_points_action("Change labeling mode") def cycle_through_label_modes(self, *args): @@ -1395,16 +1766,10 @@ def color_mode(self): @color_mode.setter def color_mode(self, mode: str | keypoints.ColorMode): self._color_mode = keypoints.ColorMode(mode) - if self._color_mode == keypoints.ColorMode.BODYPART: - face_color_mode = "label" - else: - face_color_mode = "id" - for layer in self.viewer.layers: + for layer in list(self._stores.keys()): if isinstance(layer, Points) and layer.metadata: - layer.face_color = face_color_mode - layer.face_color_cycle = layer.metadata["face_color_cycles"][face_color_mode] - layer.events.face_color() + self._apply_points_coloring_from_metadata(layer) for btn in self._color_mode_selector.buttons(): if btn.text().lower() == str(mode).lower(): @@ -1412,19 +1777,22 @@ def color_mode(self, mode: str | keypoints.ColorMode): break self._update_color_scheme() + self._trails_controller.on_points_visual_inputs_changed(checkbox_checked=self._trail_cb.isChecked()) def _is_multianimal(self, layer) -> bool: - is_multi = False - if layer is not None and isinstance(layer, Points): - try: - header = layer.metadata.get("header") - if header is not None: - ids = header.individuals - is_multi = len(ids) > 0 and ids[0] != "" - except AttributeError: - pass + if layer is None or not isinstance(layer, Points): + return False + + md = layer.metadata or {} + hdr = self._get_header_model_from_metadata(md) + if hdr is None: + return False - return is_multi + try: + inds = hdr.individuals + return bool(inds and len(inds) > 0 and str(inds[0]) != "") + except Exception: + return False def _active_layer_is_multianimal(self) -> bool: """Returns: whether the active layer is a multi-animal points layer""" @@ -1434,383 +1802,25 @@ def _active_layer_is_multianimal(self) -> bool: return False - -@Points.bind_key("E") -def toggle_edge_color(layer): - # Trick to toggle between 0 and 2 - layer.border_width = np.bitwise_xor(layer.border_width, 2) - - -class DropdownMenu(QComboBox): - def __init__(self, labels: Sequence[str], parent: QWidget | None = None): - super().__init__(parent) - self.update_items(labels) - - def update_to(self, text: str): - index = self.findText(text) - if index >= 0: - self.setCurrentIndex(index) - - def reset(self): - self.setCurrentIndex(0) - - def update_items(self, items): - self.clear() - self.addItems(items) - - -class KeypointsDropdownMenu(QWidget): - def __init__( - self, - store: keypoints.KeypointStore, - parent: QWidget | None = None, - ): - super().__init__(parent) - self.store = store - self.store.layer.events.current_properties.connect(self.update_menus) - self._locked = False - - self.id2label = defaultdict(list) - self.menus = dict() - self._map_individuals_to_bodyparts() - self._populate_menus() - - layout1 = QVBoxLayout() - layout1.addStretch(1) - group_box = QGroupBox("Keypoint selection") - layout2 = QVBoxLayout() - for menu in self.menus.values(): - layout2.addWidget(menu) - group_box.setLayout(layout2) - layout1.addWidget(group_box) - self.setLayout(layout1) - - def _map_individuals_to_bodyparts(self): - self.id2label.clear() # Empty dict so entries are ordered as in the config - for keypoint in self.store._keypoints: - label = keypoint.label - id_ = keypoint.id - if label not in self.id2label[id_]: - self.id2label[id_].append(label) - - def _populate_menus(self): - id_ = self.store.ids[0] - if id_: - menu = create_dropdown_menu(self.store, list(self.id2label), "id") - menu.currentTextChanged.connect(self.refresh_label_menu) - self.menus["id"] = menu - self.menus["label"] = create_dropdown_menu( - self.store, - self.id2label[id_], - "label", - ) - - def _update_items(self): - id_ = self.store.ids[0] - if id_: - self.menus["id"].update_items(list(self.id2label)) - self.menus["label"].update_items(self.id2label[id_]) - - def update_menus(self, event): - keypoint = self.store.current_keypoint - for attr, menu in self.menus.items(): - val = getattr(keypoint, attr) - if menu.currentText() != val: - menu.update_to(val) - - def refresh_label_menu(self, text: str): - menu = self.menus["label"] - menu.blockSignals(True) - menu.clear() - menu.blockSignals(False) - menu.addItems(self.id2label[text]) - - def smart_reset(self, event): - """Set current keypoint to the first unlabeled one.""" - if self._locked: # The currently selected point is not updated - return - unannotated = "" - already_annotated = self.store.annotated_keypoints - for keypoint in self.store._keypoints: - if keypoint not in already_annotated: - unannotated = keypoint - break - self.store.current_keypoint = unannotated if unannotated else self.store._keypoints[0] - - -def create_dropdown_menu(store, items, attr): - menu = DropdownMenu(items) - - def item_changed(ind): - current_item = menu.itemText(ind) - if current_item is not None: - setattr(store, f"current_{attr}", current_item) - - menu.currentIndexChanged.connect(item_changed) - return menu - - -# WelcomeWidget modified from: -# https://github.com/napari/napari/blob/a72d512972a274380645dae16b9aa93de38c3ba2/napari/_qt/widgets/qt_welcome.py#L28 -class QtWelcomeWidget(QWidget): - """Welcome widget to display initial information and shortcuts to user.""" - - sig_dropped = Signal("QEvent") - - def __init__(self, parent): - super().__init__(parent) - - # Create colored icon using theme - self._image = QLabel() - self._image.setObjectName("logo_silhouette") - self._image.setMinimumSize(300, 300) - self._label = QtWelcomeLabel( - """ - Drop a folder from within a DeepLabCut's labeled-data directory, - and, if labeling from scratch, - the corresponding project's config.yaml file. - """ - ) - - # Widget setup - self.setAutoFillBackground(True) - self.setAcceptDrops(True) - self._image.setAlignment(Qt.AlignCenter) - self._label.setAlignment(Qt.AlignCenter) - - # Layout - text_layout = QVBoxLayout() - text_layout.addWidget(self._label) - - layout = QVBoxLayout() - layout.addStretch() - layout.setSpacing(30) - layout.addWidget(self._image) - layout.addLayout(text_layout) - layout.addStretch() - - self.setLayout(layout) - - def paintEvent(self, event): - """Override Qt method. - - Parameters - ---------- - event : qtpy.QtCore.QEvent - Event from the Qt context. - """ - option = QStyleOption() - option.initFrom(self) - p = QPainter(self) - self.style().drawPrimitive(QStyle.PE_Widget, option, p, self) - - def _update_property(self, prop, value): - """Update properties of widget to update style. - - Parameters - ---------- - prop : str - Property name to update. - value : bool - Property value to update. - """ - self.setProperty(prop, value) - self.style().unpolish(self) - self.style().polish(self) - - def dragEnterEvent(self, event): - """Override Qt method. - - Provide style updates on event. - - Parameters - ---------- - event : qtpy.QtCore.QEvent - Event from the Qt context. - """ - self._update_property("drag", True) - if event.mimeData().hasUrls(): - event.accept() - else: - event.ignore() - - def dragLeaveEvent(self, event): - """Override Qt method. - - Provide style updates on event. - - Parameters - ---------- - event : qtpy.QtCore.QEvent - Event from the Qt context. + def _resolved_cycle_for_layer(self, layer: Points) -> dict: """ - self._update_property("drag", False) - - def dropEvent(self, event): - """Override Qt method. - - Provide style updates on event and emit the drop event. - - Parameters - ---------- - event : qtpy.QtCore.QEvent - Event from the Qt context. + Return the resolved category->color mapping used by the points layer, + so trails match the exact displayed colors. """ - self._update_property("drag", False) - self.sig_dropped.emit(event) - - -class ClickableLabel(QLabel): - clicked = Signal(str) - - def __init__(self, text="", color="turquoise", parent=None): - super().__init__(text, parent) - self._default_style = self.styleSheet() - self.color = color - - def mousePressEvent(self, event): - self.clicked.emit(self.text()) - - def enterEvent(self, event): - self.setCursor(QCursor(Qt.PointingHandCursor)) - self.setStyleSheet(f"color: {self.color}") - - def leaveEvent(self, event): - self.unsetCursor() - self.setStyleSheet(self._default_style) - + resolver = self._color_scheme_panel._resolver + cycles = resolver.get_face_color_cycles(layer) or {} -class LabelPair(QWidget): - def __init__(self, color: str, name: str, parent: QWidget): - super().__init__(parent) + prop = resolver.get_active_color_property(layer) + props = getattr(layer, "properties", {}) or {} + values = props.get(prop) - self._color = color - self._part_name = name - - self.color_label = QLabel("", parent=self) - self.part_label = ClickableLabel(name, color=color, parent=self) - - self.color_label.setToolTip(name) - self.part_label.setToolTip(name) - - self._format_label(self.color_label, 10, 10) - self._format_label(self.part_label) - - self.color_label.setStyleSheet(f"background-color: {color};") - - self._build() - - @staticmethod - def _format_label(label: QLabel, height: int = None, width: int = None): - label.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) - if height is not None: - label.setMaximumHeight(height) - if width is not None: - label.setMaximumWidth(width) - - def _build(self): - layout = QHBoxLayout() - layout.addWidget(self.color_label, alignment=Qt.AlignmentFlag.AlignLeft) - layout.addWidget(self.part_label, alignment=Qt.AlignmentFlag.AlignLeft) - self.setLayout(layout) - - @property - def color(self): - return self._color - - @color.setter - def color(self, color: str): - self._color = color - self.color_label.setStyleSheet(f"background-color: {color};") - - @property - def part_name(self): - return self._part_name - - @part_name.setter - def part_name(self, part_name: str): - self._part_name = part_name - self.part_label.setText(part_name) - self.part_label.setToolTip(part_name) - self.color_label.setToolTip(part_name) - - -class ColorSchemeDisplay(QScrollArea): - added = Signal(object) - - def __init__(self, parent): - super().__init__(parent) - - self.scheme_dict = {} # {name: color} mapping - self._layout = QVBoxLayout() - self._layout.setSpacing(0) - self._container = QWidget(parent=self) # workaround to use setWidget, let me know if there's a better option - - self._build() + if prop == "id": + try: + vals = np.asarray(values, dtype=object).ravel() if values is not None else np.array([], dtype=object) + if len(vals) == 0 or all(v in ("", None) or misc._is_nan_value(v) for v in vals): + prop = "label" + except Exception: + prop = "label" - @property - def labels(self): - labels = [] - for i in range(self._layout.count()): - item = self._layout.itemAt(i) - if w := item.widget(): - labels.append(w) - return labels - - def _build(self): - self._container.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Maximum) # feel free to change those - self._container.setLayout(self._layout) - self._container.adjustSize() - - self.setWidget(self._container) - - self.setWidgetResizable(True) - self.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding) # feel free to change those - # self.setMaximumHeight(150) - self.setBaseSize(100, 200) - - self.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOn) - self.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) - - def add_entry(self, name, color): - self.scheme_dict.update({name: color}) - - widget = LabelPair(color, name, self) - self._layout.addWidget(widget, alignment=Qt.AlignmentFlag.AlignLeft) - self.added.emit(widget) - - def update_color_scheme(self, new_color_scheme) -> None: - logging.debug(f"Updating color scheme: {self._layout.count()} widgets") - self.scheme_dict = {name: color for name, color in new_color_scheme.items()} - names = list(new_color_scheme.keys()) - existing_widgets = self._layout.count() - required_widgets = len(self.scheme_dict) - - # update existing widgets - for idx in range(min(existing_widgets, required_widgets)): - logging.debug(f" updating {idx}") - w = self._layout.itemAt(idx).widget() - w.setVisible(True) - w.part_name = names[idx] - w.color = self.scheme_dict[names[idx]] - - # remove extra widgets - for i in range(max(existing_widgets - required_widgets, 0)): - logging.debug(f" hiding {required_widgets + i}") - if w := self._layout.itemAt(required_widgets + i).widget(): - logging.debug(" done!") - w.setVisible(False) - - # add missing widgets - for i in range(max(required_widgets - existing_widgets, 0)): - logging.debug(f" adding {existing_widgets + i}") - name = names[existing_widgets + i] - self.add_entry(name, self.scheme_dict[name]) - logging.debug(" done!") - - def reset(self): - self.scheme_dict = {} - for i in range(self._layout.count()): - w = self._layout.itemAt(i).widget() - logging.debug(f"making {w} invisible") - w.setVisible(False) + cycle = cycles.get(prop, {}) or {} + return {str(k): np.asarray(v, dtype=float) for k, v in cycle.items()} diff --git a/src/napari_deeplabcut/_writer.py b/src/napari_deeplabcut/_writer.py index a9bf6b55..daddedeb 100644 --- a/src/napari_deeplabcut/_writer.py +++ b/src/napari_deeplabcut/_writer.py @@ -1,86 +1,32 @@ -import os -from itertools import groupby +"""Writers for DeepLabCut data formats.""" + +# src/napari_deeplabcut/_writer.py +import logging from pathlib import Path -import pandas as pd -import yaml -from napari.layers import Shapes -from napari_builtins.io import napari_write_shapes from skimage.io import imsave from skimage.util import img_as_ubyte -from napari_deeplabcut import misc -from napari_deeplabcut._reader import _load_config - - -def _write_config(config_path: str, params: dict): - with open(config_path, "w") as file: - yaml.safe_dump(params, file) +from napari_deeplabcut.core.io import write_hdf +logger = logging.getLogger(__name__) -def _form_df(points_data, metadata): - temp = pd.DataFrame(points_data[:, -1:0:-1], columns=["x", "y"]) - properties = metadata["properties"] - meta = metadata["metadata"] - temp["bodyparts"] = properties["label"] - temp["individuals"] = properties["id"] - temp["inds"] = points_data[:, 0].astype(int) - temp["likelihood"] = properties["likelihood"] - temp["scorer"] = meta["header"].scorer - df = temp.set_index(["scorer", "individuals", "bodyparts", "inds"]).stack() - df.index.set_names("coords", level=-1, inplace=True) - df = df.unstack(["scorer", "individuals", "bodyparts", "coords"]) - df.index.name = None - if not properties["id"][0]: - df = df.droplevel("individuals", axis=1) - df = df.reindex(meta["header"].columns, axis=1) - # Fill unannotated rows with NaNs - # df = df.reindex(range(len(meta['paths']))) - # df.index = meta['paths'] - if meta["paths"]: - df.index = [meta["paths"][i] for i in df.index] - misc.guarantee_multiindex_rows(df) - return df - -def write_hdf(filename, data, metadata): - file, _ = os.path.splitext(filename) # FIXME Unused currently - df = _form_df(data, metadata) - meta = metadata["metadata"] - name = metadata["name"] - root = meta["root"] - if "machine" in name: # We are attempting to save refined model predictions - df.drop("likelihood", axis=1, level="coords", inplace=True, errors="ignore") - header = misc.DLCHeader(df.columns) - gt_file = "" - for file in os.listdir(root): - if file.startswith("CollectedData") and file.endswith("h5"): - gt_file = file - break - if gt_file: # Refined predictions must be merged into the existing data - df_gt = pd.read_hdf(os.path.join(root, gt_file)) - new_scorer = df_gt.columns.get_level_values("scorer")[0] - header.scorer = new_scorer - df.columns = header.columns - df = pd.concat((df, df_gt)) - df = df[~df.index.duplicated(keep="first")] - name = os.path.splitext(gt_file)[0] - else: - # Let us fetch the config.yaml file to get the scorer name... - project_folder = Path(root).parents[1] - config = _load_config(str(project_folder / "config.yaml")) - new_scorer = config["scorer"] - header.scorer = new_scorer - df.columns = header.columns - name = f"CollectedData_{new_scorer}" - df.sort_index(inplace=True) - filename = name + ".h5" - path = os.path.join(root, filename) - df.to_hdf(path, key="keypoints", mode="w") - df.to_csv(path.replace(".h5", ".csv")) - return filename +def write_hdf_napari_dlc(path: str, data, attributes: dict) -> list[str]: + if not path: + path = "__dlc__.h5" # dummy path to trigger napari-deeplabcut-specific handling in write_hdf + if path != "__dlc__.h5": + logger.info( + "This function should not be used with a user-specified path." + "Layer metadata from the reader (in attributes) is used to decide where to save rather than user input." + "One path that requires user input is when machine labels" + "are refined by a human (as we do not want to overwrite machine labels)," + "but that case is handled separately." + ) + return write_hdf(path, data, attributes) +# TODO rewrite explicitly as napari-facing func def _write_image(data, output_path, plugin=None): Path(output_path).parent.mkdir(parents=True, exist_ok=True) imsave( @@ -89,22 +35,3 @@ def _write_image(data, output_path, plugin=None): plugin=plugin, check_contrast=False, ) - - -def write_masks(foldername, data, metadata): - folder, _ = os.path.splitext(foldername) - os.makedirs(folder, exist_ok=True) - filename = os.path.join(folder, "{}_obj_{}.png") - shapes = Shapes(data, shape_type="polygon") - meta = metadata["metadata"] - frame_inds = [int(array[0, 0]) for array in data] - shape_inds = [] - for _, group in groupby(frame_inds): - shape_inds += range(sum(1 for _ in group)) - masks = shapes.to_masks(mask_shape=meta["shape"][1:]) - for n, mask in enumerate(masks): - image_name = os.path.basename(meta["paths"][frame_inds[n]]) - output_path = filename.format(os.path.splitext(image_name)[0], shape_inds[n]) - _write_image(mask, output_path) - napari_write_shapes(os.path.join(folder, "vertices.csv"), data, metadata) - return folder diff --git a/src/napari_deeplabcut/config/__init__.py b/src/napari_deeplabcut/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/napari_deeplabcut/config/_autostart.py b/src/napari_deeplabcut/config/_autostart.py new file mode 100644 index 00000000..2fb0b955 --- /dev/null +++ b/src/napari_deeplabcut/config/_autostart.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import logging +from weakref import WeakSet + +import napari +from napari.layers import Points +from napari.utils.events import Event +from qtpy.QtCore import QTimer + +from napari_deeplabcut._widgets import KeypointControls +from napari_deeplabcut.config.settings import get_auto_open_keypoint_controls +from napari_deeplabcut.core.metadata import read_points_meta + +logger = logging.getLogger(__name__) + +# Track viewers where the observer has already been installed. +_INSTALLED_VIEWERS: WeakSet = WeakSet() + + +def _is_dlc_points_layer(layer) -> bool: + """Return True if layer looks like a valid DLC Points layer.""" + if not isinstance(layer, Points): + return False + + res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=False) + if hasattr(res, "errors"): + return False + return res.header is not None + + +def get_existing_keypoint_controls(viewer): + for widget in viewer.window.dock_widgets.values(): + if isinstance(widget, KeypointControls): + return widget + return None + + +def _ensure_keypoint_controls_open(viewer) -> None: + """Open Keypoint controls dock widget if enabled in settings.""" + if viewer is None or not get_auto_open_keypoint_controls(): + return + if get_existing_keypoint_controls(viewer) is not None: + return + try: + # Public API: returns the existing widget if already docked. + viewer.window.add_plugin_dock_widget( + "napari-deeplabcut", + "Keypoint controls", + ) + except Exception: + logger.debug("Failed to open Keypoint controls dock widget.", exc_info=True) + + +def _maybe_open_for_inserted_layer(viewer, layer) -> None: + """Open controls when a qualifying DLC points layer is present.""" + if viewer is None or layer is None: + return + if not _is_dlc_points_layer(layer): + return + + # Defer slightly to avoid re-entrancy during layer insertion. + QTimer.singleShot(0, lambda: _ensure_keypoint_controls_open(viewer)) + + +def maybe_install_keypoint_controls_autostart(viewer=None) -> None: + """ + Install a per-viewer observer that auto-opens Keypoint controls when a valid + DLC Points layer is inserted. + + Safe to call repeatedly; installation happens once per viewer. + """ + if viewer is None: + viewer = napari.current_viewer() + if viewer is None: + return + + if viewer in _INSTALLED_VIEWERS: + return + + _INSTALLED_VIEWERS.add(viewer) + + def _on_insert(event: Event) -> None: + try: + layer = event.value if hasattr(event, "value") else event.source[-1] + except Exception: + layer = None + _maybe_open_for_inserted_layer(viewer, layer) + + viewer.layers.events.inserted.connect(_on_insert) + + # Also scan already-present layers in case installation happens late. + for layer in list(viewer.layers): + if _is_dlc_points_layer(layer): + QTimer.singleShot(0, lambda v=viewer: _ensure_keypoint_controls_open(v)) + break diff --git a/src/napari_deeplabcut/config/keybinds.py b/src/napari_deeplabcut/config/keybinds.py new file mode 100644 index 00000000..765a6974 --- /dev/null +++ b/src/napari_deeplabcut/config/keybinds.py @@ -0,0 +1,159 @@ +"""Central registry and installers for napari-deeplabcut keybindings (source of truth).""" + +# src/napari_deeplabcut/config/keybinds.py +from __future__ import annotations + +from collections.abc import Callable, Iterable +from dataclasses import dataclass +from enum import Enum, auto + +import numpy as np +from napari.layers import Points + +_global_points_bindings_installed = False + + +@dataclass(frozen=True) +class BindingContext: + controls: object + store: object + + +@dataclass(frozen=True) +class ShortcutSpec: + keys: tuple[str, ...] + description: str + group: str + scope: str + action: ShortcutAction | None = None # optional enum for programmatic reference + get_callback: Callable[[BindingContext], Callable] | None = None + overwrite: bool = False + when: str | None = None # optional UI note, e.g. "Multi-animal layers only" + + +class ShortcutAction(Enum): + CYCLE_LABEL_MODE = auto() + CYCLE_COLOR_MODE = auto() + NEXT_KEYPOINT = auto() + PREV_KEYPOINT = auto() + JUMP_UNLABELED_FRAME = auto() + TOGGLE_EDGE_COLOR = auto() + + +# ---------------------------------------- +# Functions with associated keybind callbacks +# ---------------------------------------- +def _cycle_label_mode(ctx: BindingContext): + return ctx.controls.cycle_through_label_modes + + +def _cycle_color_mode(ctx: BindingContext): + return ctx.controls.cycle_through_color_modes + + +def _next_keypoint(ctx: BindingContext): + return ctx.store.next_keypoint + + +def _prev_keypoint(ctx: BindingContext): + return ctx.store.prev_keypoint + + +def _jump_unlabeled_frame(ctx: BindingContext): + return ctx.store._find_first_unlabeled_frame + + +# ---- Single source of truth for displayed shortcuts ---- + +SHORTCUTS: tuple[ShortcutSpec, ...] = ( + ShortcutSpec( + keys=("M",), + action=ShortcutAction.CYCLE_LABEL_MODE, + get_callback=_cycle_label_mode, + description="Change labeling mode", + group="Annotation", + scope="points-layer", + ), + ShortcutSpec( + keys=("F",), + action=ShortcutAction.CYCLE_COLOR_MODE, + get_callback=_cycle_color_mode, + description="Change color mode", + group="Display", + scope="points-layer", + when="Only cycles beyond bodypart mode for multi-animal layers", + ), + ShortcutSpec( + keys=("Down",), + action=ShortcutAction.NEXT_KEYPOINT, + get_callback=_next_keypoint, + description="Select next keypoint", + group="Navigation", + scope="points-layer", + overwrite=True, + ), + ShortcutSpec( + keys=("Up",), + action=ShortcutAction.PREV_KEYPOINT, + get_callback=_prev_keypoint, + description="Select previous keypoint", + group="Navigation", + scope="points-layer", + overwrite=True, + ), + ShortcutSpec( + keys=("Shift-Right", "Shift-Left"), + action=ShortcutAction.JUMP_UNLABELED_FRAME, + get_callback=_jump_unlabeled_frame, + description="Jump to first unlabeled frame", + group="Navigation", + scope="points-layer", + ), + ShortcutSpec( + keys=("E",), + action=ShortcutAction.TOGGLE_EDGE_COLOR, + description="Toggle point edge color", + group="Display", + scope="global-points", + ), +) + + +def iter_shortcuts() -> Iterable[ShortcutSpec]: + return SHORTCUTS + + +def _bind_each_key(layer: Points, keys: tuple[str, ...], callback, *, overwrite: bool = False) -> None: + for key in keys: + layer.bind_key(key, callback, overwrite=overwrite) + + +def install_points_layer_keybindings(layer: Points, controls, store) -> None: + ctx = BindingContext(controls=controls, store=store) + + for spec in SHORTCUTS: + if spec.scope != "points-layer" or spec.get_callback is None: + continue + + callback = spec.get_callback(ctx) + _bind_each_key(layer, spec.keys, callback, overwrite=spec.overwrite) + + +# ------- Global keybinds that apply to all points layers, e.g. toggling edge color ------- + + +def toggle_edge_color(layer): + layer.border_width = np.bitwise_xor(layer.border_width, 2) + + +def install_global_points_keybindings() -> None: + global _global_points_bindings_installed + if _global_points_bindings_installed: + return + + for spec in SHORTCUTS: + if spec.scope == "global-points" and spec.action == ShortcutAction.TOGGLE_EDGE_COLOR: + for key in spec.keys: + Points.bind_key(key)(toggle_edge_color) + + _global_points_bindings_installed = True diff --git a/src/napari_deeplabcut/config/models.py b/src/napari_deeplabcut/config/models.py new file mode 100644 index 00000000..39bf515d --- /dev/null +++ b/src/napari_deeplabcut/config/models.py @@ -0,0 +1,628 @@ +# src/napari_deeplabcut/config/models.py +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any + +import numpy as np +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + + +def unsorted_unique(array: Sequence) -> np.ndarray: + """Return the unsorted unique elements of an array.""" + _, inds = np.unique(array, return_index=True) + return np.asarray(array)[np.sort(inds)] + + +# ----------------------------------------------------------------------------- +# Enums +# ----------------------------------------------------------------------------- +class MetadataKind(str, Enum): + """High-level metadata container type.""" + + IMAGE = "image" + POINTS = "points" + + +# ----------------------------------------------------------------------------- +# Project structure models +# ----------------------------------------------------------------------------- +class DLCProjectContext(BaseModel): + """ + Best-effort DLC project/location context inferred from available hints. + + All fields are optional because users may open partial project fragments + (e.g. only a video, only a labeled-data folder, only annotations). + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) + + root_anchor: Path | None = Field( + default=None, + description="Base folder anchor used for project-relative resolution when no stronger hint exists.", + ) + project_root: Path | None = Field( + default=None, + description="Folder containing config.yaml, if inferable.", + ) + config_path: Path | None = Field( + default=None, + description="Resolved path to DLC config.yaml, if inferable.", + ) + dataset_folder: Path | None = Field( + default=None, + description="Resolved labeled-data/ folder, if inferable.", + ) + + @model_validator(mode="after") + def _normalize_and_validate(self) -> DLCProjectContext: + def _norm(p: Path | None) -> Path | None: + if p is None: + return None + try: + return p.expanduser().resolve() + except Exception: + return p + + root_anchor = _norm(self.root_anchor) + project_root = _norm(self.project_root) + config_path = _norm(self.config_path) + dataset_folder = _norm(self.dataset_folder) + + # If config_path is present, project_root should default to its parent + if config_path is not None and project_root is None: + project_root = config_path.parent + + # If project_root exists and config_path is missing, infer config path conventionally + if project_root is not None and config_path is None: + candidate = project_root / "config.yaml" + if candidate.exists(): + config_path = candidate + + # If root_anchor is missing, prefer project_root, otherwise dataset_folder + if root_anchor is None: + root_anchor = project_root or dataset_folder + + object.__setattr__(self, "root_anchor", root_anchor) + object.__setattr__(self, "project_root", project_root) + object.__setattr__(self, "config_path", config_path) + object.__setattr__(self, "dataset_folder", dataset_folder) + return self + + +# ----------------------------------------------------------------------------- +# Header model (authoritative wrapper) +# ----------------------------------------------------------------------------- +class DLCHeaderModel(BaseModel): + """ + Authoritative, pandas-optional DLC header specification. + + Design goals + ------------ + - Pandas is NOT required at runtime for this model. + - Internal representation is always portable: + columns: list[tuple[str, ...]] + names: optional list[str] aligned to tuple length (may be None) + - Accepts pandas.MultiIndex as input when pandas is installed (best-effort), + but never stores it internally. + + Semantics + --------- + Canonical meaning (when present): + scorer, individuals, bodyparts, coords + + We support both: + - 4-level canonical tuples: (scorer, individuals, bodyparts, coords) + - 3-level legacy tuples: (scorer, bodyparts, coords) -> treated as individuals="" + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + columns: list[tuple[str, ...]] = Field(default_factory=list) + names: list[str] | None = None + + # ---------------------------- + # Input normalization + # ---------------------------- + @field_validator("columns", mode="before") + @classmethod + def _coerce_columns(cls, v: Any) -> list[tuple[str, ...]]: + """ + Accept common header representations and normalize to list-of-tuples. + + Supported: + - list/tuple of tuples + - pandas.MultiIndex (if pandas installed) + - pandas.Index (if pandas installed) + """ + if v is None: + return [] + + # If payload dict accidentally arrives here, unwrap. + if isinstance(v, dict) and "columns" in v: + v = v["columns"] + + # pandas objects: import lazily (optional dependency) + try: + import pandas as pd # type: ignore + + if isinstance(v, pd.MultiIndex): + return [tuple(map(str, t)) for t in v.to_list()] + if isinstance(v, pd.Index): + return [(str(x),) for x in v.to_list()] + except Exception: + pass + + # list/tuple input + if isinstance(v, (list, tuple)): + out: list[tuple[str, ...]] = [] + for item in v: + if isinstance(item, (list, tuple)): + out.append(tuple(map(str, item))) + else: + out.append((str(item),)) + return out + + raise TypeError(f"Unsupported columns type: {type(v)!r}") + + @field_validator("names", mode="before") + @classmethod + def _coerce_names(cls, v: Any) -> list[str] | None: + if v is None: + return None + if isinstance(v, (list, tuple)): + return [str(x) for x in v] + return None + + @field_validator("names") + @classmethod + def _validate_names(cls, v: list[str] | None, info) -> list[str] | None: + # If provided, length must match tuple length (when columns non-empty). + cols = info.data.get("columns") or [] + if v is None or not cols: + return v + n = len(cols[0]) + if len(v) != n: + # Be tolerant: if mismatch, drop names rather than fail hard. + return None + return v + + # ---------------------------- + # Core shape helpers (pandas-free) + # ---------------------------- + @property + def nlevels(self) -> int: + return len(self.columns[0]) if self.columns else 0 + + def _level_index(self, name: str) -> int | None: + if not self.names: + return None + try: + return self.names.index(name) + except ValueError: + return None + + def _get_level_values(self, idx: int) -> list[str]: + if not self.columns: + return [] + vals = [] + for t in self.columns: + if idx < len(t): + vals.append(str(t[idx])) + # stable unique, preserve order + return list(dict.fromkeys(vals)) + + def _canonical_4(self) -> list[tuple[str, str, str, str]]: + """ + Return canonical 4-tuples (scorer, individuals, bodyparts, coords). + + - If already 4-level: mapped by names if present, else by position. + - If 3-level legacy: individuals="" inserted. + - Otherwise: best-effort fallback to positional mapping. + """ + out: list[tuple[str, str, str, str]] = [] + if not self.columns: + return out + + # If names are present, use them preferentially. + ix_scorer = self._level_index("scorer") + ix_inds = self._level_index("individuals") + ix_bp = self._level_index("bodyparts") + ix_coords = self._level_index("coords") + + for t in self.columns: + tt = tuple(map(str, t)) + if len(tt) == 4: + if None not in (ix_scorer, ix_inds, ix_bp, ix_coords): + out.append((tt[ix_scorer], tt[ix_inds], tt[ix_bp], tt[ix_coords])) # type: ignore[index] + else: + out.append((tt[0], tt[1], tt[2], tt[3])) + elif len(tt) == 3: + # legacy (scorer, bodyparts, coords) + if None not in (ix_scorer, ix_bp, ix_coords): + out.append((tt[ix_scorer], "", tt[ix_bp], tt[ix_coords])) # type: ignore[index] + else: + out.append((tt[0], "", tt[1], tt[2])) + else: + # unknown shape: do best-effort padding/truncation + scorer = tt[0] if len(tt) > 0 else "" + bodypart = tt[1] if len(tt) > 1 else "" + coords = tt[-1] if len(tt) > 0 else "" + out.append((scorer, "", bodypart, coords)) + + return out + + # ---------------------------- + # Pandas interop + # ---------------------------- + def as_multiindex(self): + """ + OPTIONAL: Convert to pandas.MultiIndex if pandas is installed. + This keeps pandas-specific modules working without making pandas a core invariant. + """ + try: + import pandas as pd # type: ignore + except Exception as e: + raise RuntimeError("pandas is required for as_multiindex() but is not installed") from e + + canon = self._canonical_4() + names = ["scorer", "individuals", "bodyparts", "coords"] + return pd.MultiIndex.from_tuples(canon, names=names) + + # ---------------------------- + # Self-documenting API (what callers should use) + # ---------------------------- + @property + def scorer(self) -> str | None: + canon = self._canonical_4() + return canon[0][0] if canon else None + + @property + def scorers(self) -> list[str]: + canon = self._canonical_4() + scorers = [s for s, _, _, _ in canon] + return list(dict.fromkeys(scorers)) + + @property + def individuals(self) -> list[str]: + canon = self._canonical_4() + inds = [i for _, i, _, _ in canon] + uniq = list(dict.fromkeys(inds)) + return uniq if uniq else [""] + + @property + def bodyparts(self) -> list[str]: + canon = self._canonical_4() + bps = [b for _, _, b, _ in canon] + return list(dict.fromkeys(bps)) + + @property + def coords(self) -> list[str]: + canon = self._canonical_4() + cs = [c for _, _, _, c in canon] + return list(dict.fromkeys(cs)) + + def with_scorer(self, scorer: str) -> DLCHeaderModel: + """ + Return a new header with scorer replaced (pandas-free). + + Replaces legacy `header.scorer = ...`. + """ + canon = self._canonical_4() + new_cols = [(str(scorer), ind, bp, coord) for _, ind, bp, coord in canon] + return self.model_copy(update={"columns": new_cols, "names": ["scorer", "individuals", "bodyparts", "coords"]}) + + def form_individual_bodypart_pairs(self) -> list[tuple[str, str]]: + """ + Return ordered list of (individual, bodypart) pairs. + + This matches the previous DLCHeader behavior but is pandas-free. + """ + canon = self._canonical_4() + pairs = [(ind, bp) for _, ind, bp, _ in canon] + # stable unique preserving first seen order + return list(dict.fromkeys(pairs)) + + @classmethod + def from_config(cls, config: dict) -> DLCHeaderModel: + """ + Build header from DLC config.yaml content (single or multi-animal), + without requiring pandas. + """ + multi = bool(config.get("multianimalproject", False)) + scorer = str(config["scorer"]) + + cols: list[tuple[str, ...]] = [] + names: list[str] + + if multi: + inds = [str(x) for x in config["individuals"]] + bps = [str(x) for x in config["multianimalbodyparts"]] + coords = ["x", "y"] + for i in inds: + for bp in bps: + for c in coords: + cols.append((scorer, i, bp, c)) + # unique bodyparts in "single" individual bucket + for bp in [str(x) for x in config.get("uniquebodyparts", [])]: + for c in coords: + cols.append((scorer, "single", bp, c)) + names = ["scorer", "individuals", "bodyparts", "coords"] + else: + bps = [str(x) for x in config["bodyparts"]] + coords = ["x", "y"] + for bp in bps: + for c in coords: + cols.append((scorer, bp, c)) + names = ["scorer", "bodyparts", "coords"] + + return cls(columns=cols, names=names) + + def to_metadata_payload(self) -> dict[str, Any]: + """ + Portable payload to store in napari layer.metadata. + + Never stores pandas objects. + """ + return {"columns": self.columns, "names": self.names} + + +# ----------------------------------------------------------------------------- +# Metadata & I/O models +# ----------------------------------------------------------------------------- + + +class AnnotationKind(str, Enum): + """Semantic kind of keypoint annotations for deterministic IO routing. + + Notes + ----- + This is used to enforce safe saving policies: + - ``gt``: ground-truth labels (e.g. ``CollectedData_*.h5``) + - ``machine``: machine predictions/refinements (e.g. ``machinelabels*.h5``) + + The napari layer display name must never be used to infer this value. + """ + + GT = "gt" + MACHINE = "machine" + + +class IOProvenance(BaseModel): + """Authoritative provenance for a Points layer. + + This model captures *identity* for IO, independent of the napari layer name. + + Design goals + ------------ + - Prefer project-relative, OS-agnostic paths. + - Store relative paths using POSIX separators ('/'), even on Windows. + - Be explicit about annotation kind so saving never relies on directory ordering. + + Fields + ------ + schema_version: + Version marker for forward-compatible evolution. + project_root: + Optional project root directory. When set, ``source_relpath_posix`` is + interpreted relative to this root. + source_relpath_posix: + Project-relative path encoded with POSIX separators ('/'). + Example: ``labeled-data/test/CollectedData_John.h5``. + kind: + Whether this layer is ground-truth or machine output. + dataset_key: + HDF5 key used for the keypoints table (default: ``keypoints``). + """ + + # Keep minimal but resilient to future additions + model_config = ConfigDict(extra="allow") + + schema_version: int = Field(default=1, description="Provenance schema version") + project_root: str | None = Field(default=None, description="Project root directory") + source_relpath_posix: str | None = Field( + default=None, + description="Project-relative POSIX path to the source .h5 (forward slashes).", + ) + kind: AnnotationKind | None = Field(default=None, description="Annotation kind for routing", strict=True) + dataset_key: str = Field(default="keypoints", description="HDF5 key for keypoints table") + + @field_validator("source_relpath_posix") + @classmethod + def _normalize_relpath(cls, v: str | None) -> str | None: + """Normalize provenance paths to POSIX separators. + + This keeps stored metadata OS-agnostic and stable across platforms. + """ + if v is None: + return None + return v.replace("\\", "/") + + @field_validator("kind") + @classmethod + def _validate_kind(cls, v: AnnotationKind | str | None) -> AnnotationKind | None: + """Validate that kind is either an AnnotationKind or a valid string.""" + if v is None: + return None + if isinstance(v, AnnotationKind): + return v + try: + return AnnotationKind(v) + except ValueError as e: + raise ValueError(f"Invalid annotation kind: {v!r}") from e + + @field_validator("project_root") + @classmethod + def _validate_project_root(cls, v: str | None) -> str | None: + """Store project_root without requiring the path to exist. + + Existence and type (file vs directory) checks are intentionally deferred + to path resolution time to keep serialized provenance portable across + machines and project locations. + """ + if v is None: + return None + return str(v) + + +class ImageMetadata(BaseModel): + """ + Metadata for Image layers. + + Stored in napari layer.metadata. + + Invariants + ---------- + - paths, if present, define frame order + - root, if present, is a directory path + """ + + model_config = ConfigDict(extra="allow") + + kind: MetadataKind = Field(default=MetadataKind.IMAGE, strict=True) + paths: list[str] | None = None + root: str | None = None + shape: tuple[int, ...] | None = None + name: str | None = None + + def __repr__(self) -> str: + # Only show non-None fields, truncate long lists + fields = [] + for k in ("kind", "name", "root", "shape", "paths"): + v = getattr(self, k) + if v is not None: + if k == "paths": + if isinstance(v, list): + v = f"[{len(v)} paths]" + fields.append(f"{k}={v!r}") + return f"ImageMetadata({', '.join(fields)})" + + +class PointsMetadata(BaseModel): + """ + Metadata for Points layers. + + Invariants + ---------- + - header defines keypoint structure + - root + paths must align with ImageMetadata when present + - config_colormap stores the configured bodypart colormap when known + - face_color_cycles, when present, is derived display state and not authoritative + """ + + kind: MetadataKind = Field(default=MetadataKind.POINTS) + + root: str | None = None + paths: list[str] | None = None + shape: tuple[int, ...] | None = None + name: str | None = None + + project: str | None = None + header: DLCHeaderModel | None = None + io: IOProvenance | None = None + save_target: IOProvenance | None = None + + config_colormap: str | None = None + face_color_cycles: dict[str, dict[str, Any]] | None = None + colormap_name: str | None = None + + tables: dict[str, dict[str, str]] | None = None + + # Non-serializable runtime attachments (allowed but ignored by pydantic) + controls: Any | None = Field(default=None, exclude=True) + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + +# ----------------------------------------------------------------------------- +# Save conflict models +# ----------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class ConflictEntry: + frame_label: str + keypoints: tuple[str, ...] + + +@dataclass(frozen=True) +class OverwriteConflictReport: + """ + UI-facing overwrite-conflict contract. + + This model is intentionally decoupled from pandas so the dialog layer only + depends on plain Python data structures. + + Semantics + --------- + n_overwrites: + Number of (frame/image, keypoint) overwrite events. + n_frames: + Number of distinct frames/images affected. + entries: + Detailed per-frame/image conflict rows to display in the dialog. + truncated_entries: + Number of additional frame/image rows omitted from `entries`. + layer_name: + Optional source layer name to display in the dialog. + destination_path: + Optional destination path to display in the dialog. + """ + + n_overwrites: int + n_frames: int + entries: tuple[ConflictEntry, ...] + truncated_entries: int = 0 + layer_name: str | None = None + destination_path: str | None = None + + @property + def has_conflicts(self) -> bool: + return self.n_overwrites > 0 + + @property + def details_text(self) -> str: + if not self.entries: + return "No detailed conflicts." + lines = [f"{entry.frame_label} → {', '.join(entry.keypoints)}" for entry in self.entries] + if self.truncated_entries: + lines.append("") + lines.append(f"… and {self.truncated_entries} more frame/image entries.") + return "\n".join(lines) + + +# ----------------------------------------------------------------------------- +# Trails metadata and user settings models +# ----------------------------------------------------------------------------- + + +class TrailsDisplayConfig(BaseModel): + tail_length: int = Field(default=50, ge=0) + head_length: int = Field(default=50, ge=0) + tail_width: float = Field(default=6.0, gt=0) + opacity: float = Field(default=1.0, ge=0.0, le=1.0) + blending: str = Field(default="translucent") + visible: bool = Field(default=True) + + @field_validator("blending") + @classmethod + def _validate_blending(cls, v: str) -> str: + # Keep this list minimal and permissive; extend if needed. + allowed = {"translucent", "opaque", "additive", "minimum"} + vv = str(v).strip().lower() + return vv if vv in allowed else "translucent" + + +class FolderUIState(BaseModel): + """ + Folder-scoped persisted UI state stored in .napari-deeplabcut.json. + """ + + model_config = ConfigDict(extra="allow") + + schema_version: int = Field(default=1, ge=1) + default_scorer: str | None = None + trails: TrailsDisplayConfig = Field(default_factory=TrailsDisplayConfig) diff --git a/src/napari_deeplabcut/config/settings.py b/src/napari_deeplabcut/config/settings.py new file mode 100644 index 00000000..8edbc833 --- /dev/null +++ b/src/napari_deeplabcut/config/settings.py @@ -0,0 +1,31 @@ +from qtpy.QtCore import QSettings + +DEFAULT_SINGLE_ANIMAL_CMAP = "rainbow" +DEFAULT_MULTI_ANIMAL_INDIVIDUAL_CMAP = "Set3" + +_OVERWRITE_CONFIRM_ENABLED_KEY = "napari_deeplabcut/overwrite/confirm_enabled" +AUTO_OPEN_KEYPOINT_CONTROLS_KEY = "napari_deeplabcut/ui/auto_open_keypoint_controls" + + +def get_overwrite_confirmation_enabled() -> bool: + """Return whether overwrite confirmation dialogs are enabled.""" + settings = QSettings() + return settings.value(_OVERWRITE_CONFIRM_ENABLED_KEY, True, type=bool) + + +def set_overwrite_confirmation_enabled(enabled: bool) -> None: + """Persist whether overwrite confirmation dialogs are enabled.""" + settings = QSettings() + settings.setValue(_OVERWRITE_CONFIRM_ENABLED_KEY, bool(enabled)) + + +def get_auto_open_keypoint_controls() -> bool: + """Return whether keypoint controls should be auto-opened.""" + settings = QSettings() + return settings.value(AUTO_OPEN_KEYPOINT_CONTROLS_KEY, True, type=bool) + + +def set_auto_open_keypoint_controls(enabled: bool) -> None: + """Persist whether keypoint controls should be auto-opened.""" + settings = QSettings() + settings.setValue(AUTO_OPEN_KEYPOINT_CONTROLS_KEY, bool(enabled)) diff --git a/src/napari_deeplabcut/config/supported_files.py b/src/napari_deeplabcut/config/supported_files.py new file mode 100644 index 00000000..2c93727c --- /dev/null +++ b/src/napari_deeplabcut/config/supported_files.py @@ -0,0 +1,4 @@ +"""Supported file formats for images and videos.""" + +SUPPORTED_IMAGES = (".jpg", ".jpeg", ".png") +SUPPORTED_VIDEOS = (".mp4", ".mov", ".avi") diff --git a/src/napari_deeplabcut/core/__init__.py b/src/napari_deeplabcut/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/napari_deeplabcut/core/config_sync.py b/src/napari_deeplabcut/core/config_sync.py new file mode 100644 index 00000000..dd59e8dd --- /dev/null +++ b/src/napari_deeplabcut/core/config_sync.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import logging +from pathlib import Path + +import napari_deeplabcut.core.io as io +from napari_deeplabcut.core.metadata import read_points_meta +from napari_deeplabcut.core.project_paths import ( + find_nearest_config, + infer_dlc_project, + infer_dlc_project_from_image_layer, + infer_dlc_project_from_points_meta, +) + +logger = logging.getLogger("napari-deeplabcut.core.config_sync") + +_POINT_SIZE_KEY = "dotsize" + + +def _coerce_point_size(value, *, default: int = 6, minimum: int = 1, maximum: int = 100) -> int: + try: + size = int(round(float(value))) + except Exception: + size = default + return max(minimum, min(maximum, size)) + + +def _layer_source_path(layer) -> str | None: + try: + src = getattr(layer, "source", None) + p = getattr(src, "path", None) if src is not None else None + return str(p) if p else None + except Exception: + return None + + +def resolve_config_path_from_layer( + layer, + *, + fallback_project: str | Path | None = None, + fallback_root: str | Path | None = None, + image_layer=None, + prefer_project_root: bool = True, + max_levels: int = 5, +) -> Path | None: + """ + Best-effort, lightweight config resolution using centralized DLC project inference. + + Resolution order + ---------------- + 1. Infer from Points metadata via infer_dlc_project_from_points_meta(...) + 2. Infer from current image/video layer via infer_dlc_project_from_image_layer(...) + 3. Infer from generic path-like hints via infer_dlc_project(...) + 4. Last-resort upward search with find_nearest_config(...) + + This intentionally: + - does not do recursive filesystem crawling + - only searches upward with bounded max_levels + - reuses the plugin's root-anchor / project-context semantics + """ + # ------------------------------------------------------------------ + # 1) Points-layer-centric inference (authoritative when available) + # ------------------------------------------------------------------ + try: + pts_meta = read_points_meta( + layer, + migrate_legacy=True, + drop_controls=True, + drop_header=False, + ) + except Exception: + pts_meta = None + + if pts_meta is not None and not hasattr(pts_meta, "errors"): + try: + ctx = infer_dlc_project_from_points_meta( + pts_meta, + prefer_project_root=prefer_project_root, + max_levels=max_levels, + ) + if ctx.config_path is not None and ctx.config_path.is_file(): + return ctx.config_path + except Exception: + logger.debug("Failed to infer config from points metadata", exc_info=True) + + # ------------------------------------------------------------------ + # 2) Image/video-layer-centric inference + # ------------------------------------------------------------------ + if image_layer is not None: + try: + ctx = infer_dlc_project_from_image_layer( + image_layer, + prefer_project_root=prefer_project_root, + max_levels=max_levels, + ) + if ctx.config_path is not None and ctx.config_path.is_file(): + return ctx.config_path + except Exception: + logger.debug("Failed to infer config from image layer", exc_info=True) + + # ------------------------------------------------------------------ + # 3) Generic fallback inference from path-like hints + # ------------------------------------------------------------------ + md = getattr(layer, "metadata", {}) or {} + paths = md.get("paths") or [] + + anchor_candidates: list[str | Path] = [] + dataset_candidates: list[str | Path] = [] + + for value in ( + md.get("project"), + md.get("root"), + _layer_source_path(layer), + fallback_project, + fallback_root, + ): + if value: + anchor_candidates.append(value) + + # Paths can help infer the labeled-data dataset folder/root anchor + if paths: + # dataset candidate: first opened path / row-key hint + dataset_candidates.append(paths[0]) + + # Add a few paths as anchors (bounded/lightweight) + for value in paths[:3]: + anchor_candidates.append(value) + + try: + ctx = infer_dlc_project( + anchor_candidates=anchor_candidates, + dataset_candidates=dataset_candidates, + explicit_root=None, + prefer_project_root=prefer_project_root, + max_levels=max_levels, + ) + if ctx.config_path is not None and ctx.config_path.is_file(): + return ctx.config_path + except Exception: + logger.debug("Failed to infer config from generic path hints", exc_info=True) + + # ------------------------------------------------------------------ + # 4) Fallback upward search on a bounded set of candidates + # ------------------------------------------------------------------ + for candidate in anchor_candidates: + try: + cfg = find_nearest_config(candidate, max_levels=max_levels) + if cfg is not None and cfg.is_file(): + return cfg + except Exception: + logger.debug("find_nearest_config failed for %r", candidate, exc_info=True) + + return None + + +def load_point_size_from_config(config_path: str | Path | None) -> int | None: + if not config_path: + return None + + try: + cfg = io.load_config(str(config_path)) + except Exception: + logger.debug("Could not read config file %r", config_path, exc_info=True) + return None + + if not isinstance(cfg, dict): + logger.debug( + "Config file %r did not contain a mapping; ignoring for point-size load.", + config_path, + ) + return None + + if _POINT_SIZE_KEY in cfg: + return _coerce_point_size(cfg.get(_POINT_SIZE_KEY)) + return None + + +def save_point_size_to_config(config_path: str | Path | None, size: int) -> bool: + """ + Persist point size in config.yaml if possible. + + Returns + ------- + bool + True if the config was changed and written, False otherwise. + """ + if not config_path: + logger.debug("Skipping point-size config sync: no config path resolved.") + return False + + size = _coerce_point_size(size) + + try: + cfg = io.load_config(str(config_path)) + except Exception: + logger.debug("Could not read config file %r", config_path, exc_info=True) + return False + + if not isinstance(cfg, dict): + logger.debug( + "Config file %r did not contain a mapping; replacing with empty config for point-size sync.", + config_path, + ) + cfg = {} + + old_value = cfg.get(_POINT_SIZE_KEY, None) + + try: + if old_value is not None and _coerce_point_size(old_value) == size: + logger.debug("Skipping point-size config sync: dotsize already %s", size) + return False + except Exception: + pass + + cfg[_POINT_SIZE_KEY] = size + + try: + io.write_config(str(config_path), cfg) + logger.debug("Updated dotsize=%s in %s", size, config_path) + return True + except Exception: + logger.debug("Could not write config file %r", config_path, exc_info=True) + return False diff --git a/src/napari_deeplabcut/core/conflicts.py b/src/napari_deeplabcut/core/conflicts.py new file mode 100644 index 00000000..41360df8 --- /dev/null +++ b/src/napari_deeplabcut/core/conflicts.py @@ -0,0 +1,208 @@ +# src/napari_deeplabcut/core/conflicts.py +from __future__ import annotations + +from pathlib import Path + +import pandas as pd + +from napari_deeplabcut.config.models import AnnotationKind, OverwriteConflictReport, PointsMetadata +from napari_deeplabcut.core import schemas as dlc_schemas +from napari_deeplabcut.core.dataframes import set_df_scorer +from napari_deeplabcut.core.errors import AmbiguousSaveError, MissingProvenanceError +from napari_deeplabcut.core.metadata import parse_points_metadata +from napari_deeplabcut.core.project_paths import infer_dlc_project_from_points_meta +from napari_deeplabcut.core.provenance import ( + resolve_output_path_from_metadata, +) + + +def compute_overwrite_report_for_points_save( + data, + attributes: dict, +) -> OverwriteConflictReport | None: + """ + Compute an overwrite-conflict report for a prospective points-layer save. + + This is a non-interactive preflight helper intended for UI/controller code + to call *before* invoking the napari writer. It mirrors the writer's save + routing logic closely enough to predict whether saving this points layer + would merge into an existing GT file and overwrite existing keypoints. + + Parameters + ---------- + data: + Napari Points layer data, expected to be array-like of shape (N, 3) + in [frame, y, x] order. + attributes: + Napari layer attributes dict for the points layer. This is the same + payload shape passed to the npe2 writer. + + Returns + ------- + OverwriteConflictReport | None + - OverwriteConflictReport if the save target is an existing GT file and + at least one keypoint overwrite conflict would occur. + - None if: + * there is no existing GT file to merge into, + * the destination is not GT, + * or no overwrite conflicts are detected. + + Raises + ------ + ValueError + If the layer attributes / points payload are invalid for save. + MissingProvenanceError + If saving a MACHINE source without a resolvable promotion target. + AmbiguousSaveError + If GT fallback resolution finds multiple CollectedData_*.h5 files. + """ + # Local imports keep core.conflicts free of import cycles: + # - core.dataframes imports ConflictEntry / OverwriteConflictReport + # - core.io imports dataframe helpers and metadata parsing + from napari_deeplabcut.core.dataframes import ( + build_overwrite_conflict_report, + form_df_from_validated, + keypoint_conflicts, + ) + + attrs = dlc_schemas.PointsLayerAttributesModel.model_validate(attributes or {}) + pts_meta: PointsMetadata = parse_points_metadata(attrs.metadata, drop_header=False) + + if not pts_meta.header: + raise ValueError("Layer metadata must include a valid DLC header to write keypoints.") + + points = dlc_schemas.PointsDataModel.model_validate({"data": data}) + props = dlc_schemas.KeypointPropertiesModel.model_validate(attrs.properties) + + # Bundle + validate cross-field invariants exactly like the writer + ctx = dlc_schemas.PointsWriteInputModel.model_validate( + { + "points": points, + "meta": pts_meta, + "props": props, + } + ) + + # Build the outgoing dataframe exactly like the writer + df_new = form_df_from_validated(ctx) + + # Resolve output path using the same provenance-first routing as write_hdf(...) + out_path, target_scorer, source_kind = resolve_output_path_from_metadata(attributes) + + # Promotion to GT may rewrite scorer level + if target_scorer: + df_new = set_df_scorer(df_new, target_scorer) + + # Never write back to machine sources without an explicit promotion target + if not out_path and source_kind == AnnotationKind.MACHINE: + raise MissingProvenanceError("Cannot resolve provenance output path for MACHINE source.") + + # Same GT fallback logic as write_hdf(...) + if not out_path: + project_ctx = infer_dlc_project_from_points_meta(pts_meta, prefer_project_root=False) + dataset_dir = project_ctx.dataset_folder + + if dataset_dir is not None: + dataset_dir.mkdir(parents=True, exist_ok=True) + root_path = dataset_dir + else: + root = pts_meta.root + if not root: + raise MissingProvenanceError("GT fallback requires root (and dataset folder could not be inferred).") + root_path = Path(root) + + candidates = sorted(root_path.glob("CollectedData_*.h5")) + if len(candidates) > 1: + raise AmbiguousSaveError( + f"Multiple CollectedData_*.h5 files found in {root_path}." + " Cannot determine where to save." + " Please specify a save_target with explicit path and scorer.", + candidates=[str(c) for c in candidates], + ) + elif len(candidates) == 1: + out = candidates[0] + else: + scorer = target_scorer or pts_meta.header.scorer + out = root_path / f"CollectedData_{scorer}.h5" + else: + out = Path(out_path) + + # Only GT merge-on-save can produce overwrite conflicts + has_save_target = pts_meta.save_target is not None + destination_kind = ( + AnnotationKind.GT + if has_save_target + else ((pts_meta.io.kind if pts_meta.io is not None else None) or AnnotationKind.GT) + ) + + if destination_kind != AnnotationKind.GT: + return None + + # No existing file -> no merge -> no overwrite conflict + if not out.exists(): + return None + + try: + df_old = pd.read_hdf(out, key="keypoints") + except (KeyError, ValueError): + df_old = pd.read_hdf(out) + + key_conflict = keypoint_conflicts(df_old, df_new) + + report = build_overwrite_conflict_report( + key_conflict, + layer_name=attributes.get("name"), + destination_path=str(out), + ) + + return report if report.has_conflicts else None + + +def compute_overwrite_report_for_extracted_labels_row( + destination_path: str | Path, + df_new: pd.DataFrame, + *, + layer_name: str | None = None, +) -> OverwriteConflictReport | None: + """ + Compute an overwrite-conflict report for a single extracted-frame labels row + being merged into an existing machinelabels-iter0.h5 file. + + Parameters + ---------- + destination_path: + Existing or prospective machinelabels file. + df_new: + A one-row DLC-style dataframe for the extracted frame, indexed by the + canonical image path tuple. + layer_name: + Optional display name for the source layer in the dialog. + + Returns + ------- + OverwriteConflictReport | None + Report if overwrites would occur, otherwise None. + """ + from napari_deeplabcut.core.dataframes import ( + build_overwrite_conflict_report, + keypoint_conflicts, + ) + + out = Path(destination_path) + if not out.exists(): + return None + + try: + df_old = pd.read_hdf(out, key="df_with_missing") + except (KeyError, ValueError): + df_old = pd.read_hdf(out) + + key_conflict = keypoint_conflicts(df_old, df_new) + + report = build_overwrite_conflict_report( + key_conflict, + layer_name=layer_name or "Extracted frame labels", + destination_path=str(out), + ) + + return report if report.has_conflicts else None diff --git a/src/napari_deeplabcut/core/dataframes.py b/src/napari_deeplabcut/core/dataframes.py new file mode 100644 index 00000000..8a0fc142 --- /dev/null +++ b/src/napari_deeplabcut/core/dataframes.py @@ -0,0 +1,464 @@ +# src/napari_deeplabcut/core/dataframes.py + +from __future__ import annotations + +import logging + +import numpy as np +import pandas as pd + +from napari_deeplabcut.config.models import ConflictEntry, OverwriteConflictReport +from napari_deeplabcut.core.schemas import PointsWriteInputModel + +logger = logging.getLogger(__name__) + + +def set_df_scorer(df: pd.DataFrame, scorer: str) -> pd.DataFrame: + """Return df with scorer level set to the given scorer (if present).""" + scorer = (scorer or "").strip() + if not scorer: + return df + if not hasattr(df.columns, "names") or "scorer" not in df.columns.names: + return df + + try: + cols = df.columns.to_frame(index=False) + cols["scorer"] = scorer + df = df.copy() + df.columns = pd.MultiIndex.from_frame(cols) + except Exception: + pass + return df + + +def merge_multiple_scorers(df: pd.DataFrame) -> pd.DataFrame: + """ + If df has multiple scorers in its column MultiIndex, merge them. + + - If likelihood exists, keep the scorer with max likelihood per keypoint/frame. + - Else, pick the first scorer deterministically. + """ + if not isinstance(df.columns, pd.MultiIndex): + return df + + n_frames = df.shape[0] + cols = df.columns + names = list(cols.names or []) + + # Identify scorer level + scorer_level = "scorer" if "scorer" in names else 0 + scorers = list(dict.fromkeys(cols.get_level_values(scorer_level).astype(str).tolist())) + n_scorers = len(scorers) + if n_scorers <= 1: + return df + + # Identify coords level + coords_level = "coords" if "coords" in names else (cols.nlevels - 1) + coords_vals = cols.get_level_values(coords_level).astype(str).tolist() + has_likelihood = "likelihood" in set(coords_vals) + + # Helper: take columns for a given scorer + def _cols_for_scorer(scorer: str) -> pd.MultiIndex: + mask = cols.get_level_values(scorer_level).astype(str) == str(scorer) + return cols[mask] + + if has_likelihood: + # Ensure each scorer block has same column ordering/shape + cols0 = _cols_for_scorer(scorers[0]) + per_scorer = len(cols0) + if per_scorer == 0: + return df + + # If other scorers don't match shape, fall back to first scorer + for s in scorers[1:]: + if len(_cols_for_scorer(s)) != per_scorer: + logger.debug("Scorer column blocks differ in size; falling back to first scorer.") + return df.loc[:, cols0] + + # Stack scorer axis -> (n_frames, n_scorers, per_scorer) + data = df.to_numpy(copy=True).reshape((n_frames, n_scorers, per_scorer)) + + # We need likelihood position within per_scorer block. + # Find likelihood columns within the first scorer block: + coords0 = cols0.get_level_values(coords_level).astype(str).to_numpy() + like_mask = coords0 == "likelihood" + if not np.any(like_mask): + # coords said likelihood exists, but not in the first scorer block - fallback + return df.loc[:, cols0] + + # Reshape per keypoint: assume (x, y, likelihood) triplets per keypoint. + # We infer n_keypoints from number of likelihood entries. + n_keypoints = int(np.sum(like_mask)) + # Triplet width is per_scorer / n_keypoints if structured, but be defensive: + triplet = per_scorer // max(1, n_keypoints) + if triplet < 3: + # Not in expected (x,y,likelihood) shape; fallback + return df.loc[:, cols0] + + data3 = data.reshape((n_frames, n_scorers, n_keypoints, triplet)) + + # likelihood is assumed at index 2 in each triplet (legacy DLC layout) + try: + idx = np.nanargmax(data3[..., 2], axis=1) + except ValueError: # All-NaN slice encountered + mask = np.isnan(data3[..., 2]).all(axis=1, keepdims=True) + mask = np.broadcast_to(mask[..., None], data3.shape) + data3[mask] = -1 + idx = np.nanargmax(data3[..., 2], axis=1) + data3[mask] = np.nan + + data_best = data3[np.arange(n_frames)[:, None], idx, np.arange(n_keypoints)] + data_best = data_best.reshape((n_frames, -1)) + + # Output columns: use first scorer block columns (structure preserved) + out_cols = cols0[: data_best.shape[1]] + return pd.DataFrame(data_best, index=df.index, columns=out_cols) + + # No likelihood: pick first scorer deterministically + cols0 = _cols_for_scorer(scorers[0]) + return df.loc[:, cols0] + + +def guarantee_multiindex_rows(df: pd.DataFrame) -> None: + """Ensure that DataFrame rows are a MultiIndex of path components. + Legacy DLC data may use an index with pathto/video/file.png strings as Index. + The new format uses a MultiIndex with each path component as a level. + """ + # Make paths platform-agnostic if they are not already + if not isinstance(df.index, pd.MultiIndex): # Backwards compatibility + path = df.index[0] + try: + sep = "/" if "/" in path else "\\" + splits = tuple(df.index.str.split(sep)) + df.index = pd.MultiIndex.from_tuples(splits) + except TypeError: # Ignore numerical index of frame indices + pass + + +def form_df_from_validated(ctx: PointsWriteInputModel) -> pd.DataFrame: + """Create a DLC-style DataFrame from validated napari points + metadata.""" + header = ctx.meta.header # DLCHeaderModel (validated) + props = ctx.props + + # DLC expects x,y columns; ctx.points.xy_dlc converts napari [y,x] -> [x,y] + temp_df = pd.DataFrame(ctx.points.xy_dlc, columns=["x", "y"]) + temp_df["bodyparts"] = props.label + temp_df["individuals"] = props.id + temp_df["inds"] = ctx.points.frame_inds + temp_df["likelihood"] = props.likelihood if props.likelihood is not None else 1.0 + + temp_df["scorer"] = header.scorer or "unknown" + + # Mark rows that have actual coords + temp_df["_has_xy"] = temp_df[["x", "y"]].notna().all(axis=1) + + # Sort so that rows WITH coords come last (so keep="last" keeps them) + temp_df = temp_df.sort_values("_has_xy") + + # Drop duplicates on the key that defines a unique keypoint observation + temp_df = temp_df.drop_duplicates( + subset=["scorer", "individuals", "bodyparts", "inds"], + keep="last", + ) + temp_df = temp_df.drop(columns="_has_xy") + + df = temp_df.set_index(["scorer", "individuals", "bodyparts", "inds"]).stack() + df.index.set_names("coords", level=-1, inplace=True) + df = df.unstack(["scorer", "individuals", "bodyparts", "coords"]) + df.index.name = None + + hdr_cols = ctx.meta.header.as_multiindex() # pandas-only helper; raises if pandas missing + + logger.debug("Before reindex: cols nlevels %s, names %s", df.columns.nlevels, df.columns.names) + logger.debug("header cols nlevels %s, names %s", hdr_cols.nlevels, hdr_cols.names) + + # If df columns dropped individuals, drop it from header too (if present) + # if df.columns.nlevels == 3 and isinstance(hdr_cols, pd.MultiIndex) and hdr_cols.nlevels == 4: + # if "individuals" in hdr_cols.names: + # hdr_cols = hdr_cols.droplevel("individuals") + + # If df columns kept individuals but header doesn't have it, add it (single-animal) + if df.columns.nlevels == 4 and isinstance(hdr_cols, pd.MultiIndex) and hdr_cols.nlevels == 3: + # Insert empty individuals level into header tuples + frame = hdr_cols.to_frame(index=False) + frame.insert(1, "individuals", "") + hdr_cols = pd.MultiIndex.from_frame(frame, names=["scorer", "individuals", "bodyparts", "coords"]) + + df = df.reindex(hdr_cols, axis=1) + + logger.debug("After reindex: cols nlevels %s, names %s", df.columns.nlevels, df.columns.names) + logger.debug( + "header cols nlevels %s, names %s", + ctx.meta.header.as_multiindex().nlevels, + ctx.meta.header.as_multiindex().names, + ) + + # Replace integer frame index with path keys if available + if ctx.meta.paths: + df.index = [ctx.meta.paths[i] for i in df.index] + + guarantee_multiindex_rows(df) + + # Writer invariant: if there are finite points in the layer, df must contain finite coords + layer_xy = np.asarray(ctx.points.xy_dlc) # (N, 2) in [x,y] + n_layer = np.isfinite(layer_xy).all(axis=1).sum() + + # Count finite values in df (x/y columns only) + n_df = np.isfinite(df.to_numpy()).sum() + + if n_layer > 0 and n_df == 0: + raise RuntimeError( + "Writer produced no finite coordinates although layer contains finite points. " + "Likely a header/column MultiIndex mismatch during reindex." + ) + + return df + + +def harmonize_keypoint_row_index(df_new: pd.DataFrame, df_old: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + Harmonize row index representation between a freshly formed points dataframe (df_new) + and an existing on-disk dataframe (df_old) to make combine_first/align stable. + + Strategy: + - Ensure both indices are MultiIndex via misc.guarantee_multiindex_rows + - If nlevels differ and one is 1-level while the other is >1, attempt to collapse + the deeper one to basename if it matches the 1-level index sufficiently. + + """ + # FUTURE NOTE @C-Achard 2026-02-18 hardcoded DLC structure: + # DLC's CollectedData files are commonly keyed by per-folder image names (basenames), + # even when the runtime layer may store relpaths including subfolders. + df_new2 = df_new.copy() + df_old2 = df_old.copy() + + # Make both MultiIndex (project convention) + guarantee_multiindex_rows(df_new2) + guarantee_multiindex_rows(df_old2) + + inew = df_new2.index + iold = df_old2.index + + if not isinstance(inew, pd.MultiIndex) or not isinstance(iold, pd.MultiIndex): + return df_new2, df_old2 + + if inew.nlevels == iold.nlevels: + return df_new2, df_old2 + + # Identify which is "deep" and which is "shallow" + if inew.nlevels > iold.nlevels: + deep_df, shallow_df = df_new2, df_old2 + else: + deep_df, shallow_df = df_old2, df_new2 + + deep_idx = deep_df.index + shallow_idx = shallow_df.index + + # Only try collapse when shallow is 1-level and deep is >1 + if not isinstance(shallow_idx, pd.MultiIndex) or shallow_idx.nlevels != 1: + return df_new2, df_old2 + if not isinstance(deep_idx, pd.MultiIndex) or deep_idx.nlevels <= 1: + return df_new2, df_old2 + + # Collapse deep MultiIndex to last component (basename) + deep_last = deep_idx.to_frame(index=False).iloc[:, -1].astype(str).tolist() + shallow_vals = shallow_idx.to_frame(index=False).iloc[:, 0].astype(str).tolist() + + # Measure overlap (set-based) + overlap = len(set(deep_last) & set(shallow_vals)) + denom = max(1, len(set(shallow_vals))) + ratio = overlap / denom + + # If most shallow keys exist as basenames in deep, collapse deep to shallow representation + if ratio >= 0.8: + deep_df2 = deep_df.copy() + deep_df2.index = pd.MultiIndex.from_arrays([deep_last]) + + # Return in original order + if deep_df is df_new2: + return deep_df2, shallow_df + else: + return shallow_df, deep_df2 + + return df_new2, df_old2 + + +def harmonize_keypoint_column_index(df: pd.DataFrame) -> pd.DataFrame: + """Ensure DLC keypoints columns are a 4-level MultiIndex with individuals inserted if missing.""" + if not isinstance(df.columns, pd.MultiIndex): + return df + + cols = df.columns + + # Already 4 levels: try to ensure correct names + if cols.nlevels == 4: + # set_names is safe even if already correct + df2 = df.copy() + df2.columns = cols.set_names(["scorer", "individuals", "bodyparts", "coords"]) + return df2 + + # Legacy 3-level: (scorer, bodyparts, coords) -> insert individuals="" + if cols.nlevels == 3: + # We only insert individuals if it looks like the DLC pattern + # (names might be missing/None depending on earlier ops) + list(cols.names) + # accept either correct names or unknown names + # but we assume order is scorer/bodyparts/coords + frame = cols.to_frame(index=False) + + # If names are already scorer/bodyparts/coords, this is perfect. + # If not, we still insert individuals at position 1. + frame.insert(1, "individuals", "") + + df2 = df.copy() + df2.columns = pd.MultiIndex.from_frame(frame, names=["scorer", "individuals", "bodyparts", "coords"]) + return df2 + + # Other nlevels not expected: leave unchanged + return df + + +def align_old_new(df_old: pd.DataFrame, df_new: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]: + """Align both dataframes to union of index and columns.""" + # First harmonize row index structure (deep path MI vs shallow basename index) + df_new, df_old = harmonize_keypoint_row_index(df_new, df_old) + + df_old = harmonize_keypoint_column_index(df_old) + df_new = harmonize_keypoint_column_index(df_new) + + idx = df_old.index.union(df_new.index) + cols = df_old.columns.union(df_new.columns) + return ( + df_old.reindex(index=idx, columns=cols), + df_new.reindex(index=idx, columns=cols), + ) + + +def keypoint_conflicts(df_old: pd.DataFrame, df_new: pd.DataFrame) -> pd.DataFrame: + """ + Return a boolean DataFrame indexed by image, with columns as keypoints, + True when any coord (x/y[/likelihood]) would overwrite an existing value. + + Columns in output are MultiIndex levels subset: (individuals?, bodyparts) + or just (bodyparts) for single animal. + """ + old, new = align_old_new(df_old, df_new) + + old_has = old.notna() + new_has = new.notna() + + # cell-level conflicts: both have values and differ + cell_conflict = (old != new) & old_has & new_has + + # Identify which levels exist + col_names = list(old.columns.names) + has_inds = "individuals" in col_names + has_body = "bodyparts" in col_names + has_coords = "coords" in col_names + + if not (has_body and has_coords): + # Unexpected format; fall back to cell-level summary + return cell_conflict.any(axis=1).to_frame(name="conflict") + + # Drop scorer level if present (not meaningful for end-user warning) + # We want to aggregate per (individual, bodypart) across coords. + # Build the grouping levels that define a "keypoint". + key_levels = [] + if has_inds: + key_levels.append("individuals") + key_levels.append("bodyparts") + + # Reduce across coords first -> any conflict for that coord-set + # This yields a DataFrame with columns still multi-level including scorer and coords. + # We then group by key_levels. + # Ensure we can group by dropping coords by grouping over it using "any". + # We'll group over all columns that share the same (individual/bodypart), ignoring coords. + # To do that cleanly: swap coords to last, then groupby on key_levels. + conflict_cols = cell_conflict.copy() + + # Group columns by key_levels and reduce with any() across remaining levels (coords + scorer) + # pandas no longer allows groupby on axis=1 by level names + # we use .T to swap axes, groupby on rows, then .T back to original orientation instead + key_conflict = conflict_cols.T.groupby(level=key_levels).any().T + + return key_conflict + + +def _format_image_id(img_id) -> str: + """Format a row index value into a user-friendly frame/image identifier.""" + if isinstance(img_id, tuple): + # e.g. ('labeled-data', 'test', 'img000.png') + return "/".join(map(str, img_id)) + return str(img_id) + + +def _format_keypoint_id(kp) -> str: + """Format a keypoint column label into a user-friendly identifier.""" + # kp can be: + # - scalar bodypart: "nose" + # - tuple(individual, bodypart): ("animal1", "nose") + # - larger tuple if upstream grouping shape changes + if isinstance(kp, tuple): + if len(kp) == 2: + ind, bp = kp + return f"{bp} (id: {ind})" if ind else str(bp) + return " / ".join(str(x) for x in kp if x not in (None, "")) + return str(kp) + + +def build_overwrite_conflict_report( + key_conflict: pd.DataFrame, + *, + max_entries: int = 50, + layer_name: str | None = None, + destination_path: str | None = None, +) -> OverwriteConflictReport: + """ + Convert a pandas key-conflict table into a UI-facing overwrite report. + + Parameters + ---------- + key_conflict: + Boolean-like DataFrame indexed by frame/image identifier, with columns + representing keypoints. Truthy cells indicate a keypoint overwrite conflict. + + Returns + ------- + OverwriteConflictReport + Plain-Python UI contract describing overwrite counts and detailed entries. + + Notes + ----- + This function is the pandas boundary for overwrite reporting. The UI should + depend only on OverwriteConflictReport, not on DataFrame structure. + """ + n_overwrites = int(key_conflict.to_numpy().sum()) + n_frames = int(key_conflict.any(axis=1).to_numpy().sum()) + + entries: list[ConflictEntry] = [] + + for img, row in key_conflict.iterrows(): + conflicted: list[str] = [] + for kp, flag in row.items(): + if bool(flag): + conflicted.append(_format_keypoint_id(kp)) + + if conflicted: + entries.append( + ConflictEntry( + frame_label=_format_image_id(img), + keypoints=tuple(conflicted), + ) + ) + + shown = tuple(entries[:max_entries]) + truncated = max(0, len(entries) - len(shown)) + + return OverwriteConflictReport( + n_overwrites=n_overwrites, + n_frames=n_frames, + entries=shown, + truncated_entries=truncated, + layer_name=layer_name, + destination_path=destination_path, + ) diff --git a/src/napari_deeplabcut/core/discovery.py b/src/napari_deeplabcut/core/discovery.py new file mode 100644 index 00000000..dacbf044 --- /dev/null +++ b/src/napari_deeplabcut/core/discovery.py @@ -0,0 +1,121 @@ +# src/napari_deeplabcut/core/discovery.py +""" +Deterministic discovery of DeepLabCut annotation artifacts in a folder. + +This module is pure filesystem logic (no napari imports). +It enumerates all relevant files and classifies them into AnnotationKind. + +# FUTURE NOTE hardcoded DLC structure: +# DLC naming conventions (CollectedData*, machinelabels*) are hardcoded here. +# If DLC expands formats/patterns, update ONLY this module. +""" + +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass +from pathlib import Path + +from natsort import natsorted + +from napari_deeplabcut.config.models import AnnotationKind + + +@dataclass(frozen=True) +class AnnotationArtifact: + """A discovered annotation artifact (H5/CSV) with inferred semantics.""" + + kind: AnnotationKind | None + h5_path: Path | None + csv_path: Path | None + stem: str + + @property + def primary_path(self) -> Path | None: + """Preferred path for opening: H5 preferred over CSV.""" + return self.h5_path or self.csv_path + + +def _infer_kind_from_stem(stem: str) -> AnnotationKind | None: + """Infer kind from filename stem.""" + low = stem.lower() + + # FUTURE NOTE hardcoded DLC structure: + if low.startswith("collecteddata"): + return AnnotationKind.GT + if low.startswith("machinelabels"): + return AnnotationKind.MACHINE + return None + + +def _is_relevant_artifact(p: Path) -> bool: + """Return True if path looks like a DLC annotation artifact.""" + if not p.is_file(): + return False + low = p.name.lower() + + # FUTURE NOTE hardcoded DLC structure: + if low.endswith(".h5") or low.endswith(".csv"): + return low.startswith("collecteddata") or low.startswith("machinelabels") + return False + + +def discover_annotations(folder: str | Path) -> list[AnnotationArtifact]: + """Discover DLC annotation artifacts in a folder (deterministic order).""" + root = Path(folder) + if not root.exists() or not root.is_dir(): + return [] + + files = [p for p in root.iterdir() if _is_relevant_artifact(p)] + files = natsorted(files, key=lambda p: p.name) + + by_stem: dict[str, dict[str, Path]] = {} + for p in files: + entry = by_stem.setdefault(p.stem, {}) + if p.suffix.lower() == ".h5": + entry["h5"] = p + elif p.suffix.lower() == ".csv": + entry["csv"] = p + + artifacts: list[AnnotationArtifact] = [] + for stem in natsorted(by_stem.keys()): + entry = by_stem[stem] + kind = _infer_kind_from_stem(stem) + artifacts.append( + AnnotationArtifact( + kind=kind, + h5_path=entry.get("h5"), + csv_path=entry.get("csv"), + stem=stem, + ) + ) + + # Stable ordering by primary filename + return natsorted(artifacts, key=lambda a: a.primary_path.name if a.primary_path else a.stem) + + +def discover_annotation_paths(folder: str | Path) -> list[Path]: + """Return primary paths to open (H5 preferred else CSV).""" + return [a.primary_path for a in discover_annotations(folder) if a.primary_path is not None] + + +def iter_annotation_candidates(paths: Iterable[str | Path]) -> list[Path]: + """Expand folders to annotation candidates (deterministic).""" + out: list[Path] = [] + for p in paths: + pp = Path(p) + if pp.is_dir(): + out.extend(discover_annotation_paths(pp)) + elif _is_relevant_artifact(pp): + out.append(pp) + return natsorted(out, key=lambda x: x.name) + + +def infer_annotation_kind_for_file(file_path: str | Path) -> AnnotationKind | None: + """Infer kind for a specific file path by scanning its parent folder.""" + fp = Path(file_path) + parent = fp.parent + for art in discover_annotations(parent): + if art.h5_path == fp or art.csv_path == fp: + return art.kind + return None diff --git a/src/napari_deeplabcut/core/errors.py b/src/napari_deeplabcut/core/errors.py new file mode 100644 index 00000000..8a486201 --- /dev/null +++ b/src/napari_deeplabcut/core/errors.py @@ -0,0 +1,30 @@ +"""Typed exceptions for napari-deeplabcut core utilities. + +They are used by IO routing/provenance helpers to express deterministic +failure modes (e.g., ambiguity) without relying on ad-hoc strings. + +No behavior changes are introduced by merely defining these exceptions. +""" + +# src/napari_deeplabcut/core/errors.py +from __future__ import annotations + + +class NapariDLCError(RuntimeError): + """Base class for napari-deeplabcut domain errors.""" + + +class MissingProvenanceError(NapariDLCError): + """Raised when required provenance is absent (cannot determine save target).""" + + +class AmbiguousSaveError(NapariDLCError): + """Raised when multiple save targets are plausible and policy forbids guessing.""" + + def __init__(self, message: str, candidates: list[str] | None = None): + super().__init__(message) + self.candidates = candidates or [] + + +class UnresolvablePathError(NapariDLCError): + """Raised when provenance exists but cannot be resolved to a concrete filesystem path.""" diff --git a/src/napari_deeplabcut/core/io.py b/src/napari_deeplabcut/core/io.py new file mode 100644 index 00000000..d7fa59d8 --- /dev/null +++ b/src/napari_deeplabcut/core/io.py @@ -0,0 +1,808 @@ +""" +Core IO utilities. + +Includes: +- Config file reading/writing +- HDF reading with provenance attachment +- Lazy image reading with Dask support +- Video reading with OpenCV and optional PyAV fallback +- Superkeypoints diagram and JSON loading +""" +# src/napari_deeplabcut/core/io.py + +from __future__ import annotations + +import fnmatch +import json +import logging +import os +from collections.abc import Callable +from importlib import resources +from pathlib import Path +from typing import Any + +import cv2 +import dask.array as da +import numpy as np +import pandas as pd +import yaml +from dask import delayed +from dask_image.imread import imread +from napari.types import LayerData +from natsort import natsorted +from pydantic import ValidationError + +from napari_deeplabcut import misc +from napari_deeplabcut.config.models import AnnotationKind, DLCHeaderModel, PointsMetadata +from napari_deeplabcut.config.settings import DEFAULT_SINGLE_ANIMAL_CMAP +from napari_deeplabcut.config.supported_files import SUPPORTED_IMAGES, SUPPORTED_VIDEOS +from napari_deeplabcut.core import schemas as dlc_schemas +from napari_deeplabcut.core.dataframes import ( + form_df_from_validated, + guarantee_multiindex_rows, + harmonize_keypoint_column_index, + harmonize_keypoint_row_index, + merge_multiple_scorers, + set_df_scorer, +) +from napari_deeplabcut.core.errors import AmbiguousSaveError, MissingProvenanceError +from napari_deeplabcut.core.layers import populate_keypoint_layer_properties +from napari_deeplabcut.core.metadata import attach_source_and_io_to_layer_kwargs, parse_points_metadata +from napari_deeplabcut.core.project_paths import ( + canonicalize_path, + find_nearest_config, + infer_dlc_project_from_points_meta, +) +from napari_deeplabcut.core.provenance import resolve_output_path_from_metadata + +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Supported formats (shared by image/video readers) +# ----------------------------------------------------------------------------- +_GLOB_MAGIC = set("*?[") +_SUPPORTED_SUFFIXES = {ext.lower() for ext in SUPPORTED_IMAGES} + + +def _has_glob_magic(name: str) -> bool: + return any(ch in name for ch in _GLOB_MAGIC) + + +# ============================================================================= +# CONFIG (YAML) +# ============================================================================= + + +def load_config(config_path: str): + # NOTE: intentionally minimal; callers own error handling + with open(config_path) as file: + return yaml.safe_load(file) + + +# Read config file and create keypoint layer metadata +def read_config(configname: str) -> list[LayerData]: + config = load_config(configname) + header = DLCHeaderModel.from_config(config) + layer_props = populate_keypoint_layer_properties( + header, + size=config["dotsize"], + pcutoff=config["pcutoff"], + colormap=config["colormap"], + likelihood=np.array([1]), + ) + layer_props["name"] = f"CollectedData_{config['scorer']}" + layer_props["ndim"] = 3 + layer_props["property_choices"] = layer_props.pop("properties") + layer_props["metadata"]["project"] = str(Path(configname).parent) + layer_props["metadata"]["config_colormap"] = str(config.get("colormap", DEFAULT_SINGLE_ANIMAL_CMAP)) + + conversion_tables = config.get("SuperAnimalConversionTables") + if conversion_tables is not None: + super_animal, table = conversion_tables.popitem() + layer_props["metadata"]["tables"] = {super_animal: table} + return [(None, layer_props, "points")] + + +def write_config(config_path: str | Path, params: dict[str, Any]) -> None: + """Write DeepLabCut config.yaml parameters.""" + with open(str(config_path), "w", encoding="utf-8") as f: + yaml.safe_dump(params, f) + + +# ============================================================================= +# KEYPOINTS / ANNOTATIONS (HDF5) +# ============================================================================= +# NOTE: This reader returns a napari Points layer (data + metadata + "points") +# and attaches provenance via attach_source_and_io_to_layer_kwargs. + + +def read_hdf(filename: str) -> list[LayerData]: + layers = [] + for file in Path(filename).parent.glob(Path(filename).name): + layers.extend(read_hdf_single(file)) + return layers + + +def read_hdf_single(file: Path, *, kind: AnnotationKind | None = None) -> list[LayerData]: + """Read a single H5 file and attach provenance with optional explicit kind. + + - Produces one Points layer per H5 file + - Points.data contains only finite coordinates + - Unlabeled keypoints are omitted from Points.data + - Empty Points layers are valid + """ + temp = pd.read_hdf(str(file)) + temp = merge_multiple_scorers(temp) + header = DLCHeaderModel(columns=temp.columns) + temp = temp.droplevel("scorer", axis=1) + + # Handle legacy/single-animal column layout by inserting empty "individuals" level. + # Colormap selection also falls back to config when possible. + try: + cfg = load_config(find_nearest_config(file, max_levels=3)) + config_colormap = str(cfg.get("colormap", DEFAULT_SINGLE_ANIMAL_CMAP)) + except Exception as e: + logger.warning("Could not load config for %s; falling back to default colormap. Error: %s", file, e) + config_colormap = DEFAULT_SINGLE_ANIMAL_CMAP + if "individuals" not in temp.columns.names: + old_idx = temp.columns.to_frame() + old_idx.insert(0, "individuals", "") + temp.columns = pd.MultiIndex.from_frame(old_idx) + + # If the on-disk index is a MultiIndex (path parts), collapse it to string paths. + if isinstance(temp.index, pd.MultiIndex): + temp.index = [str(Path(*row)) for row in temp.index] + + df = ( + temp.stack(["individuals", "bodyparts"]) + .reindex(header.individuals, level="individuals") + .reindex(header.bodyparts, level="bodyparts") + .reset_index() + ) + + nrows = df.shape[0] + data = np.empty((nrows, 3)) + image_paths = df["level_0"] + + # Convert image keys to integer indices when they are already numeric, + # otherwise encode category paths deterministically. + if pd.api.types.is_numeric_dtype(getattr(image_paths, "dtype", np.asarray(image_paths).dtype)): + image_inds = image_paths.values + paths2inds = [] + else: + image_inds, paths2inds = misc.encode_categories( + image_paths, + is_path=True, + return_unique=True, + do_sort=True, + ) + + data[:, 0] = image_inds + data[:, 1:] = df[["y", "x"]].to_numpy() + finite = np.isfinite(data).all(axis=1) + # Keep only finite coords in data, but keep all rows in df for metadata completeness. + data = data[finite] + df = df.loc[finite].reset_index(drop=True) + + layer_props = populate_keypoint_layer_properties( + header, + labels=df["bodyparts"], + ids=df["individuals"], + likelihood=df.get("likelihood"), + paths=list(paths2inds), + colormap=config_colormap, + ) + layer_props["name"] = file.stem + layer_props["metadata"]["root"] = str(file.parent) + layer_props["metadata"]["name"] = layer_props["name"] + layer_props["metadata"]["config_colormap"] = config_colormap + + # Attach provenance. If explicit kind provided, we store it directly. + if kind is not None: + meta = layer_props.setdefault("metadata", {}) + # Keep legacy source fields too + attach_source_and_io_to_layer_kwargs(layer_props, file) + # Override kind in io with explicit kind arg + if isinstance(meta.get("io"), dict): + meta["io"]["kind"] = kind # stored as actual enum, not value + else: + attach_source_and_io_to_layer_kwargs(layer_props, file) + + return [(data, layer_props, "points")] + + +# TODO move to dataframes.py +def form_df( + points_data, + layer_metadata: dict, + layer_properties: dict, +) -> pd.DataFrame: + """ + Form a DataFrame from points data + layer metadata, structured according to DLC conventions. + + Arguments + --------- + points_data: + array-like of shape (N, 3) in napari-style [frame, y, x] + layer_metadata: + dict that must contain at least: 'header' (DLCHeaderModel), optional 'paths' + layer_properties: + dict that must contain: 'label', 'id', optional 'likelihood' + """ + layer_metadata = layer_metadata or {} + layer_properties = layer_properties or {} + + # ----------------------------- + # 1) Normalize/wrap header + # ----------------------------- + header_obj = layer_metadata.get("header", None) + if header_obj is None: + raise KeyError("layer_metadata['header'] is required to write DLC keypoints.") + + if isinstance(header_obj, dict) and "columns" in header_obj: + header_model = DLCHeaderModel.model_validate(header_obj) + elif isinstance(header_obj, DLCHeaderModel): + header_model = header_obj + else: + # Accept a DLCHeaderModel-like object (has .columns) + cols = getattr(header_obj, "columns", None) + if cols is None: + raise TypeError("layer_metadata['header'] must be a DLCHeaderModel or an object with a .columns attribute.") + header_model = DLCHeaderModel(columns=cols) + + # Build a PointsMetadata model from the layer_metadata dict, + # but replace raw header with our DLCHeaderModel wrapper. + meta_payload = dict(layer_metadata) + meta_payload["header"] = header_model + pts_meta = PointsMetadata.model_validate(meta_payload) + + # ----------------------------- + # 2) Fill missing likelihood (preserve old behavior) + # ----------------------------- + # Your old code assumed likelihood always existed. + # To remain backwards compatible, auto-fill with 1.0 if missing/None. + n = np.asarray(points_data).shape[0] if points_data is not None else 0 + props_payload = dict(layer_properties) + if props_payload.get("likelihood", None) is None: + props_payload["likelihood"] = [1.0] * n + + # ----------------------------- + # 3) Validate with dedicated schemas + # ----------------------------- + try: + points = dlc_schemas.PointsDataModel.model_validate({"data": points_data}) + props = dlc_schemas.KeypointPropertiesModel.model_validate(props_payload) + ctx = dlc_schemas.PointsWriteInputModel.model_validate({"points": points, "meta": pts_meta, "props": props}) + except ValidationError as e: + # Give a concise error that points to the failing part. + # The full `e` still has structured details if you want to log it. + raise ValueError(f"Invalid keypoint write inputs: {e}") from e + + # ----------------------------- + # 4) Delegate transformation + # ----------------------------- + df = form_df_from_validated(ctx) + + # Keep your belt-and-suspenders guarantee + guarantee_multiindex_rows(df) + return df + + +def _atomic_to_hdf(df: pd.DataFrame, out_path: Path, key: str = "keypoints") -> None: + """Best-effort atomic write: write to temp and replace.""" + out_path.parent.mkdir(parents=True, exist_ok=True) + tmp = out_path.with_suffix(out_path.suffix + ".tmp") + # Write temp + df.to_hdf(tmp, key=key, mode="w") + # Replace + tmp.replace(out_path) + + +def write_hdf(path: str, data, attributes: dict) -> list[str]: + """ + NPE2 single-layer writer. + + Signature required by napari (manifest-based writers): + def writer(path: str, data: Any, attributes: dict) -> List[str] + Writers must return a list of successfully-written paths. + + Contract: + - Empty Points layers may be written only if promoted + - Finite Points must always produce finite stored coordinates + + This function writes DLC keypoints to .h5 (and companion .csv). + """ + attrs = dlc_schemas.PointsLayerAttributesModel.model_validate(attributes or {}) + pts_meta: PointsMetadata = parse_points_metadata(attrs.metadata, drop_header=False) + if not pts_meta.header: + raise ValueError("Layer metadata must include a valid DLC header to write keypoints.") + + points = dlc_schemas.PointsDataModel.model_validate({"data": data}) + props = dlc_schemas.KeypointPropertiesModel.model_validate(attrs.properties) + + # Bundle + validate cross-field invariants + ctx = dlc_schemas.PointsWriteInputModel.model_validate( + { + "points": points, + "meta": pts_meta, + "props": props, + } + ) + + logger.debug("HEADER nlevels: %s", ctx.meta.header.as_multiindex().nlevels) + logger.debug("HEADER names: %s", ctx.meta.header.as_multiindex().names) + + # Build df from points + plugin metadata + layer properties + df_new = form_df_from_validated(ctx) + + logger.debug("DF_NEW columns nlevels: %s", df_new.columns.nlevels) + logger.debug("DF_NEW columns names: %s", df_new.columns.names) + logger.debug("DF_NEW finite count: %s", np.isfinite(df_new.to_numpy()).sum()) + + # Decide output path: + # 1) User-requested path should be ignored in favor of provenance when available + # This is a fallback only used when provenance is missing or unresolvable, + # and is never expected to be set for this plugin + # requested_out = _normalize_requested_out_path(path, layer_name) + + # 2) provenance/save_target is always the source of truth for where to write + out_path, target_scorer, source_kind = resolve_output_path_from_metadata(attributes) + + # If promoting to GT and scorer is known, rewrite scorer level + if target_scorer: + df_new = set_df_scorer(df_new, target_scorer) + + # Never write back to machine sources without an explicit promotion target + if not out_path and source_kind == AnnotationKind.MACHINE: + raise MissingProvenanceError("Cannot resolve provenance output path for MACHINE source.") + + # If provenance returned nothing, default to requested path + if not out_path: + # Strict only for MACHINE + if source_kind == AnnotationKind.MACHINE: + raise MissingProvenanceError("Cannot resolve provenance output path for MACHINE source.") + + # Prefer dataset folder if inferable (DLC convention) + project_ctx = infer_dlc_project_from_points_meta(pts_meta, prefer_project_root=False) + dataset_dir = None + if project_ctx is not None and project_ctx.dataset_folder is not None: + dataset_dir = project_ctx.dataset_folder + + # If dataset_dir exists or can be created, use it; else fall back to pts_meta.root + if dataset_dir is not None: + dataset_dir.mkdir(parents=True, exist_ok=True) + root_path = dataset_dir + else: + root = pts_meta.root + if not root: + raise MissingProvenanceError("GT fallback requires root (and dataset folder could not be inferred).") + root_path = Path(root) + + candidates = sorted(root_path.glob("CollectedData_*.h5")) + if len(candidates) > 1: + raise AmbiguousSaveError( + f"Multiple CollectedData_*.h5 files found in {root_path}." + " Cannot determine where to save." + " Please specify a save_target with explicit path and scorer.", + candidates=[str(c) for c in candidates], + ) + elif len(candidates) == 1: + out = candidates[0] + else: + scorer = target_scorer or pts_meta.header.scorer + out = root_path / f"CollectedData_{scorer}.h5" + else: + out = Path(out_path) + + # Determine destination kind (promotion writes to GT target) + has_save_target = pts_meta.save_target is not None + destination_kind = ( + AnnotationKind.GT + if has_save_target + else ((pts_meta.io.kind if pts_meta.io is not None else None) or AnnotationKind.GT) + ) + + # Merge-on-save for GT + if destination_kind == AnnotationKind.GT and out.exists(): + try: + df_old = pd.read_hdf(out, key="keypoints") + except (KeyError, ValueError): + df_old = pd.read_hdf(out) + + # Harmonize indices and merge + try: + guarantee_multiindex_rows(df_new) + guarantee_multiindex_rows(df_old) + except Exception: + pass + + df_new, df_old = harmonize_keypoint_row_index(df_new, df_old) + df_new = harmonize_keypoint_column_index(df_new) + df_old = harmonize_keypoint_column_index(df_old) + df_out = df_new.combine_first(df_old) + + # Normalize columns to DLC header if possible + try: + header = DLCHeaderModel(columns=df_out.columns) + df_out = df_out.reindex(header.columns, axis=1) + except Exception: + pass + else: + df_out = df_new + + # Final cleanup + try: + guarantee_multiindex_rows(df_out) + except Exception: + pass + df_out.sort_index(inplace=True) + + # Write .h5 and .csv + _atomic_to_hdf(df_out, out, key="keypoints") + csv_path = out.with_suffix(".csv") + df_out.to_csv(csv_path) + + return [str(out), str(csv_path)] + + +# ============================================================================= +# SUPERKEYPOINTS (assets: diagram + JSON) +# ============================================================================= +# NOTE: These are used to support DLCHeaderModel superkeypoints workflows. + + +def load_superkeypoints_json_from_path(json_path: str | Path): + path = Path(json_path) + if not path.is_file(): + raise FileNotFoundError(f"Superkeypoints JSON file not found at {json_path}.") + with open(path) as f: + payload = json.load(f) + if payload: + return payload + else: + raise ValueError(f"Superkeypoints JSON file at {json_path} is empty or invalid.") + + +def load_superkeypoints_diagram_from_path(image_path: str | Path): + path = Path(image_path) + if not path.is_file(): + raise FileNotFoundError(f"Superkeypoints diagram not found at {image_path}.") + try: + return imread(path).squeeze() + except Exception as e: + raise RuntimeError(f"Superkeypoints diagram could not be loaded from {image_path}.") from e + + +def load_superkeypoints_diagram(super_animal: str): + path = resources.files("napari_deeplabcut") / "assets" / f"{super_animal}.jpg" + return load_superkeypoints_diagram_from_path(path) + + +def load_superkeypoints(super_animal: str): + path = resources.files("napari_deeplabcut") / "assets" / f"{super_animal}.json" + return load_superkeypoints_json_from_path(path) + + +# ============================================================================= +# IMAGES (lazy stack with Dask) +# ============================================================================= +# NOTE: Image reading uses OpenCV for normalization and Dask for laziness. + + +# Helper functions for lazy image reading and normalization +# NOTE : forced keyword-only arguments for clarity +def _read_and_normalize(*, filepath: Path, normalize_func: Callable[[np.ndarray], np.ndarray]) -> np.ndarray: + arr = cv2.imread(str(filepath), cv2.IMREAD_UNCHANGED) + if arr is None: + raise OSError(f"Could not read image: {filepath}") + return normalize_func(arr) + + +def _normalize_to_rgb(arr: np.ndarray) -> np.ndarray: + if arr.ndim == 2: + return cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB) + if arr.ndim == 3 and arr.shape[2] == 4: + return cv2.cvtColor(arr, cv2.COLOR_BGRA2RGB) + return cv2.cvtColor(arr, cv2.COLOR_BGR2RGB) + + +# FIXME remove later +# def _expand_image_paths(path: str | Path | list[str | Path] | tuple[str | Path, ...]) -> list[Path]: +# # Normalize input to list[Path] +# raw_paths = [Path(p) for p in path] if isinstance(path, (list, tuple)) else [Path(path)] + +# expanded: list[Path] = [] +# for p in tqdm(raw_paths, desc="Expanding image paths", leave=False, unit="files"): +# if p.is_dir() and p.suffix.lower() != ".zarr": +# file_matches: list[Path] = [] +# for ext in SUPPORTED_IMAGES: +# file_matches.extend(p.glob(f"*{ext}")) +# expanded.extend(x for x in natsorted(file_matches, key=str) if x.is_file()) +# else: +# matches = list(p.parent.glob(p.name)) +# expanded.extend(matches or [p]) + +# return [p for p in expanded if p.is_file() and p.suffix.lower() in SUPPORTED_IMAGES] + + +def _expand_image_paths(path: str | Path | list[str | Path] | tuple[str | Path, ...]) -> list[Path]: + raw_paths = [Path(p) for p in path] if isinstance(path, (list, tuple)) else [Path(path)] + expanded: list[Path] = [] + + for p in raw_paths: + if p.is_dir() and p.suffix.lower() != ".zarr": + try: + with os.scandir(p) as it: + files = [ + Path(entry.path) + for entry in it + if entry.is_file() and Path(entry.name).suffix.lower() in _SUPPORTED_SUFFIXES + ] + except OSError: + continue + + files.sort(key=lambda q: q.name) + expanded.extend(files) + continue + + if not _has_glob_magic(p.name): + if p.is_file() and p.suffix.lower() in _SUPPORTED_SUFFIXES: + expanded.append(p) + continue + + parent = p.parent if str(p.parent) else Path(".") + pattern = p.name + try: + with os.scandir(parent) as it: + matches = [ + Path(entry.path) + for entry in it + if entry.is_file() + and fnmatch.fnmatchcase(entry.name, pattern) + and Path(entry.name).suffix.lower() in _SUPPORTED_SUFFIXES + ] + except OSError: + continue + + matches.sort(key=lambda q: q.name) + expanded.extend(matches) + + return expanded + + +# Lazy image reader that supports directories and lists of files +def _lazy_imread( + filenames: str | Path | list[str | Path], + use_dask: bool | None = None, + stack: bool = True, +) -> np.ndarray | da.Array | list[np.ndarray | da.Array]: + """Lazily reads one or more images with optional Dask support. + + Resolves file paths using `_expand_image_paths`, ensuring consistent + handling of directories, glob patterns, and lists/tuples of paths. + Images are normalized to RGB and may be wrapped in Dask delayed + objects for lazy loading. + + Behavior: + * If a single image is resolved: + - The image is read eagerly and returned as a NumPy array. + * If multiple images are resolved: + - The first image is read eagerly to determine shape and dtype. + - Subsequent images are loaded lazily via Dask unless + `use_dask=False`. + - Stacking behavior is controlled by `stack`. + + Args: + filenames (str | Path | list[str | Path]): + File path(s), directory, or glob pattern(s) to load. + use_dask (bool | None, optional): + Whether to load images lazily using Dask. + Defaults to `True` when multiple files are found, otherwise + `False`. + stack (bool, optional): + If True, stack images along axis 0 into a single array. + If False, return a list of arrays or delayed arrays. + Defaults to True. + + Returns: + np.ndarray | da.Array | list[np.ndarray | da.Array]: + Loaded image data. The return type depends on the number of + images found, the `use_dask` flag, and the `stack` option. + + Raises: + ValueError: If no supported images are found. + """ + expanded = _expand_image_paths(filenames) + + if not expanded: + raise ValueError(f"No supported images were found for input: {filenames}") + + if use_dask is None: + use_dask = len(expanded) > 1 + + images = [] + first_shape = None + first_dtype = None + + def make_delayed_array(fp: Path, first_shape: tuple[int, ...], first_dtype: np.dtype) -> da.Array: + """Create a dask array for a single file.""" + return da.from_delayed( + delayed(_read_and_normalize)(filepath=fp, normalize_func=_normalize_to_rgb), + shape=first_shape, + dtype=first_dtype, + ) + + for fp in expanded: + if first_shape is None: + arr0 = _read_and_normalize(filepath=fp, normalize_func=_normalize_to_rgb) + first_shape = arr0.shape + first_dtype = arr0.dtype + + if use_dask: + images.append(make_delayed_array(fp, first_shape, first_dtype)) + else: + images.append(arr0) + continue + + if use_dask: + images.append(make_delayed_array(fp, first_shape, first_dtype)) + else: + images.append(_read_and_normalize(filepath=fp, normalize_func=_normalize_to_rgb)) + + if len(images) == 1: + return images[0] + + try: + return da.stack(images) if use_dask and stack else (np.stack(images) if stack else images) + except ValueError as e: + raise ValueError( + "Cannot stack images with different shapes using NumPy. " + "Ensure all images have the same shape or set stack=False." + ) from e + + +# Read images from a list of files or a glob/string path +def read_images(path: str | Path | list[str | Path]): + """Reads one or multiple images and returns a Napari Image layer. + + Uses `_expand_image_paths` to resolve the input into a list of valid + image files. Supports single paths, glob expressions, directories, + and lists or tuples of such paths. + + Behavior: + * If one file is found: + - Loaded using `dask_image.imread.imread`. + * If multiple files are found: + - Loaded lazily using `lazy_imread` into a stacked image + layer. + + Args: + path (str | Path | list[str | Path]): + Input path(s), directory, or glob pattern(s) to expand into + supported image files. + + Returns: + list[LayerData]: + A list containing one Napari layer tuple of the form + `(data, metadata, "image")`. + + Raises: + OSError: If no supported images are found after expansion. + """ + filepaths = _expand_image_paths(path) + + if not filepaths: + raise OSError(f"No supported images were found in {path}") + + filepaths = natsorted(filepaths, key=str) + + # Multiple images → lazy-imread stack + if len(filepaths) > 1: + # NOTE: canonicalize_path(fp, 3) stores a stable relative-ish path for the UI/metadata. + relative_paths = [canonicalize_path(fp, 3) for fp in filepaths] + params = { + "name": "images", + "metadata": { + "paths": relative_paths, + "root": str(filepaths[0].parent), + }, + } + data = _lazy_imread(filepaths, use_dask=True, stack=True) + return [(data, params, "image")] + + # Single image → old behavior + image_path = filepaths[0] + params = { + "name": "images", + "metadata": { + "paths": [canonicalize_path(image_path, 3)], + "root": str(image_path.parent), + }, + } + return [(imread(str(image_path)), params, "image")] + + +# ============================================================================= +# VIDEO (OpenCV; optional PyAV fallback) +# ============================================================================= + + +def is_video(filename: str) -> bool: + return any(filename.lower().endswith(ext) for ext in SUPPORTED_VIDEOS) + + +# Video reader using OpenCV +class Video: + def __init__(self, video_path): + if not Path(video_path).is_file(): + raise ValueError(f'Video path "{video_path}" does not point to a file.') + + self.path = video_path + self.stream = cv2.VideoCapture(video_path) + if not self.stream.isOpened(): + raise OSError("Video could not be opened.") + + self._n_frames = int(self.stream.get(cv2.CAP_PROP_FRAME_COUNT)) + self._width = int(self.stream.get(cv2.CAP_PROP_FRAME_WIDTH)) + self._height = int(self.stream.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self._frame = cv2.UMat(self._height, self._width, cv2.CV_8UC3) + + def __len__(self): + return self._n_frames + + @property + def width(self): + return self._width + + @property + def height(self): + return self._height + + def set_to_frame(self, ind): + ind = min(ind, len(self) - 1) + ind += 1 # Unclear why this is needed at all + self.stream.set(cv2.CAP_PROP_POS_FRAMES, ind) + + def read_frame(self): + self.stream.retrieve(self._frame) + cv2.cvtColor(self._frame, cv2.COLOR_BGR2RGB, self._frame, 3) + return self._frame.get() + + def close(self): + self.stream.release() + + +def read_video(filename: str, opencv: bool = True): + if opencv: + stream = Video(filename) + # NOTE construct output shape tuple in (H, W, C) order to match read_frame() data + shape = stream.height, stream.width, 3 + + def _read_frame(ind): + stream.set_to_frame(ind) + return stream.read_frame() + + lazy_reader = delayed(_read_frame) + else: # pragma: no cover + from pims import PyAVReaderIndexed + + try: + stream = PyAVReaderIndexed(filename) + except ImportError: + raise ImportError("`pip install av` to use the PyAV video reader.") from None + + shape = stream.frame_shape + lazy_reader = delayed(stream.get_frame) + + movie = da.stack([da.from_delayed(lazy_reader(i), shape=shape, dtype=np.uint8) for i in range(len(stream))]) + elems = list(Path(filename).parts) + elems[-2] = "labeled-data" + elems[-1] = Path(elems[-1]).stem # + Path(filename).suffix + root = str(Path(*elems)) + params = { + "name": filename, + "metadata": { + "root": root, + }, + } + return [(movie, params, "image")] diff --git a/src/napari_deeplabcut/core/keypoints.py b/src/napari_deeplabcut/core/keypoints.py new file mode 100644 index 00000000..ff91b27f --- /dev/null +++ b/src/napari_deeplabcut/core/keypoints.py @@ -0,0 +1,448 @@ +# src/napari_deeplabcut/keypoints.py +import logging +from collections import namedtuple +from collections.abc import Sequence +from enum import auto + +import numpy as np +from matplotlib import colormaps as mpl_colormaps +from napari._qt.layer_controls.qt_points_controls import QtPointsControls +from napari.layers import Points +from napari.layers.points._points_constants import SYMBOL_TRANSLATION_INVERTED +from napari.layers.points._points_utils import coerce_symbols +from napari.utils import colormaps +from pydantic import ValidationError +from scipy.spatial import cKDTree + +from napari_deeplabcut.config.models import DLCHeaderModel +from napari_deeplabcut.core.metadata import read_points_meta +from napari_deeplabcut.misc import CycleEnum, HeaderLike + +logger = logging.getLogger(__name__) + + +# Monkeypatch the point size slider +def _change_size(self, value): + """Resize all points at once regardless of the current selection.""" + self.layer._current_size = value + if self.layer._update_properties: + self.layer.size = (self.layer.size > 0) * value + self.layer.refresh() + self.layer.events.size() + + +def _change_symbol(self, text): + symbol = coerce_symbols(np.array([SYMBOL_TRANSLATION_INVERTED[text]]))[0] + self.layer._current_symbol = symbol + if self.layer._update_properties: + self.layer.symbol = symbol + self.layer.events.symbol() + self.layer.events.current_symbol() + + +QtPointsControls.changeCurrentSize = _change_size +QtPointsControls.changeCurrentSymbol = _change_symbol + + +def _validate_points_meta_best_effort(layer) -> bool: + """ + Phase-2 friendly: validate points metadata without mutating it. + We drop header + controls during validation to avoid runtime-object issues. + """ + res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=True) + if isinstance(res, ValidationError): + logger.debug("Points metadata invalid for layer=%r: %s", getattr(layer, "name", layer), res) + return False + return True + + +class ColorMode(CycleEnum): + """Modes in which keypoints can be colored + + BODYPART: the keypoints are grouped by bodypart (all bodyparts have the same color) + INDIVIDUAL: the keypoints are grouped by individual (all keypoints for the same + individual have the same color) + """ + + BODYPART = auto() + INDIVIDUAL = auto() + + @classmethod + def default(cls): + return cls.BODYPART + + +class LabelMode(CycleEnum): + """ + Labeling modes. + SEQUENTIAL: points are placed in sequence, then frame after frame; + clicking to add an already annotated point has no effect. + QUICK: similar to SEQUENTIAL, but trying to add an already + annotated point actually moves it to the cursor location. + LOOP: the currently selected point is placed frame after frame, + before wrapping at the end to frame 1, etc. + """ + + SEQUENTIAL = auto() + QUICK = auto() + LOOP = auto() + + @classmethod + def default(cls): + return cls.SEQUENTIAL + + +# Description tooltips for the labeling modes radio buttons. +TOOLTIPS = { + "SEQUENTIAL": "Points are placed in sequence, then frame after frame;\n" + "clicking to add an already annotated point has no effect.", + "QUICK": "Similar to SEQUENTIAL, but trying to add an already\n" + "annotated point actually moves it to the cursor location.", + "LOOP": "The currently selected point is placed frame after frame,\nbefore wrapping at the end to frame 1, etc.", +} + + +Keypoint = namedtuple("Keypoint", ["label", "id"]) + + +class KeypointStore: + def __init__(self, viewer, layer: Points): + self.viewer = viewer + self._keypoints = [] + self._header: DLCHeaderModel | None = None + self.layer = layer + self.viewer.dims.set_current_step(0, 0) + + @property + def layer(self): + return self._layer + + @layer.setter + def layer(self, layer): + self._layer = layer + + res = read_points_meta(layer, migrate_legacy=True, drop_controls=True, drop_header=False) + if isinstance(res, ValidationError) or res.header is None: + self._header = None + self._keypoints = [] + return + + self._header = res.header + pairs = self._header.form_individual_bodypart_pairs() + self._keypoints = [Keypoint(label, id_) for id_, label in pairs] + + @property + def labels(self) -> list[str]: + return self._header.bodyparts if self._header is not None else [] + + @property + def ids(self) -> list[str]: + return self._header.individuals if self._header is not None else [] + + @property + def current_step(self): + return self.viewer.dims.current_step[0] + + @property + def n_steps(self): + return self.viewer.dims.nsteps[0] + + @property + def annotated_keypoints(self) -> list[Keypoint]: + mask = self.current_mask + labels = self.layer.properties["label"][mask] + ids = self.layer.properties["id"][mask] + return [Keypoint(label, id_) for label, id_ in zip(labels, ids, strict=False)] + + @property + def current_mask(self) -> Sequence[bool]: + return np.asarray(self.layer.data[:, 0] == self.current_step) + + @property + def current_keypoint(self) -> Keypoint: + props = getattr(self.layer, "current_properties", {}) or {} + try: + label = props.get("label", [""])[0] + except Exception: + label = "" + try: + id_ = props.get("id", [""])[0] + except Exception: + id_ = "" + return Keypoint(label=label, id=id_) + + @current_keypoint.setter + def current_keypoint(self, keypoint: Keypoint): + # Avoid changing the properties of a selected point + if not len(self.layer.selected_data): + current_properties = self.layer.current_properties + current_properties["label"] = np.asarray([keypoint.label]) + current_properties["id"] = np.asarray([keypoint.id]) + self.layer.current_properties = current_properties + + def next_keypoint(self, *args): + ind = self._keypoints.index(self.current_keypoint) + 1 + if ind <= len(self._keypoints) - 1: + self.current_keypoint = self._keypoints[ind] + + def prev_keypoint(self, *args): + ind = self._keypoints.index(self.current_keypoint) - 1 + if ind >= 0: + self.current_keypoint = self._keypoints[ind] + + @property + def current_label(self) -> str: + return self.layer.current_properties["label"][0] + + @current_label.setter + def current_label(self, label: str): + if not len(self.layer.selected_data): + current_properties = self.layer.current_properties + current_properties["label"] = np.asarray([label]) + self.layer.current_properties = current_properties + + @property + def current_id(self) -> str: + return self.layer.current_properties["id"][0] + + @current_id.setter + def current_id(self, id_: str): + if not len(self.layer.selected_data): + current_properties = self.layer.current_properties + current_properties["id"] = np.asarray([id_]) + self.layer.current_properties = current_properties + + def _advance_step(self, event): + ind = (self.current_step + 1) % self.n_steps + self.viewer.dims.set_current_step(0, ind) + + def _find_first_unlabeled_frame(self, event): + inds = set(range(self.n_steps)) + unlabeled_inds = inds.difference(self.layer.data[:, 0].astype(int)) + if not unlabeled_inds: + self.viewer.dims.set_current_step(0, self.n_steps - 1) + else: + self.viewer.dims.set_current_step(0, min(unlabeled_inds)) + + +def _add(store, coord): + coord = np.atleast_2d(coord) + + # Controls are runtime-only; prefer layer attribute, fall back to metadata. + get_mode = getattr(store, "_get_label_mode", None) + label_mode = get_mode() if callable(get_mode) else None + + if store.current_keypoint not in store.annotated_keypoints: + # 1) append data + store.layer.data = np.append(store.layer.data, coord, axis=0) + + # 2) append/align properties to match number of points + kp = store.current_keypoint + n_new = coord.shape[0] + n_total = len(store.layer.data) + n_old = n_total - n_new + + props = store.layer.properties.copy() + + def _as_array(key, dtype): + arr = props.get(key, None) + if arr is None: + return np.array([], dtype=dtype) + return np.asarray(arr, dtype=dtype) + + # Existing values truncated/padded to n_old, then append new rows + label_arr = _as_array("label", object)[:n_old] + id_arr = _as_array("id", object)[:n_old] + lik_arr = _as_array("likelihood", float)[:n_old] + + # If any are shorter than n_old, pad (rare but safe) + if label_arr.size < n_old: + label_arr = np.concatenate([label_arr, np.array([kp.label] * (n_old - label_arr.size), dtype=object)]) + if id_arr.size < n_old: + id_arr = np.concatenate([id_arr, np.array([kp.id] * (n_old - id_arr.size), dtype=object)]) + if lik_arr.size < n_old: + lik_arr = np.concatenate([lik_arr, np.ones(n_old - lik_arr.size, dtype=float)]) + + props["label"] = np.concatenate([label_arr, np.array([kp.label] * n_new, dtype=object)]) + props["id"] = np.concatenate([id_arr, np.array([kp.id] * n_new, dtype=object)]) + props["likelihood"] = np.concatenate([lik_arr, np.ones(n_new, dtype=float)]) + + store.layer.properties = props + + elif label_mode is LabelMode.QUICK: + ind = store.annotated_keypoints.index(store.current_keypoint) + data = store.layer.data + data[np.flatnonzero(store.current_mask)[ind]] = coord.squeeze() + store.layer.data = data + + store.layer.selected_data = set() + + # If controls are missing, behave like the default mode (advance keypoint) + if label_mode is LabelMode.LOOP: + store.layer.events.query_next_frame() + else: + store.next_keypoint() + + +def _find_nearest_neighbors(xy_true, xy_pred, k=5): + n_preds = xy_pred.shape[0] + tree = cKDTree(xy_pred) + dist, inds = tree.query(xy_true, k=k) + idx = np.argsort(dist[:, 0]) + neighbors = np.full(len(xy_true), -1, dtype=int) + picked = set() + for i, ind in enumerate(inds[idx]): + for j in ind: + if j not in picked: + picked.add(j) + neighbors[idx[i]] = j + break + if len(picked) == n_preds: + break + return neighbors + + +# ---------------------------- +# Colormap functions +# ---------------------------- +def _rgba_array(colors) -> np.ndarray: + """Normalize a color list/array to float RGBA shape (N, 4).""" + arr = np.asarray(colors, dtype=float) + + if arr.size == 0: + return np.empty((0, 4), dtype=float) + + if arr.ndim == 1: + if arr.shape[0] == 3: + arr = np.r_[arr, 1.0][None, :] + elif arr.shape[0] == 4: + arr = arr[None, :] + else: + raise ValueError(f"Unexpected color shape: {arr.shape!r}") + elif arr.ndim == 2 and arr.shape[1] == 3: + arr = np.c_[arr, np.ones(len(arr), dtype=float)] + elif arr.ndim != 2 or arr.shape[1] != 4: + raise ValueError(f"Unexpected colors array shape: {arr.shape!r}") + + return np.asarray(arr, dtype=float) + + +def _repeat_or_trim(colors: np.ndarray, n_colors: int) -> np.ndarray: + """Return exactly n_colors rows by trimming or cycling a palette.""" + if n_colors <= 0: + return np.empty((0, 4), dtype=float) + + colors = _rgba_array(colors) + if len(colors) == 0: + return np.empty((0, 4), dtype=float) + + if len(colors) >= n_colors: + return colors[:n_colors] + + reps = int(np.ceil(n_colors / len(colors))) + out = np.tile(colors, (reps, 1))[:n_colors] + + logger.debug( + "Requested %d colors from a listed palette of length %d; cycling palette.", + n_colors, + len(colors), + ) + return out + + +def _try_matplotlib_listed_colors(colormap: str | None) -> np.ndarray | None: + """ + Return listed RGBA colors from a matplotlib colormap when available. + + This is the preferred path for qualitative palettes like Set3, tab10, tab20, + Dark2, etc., because they should be treated as discrete palettes, not sampled + continuously. + """ + if not colormap: + return None + + try: + mpl_cmap = mpl_colormaps.get_cmap(colormap) + except Exception: + return None + + listed = getattr(mpl_cmap, "colors", None) + if listed is None: + return None + + try: + return _rgba_array(listed) + except Exception: + logger.debug("Failed to normalize matplotlib listed colors for %r", colormap, exc_info=True) + return None + + +def _sample_continuous_colormap(cmap, n_colors: int) -> np.ndarray: + """ + Sample a continuous colormap at bin centers. + + Using centers instead of endpoints avoids repeated-looking adjacent colors and + behaves better for categorical assignment. + """ + if n_colors <= 0: + return np.empty((0, 4), dtype=float) + + values = (np.arange(n_colors, dtype=float) + 0.5) / n_colors + return _rgba_array(cmap.map(values)) + + +def build_color_cycle(n_colors: int, colormap: str | None = "viridis") -> np.ndarray: + """ + Build a robust RGBA color cycle. + + Policy + ------ + 1) If `colormap` is a listed matplotlib palette (e.g. Set3, tab10, tab20, + Dark2...), use its listed colors directly. + 2) Otherwise, resolve via napari and sample at bin centers. + """ + if n_colors <= 0: + return np.empty((0, 4), dtype=float) + + # Prefer discrete/listed matplotlib palettes directly. + listed = _try_matplotlib_listed_colors(colormap) + if listed is not None and len(listed) > 0: + return _repeat_or_trim(listed, n_colors) + + # Fall back to napari colormap resolution. + cmap = colormaps.ensure_colormap(colormap) + + # If napari resolved something that itself behaves like a listed palette, + # prefer those colors directly as well. + try: + cmap_colors = getattr(cmap, "colors", None) + if cmap_colors is not None: + cmap_colors = _rgba_array(cmap_colors) + interp = str(getattr(cmap, "interpolation", "")).lower() + if len(cmap_colors) > 0 and interp == "zero": + return _repeat_or_trim(cmap_colors, n_colors) + except Exception: + logger.debug("Failed to inspect napari colormap %r for listed colors", colormap, exc_info=True) + + return _sample_continuous_colormap(cmap, n_colors) + + +def build_color_cycles(header: HeaderLike, colormap: str | None = "viridis"): + """ + Build categorical label/id color mappings from a DLC-style header. + + Notes + ----- + - bodyparts always preserve header order + - individuals preserve header order, excluding blank single-animal placeholders + """ + bodyparts = [str(x) for x in header.bodyparts] + individuals = [str(x) for x in header.individuals if str(x) != ""] + + label_colors = build_color_cycle(len(bodyparts), colormap) + id_colors = build_color_cycle(len(individuals), colormap) + + return { + "label": dict(zip(bodyparts, label_colors, strict=False)), + "id": dict(zip(individuals, id_colors, strict=False)), + } diff --git a/src/napari_deeplabcut/core/layer_versioning.py b/src/napari_deeplabcut/core/layer_versioning.py new file mode 100644 index 00000000..70893902 --- /dev/null +++ b/src/napari_deeplabcut/core/layer_versioning.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field +from weakref import WeakKeyDictionary, ref + +from napari.layers import Layer, Points + + +@dataclass +class LayerChangeGenerations: + """ + Monotonic "versioning" tokens for a layer. + This helps build any derived state that needs to be invalidated on upstream changes, + without needing to inspect the nature of those changes. + + Notes + ----- + - `content` is intended for semantic/model changes (data, properties, features). + - `presentation` is intended for display/config changes (metadata, visual config). + """ + + content: int = 0 + presentation: int = 0 + + def bump_content(self) -> None: + self.content += 1 + + def bump_presentation(self) -> None: + self.presentation += 1 + + +@dataclass +class _LayerState: + generations: LayerChangeGenerations = field(default_factory=LayerChangeGenerations) + connections: list[tuple[object, Callable]] = field(default_factory=list) + installed: bool = False + + +class LayerChangeRegistry: + """ + Centralized registry for per-layer change generations and reusable event hooks. + Layer agnostic, but with some Points-specific hooks currently. + If the hook list needs updated, check `_content_emitters` and `_presentation_emitters` below. + + Design goals + ------------ + - O(1) version-token reads + - one-time event hookup per layer + - explicit detach support + - no mutation of napari internals + """ + + def __init__(self) -> None: + self._states: WeakKeyDictionary[Layer, _LayerState] = WeakKeyDictionary() + + def ensure_hooks(self, layer: Layer) -> LayerChangeGenerations: + state = self._states.get(layer) + if state is None: + state = _LayerState() + self._states[layer] = state + + if not state.installed: + self._install_hooks(layer, state) + state.installed = True + + return state.generations + + def generations_for(self, layer: Layer) -> LayerChangeGenerations: + return self.ensure_hooks(layer) + + def mark_content_changed(self, layer: Layer) -> None: + self.ensure_hooks(layer).bump_content() + + def mark_presentation_changed(self, layer: Layer) -> None: + self.ensure_hooks(layer).bump_presentation() + + def detach(self, layer: Layer) -> None: + state = self._states.get(layer) + if state is None: + return + + for emitter, callback in state.connections: + try: + emitter.disconnect(callback) + except Exception: + # Best-effort disconnect; emitter/layer may already be torn down. + pass + + state.connections.clear() + state.installed = False + self._states.pop(layer, None) + + # ---------- Private ---------- + + def _install_hooks(self, layer: Layer, state: _LayerState) -> None: + layer_ref = ref(layer) + + def _with_layer(fn: Callable[[Layer], None]) -> Callable: + def _callback(event=None) -> None: + target = layer_ref() + if target is None: + return + fn(target) + + return _callback + + def _on_content_change(target: Layer) -> None: + state = self._states.get(target) + if state is None: + return + state.generations.bump_content() + + def _on_presentation_change(target: Layer) -> None: + state = self._states.get(target) + if state is None: + return + state.generations.bump_presentation() + + for emitter in self._content_emitters(layer): + callback = _with_layer(_on_content_change) + emitter.connect(callback) + state.connections.append((emitter, callback)) + + for emitter in self._presentation_emitters(layer): + callback = _with_layer(_on_presentation_change) + emitter.connect(callback) + state.connections.append((emitter, callback)) + + def _content_emitters(self, layer: Layer) -> list[object]: + emitters = [ + getattr(layer.events, "data", None), + getattr(layer.events, "set_data", None), + ] + + # Points-specific semantic state + if isinstance(layer, Points): + emitters.extend( + [ + getattr(layer.events, "properties", None), + getattr(layer.events, "features", None), + ] + ) + + return [emitter for emitter in emitters if emitter is not None] + + def _presentation_emitters(self, layer: Layer) -> list[object]: + emitters = [ + getattr(layer.events, "metadata", None), + ] + return [emitter for emitter in emitters if emitter is not None] + + +_LAYER_CHANGES = LayerChangeRegistry() + + +def ensure_layer_change_hooks(layer: Layer) -> LayerChangeGenerations: + return _LAYER_CHANGES.ensure_hooks(layer) + + +def layer_change_generations(layer: Layer) -> LayerChangeGenerations: + return _LAYER_CHANGES.generations_for(layer) + + +def mark_layer_content_changed(layer: Layer) -> None: + _LAYER_CHANGES.mark_content_changed(layer) + + +def mark_layer_presentation_changed(layer: Layer) -> None: + _LAYER_CHANGES.mark_presentation_changed(layer) + + +def detach_layer_change_hooks(layer: Layer) -> None: + _LAYER_CHANGES.detach(layer) diff --git a/src/napari_deeplabcut/core/layers.py b/src/napari_deeplabcut/core/layers.py new file mode 100644 index 00000000..afbc9ea7 --- /dev/null +++ b/src/napari_deeplabcut/core/layers.py @@ -0,0 +1,392 @@ +from __future__ import annotations + +import logging +from collections.abc import Callable, Iterable, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Any, TypeVar + +import numpy as np +from napari.layers import Image, Points, Shapes, Tracks + +from napari_deeplabcut.config.models import AnnotationKind, DLCHeaderModel +from napari_deeplabcut.core.keypoints import build_color_cycles + +T = TypeVar("T") + +# TODO move to a layers/ folder? +logger = logging.getLogger(__name__) + + +# Helper to populate keypoint layer properties +def populate_keypoint_layer_properties( + header: DLCHeaderModel, + *, + labels: Sequence[str] | None = None, + ids: Sequence[str] | None = None, + likelihood: Sequence[float] | None = None, + paths: list[str] | None = None, + size: int | None = 8, + pcutoff: float | None = 0.6, + colormap: str | None = "viridis", +) -> dict: + """ + Populate metadata and display properties for a keypoint Points layer. + + Notes + ----- + - Single-animal DLC: "individuals" level is effectively absent; we represent + that as ids[0] == "" (falsy) => color/text by label. + - Multi-animal DLC: ids[0] is a non-empty individual identifier => color/text by id. + - Must accept empty labels/ids/likelihood and must not assume ≥ 1 entry. + """ + + if labels is None: + labels = header.bodyparts + if ids is None: + ids = header.individuals + if likelihood is None: + likelihood_arr = np.ones(len(labels), dtype=float) + else: + likelihood_arr = np.asarray(likelihood, dtype=float) + + # 1) Normalize inputs to plain lists (Series-safe) + # This prevents pandas Series truthiness errors. + labels_list = list(labels) if labels is not None else [] + ids_list = list(ids) if ids is not None else [] + + # 2) Likelihood: always numeric ndarray for vector ops + if likelihood is None: + likelihood_arr = np.ones(len(labels_list), dtype=float) + else: + likelihood_arr = np.asarray(list(likelihood), dtype=float) + + # 3) Determine single vs multi animal: + # - empty ids => treat as single-animal (label-based) + # - ids[0] == "" => also single-animal (label-based) + first_id = ids_list[0] if len(ids_list) > 0 else "" + use_id = bool(first_id) + + face_color_cycle_maps = build_color_cycles(header, colormap) + face_color_prop = "id" if use_id else "label" + + return { + "name": "keypoints", + "text": "{id}–{label}" if use_id else "label", + "properties": { + "label": list(labels), + "id": list(ids), + "likelihood": likelihood_arr, + "valid": likelihood_arr > pcutoff, + }, + "face_color_cycle": face_color_cycle_maps[face_color_prop], + "face_color": face_color_prop, + "face_colormap": colormap, + "border_color": "valid", + "border_color_cycle": ["black", "red"], + "border_width": 0, + "border_width_is_relative": False, + "size": size, + "metadata": { + "header": header, + "face_color_cycles": face_color_cycle_maps, + "colormap_name": colormap, + "paths": paths or [], + }, + } + + +def is_machine_layer(layer) -> bool: + md = getattr(layer, "metadata", {}) or {} + io = md.get("io") or {} + k = io.get("kind") + # allow enum or string + if k is AnnotationKind.MACHINE: + return True + is_machine = str(k).lower() == "machine" + if is_machine: + logger.info( + "A literal 'machine' str was used for io.kind; please use AnnotationKind.MACHINE for better validation." + ) + return is_machine + + +# ----------------------------------------------- +# Layer-finding utilities +# ----------------------------------------------- +def iter_layers(viewer_or_layers: Any) -> Iterable[Any]: + """Yield layers from a napari Viewer or an iterable of layers.""" + layers = getattr(viewer_or_layers, "layers", viewer_or_layers) + return layers + + +def find_first_layer( + viewer_or_layers: Any, + layer_type: type[T], + predicate: Callable[[T], bool] | None = None, +) -> T | None: + """Return the first layer of type ``layer_type`` that matches ``predicate``. + + Parameters + ---------- + viewer_or_layers: + A napari Viewer, LayerList, or any iterable of layers. + layer_type: + The desired layer type (e.g., napari.layers.Points). + predicate: + Optional function to further filter matching layers. + + Notes + ----- + This intentionally mirrors the common pattern used throughout the plugin: + "iterate viewer.layers in order and pick the first match". + """ + pred = predicate or (lambda _ly: True) + for ly in iter_layers(viewer_or_layers): + if isinstance(ly, layer_type) and pred(ly): + return ly + return None + + +def find_last_layer( + viewer_or_layers: Any, + layer_type: type[T], + predicate: Callable[[T], bool] | None = None, +) -> T | None: + """Return the last layer of type ``layer_type`` that matches ``predicate``.""" + pred = predicate or (lambda _ly: True) + last: T | None = None + for ly in iter_layers(viewer_or_layers): + if isinstance(ly, layer_type) and pred(ly): + last = ly + return last + + +# -------------------- +# Convenience wrappers +# -------------------- + + +def get_first_points_layer(viewer_or_layers: Any) -> Any | None: + return find_first_layer(viewer_or_layers, Points) + + +def get_first_image_layer(viewer_or_layers: Any) -> Any | None: + return find_first_layer(viewer_or_layers, Image) + + +def get_first_video_image_layer(viewer_or_layers: Any) -> Any | None: + """First Image layer that looks like a video (>=3D data).""" + + def _is_video(img: Any) -> bool: + try: + return hasattr(img, "data") and getattr(img.data, "ndim", 0) >= 3 + except Exception: + return False + + return find_first_layer(viewer_or_layers, Image, _is_video) + + +def get_points_layer_with_tables(viewer_or_layers: Any) -> Any | None: + """First Points layer whose metadata has a non-empty 'tables' entry.""" + + def _has_tables(pts: Any) -> bool: + try: + md = getattr(pts, "metadata", None) or {} + return bool(md.get("tables")) + except Exception: + return False + + return find_first_layer(viewer_or_layers, Points, _has_tables) + + +def get_first_shapes_layer(viewer_or_layers: Any) -> Any | None: + return find_first_layer(viewer_or_layers, Shapes) + + +def get_first_tracks_layer(viewer_or_layers: Any) -> Any | None: + return find_first_layer(viewer_or_layers, Tracks) + + +@dataclass(frozen=True) +class LabelProgress: + labeled_points: int + total_points: int + labeled_percent: float + remaining_percent: float + frame_count: int + bodypart_count: int + individual_count: int + + +def _get_header_model_from_metadata(md: dict) -> DLCHeaderModel | None: + if not isinstance(md, dict): + return None + + hdr = md.get("header") + if hdr is None: + return None + + if isinstance(hdr, DLCHeaderModel): + return hdr + + if isinstance(hdr, dict): + try: + return DLCHeaderModel.model_validate(hdr) + except Exception: + return None + + try: + return DLCHeaderModel(columns=hdr) + except Exception: + return None + + +def get_uniform_point_size(layer: Points, *, default: int = 6) -> int: + size = getattr(layer, "size", default) + try: + arr = np.asarray(size, dtype=float).ravel() + if arr.size == 0: + return default + return int(round(float(np.nanmean(arr)))) + except Exception: + try: + return int(round(float(size))) + except Exception: + return default + + +def set_uniform_point_size(layer: Points, size: int) -> None: + # Scalar assignment keeps it lightweight and applies uniformly. + layer.size = float(size) + + +def infer_frame_count(layer: Points, *, fallback_paths: list[str] | None = None) -> int: + md = getattr(layer, "metadata", {}) or {} + + paths = md.get("paths") or fallback_paths or [] + if paths: + return len(paths) + + data = np.asarray(getattr(layer, "data", [])) + if data.size == 0: + return 0 + + try: + # Points layers use frame/time in first column + return int(np.nanmax(data[:, 0])) + 1 + except Exception: + return 0 + + +def infer_bodypart_count(layer: Points) -> int: + hdr = _get_header_model_from_metadata(getattr(layer, "metadata", {}) or {}) + if hdr is None: + return 0 + + try: + return len([bp for bp in hdr.bodyparts if str(bp) != ""]) + except Exception: + return 0 + + +def infer_individual_count(layer: Points) -> int: + """ + Returns the number of valid DLC individuals. + + Single-animal convention: + - if no individuals are defined + - or individuals are empty / blank + => returns 1 + """ + hdr = _get_header_model_from_metadata(getattr(layer, "metadata", {}) or {}) + if hdr is None: + return 1 + + try: + inds = [str(ind) for ind in hdr.individuals if str(ind) != ""] + return max(1, len(inds)) + except Exception: + return 1 + + +def compute_label_progress(layer: Points, *, fallback_paths: list[str] | None = None) -> LabelProgress: + frame_count = infer_frame_count(layer, fallback_paths=fallback_paths) + bodypart_count = infer_bodypart_count(layer) + individual_count = infer_individual_count(layer) + + total_points = frame_count * bodypart_count * individual_count + + data = np.asarray(getattr(layer, "data", [])) + labeled_points = int(data.shape[0]) if data.ndim >= 2 else 0 + + if total_points > 0: + labeled_points = min(labeled_points, total_points) + labeled_percent = 100.0 * labeled_points / total_points + else: + labeled_percent = 0.0 + + remaining_percent = max(0.0, 100.0 - labeled_percent) + + return LabelProgress( + labeled_points=labeled_points, + total_points=total_points, + labeled_percent=labeled_percent, + remaining_percent=remaining_percent, + frame_count=frame_count, + bodypart_count=bodypart_count, + individual_count=individual_count, + ) + + +def infer_folder_display_name( + active_layer, + *, + fallback_root: str | None = None, +) -> str: + """ + Best-effort label for the current image/video folder context. + """ + if active_layer is None: + return "—" + + md = getattr(active_layer, "metadata", {}) or {} + + paths = md.get("paths") or [] + if paths: + try: + return Path(paths[0]).expanduser().parent.name or "—" + except Exception: + pass + + root = md.get("root") or fallback_root + if root: + try: + return Path(root).expanduser().name or "—" + except Exception: + pass + + try: + src = getattr(getattr(active_layer, "source", None), "path", None) + if src: + p = Path(str(src)) + if p.is_file(): + # video source: show parent folder name + return p.parent.name or p.stem or "—" + return p.name or "—" + except Exception: + pass + + return "—" + + +def find_relevant_image_layer(viewer) -> Image | None: + active = viewer.layers.selection.active + if isinstance(active, Image): + return active + + for layer in viewer.layers: + if isinstance(layer, Image): + return layer + + return None diff --git a/src/napari_deeplabcut/core/metadata.py b/src/napari_deeplabcut/core/metadata.py new file mode 100644 index 00000000..36710eb3 --- /dev/null +++ b/src/napari_deeplabcut/core/metadata.py @@ -0,0 +1,864 @@ +# src/napari_deeplabcut/core/metadata.py +from __future__ import annotations + +import logging +from collections.abc import Iterable, Mapping +from enum import Enum +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, ValidationError + +from napari_deeplabcut.config.models import AnnotationKind, DLCHeaderModel, ImageMetadata, IOProvenance, PointsMetadata +from napari_deeplabcut.core.discovery import infer_annotation_kind_for_file +from napari_deeplabcut.core.errors import AmbiguousSaveError, MissingProvenanceError +from napari_deeplabcut.core.project_paths import canonicalize_path + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Inference +# ----------------------------------------------------------------------------- +def _coerce_path(p: str | None) -> Path | None: + if not p: + return None + try: + return Path(p).expanduser().resolve() + except Exception: + return Path(p) + + +def _is_dlc_dataset_root(p: Path) -> bool: + """ + Heuristic: DLC dataset folder usually looks like: + /labeled-data/ + + True if path contains a 'labeled-data' segment AND is deeper than that folder. + """ + parts = [s.lower() for s in p.parts] + if "labeled-data" not in parts: + return False + return parts[-1] != "labeled-data" + + +def _paths_look_like_labeled_data(paths: list[str] | None) -> bool: + """ + Check if any path strings contain 'labeled-data//'. + Works with canonicalized paths like 'labeled-data/test/img000.png'. + """ + if not paths: + return False + for s in paths: + if isinstance(s, str) and "labeled-data" in s.replace("\\", "/").lower(): + return True + return False + + +def _looks_like_project_root(points_root: str | None, project: str | None) -> bool: + """ + Root equals project root (config parent) => this is WRONG for saving GT. + """ + if not points_root or not project: + return False + try: + return Path(points_root).expanduser().resolve() == Path(project).expanduser().resolve() + except Exception: + return points_root == project + + +def build_io_provenance_dict( + *, + project_root: str | Path, + source_relpath_posix: str, + kind: AnnotationKind | None, + dataset_key: str, + **extra: Any, +) -> dict[str, Any]: + """ + Build a provenance dict for storage in napari layer.metadata. + + Important: uses mode="python" so AnnotationKind stays an enum at runtime. + """ + io = IOProvenance( + project_root=str(project_root), + source_relpath_posix=source_relpath_posix, + kind=kind, + dataset_key=dataset_key, + **extra, + ) + return io.model_dump(mode="python", exclude_none=True) + + +def infer_image_root( + *, + explicit_root: str | None = None, + paths: Iterable[str] | None = None, + source_path: str | None = None, +) -> str | None: + """ + Best-effort inference of an image root directory. + + Priority: + 1. explicit_root + 2. parent of first path + 3. parent of source_path + """ + if explicit_root: + return explicit_root + + if paths: + try: + return str(Path(next(iter(paths))).expanduser().resolve().parent) + except Exception: + pass + + if source_path: + try: + return str(Path(source_path).expanduser().resolve().parent) + except Exception: + pass + + return None + + +# ----------------------------------------------------------------------------- +# Safe update / merge rules +# ----------------------------------------------------------------------------- + + +def merge_image_metadata(base: ImageMetadata, incoming: ImageMetadata) -> ImageMetadata: + """Merge ImageMetadata without clobbering existing values.""" + data = base.model_dump(mode="python") + for field, value in incoming.model_dump(mode="python").items(): + if data.get(field) in (None, "", []) and value not in (None, "", []): + data[field] = value + return ImageMetadata(**data) + + +def merge_points_metadata(base: PointsMetadata, incoming: PointsMetadata) -> PointsMetadata: + """Merge PointsMetadata without clobbering existing values.""" + data = base.model_dump(mode="python") + for field, value in incoming.model_dump(mode="python").items(): + if field == "controls": + continue + if data.get(field) in (None, "", []) and value not in (None, "", []): + data[field] = value + return PointsMetadata(**data) + + +# ----------------------------------------------------------------------------- +# Synchronization helpers +# ----------------------------------------------------------------------------- +def sync_points_from_image(image_meta: ImageMetadata, points_meta: PointsMetadata) -> PointsMetadata: + """ + Ensure PointsMetadata contains required image-derived fields. + + Robust DLC policy: + - If image root looks like a DLC dataset folder (…/labeled-data/), + prefer it for points_meta.root even if points_meta.root is already set + but equals project root (config parent) or is not a dataset root. + """ + updated = points_meta.model_dump(mode="python") + + # --- First: fill missing fields (existing behavior) --- + for key in ("root", "paths", "shape", "name"): + if updated.get(key) in (None, "", []): + value = getattr(image_meta, key, None) + if value not in (None, "", []): + updated[key] = value + + # --- Second: if we have dataset context, correct stale root --- + img_root_p = _coerce_path(getattr(image_meta, "root", None)) + pts_root_p = _coerce_path(updated.get("root")) + project_p = _coerce_path(updated.get("project")) + + # Determine if the image root is a DLC dataset directory + image_is_dataset = bool(img_root_p is not None and _is_dlc_dataset_root(img_root_p)) + + # Additional hint: sometimes image_meta.root might be missing, but paths show labeled-data + # (depends on readers / napari versions). Use that as secondary signal. + if not image_is_dataset: + if _paths_look_like_labeled_data(getattr(image_meta, "paths", None)): + # If image paths look like labeled-data/... and we have a root-like string, + # try to interpret image_meta.root anyway. + image_is_dataset = bool(img_root_p is not None and _is_dlc_dataset_root(img_root_p)) + + if image_is_dataset and img_root_p is not None: + # Override root if: + # - points root equals project root (typical config-first bug), OR + # - points root exists but isn't a dataset root. + should_override_root = False + + if _looks_like_project_root(str(pts_root_p) if pts_root_p else None, str(project_p) if project_p else None): + should_override_root = True + elif pts_root_p is not None and not _is_dlc_dataset_root(pts_root_p): + should_override_root = True + + if should_override_root: + updated["root"] = str(img_root_p) + + return PointsMetadata(**updated) + + +def apply_project_paths_override_to_points_meta( + pts_meta: PointsMetadata, + *, + project_root: str | Path, + rewritten_paths: list[str], +) -> PointsMetadata: + """ + Return a copy of PointsMetadata with a save-time project/path override applied. + + This updates: + - project + - paths + - save_target.project_root (if present) + + It intentionally clears `io` so downstream output routing cannot prefer stale + provenance (e.g. legacy-migrated source_h5 -> io) over the rewritten DLC row + keys. This keeps save routing consistent with the project-association rewrite. + + It intentionally does NOT rewrite `root`, so the current source-folder anchor + remains available for remapping / plugin-local workflows. + """ + project_root_str = str(project_root) + + updates = { + "project": project_root_str, + "paths": list(rewritten_paths), + # force save routing to use rewritten row keys + dataset-folder inference + # instead of any pre-existing legacy IO provenance + "io": None, + } + + if pts_meta.save_target is not None: + updates["save_target"] = pts_meta.save_target.model_copy(update={"project_root": project_root_str}) + + return pts_meta.model_copy(update=updates) + + +def ensure_metadata_models( + image_meta: dict | ImageMetadata | None, + points_meta: dict | PointsMetadata | None, +) -> tuple[ImageMetadata | None, PointsMetadata | None]: + """Normalize raw metadata dicts into authoritative models.""" + img = None + pts = None + if image_meta is not None: + img = image_meta if isinstance(image_meta, ImageMetadata) else ImageMetadata(**image_meta) + if points_meta is not None: + pts = points_meta if isinstance(points_meta, PointsMetadata) else PointsMetadata(**points_meta) + return img, pts + + +# ----------------------------------------------------------------------------- +# Parsing / round-tripping +# ----------------------------------------------------------------------------- +def _normalize_columns(cols: Any) -> Any: + try: + import pandas as pd + + if isinstance(cols, pd.MultiIndex): + return [tuple(map(str, t)) for t in cols.to_list()] + if isinstance(cols, pd.Index): + return [str(x) for x in cols.to_list()] + except Exception: + pass + return cols + + +def _coerce_header_to_model(header: Any, *, strict: bool = False) -> DLCHeaderModel | None: + """Coerce various runtime header forms into DLCHeaderModel. + + Accepts: + - DLCHeaderModel (returned as-is) + - dict (validated into DLCHeaderModel) + """ + if header is None: + if strict: + raise ValueError("Header is None; cannot write to DLCHeaderModel.") + return None + + if isinstance(header, DLCHeaderModel): + return header + + def _fail(msg: str, exc: Exception | None = None): + logger.debug(msg, exc_info=exc is not None) + if strict: + raise ValueError(msg) from exc + return None + + # dict-ish header + if isinstance(header, Mapping): + try: + hd = dict(header) + if "columns" in hd: + hd["columns"] = _normalize_columns(hd["columns"]) + return DLCHeaderModel.model_validate(hd) + except Exception as e: + return _fail("Failed to parse header dict into DLCHeaderModel.", exc=e) + + # cols = getattr(header, "columns", None) + # if cols is not None: + # try: + # return DLCHeaderModel(columns=_normalize_columns(cols)) + # except Exception as e: + # return _fail("Failed to coerce header with columns into DLCHeaderModel.", exc=e) + + return None + + +def _coerce_io_kind(d: dict, key: str = "kind") -> None: + k = d.get(key) + if isinstance(k, str): + try: + d[key] = AnnotationKind(k) # works if enum values are "gt"/"machine" + except Exception: + # optionally accept upper-cased names + try: + d[key] = AnnotationKind[k.upper()] + except Exception: + pass + + +def parse_points_metadata( + md: Mapping[str, Any] | PointsMetadata | None, + *, + drop_header: bool = False, + drop_controls: bool = True, + # TODO defaults may need adjusted @C-Achard +) -> PointsMetadata: + """ + Parse PointsMetadata from a napari layer.metadata mapping. + + Robust to runtime objects + - controls are dropped by default (runtime-only) + - header is kept by default (needed for writing + conflict checking) + """ + if md is None: + return PointsMetadata() + if isinstance(md, PointsMetadata): + return md + + raw = dict(md) + + # Drop runtime-only / non-serializable fields + if drop_controls: + raw.pop("controls", None) + + # Coerce header unless explicitly dropped + if drop_header: + raw.pop("header", None) + else: + hdr = raw.get("header", None) + logger.debug("Raw header type=%r", type(hdr)) + logger.debug("Raw header has columns=%s", hasattr(hdr, "columns")) + cols = getattr(hdr, "columns", None) + logger.debug("columns type=%r", type(cols)) + coerced = _coerce_header_to_model(hdr) + if coerced is not None: + raw["header"] = coerced + else: + # If a header was present but not usable, remove it so we can still parse. + raw.pop("header", None) + + io_dict = raw.get("io", None) + if isinstance(io_dict, dict): + _coerce_io_kind(io_dict) + + st_dict = raw.get("save_target", None) + if isinstance(st_dict, dict): + _coerce_io_kind(st_dict) + + try: + return PointsMetadata.model_validate(raw) + except Exception: + logger.debug("Failed to parse PointsMetadata; falling back to empty model.", exc_info=True) + return PointsMetadata() + + +def merge_model_into_metadata( + metadata: dict[str, Any], + model: BaseModel, + *, + exclude_none: bool = True, + exclude: set[str] | None = None, +) -> dict[str, Any]: + """Merge a Pydantic model into an existing metadata dict (shallow merge).""" + if exclude is None: + exclude = set() + + try: + dumped = model.model_dump(mode="python", exclude_none=exclude_none, exclude=exclude) + except Exception: + dumped = {} + + for k, v in dumped.items(): + metadata[k] = v + return metadata + + +# ----------------------------------------------------------------------------- +# Save target utilities +# ----------------------------------------------------------------------------- + + +def require_unique_target(candidates: list[Path], *, context: str = "save target") -> Path: + """Ensure a candidate list resolves to exactly one path.""" + if not candidates: + raise MissingProvenanceError(f"No candidates found for {context}.") + if len(candidates) > 1: + raise AmbiguousSaveError(f"Ambiguous {context}: {len(candidates)} candidates: {[c.name for c in candidates]}") + return candidates[0] + + +# ----------------------------------------------------------------------------- +# Provenance attachment (napari metadata glue) +# ----------------------------------------------------------------------------- +def attach_source_and_io_to_layer_kwargs( + layer_kwargs: dict[str, Any], + file_path: Path, + *, + kind: AnnotationKind | None = None, + dataset_key: str = "keypoints", +) -> None: + """ + Attach authoritative source info + IO provenance to napari layer metadata dict. + + - Keeps legacy fields (source_h5*) for debugging/migration. + - Stores IOProvenance as a plain dict under metadata['metadata']['io']. + - Stores AnnotationKind as enum object (runtime invariant). + + If kind is None, we fall back to discovery-based inference. + + # FUTURE NOTE hardcoded DLC structure: + # kind inference relies on discovery filename patterns (CollectedData*, machinelabels*). + + Notes + ----- + This function does not expect ``layer.metadata`` directly. + It expects the middle dict of a napari LayerData tuple, e.g.: + + (data, layer_kwargs, "points") + + and writes into: + + layer_kwargs["metadata"] + """ + + meta = layer_kwargs.setdefault("metadata", {}) + + # Legacy migration fields + try: + src_abs = str(file_path.expanduser().resolve()) + except Exception: + src_abs = str(file_path) + + meta["source_h5"] = src_abs + meta["source_h5_name"] = file_path.name + meta["source_h5_stem"] = file_path.stem + + # Anchor root: file parent (robust for shared labeled-data folders) + try: + anchor = str(file_path.expanduser().resolve().parent) + except Exception: + anchor = str(file_path.parent) + + # Relative path stored as POSIX (OS-agnostic) + relposix = canonicalize_path(file_path, n=1) + + # If caller didn't provide kind, infer from discovery + if kind is None: + kind = infer_annotation_kind_for_file(file_path) + + meta["io"] = build_io_provenance_dict( + project_root=anchor, + source_relpath_posix=relposix, + kind=kind, + dataset_key=dataset_key, + ) + + +# ------------------------------------------------------------------------- +# Central metadata adapter gateway (validation + migration + controlled write) +# ------------------------------------------------------------------------- +class MergePolicy(str, Enum): + """ + How to apply an incoming model to an existing layer.metadata dict. + + - MERGE_MISSING: only fill missing/empty keys on the layer + - MERGE: shallow update (incoming overwrites existing) + - REPLACE: replace entire metadata mapping with incoming (rarely desired) + """ + + MERGE_MISSING = "merge_missing" + MERGE = "merge" + REPLACE = "replace" + + +# _EMPTY = (None, "", [], {}) + + +def _is_empty_value(v: Any) -> bool: + # Treat 0 and False as legitimate values, not "empty" + if v is None: + return True + if isinstance(v, str) and v == "": + return True + if isinstance(v, (list, tuple, set)) and len(v) == 0: + return True + if isinstance(v, dict) and len(v) == 0: + return True + return False + + +def _layer_metadata_dict(layer: Any) -> dict[str, Any]: + md = getattr(layer, "metadata", None) + if isinstance(md, dict): + return md + if md is None: + return {} + # best-effort cast if napari gives a Mapping-like object + try: + return dict(md) + except Exception: + return {} + + +def _infer_kind_from_source_name(p: Path) -> AnnotationKind | None: + # best-effort legacy inference; discovery-based inference is preferred + try: + return infer_annotation_kind_for_file(p) + except Exception: + low = p.name.lower() + if low.startswith("collecteddata"): + return AnnotationKind.GT + if low.startswith("machinelabels"): + return AnnotationKind.MACHINE + return None + + +def _build_io_from_source_h5( + src: str, + *, + dataset_key: str = "keypoints", +) -> dict[str, Any] | None: + """Legacy migration: build io provenance dict from source_h5 string.""" + if not isinstance(src, str) or not src: + return None + try: + p = Path(src).expanduser().resolve() + except Exception: + p = Path(src) + + kind = _infer_kind_from_source_name(p) + + try: + anchor = str(p.expanduser().resolve().parent) + except Exception: + anchor = str(p.parent) + + relposix = canonicalize_path(p, n=1) + + try: + return build_io_provenance_dict( + project_root=anchor, + source_relpath_posix=relposix, + kind=kind, + dataset_key=dataset_key, + ) + except Exception: + logger.debug("Failed to build io provenance from legacy source_h5=%r", src, exc_info=True) + return None + + +def _prepare_points_payload( + md: Mapping[str, Any], + *, + drop_controls: bool = True, + drop_header: bool = False, + migrate_legacy: bool = True, +) -> dict[str, Any]: + """ + Prepare a dict suitable for PointsMetadata.model_validate(). + + - coerces header into DLCHeaderModel when possible + - coerces io.kind / save_target.kind into AnnotationKind + - drops runtime-only fields (controls) by default + - optionally migrates legacy source_h5 -> io dict + """ + raw = dict(md) + + if drop_controls: + raw.pop("controls", None) + + # legacy migration: io from source_h5 + if migrate_legacy and not raw.get("io"): + src = raw.get("source_h5") + io_dict = _build_io_from_source_h5(src, dataset_key="keypoints") + if io_dict: + raw["io"] = io_dict + + # Coerce header into DLCHeaderModel (schema) if present + hdr = raw.get("header", None) + if hdr is not None: + try: + if isinstance(hdr, DLCHeaderModel): + raw["header"] = hdr + elif isinstance(hdr, dict) and "columns" in hdr: + raw["header"] = DLCHeaderModel.model_validate(hdr) + else: + # support MultiIndex / list-of-tuples via validator + raw["header"] = DLCHeaderModel(columns=hdr) + except Exception: + raw.pop("header", None) + + io_dict = raw.get("io", None) + if isinstance(io_dict, dict): + _coerce_io_kind(io_dict) + + st_dict = raw.get("save_target", None) + if isinstance(st_dict, dict): + _coerce_io_kind(st_dict) + + return raw + + +def _prepare_image_payload( + md: Mapping[str, Any], +) -> dict[str, Any]: + """Prepare a dict suitable for ImageMetadata.model_validate().""" + raw = dict(md) + return raw + + +# ------------------------- +# Public adapter API +# ------------------------- + + +def read_points_meta( + layer: Any, + *, + migrate_legacy: bool = True, + drop_controls: bool = True, + drop_header: bool = False, +) -> PointsMetadata | ValidationError: + """ + Read PointsMetadata from a layer with strict validation. + + Returns: + - PointsMetadata on success + - ValidationError on failure (visible to caller) + """ + md = _layer_metadata_dict(layer) + payload = _prepare_points_payload( + md, + drop_controls=drop_controls, + drop_header=drop_header, + migrate_legacy=migrate_legacy, + ) + try: + return PointsMetadata.model_validate(payload) + except ValidationError as e: + return e + + +def write_points_meta( + layer: Any, + model: PointsMetadata | Mapping[str, Any], + merge_policy: MergePolicy | str = MergePolicy.MERGE_MISSING, + *, + fields: set[str] | None = None, + exclude_none: bool = True, + validate: bool = True, + migrate_legacy: bool = True, +) -> PointsMetadata | ValidationError: + """ + Write Points metadata through a single validated gateway. + + - Applies merge_policy to layer.metadata + - Validates the final dict by default + - Returns PointsMetadata on success, ValidationError on failure + """ + if isinstance(merge_policy, str): + merge_policy = MergePolicy(merge_policy) + + existing = _layer_metadata_dict(layer) + # preserve a strong reference to the existing header (runtime object sometimes) + existing_header = existing.get("header", None) + + if isinstance(model, Mapping): + incoming = dict(model) + else: + incoming = model.model_dump(mode="python", exclude_none=exclude_none) + + # never write runtime-only field + incoming.pop("controls", None) + + if fields is not None: + incoming = {k: v for k, v in incoming.items() if k in fields} + + if merge_policy is MergePolicy.REPLACE: + merged = dict(incoming) + elif merge_policy is MergePolicy.MERGE: + merged = dict(existing) + merged.update(incoming) + else: # MERGE_MISSING + merged = dict(existing) + for k, v in incoming.items(): + if _is_empty_value(merged.get(k)): + merged[k] = v + + # restore header if it existed but got dropped by incoming dict + if existing_header is not None and merged.get("header") is None: + merged["header"] = existing_header + + # legacy migration (optional): if caller writes anything, keep io stable + if migrate_legacy and not merged.get("io") and merged.get("source_h5"): + io_dict = _build_io_from_source_h5(str(merged.get("source_h5")), dataset_key="keypoints") + if io_dict: + merged["io"] = io_dict + + if validate: + payload = _prepare_points_payload( + merged, + drop_controls=True, + drop_header=False, + migrate_legacy=migrate_legacy, + ) + try: + validated = PointsMetadata.model_validate(payload) + except ValidationError as e: + logger.warning("write_points_meta validation failed for layer=%r: %s", getattr(layer, "name", layer), e) + return e + + # Write back validated python dict (exclude_none keeps metadata lean) + final_dict = validated.model_dump(mode="python", exclude_none=True) + hdr = final_dict.get("header", None) + if isinstance(hdr, DLCHeaderModel): + final_dict["header"] = hdr.to_metadata_payload() + # preserve header if pydantic excluded it (or coercion removed it) + if existing_header is not None and final_dict.get("header") is None: + final_dict["header"] = existing_header + + # mutate in place (napari likes stable mapping refs) + if getattr(layer, "metadata", None) is None or not isinstance(getattr(layer, "metadata", None), dict): + layer.metadata = {} + layer.metadata.clear() + layer.metadata.update(final_dict) + return validated + + # No validation mode (rare): still write merged mapping + if getattr(layer, "metadata", None) is None or not isinstance(getattr(layer, "metadata", None), dict): + layer.metadata = {} + layer.metadata.clear() + layer.metadata.update(merged) + # best-effort return model + try: + return PointsMetadata.model_validate(_prepare_points_payload(merged, migrate_legacy=migrate_legacy)) + except ValidationError as e: + return e + + +def read_image_meta(layer: Any) -> ImageMetadata | ValidationError: + """ + Read ImageMetadata from a layer with strict validation. + """ + md = _layer_metadata_dict(layer) + payload = _prepare_image_payload(md) + try: + return ImageMetadata.model_validate(payload) + except ValidationError as e: + return e + + +def write_image_meta( + layer: Any, + model: ImageMetadata | Mapping[str, Any], + merge_policy: MergePolicy | str = MergePolicy.MERGE_MISSING, + *, + fields: set[str] | None = None, + exclude_none: bool = True, + validate: bool = True, +) -> ImageMetadata | ValidationError: + """ + Write Image metadata through a single validated gateway. + """ + if isinstance(merge_policy, str): + merge_policy = MergePolicy(merge_policy) + + existing = _layer_metadata_dict(layer) + + if isinstance(model, Mapping): + incoming = dict(model) + else: + incoming = model.model_dump(mode="python", exclude_none=exclude_none) + + if fields is not None: + incoming = {k: v for k, v in incoming.items() if k in fields} + + if merge_policy is MergePolicy.REPLACE: + merged = dict(incoming) + elif merge_policy is MergePolicy.MERGE: + merged = dict(existing) + merged.update(incoming) + else: # MERGE_MISSING + merged = dict(existing) + for k, v in incoming.items(): + if _is_empty_value(merged.get(k)): + merged[k] = v + + if validate: + try: + validated = ImageMetadata.model_validate(_prepare_image_payload(merged)) + except ValidationError as e: + logger.warning("write_image_meta validation failed for layer=%r: %s", getattr(layer, "name", layer), e) + return e + + final_dict = validated.model_dump(mode="python", exclude_none=True) + + if getattr(layer, "metadata", None) is None or not isinstance(getattr(layer, "metadata", None), dict): + layer.metadata = {} + layer.metadata.clear() + layer.metadata.update(final_dict) + return validated + + if getattr(layer, "metadata", None) is None or not isinstance(getattr(layer, "metadata", None), dict): + layer.metadata = {} + layer.metadata.clear() + layer.metadata.update(merged) + try: + return ImageMetadata.model_validate(_prepare_image_payload(merged)) + except ValidationError as e: + return e + + +def migrate_points_layer_metadata(layer: Any) -> PointsMetadata | ValidationError: + """ + Convenience migration entrypoint: + - reads (with legacy migration) + - writes back (merge_missing) through gateway + """ + res = read_points_meta(layer, migrate_legacy=True) + if isinstance(res, ValidationError): + return res + return write_points_meta(layer, res, MergePolicy.MERGE_MISSING, migrate_legacy=True) + + +def coerce_header_model(header: Any) -> DLCHeaderModel | None: + """ + Convert any supported header representation to DLCHeaderModel. + + Supported: + - DLCHeaderModel + - dict-like payload (including {"columns": ...}) + - pandas.MultiIndex / Index via existing _coerce_header_to_model logic + """ + if header is None: + return None + if isinstance(header, DLCHeaderModel): + return header + # fall back to existing coercion path (dict, MultiIndex, etc.) + return _coerce_header_to_model(header) diff --git a/src/napari_deeplabcut/core/project_paths.py b/src/napari_deeplabcut/core/project_paths.py new file mode 100644 index 00000000..3c7286e8 --- /dev/null +++ b/src/napari_deeplabcut/core/project_paths.py @@ -0,0 +1,667 @@ +""" +Regarding root anchor : +The "root anchor" is a configurable directory used to resolve project-relative +paths in IO provenance. + +Motivation +---------- +We may not always load a full DLC project (config.yaml may be missing). +In particular there may only be a labeled-data folder containing images + h5/csv files. +Therefore the root anchor must be inferable from what the user opened: + +- If the user opens a file: the default anchor is the file's parent directory. +- If the user opens a folder: the anchor is that folder. +- If a config.yaml exists nearby: the anchor *may* be elevated to the project + root, but must remain configurable and must not be required. +""" + +# src/napari_deeplabcut/core/project_paths.py +from __future__ import annotations + +import logging +from collections.abc import Iterable +from enum import Enum +from pathlib import Path, PureWindowsPath + +from napari_deeplabcut.config.models import DLCProjectContext, PointsMetadata + +logger = logging.getLogger(__name__) + + +def is_windows_absolute_path(value: str | Path) -> bool: + try: + return PureWindowsPath(str(value)).is_absolute() + except Exception: + return False + + +# ----------------------------------------------------------------------------- +# Canonicalization +# ----------------------------------------------------------------------------- + + +def canonicalize_path(p: str | Path, n: int = 3) -> str: + """ + Return canonical POSIX path built from the last n path components. + + This implementation is intentionally identical to the legacy behavior + in misc.canonicalize_path to preserve remapping semantics. + + Parameters + ---------- + p : str | Path + Input path. + n : int + Number of trailing components to keep. + + Returns + ------- + str + Canonicalized POSIX-style path, or empty string on failure. + """ + if n <= 0: + raise ValueError("n must be a positive integer") + + try: + s = str(p) + except Exception: + logger.debug("Failed to stringify path of type %s", type(p).__name__, exc_info=True) + return "" + + s = s.replace("\\", "/") + s = s.rstrip("/") + parts = [part for part in s.split("/") if part and part not in (".", "..")] + + if not parts: + return "" + + return "/".join(parts[-n:]) + + +# ----------------------------------------------------------------------------- +# Path matching policy +# ----------------------------------------------------------------------------- + + +class PathMatchPolicy(Enum): + """ + Policy controlling how image paths are matched across datasets. + + ORDERED_DEPTHS means: + - Try matching with depth=3 + - If no overlap, try depth=2 + - If still no overlap, try depth=1 + """ + + ORDERED_DEPTHS = "ordered_depths" + + @property + def depths(self) -> tuple[int, ...]: + if self is PathMatchPolicy.ORDERED_DEPTHS: + return (3, 2, 1) + raise NotImplementedError(f"Unhandled PathMatchPolicy: {self}") + + +def find_matching_depth( + old_paths: Iterable[str | Path], + new_paths: Iterable[str | Path], + policy: PathMatchPolicy = PathMatchPolicy.ORDERED_DEPTHS, +) -> int | None: + """ + Find the first canonicalization depth producing overlapping path keys. + + Returns + ------- + int | None + Depth used for matching, or None if no overlap found. + """ + old_paths = list(old_paths) + new_paths = list(new_paths) + + if not old_paths or not new_paths: + return None + + for depth in policy.depths: + old_keys = {canonicalize_path(p, depth) for p in old_paths} + new_keys = {canonicalize_path(p, depth) for p in new_paths} + if old_keys & new_keys: + return depth + + return None + + +# ----------------------------------------------------------------------------- +# DLC path heuristics +# ----------------------------------------------------------------------------- + + +def is_config_yaml(path: str | Path) -> bool: + """Return True if path points to a DLC config.yaml file.""" + try: + p = Path(path) + except TypeError: + return False + return p.is_file() and p.name.lower() == "config.yaml" + + +def has_dlc_datafiles(folder: str | Path) -> bool: + """ + True if folder contains DLC label artifacts such as: + - CollectedData*.h5 / .csv + - machinelabels*.h5 / .csv + """ + # FUTURE NOTE @C-Achard 2026-02-17: Do not hardcode these patterns + # and clearly expose these if data file formats change or expand. + p = Path(folder) + if not p.exists() or not p.is_dir(): + return False + + patterns = ( + "CollectedData*.h5", + "CollectedData*.csv", + "machinelabels*.h5", + "machinelabels*.csv", + ) + return any(any(p.glob(pat)) for pat in patterns) + + +def looks_like_dlc_labeled_folder(folder: str | Path) -> bool: + """ + Heuristic for DLC labeled-data folders. + + True if: + - DLC artifacts are present, OR + - Folder is inside a 'labeled-data' directory. + """ + p = Path(folder) + if not p.exists() or not p.is_dir(): + return False + + if has_dlc_datafiles(p): + return True + + return any(part.lower() == "labeled-data" for part in p.parts) + + +def should_force_dlc_reader(paths: str | Path | Iterable[str | Path]) -> bool: + """ + Decide whether napari-deeplabcut reader should be preferred. + + Rules (unchanged from legacy behavior): + - Any config.yaml -> DLC reader + - Any folder that looks like DLC labeled-data -> DLC reader + """ + if isinstance(paths, (str, Path)): + paths = [paths] + + paths = list(paths) + if not paths: + return False + + if any(is_config_yaml(p) for p in paths): + return True + + if any(looks_like_dlc_labeled_folder(p) for p in paths): + return True + + return False + + +# ----------------------------------------------------------------------------- +# Root-anchor inference utilities +# ----------------------------------------------------------------------------- +def _collect_anchor_candidates( + *values: str | Path | None, +) -> list[Path]: + anchors: list[Path] = [] + for value in values: + anchor = normalize_anchor_candidate(value) + if anchor is not None and anchor not in anchors: + anchors.append(anchor) + return anchors + + +def _is_labeled_data_dataset_folder(path: Path | None) -> bool: + if path is None: + return False + lowered = [part.lower() for part in path.parts] + return "labeled-data" in lowered and path.name.lower() != "labeled-data" + + +def _extract_dataset_name_from_paths(paths: Iterable[str | Path]) -> str | None: + for value in paths: + try: + text = str(value).replace("\\", "/") + except Exception: + continue + parts = [p for p in text.split("/") if p] + lowered = [p.lower() for p in parts] + try: + idx = lowered.index("labeled-data") + except ValueError: + continue + if idx + 1 < len(parts): + return parts[idx + 1] + return None + + +def normalize_anchor_candidate(value: str | Path | None) -> Path | None: + """Return a normalized directory anchor from a file/folder candidate.""" + if value is None: + return None + + try: + p = Path(value).expanduser().resolve() + except Exception: + try: + p = Path(value) + except Exception: + return None + + # If this is an existing file, anchor on its parent directory. + if p.is_file(): + return p.parent + + # For non-existent paths, heuristically treat file-like paths (with a suffix) + # as files so that their parent directory is used as the anchor. This avoids + # searching for config files under a non-existent "/config.yaml". + if not p.exists() and p.suffix: + return p.parent + + # Existing directories (or suffix-less paths) are used as-is. + return p + + +def infer_dlc_project( + *, + anchor_candidates: list[str | Path] | tuple[str | Path, ...] = (), + dataset_candidates: list[str | Path] | tuple[str | Path, ...] = (), + explicit_root: str | Path | None = None, + prefer_project_root: bool = True, + max_levels: int = 5, +) -> DLCProjectContext: + """ + Infer a best-effort DLC project context from generic path-like hints. + + Parameters + ---------- + anchor_candidates: + Ordered candidates that may indicate a project anchor, project root, + file parent, source directory, etc. + dataset_candidates: + Ordered candidates that may already point at a labeled-data dataset folder. + explicit_root: + Strongest hint. If provided, used first. + prefer_project_root: + If True, root_anchor prefers the folder containing config.yaml. + Otherwise it prefers the first valid anchor candidate. + """ + anchors = _collect_anchor_candidates(explicit_root, *anchor_candidates) + + dataset_folder = None + for cand in dataset_candidates: + d = normalize_anchor_candidate(cand) + if d is not None: + dataset_folder = d + break + + for anchor in anchors: + cfg = find_nearest_config(anchor, max_levels=max_levels) + if cfg is not None: + project_root = cfg.parent + root_anchor = project_root if prefer_project_root else anchor + return DLCProjectContext( + root_anchor=root_anchor, + project_root=project_root, + config_path=cfg, + dataset_folder=dataset_folder, + ) + + return DLCProjectContext( + root_anchor=anchors[0] if anchors else dataset_folder, + project_root=None, + config_path=None, + dataset_folder=dataset_folder, + ) + + +def infer_labeled_data_folder_from_paths( + paths: Iterable[str | Path], + *, + project_root: str | Path | None = None, + fallback_root: str | Path | None = None, +) -> Path | None: + """ + Infer a DLC labeled-data/ folder from path hints. + """ + fallback = normalize_anchor_candidate(fallback_root) + if _is_labeled_data_dataset_folder(fallback): + return fallback + + dataset_name = _extract_dataset_name_from_paths(paths) + if not dataset_name: + return None + + proj = normalize_anchor_candidate(project_root) + if proj is None: + return None + + return proj / "labeled-data" / dataset_name + + +def find_nearest_config( + start: str | Path | None, + *, + max_levels: int = 5, +) -> Path | None: + """ + Walk upward from start to find the nearest config.yaml. + """ + anchor = normalize_anchor_candidate(start) + if anchor is None: + return None + + cur = anchor + for _ in range(max_levels + 1): + cfg = cur / "config.yaml" + if cfg.is_file() and cfg.name.lower() == "config.yaml": + return cfg + if cur.parent == cur: + break + cur = cur.parent + + return None + + +def infer_dlc_project_from_config(config_path: str | Path) -> DLCProjectContext: + root = resolve_project_root_from_config(config_path) + if root is None: + raise ValueError(f"Not a valid DLC config.yaml: {config_path!r}") + cfg = Path(config_path).expanduser().resolve(strict=False) + return DLCProjectContext( + root_anchor=root, + project_root=root, + config_path=cfg, + dataset_folder=None, + ) + + +# ----------------------------------------------------------------------------- +# Explicit config-based DLC path normalization for project-less labeled folders +# ----------------------------------------------------------------------------- +def resolve_project_root_from_config(config_path: str | Path | None) -> Path | None: + """ + Return the DLC project root (= parent directory) from an explicit config.yaml path. + """ + if config_path is None: + return None + + try: + p = Path(config_path).expanduser().resolve(strict=False) + except Exception: + try: + p = Path(config_path) + except Exception: + return None + + if p.name.lower() != "config.yaml": + return None + + if not p.is_file(): + return None + + return p.parent + + +def coerce_paths_to_dlc_row_keys( + paths: Iterable[str | Path], + *, + source_root: str | Path, + dataset_name: str | None = None, +) -> tuple[list[str], tuple[int, ...]]: + """ + Rewrite paths from a project-less labeled folder into canonical DLC row-key form: + + labeled-data// + + Intended use + ------------ + This is for the workflow where the user labeled a folder outside any DLC + project, then chooses a target config.yaml at save time to associate the + labels with a DLC project. + + Rules + ----- + - If a path is already a DLC row key (`labeled-data//`), + normalize it to POSIX and keep it. + - If a path is an absolute file directly inside `source_root`, + rewrite it to `labeled-data//`. + - If a path is a relative basename (e.g. `img001.png`), + rewrite it similarly. + - All other paths are preserved unchanged (POSIX-normalized) and reported as unresolved. + + This deliberately does NOT: + - invent nested DLC row keys for subdirectories, + - coerce multi-folder or ambiguous layouts, + - validate against the selected project root. + """ + root = normalize_anchor_candidate(source_root) + if root is None: + raise ValueError("source_root must resolve to a valid directory-like anchor") + + try: + root = root.expanduser().resolve(strict=False) + except Exception: + pass + + ds_name = (dataset_name or root.name).strip() + if not ds_name: + raise ValueError("dataset_name must be non-empty") + + rewritten: list[str] = [] + unresolved: list[int] = [] + + for i, value in enumerate(paths): + try: + text = str(value).replace("\\", "/") + except Exception: + text = "" + + if not text: + rewritten.append(text) + unresolved.append(i) + continue + + parts = [p for p in text.split("/") if p] + + # Already canonical-ish DLC row key -> preserve from labeled-data onward + lowered = [p.lower() for p in parts] + if "labeled-data" in lowered: + try: + idx = lowered.index("labeled-data") + if idx + 2 < len(parts): + rewritten.append("/".join(parts[idx:])) + continue + except Exception: + pass + + try: + p = Path(value) + except Exception: + p = None + + if p is not None and not p.is_absolute() and is_windows_absolute_path(value): + rewritten.append(PureWindowsPath(str(value)).as_posix()) + unresolved.append(i) + continue + + # Relative basename only -> coerce safely. + # Refuse any relative path that contains '.', '..', or multiple path parts. + if p is not None and not p.is_absolute(): + raw_parts = [str(part) for part in p.parts if str(part) != ""] + if len(raw_parts) == 1 and raw_parts[0] not in (".", ".."): + rewritten.append(f"labeled-data/{ds_name}/{raw_parts[0]}") + else: + rewritten.append(text) + unresolved.append(i) + continue + + # Absolute file directly under source_root -> coerce safely + try: + abs_path = Path(value).expanduser().resolve(strict=False) + except Exception: + rewritten.append(text) + unresolved.append(i) + continue + + try: + rel_to_root = abs_path.relative_to(root) + except Exception: + rewritten.append(abs_path.as_posix()) + unresolved.append(i) + continue + + # Only direct children of source_root are coerced in this lightweight version. + if len(rel_to_root.parts) == 1: + rewritten.append(f"labeled-data/{ds_name}/{rel_to_root.name}") + else: + rewritten.append(abs_path.as_posix()) + unresolved.append(i) + + return rewritten, tuple(unresolved) + + +def target_dataset_folder_for_config( + config_path: str | Path, + *, + dataset_name: str, +) -> Path | None: + """ + Return the target DLC dataset folder under the chosen project: + + /labeled-data/ + """ + project_root = resolve_project_root_from_config(config_path) + if project_root is None: + return None + return project_root / "labeled-data" / dataset_name + + +def dataset_folder_has_files(folder: str | Path | None) -> bool: + """ + Return True if the given folder exists and contains any files. + + This is intentionally conservative: any existing file content means we refuse + the project-association override to avoid colliding with an existing dataset. + """ + if folder is None: + return False + + p = Path(folder) + if not p.exists() or not p.is_dir(): + return False + + return any(child.is_file() for child in p.iterdir()) + + +# ----------------------------------------------------------------------------- +# Source-specific adapters +# ----------------------------------------------------------------------------- +def infer_dlc_project_from_opened( + opened: str | Path, + *, + explicit_root: str | Path | None = None, + prefer_project_root: bool = True, + max_levels: int = 5, +) -> DLCProjectContext: + return infer_dlc_project( + anchor_candidates=[opened], + explicit_root=explicit_root, + prefer_project_root=prefer_project_root, + max_levels=max_levels, + ) + + +def infer_dlc_project_from_points_meta( + pts_meta: PointsMetadata, + *, + prefer_project_root: bool = True, + max_levels: int = 5, +) -> DLCProjectContext: + """ + Infer DLC dataset folder (…/labeled-data/) from PointsMetadata. + + Uses: + - pts_meta.project (config parent) as project root + - pts_meta.paths (canonicalized relpaths like labeled-data/test/img000.png) + - pts_meta.root as a fallback hint + + Args: + pts_meta: PointsMetadata object containing project-related metadata. + prefer_project_root: If True, root_anchor prefers the folder containing config.yaml. + max_levels: Maximum number of levels to search upward for config.yaml. + + + Returns a DLCProjectContext object representing the inferred project context. + """ + project = getattr(pts_meta, "project", None) + root = getattr(pts_meta, "root", None) + paths = getattr(pts_meta, "paths", None) or [] + + dataset_folder = infer_labeled_data_folder_from_paths( + paths, + project_root=project, + fallback_root=root, + ) + + return infer_dlc_project( + anchor_candidates=[project, root, dataset_folder], + dataset_candidates=[dataset_folder], + explicit_root=None, + prefer_project_root=prefer_project_root, + max_levels=max_levels, + ) + + +def infer_dlc_project_from_image_layer( + layer, + *, + prefer_project_root: bool = True, + max_levels: int = 5, +) -> DLCProjectContext: + """Best-effort inference of the DLC project context from an Image/video layer using its source metadata. + + Uses: + - layer.metadata.project as project root + - layer.metadata.root as a fallback hint + - layer.source.path as a fallback hint + + Returns a DLCProjectContext object representing the inferred project context. + """ + md = getattr(layer, "metadata", {}) or {} + + candidates: list[str | Path] = [] + + project = md.get("project") + if isinstance(project, str) and project: + candidates.append(project) + + root = md.get("root") + if isinstance(root, str) and root: + candidates.append(root) + + try: + src = getattr(getattr(layer, "source", None), "path", None) + except Exception: + src = None + + if src: + candidates.append(src) + + return infer_dlc_project( + anchor_candidates=candidates, + dataset_candidates=[], + explicit_root=None, + prefer_project_root=prefer_project_root, + max_levels=max_levels, + ) diff --git a/src/napari_deeplabcut/core/provenance.py b/src/napari_deeplabcut/core/provenance.py new file mode 100644 index 00000000..78b9d34b --- /dev/null +++ b/src/napari_deeplabcut/core/provenance.py @@ -0,0 +1,279 @@ +# src/napari_deeplabcut/core/provenance.py +from __future__ import annotations + +import hashlib +import logging +from pathlib import Path, PurePosixPath + +from pydantic import ValidationError + +from napari_deeplabcut.config.models import AnnotationKind, IOProvenance, PointsMetadata +from napari_deeplabcut.core.errors import MissingProvenanceError, UnresolvablePathError +from napari_deeplabcut.core.metadata import parse_points_metadata +from napari_deeplabcut.core.project_paths import ( + infer_dlc_project_from_points_meta, + is_windows_absolute_path, +) + +logger = logging.getLogger(__name__) + + +# ---------------------------------------- +# Helper functions +# ---------------------------------------- +def suggest_human_placeholder(anchor: str) -> str: + """ + Deterministic fallback scorer placeholder derived from anchor path. + """ + h = hashlib.sha1(anchor.encode("utf-8", errors="ignore")).hexdigest()[:6] + return f"human_{h}" + + +def requires_gt_promotion(pts_meta: PointsMetadata) -> bool: + """ + Return True when a machine/prediction source must be promoted to a GT save_target + before saving. + + Rules: + - if save_target already exists -> no promotion needed + - if io.kind is MACHINE -> promotion required + - otherwise -> no promotion required + """ + if getattr(pts_meta, "save_target", None) is not None: + return False + + io_meta = getattr(pts_meta, "io", None) + src_kind = getattr(io_meta, "kind", None) if io_meta is not None else None + return src_kind is AnnotationKind.MACHINE + + +def build_gt_save_target( + anchor: str, + scorer: str, + *, + dataset_key: str = "keypoints", +) -> IOProvenance: + """ + Build a GT save_target pointing to CollectedData_.h5 under a folder anchor. + """ + scorer_clean = str(scorer).strip() + target_name = f"CollectedData_{scorer_clean}.h5" + return IOProvenance( + project_root=anchor, + source_relpath_posix=target_name, + kind=AnnotationKind.GT, + dataset_key=dataset_key, + scorer=scorer_clean, + ) + + +def apply_gt_save_target( + pts_meta: PointsMetadata, + *, + anchor: str, + scorer: str, + dataset_key: str = "keypoints", +) -> PointsMetadata: + """ + Return an updated PointsMetadata with a GT promotion save_target attached. + """ + st = build_gt_save_target(anchor, scorer, dataset_key=dataset_key) + return pts_meta.model_copy(update={"save_target": st}) + + +def is_projectless_folder_association_candidate( + pts_meta: PointsMetadata, + *, + treat_machine_as_ineligible: bool = True, +) -> bool: + """ + Return True for the 'associate current labeled folder with a DLC project' workflow. + + Non-candidates include: + - machine/promotion layers (optional policy) + - layers with an explicit save_target + - layers that already have a resolved DLC project/config context + - layers without a usable folder root + - layers whose paths do not look like a simple single-folder labeling session + """ + if treat_machine_as_ineligible and requires_gt_promotion(pts_meta): + return False + + if getattr(pts_meta, "save_target", None) is not None: + return False + + project_ctx = infer_dlc_project_from_points_meta(pts_meta, prefer_project_root=False) + if project_ctx.project_root is not None and project_ctx.config_path is not None: + return False + + root = getattr(pts_meta, "root", None) + paths = list(getattr(pts_meta, "paths", None) or []) + if not root or not paths: + return False + + try: + root_path = Path(root).expanduser().resolve(strict=False) + except Exception: + root_path = Path(root) + + for value in paths: + text = str(value).replace("\\", "/") + p = Path(value) + + parts = [part for part in text.split("/") if part] + lowered = [part.lower() for part in parts] + if "labeled-data" in lowered: + idx = lowered.index("labeled-data") + if idx + 2 < len(parts): + continue + + is_windows_abs_misclassified = not p.is_absolute() and is_windows_absolute_path(value) + if not p.is_absolute(): + if is_windows_abs_misclassified: + continue + + if len(p.parts) == 1 and p.parts[0] not in (".", ".."): + continue + return False + + try: + rel_to_root = p.expanduser().resolve(strict=False).relative_to(root_path) + except Exception: + continue + + if len(rel_to_root.parts) == 1: + continue + + return False + + return True + + +# ---------------------------------------- +# Core provenance logic +# ---------------------------------------- + + +def resolve_output_path_from_metadata(metadata: dict) -> tuple[str | None, str | None, AnnotationKind | None]: + """ + Resolve output path with promotion support. + + Returns: + (out_path, target_scorer, source_kind) + + - Prefer PointsMetadata.save_target (promotion-to-GT). + - For GT sources, fall back to io/source_h5. + - For machine sources without save_target, return (None, None, "machine") to allow safe abort. + """ + layer_meta = metadata.get("metadata") + if not isinstance(layer_meta, dict): + layer_meta = {} + + pts = parse_points_metadata(layer_meta) + io = pts.io + st = pts.save_target + + source_kind = getattr(io, "kind", None) if io is not None else None + + # Promotion target wins + if st is not None: + try: + p = resolve_provenance_path(st, root_anchor=st.project_root, allow_missing=True) + target_scorer = getattr(st, "scorer", None) + if isinstance(target_scorer, str) and target_scorer.strip(): + return str(p), target_scorer.strip(), source_kind + # Also accept scorer stored in dict extra + if isinstance(layer_meta.get("save_target"), dict): + s2 = layer_meta["save_target"].get("scorer") + if isinstance(s2, str) and s2.strip(): + return str(p), s2.strip(), source_kind + return str(p), None, source_kind + except (MissingProvenanceError, UnresolvablePathError): + return None, None, source_kind + + # Never save back to machine sources + if source_kind == AnnotationKind.MACHINE: + return None, None, source_kind + # GT source: prefer io if available + if io is not None: + try: + p = resolve_provenance_path(io, root_anchor=io.project_root, allow_missing=True) + return str(p), None, source_kind + except (MissingProvenanceError, UnresolvablePathError): + pass + + # Legacy fallback: source_h5 (GT only) + src = layer_meta.get("source_h5") + if isinstance(src, str) and src: + return src, None, source_kind + + return None, None, source_kind + + +def ensure_io_provenance(obj: IOProvenance | dict | None) -> IOProvenance | None: + """ + Validate/normalize IO provenance payload. + + Policy: runtime must carry AnnotationKind objects. + Invalid dicts raise MissingProvenanceError for deterministic behavior. + """ + if obj is None: + return None + if isinstance(obj, IOProvenance): + return obj + if isinstance(obj, dict): + try: + # This must succeed only if kinds are AnnotationKind instances + return IOProvenance.model_validate(obj) + except (ValidationError, TypeError) as e: + raise MissingProvenanceError(f"Invalid IO provenance payload: {e}") from e + raise MissingProvenanceError(f"Invalid IO provenance type: {type(obj).__name__}") + + +def normalize_provenance(io: IOProvenance | None) -> IOProvenance | None: + """ + Normalize provenance fields for stable storage. + + - Ensures source_relpath_posix uses '/' separators. + - Leaves extra fields untouched. + """ + if io is None: + return None + + src = io.source_relpath_posix + if isinstance(src, str): + src = src.replace("\\\\", "/").replace("\\", "/") + + return io.model_copy(update={"source_relpath_posix": src}) + + +def resolve_provenance_path( + io: IOProvenance | dict | None, + *, + root_anchor: str | Path | None = None, + allow_missing: bool = False, +) -> Path: + """ + Resolve IOProvenance into a concrete filesystem Path. + + io may be a dict stored in napari metadata; it will be validated strictly. + """ + io2 = ensure_io_provenance(io) + if io2 is None or not io2.source_relpath_posix: + raise MissingProvenanceError("Missing IO provenance (source_relpath_posix is required).") + + io2 = normalize_provenance(io2) or io2 + + anchor = root_anchor or io2.project_root + if not anchor: + raise UnresolvablePathError( + "Cannot resolve provenance path: no root anchor provided and io.project_root is missing." + ) + + rel = PurePosixPath(io2.source_relpath_posix) + resolved = Path(anchor) / Path(*rel.parts) + + if not allow_missing and not resolved.exists(): + raise UnresolvablePathError(f"Resolved provenance path does not exist: {resolved}") + + return resolved diff --git a/src/napari_deeplabcut/core/remap.py b/src/napari_deeplabcut/core/remap.py new file mode 100644 index 00000000..db5f5d37 --- /dev/null +++ b/src/napari_deeplabcut/core/remap.py @@ -0,0 +1,355 @@ +# src/napari_deeplabcut/core/remap.py +from __future__ import annotations + +import logging +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from napari_deeplabcut.core.project_paths import PathMatchPolicy, canonicalize_path, find_matching_depth + +logger = logging.getLogger(__name__) + +# Heuristic thresholds for "risky remap" warnings. +# These do NOT change behavior; they only control warning emissions. +_WARN_OVERLAP_RATIO = 0.80 # Warn if canonicalized path overlap is below this ratio (relative to smaller set size). +_WARN_MAPPED_RATIO = 0.80 # Warn if mapping coverage of old paths is below this ratio (mapped / old). +_SAMPLE_N = 5 # Number of examples to include in warnings about duplicate keys. + + +@dataclass(frozen=True) +class RemapResult: + """ + Result of an attempted time/frame remapping. + + Attributes + ---------- + changed: + True if output data differs from input data. + applied: + True if the remap is considered safe and should be applied upstream. + accept_paths_update: + True if upstream may safely replace metadata["paths"] with the new paths. + is_ambiguous: + True if remap was rejected because matching was ambiguous/risky. + depth_used: + Canonicalization depth used to match paths (e.g. 3, 2, 1), or None. + mapped_count: + Number of old frame indices that had a mapping into new indices. + message: + Human-readable summary suitable for logs / UI. + data: + Remapped data object (same type shape intent as input), or None if not applied. + warnings: + Tuple of warning strings describing ambiguity/risk detected during remap decision. + """ + + changed: bool + applied: bool + accept_paths_update: bool + is_ambiguous: bool + depth_used: int | None + mapped_count: int + message: str + data: Any | None + warnings: tuple[str, ...] = () + + +def _remap_array(values: np.ndarray, idx_map: Mapping[int, int]) -> np.ndarray: + """ + Remap time indices in an array of indices. + + Parameters + ---------- + values: + Array-like of integer time/frame indices to remap. + idx_map: + Mapping from old integer frame index -> new integer frame index. + Indices not present in the mapping are left unchanged. + """ + values = np.asarray(values) + if values.size == 0: + return values + + try: + values_int = values.astype(int, copy=False) + except Exception: + values_int = values.astype(int) + + mapped = np.fromiter( + (idx_map.get(int(v), int(v)) for v in values_int), + dtype=values_int.dtype, + count=len(values_int), + ) + return mapped + + +def _find_duplicates(keys: list[str]) -> dict[str, int]: + """Return dict of duplicate key -> count (only for keys occurring > 1).""" + counts: dict[str, int] = {} + for k in keys: + counts[k] = counts.get(k, 0) + 1 + return {k: c for k, c in counts.items() if c > 1} + + +def build_frame_index_map( + *, + old_paths: Iterable[str], + new_paths: Iterable[str], + policy: PathMatchPolicy = PathMatchPolicy.ORDERED_DEPTHS, +) -> tuple[dict[int, int], int | None]: + """ + Build a mapping from old frame indices -> new frame indices using canonicalized path overlap. + + Returns + ------- + (idx_map, depth_used) + """ + old_paths = list(old_paths or []) + new_paths = list(new_paths or []) + if not old_paths or not new_paths: + return {}, None + + depth = find_matching_depth(old_paths, new_paths, policy=policy) + if depth is None: + return {}, None + + old_keys = [canonicalize_path(p, depth) for p in old_paths] + new_keys = [canonicalize_path(p, depth) for p in new_paths] + + key_to_new_idx = {k: i for i, k in enumerate(new_keys)} + + idx_map: dict[int, int] = {} + for old_idx, k in enumerate(old_keys): + new_idx = key_to_new_idx.get(k) + if new_idx is not None: + idx_map[old_idx] = new_idx + + return idx_map, depth + + +def remap_time_indices( + *, + data: Any, + time_col: int, + idx_map: Mapping[int, int], +) -> RemapResult: + """ + Remap time indices in a data container (array-like or list-of-arrays). + + This function is intentionally policy-free: it only applies a provided + index mapping and reports whether anything changed. + """ + if data is None: + return RemapResult(False, False, False, False, None, 0, "No data to remap (data is None).", None) + + if not idx_map: + return RemapResult(False, False, False, False, None, 0, "No index mapping available (empty idx_map).", None) + + # Shapes-like: list of arrays + if isinstance(data, list): + new_data = [] + changed = False + + for verts in data: + arr = np.asarray(verts) + if arr.size == 0: + new_data.append(arr) + continue + + if arr.ndim < 2 or arr.shape[1] <= time_col: + new_data.append(arr) + continue + + arr2 = np.array(arr, copy=True) + t = arr2[:, time_col] + + try: + t2 = _remap_array(t, idx_map) + except Exception: + new_data.append(arr2) + continue + + try: + if not np.array_equal(t2, np.asarray(t).astype(int, copy=False)): + changed = True + except Exception: + changed = True + + arr2[:, time_col] = t2 + new_data.append(arr2) + + return RemapResult( + changed=changed, + applied=changed, + accept_paths_update=changed, + is_ambiguous=False, + depth_used=None, + mapped_count=len(idx_map), + message="Remapped list-like vertices." if changed else "List-like vertices unchanged.", + data=new_data if changed else None, + ) + + # Array-like + arr = np.asarray(data) + if arr.size == 0: + return RemapResult(False, False, False, False, None, len(idx_map), "No data to remap (empty array).", None) + + if arr.ndim < 2 or arr.shape[1] <= time_col: + return RemapResult( + False, False, False, False, None, len(idx_map), "Data shape does not contain a time column.", None + ) + + arr2 = np.array(arr, copy=True) + t = arr2[:, time_col] + + try: + t2 = _remap_array(t, idx_map) + except Exception: + return RemapResult(False, False, False, False, None, len(idx_map), "Failed to remap time column.", None) + + try: + unchanged = np.array_equal(t2, np.asarray(t).astype(int, copy=False)) + except Exception: + unchanged = False + + if unchanged: + return RemapResult(False, False, False, False, None, len(idx_map), "Time column unchanged after remap.", None) + + arr2[:, time_col] = t2 + return RemapResult(True, True, True, False, None, len(idx_map), "Remapped array-like data.", arr2) + + +def remap_layer_data_by_paths( + *, + data: Any, + old_paths: Iterable[str] | None, + new_paths: Iterable[str] | None, + time_col: int, + policy: PathMatchPolicy = PathMatchPolicy.ORDERED_DEPTHS, +) -> RemapResult: + """ + High-level remap: infer idx_map from old/new paths, assess safety, then + remap `data` only when the mapping is acceptable. + + Policy + ------ + - Safe remaps are applied automatically. + - Ambiguous basename-only remaps (depth=1 with duplicate canonical keys + and/or non-bijective mapping) are rejected. + """ + old_paths = list(old_paths or []) + new_paths = list(new_paths or []) + + if not old_paths: + return RemapResult(False, False, False, False, None, 0, "No old paths present; cannot remap.", None) + if not new_paths: + return RemapResult(False, False, False, False, None, 0, "No new paths present; cannot remap.", None) + + depth = find_matching_depth(old_paths, new_paths, policy=policy) + if depth is None: + return RemapResult( + False, False, False, False, None, 0, "No overlap between old and new paths; skipping remap.", None + ) + + old_keys = [canonicalize_path(p, depth) for p in old_paths] + new_keys = [canonicalize_path(p, depth) for p in new_paths] + + overlap = set(old_keys) & set(new_keys) + overlap_ratio = (len(overlap) / max(1, min(len(old_keys), len(new_keys)))) if overlap else 0.0 + + key_to_new_idx = {k: i for i, k in enumerate(new_keys)} + idx_map: dict[int, int] = {} + for old_idx, k in enumerate(old_keys): + new_idx = key_to_new_idx.get(k) + if new_idx is not None: + idx_map[old_idx] = new_idx + + identity_mappings = sum(1 for old_i, new_i in idx_map.items() if old_i == new_i) + logger.debug( + "Remap mapping stats: depth=%s old=%s new=%s mapped=%s identity=%s overlap=%s", + depth, + len(old_keys), + len(new_keys), + len(idx_map), + identity_mappings, + len(overlap), + ) + + if not idx_map: + return RemapResult( + False, False, False, False, None, 0, "No overlap between old and new paths; skipping remap.", None + ) + + # If ordering already matches, accept metadata paths update but no data remap needed. + if old_keys == new_keys: + return RemapResult( + changed=False, + applied=False, + accept_paths_update=True, + is_ambiguous=False, + depth_used=depth, + mapped_count=len(idx_map), + message="Path keys already aligned; no remap needed.", + data=None, + ) + + warnings: list[str] = [] + + dup_old = _find_duplicates(old_keys) + dup_new = _find_duplicates(new_keys) + if dup_old: + examples = ", ".join(list(dup_old.keys())[:_SAMPLE_N]) + warnings.append(f"Duplicate canonical keys in old_paths at depth={depth} (examples: {examples}).") + if dup_new: + examples = ", ".join(list(dup_new.keys())[:_SAMPLE_N]) + warnings.append(f"Duplicate canonical keys in new_paths at depth={depth} (examples: {examples}).") + + mapped_ratio = len(idx_map) / max(1, len(old_keys)) + if overlap_ratio < _WARN_OVERLAP_RATIO: + warnings.append( + f"Low path overlap ratio at depth={depth}: {overlap_ratio:.2f} " + f"(overlap={len(overlap)}, old={len(old_keys)}, new={len(new_keys)})." + ) + if mapped_ratio < _WARN_MAPPED_RATIO: + warnings.append(f"Low mapping coverage: {mapped_ratio:.2f} (mapped={len(idx_map)} of old={len(old_keys)}).") + + non_bijective = len(set(idx_map.values())) < len(idx_map) + if non_bijective: + warnings.append("Non-bijective mapping detected (multiple old indices map to the same new index).") + + for w in warnings: + logger.warning("Remap may be ambiguous/risky: %s", w) + + # Reject ambiguous basename-only remaps. + ambiguous_depth1 = depth == 1 and (bool(dup_old) or bool(dup_new) or non_bijective) + if ambiguous_depth1: + msg = "Rejected ambiguous depth=1 remap; keeping original frame indices and paths." + logger.warning(msg) + return RemapResult( + changed=False, + applied=False, + accept_paths_update=False, + is_ambiguous=True, + depth_used=depth, + mapped_count=len(idx_map), + message=msg, + data=None, + warnings=tuple(warnings), + ) + + res = remap_time_indices(data=data, time_col=time_col, idx_map=idx_map) + + return RemapResult( + changed=res.changed, + applied=res.changed, + accept_paths_update=True, + is_ambiguous=False, + depth_used=depth, + mapped_count=res.mapped_count, + message=res.message, + data=res.data if res.changed else None, + warnings=tuple(warnings), + ) diff --git a/src/napari_deeplabcut/core/schemas.py b/src/napari_deeplabcut/core/schemas.py new file mode 100644 index 00000000..4606f097 --- /dev/null +++ b/src/napari_deeplabcut/core/schemas.py @@ -0,0 +1,139 @@ +# src/napari_deeplabcut/core/schemas.py +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +import numpy as np +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + +from napari_deeplabcut.config.models import DLCHeaderModel, PointsMetadata + + +class PointsDataModel(BaseModel): + """Validated napari Points data for DLC keypoints writing. + + Expected napari-style layout: (N, 3) with columns [frame_index, y, x]. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + data: Any = Field(..., description="Array-like of shape (N, 3): [frame, y, x]") + + @field_validator("data") + @classmethod + def _validate_points_array(cls, v: Any) -> np.ndarray: + arr = np.asarray(v) + if arr.ndim != 2 or arr.shape[1] != 3: + raise ValueError(f"Points data must have shape (N, 3) [frame, y, x]. Got {arr.shape}.") + # Ensure numeric + if not np.issubdtype(arr.dtype, np.number): + raise TypeError(f"Points data must be numeric. Got dtype={arr.dtype}.") + return arr + + @property + def n(self) -> int: + return int(self.data.shape[0]) + + @property + def frame_inds(self) -> np.ndarray: + # floor/cast to int is typical for frame indices + return self.data[:, 0].astype(int) + + @property + def xy_dlc(self) -> np.ndarray: + # Convert napari [y, x] -> DLC [x, y] + return self.data[:, [2, 1]] + + +class KeypointPropertiesModel(BaseModel): + """Validated napari layer.properties for keypoint points. + + Napari stores per-point properties as sequences (often numpy arrays). + """ + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + + label: Sequence[str] = Field(..., description="Bodypart label per point") + id: Sequence[str] = Field(..., description="Individual id per point ('' if single animal)") + likelihood: Sequence[float] | None = Field(default=None, description="Likelihood per point (optional)") + + @field_validator("label", "id", mode="before") + @classmethod + def _coerce_str_seq(cls, v): + # make sure we can len() it and iterate + if v is None: + return v + if isinstance(v, np.ndarray): + return v.tolist() + return list(v) + + @field_validator("likelihood", mode="before") + @classmethod + def _coerce_float_seq(cls, v): + if v is None: + return None + if isinstance(v, np.ndarray): + return v.tolist() + return list(v) + + +class PointsWriteInputModel(BaseModel): + """Validated bundle of what form_df needs to write DLC keypoints.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + points: PointsDataModel + meta: PointsMetadata + props: KeypointPropertiesModel + + @model_validator(mode="after") + def _validate_required_header(self): + # For writing, header must exist + if self.meta.header is None: + raise ValueError("PointsMetadata.header is required for writing keypoints.") + if not isinstance(self.meta.header, DLCHeaderModel): + raise TypeError("PointsMetadata.header must be DLCHeaderModel.") + return self + + @model_validator(mode="after") + def _validate_lengths(self): + n = self.points.n + if len(self.props.label) != n: + raise ValueError(f"properties['label'] length {len(self.props.label)} != N points {n}") + if len(self.props.id) != n: + raise ValueError(f"properties['id'] length {len(self.props.id)} != N points {n}") + if self.props.likelihood is not None and len(self.props.likelihood) != n: + raise ValueError(f"properties['likelihood'] length {len(self.props.likelihood)} != N points {n}") + return self + + @model_validator(mode="after") + def _validate_paths_indexing(self): + # If paths exist, frame indices must be in range + if self.meta.paths: + max_idx = len(self.meta.paths) - 1 + fi = self.points.frame_inds + if fi.size and (fi.min() < 0 or fi.max() > max_idx): + raise ValueError( + f"Frame indices out of bounds for metadata.paths: " + f"min={fi.min()}, max={fi.max()}, paths_len={len(self.meta.paths)}" + ) + return self + + +class PointsLayerAttributesModel(BaseModel): + """NPE2 writer attributes bundle for a Points layer.""" + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + + name: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + properties: dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="after") + def _ensure_metadata_dict(self): + if not isinstance(self.metadata, dict): + raise TypeError("attributes['metadata'] must be a dict") + if not isinstance(self.properties, dict): + raise TypeError("attributes['properties'] must be a dict") + return self diff --git a/src/napari_deeplabcut/core/sidecar.py b/src/napari_deeplabcut/core/sidecar.py new file mode 100644 index 00000000..90e16b4c --- /dev/null +++ b/src/napari_deeplabcut/core/sidecar.py @@ -0,0 +1,239 @@ +# src/napari_deeplabcut/core/sidecar.py +""" +Sidecar storage for folder-scoped napari-deeplabcut preferences. + +This is intentionally non-invasive: DeepLabCut ignores unknown files in folders. +We store minimal, portable UI state (e.g. default scorer, trails display config) +to avoid repeated prompts and to restore per-folder preferences. + +File name: .napari-deeplabcut.json +Location: anchor folder (typically labeled-data/