Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 138 additions & 4 deletions movement/napari/loader_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from napari.utils.notifications import show_error, show_warning
from napari.viewer import Viewer
from qtpy.QtWidgets import (
QCheckBox,
QComboBox,
QDoubleSpinBox,
QFileDialog,
Expand All @@ -26,12 +27,15 @@
from movement.napari.layer_styles import BoxesStyle, PointsStyle, TracksStyle
from movement.utils.logging import logger
from movement.validators.datasets import ValidBboxesInputs, ValidPosesInputs
from movement.validators.files import DEFAULT_FRAME_REGEXP

# Allowed file suffixes for each supported source software
SUPPORTED_POSES_FILES = {
"DeepLabCut": ["h5", "csv"],
"LightningPose": ["csv"],
"SLEAP": ["h5", "slp"],
"Anipose": ["csv"],
"NWB": ["nwb"],
}

SUPPORTED_BBOXES_FILES = {
Expand Down Expand Up @@ -61,6 +65,9 @@ def __init__(self, napari_viewer: Viewer, parent=None):
# Create widgets
self._create_source_software_widget()
self._create_fps_widget()
self._create_anipose_widgets()
self._create_nwb_widgets()
self._create_via_tracks_widgets()
self._create_file_path_widget()
self._create_load_button()

Expand Down Expand Up @@ -102,6 +109,81 @@ def _create_fps_widget(self):
)
self.layout().addRow("fps:", self.fps_spinbox)

def _create_anipose_widgets(self):
"""Create a line edit for the Anipose individual name.

Hidden by default; revealed when 'Anipose' is selected.
"""
self.individual_name_edit = QLineEdit()
self.individual_name_edit.setObjectName("individual_name_edit")
self.individual_name_edit.setText("id_0")
self.individual_name_edit.setToolTip(
"Name assigned to the individual in the resulting dataset.\n"
"Defaults to 'id_0'."
)
self.layout().addRow("individual name:", self.individual_name_edit)
self.layout().setRowVisible(self.individual_name_edit, False)

def _create_nwb_widgets(self):
"""Create line edits for NWB processing module and pose estimation key.

Hidden by default; revealed when 'NWB' is selected.
"""
self.processing_module_key_edit = QLineEdit()
self.processing_module_key_edit.setObjectName(
"processing_module_key_edit"
)
self.processing_module_key_edit.setText("behavior")
self.processing_module_key_edit.setToolTip(
"Name of the NWB ProcessingModule that contains\n"
"the pose estimation data. Default: 'behavior'."
)
self.layout().addRow(
"processing module:", self.processing_module_key_edit
)
self.layout().setRowVisible(self.processing_module_key_edit, False)

self.pose_estimation_key_edit = QLineEdit()
self.pose_estimation_key_edit.setObjectName("pose_estimation_key_edit")
self.pose_estimation_key_edit.setText("PoseEstimation")
self.pose_estimation_key_edit.setToolTip(
"Name of the PoseEstimation object inside the processing module.\n"
"Default: 'PoseEstimation'."
)
self.layout().addRow(
"pose estimation key:", self.pose_estimation_key_edit
)
self.layout().setRowVisible(self.pose_estimation_key_edit, False)

def _create_via_tracks_widgets(self):
"""Create widgets for VIA-tracks-specific loading options.

Hidden by default; revealed when 'VIA-tracks' is selected.
"""
self.use_frame_numbers_checkbox = QCheckBox()
self.use_frame_numbers_checkbox.setObjectName(
"use_frame_numbers_checkbox"
)
self.use_frame_numbers_checkbox.setChecked(False)
self.use_frame_numbers_checkbox.setToolTip(
"If checked, frame numbers from the file are used as-is.\n"
"Otherwise, frames are re-indexed from 0."
)
self.layout().addRow(
"use file frame numbers:", self.use_frame_numbers_checkbox
)
self.layout().setRowVisible(self.use_frame_numbers_checkbox, False)

self.frame_regexp_edit = QLineEdit()
self.frame_regexp_edit.setObjectName("frame_regexp_edit")
self.frame_regexp_edit.setText(DEFAULT_FRAME_REGEXP)
self.frame_regexp_edit.setToolTip(
"Regex to extract frame number from filenames.\n"
"Only used when 'use file frame numbers' is checked."
)
self.layout().addRow("frame regexp:", self.frame_regexp_edit)
self.layout().setRowVisible(self.frame_regexp_edit, False)

def _create_file_path_widget(self):
"""Create a line edit and browse button for selecting the file path.

Expand All @@ -111,8 +193,10 @@ def _create_file_path_widget(self):
# File path line edit and browse button
self.file_path_edit = QLineEdit()
self.file_path_edit.setObjectName("file_path_edit")
self.file_path_edit.setMinimumHeight(28)
self.browse_button = QPushButton("Browse")
self.browse_button.setObjectName("browse_button")
self.browse_button.setMinimumHeight(28)
self.browse_button.clicked.connect(self._on_browse_clicked)

# Layout for line edit and button
Expand All @@ -122,23 +206,48 @@ def _create_file_path_widget(self):
self.layout().addRow("file path:", self.file_path_layout)

def _on_source_software_changed(self, current_text: str):
"""Enable/disable the fps spinbox based on source software."""
"""Update widget state based on the selected source software.

- Disables the fps spinbox for netCDF and NWB files (both read fps
directly from the file).
- Reveals only the input rows relevant to the selected software and
hides all others.
"""
is_netcdf = current_text in SUPPORTED_NETCDF_FILES
# Disable fps box if netCDF
self.fps_spinbox.setEnabled(not is_netcdf)
is_nwb = current_text == "NWB"

# Disable fps spinbox for formats that read it from the file
self.fps_spinbox.setEnabled(not is_netcdf and not is_nwb)
if is_netcdf:
self.fps_spinbox.setToolTip(
"The fps (frames per second) is read directly \n"
"from the netCDF file attributes."
)
elif is_nwb:
self.fps_spinbox.setToolTip(
"The fps is read directly from the NWB file metadata."
)
else:
self.fps_spinbox.setToolTip(self.fps_default_tooltip)

# Toggle per-software rows
self.layout().setRowVisible(
self.individual_name_edit, current_text == "Anipose"
)
self.layout().setRowVisible(self.processing_module_key_edit, is_nwb)
self.layout().setRowVisible(self.pose_estimation_key_edit, is_nwb)
self.layout().setRowVisible(
self.use_frame_numbers_checkbox, current_text == "VIA-tracks"
)
self.layout().setRowVisible(
self.frame_regexp_edit, current_text == "VIA-tracks"
)

def _create_load_button(self):
"""Create a button to load the file and add layers to the viewer."""
self.load_button = QPushButton("Load")
self.load_button.setObjectName("load_button")
self.load_button.setMinimumHeight(30)
self.load_button.clicked.connect(lambda: self._on_load_clicked())
self.layout().addRow(self.load_button)

Expand Down Expand Up @@ -231,9 +340,34 @@ def _format_data_for_layers(self) -> bool:
def _load_third_party_file(self) -> xr.Dataset:
"""Load a third-party file as a ``movement`` dataset.

Builds a software-specific ``kwargs`` dict from the visible form
widgets and forwards it to :func:`movement.io.load_dataset`.
Validation is handled by the loader functions.
"""
ds = load_dataset(self.file_path, self.source_software, self.fps)
kwargs: dict = {}
if self.source_software == "Anipose":
kwargs["individual_name"] = (
self.individual_name_edit.text().strip() or "id_0"
)
elif self.source_software == "NWB":
kwargs["processing_module_key"] = (
self.processing_module_key_edit.text().strip() or "behavior"
)
kwargs["pose_estimation_key"] = (
self.pose_estimation_key_edit.text().strip()
or "PoseEstimation"
)
elif self.source_software == "VIA-tracks":
use_file_frames = self.use_frame_numbers_checkbox.isChecked()
kwargs["use_frame_numbers_from_file"] = use_file_frames
if use_file_frames:
kwargs["frame_regexp"] = (
self.frame_regexp_edit.text().strip()
or DEFAULT_FRAME_REGEXP
)
ds = load_dataset(
self.file_path, self.source_software, self.fps, **kwargs
)
return ds

def _load_netcdf_file(self) -> xr.Dataset | None:
Expand Down
72 changes: 70 additions & 2 deletions tests/test_unit/test_napari_plugin/test_data_loader_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
from napari.utils.events import EmitterGroup
from pytest import DATA_PATHS
from qtpy.QtWidgets import (
QCheckBox,
QComboBox,
QDoubleSpinBox,
QLineEdit,
QPushButton,
QWidget,
)

from movement.napari.loader_widgets import (
Expand All @@ -46,13 +48,17 @@ def test_data_loader_widget_instantiation(make_napari_viewer_proxy):
# Instantiate the data loader widget
data_loader_widget = DataLoader(make_napari_viewer_proxy())

# Check that the widget has the expected number of rows
assert data_loader_widget.layout().rowCount() == 4
assert data_loader_widget.layout().rowCount() == 9

# Check that the expected widgets are present in the layout
expected_widgets = [
(QComboBox, "source_software_combo"),
(QDoubleSpinBox, "fps_spinbox"),
(QLineEdit, "individual_name_edit"),
(QLineEdit, "processing_module_key_edit"),
(QLineEdit, "pose_estimation_key_edit"),
(QCheckBox, "use_frame_numbers_checkbox"),
(QLineEdit, "frame_regexp_edit"),
(QLineEdit, "file_path_edit"),
(QPushButton, "load_button"),
(QPushButton, "browse_button"),
Expand Down Expand Up @@ -155,6 +161,7 @@ def test_on_layer_added_and_deleted(
"choice, fps_enabled, tooltip_contains",
[
("movement (netCDF)", False, "netCDF file attributes"),
("NWB", False, "NWB file metadata"),
("SLEAP", True, "Set the frames per second"),
("DeepLabCut", True, "Set the frames per second"),
],
Expand Down Expand Up @@ -183,6 +190,65 @@ def test_on_source_software_changed_sets_fps_state(
assert tooltip_contains in data_loader_widget.fps_spinbox.toolTip()


@pytest.mark.parametrize(
"choice, visible_widgets, hidden_widgets",
[
(
"Anipose",
["individual_name_edit"],
[
"processing_module_key_edit",
"pose_estimation_key_edit",
"use_frame_numbers_checkbox",
"frame_regexp_edit",
],
),
(
"NWB",
["processing_module_key_edit", "pose_estimation_key_edit"],
[
"individual_name_edit",
"use_frame_numbers_checkbox",
"frame_regexp_edit",
],
),
(
"VIA-tracks",
["use_frame_numbers_checkbox", "frame_regexp_edit"],
[
"individual_name_edit",
"processing_module_key_edit",
"pose_estimation_key_edit",
],
),
(
"DeepLabCut",
[],
[
"individual_name_edit",
"processing_module_key_edit",
"pose_estimation_key_edit",
"use_frame_numbers_checkbox",
"frame_regexp_edit",
],
),
],
)
def test_on_source_software_changed_row_visibility(
make_napari_viewer_proxy, choice, visible_widgets, hidden_widgets
):
"""Selecting a source software shows only its relevant rows."""
data_loader_widget = DataLoader(make_napari_viewer_proxy())
data_loader_widget._on_source_software_changed(choice)
layout = data_loader_widget.layout()
for name in visible_widgets:
widget = data_loader_widget.findChild(QWidget, name)
assert layout.isRowVisible(widget), f"{name} should be visible"
for name in hidden_widgets:
widget = data_loader_widget.findChild(QWidget, name)
assert not layout.isRowVisible(widget), f"{name} should be hidden"


@pytest.mark.parametrize(
"file_path",
[
Expand Down Expand Up @@ -222,6 +288,8 @@ def test_on_browse_clicked(file_path, make_napari_viewer_proxy, mocker):
("DeepLabCut", "*.h5 *.csv"),
("SLEAP", "*.h5 *.slp"),
("LightningPose", "*.csv"),
("Anipose", "*.csv"),
("NWB", "*.nwb"),
("VIA-tracks", "*.csv"),
("movement (netCDF)", "*.nc"),
],
Expand Down
Loading