Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 60 additions & 41 deletions cellpose/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,23 @@
import sys, os, pathlib, warnings, datetime, time, copy

from qtpy import QtGui, QtCore
from superqt import QRangeSlider, QCollapsible
from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, \
from superqt import QCollapsible
from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, \
QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, \
QLineEdit, QMessageBox, QGroupBox, QMenu, QAction
QLineEdit, QGroupBox
import pyqtgraph as pg

import numpy as np
from scipy.stats import mode
import cv2

from cellpose.gui.guiparts import ClickableSlider, SaturationSliderDialog

from . import guiparts, menus, io
from .. import models, core, dynamics, version, train
from ..utils import download_url_to_file, masks_to_outlines, diameters
from ..io import get_image_files, imsave, imread
from ..utils import download_url_to_file
from ..io import get_image_files
from ..transforms import resize_image, normalize99, normalize99_tile, smooth_sharpen_img
from ..models import normalize_default
from ..plot import disk

try:
Expand All @@ -30,27 +31,6 @@
except:
MATPLOTLIB = False

Horizontal = QtCore.Qt.Orientation.Horizontal


class Slider(QRangeSlider):

def __init__(self, parent, name, color):
super().__init__(Horizontal)
self.setEnabled(False)
self.valueChanged.connect(lambda: self.levelChanged(parent))
self.name = name

self.setStyleSheet(""" QSlider{
background-color: transparent;
}
""")
self.show()

def levelChanged(self, parent):
parent.level_change(self.name)


class QHLine(QFrame):

