Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
fail-fast: true
matrix:
platform: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.10", "3.11"]
python-version: ["3.11", "3.12"]

steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions cellpose/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def get_arg_parser():

# model settings
model_args = parser.add_argument_group("Model Arguments")
model_args.add_argument("--pretrained_model", required=False, default="cpsam",
model_args.add_argument("--pretrained_model", required=False, default="cpdino",
type=str,
help="model to use for running or starting training")
help="path to model for segmentation or starting training, or builtin model name: cpdino, cpsam_v2, cpdino-vitb, or cpsam")
model_args.add_argument(
"--add_model", required=False, default=None, type=str,
help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")
Expand Down
6 changes: 2 additions & 4 deletions cellpose/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,8 @@ def run_3D(net, imgs, batch_size=8, augment=False,
# per image
core_logger.info("running %s: %d planes of size (%d, %d)" %
(sstr[p], shape[pm[p][0]], shape[pm[p][1]], shape[pm[p][2]]))
y, style = run_net(net,
xsl, batch_size=batch_size, augment=augment,
bsize=bsize, tile_overlap=tile_overlap,
rsz=None)
y, style = run_net(net, xsl, batch_size=batch_size, augment=augment,
bsize=bsize, tile_overlap=tile_overlap, rsz=None)
yf[..., -1] += y[..., -1].transpose(ipm[p])
for j in range(2):
yf[..., cp[p][j]] += y[..., cpy[p][j]].transpose(ipm[p])
Expand Down
1,290 changes: 2 additions & 1,288 deletions cellpose/denoise.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion cellpose/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from cellpose.io import imread
from cellpose.utils import download_url_to_file
from cellpose.transforms import pad_image_ND, normalize_img, convert_image
from cellpose.vit_sam import CPnetBioImageIO
from cellpose.vit import CPnetBioImageIO

from bioimageio.spec.model.v0_5 import (
ArchitectureFromFileDescr,
Expand Down
84 changes: 44 additions & 40 deletions cellpose/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def __init__(self, image=None, logger=None):
"learning_rate": 1e-5,
"weight_decay": 0.1,
"n_epochs": 100,
"model_name": "cpsam" + d.strftime("_%Y%m%d_%H%M%S"),
"model_name": "cp4" + d.strftime("_%Y%m%d_%H%M%S"),
}

self.stitch_threshold = 0.
Expand Down Expand Up @@ -496,21 +496,32 @@ def make_buttons(self):
)
self.useGPU.setFont(self.medfont)
self.check_gpu()
self.segBoxG.addWidget(self.useGPU, widget_row, 0, 1, 3)

# compute segmentation with general models
self.net_text = ["run CPSAM"]
nett = ["cellpose super-generalist model"]

self.StyleButtons = []
jj = 4
for j in range(len(self.net_text)):
self.StyleButtons.append(
guiparts.ModelButton(self, self.net_text[j], self.net_text[j]))
w = 5
self.segBoxG.addWidget(self.StyleButtons[-1], widget_row, jj, 1, w)
jj += w
self.StyleButtons[-1].setToolTip(nett[j])
self.segBoxG.addWidget(self.useGPU, widget_row, 0, 1, 3)

self.progress = QProgressBar(self)
self.segBoxG.addWidget(self.progress, widget_row, 3, 1, 5)

# compute segmentation with built-in models
widget_row += 1
self.ModelChooseB = QComboBox()
self.ModelChooseB.setFont(self.medfont)
current_index = 0
self.ModelChooseB.addItems(models.MODEL_NAMES)
self.ModelChooseB.setFixedWidth(175)
self.ModelChooseB.setCurrentIndex(current_index)
tipstr = 'built-in models'
self.ModelChooseB.setToolTip(tipstr)
self.ModelChooseB.activated.connect(lambda: self.model_choose(custom=False))
self.segBoxG.addWidget(self.ModelChooseB, widget_row, 0, 1, 8)

# compute segmentation w/ custom model
self.ModelButtonB = QPushButton(u"run")
self.ModelButtonB.setFont(self.medfont)
self.ModelButtonB.setFixedWidth(35)
self.ModelButtonB.clicked.connect(
lambda: self.compute_segmentation(custom=False))
self.segBoxG.addWidget(self.ModelButtonB, widget_row, 8, 1, 1)
self.ModelButtonB.setEnabled(False)

widget_row += 1
self.ncells = guiparts.ObservableVariable(0)
Expand All @@ -521,10 +532,7 @@ def make_buttons(self):
lambda n: self.roi_count.setText(f'{str(n)} ROIs')
)

self.segBoxG.addWidget(self.roi_count, widget_row, 0, 1, 4)

self.progress = QProgressBar(self)
self.segBoxG.addWidget(self.progress, widget_row, 4, 1, 5)
self.segBoxG.addWidget(self.roi_count, widget_row, 3, 1, 4)

widget_row += 1

Expand Down Expand Up @@ -786,15 +794,13 @@ def check_gpu(self, torch=True):


def model_choose(self, custom=False):
index = self.ModelChooseC.currentIndex(
) if custom else self.ModelChooseB.currentIndex()
if index > 0:
if custom:
model_name = self.ModelChooseC.currentText()
else:
model_name = self.net_names[index - 1]
print(f"GUI_INFO: selected model {model_name}, loading now")
self.initialize_model(model_name=model_name, custom=custom)
if custom:
model_name = self.ModelChooseC.currentText()
else:
model_name = self.ModelChooseB.currentText()
print(f"GUI_INFO: selected model {model_name}")
# avoid double-loading model unless we need to?
# self.initialize_model(model_name=model_name, custom=custom)

