|
4 | 4 | import numpy as np |
5 | 5 | import pandas as pd |
6 | 6 | from typing import List |
7 | | -from magicgui.widgets import Container, FileEdit, ComboBox, Label, CheckBox, PushButton |
| 7 | +from magicgui.widgets import Container, FileEdit, LineEdit, ComboBox, Label, CheckBox, PushButton |
8 | 8 | from qtpy.QtWidgets import QListWidget, QListWidgetItem, QMessageBox, QSizePolicy |
9 | 9 | from qtpy.QtGui import QFont, QImage, QPixmap, QIcon |
10 | 10 | from qtpy.QtCore import QSize |
|
20 | 20 | write_annotations, |
21 | 21 | create_empty_annotations, |
22 | 22 | find_keypoint_conflicts, |
23 | | - ensure_images_in_yaml) |
| 23 | + ensure_images_in_yaml, |
| 24 | + save_shapes_yaml, |
| 25 | + load_shapes_yaml) |
24 | 26 |
|
25 | 27 | class FrameAnnotatorWidget(Container): |
26 | 28 | def __init__(self, viewer: Viewer): |
@@ -565,3 +567,117 @@ def jump_to_time(self, item): |
565 | 567 | self.viewer.dims.current_step = (frame_index,) + self.viewer.dims.current_step[1:] |
566 | 568 | except Exception as e: |
567 | 569 | QMessageBox.warning(None, "Jump Failed", str(e)) |
| 570 | + |
| 571 | +class CurveAnnotatorWidget(Container): |
| 572 | + def __init__(self, viewer: Viewer): |
| 573 | + super().__init__() |
| 574 | + self.viewer = viewer |
| 575 | + self.viewer.layers.clear() |
| 576 | + self.viewer.window.remove_dock_widget('all') |
| 577 | + self.viewer.grid.enabled = False |
| 578 | + |
| 579 | + self.root_folder = FileEdit(label="Root Folder", mode="d") |
| 580 | + self.labeler_name = LineEdit(label="Labeler", value="houlab") |
| 581 | + |
| 582 | + self.folder_list = QListWidget() |
| 583 | + self.folder_list.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Minimum) |
| 584 | + self.folder_list.setMaximumHeight(500) |
| 585 | + self.folder_list.setFont(QFont("", 12)) |
| 586 | + self.folder_list.setStyleSheet("QListWidget::item { height: 32px; }") |
| 587 | + self.folder_list.itemClicked.connect(self.load_subfolder) |
| 588 | + |
| 589 | + self.help_label = Label(value=""" |
| 590 | + <b>Instructions:</b><br> |
| 591 | + - Select Root Folder.<br> |
| 592 | + - Pick a frame folder to annotate.<br> |
| 593 | + - Use shape tool to draw curves on each frame.<br> |
| 594 | + - Use slider to move between frames.<br> |
| 595 | + - Press <b>S</b> to save. |
| 596 | + """) |
| 597 | + |
| 598 | + self.extend([ |
| 599 | + self.labeler_name, |
| 600 | + self.root_folder, |
| 601 | + ]) |
| 602 | + self.native.layout().addWidget(self.folder_list) |
| 603 | + self.native.layout().addWidget(self.help_label.native) |
| 604 | + |
| 605 | + self.root_folder.changed.connect(self.refresh_folders) |
| 606 | + |
| 607 | + self.parts = ["right_ear", "left_ear"] |
| 608 | + self.shape_layers = {} # {part: shape_layer} |
| 609 | + self.current_folder = None |
| 610 | + self.last_frame = None |
| 611 | + |
| 612 | + self._bind_keys() |
| 613 | + # TODO: add on change callback for parts_dropdown |
| 614 | + self.viewer.dims.events.current_step.connect(self._on_frame_change) |
| 615 | + |
| 616 | + def refresh_folders(self): |
| 617 | + self.folder_list.clear() |
| 618 | + base = self.root_folder.value |
| 619 | + if not base or not os.path.isdir(base): |
| 620 | + return |
| 621 | + for f in sorted(os.listdir(base)): |
| 622 | + if os.path.isdir(os.path.join(base, f)): |
| 623 | + self.folder_list.addItem(f) |
| 624 | + |
| 625 | + def _on_frame_change(self, event): |
| 626 | + frame = self.viewer.dims.current_step[0] |
| 627 | + if self.last_frame is not None and frame != self.last_frame: |
| 628 | + self._save(self.viewer) # auto-save when frame changes |
| 629 | + self.last_frame = frame |
| 630 | + |
| 631 | + def _bind_keys(self): |
| 632 | + @self.viewer.bind_key("y", overwrite=True) |
| 633 | + def _save(viewer): |
| 634 | + self._save(viewer) |
| 635 | + |
| 636 | + def load_subfolder(self, item: QListWidgetItem): |
| 637 | + if self.current_folder and self.shape_layers: |
| 638 | + self._save(self.viewer) |
| 639 | + |
| 640 | + folder_name = item.text() |
| 641 | + folder_path = os.path.join(self.root_folder.value, folder_name) |
| 642 | + self.viewer.layers.clear() |
| 643 | + self.shape_layers.clear() |
| 644 | + |
| 645 | + # Load image stack |
| 646 | + self.filenames = sorted(glob(os.path.join(folder_path, "*.png"))) |
| 647 | + if not self.filenames: |
| 648 | + QMessageBox.warning(None, "No images", f"No PNGs found in {folder_name}") |
| 649 | + return |
| 650 | + |
| 651 | + stack = np.stack([ |
| 652 | + rgb2gray(imread(p)) if imread(p).ndim == 3 else imread(p) |
| 653 | + for p in self.filenames |
| 654 | + ]) |
| 655 | + self.viewer.add_image(stack, name=folder_name) |
| 656 | + self.viewer.dims.ndisplay = 2 |
| 657 | + self.viewer.dims.axis_labels = ("frame", "y", "x") |
| 658 | + self.current_folder = folder_path |
| 659 | + self.last_frame = self.viewer.dims.current_step[0] |
| 660 | + |
| 661 | + # Load shapes |
| 662 | + annotation_path = os.path.join(folder_path, "annotation.yaml") |
| 663 | + all_data = load_shapes_yaml(annotation_path, self.filenames) if os.path.exists(annotation_path) else {} |
| 664 | + |
| 665 | + for part in self.parts: |
| 666 | + layer = self.viewer.add_shapes( |
| 667 | + name=f"{part}_shapes", shape_type="path", ndim=3, edge_width=4 |
| 668 | + ) |
| 669 | + if part in all_data and all_data[part]: |
| 670 | + layer.data = all_data[part] |
| 671 | + layer.shape_type = "path" |
| 672 | + self.shape_layers[part] = layer |
| 673 | + |
| 674 | + def _save(self, viewer): |
| 675 | + if self.current_folder and self.shape_layers: |
| 676 | + labeler = self.labeler_name.value |
| 677 | + save_shapes_yaml( |
| 678 | + self.shape_layers, |
| 679 | + self.filenames, |
| 680 | + os.path.join(self.current_folder, "annotation.yaml"), |
| 681 | + labeler=labeler |
| 682 | + ) |
| 683 | + print(f"✅ Saved to {self.current_folder}/annotation.yaml") |
0 commit comments