def __init__(self):
Expand Down Expand Up @@ -394,13 +374,15 @@ def make_buttons(self):
label.setStyleSheet(f"color: {colornames[r]}")
label.setFont(self.boldmedfont)
self.satBoxG.addWidget(label, widget_row, 0, 1, 2)
self.sliders.append(Slider(self, names[r], colors[r]))
self.sliders.append(ClickableSlider(self))
self.sliders[-1].setMinimum(-.1)
self.sliders[-1].setMaximum(255.1)
self.sliders[-1].setValue([0, 255])
self.sliders[-1].setToolTip(
"NOTE: manually changing the saturation bars does not affect normalization in segmentation"
self.sliders[-1].setToolTip("Right click to pop out slider.\n" +
"NOTE: this saturation bar does not affect normalization in segmentation"
)
self.sliders[-1].sigRightClick.connect(lambda r=r:self.open_slider_popup(r))
self.sliders[-1].valueChanged.connect(lambda value, r=r: self.color_level_change(value, r))
self.satBoxG.addWidget(self.sliders[-1], widget_row, 2, 1, 7)

b += 1
Expand Down Expand Up @@ -668,17 +650,54 @@ def make_buttons(self):


return b

def open_slider_popup(self, r:int) -> None:
"""Open a saturation slider popup for the color channel at index ``r``.

The popup closes when the user clicks outside of it.

Args:
r (int): Color channel index.

Raises:
AssertionError: If ``r`` is not in {0, 1, 2}.
"""
assert r in [0, 1, 2], f'The color index, `r`, must be in (0, 1, 2), got {r}'
low, high = self.saturation[r][self.currentZ]
low = int(low)
high = int(high)
dialog = SaturationSliderDialog(self, low=low, high=high)
dialog.valueChanged.connect(lambda val, r=r: self.color_level_change(val, r))
dialog.show()

# move the dialog to the top of the window:
global_pos = self.win.mapToGlobal(QtCore.QPoint(0, 0))
x_pad = 10
y_offset = 100
x = global_pos.x() + x_pad
y = global_pos.y() + y_offset
dialog.move(x, y)
dialog.setFixedWidth(self.win.width() - 2*x_pad)


def color_level_change(self, lohi:tuple, r:int) -> None:
"""Update the saturation range for color channel ``r`` and refresh the plot.

If the auto-adjust button is unchecked, the same range is applied
to all z-layers for all channels.

Args:
lohi (tuple[int, int]): (low, high) saturation values.
r (int): Color channel index.
"""
self.saturation[r][self.currentZ] = lohi
# update all the layers if autobtn is unchecked
if not self.autobtn.isChecked():
for ch in range(3):
for i in range(len(self.saturation[ch])):
self.saturation[ch][i] = self.saturation[ch][self.currentZ]
self.update_plot()

def level_change(self, r):
r = ["red", "green", "blue"].index(r)
if self.loaded:
sval = self.sliders[r].value()
self.saturation[r][self.currentZ] = sval
# if not self.autobtn.isChecked():
for r in range(3):
for i in range(len(self.saturation[r])):
self.saturation[r][i] = self.saturation[r][self.currentZ]
self.update_plot()

def keyPressEvent(self, event):
event.ignore()
Expand Down Expand Up @@ -771,7 +790,7 @@ def keyPressEvent(self, event):

# when in stroke, allow escaping out of drawing
else:
if event.key() == QtCore.Qt.Key_Escape:
if self.in_stroke and event.key() == QtCore.Qt.Key_Escape:
self.layer.end_stroke(keep_stroke=False)
if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal:
self.p0.keyPressEvent(event)
Expand Down
173 changes: 173 additions & 0 deletions cellpose/gui/guiparts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import numpy as np
import pathlib, os

from superqt import QRangeSlider

from cellpose import models
from cellpose.gui.io import _save_sets

Expand Down Expand Up @@ -796,3 +798,174 @@ def setDrawKernel(self, kernel_size=3):
opamask = 100 * kernel[:, :, np.newaxis]
self.redmask = np.concatenate((onmask, offmask, offmask, onmask), axis=-1)
self.strokemask = np.concatenate((onmask, offmask, onmask, opamask), axis=-1)


Horizontal = QtCore.Qt.Orientation.Horizontal


class BaseSlider(QRangeSlider):
"""Horizontal range slider. Starts disabled."""
def __init__(self, parent):
super().__init__(Horizontal)
self.setParent(parent)
self.setEnabled(False)
self.setStyleSheet(""" QSlider{
background-color: transparent;
}
""")
self.show()


class SaturationSliderDialog(QDialog):
"""A range slider dialog with text box inputs for fine-grained control.

Connect to ``valueChanged`` to receive ``(low, high)`` tuples whenever the
slider or text boxes change.
"""

# Note: The textboxes update the slider and only the slider is connected to signals.
# This way, only the ``valueChanged`` signal is required for all syncing.

valueChanged = QtCore.Signal(tuple)

def __init__(self, parent: QWidget, low:int|None=None, high:int|None=None, dtype:np.dtype|None=None):
"""
Args:
parent (QWidget): Parent widget.
low (int | None, optional): Initial low value. Defaults to ``dtype`` minimum.
high (int | None, optional): Initial high value. Defaults to ``dtype`` maximum.
dtype (np.dtype | None, optional): Data type used to set slider bounds. Defaults to ``np.uint8``.
"""
super().__init__(parent)

if dtype is None:
dtype = np.dtype(np.uint8)

try:
dtype = np.dtype(dtype)
except TypeError:
raise TypeError(f"Expected valid numpy data type, got {type(dtype)}")

self.dtype_min = np.iinfo(dtype).min
self.dtype_max = np.iinfo(dtype).max

if low is None:
low = self.dtype_min
if high is None:
high = self.dtype_max

self.slider = BaseSlider(self)
layout = QGridLayout(self)
low_textbox = QLineEdit(self)
low_textbox.setFixedWidth(50)
self.low_textbox = low_textbox

high_textbox = QLineEdit(self)
high_textbox.setFixedWidth(50)
self.high_textbox = high_textbox

layout.addWidget(low_textbox, 0, 0)
layout.addWidget(self.slider, 0, 1)
layout.addWidget(high_textbox, 0, 2)
layout.setColumnStretch(1, 1)
self.setLayout(layout)

self.slider.setMinimum(self.dtype_min)
self.slider.setMaximum(self.dtype_max)
self.slider.setValue([low, high])
self.slider.setEnabled(True)

self._low = low
self._high = high
self.low_textbox.setText(str(low))
self.high_textbox.setText(str(high))

self.setWindowFlags(QtCore.Qt.Popup) # make it stationary and temporary
low_textbox.textChanged.connect(self._validate_update_low_textbox)
high_textbox.textChanged.connect(self._validate_update_high_textbox)
self.slider.valueChanged.connect(self.slider_changed)


def _validate_text_input(self, value) -> int:
"""Convert textbox string to ``int``, returning 0 for empty input."""
if len(value) < 1:
value = 0
return int(value)


def _validate_update_low_textbox(self) -> None:
"""Read the low textbox and update the slider; no-op if low would exceed high."""
low = self._validate_text_input(self.low_textbox.text())
high = self.slider.value()[1]
if low <= high:
self.slider.setValue((low, high))


def _validate_update_high_textbox(self) -> None:
"""Read the high textbox and update the slider; no-op if high would go below low."""
low = self.slider.value()[0]
high = self._validate_text_input(self.high_textbox.text())
if low <= high:
self.slider.setValue((low, high))


@property
def low(self) -> int:
""" slider low value """
return self._low


@low.setter
def low(self, value:int) -> None:
""" set `value` must be <= self.high, otherwise ignored"""
if value == self._low:
return
if value > self.high:
return
self._low = value
self.low_textbox.setText(str(value))
self._update_validators()


@property
def high(self) -> int:
""" slider high value"""
return self._high


@high.setter
def high(self, value:int) -> None:
""" set `value` must be >= self.low, otherwise ignored """
if value == self._high:
return
if value < self.low:
return
self._high = value
self.high_textbox.setText(str(value))
self._update_validators()


def _update_validators(self) -> None:
""" Update the textbox validators after editing """
self.high_textbox.setValidator(QtGui.QIntValidator(self.low, self.dtype_max))
self.low_textbox.setValidator(QtGui.QIntValidator(self.dtype_min, self.high))


def slider_changed(self, lohi) -> None:
""" Set the low and high values and emit the valueChanged signal. """
lo, hi = lohi
self.low = lo
self.high = hi
self.valueChanged.emit(lohi)


class ClickableSlider(BaseSlider):
""" Slider that emits a signal on right click """
sigRightClick = QtCore.Signal()

def __init__(self, parent):
super().__init__(parent)

def contextMenuEvent(self, event):
self.sigRightClick.emit()
event.accept()
Loading