def toggle_scale(self):
if self.scale_on:
Expand All @@ -805,11 +811,10 @@ def toggle_scale(self):
self.scale_on = True

def enable_buttons(self):
self.ModelButtonB.setEnabled(True)
if len(self.model_strings) > 0:
self.ModelButtonC.setEnabled(True)
for i in range(len(self.StyleButtons)):
self.StyleButtons[i].setEnabled(True)


for i in range(len(self.FilterButtons)):
self.FilterButtons[i].setEnabled(True)
if self.load_3D:
Expand Down Expand Up @@ -1889,8 +1894,8 @@ def get_model_path(self, custom=False):
self.current_model_path = os.fspath(
models.MODEL_DIR.joinpath(self.current_model))
else:
self.current_model = "cpsam"
self.current_model_path = models.model_path(self.current_model)
self.current_model = self.ModelChooseB.currentText()
self.current_model_path = models.cache_model_path(self.current_model)

def initialize_model(self, model_name=None, custom=False):
if model_name is None or custom:
Expand All @@ -1907,7 +1912,7 @@ def initialize_model(self, model_name=None, custom=False):
models.MODEL_DIR.joinpath(self.current_model))

self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
pretrained_model=self.current_model)
pretrained_model=self.current_model_path)

def add_model(self):
io._add_model(self)
Expand All @@ -1926,7 +1931,8 @@ def new_model(self):
image_names = self.get_files()[0]
self.train_data, self.train_labels, self.train_files, restore, normalize_params = io._get_train_set(
image_names)
TW = guiparts.TrainWindow(self, models.MODEL_NAMES)
self.training_params["model_index"] = self.ModelChooseB.currentIndex()
TW = guiparts.TrainWindow(self)
train = TW.exec_()
if train:
self.logger.info(
Expand All @@ -1944,7 +1950,7 @@ def train_model(self, restore=None, normalize_params=None):
self.current_model = model_type

self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
model_type=model_type)
pretrained_model=model_type)
save_path = os.path.dirname(self.filename)

print("GUI_INFO: name of new model: " + self.training_params["model_name"])
Expand Down Expand Up @@ -2048,9 +2054,7 @@ def compute_segmentation(self, custom=False, model_name=None, load_model=True):
print(normalize_params)
try:
masks, flows = self.model.eval(
data,
diameter=diameter,
cellprob_threshold=cellprob_threshold,
data, diameter=diameter, cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold, do_3D=do_3D, niter=niter,
normalize=normalize_params, stitch_threshold=stitch_threshold,
anisotropy=anisotropy, flow3D_smooth=flow3D_smooth,
Expand Down
44 changes: 3 additions & 41 deletions cellpose/gui/guiparts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import pathlib, os

from cellpose import models
from cellpose.gui.io import _save_sets


Expand Down Expand Up @@ -111,32 +112,6 @@ def setup(self):
)


# def create_channel_choose():
# # choose channel
# ChannelChoose = [QComboBox(), QComboBox()]
# ChannelLabels = []
# ChannelChoose[0].addItems(["gray", "red", "green", "blue"])
# ChannelChoose[1].addItems(["none", "red", "green", "blue"])
# cstr = ["chan to segment:", "chan2 (optional): "]
# for i in range(2):
# ChannelLabels.append(QLabel(cstr[i]))
# if i == 0:
# ChannelLabels[i].setToolTip(
# "this is the channel in which the cytoplasm or nuclei exist \
# that you want to segment")
# ChannelChoose[i].setToolTip(
# "this is the channel in which the cytoplasm or nuclei exist \
# that you want to segment")
# else:
# ChannelLabels[i].setToolTip(
# "if <em>cytoplasm</em> model is chosen, and you also have a \
# nuclear channel, then choose the nuclear channel for this option")
# ChannelChoose[i].setToolTip(
# "if <em>cytoplasm</em> model is chosen, and you also have a \
# nuclear channel, then choose the nuclear channel for this option")

# return ChannelChoose, ChannelLabels

def unsilence_exceptions(func):
""" Wrapper to unsilence Qt exceptions and re-raise them """
def wrapper(*args, **kwargs):
Expand All @@ -148,19 +123,6 @@ def wrapper(*args, **kwargs):
logger.debug(''.join(traceback.format_exception(type(e), e, e.__traceback__)))
return wrapper

class ModelButton(QPushButton):

def __init__(self, parent, model_name, text):
super().__init__()
self.setEnabled(False)
self.setText(text)
self.setFont(parent.boldfont)
self.clicked.connect(lambda: self.press(parent))
self.model_name = "cpsam"

def press(self, parent):
parent.compute_segmentation(model_name="cpsam")


class FilterButton(QPushButton):

Expand Down Expand Up @@ -428,7 +390,7 @@ def niter(self):

class TrainWindow(QDialog):

def __init__(self, parent, model_strings):
def __init__(self, parent):
super().__init__(parent)
self.setGeometry(100, 100, 900, 550)
self.setWindowTitle("train settings")
Expand All @@ -446,7 +408,7 @@ def __init__(self, parent, model_strings):
# choose initial model
yoff += 1
self.ModelChoose = QComboBox()
self.ModelChoose.addItems(model_strings)
self.ModelChoose.addItems(models.MODEL_NAMES)
self.ModelChoose.setFixedWidth(150)
self.ModelChoose.setCurrentIndex(parent.training_params["model_index"])
self.l0.addWidget(self.ModelChoose, yoff, 1, 1, 1)
Expand Down
Loading
Loading