diff --git a/cellpose/gui/gui.py b/cellpose/gui/gui.py index 955b9b88..0c02f1c7 100644 --- a/cellpose/gui/gui.py +++ b/cellpose/gui/gui.py @@ -6,22 +6,23 @@ import sys, os, pathlib, warnings, datetime, time, copy from qtpy import QtGui, QtCore -from superqt import QRangeSlider, QCollapsible -from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, \ +from superqt import QCollapsible +from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, \ QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, \ - QLineEdit, QMessageBox, QGroupBox, QMenu, QAction + QLineEdit, QGroupBox import pyqtgraph as pg import numpy as np from scipy.stats import mode import cv2 +from cellpose.gui.guiparts import ClickableSlider, SaturationSliderDialog + from . import guiparts, menus, io from .. import models, core, dynamics, version, train -from ..utils import download_url_to_file, masks_to_outlines, diameters -from ..io import get_image_files, imsave, imread +from ..utils import download_url_to_file +from ..io import get_image_files from ..transforms import resize_image, normalize99, normalize99_tile, smooth_sharpen_img -from ..models import normalize_default from ..plot import disk try: @@ -30,27 +31,6 @@ except: MATPLOTLIB = False -Horizontal = QtCore.Qt.Orientation.Horizontal - - -class Slider(QRangeSlider): - - def __init__(self, parent, name, color): - super().__init__(Horizontal) - self.setEnabled(False) - self.valueChanged.connect(lambda: self.levelChanged(parent)) - self.name = name - - self.setStyleSheet(""" QSlider{ - background-color: transparent; - } - """) - self.show() - - def levelChanged(self, parent): - parent.level_change(self.name) - - class QHLine(QFrame): def __init__(self): @@ -394,13 +374,15 @@ def make_buttons(self): label.setStyleSheet(f"color: {colornames[r]}") label.setFont(self.boldmedfont) self.satBoxG.addWidget(label, widget_row, 0, 1, 2) - self.sliders.append(Slider(self, names[r], colors[r])) + self.sliders.append(ClickableSlider(self)) self.sliders[-1].setMinimum(-.1) self.sliders[-1].setMaximum(255.1) self.sliders[-1].setValue([0, 255]) - self.sliders[-1].setToolTip( - "NOTE: manually changing the saturation bars does not affect normalization in segmentation" + self.sliders[-1].setToolTip("Right click to pop out slider.\n" + + "NOTE: this saturation bar does not affect normalization in segmentation" ) + self.sliders[-1].sigRightClick.connect(lambda r=r:self.open_slider_popup(r)) + self.sliders[-1].valueChanged.connect(lambda value, r=r: self.color_level_change(value, r)) self.satBoxG.addWidget(self.sliders[-1], widget_row, 2, 1, 7) b += 1 @@ -668,17 +650,54 @@ def make_buttons(self): return b + + def open_slider_popup(self, r:int) -> None: + """Open a saturation slider popup for the color channel at index ``r``. + + The popup closes when the user clicks outside of it. + + Args: + r (int): Color channel index. + + Raises: + AssertionError: If ``r`` is not in {0, 1, 2}. + """ + assert r in [0, 1, 2], f'The color index, `r`, must be in (0, 1, 2), got {r}' + low, high = self.saturation[r][self.currentZ] + low = int(low) + high = int(high) + dialog = SaturationSliderDialog(self, low=low, high=high) + dialog.valueChanged.connect(lambda val, r=r: self.color_level_change(val, r)) + dialog.show() + + # move the dialog to the top of the window: + global_pos = self.win.mapToGlobal(QtCore.QPoint(0, 0)) + x_pad = 10 + y_offset = 100 + x = global_pos.x() + x_pad + y = global_pos.y() + y_offset + dialog.move(x, y) + dialog.setFixedWidth(self.win.width() - 2*x_pad) + + + def color_level_change(self, lohi:tuple, r:int) -> None: + """Update the saturation range for color channel ``r`` and refresh the plot. + + If the auto-adjust button is unchecked, the same range is applied + to all z-layers for all channels. + + Args: + lohi (tuple[int, int]): (low, high) saturation values. + r (int): Color channel index. + """ + self.saturation[r][self.currentZ] = lohi + # update all the layers if autobtn is unchecked + if not self.autobtn.isChecked(): + for ch in range(3): + for i in range(len(self.saturation[ch])): + self.saturation[ch][i] = self.saturation[ch][self.currentZ] + self.update_plot() - def level_change(self, r): - r = ["red", "green", "blue"].index(r) - if self.loaded: - sval = self.sliders[r].value() - self.saturation[r][self.currentZ] = sval - # if not self.autobtn.isChecked(): - for r in range(3): - for i in range(len(self.saturation[r])): - self.saturation[r][i] = self.saturation[r][self.currentZ] - self.update_plot() def keyPressEvent(self, event): event.ignore() @@ -771,7 +790,7 @@ def keyPressEvent(self, event): # when in stroke, allow escaping out of drawing else: - if event.key() == QtCore.Qt.Key_Escape: + if self.in_stroke and event.key() == QtCore.Qt.Key_Escape: self.layer.end_stroke(keep_stroke=False) if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal: self.p0.keyPressEvent(event) diff --git a/cellpose/gui/guiparts.py b/cellpose/gui/guiparts.py index 9e64a151..61831072 100644 --- a/cellpose/gui/guiparts.py +++ b/cellpose/gui/guiparts.py @@ -10,6 +10,8 @@ import numpy as np import pathlib, os +from superqt import QRangeSlider + from cellpose import models from cellpose.gui.io import _save_sets @@ -796,3 +798,174 @@ def setDrawKernel(self, kernel_size=3): opamask = 100 * kernel[:, :, np.newaxis] self.redmask = np.concatenate((onmask, offmask, offmask, onmask), axis=-1) self.strokemask = np.concatenate((onmask, offmask, onmask, opamask), axis=-1) + + +Horizontal = QtCore.Qt.Orientation.Horizontal + + +class BaseSlider(QRangeSlider): + """Horizontal range slider. Starts disabled.""" + def __init__(self, parent): + super().__init__(Horizontal) + self.setParent(parent) + self.setEnabled(False) + self.setStyleSheet(""" QSlider{ + background-color: transparent; + } + """) + self.show() + + +class SaturationSliderDialog(QDialog): + """A range slider dialog with text box inputs for fine-grained control. + + Connect to ``valueChanged`` to receive ``(low, high)`` tuples whenever the + slider or text boxes change. + """ + + # Note: The textboxes update the slider and only the slider is connected to signals. + # This way, only the ``valueChanged`` signal is required for all syncing. + + valueChanged = QtCore.Signal(tuple) + + def __init__(self, parent: QWidget, low:int|None=None, high:int|None=None, dtype:np.dtype|None=None): + """ + Args: + parent (QWidget): Parent widget. + low (int | None, optional): Initial low value. Defaults to ``dtype`` minimum. + high (int | None, optional): Initial high value. Defaults to ``dtype`` maximum. + dtype (np.dtype | None, optional): Data type used to set slider bounds. Defaults to ``np.uint8``. + """ + super().__init__(parent) + + if dtype is None: + dtype = np.dtype(np.uint8) + + try: + dtype = np.dtype(dtype) + except TypeError: + raise TypeError(f"Expected valid numpy data type, got {type(dtype)}") + + self.dtype_min = np.iinfo(dtype).min + self.dtype_max = np.iinfo(dtype).max + + if low is None: + low = self.dtype_min + if high is None: + high = self.dtype_max + + self.slider = BaseSlider(self) + layout = QGridLayout(self) + low_textbox = QLineEdit(self) + low_textbox.setFixedWidth(50) + self.low_textbox = low_textbox + + high_textbox = QLineEdit(self) + high_textbox.setFixedWidth(50) + self.high_textbox = high_textbox + + layout.addWidget(low_textbox, 0, 0) + layout.addWidget(self.slider, 0, 1) + layout.addWidget(high_textbox, 0, 2) + layout.setColumnStretch(1, 1) + self.setLayout(layout) + + self.slider.setMinimum(self.dtype_min) + self.slider.setMaximum(self.dtype_max) + self.slider.setValue([low, high]) + self.slider.setEnabled(True) + + self._low = low + self._high = high + self.low_textbox.setText(str(low)) + self.high_textbox.setText(str(high)) + + self.setWindowFlags(QtCore.Qt.Popup) # make it stationary and temporary + low_textbox.textChanged.connect(self._validate_update_low_textbox) + high_textbox.textChanged.connect(self._validate_update_high_textbox) + self.slider.valueChanged.connect(self.slider_changed) + + + def _validate_text_input(self, value) -> int: + """Convert textbox string to ``int``, returning 0 for empty input.""" + if len(value) < 1: + value = 0 + return int(value) + + + def _validate_update_low_textbox(self) -> None: + """Read the low textbox and update the slider; no-op if low would exceed high.""" + low = self._validate_text_input(self.low_textbox.text()) + high = self.slider.value()[1] + if low <= high: + self.slider.setValue((low, high)) + + + def _validate_update_high_textbox(self) -> None: + """Read the high textbox and update the slider; no-op if high would go below low.""" + low = self.slider.value()[0] + high = self._validate_text_input(self.high_textbox.text()) + if low <= high: + self.slider.setValue((low, high)) + + + @property + def low(self) -> int: + """ slider low value """ + return self._low + + + @low.setter + def low(self, value:int) -> None: + """ set `value` must be <= self.high, otherwise ignored""" + if value == self._low: + return + if value > self.high: + return + self._low = value + self.low_textbox.setText(str(value)) + self._update_validators() + + + @property + def high(self) -> int: + """ slider high value""" + return self._high + + + @high.setter + def high(self, value:int) -> None: + """ set `value` must be >= self.low, otherwise ignored """ + if value == self._high: + return + if value < self.low: + return + self._high = value + self.high_textbox.setText(str(value)) + self._update_validators() + + + def _update_validators(self) -> None: + """ Update the textbox validators after editing """ + self.high_textbox.setValidator(QtGui.QIntValidator(self.low, self.dtype_max)) + self.low_textbox.setValidator(QtGui.QIntValidator(self.dtype_min, self.high)) + + + def slider_changed(self, lohi) -> None: + """ Set the low and high values and emit the valueChanged signal. """ + lo, hi = lohi + self.low = lo + self.high = hi + self.valueChanged.emit(lohi) + + +class ClickableSlider(BaseSlider): + """ Slider that emits a signal on right click """ + sigRightClick = QtCore.Signal() + + def __init__(self, parent): + super().__init__(parent) + + def contextMenuEvent(self, event): + self.sigRightClick.emit() + event.accept()