Skip to content

Commit 4ea76eb

Browse files
Merge pull request #1454 from MouseLand/dino
adding dino models and gpu-accelerated augmentations
2 parents c69c4ae + be5047b commit 4ea76eb

28 files changed

Lines changed: 3401 additions & 2642 deletions

.github/workflows/test_and_deploy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
fail-fast: true
2323
matrix:
2424
platform: [ubuntu-latest, windows-latest, macos-latest]
25-
python-version: ["3.10", "3.11"]
25+
python-version: ["3.11", "3.12"]
2626

2727
steps:
2828
- uses: actions/checkout@v4

cellpose/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def get_arg_parser():
6565

6666
# model settings
6767
model_args = parser.add_argument_group("Model Arguments")
68-
model_args.add_argument("--pretrained_model", required=False, default="cpsam",
68+
model_args.add_argument("--pretrained_model", required=False, default="cpdino",
6969
type=str,
70-
help="model to use for running or starting training")
70+
help="path to model for segmentation or starting training, or builtin model name: cpdino, cpsam_v2, cpdino-vitb, or cpsam")
7171
model_args.add_argument(
7272
"--add_model", required=False, default=None, type=str,
7373
help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")

cellpose/core.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,8 @@ def run_3D(net, imgs, batch_size=8, augment=False,
292292
# per image
293293
core_logger.info("running %s: %d planes of size (%d, %d)" %
294294
(sstr[p], shape[pm[p][0]], shape[pm[p][1]], shape[pm[p][2]]))
295-
y, style = run_net(net,
296-
xsl, batch_size=batch_size, augment=augment,
297-
bsize=bsize, tile_overlap=tile_overlap,
298-
rsz=None)
295+
y, style = run_net(net, xsl, batch_size=batch_size, augment=augment,
296+
bsize=bsize, tile_overlap=tile_overlap, rsz=None)
299297
yf[..., -1] += y[..., -1].transpose(ipm[p])
300298
for j in range(2):
301299
yf[..., cp[p][j]] += y[..., cpy[p][j]].transpose(ipm[p])

cellpose/denoise.py

Lines changed: 2 additions & 1288 deletions
Large diffs are not rendered by default.

cellpose/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from cellpose.io import imread
5252
from cellpose.utils import download_url_to_file
5353
from cellpose.transforms import pad_image_ND, normalize_img, convert_image
54-
from cellpose.vit_sam import CPnetBioImageIO
54+
from cellpose.vit import CPnetBioImageIO
5555

5656
from bioimageio.spec.model.v0_5 import (
5757
ArchitectureFromFileDescr,

cellpose/gui/gui.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def __init__(self, image=None, logger=None):
290290
"learning_rate": 1e-5,
291291
"weight_decay": 0.1,
292292
"n_epochs": 100,
293-
"model_name": "cpsam" + d.strftime("_%Y%m%d_%H%M%S"),
293+
"model_name": "cp4" + d.strftime("_%Y%m%d_%H%M%S"),
294294
}
295295

296296
self.stitch_threshold = 0.
@@ -496,21 +496,32 @@ def make_buttons(self):
496496
)
497497
self.useGPU.setFont(self.medfont)
498498
self.check_gpu()
499-
self.segBoxG.addWidget(self.useGPU, widget_row, 0, 1, 3)
500-
501-
# compute segmentation with general models
502-
self.net_text = ["run CPSAM"]
503-
nett = ["cellpose super-generalist model"]
504-
505-
self.StyleButtons = []
506-
jj = 4
507-
for j in range(len(self.net_text)):
508-
self.StyleButtons.append(
509-
guiparts.ModelButton(self, self.net_text[j], self.net_text[j]))
510-
w = 5
511-
self.segBoxG.addWidget(self.StyleButtons[-1], widget_row, jj, 1, w)
512-
jj += w
513-
self.StyleButtons[-1].setToolTip(nett[j])
499+
self.segBoxG.addWidget(self.useGPU, widget_row, 0, 1, 3)
500+
501+
self.progress = QProgressBar(self)
502+
self.segBoxG.addWidget(self.progress, widget_row, 3, 1, 5)
503+
504+
# compute segmentation with built-in models
505+
widget_row += 1
506+
self.ModelChooseB = QComboBox()
507+
self.ModelChooseB.setFont(self.medfont)
508+
current_index = 0
509+
self.ModelChooseB.addItems(models.MODEL_NAMES)
510+
self.ModelChooseB.setFixedWidth(175)
511+
self.ModelChooseB.setCurrentIndex(current_index)
512+
tipstr = 'built-in models'
513+
self.ModelChooseB.setToolTip(tipstr)
514+
self.ModelChooseB.activated.connect(lambda: self.model_choose(custom=False))
515+
self.segBoxG.addWidget(self.ModelChooseB, widget_row, 0, 1, 8)
516+
517+
# compute segmentation w/ custom model
518+
self.ModelButtonB = QPushButton(u"run")
519+
self.ModelButtonB.setFont(self.medfont)
520+
self.ModelButtonB.setFixedWidth(35)
521+
self.ModelButtonB.clicked.connect(
522+
lambda: self.compute_segmentation(custom=False))
523+
self.segBoxG.addWidget(self.ModelButtonB, widget_row, 8, 1, 1)
524+
self.ModelButtonB.setEnabled(False)
514525

515526
widget_row += 1
516527
self.ncells = guiparts.ObservableVariable(0)
@@ -521,10 +532,7 @@ def make_buttons(self):
521532
lambda n: self.roi_count.setText(f'{str(n)} ROIs')
522533
)
523534

