Skip to content

Commit 41095c6

Browse files
committed
2 parents 778b5c1 + f7805f6 commit 41095c6

22 files changed

Lines changed: 7429 additions & 5558 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
[![Data](https://img.shields.io/badge/download-demodata-blue)](https://labshare.cshl.edu/shares/houlab/www-data/cheese3d_paper_data/cheese3d_demo.tar.gz)
55
<!--[Download demo data](https://labshare.cshl.edu/shares/houlab/www-data/cheese3d_paper_data/cheese3d_demo.tar.gz)-->
66

7-
Cheese3D is a pipeline for tracking mouse facial movements built on top of existing tools like [DeepLabCut](https://github.com/DeepLabCut/DeepLabCut) and [Anipose](https://github.com/lambdaloop/anipose). By tracking anatomically-informed keypoints using multiple cameras registered in 3D, our pipeline produces sensitive, high-precision facial movement data that can be related internal state (e.g., electrophysiology).
7+
Cheese3D is a pipeline for tracking mouse facial movements built on top of existing tools ([DeepLabCut](https://github.com/DeepLabCut/DeepLabCut) and [Anipose](https://github.com/lambdaloop/anipose)). By tracking anatomically-informed keypoints using multiple cameras registered in 3D, our pipeline produces sensitive, high-precision facial movement data that can be related internal state (e.g., electrophysiology).
88

99
<p align="center">
1010
<img src="docs/source/_static/Cheese3D.gif" alt="Animation of Cheese3D pipeline", width=60%>

docs/source/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Cheese3D
44
========
55

6-
Cheese3D is a pipeline for tracking mouse facial movements built on top of existing tools like https://github.com/DeepLabCut/DeepLabCut and https://github.com/lambdaloop/anipose. By tracking anatomically-informed keypoints using multiple cameras registered in 3D, our pipeline produces sensitive, high-precision facial movement data that can be related internal state (e.g., electrophysiology).
6+
Cheese3D is a pipeline for tracking mouse facial movements built on top of existing tools (https://github.com/DeepLabCut/DeepLabCut and https://github.com/lambdaloop/anipose). By tracking anatomically-informed keypoints using multiple cameras registered in 3D, our pipeline produces sensitive, high-precision facial movement data that can be related internal state (e.g., electrophysiology).
77

88
.. image:: /_static/Cheese3D.gif
99
:width: 59%

packages/cheese3d-annotator/README.md

Whitespace-only changes.

packages/cheese3d-annotator/cheese3d_annotator/data.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,55 @@ def ensure_images_in_yaml(image_files: List[str],
111111
with open(yaml_path, "w") as f:
112112
yaml.safe_dump(annotations, f)
113113
print(f"▶︎ Added {n_files_added} new image(s) to {os.path.basename(yaml_path)}")
114+
115+
def save_shapes_yaml(layers_by_part, filenames, path, labeler=""):
116+
"""
117+
Save shape annotations using image filenames as keys.
118+
119+
Parameters
120+
----------
121+
layers_by_part : dict
122+
Dict of {part: napari shape layer}
123+
filenames : list of str
124+
List of frame image paths
125+
path : str
126+
Path to save annotation.yaml
127+
labeler : str
128+
Labeler name to include
129+
"""
130+
out = {}
131+
for part, layer in layers_by_part.items():
132+
part_data = {}
133+
for shape in layer.data:
134+
shape = np.asarray(shape)
135+
if shape.ndim != 2 or shape.shape[1] != 3 or np.isnan(shape).any():
136+
continue
137+
z = int(round(shape[0, 0]))
138+
if z < 0 or z >= len(filenames):
139+
continue
140+
fname = os.path.basename(filenames[z])
141+
part_data.setdefault(fname, []).append(shape[:, 1:].tolist()) # strip z-axis
142+
out[part] = part_data
143+
144+
with open(path, "w") as f:
145+
yaml.safe_dump(out, f)
146+
147+
148+
def load_shapes_yaml(path, filenames):
149+
with open(path, "r") as f:
150+
raw = yaml.safe_load(f) or {}
151+
152+
file_to_index = {os.path.basename(f): i for i, f in enumerate(filenames)}
153+
154+
out = {}
155+
for part, part_data in raw.items():
156+
data = []
157+
for fname, shapes in part_data.items():
158+
z = file_to_index.get(fname, None)
159+
if z is None:
160+
continue
161+
for s in shapes:
162+
shape = np.array([[z, y, x] for y, x in s], dtype=np.float32)
163+
data.append(shape)
164+
out[part] = data
165+
return out

packages/cheese3d-annotator/cheese3d_annotator/napari.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@ display_name: Cheese3D Annotator
33
contributions:
44
commands:
55
- id: cheese3d-annotator.frame_picker
6-
python_name: cheese3d-annotator.widget:FramePickerWidget
6+
python_name: cheese3d_annotator.widget:FramePickerWidget
77
title: Cheese3D Frame Picker
88
- id: cheese3d-annotator.frame_annotator
9-
python_name: cheese3d-annotator.widget:FrameAnnotatorWidget
9+
python_name: cheese3d_annotator.widget:FrameAnnotatorWidget
1010
title: Cheese3D Frame Annotator
11+
- id: cheese3d-annotator.curve_annotator
12+
python_name: cheese3d_annotator.widget:CurveAnnotatorWidget
13+
title: Cheese3D Curve Annotator (experimental)
1114
widgets:
1215
- command: cheese3d-annotator.frame_annotator
1316
display_name: Cheese3D Frame Annotator
1417
- command: cheese3d-annotator.frame_picker
1518
display_name: Cheese3D Frame Picker
19+
- command: cheese3d-annotator.curve_annotator
20+
display_name: Cheese3D Curve Annotator (experimental)

packages/cheese3d-annotator/cheese3d_annotator/widget.py

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pandas as pd
66
from typing import List
7-
from magicgui.widgets import Container, FileEdit, ComboBox, Label, CheckBox, PushButton
7+
from magicgui.widgets import Container, FileEdit, LineEdit, ComboBox, Label, CheckBox, PushButton
88
from qtpy.QtWidgets import QListWidget, QListWidgetItem, QMessageBox, QSizePolicy
99
from qtpy.QtGui import QFont, QImage, QPixmap, QIcon
1010
from qtpy.QtCore import QSize
@@ -20,7 +20,9 @@
2020
write_annotations,
2121
create_empty_annotations,
2222
find_keypoint_conflicts,
23-
ensure_images_in_yaml)
23+
ensure_images_in_yaml,
24+
save_shapes_yaml,
25+
load_shapes_yaml)
2426

2527
class FrameAnnotatorWidget(Container):
2628
def __init__(self, viewer: Viewer):
@@ -565,3 +567,117 @@ def jump_to_time(self, item):
565567
self.viewer.dims.current_step = (frame_index,) + self.viewer.dims.current_step[1:]
566568
except Exception as e:
567569
QMessageBox.warning(None, "Jump Failed", str(e))
570+
571+
class CurveAnnotatorWidget(Container):
572+
def __init__(self, viewer: Viewer):
573+
super().__init__()
574+
self.viewer = viewer
575+
self.viewer.layers.clear()
576+
self.viewer.window.remove_dock_widget('all')
577+
self.viewer.grid.enabled = False
578+
579+
self.root_folder = FileEdit(label="Root Folder", mode="d")
580+
self.labeler_name = LineEdit(label="Labeler", value="houlab")
581+
582+
self.folder_list = QListWidget()
583+
self.folder_list.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Minimum)
584+
self.folder_list.setMaximumHeight(500)
585+
self.folder_list.setFont(QFont("", 12))
586+
self.folder_list.setStyleSheet("QListWidget::item { height: 32px; }")
587+
self.folder_list.itemClicked.connect(self.load_subfolder)
588+
589+
self.help_label = Label(value="""
590+
<b>Instructions:</b><br>
591+
- Select Root Folder.<br>
592+
- Pick a frame folder to annotate.<br>
593+
- Use shape tool to draw curves on each frame.<br>
594+
- Use slider to move between frames.<br>
595+
- Press <b>S</b> to save.
596+
""")
597+
598+
self.extend([
599+
self.labeler_name,
600+
self.root_folder,
601+
])
602+
self.native.layout().addWidget(self.folder_list)
603+
self.native.layout().addWidget(self.help_label.native)
604+
605+
self.root_folder.changed.connect(self.refresh_folders)
606+
607+
self.parts = ["right_ear", "left_ear"]
608+
self.shape_layers = {} # {part: shape_layer}
609+
self.current_folder = None
610+
self.last_frame = None
611+
612+
self._bind_keys()
613+
# TODO: add on change callback for parts_dropdown
614+
self.viewer.dims.events.current_step.connect(self._on_frame_change)
615+
616+
def refresh_folders(self):
617+
self.folder_list.clear()
618+
base = self.root_folder.value
619+
if not base or not os.path.isdir(base):
620+
return
621+
for f in sorted(os.listdir(base)):
622+
if os.path.isdir(os.path.join(base, f)):
623+
self.folder_list.addItem(f)
624+
625+
def _on_frame_change(self, event):
626+
frame = self.viewer.dims.current_step[0]
627+
if self.last_frame is not None and frame != self.last_frame:
628+
self._save(self.viewer) # auto-save when frame changes
629+
self.last_frame = frame
630+
631+
def _bind_keys(self):
632+
@self.viewer.bind_key("y", overwrite=True)
633+
def _save(viewer):
634+
self._save(viewer)
635+
636+
def load_subfolder(self, item: QListWidgetItem):
637+
if self.current_folder and self.shape_layers:
638+
self._save(self.viewer)
639+
640+
folder_name = item.text()
641+
folder_path = os.path.join(self.root_folder.value, folder_name)
642+
self.viewer.layers.clear()
643+
self.shape_layers.clear()
644+
645+
# Load image stack
646+
self.filenames = sorted(glob(os.path.join(folder_path, "*.png")))
647+
if not self.filenames:
648+
QMessageBox.warning(None, "No images", f"No PNGs found in {folder_name}")
649+
return
650+
651+
stack = np.stack([
652+
rgb2gray(imread(p)) if imread(p).ndim == 3 else imread(p)
653+
for p in self.filenames
654+
])
655+
self.viewer.add_image(stack, name=folder_name)
656+
self.viewer.dims.ndisplay = 2
657+
self.viewer.dims.axis_labels = ("frame", "y", "x")
658+
self.current_folder = folder_path
659+
self.last_frame = self.viewer.dims.current_step[0]
660+
661+
# Load shapes
662+
annotation_path = os.path.join(folder_path, "annotation.yaml")
663+
all_data = load_shapes_yaml(annotation_path, self.filenames) if os.path.exists(annotation_path) else {}
664+
665+
for part in self.parts:
666+
layer = self.viewer.add_shapes(
667+
name=f"{part}_shapes", shape_type="path", ndim=3, edge_width=4
668+
)
669+
if part in all_data and all_data[part]:
670+
layer.data = all_data[part]
671+
layer.shape_type = "path"
672+
self.shape_layers[part] = layer
673+
674+
def _save(self, viewer):
675+
if self.current_folder and self.shape_layers:
676+
labeler = self.labeler_name.value
677+
save_shapes_yaml(
678+
self.shape_layers,
679+
self.filenames,
680+
os.path.join(self.current_folder, "annotation.yaml"),
681+
labeler=labeler
682+
)
683+
print(f"✅ Saved to {self.current_folder}/annotation.yaml")

packages/cheese3d-annotator/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ authors = [
77
{ name = "Kyle Daruwalla", email = "daruwal@cshl.edu"},
88
{ name = "Rubin Zhao", email = "rzhao@cshl.edu"}
99
]
10-
readme = "README.md"
11-
requires-python = "==3.10.*"
10+
readme = "../../README.md"
11+
requires-python = ">=3.10"
1212
dependencies = [
1313
"tables>=3.7.0",
1414
"pandas>=2.1.4",

packages/cheese3d/cheese3d/backends/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,6 @@ def extract_frames(self, videos: Optional[List[Path]] = None):
2626
"""Extract frames from videos."""
2727
raise NotImplementedError("This method should be implemented by subclasses.")
2828

29-
def train(self, gpu):
29+
def train(self, gpu, iterate_dataset: bool = True):
3030
"""Train the model using GPU ID `gpu`."""
3131
raise NotImplementedError("This method should be implemented by subclasses.")

0 commit comments

Comments
 (0)