524-
self.segBoxG.addWidget(self.roi_count, widget_row, 0, 1, 4)
525-
526-
self.progress = QProgressBar(self)
527-
self.segBoxG.addWidget(self.progress, widget_row, 4, 1, 5)
535+
self.segBoxG.addWidget(self.roi_count, widget_row, 3, 1, 4)
528536

529537
widget_row += 1
530538

@@ -786,15 +794,13 @@ def check_gpu(self, torch=True):
786794

787795

788796
def model_choose(self, custom=False):
789-
index = self.ModelChooseC.currentIndex(
790-
) if custom else self.ModelChooseB.currentIndex()
791-
if index > 0:
792-
if custom:
793-
model_name = self.ModelChooseC.currentText()
794-
else:
795-
model_name = self.net_names[index - 1]
796-
print(f"GUI_INFO: selected model {model_name}, loading now")
797-
self.initialize_model(model_name=model_name, custom=custom)
797+
if custom:
798+
model_name = self.ModelChooseC.currentText()
799+
else:
800+
model_name = self.ModelChooseB.currentText()
801+
print(f"GUI_INFO: selected model {model_name}")
802+
# avoid double-loading model unless we need to?
803+
# self.initialize_model(model_name=model_name, custom=custom)
798804

799805
def toggle_scale(self):
800806
if self.scale_on:
@@ -805,11 +811,10 @@ def toggle_scale(self):
805811
self.scale_on = True
806812

807813
def enable_buttons(self):
814+
self.ModelButtonB.setEnabled(True)
808815
if len(self.model_strings) > 0:
809816
self.ModelButtonC.setEnabled(True)
810-
for i in range(len(self.StyleButtons)):
811-
self.StyleButtons[i].setEnabled(True)
812-
817+
813818
for i in range(len(self.FilterButtons)):
814819
self.FilterButtons[i].setEnabled(True)
815820
if self.load_3D:
@@ -1889,8 +1894,8 @@ def get_model_path(self, custom=False):
18891894
self.current_model_path = os.fspath(
18901895
models.MODEL_DIR.joinpath(self.current_model))
18911896
else:
1892-
self.current_model = "cpsam"
1893-
self.current_model_path = models.model_path(self.current_model)
1897+
self.current_model = self.ModelChooseB.currentText()
1898+
self.current_model_path = models.cache_model_path(self.current_model)
18941899

18951900
def initialize_model(self, model_name=None, custom=False):
18961901
if model_name is None or custom:
@@ -1907,7 +1912,7 @@ def initialize_model(self, model_name=None, custom=False):
19071912
models.MODEL_DIR.joinpath(self.current_model))
19081913

19091914
self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
1910-
pretrained_model=self.current_model)
1915+
pretrained_model=self.current_model_path)
19111916

19121917
def add_model(self):
19131918
io._add_model(self)
@@ -1926,7 +1931,8 @@ def new_model(self):
19261931
image_names = self.get_files()[0]
19271932
self.train_data, self.train_labels, self.train_files, restore, normalize_params = io._get_train_set(
19281933
image_names)
1929-
TW = guiparts.TrainWindow(self, models.MODEL_NAMES)
1934+
self.training_params["model_index"] = self.ModelChooseB.currentIndex()
1935+
TW = guiparts.TrainWindow(self)
19301936
train = TW.exec_()
19311937
if train:
19321938
self.logger.info(
@@ -1944,7 +1950,7 @@ def train_model(self, restore=None, normalize_params=None):
19441950
self.current_model = model_type
19451951

19461952
self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
1947-
model_type=model_type)
1953+
pretrained_model=model_type)
19481954
save_path = os.path.dirname(self.filename)
19491955

19501956
print("GUI_INFO: name of new model: " + self.training_params["model_name"])
@@ -2048,9 +2054,7 @@ def compute_segmentation(self, custom=False, model_name=None, load_model=True):
20482054
print(normalize_params)
20492055
try:
20502056
masks, flows = self.model.eval(
2051-
data,
2052-
diameter=diameter,
2053-
cellprob_threshold=cellprob_threshold,
2057+
data, diameter=diameter, cellprob_threshold=cellprob_threshold,
20542058
flow_threshold=flow_threshold, do_3D=do_3D, niter=niter,
20552059
normalize=normalize_params, stitch_threshold=stitch_threshold,
20562060
anisotropy=anisotropy, flow3D_smooth=flow3D_smooth,

cellpose/gui/guiparts.py

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111
import pathlib, os
1212

13+
from cellpose import models
1314
from cellpose.gui.io import _save_sets
1415

1516

@@ -111,32 +112,6 @@ def setup(self):
111112
)
112113

113114

114-
# def create_channel_choose():
115-
# # choose channel
116-
# ChannelChoose = [QComboBox(), QComboBox()]
117-
# ChannelLabels = []
118-
# ChannelChoose[0].addItems(["gray", "red", "green", "blue"])
119-
# ChannelChoose[1].addItems(["none", "red", "green", "blue"])
120-
# cstr = ["chan to segment:", "chan2 (optional): "]
121-
# for i in range(2):
122-
# ChannelLabels.append(QLabel(cstr[i]))
123-
# if i == 0:
124-
# ChannelLabels[i].setToolTip(
125-
# "this is the channel in which the cytoplasm or nuclei exist \
126-
# that you want to segment")
127-
# ChannelChoose[i].setToolTip(
128-
# "this is the channel in which the cytoplasm or nuclei exist \
129-
# that you want to segment")
130-
# else:
131-
# ChannelLabels[i].setToolTip(
132-
# "if <em>cytoplasm</em> model is chosen, and you also have a \
133-
# nuclear channel, then choose the nuclear channel for this option")
134-
# ChannelChoose[i].setToolTip(
135-
# "if <em>cytoplasm</em> model is chosen, and you also have a \
136-
# nuclear channel, then choose the nuclear channel for this option")
137-
138-
# return ChannelChoose, ChannelLabels
139-
140115
def unsilence_exceptions(func):
141116
""" Wrapper to unsilence Qt exceptions and re-raise them """
142117
def wrapper(*args, **kwargs):
@@ -148,19 +123,6 @@ def wrapper(*args, **kwargs):
148123
logger.debug(''.join(traceback.format_exception(type(e), e, e.__traceback__)))
149124
return wrapper
150125

151-
class ModelButton(QPushButton):
152-
153-
def __init__(self, parent, model_name, text):
154-
super().__init__()
155-
self.setEnabled(False)
156-
self.setText(text)
157-
self.setFont(parent.boldfont)
158-
self.clicked.connect(lambda: self.press(parent))
159-
self.model_name = "cpsam"
160-
161-
def press(self, parent):
162-
parent.compute_segmentation(model_name="cpsam")
163-
164126

165127
class FilterButton(QPushButton):
166128

@@ -428,7 +390,7 @@ def niter(self):
428390

429391
class TrainWindow(QDialog):
430392

431-
def __init__(self, parent, model_strings):
393+
def __init__(self, parent):
432394
super().__init__(parent)
433395
self.setGeometry(100, 100, 900, 550)
434396
self.setWindowTitle("train settings")
@@ -446,7 +408,7 @@ def __init__(self, parent, model_strings):
446408
# choose initial model
447409
yoff += 1
448410
self.ModelChoose = QComboBox()
449-
self.ModelChoose.addItems(model_strings)
411+
self.ModelChoose.addItems(models.MODEL_NAMES)
450412
self.ModelChoose.setFixedWidth(150)
451413
self.ModelChoose.setCurrentIndex(parent.training_params["model_index"])
452414
self.l0.addWidget(self.ModelChoose, yoff, 1, 1, 1)

0 commit comments

Comments
 (0)